1 | /* |
2 | * This file is part of the MicroPython project, http://micropython.org/ |
3 | * |
4 | * The MIT License (MIT) |
5 | * |
6 | * Copyright (c) 2016 Linaro Ltd. |
7 | * Copyright (c) 2019 Paul Sokolovsky |
8 | * |
9 | * Permission is hereby granted, free of charge, to any person obtaining a copy |
10 | * of this software and associated documentation files (the "Software"), to deal |
11 | * in the Software without restriction, including without limitation the rights |
12 | * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
13 | * copies of the Software, and to permit persons to whom the Software is |
14 | * furnished to do so, subject to the following conditions: |
15 | * |
16 | * The above copyright notice and this permission notice shall be included in |
17 | * all copies or substantial portions of the Software. |
18 | * |
19 | * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
20 | * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
21 | * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
22 | * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
23 | * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
24 | * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN |
25 | * THE SOFTWARE. |
26 | */ |
27 | |
28 | #include "py/mpconfig.h" |
29 | #if MICROPY_PY_USSL && MICROPY_SSL_MBEDTLS |
30 | |
31 | #include <stdio.h> |
32 | #include <string.h> |
33 | #include <errno.h> // needed because mp_is_nonblocking_error uses system error codes |
34 | |
35 | #include "py/runtime.h" |
36 | #include "py/stream.h" |
37 | #include "py/objstr.h" |
38 | |
39 | // mbedtls_time_t |
40 | #include "mbedtls/platform.h" |
41 | #include "mbedtls/ssl.h" |
42 | #include "mbedtls/x509_crt.h" |
43 | #include "mbedtls/pk.h" |
44 | #include "mbedtls/entropy.h" |
45 | #include "mbedtls/ctr_drbg.h" |
46 | #include "mbedtls/debug.h" |
47 | #include "mbedtls/error.h" |
48 | |
49 | typedef struct _mp_obj_ssl_socket_t { |
50 | mp_obj_base_t base; |
51 | mp_obj_t sock; |
52 | mbedtls_entropy_context entropy; |
53 | mbedtls_ctr_drbg_context ctr_drbg; |
54 | mbedtls_ssl_context ssl; |
55 | mbedtls_ssl_config conf; |
56 | mbedtls_x509_crt cacert; |
57 | mbedtls_x509_crt cert; |
58 | mbedtls_pk_context pkey; |
59 | } mp_obj_ssl_socket_t; |
60 | |
61 | struct ssl_args { |
62 | mp_arg_val_t key; |
63 | mp_arg_val_t cert; |
64 | mp_arg_val_t server_side; |
65 | mp_arg_val_t server_hostname; |
66 | mp_arg_val_t do_handshake; |
67 | }; |
68 | |
69 | STATIC const mp_obj_type_t ussl_socket_type; |
70 | |
71 | #ifdef MBEDTLS_DEBUG_C |
72 | STATIC void mbedtls_debug(void *ctx, int level, const char *file, int line, const char *str) { |
73 | (void)ctx; |
74 | (void)level; |
75 | printf("DBG:%s:%04d: %s\n" , file, line, str); |
76 | } |
77 | #endif |
78 | |
79 | STATIC NORETURN void mbedtls_raise_error(int err) { |
80 | // _mbedtls_ssl_send and _mbedtls_ssl_recv (below) turn positive error codes from the |
81 | // underlying socket into negative codes to pass them through mbedtls. Here we turn them |
82 | // positive again so they get interpreted as the OSError they really are. The |
83 | // cut-off of -256 is a bit hacky, sigh. |
84 | if (err < 0 && err > -256) { |
85 | mp_raise_OSError(-err); |
86 | } |
87 | |
88 | #if defined(MBEDTLS_ERROR_C) |
89 | // Including mbedtls_strerror takes about 1.5KB due to the error strings. |
90 | // MBEDTLS_ERROR_C is the define used by mbedtls to conditionally include mbedtls_strerror. |
91 | // It is set/unset in the MBEDTLS_CONFIG_FILE which is defined in the Makefile. |
92 | |
93 | // Try to allocate memory for the message |
94 | #define ERR_STR_MAX 80 // mbedtls_strerror truncates if it doesn't fit |
95 | mp_obj_str_t *o_str = m_new_obj_maybe(mp_obj_str_t); |
96 | byte *o_str_buf = m_new_maybe(byte, ERR_STR_MAX); |
97 | if (o_str == NULL || o_str_buf == NULL) { |
98 | mp_raise_OSError(err); |
99 | } |
100 | |
101 | // print the error message into the allocated buffer |
102 | mbedtls_strerror(err, (char *)o_str_buf, ERR_STR_MAX); |
103 | size_t len = strlen((char *)o_str_buf); |
104 | |
105 | // Put the exception object together |
106 | o_str->base.type = &mp_type_str; |
107 | o_str->data = o_str_buf; |
108 | o_str->len = len; |
109 | o_str->hash = qstr_compute_hash(o_str->data, o_str->len); |
110 | // raise |
111 | mp_obj_t args[2] = { MP_OBJ_NEW_SMALL_INT(err), MP_OBJ_FROM_PTR(o_str)}; |
112 | nlr_raise(mp_obj_exception_make_new(&mp_type_OSError, 2, 0, args)); |
113 | #else |
114 | // mbedtls is compiled without error strings so we simply return the err number |
115 | mp_raise_OSError(err); // err is typically a large negative number |
116 | #endif |
117 | } |
118 | |
119 | STATIC int _mbedtls_ssl_send(void *ctx, const byte *buf, size_t len) { |
120 | mp_obj_t sock = *(mp_obj_t *)ctx; |
121 | |
122 | const mp_stream_p_t *sock_stream = mp_get_stream(sock); |
123 | int err; |
124 | |
125 | mp_uint_t out_sz = sock_stream->write(sock, buf, len, &err); |
126 | if (out_sz == MP_STREAM_ERROR) { |
127 | if (mp_is_nonblocking_error(err)) { |
128 | return MBEDTLS_ERR_SSL_WANT_WRITE; |
129 | } |
130 | return -err; // convert an MP_ERRNO to something mbedtls passes through as error |
131 | } else { |
132 | return out_sz; |
133 | } |
134 | } |
135 | |
136 | // _mbedtls_ssl_recv is called by mbedtls to receive bytes from the underlying socket |
137 | STATIC int _mbedtls_ssl_recv(void *ctx, byte *buf, size_t len) { |
138 | mp_obj_t sock = *(mp_obj_t *)ctx; |
139 | |
140 | const mp_stream_p_t *sock_stream = mp_get_stream(sock); |
141 | int err; |
142 | |
143 | mp_uint_t out_sz = sock_stream->read(sock, buf, len, &err); |
144 | if (out_sz == MP_STREAM_ERROR) { |
145 | if (mp_is_nonblocking_error(err)) { |
146 | return MBEDTLS_ERR_SSL_WANT_READ; |
147 | } |
148 | return -err; |
149 | } else { |
150 | return out_sz; |
151 | } |
152 | } |
153 | |
154 | |
155 | STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) { |
156 | // Verify the socket object has the full stream protocol |
157 | mp_get_stream_raise(sock, MP_STREAM_OP_READ | MP_STREAM_OP_WRITE | MP_STREAM_OP_IOCTL); |
158 | |
159 | #if MICROPY_PY_USSL_FINALISER |
160 | mp_obj_ssl_socket_t *o = m_new_obj_with_finaliser(mp_obj_ssl_socket_t); |
161 | #else |
162 | mp_obj_ssl_socket_t *o = m_new_obj(mp_obj_ssl_socket_t); |
163 | #endif |
164 | o->base.type = &ussl_socket_type; |
165 | o->sock = sock; |
166 | |
167 | int ret; |
168 | mbedtls_ssl_init(&o->ssl); |
169 | mbedtls_ssl_config_init(&o->conf); |
170 | mbedtls_x509_crt_init(&o->cacert); |
171 | mbedtls_x509_crt_init(&o->cert); |
172 | mbedtls_pk_init(&o->pkey); |
173 | mbedtls_ctr_drbg_init(&o->ctr_drbg); |
174 | #ifdef MBEDTLS_DEBUG_C |
175 | // Debug level (0-4) 1=warning, 2=info, 3=debug, 4=verbose |
176 | mbedtls_debug_set_threshold(0); |
177 | #endif |
178 | |
179 | mbedtls_entropy_init(&o->entropy); |
180 | const byte seed[] = "upy" ; |
181 | ret = mbedtls_ctr_drbg_seed(&o->ctr_drbg, mbedtls_entropy_func, &o->entropy, seed, sizeof(seed)); |
182 | if (ret != 0) { |
183 | goto cleanup; |
184 | } |
185 | |
186 | ret = mbedtls_ssl_config_defaults(&o->conf, |
187 | args->server_side.u_bool ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT, |
188 | MBEDTLS_SSL_TRANSPORT_STREAM, |
189 | MBEDTLS_SSL_PRESET_DEFAULT); |
190 | if (ret != 0) { |
191 | goto cleanup; |
192 | } |
193 | |
194 | mbedtls_ssl_conf_authmode(&o->conf, MBEDTLS_SSL_VERIFY_NONE); |
195 | mbedtls_ssl_conf_rng(&o->conf, mbedtls_ctr_drbg_random, &o->ctr_drbg); |
196 | #ifdef MBEDTLS_DEBUG_C |
197 | mbedtls_ssl_conf_dbg(&o->conf, mbedtls_debug, NULL); |
198 | #endif |
199 | |
200 | ret = mbedtls_ssl_setup(&o->ssl, &o->conf); |
201 | if (ret != 0) { |
202 | goto cleanup; |
203 | } |
204 | |
205 | if (args->server_hostname.u_obj != mp_const_none) { |
206 | const char *sni = mp_obj_str_get_str(args->server_hostname.u_obj); |
207 | ret = mbedtls_ssl_set_hostname(&o->ssl, sni); |
208 | if (ret != 0) { |
209 | goto cleanup; |
210 | } |
211 | } |
212 | |
213 | mbedtls_ssl_set_bio(&o->ssl, &o->sock, _mbedtls_ssl_send, _mbedtls_ssl_recv, NULL); |
214 | |
215 | if (args->key.u_obj != mp_const_none) { |
216 | size_t key_len; |
217 | const byte *key = (const byte *)mp_obj_str_get_data(args->key.u_obj, &key_len); |
218 | // len should include terminating null |
219 | ret = mbedtls_pk_parse_key(&o->pkey, key, key_len + 1, NULL, 0); |
220 | if (ret != 0) { |
221 | ret = MBEDTLS_ERR_PK_BAD_INPUT_DATA; // use general error for all key errors |
222 | goto cleanup; |
223 | } |
224 | |
225 | size_t cert_len; |
226 | const byte *cert = (const byte *)mp_obj_str_get_data(args->cert.u_obj, &cert_len); |
227 | // len should include terminating null |
228 | ret = mbedtls_x509_crt_parse(&o->cert, cert, cert_len + 1); |
229 | if (ret != 0) { |
230 | ret = MBEDTLS_ERR_X509_BAD_INPUT_DATA; // use general error for all cert errors |
231 | goto cleanup; |
232 | } |
233 | |
234 | ret = mbedtls_ssl_conf_own_cert(&o->conf, &o->cert, &o->pkey); |
235 | if (ret != 0) { |
236 | goto cleanup; |
237 | } |
238 | } |
239 | |
240 | if (args->do_handshake.u_bool) { |
241 | while ((ret = mbedtls_ssl_handshake(&o->ssl)) != 0) { |
242 | if (ret != MBEDTLS_ERR_SSL_WANT_READ && ret != MBEDTLS_ERR_SSL_WANT_WRITE) { |
243 | goto cleanup; |
244 | } |
245 | } |
246 | } |
247 | |
248 | return o; |
249 | |
250 | cleanup: |
251 | mbedtls_pk_free(&o->pkey); |
252 | mbedtls_x509_crt_free(&o->cert); |
253 | mbedtls_x509_crt_free(&o->cacert); |
254 | mbedtls_ssl_free(&o->ssl); |
255 | mbedtls_ssl_config_free(&o->conf); |
256 | mbedtls_ctr_drbg_free(&o->ctr_drbg); |
257 | mbedtls_entropy_free(&o->entropy); |
258 | |
259 | if (ret == MBEDTLS_ERR_SSL_ALLOC_FAILED) { |
260 | mp_raise_OSError(MP_ENOMEM); |
261 | } else if (ret == MBEDTLS_ERR_PK_BAD_INPUT_DATA) { |
262 | mp_raise_ValueError(MP_ERROR_TEXT("invalid key" )); |
263 | } else if (ret == MBEDTLS_ERR_X509_BAD_INPUT_DATA) { |
264 | mp_raise_ValueError(MP_ERROR_TEXT("invalid cert" )); |
265 | } else { |
266 | mbedtls_raise_error(ret); |
267 | } |
268 | } |
269 | |
270 | STATIC mp_obj_t mod_ssl_getpeercert(mp_obj_t o_in, mp_obj_t binary_form) { |
271 | mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); |
272 | if (!mp_obj_is_true(binary_form)) { |
273 | mp_raise_NotImplementedError(NULL); |
274 | } |
275 | const mbedtls_x509_crt *peer_cert = mbedtls_ssl_get_peer_cert(&o->ssl); |
276 | if (peer_cert == NULL) { |
277 | return mp_const_none; |
278 | } |
279 | return mp_obj_new_bytes(peer_cert->raw.p, peer_cert->raw.len); |
280 | } |
281 | STATIC MP_DEFINE_CONST_FUN_OBJ_2(mod_ssl_getpeercert_obj, mod_ssl_getpeercert); |
282 | |
283 | STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kind_t kind) { |
284 | (void)kind; |
285 | mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(self_in); |
286 | mp_printf(print, "<_SSLSocket %p>" , self); |
287 | } |
288 | |
289 | STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) { |
290 | mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); |
291 | |
292 | int ret = mbedtls_ssl_read(&o->ssl, buf, size); |
293 | if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { |
294 | // end of stream |
295 | return 0; |
296 | } |
297 | if (ret >= 0) { |
298 | return ret; |
299 | } |
300 | if (ret == MBEDTLS_ERR_SSL_WANT_READ) { |
301 | ret = MP_EWOULDBLOCK; |
302 | } else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { |
303 | // If handshake is not finished, read attempt may end up in protocol |
304 | // wanting to write next handshake message. The same may happen with |
305 | // renegotation. |
306 | ret = MP_EWOULDBLOCK; |
307 | } |
308 | *errcode = ret; |
309 | return MP_STREAM_ERROR; |
310 | } |
311 | |
312 | STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) { |
313 | mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in); |
314 | |
315 | int ret = mbedtls_ssl_write(&o->ssl, buf, size); |
316 | if (ret >= 0) { |
317 | return ret; |
318 | } |
319 | if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { |
320 | ret = MP_EWOULDBLOCK; |
321 | } else if (ret == MBEDTLS_ERR_SSL_WANT_READ) { |
322 | // If handshake is not finished, write attempt may end up in protocol |
323 | // wanting to read next handshake message. The same may happen with |
324 | // renegotation. |
325 | ret = MP_EWOULDBLOCK; |
326 | } |
327 | *errcode = ret; |
328 | return MP_STREAM_ERROR; |
329 | } |
330 | |
331 | STATIC mp_obj_t socket_setblocking(mp_obj_t self_in, mp_obj_t flag_in) { |
332 | mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(self_in); |
333 | mp_obj_t sock = o->sock; |
334 | mp_obj_t dest[3]; |
335 | mp_load_method(sock, MP_QSTR_setblocking, dest); |
336 | dest[2] = flag_in; |
337 | return mp_call_method_n_kw(1, 0, dest); |
338 | } |
339 | STATIC MP_DEFINE_CONST_FUN_OBJ_2(socket_setblocking_obj, socket_setblocking); |
340 | |
341 | STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, int *errcode) { |
342 | mp_obj_ssl_socket_t *self = MP_OBJ_TO_PTR(o_in); |
343 | if (request == MP_STREAM_CLOSE) { |
344 | mbedtls_pk_free(&self->pkey); |
345 | mbedtls_x509_crt_free(&self->cert); |
346 | mbedtls_x509_crt_free(&self->cacert); |
347 | mbedtls_ssl_free(&self->ssl); |
348 | mbedtls_ssl_config_free(&self->conf); |
349 | mbedtls_ctr_drbg_free(&self->ctr_drbg); |
350 | mbedtls_entropy_free(&self->entropy); |
351 | } |
352 | // Pass all requests down to the underlying socket |
353 | return mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode); |
354 | } |
355 | |
356 | STATIC const mp_rom_map_elem_t ussl_socket_locals_dict_table[] = { |
357 | { MP_ROM_QSTR(MP_QSTR_read), MP_ROM_PTR(&mp_stream_read_obj) }, |
358 | { MP_ROM_QSTR(MP_QSTR_readinto), MP_ROM_PTR(&mp_stream_readinto_obj) }, |
359 | { MP_ROM_QSTR(MP_QSTR_readline), MP_ROM_PTR(&mp_stream_unbuffered_readline_obj) }, |
360 | { MP_ROM_QSTR(MP_QSTR_write), MP_ROM_PTR(&mp_stream_write_obj) }, |
361 | { MP_ROM_QSTR(MP_QSTR_setblocking), MP_ROM_PTR(&socket_setblocking_obj) }, |
362 | { MP_ROM_QSTR(MP_QSTR_close), MP_ROM_PTR(&mp_stream_close_obj) }, |
363 | #if MICROPY_PY_USSL_FINALISER |
364 | { MP_ROM_QSTR(MP_QSTR___del__), MP_ROM_PTR(&mp_stream_close_obj) }, |
365 | #endif |
366 | { MP_ROM_QSTR(MP_QSTR_getpeercert), MP_ROM_PTR(&mod_ssl_getpeercert_obj) }, |
367 | }; |
368 | |
369 | STATIC MP_DEFINE_CONST_DICT(ussl_socket_locals_dict, ussl_socket_locals_dict_table); |
370 | |
371 | STATIC const mp_stream_p_t ussl_socket_stream_p = { |
372 | .read = socket_read, |
373 | .write = socket_write, |
374 | .ioctl = socket_ioctl, |
375 | }; |
376 | |
377 | STATIC const mp_obj_type_t ussl_socket_type = { |
378 | { &mp_type_type }, |
379 | // Save on qstr's, reuse same as for module |
380 | .name = MP_QSTR_ussl, |
381 | .print = socket_print, |
382 | .getiter = NULL, |
383 | .iternext = NULL, |
384 | .protocol = &ussl_socket_stream_p, |
385 | .locals_dict = (void *)&ussl_socket_locals_dict, |
386 | }; |
387 | |
388 | STATIC mp_obj_t mod_ssl_wrap_socket(size_t n_args, const mp_obj_t *pos_args, mp_map_t *kw_args) { |
389 | // TODO: Implement more args |
390 | static const mp_arg_t allowed_args[] = { |
391 | { MP_QSTR_key, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, |
392 | { MP_QSTR_cert, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, |
393 | { MP_QSTR_server_side, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = false} }, |
394 | { MP_QSTR_server_hostname, MP_ARG_KW_ONLY | MP_ARG_OBJ, {.u_rom_obj = MP_ROM_NONE} }, |
395 | { MP_QSTR_do_handshake, MP_ARG_KW_ONLY | MP_ARG_BOOL, {.u_bool = true} }, |
396 | }; |
397 | |
398 | // TODO: Check that sock implements stream protocol |
399 | mp_obj_t sock = pos_args[0]; |
400 | |
401 | struct ssl_args args; |
402 | mp_arg_parse_all(n_args - 1, pos_args + 1, kw_args, |
403 | MP_ARRAY_SIZE(allowed_args), allowed_args, (mp_arg_val_t *)&args); |
404 | |
405 | return MP_OBJ_FROM_PTR(socket_new(sock, &args)); |
406 | } |
407 | STATIC MP_DEFINE_CONST_FUN_OBJ_KW(mod_ssl_wrap_socket_obj, 1, mod_ssl_wrap_socket); |
408 | |
409 | STATIC const mp_rom_map_elem_t mp_module_ssl_globals_table[] = { |
410 | { MP_ROM_QSTR(MP_QSTR___name__), MP_ROM_QSTR(MP_QSTR_ussl) }, |
411 | { MP_ROM_QSTR(MP_QSTR_wrap_socket), MP_ROM_PTR(&mod_ssl_wrap_socket_obj) }, |
412 | }; |
413 | |
414 | STATIC MP_DEFINE_CONST_DICT(mp_module_ssl_globals, mp_module_ssl_globals_table); |
415 | |
416 | const mp_obj_module_t mp_module_ussl = { |
417 | .base = { &mp_type_module }, |
418 | .globals = (mp_obj_dict_t *)&mp_module_ssl_globals, |
419 | }; |
420 | |
421 | #endif // MICROPY_PY_USSL |
422 | |