| 1 | /* |
| 2 | Copyright 2024 Google LLC |
| 3 | |
| 4 | Use of this source code is governed by an MIT-style |
| 5 | license that can be found in the LICENSE file or at |
| 6 | https://opensource.org/licenses/MIT. |
| 7 | */ |
| 8 | // SPDX-License-Identifier: MIT |
| 9 | #pragma once |
| 10 | |
| 11 | #include "minja.hpp" |
| 12 | |
| 13 | #include <chrono> |
| 14 | #include <cstddef> |
| 15 | #include <cstdio> |
| 16 | #include <ctime> |
| 17 | #include <exception> |
| 18 | #include <iomanip> |
| 19 | #include <memory> |
| 20 | #include <sstream> |
| 21 | #include <stdexcept> |
| 22 | #include <string> |
| 23 | #include <vector> |
| 24 | |
| 25 | #include <nlohmann/json.hpp> |
| 26 | |
| 27 | using json = nlohmann::ordered_json; |
| 28 | |
| 29 | namespace minja { |
| 30 | |
| 31 | struct chat_template_caps { |
| 32 | bool supports_tools = false; |
| 33 | bool supports_tool_calls = false; |
| 34 | bool supports_tool_responses = false; |
| 35 | bool supports_system_role = false; |
| 36 | bool supports_parallel_tool_calls = false; |
| 37 | bool supports_tool_call_id = false; |
| 38 | // meta-llama/Llama-3.1-8B-Instruct expects arguments to be an object. |
| 39 | // Most other templates (and OpenAI's API) expect the arguments object to be stringified. |
| 40 | bool requires_object_arguments = false; |
| 41 | // CohereForAI/c4ai-command-r-plus simple variant |
| 42 | bool requires_non_null_content = false; |
| 43 | // MiniMaxAI/MiniMax-Text-01 special |
| 44 | bool requires_typed_content = false; |
| 45 | }; |
| 46 | |
| 47 | struct chat_template_inputs { |
| 48 | nlohmann::ordered_json messages; |
| 49 | nlohmann::ordered_json tools; |
| 50 | bool add_generation_prompt = true; |
| 51 | nlohmann::ordered_json ; |
| 52 | std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); |
| 53 | }; |
| 54 | |
| 55 | struct chat_template_options { |
| 56 | bool apply_polyfills = true; |
| 57 | bool use_bos_token = true; |
| 58 | bool use_eos_token = true; |
| 59 | bool define_strftime_now = true; |
| 60 | |
| 61 | bool polyfill_tools = true; |
| 62 | bool polyfill_tool_call_examples = true; |
| 63 | bool polyfill_tool_calls = true; |
| 64 | bool polyfill_tool_responses = true; |
| 65 | bool polyfill_system_role = true; |
| 66 | bool polyfill_object_arguments = true; |
| 67 | bool polyfill_typed_content = true; |
| 68 | }; |
| 69 | |
| 70 | class chat_template { |
| 71 | |
| 72 | private: |
| 73 | chat_template_caps caps_; |
| 74 | std::string source_; |
| 75 | std::string bos_token_; |
| 76 | std::string eos_token_; |
| 77 | std::shared_ptr<minja::TemplateNode> template_root_; |
| 78 | std::string tool_call_example_; |
| 79 | |
| 80 | std::string try_raw_render( |
| 81 | const nlohmann::ordered_json & messages, |
| 82 | const nlohmann::ordered_json & tools, |
| 83 | bool add_generation_prompt, |
| 84 | const nlohmann::ordered_json & = nlohmann::ordered_json()) const |
| 85 | { |
| 86 | try { |
| 87 | chat_template_inputs inputs; |
| 88 | inputs.messages = messages; |
| 89 | inputs.tools = tools; |
| 90 | inputs.add_generation_prompt = add_generation_prompt; |
| 91 | inputs.extra_context = extra_context; |
| 92 | // Use fixed date for tests |
| 93 | inputs.now = std::chrono::system_clock::from_time_t(t: 0); |
| 94 | |
| 95 | chat_template_options opts; |
| 96 | opts.apply_polyfills = false; |
| 97 | |
| 98 | auto prompt = apply(inputs, opts); |
| 99 | // fprintf(stderr, "try_raw_render: %s\n", prompt.c_str()); |
| 100 | return prompt; |
| 101 | } catch (const std::exception & e) { |
| 102 | // fprintf(stderr, "try_raw_render error: %s\n", e.what()); |
| 103 | return "" ; |
| 104 | } |
| 105 | } |
| 106 | |
| 107 | public: |
| 108 | |
| 109 | chat_template(const std::string & source, const std::string & bos_token, const std::string & eos_token) |
| 110 | : source_(source), bos_token_(bos_token), eos_token_(eos_token) |
| 111 | { |
| 112 | template_root_ = minja::Parser::parse(template_str: source_, options: { |
| 113 | /* .trim_blocks = */ .trim_blocks: true, |
| 114 | /* .lstrip_blocks = */ .lstrip_blocks: true, |
| 115 | /* .keep_trailing_newline = */ .keep_trailing_newline: false, |
| 116 | }); |
| 117 | |
| 118 | auto contains = [](const std::string & haystack, const std::string & needle) { |
| 119 | return haystack.find(str: needle) != std::string::npos; |
| 120 | }; |
| 121 | |
| 122 | const std::string user_needle = "<User Needle>" ; |
| 123 | const std::string sys_needle = "<System Needle>" ; |
| 124 | const json dummy_str_user_msg = {{"role" , "user" }, {"content" , user_needle}}; |
| 125 | const json dummy_typed_user_msg = {{"role" , "user" }, {"content" , json::array(init: {{{"type" , "text" }, {"text" , user_needle}}})}}; |
| 126 | |
| 127 | caps_.requires_typed_content = |
| 128 | !contains(try_raw_render(messages: json::array(init: {dummy_str_user_msg}), tools: {}, add_generation_prompt: false), user_needle) |
| 129 | && contains(try_raw_render(messages: json::array(init: {dummy_typed_user_msg}), tools: {}, add_generation_prompt: false), user_needle); |
| 130 | |
| 131 | const auto dummy_user_msg = caps_.requires_typed_content |
| 132 | ? dummy_typed_user_msg |
| 133 | : dummy_str_user_msg; |
| 134 | const json needle_system_msg = { |
| 135 | {"role" , "system" }, |
| 136 | {"content" , caps_.requires_typed_content ? json::array(init: {{{"type" , "text" }, {"text" , sys_needle}}}) : json(sys_needle)}, |
| 137 | }; |
| 138 | |
| 139 | caps_.supports_system_role = contains(try_raw_render(messages: {needle_system_msg, dummy_user_msg,}, tools: {}, add_generation_prompt: false), sys_needle); |
| 140 | |
| 141 | auto out = try_raw_render(messages: json::array(init: { |
| 142 | dummy_user_msg |
| 143 | }), tools: json::array(init: { |
| 144 | { |
| 145 | {"name" , "some_tool" }, |
| 146 | {"type" , "function" }, |
| 147 | {"function" , { |
| 148 | {"name" , "some_tool" }, |
| 149 | {"description" , "Some tool." }, |
| 150 | {"parameters" , { |
| 151 | {"type" , "object" }, |
| 152 | {"properties" , { |
| 153 | {"arg" , { |
| 154 | {"type" , "string" }, |
| 155 | {"description" , "Some argument." }, |
| 156 | }}, |
| 157 | }}, |
| 158 | {"required" , json::array(init: { "arg" })}, |
| 159 | }}, |
| 160 | }}, |
| 161 | }, |
| 162 | }), add_generation_prompt: false); |
| 163 | caps_.supports_tools = contains(out, "some_tool" ); |
| 164 | |
| 165 | const auto render_with_content = [&](const json & content) { |
| 166 | const json assistant_msg {{"role" , "assistant" }, {"content" , content}}; |
| 167 | // Render two assistant messages as some templates like QwQ-32B are handling |
| 168 | // the content differently depending on whether it's the last message or not |
| 169 | // (to remove the <think> tag in all but the last message). |
| 170 | return try_raw_render(messages: json::array(init: {dummy_user_msg, assistant_msg, dummy_user_msg, assistant_msg}), tools: {}, add_generation_prompt: false); |
| 171 | }; |
| 172 | auto out_empty = render_with_content("" ); |
| 173 | auto out_null = render_with_content(json()); |
| 174 | caps_.requires_non_null_content = contains(out_empty, user_needle) && !contains(out_null, user_needle); |
| 175 | |
| 176 | json j_null; |
| 177 | auto make_tool_calls_msg = [&](const json & tool_calls) { |
| 178 | return json { |
| 179 | {"role" , "assistant" }, |
| 180 | {"content" , caps_.requires_non_null_content? "" : j_null}, |
| 181 | {"tool_calls" , tool_calls}, |
| 182 | }; |
| 183 | }; |
| 184 | auto make_tool_call = [](const std::string & tool_name, const json & arguments) { |
| 185 | return json { |
| 186 | {"id" , "call_1___" }, |
| 187 | {"type" , "function" }, |
| 188 | {"function" , { |
| 189 | {"arguments" , arguments}, |
| 190 | {"name" , tool_name}, |
| 191 | }}, |
| 192 | }; |
| 193 | }; |
| 194 | const json dummy_args_obj {{"argument_needle" , "print('Hello, World!')" }}; |
| 195 | const auto contains_arg_needle = [&](const std::string & out_str) { |
| 196 | return contains(out_str, "<parameter=argument_needle>" ) |
| 197 | || contains(out_str, "\"argument_needle\":" ) |
| 198 | || contains(out_str, "'argument_needle':" ) |
| 199 | || contains(out_str, ">argument_needle<" ) |
| 200 | || contains(out_str, "<parameter name=\"argument_needle\">" ); |
| 201 | }; |
| 202 | |
| 203 | // Note: the arguments are rendered in both cases, but may be double-escaped, which we don't want. |
| 204 | out = try_raw_render(messages: json::array(init: { |
| 205 | dummy_user_msg, |
| 206 | make_tool_calls_msg(json::array(init: {make_tool_call("ipython" , dummy_args_obj.dump())})), |
| 207 | }), tools: {}, add_generation_prompt: false); |
| 208 | auto tool_call_renders_str_arguments = contains_arg_needle(out); |
| 209 | out = try_raw_render(messages: json::array(init: { |
| 210 | dummy_user_msg, |
| 211 | make_tool_calls_msg(json::array(init: {make_tool_call("ipython" , dummy_args_obj)})), |
| 212 | }), tools: {}, add_generation_prompt: false); |
| 213 | auto tool_call_renders_obj_arguments = contains_arg_needle(out); |
| 214 | |
| 215 | caps_.supports_tool_calls = tool_call_renders_str_arguments || tool_call_renders_obj_arguments; |
| 216 | caps_.requires_object_arguments = !tool_call_renders_str_arguments && tool_call_renders_obj_arguments; |
| 217 | |
| 218 | if (caps_.supports_tool_calls) { |
| 219 | auto dummy_args = caps_.requires_object_arguments ? dummy_args_obj : json(dummy_args_obj.dump()); |
| 220 | auto tc1 = make_tool_call("test_tool1" , dummy_args); |
| 221 | auto tc2 = make_tool_call("test_tool2" , dummy_args); |
| 222 | auto out = try_raw_render(messages: json::array(init: { |
| 223 | dummy_user_msg, |
| 224 | make_tool_calls_msg(json::array(init: {tc1, tc2})), |
| 225 | }), tools: {}, add_generation_prompt: false); |
| 226 | caps_.supports_parallel_tool_calls = contains(out, "test_tool1" ) && contains(out, "test_tool2" ); |
| 227 | |
| 228 | out = try_raw_render(messages: json::array(init: { |
| 229 | dummy_user_msg, |
| 230 | make_tool_calls_msg(json::array(init: {tc1})), |
| 231 | { |
| 232 | {"role" , "tool" }, |
| 233 | {"name" , "test_tool1" }, |
| 234 | {"content" , "Some response!" }, |
| 235 | {"tool_call_id" , "call_911_" }, |
| 236 | } |
| 237 | }), tools: {}, add_generation_prompt: false); |
| 238 | caps_.supports_tool_responses = contains(out, "Some response!" ); |
| 239 | caps_.supports_tool_call_id = contains(out, "call_911_" ); |
| 240 | } |
| 241 | |
| 242 | try { |
| 243 | if (!caps_.supports_tools) { |
| 244 | const json user_msg { |
| 245 | {"role" , "user" }, |
| 246 | {"content" , "Hey" }, |
| 247 | }; |
| 248 | const json args { |
| 249 | {"arg1" , "some_value" }, |
| 250 | }; |
| 251 | const json tool_call_msg { |
| 252 | {"role" , "assistant" }, |
| 253 | {"content" , caps_.requires_non_null_content ? "" : j_null}, |
| 254 | {"tool_calls" , json::array(init: { |
| 255 | { |
| 256 | // TODO: detect if requires numerical id or fixed length == 6 like Nemo |
| 257 | {"id" , "call_1___" }, |
| 258 | {"type" , "function" }, |
| 259 | {"function" , { |
| 260 | {"name" , "tool_name" }, |
| 261 | {"arguments" , (caps_.requires_object_arguments ? args : json(minja::Value(args).dump(indent: -1, /* to_json= */ to_json: true)))}, |
| 262 | }}, |
| 263 | }, |
| 264 | })}, |
| 265 | }; |
| 266 | std::string prefix, full; |
| 267 | { |
| 268 | chat_template_inputs inputs; |
| 269 | inputs.messages = json::array(init: {user_msg}); |
| 270 | inputs.add_generation_prompt = true; |
| 271 | prefix = apply(inputs); |
| 272 | } |
| 273 | { |
| 274 | chat_template_inputs inputs; |
| 275 | inputs.messages = json::array(init: {user_msg, tool_call_msg}); |
| 276 | inputs.add_generation_prompt = false; |
| 277 | full = apply(inputs); |
| 278 | } |
| 279 | auto eos_pos_last = full.rfind(str: eos_token_); |
| 280 | if (eos_pos_last == prefix.size() - eos_token_.size() || |
| 281 | (full[full.size() - 1] == '\n' && (eos_pos_last == full.size() - eos_token_.size() - 1))) { |
| 282 | full = full.substr(pos: 0, n: eos_pos_last); |
| 283 | } |
| 284 | size_t common_prefix_length = 0; |
| 285 | for (size_t i = 0; i < prefix.size() && i < full.size(); ++i) { |
| 286 | if (prefix[i] != full[i]) { |
| 287 | break; |
| 288 | } |
| 289 | if (prefix[i] == '<') { |
| 290 | // DeepSeek R1's template (as of 20250209) adds a trailing <think> if add_generation_prompt, |
| 291 | // but it removes thinking tags for past messages. |
| 292 | // The prefix and full strings diverge at <think> vs. <|tool▁calls▁begin|>, we avoid consuming the leading <. |
| 293 | continue; |
| 294 | } |
| 295 | common_prefix_length = i + 1; |
| 296 | } |
| 297 | auto example = full.substr(pos: common_prefix_length); |
| 298 | if (example.find(s: "tool_name" ) == std::string::npos && example.find(s: "some_value" ) == std::string::npos) { |
| 299 | fprintf(stderr, format: "Failed to infer a tool call example (possible template bug)\n" ); |
| 300 | } else { |
| 301 | tool_call_example_ = example; |
| 302 | } |
| 303 | } |
| 304 | } catch (const std::exception & e) { |
| 305 | fprintf(stderr, format: "Failed to generate tool call example: %s\n" , e.what()); |
| 306 | } |
| 307 | } |
| 308 | |
| 309 | const std::string & source() const { return source_; } |
| 310 | const std::string & bos_token() const { return bos_token_; } |
| 311 | const std::string & eos_token() const { return eos_token_; } |
| 312 | const chat_template_caps & original_caps() const { return caps_; } |
| 313 | |
| 314 | // Deprecated, please use the form with chat_template_inputs and chat_template_options |
| 315 | std::string apply( |
| 316 | const nlohmann::ordered_json & messages, |
| 317 | const nlohmann::ordered_json & tools, |
| 318 | bool add_generation_prompt, |
| 319 | const nlohmann::ordered_json & = nlohmann::ordered_json(), |
| 320 | bool apply_polyfills = true) |
| 321 | { |
| 322 | fprintf(stderr, format: "[%s] Deprecated!\n" , __func__); |
| 323 | chat_template_inputs inputs; |
| 324 | inputs.messages = messages; |
| 325 | inputs.tools = tools; |
| 326 | inputs.add_generation_prompt = add_generation_prompt; |
| 327 | inputs.extra_context = extra_context; |
| 328 | inputs.now = std::chrono::system_clock::now(); |
| 329 | |
| 330 | chat_template_options opts; |
| 331 | opts.apply_polyfills = apply_polyfills; |
| 332 | |
| 333 | return apply(inputs, opts); |
| 334 | } |
| 335 | |
| 336 | std::string apply( |
| 337 | const chat_template_inputs & inputs, |
| 338 | const chat_template_options & opts = chat_template_options()) const |
| 339 | { |
| 340 | json actual_messages; |
| 341 | |
| 342 | auto has_tools = inputs.tools.is_array() && !inputs.tools.empty(); |
| 343 | auto has_tool_calls = false; |
| 344 | auto has_tool_responses = false; |
| 345 | auto has_string_content = false; |
| 346 | for (const auto & message : inputs.messages) { |
| 347 | if (message.contains(key: "tool_calls" ) && !message["tool_calls" ].is_null()) { |
| 348 | has_tool_calls = true; |
| 349 | } |
| 350 | if (message.contains(key: "role" ) && message["role" ] == "tool" ) { |
| 351 | has_tool_responses = true; |
| 352 | } |
| 353 | if (message.contains(key: "content" ) && message["content" ].is_string()) { |
| 354 | has_string_content = true; |
| 355 | } |
| 356 | } |
| 357 | |
| 358 | auto polyfill_system_role = opts.polyfill_system_role && !caps_.supports_system_role; |
| 359 | auto polyfill_tools = opts.polyfill_tools && has_tools && !caps_.supports_tools; |
| 360 | auto polyfill_tool_call_example = polyfill_tools && opts.polyfill_tool_call_examples; |
| 361 | auto polyfill_tool_calls = opts.polyfill_tool_calls && has_tool_calls && !caps_.supports_tool_calls; |
| 362 | auto polyfill_tool_responses = opts.polyfill_tool_responses && has_tool_responses && !caps_.supports_tool_responses; |
| 363 | auto polyfill_object_arguments = opts.polyfill_object_arguments && has_tool_calls && caps_.requires_object_arguments; |
| 364 | auto polyfill_typed_content = opts.polyfill_typed_content && has_string_content && caps_.requires_typed_content; |
| 365 | |
| 366 | auto needs_polyfills = opts.apply_polyfills && (false |
| 367 | || polyfill_system_role |
| 368 | || polyfill_tools |
| 369 | || polyfill_tool_calls |
| 370 | || polyfill_tool_responses |
| 371 | || polyfill_object_arguments |
| 372 | || polyfill_typed_content |
| 373 | ); |
| 374 | |
| 375 | if (needs_polyfills) { |
| 376 | actual_messages = json::array(); |
| 377 | |
| 378 | auto add_message = [&](const json & msg) { |
| 379 | if (polyfill_typed_content && msg.contains(key: "content" ) && !msg.at(key: "content" ).is_null() && msg.at(key: "content" ).is_string()) { |
| 380 | actual_messages.push_back(init: { |
| 381 | {"role" , msg.at(key: "role" )}, |
| 382 | {"content" , {{ |
| 383 | {"type" , "text" }, |
| 384 | {"text" , msg.at(key: "content" )}, |
| 385 | }}}, |
| 386 | }); |
| 387 | } else { |
| 388 | actual_messages.push_back(val: msg); |
| 389 | } |
| 390 | }; |
| 391 | |
| 392 | std::string pending_system; |
| 393 | auto flush_sys = [&]() { |
| 394 | if (!pending_system.empty()) { |
| 395 | add_message({ |
| 396 | {"role" , "user" }, |
| 397 | {"content" , pending_system}, |
| 398 | }); |
| 399 | pending_system.clear(); |
| 400 | } |
| 401 | }; |
| 402 | |
| 403 | json adjusted_messages; |
| 404 | if (polyfill_tools) { |
| 405 | adjusted_messages = add_system(messages: inputs.messages, |
| 406 | system_prompt: "You can call any of the following tools to satisfy the user's requests: " + minja::Value(inputs.tools).dump(indent: 2, /* to_json= */ to_json: true) + |
| 407 | (!polyfill_tool_call_example || tool_call_example_.empty() ? "" : "\n\nExample tool call syntax:\n\n" + tool_call_example_ + "\n\n" )); |
| 408 | } else { |
| 409 | adjusted_messages = inputs.messages; |
| 410 | } |
| 411 | |
| 412 | for (const auto & message_ : adjusted_messages) { |
| 413 | auto message = message_; |
| 414 | if (!message.contains(key: "role" ) || (!message.contains(key: "content" ) && !message.contains(key: "tool_calls" ))) { |
| 415 | throw std::runtime_error("message must have 'role' and one of 'content' or 'tool_calls' fields: " + message.dump()); |
| 416 | } |
| 417 | std::string role = message.at(key: "role" ); |
| 418 | |
| 419 | if (message.contains(key: "tool_calls" )) { |
| 420 | if (polyfill_object_arguments || polyfill_tool_calls) { |
| 421 | for (auto & tool_call : message.at(key: "tool_calls" )) { |
| 422 | if (tool_call["type" ] == "function" ) { |
| 423 | auto & function = tool_call.at(key: "function" ); |
| 424 | auto & arguments = function.at(key: "arguments" ); |
| 425 | if (arguments.is_string()) { |
| 426 | try { |
| 427 | arguments = json::parse(i: arguments.get<std::string>()); |
| 428 | } catch (const std::exception & ecvt) { |
| 429 | fprintf(stderr, format: "Failed to parse arguments: %s\n" , ecvt.what()); |
| 430 | } |
| 431 | } |
| 432 | } |
| 433 | } |
| 434 | } |
| 435 | if (polyfill_tool_calls) { |
| 436 | auto tool_calls = json::array(); |
| 437 | for (const auto & tool_call : message.at(key: "tool_calls" )) { |
| 438 | if (tool_call.at(key: "type" ) != "function" ) { |
| 439 | continue; |
| 440 | } |
| 441 | const auto & function = tool_call.at(key: "function" ); |
| 442 | auto tc = json { |
| 443 | {"name" , function.at(key: "name" )}, |
| 444 | {"arguments" , function.at(key: "arguments" )}, |
| 445 | }; |
| 446 | if (tool_call.contains(key: "id" )) { |
| 447 | tc["id" ] = tool_call["id" ]; |
| 448 | } |
| 449 | tool_calls.push_back(val: tc); |
| 450 | } |
| 451 | auto obj = json { |
| 452 | {"tool_calls" , tool_calls}, |
| 453 | }; |
| 454 | if (message.contains(key: "content" )) { |
| 455 | auto content = message.at(key: "content" ); |
| 456 | if (!content.is_null() && !content.empty()) { |
| 457 | obj["content" ] = content; |
| 458 | } |
| 459 | } |
| 460 | message["content" ] = obj.dump(indent: 2); |
| 461 | message.erase(key: "tool_calls" ); |
| 462 | } |
| 463 | } |
| 464 | if (polyfill_tool_responses && role == "tool" ) { |
| 465 | message["role" ] = "user" ; |
| 466 | auto obj = json { |
| 467 | {"tool_response" , json::object()}, |
| 468 | }; |
| 469 | if (message.contains(key: "name" )) { |
| 470 | obj["tool_response" ]["tool" ] = message.at(key: "name" ); |
| 471 | } |
| 472 | obj["tool_response" ]["content" ] = message.at(key: "content" ); |
| 473 | if (message.contains(key: "tool_call_id" )) { |
| 474 | obj["tool_response" ]["tool_call_id" ] = message.at(key: "tool_call_id" ); |
| 475 | } |
| 476 | message["content" ] = obj.dump(indent: 2); |
| 477 | message.erase(key: "name" ); |
| 478 | } |
| 479 | |
| 480 | if (!message["content" ].is_null() && polyfill_system_role) { |
| 481 | std::string content = message.at(key: "content" ); |
| 482 | if (role == "system" ) { |
| 483 | if (!pending_system.empty()) pending_system += "\n" ; |
| 484 | pending_system += content; |
| 485 | continue; |
| 486 | } else { |
| 487 | if (role == "user" ) { |
| 488 | if (!pending_system.empty()) { |
| 489 | message["content" ] = pending_system + (content.empty() ? "" : "\n" + content); |
| 490 | pending_system.clear(); |
| 491 | } |
| 492 | } else { |
| 493 | flush_sys(); |
| 494 | } |
| 495 | } |
| 496 | } |
| 497 | add_message(message); |
| 498 | } |
| 499 | flush_sys(); |
| 500 | } else { |
| 501 | actual_messages = inputs.messages; |
| 502 | } |
| 503 | |
| 504 | auto context = minja::Context::make(values: json({ |
| 505 | {"messages" , actual_messages}, |
| 506 | {"add_generation_prompt" , inputs.add_generation_prompt}, |
| 507 | })); |
| 508 | context->set(key: "bos_token" , value: opts.use_bos_token ? bos_token_ : "" ); |
| 509 | context->set(key: "eos_token" , value: opts.use_eos_token ? eos_token_ : "" ); |
| 510 | if (opts.define_strftime_now) { |
| 511 | auto now = inputs.now; |
| 512 | context->set(key: "strftime_now" , value: Value::callable(callable: [now](const std::shared_ptr<minja::Context> &, minja::ArgumentsValue & args) { |
| 513 | args.expectArgs(method_name: "strftime_now" , pos_count: {1, 1}, kw_count: {0, 0}); |
| 514 | auto format = args.args[0].get<std::string>(); |
| 515 | |
| 516 | auto time = std::chrono::system_clock::to_time_t(t: now); |
| 517 | auto local_time = *std::localtime(timer: &time); |
| 518 | std::ostringstream ss; |
| 519 | ss << std::put_time(tmb: &local_time, fmt: format.c_str()); |
| 520 | return ss.str(); |
| 521 | })); |
| 522 | } |
| 523 | if (!inputs.tools.is_null()) { |
| 524 | context->set(key: "tools" , value: minja::Value(inputs.tools)); |
| 525 | } |
| 526 | if (!inputs.extra_context.is_null()) { |
| 527 | for (auto & kv : inputs.extra_context.items()) { |
| 528 | context->set(key: kv.key(), value: minja::Value(kv.value())); |
| 529 | } |
| 530 | } |
| 531 | |
| 532 | auto ret = template_root_->render(context); |
| 533 | // fprintf(stderr, "actual_messages: %s\n", actual_messages.dump(2).c_str()); |
| 534 | // fprintf(stderr, "apply: %s\n\n", ret.c_str()); |
| 535 | return ret; |
| 536 | } |
| 537 | |
| 538 | static nlohmann::ordered_json add_system(const nlohmann::ordered_json & messages, const std::string & system_prompt) { |
| 539 | json messages_with_system = messages; |
| 540 | |
| 541 | if (!messages_with_system.empty() && messages_with_system[0].at(key: "role" ) == "system" ) { |
| 542 | std::string existing_system = messages_with_system.at(idx: 0).at(key: "content" ); |
| 543 | messages_with_system[0] = json { |
| 544 | {"role" , "system" }, |
| 545 | {"content" , existing_system + "\n\n" + system_prompt}, |
| 546 | }; |
| 547 | } else { |
| 548 | messages_with_system.insert(pos: messages_with_system.begin(), val: json { |
| 549 | {"role" , "system" }, |
| 550 | {"content" , system_prompt}, |
| 551 | }); |
| 552 | } |
| 553 | return messages_with_system; |
| 554 | } |
| 555 | }; |
| 556 | |
| 557 | } // namespace minja |
| 558 | |