1#pragma once
2#include <ext/scope_guard.h>
3#include <random>
4#include <sstream>
5#include <Common/MemoryTracker.h>
6#include <Common/OpenSSLHelpers.h>
7#include <Common/PODArray.h>
8#include <Core/Types.h>
9#include <Interpreters/Context.h>
10#include <IO/copyData.h>
11#include <IO/LimitReadBuffer.h>
12#include <IO/ReadBuffer.h>
13#include <IO/ReadBufferFromMemory.h>
14#include <IO/ReadBufferFromPocoSocket.h>
15#include <IO/ReadHelpers.h>
16#include <IO/WriteBuffer.h>
17#include <IO/WriteBufferFromPocoSocket.h>
18#include <IO/WriteBufferFromString.h>
19#include <IO/WriteHelpers.h>
20#include <Poco/Net/StreamSocket.h>
21#include <Poco/RandomStream.h>
22#include <Poco/SHA1Engine.h>
23#include "config_core.h"
24#if USE_SSL
25#include <openssl/pem.h>
26#include <openssl/rsa.h>
27#endif
28
29/// Implementation of MySQL wire protocol.
30/// Works only on little-endian architecture.
31
32namespace DB
33{
34
35namespace ErrorCodes
36{
37 extern const int UNKNOWN_PACKET_FROM_CLIENT;
38 extern const int MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES;
39 extern const int OPENSSL_ERROR;
40 extern const int UNKNOWN_EXCEPTION;
41}
42
43namespace MySQLProtocol
44{
45
46const size_t MAX_PACKET_LENGTH = (1 << 24) - 1; // 16 mb
47const size_t SCRAMBLE_LENGTH = 20;
48const size_t AUTH_PLUGIN_DATA_PART_1_LENGTH = 8;
49const size_t MYSQL_ERRMSG_SIZE = 512;
50const size_t PACKET_HEADER_SIZE = 4;
51const size_t SSL_REQUEST_PAYLOAD_SIZE = 32;
52
53
54enum CharacterSet
55{
56 utf8_general_ci = 33,
57 binary = 63
58};
59
60enum StatusFlags
61{
62 SERVER_SESSION_STATE_CHANGED = 0x4000
63};
64
65enum Capability
66{
67 CLIENT_CONNECT_WITH_DB = 0x00000008,
68 CLIENT_PROTOCOL_41 = 0x00000200,
69 CLIENT_SSL = 0x00000800,
70 CLIENT_TRANSACTIONS = 0x00002000, // TODO
71 CLIENT_SESSION_TRACK = 0x00800000, // TODO
72 CLIENT_SECURE_CONNECTION = 0x00008000,
73 CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA = 0x00200000,
74 CLIENT_PLUGIN_AUTH = 0x00080000,
75 CLIENT_DEPRECATE_EOF = 0x01000000,
76};
77
78enum Command
79{
80 COM_SLEEP = 0x0,
81 COM_QUIT = 0x1,
82 COM_INIT_DB = 0x2,
83 COM_QUERY = 0x3,
84 COM_FIELD_LIST = 0x4,
85 COM_CREATE_DB = 0x5,
86 COM_DROP_DB = 0x6,
87 COM_REFRESH = 0x7,
88 COM_SHUTDOWN = 0x8,
89 COM_STATISTICS = 0x9,
90 COM_PROCESS_INFO = 0xa,
91 COM_CONNECT = 0xb,
92 COM_PROCESS_KILL = 0xc,
93 COM_DEBUG = 0xd,
94 COM_PING = 0xe,
95 COM_TIME = 0xf,
96 COM_DELAYED_INSERT = 0x10,
97 COM_CHANGE_USER = 0x11,
98 COM_RESET_CONNECTION = 0x1f,
99 COM_DAEMON = 0x1d
100};
101
102enum ColumnType
103{
104 MYSQL_TYPE_DECIMAL = 0x00,
105 MYSQL_TYPE_TINY = 0x01,
106 MYSQL_TYPE_SHORT = 0x02,
107 MYSQL_TYPE_LONG = 0x03,
108 MYSQL_TYPE_FLOAT = 0x04,
109 MYSQL_TYPE_DOUBLE = 0x05,
110 MYSQL_TYPE_NULL = 0x06,
111 MYSQL_TYPE_TIMESTAMP = 0x07,
112 MYSQL_TYPE_LONGLONG = 0x08,
113 MYSQL_TYPE_INT24 = 0x09,
114 MYSQL_TYPE_DATE = 0x0a,
115 MYSQL_TYPE_TIME = 0x0b,
116 MYSQL_TYPE_DATETIME = 0x0c,
117 MYSQL_TYPE_YEAR = 0x0d,
118 MYSQL_TYPE_VARCHAR = 0x0f,
119 MYSQL_TYPE_BIT = 0x10,
120 MYSQL_TYPE_NEWDECIMAL = 0xf6,
121 MYSQL_TYPE_ENUM = 0xf7,
122 MYSQL_TYPE_SET = 0xf8,
123 MYSQL_TYPE_TINY_BLOB = 0xf9,
124 MYSQL_TYPE_MEDIUM_BLOB = 0xfa,
125 MYSQL_TYPE_LONG_BLOB = 0xfb,
126 MYSQL_TYPE_BLOB = 0xfc,
127 MYSQL_TYPE_VAR_STRING = 0xfd,
128 MYSQL_TYPE_STRING = 0xfe,
129 MYSQL_TYPE_GEOMETRY = 0xff
130};
131
132
133// https://dev.mysql.com/doc/dev/mysql-server/latest/group__group__cs__column__definition__flags.html
134enum ColumnDefinitionFlags
135{
136 UNSIGNED_FLAG = 32,
137 BINARY_FLAG = 128
138};
139
140
141class ProtocolError : public DB::Exception
142{
143public:
144 using Exception::Exception;
145};
146
147
148/** Reading packets.
149 * Internally, it calls (if no more data) next() method of the underlying ReadBufferFromPocoSocket, and sets the working buffer to the rest part of the current packet payload.
150 */
151class PacketPayloadReadBuffer : public ReadBuffer
152{
153public:
154 PacketPayloadReadBuffer(ReadBuffer & in_, uint8_t & sequence_id_)
155 : ReadBuffer(in_.position(), 0) // not in.buffer().begin(), because working buffer may include previous packet
156 , in(in_)
157 , sequence_id(sequence_id_)
158 {
159 }
160
161private:
162 ReadBuffer & in;
163 uint8_t & sequence_id;
164 const size_t max_packet_size = MAX_PACKET_LENGTH;
165
166 bool has_read_header = false;
167
168 // Size of packet which is being read now.
169 size_t payload_length = 0;
170
171 // Offset in packet payload.
172 size_t offset = 0;
173
174protected:
175 bool nextImpl() override
176 {
177 if (!has_read_header || (payload_length == max_packet_size && offset == payload_length))
178 {
179 has_read_header = true;
180 working_buffer.resize(0);
181 offset = 0;
182 payload_length = 0;
183 in.readStrict(reinterpret_cast<char *>(&payload_length), 3);
184
185 if (payload_length > max_packet_size)
186 {
187 std::ostringstream tmp;
188 tmp << "Received packet with payload larger than max_packet_size: " << payload_length;
189 throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
190 }
191
192 size_t packet_sequence_id = 0;
193 in.read(reinterpret_cast<char &>(packet_sequence_id));
194 if (packet_sequence_id != sequence_id)
195 {
196 std::ostringstream tmp;
197 tmp << "Received packet with wrong sequence-id: " << packet_sequence_id << ". Expected: " << static_cast<unsigned int>(sequence_id) << '.';
198 throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
199 }
200 sequence_id++;
201
202 if (payload_length == 0)
203 return false;
204 }
205 else if (offset == payload_length)
206 {
207 return false;
208 }
209
210 in.nextIfAtEnd();
211 working_buffer = ReadBuffer::Buffer(in.position(), in.buffer().end());
212 size_t count = std::min(in.available(), payload_length - offset);
213 working_buffer.resize(count);
214 in.ignore(count);
215
216 offset += count;
217
218 return true;
219 }
220};
221
222
223class ClientPacket
224{
225public:
226 ClientPacket() = default;
227
228 ClientPacket(ClientPacket &&) = default;
229
230 virtual void read(ReadBuffer & in, uint8_t & sequence_id)
231 {
232 PacketPayloadReadBuffer payload(in, sequence_id);
233 readPayload(payload);
234 if (!payload.eof())
235 {
236 std::stringstream tmp;
237 tmp << "Packet payload is not fully read. Stopped after " << payload.count() << " bytes, while " << payload.available() << " bytes are in buffer.";
238 throw ProtocolError(tmp.str(), ErrorCodes::UNKNOWN_PACKET_FROM_CLIENT);
239 }
240 }
241
242 virtual void readPayload(ReadBuffer & buf) = 0;
243
244 virtual ~ClientPacket() = default;
245};
246
247
248class LimitedClientPacket : public ClientPacket
249{
250public:
251 void read(ReadBuffer & in, uint8_t & sequence_id) override
252 {
253 LimitReadBuffer limited(in, 10000, true, "too long MySQL packet.");
254 ClientPacket::read(limited, sequence_id);
255 }
256};
257
258
259/** Writing packets.
260 * https://dev.mysql.com/doc/internals/en/mysql-packet.html
261 */
262class PacketPayloadWriteBuffer : public WriteBuffer
263{
264public:
265 PacketPayloadWriteBuffer(WriteBuffer & out_, size_t payload_length_, uint8_t & sequence_id_)
266 : WriteBuffer(out_.position(), 0), out(out_), sequence_id(sequence_id_), total_left(payload_length_)
267 {
268 startNewPacket();
269 setWorkingBuffer();
270 pos = out.position();
271 }
272
273 bool remainingPayloadSize()
274 {
275 return total_left;
276 }
277
278private:
279 WriteBuffer & out;
280 uint8_t & sequence_id;
281
282 size_t total_left = 0;
283 size_t payload_length = 0;
284 size_t bytes_written = 0;
285 bool eof = false;
286
287 void startNewPacket()
288 {
289 payload_length = std::min(total_left, MAX_PACKET_LENGTH);
290 bytes_written = 0;
291 total_left -= payload_length;
292
293 out.write(reinterpret_cast<char *>(&payload_length), 3);
294 out.write(sequence_id++);
295 bytes += 4;
296 }
297
298 /// Sets working buffer to the rest of current packet payload.
299 void setWorkingBuffer()
300 {
301 out.nextIfAtEnd();
302 working_buffer = WriteBuffer::Buffer(out.position(), out.position() + std::min(payload_length - bytes_written, out.available()));
303
304 if (payload_length - bytes_written == 0)
305 {
306 /// Finished writing packet. Due to an implementation of WriteBuffer, working_buffer cannot be empty. Further write attempts will throw Exception.
307 eof = true;
308 working_buffer.resize(1);
309 }
310 }
311
312protected:
313 void nextImpl() override
314 {
315 const int written = pos - working_buffer.begin();
316 if (eof)
317 throw Exception("Cannot write after end of buffer.", ErrorCodes::CANNOT_WRITE_AFTER_END_OF_BUFFER);
318
319 out.position() += written;
320 bytes_written += written;
321
322 /// Packets of size greater than MAX_PACKET_LENGTH are split into few packets of size MAX_PACKET_LENGTH and las packet of size < MAX_PACKET_LENGTH.
323 if (bytes_written == payload_length && (total_left > 0 || payload_length == MAX_PACKET_LENGTH))
324 startNewPacket();
325
326 setWorkingBuffer();
327 }
328};
329
330
331class WritePacket
332{
333public:
334 virtual void writePayload(WriteBuffer & buffer, uint8_t & sequence_id) const
335 {
336 PacketPayloadWriteBuffer buf(buffer, getPayloadSize(), sequence_id);
337 writePayloadImpl(buf);
338 buf.next();
339 if (buf.remainingPayloadSize())
340 {
341 std::stringstream ss;
342 ss << "Incomplete payload. Written " << getPayloadSize() - buf.remainingPayloadSize() << " bytes, expected " << getPayloadSize() << " bytes.";
343 throw Exception(ss.str(), 0);
344 }
345 }
346
347 virtual ~WritePacket() = default;
348
349protected:
350 virtual size_t getPayloadSize() const = 0;
351
352 virtual void writePayloadImpl(WriteBuffer & buffer) const = 0;
353};
354
355
356/* Writes and reads packets, keeping sequence-id.
357 * Throws ProtocolError, if packet with incorrect sequence-id was received.
358 */
359class PacketSender
360{
361public:
362 uint8_t & sequence_id;
363 ReadBuffer * in;
364 WriteBuffer * out;
365 size_t max_packet_size = MAX_PACKET_LENGTH;
366
367 /// For reading and writing.
368 PacketSender(ReadBuffer & in_, WriteBuffer & out_, uint8_t & sequence_id_)
369 : sequence_id(sequence_id_)
370 , in(&in_)
371 , out(&out_)
372 {
373 }
374
375 /// For writing.
376 PacketSender(WriteBuffer & out_, uint8_t & sequence_id_)
377 : sequence_id(sequence_id_)
378 , in(nullptr)
379 , out(&out_)
380 {
381 }
382
383 void receivePacket(ClientPacket & packet)
384 {
385 packet.read(*in, sequence_id);
386 }
387
388 template<class T>
389 void sendPacket(const T & packet, bool flush = false)
390 {
391 static_assert(std::is_base_of<WritePacket, T>());
392 packet.writePayload(*out, sequence_id);
393 if (flush)
394 out->next();
395 }
396
397 PacketPayloadReadBuffer getPayload()
398 {
399 return PacketPayloadReadBuffer(*in, sequence_id);
400 }
401
402 /// Sets sequence-id to 0. Must be called before each command phase.
403 void resetSequenceId();
404
405 /// Converts packet to text. Is used for debug output.
406 static String packetToText(const String & payload);
407};
408
409
410uint64_t readLengthEncodedNumber(ReadBuffer & ss);
411
412void writeLengthEncodedNumber(uint64_t x, WriteBuffer & buffer);
413
414inline void writeLengthEncodedString(const String & s, WriteBuffer & buffer)
415{
416 writeLengthEncodedNumber(s.size(), buffer);
417 buffer.write(s.data(), s.size());
418}
419
420inline void writeNulTerminatedString(const String & s, WriteBuffer & buffer)
421{
422 buffer.write(s.data(), s.size());
423 buffer.write(0);
424}
425
426size_t getLengthEncodedNumberSize(uint64_t x);
427
428size_t getLengthEncodedStringSize(const String & s);
429
430
431class Handshake : public WritePacket
432{
433 int protocol_version = 0xa;
434 String server_version;
435 uint32_t connection_id;
436 uint32_t capability_flags;
437 uint8_t character_set;
438 uint32_t status_flags;
439 String auth_plugin_name;
440 String auth_plugin_data;
441public:
442 explicit Handshake(uint32_t capability_flags_, uint32_t connection_id_, String server_version_, String auth_plugin_name_, String auth_plugin_data_)
443 : protocol_version(0xa)
444 , server_version(std::move(server_version_))
445 , connection_id(connection_id_)
446 , capability_flags(capability_flags_)
447 , character_set(CharacterSet::utf8_general_ci)
448 , status_flags(0)
449 , auth_plugin_name(std::move(auth_plugin_name_))
450 , auth_plugin_data(std::move(auth_plugin_data_))
451 {
452 }
453
454protected:
455 size_t getPayloadSize() const override
456 {
457 return 26 + server_version.size() + auth_plugin_data.size() + auth_plugin_name.size();
458 }
459
460 void writePayloadImpl(WriteBuffer & buffer) const override
461 {
462 buffer.write(static_cast<char>(protocol_version));
463 writeNulTerminatedString(server_version, buffer);
464 buffer.write(reinterpret_cast<const char *>(&connection_id), 4);
465 writeNulTerminatedString(auth_plugin_data.substr(0, AUTH_PLUGIN_DATA_PART_1_LENGTH), buffer);
466 buffer.write(reinterpret_cast<const char *>(&capability_flags), 2);
467 buffer.write(reinterpret_cast<const char *>(&character_set), 1);
468 buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
469 buffer.write((reinterpret_cast<const char *>(&capability_flags)) + 2, 2);
470 buffer.write(static_cast<char>(auth_plugin_data.size()));
471 writeChar(0x0, 10, buffer);
472 writeString(auth_plugin_data.substr(AUTH_PLUGIN_DATA_PART_1_LENGTH, auth_plugin_data.size() - AUTH_PLUGIN_DATA_PART_1_LENGTH), buffer);
473 writeString(auth_plugin_name, buffer);
474 writeChar(0x0, 1, buffer);
475 }
476};
477
478class SSLRequest : public ClientPacket
479{
480public:
481 uint32_t capability_flags;
482 uint32_t max_packet_size;
483 uint8_t character_set;
484
485 void readPayload(ReadBuffer & buf) override
486 {
487 buf.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
488 buf.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
489 buf.readStrict(reinterpret_cast<char *>(&character_set), 1);
490 }
491};
492
493class HandshakeResponse : public LimitedClientPacket
494{
495public:
496 uint32_t capability_flags = 0;
497 uint32_t max_packet_size = 0;
498 uint8_t character_set = 0;
499 String username;
500 String auth_response;
501 String database;
502 String auth_plugin_name;
503
504 HandshakeResponse() = default;
505
506 void readPayload(ReadBuffer & payload) override
507 {
508 payload.readStrict(reinterpret_cast<char *>(&capability_flags), 4);
509 payload.readStrict(reinterpret_cast<char *>(&max_packet_size), 4);
510 payload.readStrict(reinterpret_cast<char *>(&character_set), 1);
511 payload.ignore(23);
512
513 readNullTerminated(username, payload);
514
515 if (capability_flags & CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA)
516 {
517 auto len = readLengthEncodedNumber(payload);
518 auth_response.resize(len);
519 payload.readStrict(auth_response.data(), len);
520 }
521 else if (capability_flags & CLIENT_SECURE_CONNECTION)
522 {
523 char len;
524 payload.readStrict(len);
525 auth_response.resize(static_cast<unsigned int>(len));
526 payload.readStrict(auth_response.data(), len);
527 }
528 else
529 {
530 readNullTerminated(auth_response, payload);
531 }
532
533 if (capability_flags & CLIENT_CONNECT_WITH_DB)
534 {
535 readNullTerminated(database, payload);
536 }
537
538 if (capability_flags & CLIENT_PLUGIN_AUTH)
539 {
540 readNullTerminated(auth_plugin_name, payload);
541 }
542 }
543};
544
545class AuthSwitchRequest : public WritePacket
546{
547 String plugin_name;
548 String auth_plugin_data;
549public:
550 AuthSwitchRequest(String plugin_name_, String auth_plugin_data_)
551 : plugin_name(std::move(plugin_name_)), auth_plugin_data(std::move(auth_plugin_data_))
552 {
553 }
554
555protected:
556 size_t getPayloadSize() const override
557 {
558 return 2 + plugin_name.size() + auth_plugin_data.size();
559 }
560
561 void writePayloadImpl(WriteBuffer & buffer) const override
562 {
563 buffer.write(0xfe);
564 writeNulTerminatedString(plugin_name, buffer);
565 writeString(auth_plugin_data, buffer);
566 }
567};
568
569class AuthSwitchResponse : public LimitedClientPacket
570{
571public:
572 String value;
573
574 void readPayload(ReadBuffer & payload) override
575 {
576 readStringUntilEOF(value, payload);
577 }
578};
579
580class AuthMoreData : public WritePacket
581{
582 String data;
583public:
584 explicit AuthMoreData(String data_): data(std::move(data_)) {}
585
586protected:
587 size_t getPayloadSize() const override
588 {
589 return 1 + data.size();
590 }
591
592 void writePayloadImpl(WriteBuffer & buffer) const override
593 {
594 buffer.write(0x01);
595 writeString(data, buffer);
596 }
597};
598
599
600class OK_Packet : public WritePacket
601{
602 uint8_t header;
603 uint32_t capabilities;
604 uint64_t affected_rows;
605 int16_t warnings = 0;
606 uint32_t status_flags;
607 String session_state_changes;
608 String info;
609public:
610 OK_Packet(uint8_t header_,
611 uint32_t capabilities_,
612 uint64_t affected_rows_,
613 uint32_t status_flags_,
614 int16_t warnings_,
615 String session_state_changes_ = "",
616 String info_ = "")
617 : header(header_)
618 , capabilities(capabilities_)
619 , affected_rows(affected_rows_)
620 , warnings(warnings_)
621 , status_flags(status_flags_)
622 , session_state_changes(std::move(session_state_changes_))
623 , info(std::move(info_))
624 {
625 }
626
627protected:
628 size_t getPayloadSize() const override
629 {
630 size_t result = 2 + getLengthEncodedNumberSize(affected_rows);
631
632 if (capabilities & CLIENT_PROTOCOL_41)
633 {
634 result += 4;
635 }
636 else if (capabilities & CLIENT_TRANSACTIONS)
637 {
638 result += 2;
639 }
640
641 if (capabilities & CLIENT_SESSION_TRACK)
642 {
643 result += getLengthEncodedStringSize(info);
644 if (status_flags & SERVER_SESSION_STATE_CHANGED)
645 result += getLengthEncodedStringSize(session_state_changes);
646 }
647 else
648 {
649 result += info.size();
650 }
651
652 return result;
653 }
654
655 void writePayloadImpl(WriteBuffer & buffer) const override
656 {
657 buffer.write(header);
658 writeLengthEncodedNumber(affected_rows, buffer);
659 writeLengthEncodedNumber(0, buffer); /// last insert-id
660
661 if (capabilities & CLIENT_PROTOCOL_41)
662 {
663 buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
664 buffer.write(reinterpret_cast<const char *>(&warnings), 2);
665 }
666 else if (capabilities & CLIENT_TRANSACTIONS)
667 {
668 buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
669 }
670
671 if (capabilities & CLIENT_SESSION_TRACK)
672 {
673 writeLengthEncodedString(info, buffer);
674 if (status_flags & SERVER_SESSION_STATE_CHANGED)
675 writeLengthEncodedString(session_state_changes, buffer);
676 }
677 else
678 {
679 writeString(info, buffer);
680 }
681 }
682};
683
684class EOF_Packet : public WritePacket
685{
686 int warnings;
687 int status_flags;
688public:
689 EOF_Packet(int warnings_, int status_flags_) : warnings(warnings_), status_flags(status_flags_)
690 {}
691
692protected:
693 size_t getPayloadSize() const override
694 {
695 return 5;
696 }
697
698 void writePayloadImpl(WriteBuffer & buffer) const override
699 {
700 buffer.write(0xfe); // EOF header
701 buffer.write(reinterpret_cast<const char *>(&warnings), 2);
702 buffer.write(reinterpret_cast<const char *>(&status_flags), 2);
703 }
704};
705
706class ERR_Packet : public WritePacket
707{
708 int error_code;
709 String sql_state;
710 String error_message;
711public:
712 ERR_Packet(int error_code_, String sql_state_, String error_message_)
713 : error_code(error_code_), sql_state(std::move(sql_state_)), error_message(std::move(error_message_))
714 {
715 }
716
717protected:
718 size_t getPayloadSize() const override
719 {
720 return 4 + sql_state.length() + std::min(error_message.length(), MYSQL_ERRMSG_SIZE);
721 }
722
723 void writePayloadImpl(WriteBuffer & buffer) const override
724 {
725 buffer.write(0xff);
726 buffer.write(reinterpret_cast<const char *>(&error_code), 2);
727 buffer.write('#');
728 buffer.write(sql_state.data(), sql_state.length());
729 buffer.write(error_message.data(), std::min(error_message.length(), MYSQL_ERRMSG_SIZE));
730 }
731};
732
733class ColumnDefinition : public WritePacket
734{
735 String schema;
736 String table;
737 String org_table;
738 String name;
739 String org_name;
740 size_t next_length = 0x0c;
741 uint16_t character_set;
742 uint32_t column_length;
743 ColumnType column_type;
744 uint16_t flags;
745 uint8_t decimals = 0x00;
746public:
747 ColumnDefinition(
748 String schema_,
749 String table_,
750 String org_table_,
751 String name_,
752 String org_name_,
753 uint16_t character_set_,
754 uint32_t column_length_,
755 ColumnType column_type_,
756 uint16_t flags_,
757 uint8_t decimals_)
758
759 : schema(std::move(schema_)), table(std::move(table_)), org_table(std::move(org_table_)), name(std::move(name_)),
760 org_name(std::move(org_name_)), character_set(character_set_), column_length(column_length_), column_type(column_type_), flags(flags_),
761 decimals(decimals_)
762 {
763 }
764
765 /// Should be used when column metadata (original name, table, original table, database) is unknown.
766 ColumnDefinition(
767 String name_,
768 uint16_t character_set_,
769 uint32_t column_length_,
770 ColumnType column_type_,
771 uint16_t flags_,
772 uint8_t decimals_)
773 : ColumnDefinition("", "", "", std::move(name_), "", character_set_, column_length_, column_type_, flags_, decimals_)
774 {
775 }
776
777protected:
778 size_t getPayloadSize() const override
779 {
780 return 13 + getLengthEncodedStringSize("def") + getLengthEncodedStringSize(schema) + getLengthEncodedStringSize(table) + getLengthEncodedStringSize(org_table) + \
781 getLengthEncodedStringSize(name) + getLengthEncodedStringSize(org_name) + getLengthEncodedNumberSize(next_length);
782 }
783
784 void writePayloadImpl(WriteBuffer & buffer) const override
785 {
786 writeLengthEncodedString(std::string("def"), buffer); /// always "def"
787 writeLengthEncodedString(schema, buffer);
788 writeLengthEncodedString(table, buffer);
789 writeLengthEncodedString(org_table, buffer);
790 writeLengthEncodedString(name, buffer);
791 writeLengthEncodedString(org_name, buffer);
792 writeLengthEncodedNumber(next_length, buffer);
793 buffer.write(reinterpret_cast<const char *>(&character_set), 2);
794 buffer.write(reinterpret_cast<const char *>(&column_length), 4);
795 buffer.write(reinterpret_cast<const char *>(&column_type), 1);
796 buffer.write(reinterpret_cast<const char *>(&flags), 2);
797 buffer.write(reinterpret_cast<const char *>(&decimals), 2);
798 writeChar(0x0, 2, buffer);
799 }
800};
801
802class ComFieldList : public LimitedClientPacket
803{
804public:
805 String table, field_wildcard;
806
807 void readPayload(ReadBuffer & payload) override
808 {
809 // Command byte has been already read from payload.
810 readNullTerminated(table, payload);
811 readStringUntilEOF(field_wildcard, payload);
812 }
813};
814
815class LengthEncodedNumber : public WritePacket
816{
817 uint64_t value;
818public:
819 explicit LengthEncodedNumber(uint64_t value_): value(value_)
820 {
821 }
822
823protected:
824 size_t getPayloadSize() const override
825 {
826 return getLengthEncodedNumberSize(value);
827 }
828
829 void writePayloadImpl(WriteBuffer & buffer) const override
830 {
831 writeLengthEncodedNumber(value, buffer);
832 }
833};
834
835
836ColumnDefinition getColumnDefinition(const String & column_name, const TypeIndex index);
837
838
839namespace ProtocolText
840{
841
842class ResultsetRow : public WritePacket
843{
844 const Columns & columns;
845 int row_num;
846 size_t payload_size = 0;
847 std::vector<String> serialized;
848public:
849 ResultsetRow(const DataTypes & data_types, const Columns & columns_, int row_num_)
850 : columns(columns_)
851 , row_num(row_num_)
852 {
853 for (size_t i = 0; i < columns.size(); i++)
854 {
855 if (columns[i]->isNullAt(row_num))
856 {
857 payload_size += 1;
858 serialized.emplace_back("\xfb");
859 }
860 else
861 {
862 WriteBufferFromOwnString ostr;
863 data_types[i]->serializeAsText(*columns[i], row_num, ostr, FormatSettings());
864 payload_size += getLengthEncodedStringSize(ostr.str());
865 serialized.push_back(std::move(ostr.str()));
866 }
867 }
868 }
869protected:
870 size_t getPayloadSize() const override
871 {
872 return payload_size;
873 }
874
875 void writePayloadImpl(WriteBuffer & buffer) const override
876 {
877 for (size_t i = 0; i < columns.size(); i++)
878 {
879 if (columns[i]->isNullAt(row_num))
880 buffer.write(serialized[i].data(), 1);
881 else
882 writeLengthEncodedString(serialized[i], buffer);
883 }
884 }
885};
886
887}
888
889namespace Authentication
890{
891
892class IPlugin
893{
894public:
895 virtual String getName() = 0;
896
897 virtual String getAuthPluginData() = 0;
898
899 virtual void authenticate(const String & user_name, std::optional<String> auth_response, Context & context, std::shared_ptr<PacketSender> packet_sender, bool is_secure_connection,
900 const Poco::Net::SocketAddress & address) = 0;
901
902 virtual ~IPlugin() = default;
903};
904
905/// https://dev.mysql.com/doc/internals/en/secure-password-authentication.html
906class Native41 : public IPlugin
907{
908public:
909 Native41()
910 {
911 scramble.resize(SCRAMBLE_LENGTH + 1, 0);
912 Poco::RandomInputStream generator;
913
914 for (size_t i = 0; i < SCRAMBLE_LENGTH; i++)
915 generator >> scramble[i];
916 }
917
918 String getName() override
919 {
920 return "mysql_native_password";
921 }
922
923 String getAuthPluginData() override
924 {
925 return scramble;
926 }
927
928 void authenticate(
929 const String & user_name,
930 std::optional<String> auth_response,
931 Context & context,
932 std::shared_ptr<PacketSender> packet_sender,
933 bool /* is_secure_connection */,
934 const Poco::Net::SocketAddress & address) override
935 {
936 if (!auth_response)
937 {
938 packet_sender->sendPacket(AuthSwitchRequest(getName(), scramble), true);
939 AuthSwitchResponse response;
940 packet_sender->receivePacket(response);
941 auth_response = response.value;
942 }
943
944 if (auth_response->empty())
945 {
946 context.setUser(user_name, "", address, "");
947 return;
948 }
949
950 if (auth_response->size() != Poco::SHA1Engine::DIGEST_SIZE)
951 throw Exception("Wrong size of auth response. Expected: " + std::to_string(Poco::SHA1Engine::DIGEST_SIZE) + " bytes, received: " + std::to_string(auth_response->size()) + " bytes.",
952 ErrorCodes::UNKNOWN_EXCEPTION);
953
954 auto user = context.getUser(user_name);
955
956 Poco::SHA1Engine::Digest double_sha1_value = user->authentication.getPasswordDoubleSHA1();
957 assert(double_sha1_value.size() == Poco::SHA1Engine::DIGEST_SIZE);
958
959 Poco::SHA1Engine engine;
960 engine.update(scramble.data(), SCRAMBLE_LENGTH);
961 engine.update(double_sha1_value.data(), double_sha1_value.size());
962
963 String password_sha1(Poco::SHA1Engine::DIGEST_SIZE, 0x0);
964 const Poco::SHA1Engine::Digest & digest = engine.digest();
965 for (size_t i = 0; i < password_sha1.size(); i++)
966 {
967 password_sha1[i] = digest[i] ^ static_cast<unsigned char>((*auth_response)[i]);
968 }
969 context.setUser(user_name, password_sha1, address, "");
970 }
971private:
972 String scramble;
973};
974
975#if USE_SSL
976/// Caching SHA2 plugin is not used because it would be possible to authenticate knowing hash from users.xml.
977/// https://dev.mysql.com/doc/internals/en/sha256.html
978class Sha256Password : public IPlugin
979{
980public:
981 Sha256Password(RSA & public_key_, RSA & private_key_, Logger * log_)
982 : public_key(public_key_)
983 , private_key(private_key_)
984 , log(log_)
985 {
986 /** Native authentication sent 20 bytes + '\0' character = 21 bytes.
987 * This plugin must do the same to stay consistent with historical behavior if it is set to operate as a default plugin. [1]
988 * https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L3994
989 */
990 scramble.resize(SCRAMBLE_LENGTH + 1, 0);
991 Poco::RandomInputStream generator;
992
993 for (size_t i = 0; i < SCRAMBLE_LENGTH; i++)
994 generator >> scramble[i];
995 }
996
997 String getName() override
998 {
999 return "sha256_password";
1000 }
1001
1002 String getAuthPluginData() override
1003 {
1004 return scramble;
1005 }
1006
1007 void authenticate(
1008 const String & user_name,
1009 std::optional<String> auth_response,
1010 Context & context,
1011 std::shared_ptr<PacketSender> packet_sender,
1012 bool is_secure_connection,
1013 const Poco::Net::SocketAddress & address) override
1014 {
1015 if (!auth_response)
1016 {
1017 packet_sender->sendPacket(AuthSwitchRequest(getName(), scramble), true);
1018
1019 if (packet_sender->in->eof())
1020 throw Exception("Client doesn't support authentication method " + getName() + " used by ClickHouse. Specifying user password using 'password_double_sha1_hex' may fix the problem.",
1021 ErrorCodes::MYSQL_CLIENT_INSUFFICIENT_CAPABILITIES);
1022
1023 AuthSwitchResponse response;
1024 packet_sender->receivePacket(response);
1025 auth_response = response.value;
1026 LOG_TRACE(log, "Authentication method mismatch.");
1027 }
1028 else
1029 {
1030 LOG_TRACE(log, "Authentication method match.");
1031 }
1032
1033 if (auth_response == "\1")
1034 {
1035 LOG_TRACE(log, "Client requests public key.");
1036 BIO * mem = BIO_new(BIO_s_mem());
1037 SCOPE_EXIT(BIO_free(mem));
1038 if (PEM_write_bio_RSA_PUBKEY(mem, &public_key) != 1)
1039 {
1040 throw Exception("Failed to write public key to memory. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
1041 }
1042 char * pem_buf = nullptr;
1043# pragma GCC diagnostic push
1044# pragma GCC diagnostic ignored "-Wold-style-cast"
1045 long pem_size = BIO_get_mem_data(mem, &pem_buf);
1046# pragma GCC diagnostic pop
1047 String pem(pem_buf, pem_size);
1048
1049 LOG_TRACE(log, "Key: " << pem);
1050
1051 AuthMoreData data(pem);
1052 packet_sender->sendPacket(data, true);
1053
1054 AuthSwitchResponse response;
1055 packet_sender->receivePacket(response);
1056 auth_response = response.value;
1057 }
1058 else
1059 {
1060 LOG_TRACE(log, "Client didn't request public key.");
1061 }
1062
1063 String password;
1064
1065 /** Decrypt password, if it's not empty.
1066 * The original intention was that the password is a string[NUL] but this never got enforced properly so now we have to accept that
1067 * an empty packet is a blank password, thus the check for auth_response.empty() has to be made too.
1068 * https://github.com/mysql/mysql-server/blob/8.0/sql/auth/sql_authentication.cc#L4017
1069 */
1070 if (!is_secure_connection && !auth_response->empty() && auth_response != String("\0", 1))
1071 {
1072 LOG_TRACE(log, "Received nonempty password");
1073 auto ciphertext = reinterpret_cast<unsigned char *>(auth_response->data());
1074
1075 unsigned char plaintext[RSA_size(&private_key)];
1076 int plaintext_size = RSA_private_decrypt(auth_response->size(), ciphertext, plaintext, &private_key, RSA_PKCS1_OAEP_PADDING);
1077 if (plaintext_size == -1)
1078 {
1079 throw Exception("Failed to decrypt auth data. Error: " + getOpenSSLErrors(), ErrorCodes::OPENSSL_ERROR);
1080 }
1081
1082 password.resize(plaintext_size);
1083 for (int i = 0; i < plaintext_size; i++)
1084 {
1085 password[i] = plaintext[i] ^ static_cast<unsigned char>(scramble[i % scramble.size()]);
1086 }
1087 }
1088 else if (is_secure_connection)
1089 {
1090 password = *auth_response;
1091 }
1092 else
1093 {
1094 LOG_TRACE(log, "Received empty password");
1095 }
1096
1097 if (!password.empty() && password.back() == 0)
1098 {
1099 password.pop_back();
1100 }
1101
1102 context.setUser(user_name, password, address, "");
1103 }
1104
1105private:
1106 RSA & public_key;
1107 RSA & private_key;
1108 Logger * log;
1109 String scramble;
1110};
1111#endif
1112
1113}
1114
1115}
1116}
1117