1// Chat support (incl. tool call grammar constraining & output parsing) w/ generic & custom template handlers.
2
3#pragma once
4
5#include "common.h"
6#include <functional>
7#include <chrono>
8#include <string>
9#include <vector>
10#include <map>
11
12struct common_chat_templates;
13
14struct common_chat_tool_call {
15 std::string name;
16 std::string arguments;
17 std::string id;
18
19 bool operator==(const common_chat_tool_call & other) const {
20 return name == other.name && arguments == other.arguments && id == other.id;
21 }
22};
23
24struct common_chat_msg_content_part {
25 std::string type;
26 std::string text;
27
28 bool operator==(const common_chat_msg_content_part & other) const {
29 return type == other.type && text == other.text;
30 }
31};
32
33struct common_chat_msg {
34 std::string role;
35 std::string content;
36 std::vector<common_chat_msg_content_part> content_parts;
37 std::vector<common_chat_tool_call> tool_calls;
38 std::string reasoning_content;
39 std::string tool_name;
40 std::string tool_call_id;
41
42 template <class T> T to_json_oaicompat() const;
43
44 bool empty() const {
45 return content.empty() && content_parts.empty() && tool_calls.empty() && reasoning_content.empty() && tool_name.empty() && tool_call_id.empty();
46 }
47 void set_tool_call_ids(std::vector<std::string> & ids_cache, const std::function<std::string()> & gen_tool_call_id) {
48 for (auto i = 0u; i < tool_calls.size(); i++) {
49 if (ids_cache.size() <= i) {
50 auto id = tool_calls[i].id;
51 if (id.empty()) {
52 id = gen_tool_call_id();
53 }
54 ids_cache.push_back(x: id);
55 }
56 tool_calls[i].id = ids_cache[i];
57 }
58 }
59 bool operator==(const common_chat_msg & other) const {
60 return role == other.role
61 && content == other.content
62 && content_parts == other.content_parts
63 && tool_calls == other.tool_calls
64 && reasoning_content == other.reasoning_content
65 && tool_name == other.tool_name
66 && tool_call_id == other.tool_call_id;
67 }
68 bool operator!=(const common_chat_msg & other) const {
69 return !(*this == other);
70 }
71};
72
73struct common_chat_msg_diff {
74 std::string reasoning_content_delta;
75 std::string content_delta;
76 size_t tool_call_index = std::string::npos;
77 common_chat_tool_call tool_call_delta;
78
79 static std::vector<common_chat_msg_diff> compute_diffs(const common_chat_msg & previous_msg, const common_chat_msg & new_msg);
80
81 bool operator==(const common_chat_msg_diff & other) const {
82 return content_delta == other.content_delta
83 && tool_call_index == other.tool_call_index
84 && tool_call_delta == other.tool_call_delta;
85 }
86};
87
88struct common_chat_tool {
89 std::string name;
90 std::string description;
91 std::string parameters;
92};
93
94enum common_chat_tool_choice {
95 COMMON_CHAT_TOOL_CHOICE_AUTO,
96 COMMON_CHAT_TOOL_CHOICE_REQUIRED,
97 COMMON_CHAT_TOOL_CHOICE_NONE,
98};
99
100enum common_chat_format {
101 COMMON_CHAT_FORMAT_CONTENT_ONLY,
102 COMMON_CHAT_FORMAT_GENERIC,
103 COMMON_CHAT_FORMAT_MISTRAL_NEMO,
104 COMMON_CHAT_FORMAT_MAGISTRAL,
105 COMMON_CHAT_FORMAT_LLAMA_3_X,
106 COMMON_CHAT_FORMAT_LLAMA_3_X_WITH_BUILTIN_TOOLS,
107 COMMON_CHAT_FORMAT_DEEPSEEK_R1,
108 COMMON_CHAT_FORMAT_FIREFUNCTION_V2,
109 COMMON_CHAT_FORMAT_FUNCTIONARY_V3_2,
110 COMMON_CHAT_FORMAT_FUNCTIONARY_V3_1_LLAMA_3_1,
111 COMMON_CHAT_FORMAT_DEEPSEEK_V3_1,
112 COMMON_CHAT_FORMAT_HERMES_2_PRO,
113 COMMON_CHAT_FORMAT_COMMAND_R7B,
114 COMMON_CHAT_FORMAT_GRANITE,
115 COMMON_CHAT_FORMAT_GPT_OSS,
116 COMMON_CHAT_FORMAT_SEED_OSS,
117 COMMON_CHAT_FORMAT_NEMOTRON_V2,
118 COMMON_CHAT_FORMAT_APERTUS,
119 COMMON_CHAT_FORMAT_LFM2_WITH_JSON_TOOLS,
120
121 COMMON_CHAT_FORMAT_COUNT, // Not a format, just the # formats
122};
123
124struct common_chat_templates_inputs {
125 std::vector<common_chat_msg> messages;
126 std::string grammar;
127 std::string json_schema;
128 bool add_generation_prompt = true;
129 bool use_jinja = true;
130 // Parameters below only supported when use_jinja is true
131 std::vector<common_chat_tool> tools;
132 common_chat_tool_choice tool_choice = COMMON_CHAT_TOOL_CHOICE_AUTO;
133 bool parallel_tool_calls = false;
134 common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
135 bool enable_thinking = true;
136 std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
137 std::map<std::string, std::string> chat_template_kwargs;
138 bool add_bos = false;
139 bool add_eos = false;
140};
141
142struct common_chat_params {
143 common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
144 std::string prompt;
145 std::string grammar;
146 bool grammar_lazy = false;
147 bool thinking_forced_open = false;
148 std::vector<common_grammar_trigger> grammar_triggers;
149 std::vector<std::string> preserved_tokens;
150 std::vector<std::string> additional_stops;
151};
152
153struct common_chat_syntax {
154 common_chat_format format = COMMON_CHAT_FORMAT_CONTENT_ONLY;
155 common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_NONE;
156 // Whether reasoning_content should be inlined in the content (e.g. for reasoning_format=deepseek in stream mode)
157 bool reasoning_in_content = false;
158 bool thinking_forced_open = false;
159 bool parse_tool_calls = true;
160};
161
162// Check if the template supplied via "--chat-template" is supported or not. Returns true if it's valid
163bool common_chat_verify_template(const std::string & tmpl, bool use_jinja);
164
165void common_chat_templates_free(struct common_chat_templates * tmpls);
166
167struct common_chat_templates_deleter { void operator()(common_chat_templates * tmpls) { common_chat_templates_free(tmpls); } };
168
169typedef std::unique_ptr<struct common_chat_templates, common_chat_templates_deleter> common_chat_templates_ptr;
170
171common_chat_templates_ptr common_chat_templates_init(
172 const struct llama_model * model,
173 const std::string & chat_template_override,
174 const std::string & bos_token_override = "",
175 const std::string & eos_token_override = "");
176
177bool common_chat_templates_was_explicit(const struct common_chat_templates * tmpls);
178const char * common_chat_templates_source(const struct common_chat_templates * tmpls, const char * variant = nullptr);
179
180
181struct common_chat_params common_chat_templates_apply(
182 const struct common_chat_templates * tmpls,
183 const struct common_chat_templates_inputs & inputs);
184
185// Format single message, while taking into account the position of that message in chat history
186std::string common_chat_format_single(
187 const struct common_chat_templates * tmpls,
188 const std::vector<common_chat_msg> & past_msg,
189 const common_chat_msg & new_msg,
190 bool add_ass,
191 bool use_jinja);
192
193// Returns an example of formatted chat
194std::string common_chat_format_example(
195 const struct common_chat_templates * tmpls,
196 bool use_jinja,
197 const std::map<std::string, std::string> & chat_template_kwargs);
198
199const char* common_chat_format_name(common_chat_format format);
200const char* common_reasoning_format_name(common_reasoning_format format);
201common_reasoning_format common_reasoning_format_from_name(const std::string & format);
202common_chat_msg common_chat_parse(const std::string & input, bool is_partial, const common_chat_syntax & syntax);
203
204common_chat_tool_choice common_chat_tool_choice_parse_oaicompat(const std::string & tool_choice);
205
206bool common_chat_templates_support_enable_thinking(const common_chat_templates * chat_templates);
207
208// Parses a JSON array of messages in OpenAI's chat completion API format.
209// T can be std::string containing JSON or nlohmann::ordered_json
210template <class T> std::vector<common_chat_msg> common_chat_msgs_parse_oaicompat(const T & messages);
211template <class T> T common_chat_msgs_to_json_oaicompat(const std::vector<common_chat_msg> & msgs, bool concat_typed_text = false);
212
213// Parses a JSON array of tools in OpenAI's chat completion tool call API format.
214// T can be std::string containing JSON or nlohmann::ordered_json
215template <class T> std::vector<common_chat_tool> common_chat_tools_parse_oaicompat(const T & tools);
216template <class T> T common_chat_tools_to_json_oaicompat(const std::vector<common_chat_tool> & tools);
217
218template <class T> T common_chat_msg_diff_to_json_oaicompat(const common_chat_msg_diff & diff);
219