| 1 | // Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers. |
| 2 | |
| 3 | #pragma once |
| 4 | |
| 5 | #include "common.h" |
| 6 | #include <functional> |
| 7 | #include <chrono> |
| 8 | #include <string> |
| 9 | #include <vector> |
| 10 | #include <map> |
| 11 | |
| 12 | struct common_chat_templates; |
| 13 | |
| 14 | struct common_chat_tool_call { |
| 15 | std::string name; |
| 16 | std::string arguments; |
| 17 | std::string id; |
| 18 | |
| 19 | bool operator==(const common_chat_tool_call & other) const { |
| 20 | return name == other.name && arguments == other.arguments && id == other.id; |
| 21 | } |
| 22 | }; |
| 23 | |
| 24 | struct common_chat_msg_content_part { |
| 25 | std::string type; |
| 26 | std::string text; |
| 27 | |
| 28 | bool operator==(const common_chat_msg_content_part & other) const { |
| 29 | return type == other.type && text == other.text; |
| 30 | } |
| 31 | }; |
| 32 | |
| 33 | struct common_chat_msg { |
| 34 | std::string role; |
| 35 | std::string content; |
| 36 | std::vector<common_chat_msg_content_part> content_parts; |
| 37 | std::vector<common_chat_tool_call> tool_calls; |
| 38 | std::string reasoning_content; |
| 39 | std::string tool_name; |
| 40 | std::string tool_call_id; |
| 41 | |
| 42 | template <class T> T to_json_oaicompat() const; |
| 43 | |
| 44 | bool empty() const { |
| 45 | return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty(); |
| 46 | } |
| 47 | void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) { |
| 48 | for (auto i = 0u; i < tool_calls.size(); i++) { |
| 49 | if (ids_cache.size() <= i) { |
| 50 | auto id = tool_calls[i].id; |
| 51 | if (id.empty()) { |
| 52 | id = gen_tool_call_id(); |
| 53 | } |
| 54 | ids_cache.push_back(x: id); |
| 55 | } |
| 56 | tool_calls[i].id = ids_cache[i]; |
| 57 | } |
| 58 | } |
| 59 | bool operator==(const common_chat_msg & other) const { |
| 60 | return role == other.role |
| 61 | && content == other.content |
| 62 | && content_parts == other.content_parts |
| 63 | && tool_calls == other.tool_calls |
| 64 | && reasoning_content == other.reasoning_content |
| 65 | && tool_name == other.tool_name |
| 66 | && tool_call_id == other.tool_call_id; |
| 67 | } |
| 68 | bool operator!=(const common_chat_msg & other) const { |
| 69 | return !(*this == other); |
| 70 | } |
| 71 | }; |
| 72 | |
| 73 | struct common_chat_msg_diff { |
| 74 | std::string reasoning_content_delta; |
| 75 | std::string content_delta; |
| 76 | size_t tool_call_index = std::string::npos; |
| 77 | common_chat_tool_call tool_call_delta; |
| 78 | |
| 79 | static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg); |
| 80 | |
| 81 | bool operator==(const common_chat_msg_diff & other) const { |
| 82 | return content_delta == other.content_delta |
| 83 | && tool_call_index == other.tool_call_index |
| 84 | && tool_call_delta == other.tool_call_delta; |
| 85 | } |
| 86 | }; |
| 87 | |
| 88 | struct common_chat_tool { |
| 89 | std::string name; |
| 90 | std::string description; |
| 91 | std::string parameters; |
| 92 | }; |
| 93 | |
| 94 | enum common_chat_tool_choice { |
| 95 | COMMON_CHAT_TOOL_CHOICE_AUTO, |
| 96 | COMMON_CHAT_TOOL_CHOICE_REQUIRED, |
| 97 | COMMON_CHAT_TOOL_CHOICE_NONE, |
| 98 | }; |
| 99 | |
| 100 | enum common_chat_format { |
| 101 | COMMON_CHAT_FORMAT_CONTENT_ONLY, |
| 102 | COMMON_CHAT_FORMAT_GENERIC, |
| 103 | COMMON_CHAT_FORMAT_MISTRAL_NEMO, |
| 104 | COMMON_CHAT_FORMAT_MAGISTRAL, |
| 105 | COMMON_CHAT_FORMAT_LLAMA_3_X, |
| 106 | COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS, |
| 107 | COMMON_CHAT_FORMAT_DEEPSEEK_R1, |
| 108 | COMMON_CHAT_FORMAT_FIREFUNCTION_V2, |
| 109 | COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2, |
| 110 | COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1, |
| 111 | COMMON_CHAT_FORMAT_DEEPSEEK_V3_1, |
| 112 | COMMON_CHAT_FORMAT_HERMES_2_PRO, |
| 113 | COMMON_CHAT_FORMAT_COMMAND_R7B, |
| 114 | COMMON_CHAT_FORMAT_GRANITE, |
| 115 | COMMON_CHAT_FORMAT_GPT_OSS, |
| 116 | COMMON_CHAT_FORMAT_SEED_OSS, |
| 117 | COMMON_CHAT_FORMAT_NEMOTRON_V2, |
| 118 | COMMON_CHAT_FORMAT_APERTUS, |
| 119 | COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS, |
| 120 | |
| 121 | COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats |
| 122 | }; |
| 123 | |
| 124 | struct common_chat_templates_inputs { |
| 125 | std::vector<common_chat_msg> messages; |
| 126 | std::string grammar; |
| 127 | std::string json_schema; |
| 128 | bool add_generation_prompt = true; |
| 129 | bool use_jinja = true; |
| 130 | // Parameters below only supported when use_jinja is true |
| 131 | std::vector<common_chat_tool> tools; |
| 132 | common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO; |
| 133 | bool parallel_tool_calls = false; |
| 134 | common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; |
| 135 | bool enable_thinking = true; |
| 136 | std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); |
| 137 | std::map<std::string, std::string> chat_template_kwargs; |
| 138 | bool add_bos = false; |
| 139 | bool add_eos = false; |
| 140 | }; |
| 141 | |
| 142 | struct common_chat_params { |
| 143 | common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; |
| 144 | std::string prompt; |
| 145 | std::string grammar; |
| 146 | bool grammar_lazy = false; |
| 147 | bool thinking_forced_open = false; |
| 148 | std::vector<common_grammar_trigger> grammar_triggers; |
| 149 | std::vector<std::string> preserved_tokens; |
| 150 | std::vector<std::string> additional_stops; |
| 151 | }; |
| 152 | |
| 153 | struct common_chat_syntax { |
| 154 | common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY; |
| 155 | common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE; |
| 156 | // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode) |
| 157 | bool reasoning_in_content = false; |
| 158 | bool thinking_forced_open = false; |
| 159 | bool parse_tool_calls = true; |
| 160 | }; |
| 161 | |
| 162 | // Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid |
| 163 | bool common_chat_verify_template(const std::string & tmpl, bool use_jinja); |
| 164 | |
| 165 | void common_chat_templates_free(struct common_chat_templates * tmpls); |
| 166 | |
| 167 | struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } }; |
| 168 | |
| 169 | typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr; |
| 170 | |
| 171 | common_chat_templates_ptr common_chat_templates_init( |
| 172 | const struct llama_model * model, |
| 173 | const std::string & chat_template_override, |
| 174 | const std::string & bos_token_override = "" , |
| 175 | const std::string & eos_token_override = "" ); |
| 176 | |
| 177 | bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls); |
| 178 | const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr); |
| 179 | |
| 180 | |
| 181 | struct common_chat_params common_chat_templates_apply( |
| 182 | const struct common_chat_templates * tmpls, |
| 183 | const struct common_chat_templates_inputs & inputs); |
| 184 | |
| 185 | // Format single message, while taking into account the position of that message in chat history |
| 186 | std::string common_chat_format_single( |
| 187 | const struct common_chat_templates * tmpls, |
| 188 | const std::vector<common_chat_msg> & past_msg, |
| 189 | const common_chat_msg & new_msg, |
| 190 | bool add_ass, |
| 191 | bool use_jinja); |
| 192 | |
| 193 | // Returns an example of formatted chat |
| 194 | std::string common_chat_format_example( |
| 195 | const struct common_chat_templates * tmpls, |
| 196 | bool use_jinja, |
| 197 | const std::map<std::string, std::string> & chat_template_kwargs); |
| 198 | |
| 199 | const char* common_chat_format_name(common_chat_format format); |
| 200 | const char* common_reasoning_format_name(common_reasoning_format format); |
| 201 | common_reasoning_format common_reasoning_format_from_name(const std::string & format); |
| 202 | common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax); |
| 203 | |
| 204 | common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice); |
| 205 | |
| 206 | bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates); |
| 207 | |
| 208 | // Parses a JSON array of messages in OpenAI's chat completion API format. |
| 209 | // T can be std::string containing JSON or nlohmann::ordered_json |
| 210 | template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages); |
| 211 | template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false); |
| 212 | |
| 213 | // Parses a JSON array of tools in OpenAI's chat completion tool call API format. |
| 214 | // T can be std::string containing JSON or nlohmann::ordered_json |
| 215 | template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools); |
| 216 | template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools); |
| 217 | |
| 218 | template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff); |
| 219 | |