diff options
| author | Garrett D'Amore <garrett@damore.org> | 2024-07-20 16:47:25 -0700 |
|---|---|---|
| committer | Garrett D'Amore <garrett@damore.org> | 2024-07-21 14:23:08 -0700 |
| commit | 0aeed90d9a85eaf6f00e81c6f5f69a7ed9fec8c6 (patch) | |
| tree | 9f1acaa0bc8569a9e8e88e203fddd877f0dbab99 /src | |
| parent | c0b93b441199619d27a1caf201a8c410f4246cf4 (diff) | |
| download | nng-0aeed90d9a85eaf6f00e81c6f5f69a7ed9fec8c6.tar.gz nng-0aeed90d9a85eaf6f00e81c6f5f69a7ed9fec8c6.tar.bz2 nng-0aeed90d9a85eaf6f00e81c6f5f69a7ed9fec8c6.zip | |
fixes #1846 Add support for TLS PSK
This also adds an SP layer transport test for TLS, based on the TCP
test but with some additions; this test does not cover all the edge
cases for TLS, but it does at least show how to use it.
Diffstat (limited to 'src')
| -rw-r--r-- | src/sp/transport/tls/CMakeLists.txt | 5 | ||||
| -rw-r--r-- | src/sp/transport/tls/tls_tran_test.c | 430 | ||||
| -rw-r--r-- | src/supplemental/tls/mbedtls/tls.c | 110 | ||||
| -rw-r--r-- | src/supplemental/tls/tls_common.c | 35 | ||||
| -rw-r--r-- | src/supplemental/tls/tls_test.c | 311 |
5 files changed, 879 insertions, 12 deletions
diff --git a/src/sp/transport/tls/CMakeLists.txt b/src/sp/transport/tls/CMakeLists.txt index 82f24c79..fa3290cc 100644 --- a/src/sp/transport/tls/CMakeLists.txt +++ b/src/sp/transport/tls/CMakeLists.txt @@ -1,5 +1,5 @@ # -# Copyright 2020 Staysail Systems, Inc. <info@staysail.tech> +# Copyright 2024 Staysail Systems, Inc. <info@staysail.tech> # Copyright 2018 Capitar IT Group BV <info@capitar.com> # # This software is supplied under the terms of the MIT License, a @@ -13,4 +13,5 @@ nng_directory(tls) nng_sources_if(NNG_TRANSPORT_TLS tls.c) nng_headers_if(NNG_TRANSPORT_TLS nng/transport/tls/tls.h) -nng_defines_if(NNG_TRANSPORT_TLS NNG_TRANSPORT_TLS)
\ No newline at end of file +nng_defines_if(NNG_TRANSPORT_TLS NNG_TRANSPORT_TLS) +nng_test_if(NNG_TRANSPORT_TLS tls_tran_test) diff --git a/src/sp/transport/tls/tls_tran_test.c b/src/sp/transport/tls/tls_tran_test.c new file mode 100644 index 00000000..94794d6e --- /dev/null +++ b/src/sp/transport/tls/tls_tran_test.c @@ -0,0 +1,430 @@ +// +// Copyright 2024 Staysail Systems, Inc. <info@staysail.tech> +// Copyright 2018 Capitar IT Group BV <info@capitar.com> +// Copyright 2018 Devolutions <info@devolutions.net> +// Copyright 2018 Cody Piersall <cody.piersall@gmail.com> +// +// This software is supplied under the terms of the MIT License, a +// copy of which should be located in the distribution where this +// file was obtained (LICENSE.txt). A copy of the license may also be +// found online at https://opensource.org/licenses/MIT. +// + +#include "nng/nng.h" +#include "nng/supplemental/tls/tls.h" +#include <nuts.h> + +// TLS tests. + +static nng_tls_config * +tls_server_config(void) +{ + nng_tls_config *c; + NUTS_PASS(nng_tls_config_alloc(&c, NNG_TLS_MODE_SERVER)); + NUTS_PASS(nng_tls_config_own_cert( + c, nuts_server_crt, nuts_server_key, NULL)); + return (c); +} + +static nng_tls_config * +tls_config_psk(nng_tls_mode mode, const char *name, uint8_t *key, size_t len) +{ + nng_tls_config *c; + NUTS_PASS(nng_tls_config_alloc(&c, mode)); + NUTS_PASS(nng_tls_config_psk(c, name, key, len)); + return (c); +} + +static nng_tls_config * +tls_client_config(void) +{ + nng_tls_config *c; + NUTS_PASS(nng_tls_config_alloc(&c, NNG_TLS_MODE_CLIENT)); + NUTS_PASS(nng_tls_config_own_cert( + c, nuts_client_crt, nuts_client_key, NULL)); + NUTS_PASS(nng_tls_config_ca_chain(c, nuts_server_crt, NULL)); + return (c); +} + +static void +test_tls_wild_card_connect_fail(void) +{ + nng_socket s; + nng_tls_config *c; + char addr[NNG_MAXADDRLEN]; + + NUTS_OPEN(s); + c = tls_client_config(); + nng_socket_set_ptr(s, NNG_OPT_TLS_CONFIG, c); + (void) snprintf( + addr, sizeof(addr), "tls+tcp://*:%u", nuts_next_port()); + NUTS_FAIL(nng_dial(s, addr, NULL, 0), NNG_EADDRINVAL); + NUTS_CLOSE(s); + nng_tls_config_free(c); +} + +void +test_tls_wild_card_bind(void) +{ + nng_socket s1; + nng_socket s2; + char addr[NNG_MAXADDRLEN]; + uint16_t port; + nng_tls_config *cc; + nng_tls_config *sc; + + port = nuts_next_port(); + + sc = tls_server_config(); + cc = tls_client_config(); + + NUTS_OPEN(s1); + NUTS_OPEN(s2); + (void) snprintf(addr, sizeof(addr), "tls+tcp4://*:%u", port); + nng_socket_set_ptr(s1, NNG_OPT_TLS_CONFIG, sc); + nng_socket_set_ptr(s2, NNG_OPT_TLS_CONFIG, cc); + NUTS_PASS(nng_listen(s1, addr, NULL, 0)); + (void) snprintf(addr, sizeof(addr), "tls+tcp://127.0.0.1:%u", port); + NUTS_PASS(nng_dial(s2, addr, NULL, 0)); + NUTS_CLOSE(s2); + NUTS_CLOSE(s1); + nng_tls_config_free(cc); + nng_tls_config_free(sc); +} + +void +test_tls_port_zero_bind(void) +{ + nng_socket s1; + nng_socket s2; + nng_tls_config *c1, *c2; + nng_sockaddr sa; + nng_listener l; + char *addr; + + c1 = tls_server_config(); + c2 = tls_client_config(); + NUTS_OPEN(s1); + NUTS_OPEN(s2); + nng_socket_set_ptr(s1, NNG_OPT_TLS_CONFIG, c1); + nng_socket_set_ptr(s2, NNG_OPT_TLS_CONFIG, c2); + NUTS_PASS(nng_listen(s1, "tls+tcp://127.0.0.1:0", &l, 0)); + NUTS_PASS(nng_listener_get_string(l, NNG_OPT_URL, &addr)); + NUTS_TRUE(memcmp(addr, "tls+tcp://", 6) == 0); + NUTS_PASS(nng_listener_get_addr(l, NNG_OPT_LOCADDR, &sa)); + NUTS_TRUE(sa.s_in.sa_family == NNG_AF_INET); + NUTS_TRUE(sa.s_in.sa_port != 0); + NUTS_TRUE(sa.s_in.sa_addr = nuts_be32(0x7f000001)); + NUTS_PASS(nng_dial(s2, addr, NULL, 0)); + nng_strfree(addr); + NUTS_CLOSE(s2); + NUTS_CLOSE(s1); + nng_tls_config_free(c1); + nng_tls_config_free(c2); +} + +void +test_tls_local_address_connect(void) +{ + + nng_socket s1; + nng_socket s2; + nng_tls_config *c1, *c2; + char addr[NNG_MAXADDRLEN]; + uint16_t port; + + c1 = tls_server_config(); + c2 = tls_client_config(); + NUTS_OPEN(s1); + NUTS_OPEN(s2); + nng_socket_set_ptr(s1, NNG_OPT_TLS_CONFIG, c1); + nng_socket_set_ptr(s2, NNG_OPT_TLS_CONFIG, c2); + port = nuts_next_port(); + (void) snprintf(addr, sizeof(addr), "tls+tcp://127.0.0.1:%u", port); + NUTS_PASS(nng_listen(s1, addr, NULL, 0)); + (void) snprintf( + addr, sizeof(addr), "tls+tcp://127.0.0.1;127.0.0.1:%u", port); + NUTS_PASS(nng_dial(s2, addr, NULL, 0)); + NUTS_CLOSE(s2); + NUTS_CLOSE(s1); + nng_tls_config_free(c1); + nng_tls_config_free(c2); +} + +void +test_tls_bad_local_interface(void) +{ + nng_socket s1; + nng_tls_config *c1; + int rv; + + c1 = tls_client_config(); + NUTS_OPEN(s1); + nng_socket_set_ptr(s1, NNG_OPT_TLS_CONFIG, c1); + nng_tls_config_free(c1); // ref count held by socket + rv = nng_dial(s1, "tcp://bogus1;127.0.0.1:80", NULL, 0), + NUTS_TRUE(rv != 0); + NUTS_TRUE(rv != NNG_ECONNREFUSED); + NUTS_CLOSE(s1); +} + +void +test_tls_non_local_address(void) +{ + nng_socket s1; + nng_tls_config *c1; + + c1 = tls_client_config(); + NUTS_OPEN(s1); + nng_socket_set_ptr(s1, NNG_OPT_TLS_CONFIG, c1); + NUTS_FAIL(nng_dial(s1, "tls+tcp://8.8.8.8;127.0.0.1:80", NULL, 0), + NNG_EADDRINVAL); + NUTS_CLOSE(s1); + nng_tls_config_free(c1); +} + +void +test_tls_malformed_address(void) +{ + nng_socket s1; + nng_tls_config *c1; + + NUTS_OPEN(s1); + c1 = tls_client_config(); + nng_socket_set_ptr(s1, NNG_OPT_TLS_CONFIG, c1); + nng_tls_config_free(c1); + NUTS_FAIL( + nng_dial(s1, "tls+tcp://127.0.0.1", NULL, 0), NNG_EADDRINVAL); + NUTS_FAIL( + nng_dial(s1, "tls+tcp://127.0.0.1.32", NULL, 0), NNG_EADDRINVAL); + NUTS_FAIL( + nng_dial(s1, "tls+tcp://127.0.x.1.32", NULL, 0), NNG_EADDRINVAL); + NUTS_FAIL( + nng_listen(s1, "tls+tcp://127.0.0.1.32", NULL, 0), NNG_EADDRINVAL); + NUTS_FAIL( + nng_listen(s1, "tls+tcp://127.0.x.1.32", NULL, 0), NNG_EADDRINVAL); + NUTS_CLOSE(s1); +} + +void +test_tls_no_delay_option(void) +{ + nng_socket s; + nng_dialer d; + nng_listener l; + bool v; + int x; + char *addr; + nng_tls_config *dc, *lc; + + NUTS_ADDR(addr, "tls+tcp"); + dc = tls_client_config(); + lc = tls_server_config(); + + NUTS_OPEN(s); +#ifndef NNG_ELIDE_DEPRECATED + NUTS_PASS(nng_socket_get_bool(s, NNG_OPT_TCP_NODELAY, &v)); + NUTS_TRUE(v); +#endif + NUTS_PASS(nng_dialer_create(&d, s, addr)); + NUTS_PASS(nng_dialer_set_ptr(d, NNG_OPT_TLS_CONFIG, dc)); + NUTS_PASS(nng_dialer_get_bool(d, NNG_OPT_TCP_NODELAY, &v)); + NUTS_TRUE(v); + NUTS_PASS(nng_dialer_set_bool(d, NNG_OPT_TCP_NODELAY, false)); + NUTS_PASS(nng_dialer_get_bool(d, NNG_OPT_TCP_NODELAY, &v)); + NUTS_TRUE(v == false); + NUTS_FAIL( + nng_dialer_get_int(d, NNG_OPT_TCP_NODELAY, &x), NNG_EBADTYPE); + x = 0; + NUTS_FAIL(nng_dialer_set_int(d, NNG_OPT_TCP_NODELAY, x), NNG_EBADTYPE); + // This assumes sizeof (bool) != sizeof (int) + if (sizeof(bool) != sizeof(int)) { + NUTS_FAIL( + nng_dialer_set(d, NNG_OPT_TCP_NODELAY, &x, sizeof(x)), + NNG_EINVAL); + } + + NUTS_PASS(nng_listener_create(&l, s, addr)); + NUTS_PASS(nng_listener_set_ptr(l, NNG_OPT_TLS_CONFIG, lc)); + NUTS_PASS(nng_listener_get_bool(l, NNG_OPT_TCP_NODELAY, &v)); + NUTS_TRUE(v == true); + x = 0; + NUTS_FAIL( + nng_listener_set_int(l, NNG_OPT_TCP_NODELAY, x), NNG_EBADTYPE); + // This assumes sizeof (bool) != sizeof (int) + NUTS_FAIL(nng_listener_set(l, NNG_OPT_TCP_NODELAY, &x, sizeof(x)), + NNG_EINVAL); + + NUTS_PASS(nng_dialer_close(d)); + NUTS_PASS(nng_listener_close(l)); + + // Make sure socket wide defaults apply. +#ifndef NNG_ELIDE_DEPRECATED + NUTS_PASS(nng_socket_set_bool(s, NNG_OPT_TCP_NODELAY, true)); + v = false; + NUTS_PASS(nng_socket_get_bool(s, NNG_OPT_TCP_NODELAY, &v)); + NUTS_TRUE(v); + NUTS_PASS(nng_socket_set_bool(s, NNG_OPT_TCP_NODELAY, false)); + NUTS_PASS(nng_dialer_create(&d, s, addr)); + NUTS_PASS(nng_dialer_get_bool(d, NNG_OPT_TCP_NODELAY, &v)); + NUTS_TRUE(v == false); +#endif + NUTS_CLOSE(s); + nng_tls_config_free(lc); + nng_tls_config_free(dc); +} + +void +test_tls_keep_alive_option(void) +{ + nng_socket s; + nng_dialer d; + nng_listener l; + nng_tls_config *dc, *lc; + bool v; + int x; + char *addr; + + dc = tls_client_config(); + lc = tls_server_config(); + NUTS_ADDR(addr, "tls+tcp"); + NUTS_OPEN(s); +#ifndef NNG_ELIDE_DEPRECATED + NUTS_PASS(nng_socket_get_bool(s, NNG_OPT_TCP_KEEPALIVE, &v)); + NUTS_TRUE(v == false); +#endif + NUTS_PASS(nng_dialer_create(&d, s, addr)); + NUTS_PASS(nng_dialer_set_ptr(d, NNG_OPT_TLS_CONFIG, dc)); + NUTS_PASS(nng_dialer_get_bool(d, NNG_OPT_TCP_KEEPALIVE, &v)); + NUTS_TRUE(v == false); + NUTS_PASS(nng_dialer_set_bool(d, NNG_OPT_TCP_KEEPALIVE, true)); + NUTS_PASS(nng_dialer_get_bool(d, NNG_OPT_TCP_KEEPALIVE, &v)); + NUTS_TRUE(v); + NUTS_FAIL( + nng_dialer_get_int(d, NNG_OPT_TCP_KEEPALIVE, &x), NNG_EBADTYPE); + x = 1; + NUTS_FAIL( + nng_dialer_set_int(d, NNG_OPT_TCP_KEEPALIVE, x), NNG_EBADTYPE); + + NUTS_PASS(nng_listener_create(&l, s, addr)); + NUTS_PASS(nng_listener_set_ptr(l, NNG_OPT_TLS_CONFIG, lc)); + NUTS_PASS(nng_listener_get_bool(l, NNG_OPT_TCP_KEEPALIVE, &v)); + NUTS_TRUE(v == false); + x = 1; + NUTS_FAIL( + nng_listener_set_int(l, NNG_OPT_TCP_KEEPALIVE, x), NNG_EBADTYPE); + + NUTS_PASS(nng_dialer_close(d)); + NUTS_PASS(nng_listener_close(l)); + + // Make sure socket wide defaults apply. +#ifndef NNG_ELIDE_DEPRECATED + NUTS_PASS(nng_socket_set_bool(s, NNG_OPT_TCP_KEEPALIVE, false)); + v = true; + NUTS_PASS(nng_socket_get_bool(s, NNG_OPT_TCP_KEEPALIVE, &v)); + NUTS_TRUE(v == false); + NUTS_PASS(nng_socket_set_bool(s, NNG_OPT_TCP_KEEPALIVE, true)); + NUTS_PASS(nng_dialer_create(&d, s, addr)); + NUTS_PASS(nng_dialer_get_bool(d, NNG_OPT_TCP_KEEPALIVE, &v)); + NUTS_TRUE(v); +#endif + NUTS_CLOSE(s); + nng_tls_config_free(lc); + nng_tls_config_free(dc); +} + +void +test_tls_recv_max(void) +{ + char msg[256]; + char buf[256]; + nng_socket s0; + nng_socket s1; + nng_tls_config *c0, *c1; + nng_listener l; + size_t sz; + char *addr; + + NUTS_ADDR(addr, "tls+tcp"); + + c0 = tls_server_config(); + c1 = tls_client_config(); + NUTS_OPEN(s0); + NUTS_PASS(nng_socket_set_ms(s0, NNG_OPT_RECVTIMEO, 100)); + NUTS_PASS(nng_socket_set_size(s0, NNG_OPT_RECVMAXSZ, 200)); + NUTS_PASS(nng_listener_create(&l, s0, addr)); + NUTS_PASS(nng_listener_set_ptr(l, NNG_OPT_TLS_CONFIG, c0)); + NUTS_PASS(nng_socket_get_size(s0, NNG_OPT_RECVMAXSZ, &sz)); + NUTS_TRUE(sz == 200); + NUTS_PASS(nng_listener_set_size(l, NNG_OPT_RECVMAXSZ, 100)); + NUTS_PASS(nng_listener_start(l, 0)); + + NUTS_OPEN(s1); + NUTS_PASS(nng_socket_set_ptr(s1, NNG_OPT_TLS_CONFIG, c1)); + NUTS_PASS(nng_dial(s1, addr, NULL, 0)); + NUTS_PASS(nng_send(s1, msg, 95, 0)); + NUTS_PASS(nng_socket_set_ms(s1, NNG_OPT_SENDTIMEO, 100)); + NUTS_PASS(nng_recv(s0, buf, &sz, 0)); + NUTS_TRUE(sz == 95); + NUTS_PASS(nng_send(s1, msg, 150, 0)); + NUTS_FAIL(nng_recv(s0, buf, &sz, 0), NNG_ETIMEDOUT); + NUTS_PASS(nng_close(s0)); + NUTS_CLOSE(s1); + nng_tls_config_free(c0); + nng_tls_config_free(c1); +} + +void +test_tls_psk(void) +{ + char msg[256]; + char buf[256]; + nng_socket s0; + nng_socket s1; + nng_tls_config *c0, *c1; + nng_listener l; + size_t sz; + char *addr; + uint8_t key[32]; + + for (unsigned i = 0; i < sizeof(key); i++) { + key[i] = rand() % 0xff; + } + + NUTS_ADDR(addr, "tls+tcp"); + + c0 = tls_config_psk(NNG_TLS_MODE_SERVER, "identity", key, sizeof key); + c1 = tls_config_psk(NNG_TLS_MODE_CLIENT, "identity", key, sizeof key); + NUTS_OPEN(s0); + NUTS_PASS(nng_socket_set_ms(s0, NNG_OPT_RECVTIMEO, 100)); + NUTS_PASS(nng_listener_create(&l, s0, addr)); + NUTS_PASS(nng_listener_set_ptr(l, NNG_OPT_TLS_CONFIG, c0)); + NUTS_PASS(nng_listener_start(l, 0)); + + NUTS_OPEN(s1); + NUTS_PASS(nng_socket_set_ptr(s1, NNG_OPT_TLS_CONFIG, c1)); + NUTS_PASS(nng_dial(s1, addr, NULL, 0)); + NUTS_PASS(nng_send(s1, msg, 95, 0)); + NUTS_PASS(nng_recv(s0, buf, &sz, 0)); + NUTS_TRUE(sz == 95); + NUTS_PASS(nng_close(s0)); + NUTS_CLOSE(s1); + nng_tls_config_free(c0); + nng_tls_config_free(c1); +} + +NUTS_TESTS = { + + { "tls wild card connect fail", test_tls_wild_card_connect_fail }, + { "tls wild card bind", test_tls_wild_card_bind }, + { "tls port zero bind", test_tls_port_zero_bind }, + { "tls local address connect", test_tls_local_address_connect }, + { "tls bad local interface", test_tls_bad_local_interface }, + { "tls non-local address", test_tls_non_local_address }, + { "tls malformed address", test_tls_malformed_address }, + { "tls no delay option", test_tls_no_delay_option }, + { "tls keep alive option", test_tls_keep_alive_option }, + { "tls recv max", test_tls_recv_max }, + { "tls preshared key", test_tls_psk }, + { NULL, NULL }, +}; diff --git a/src/supplemental/tls/mbedtls/tls.c b/src/supplemental/tls/mbedtls/tls.c index 8f09558a..3e424cc8 100644 --- a/src/supplemental/tls/mbedtls/tls.c +++ b/src/supplemental/tls/mbedtls/tls.c @@ -17,7 +17,9 @@ #include "mbedtls/version.h" // Must be first in order to pick up version #include "mbedtls/error.h" + #include "nng/nng.h" +#include "nng/supplemental/tls/tls.h" // mbedTLS renamed this header for 2.4.0. #if MBEDTLS_VERSION_MAJOR > 2 || MBEDTLS_VERSION_MINOR >= 4 @@ -38,6 +40,35 @@ typedef struct { nni_list_node node; } pair; +// psk holds an identity and preshared key +typedef struct { + // NB: Technically RFC 4279 requires this be UTF-8 string, although + // others suggest making it opaque bytes. We treat it as a C string, + // so we cannot support embedded zero bytes. + char *identity; + uint8_t *key; + size_t keylen; + nni_list_node node; +} psk; + +static void +psk_free(psk *p) +{ + if (p != NULL) { + NNI_ASSERT(!nni_list_node_active(&p->node)); + if (p->identity != NULL) { + nni_strfree(p->identity); + p->identity = NULL; + } + if (p->key != NULL && p->keylen != 0) { + nni_free(p->key, p->keylen); + p->key = NULL; + p->keylen = 0; + } + NNI_FREE_STRUCT(p); + } +} + #ifdef NNG_TLS_USE_CTR_DRBG // Use a global RNG if we're going to override the builtin. static mbedtls_ctr_drbg_context rng_ctx; @@ -56,7 +87,9 @@ struct nng_tls_engine_config { mbedtls_x509_crl crl; int min_ver; int max_ver; + nng_tls_mode mode; nni_list pairs; + nni_list psks; }; static void @@ -371,6 +404,7 @@ static void config_fini(nng_tls_engine_config *cfg) { pair *p; + psk *psk; mbedtls_ssl_config_free(&cfg->cfg_ctx); #ifdef NNG_TLS_USE_CTR_DRBG @@ -381,13 +415,17 @@ config_fini(nng_tls_engine_config *cfg) if (cfg->server_name) { nni_strfree(cfg->server_name); } - while ((p = nni_list_first(&cfg->pairs))) { + while ((p = nni_list_first(&cfg->pairs)) != NULL) { nni_list_remove(&cfg->pairs, p); mbedtls_x509_crt_free(&p->crt); mbedtls_pk_free(&p->key); NNI_FREE_STRUCT(p); } + while ((psk = nni_list_first(&cfg->psks)) != NULL) { + nni_list_remove(&cfg->psks, psk); + psk_free(psk); + } } static int @@ -405,7 +443,9 @@ config_init(nng_tls_engine_config *cfg, enum nng_tls_mode mode) auth_mode = MBEDTLS_SSL_VERIFY_REQUIRED; } + cfg->mode = mode; NNI_LIST_INIT(&cfg->pairs, pair, node); + NNI_LIST_INIT(&cfg->psks, psk, node); mbedtls_ssl_config_init(&cfg->cfg_ctx); mbedtls_x509_crt_init(&cfg->ca_certs); mbedtls_x509_crl_init(&cfg->crl); @@ -452,6 +492,73 @@ config_server_name(nng_tls_engine_config *cfg, const char *name) return (0); } +// callback used on the server side to select the right key +static int +config_psk_cb(void *arg, mbedtls_ssl_context *ssl, + const unsigned char *identity, size_t id_len) +{ + nng_tls_engine_config *cfg = arg; + psk *psk; + NNI_LIST_FOREACH (&cfg->psks, psk) { + if (id_len == strlen(psk->identity) && + (memcmp(identity, psk->identity, id_len) == 0)) { + nng_log_debug("NNG-TLS-PSK-IDENTITY", + "TLS client using PSK identity %s", psk->identity); + return (mbedtls_ssl_set_hs_psk( + ssl, psk->key, psk->keylen)); + } + } + nng_log_warn( + "NNG-TLS-PSK-NO-IDENTITY", "TLS client PSK identity not found"); + return (MBEDTLS_ERR_SSL_UNKNOWN_IDENTITY); +} + +static int +config_psk(nng_tls_engine_config *cfg, const char *identity, + const uint8_t *key, size_t key_len) +{ + int rv; + psk *srch; + psk *newpsk; + + if (((newpsk = NNI_ALLOC_STRUCT(newpsk)) == NULL) || + ((newpsk->identity = nni_strdup(identity)) == NULL) || + ((newpsk->key = nni_alloc(key_len)) == NULL)) { + psk_free(newpsk); + return (NNG_ENOMEM); + } + newpsk->keylen = key_len; + memcpy(newpsk->key, key, key_len); + + if (cfg->mode == NNG_TLS_MODE_SERVER) { + if (nni_list_empty(&cfg->psks)) { + mbedtls_ssl_conf_psk_cb( + &cfg->cfg_ctx, config_psk_cb, cfg); + } + } else { + if ((rv = mbedtls_ssl_conf_psk(&cfg->cfg_ctx, key, key_len, + (const unsigned char *) identity, + strlen(identity))) != 0) { + psk_free(newpsk); + tls_log_err("NNG-TLS-PSK-FAIL", + "Failed to configure PSK identity", rv); + return (tls_mk_err(rv)); + } + } + + // If the identity was previously configured, replace it. + // The rule here is that last one wins, so we always append. + NNI_LIST_FOREACH (&cfg->psks, srch) { + if (strcmp(srch->identity, identity) == 0) { + nni_list_remove(&cfg->psks, srch); + psk_free(srch); + break; + } + } + nni_list_append(&cfg->psks, newpsk); + return (0); +} + static int config_auth_mode(nng_tls_engine_config *cfg, nng_tls_auth_mode mode) { @@ -630,6 +737,7 @@ static nng_tls_engine_config_ops config_ops = { .ca_chain = config_ca_chain, .own_cert = config_own_cert, .server = config_server_name, + .psk = config_psk, .version = config_version, }; diff --git a/src/supplemental/tls/tls_common.c b/src/supplemental/tls/tls_common.c index 65805b03..62296de0 100644 --- a/src/supplemental/tls/tls_common.c +++ b/src/supplemental/tls/tls_common.c @@ -1,5 +1,5 @@ // -// Copyright 2021 Staysail Systems, Inc. <info@staysail.tech> +// Copyright 2024 Staysail Systems, Inc. <info@staysail.tech> // Copyright 2018 Capitar IT Group BV <info@capitar.com> // Copyright 2019 Devolutions <info@devolutions.net> // @@ -46,7 +46,7 @@ struct nng_tls_config { const nng_tls_engine *engine; // store this so we can verify nni_mtx lock; int ref; - int busy; + bool busy; size_t size; // ... engine config data follows @@ -843,6 +843,10 @@ tls_alloc(tls_conn **conn_p, nng_tls_config *cfg, nng_aio *user_aio) eng = cfg->engine; + nni_mtx_lock(&cfg->lock); + cfg->busy = true; + nni_mtx_unlock(&cfg->lock); + size = NNI_ALIGN_UP(sizeof(*conn)) + eng->conn_ops->size; if ((conn = nni_zalloc(size)) == NULL) { @@ -1325,7 +1329,7 @@ nng_tls_config_version( int rv; nni_mtx_lock(&cfg->lock); - if (cfg->busy != 0) { + if (cfg->busy) { rv = NNG_EBUSY; } else { rv = cfg->ops.version((void *) (cfg + 1), min_ver, max_ver); @@ -1340,7 +1344,7 @@ nng_tls_config_server_name(nng_tls_config *cfg, const char *name) int rv; nni_mtx_lock(&cfg->lock); - if (cfg->busy != 0) { + if (cfg->busy) { rv = NNG_EBUSY; } else { rv = cfg->ops.server((void *) (cfg + 1), name); @@ -1356,7 +1360,7 @@ nng_tls_config_ca_chain( int rv; nni_mtx_lock(&cfg->lock); - if (cfg->busy != 0) { + if (cfg->busy) { rv = NNG_EBUSY; } else { rv = cfg->ops.ca_chain((void *) (cfg + 1), certs, crl); @@ -1371,7 +1375,7 @@ nng_tls_config_own_cert( { int rv; nni_mtx_lock(&cfg->lock); - if (cfg->busy != 0) { + if (cfg->busy) { rv = NNG_EBUSY; } else { rv = cfg->ops.own_cert((void *) (cfg + 1), cert, key, pass); @@ -1381,12 +1385,27 @@ nng_tls_config_own_cert( } int +nng_tls_config_psk(nng_tls_config *cfg, const char *identity, + const uint8_t *key, size_t key_len) +{ + int rv; + nni_mtx_lock(&cfg->lock); + if (cfg->busy) { + rv = NNG_EBUSY; + } else { + rv = cfg->ops.psk((void *) (cfg + 1), identity, key, key_len); + } + nni_mtx_unlock(&cfg->lock); + return (rv); +} + +int nng_tls_config_auth_mode(nng_tls_config *cfg, nng_tls_auth_mode mode) { int rv; nni_mtx_lock(&cfg->lock); - if (cfg->busy != 0) { + if (cfg->busy) { rv = NNG_EBUSY; } else { rv = cfg->ops.auth((void *) (cfg + 1), mode); @@ -1423,7 +1442,7 @@ nng_tls_config_alloc(nng_tls_config **cfg_p, nng_tls_mode mode) cfg->size = size; cfg->engine = eng; cfg->ref = 1; - cfg->busy = 0; + cfg->busy = false; nni_mtx_init(&cfg->lock); if ((rv = cfg->ops.init((void *) (cfg + 1), mode)) != 0) { diff --git a/src/supplemental/tls/tls_test.c b/src/supplemental/tls/tls_test.c index 4a3f6e2d..d4ad2cc4 100644 --- a/src/supplemental/tls/tls_test.c +++ b/src/supplemental/tls/tls_test.c @@ -86,7 +86,7 @@ test_tls_large_message(void) void *t2; int port; - NUTS_ENABLE_LOG(NNG_LOG_INFO); + NUTS_ENABLE_LOG(NNG_LOG_DEBUG); // allocate messages NUTS_ASSERT((buf1 = nng_alloc(size)) != NULL); NUTS_ASSERT((buf2 = nng_alloc(size)) != NULL); @@ -170,10 +170,319 @@ test_tls_garbled_cert(void) nng_tls_config_free(c1); } +void +test_tls_psk(void) +{ + nng_stream_listener *l; + nng_stream_dialer *d; + nng_aio *aio1, *aio2; + nng_stream *s1; + nng_stream *s2; + nng_tls_config *c1; + nng_tls_config *c2; + char addr[32]; + uint8_t key[32]; + uint8_t *buf1; + uint8_t *buf2; + size_t size = 10000; + void *t1; + void *t2; + int port; + + NUTS_ENABLE_LOG(NNG_LOG_DEBUG); + // allocate messages + NUTS_ASSERT((buf1 = nng_alloc(size)) != NULL); + NUTS_ASSERT((buf2 = nng_alloc(size)) != NULL); + + for (size_t i = 0; i < sizeof(key); i++) { + key[i] = rand() & 0xff; + } + for (size_t i = 0; i < size; i++) { + buf1[i] = rand() & 0xff; + } + + NUTS_PASS(nng_aio_alloc(&aio1, NULL, NULL)); + NUTS_PASS(nng_aio_alloc(&aio2, NULL, NULL)); + nng_aio_set_timeout(aio1, 5000); + nng_aio_set_timeout(aio2, 5000); + + // Allocate the listener first. We use a wild-card port. + NUTS_PASS(nng_stream_listener_alloc(&l, "tls+tcp://127.0.0.1:0")); + NUTS_PASS(nng_tls_config_alloc(&c1, NNG_TLS_MODE_SERVER)); + NUTS_PASS(nng_tls_config_psk(c1, "identity", key, sizeof(key))); + NUTS_PASS(nng_stream_listener_set_ptr(l, NNG_OPT_TLS_CONFIG, c1)); + NUTS_PASS(nng_stream_listener_listen(l)); + NUTS_PASS( + nng_stream_listener_get_int(l, NNG_OPT_TCP_BOUND_PORT, &port)); + NUTS_TRUE(port > 0); + NUTS_TRUE(port < 65536); + + snprintf(addr, sizeof(addr), "tls+tcp://127.0.0.1:%d", port); + NUTS_PASS(nng_stream_dialer_alloc(&d, addr)); + NUTS_PASS(nng_tls_config_alloc(&c2, NNG_TLS_MODE_CLIENT)); + NUTS_PASS(nng_tls_config_psk(c2, "identity", key, sizeof(key))); + NUTS_PASS(nng_tls_config_server_name(c2, "localhost")); + + NUTS_PASS(nng_stream_dialer_set_ptr(d, NNG_OPT_TLS_CONFIG, c2)); + + nng_stream_listener_accept(l, aio1); + nng_stream_dialer_dial(d, aio2); + + nng_aio_wait(aio1); + nng_aio_wait(aio2); + + NUTS_PASS(nng_aio_result(aio1)); + NUTS_PASS(nng_aio_result(aio2)); + + NUTS_TRUE((s1 = nng_aio_get_output(aio1, 0)) != NULL); + NUTS_TRUE((s2 = nng_aio_get_output(aio2, 0)) != NULL); + + t1 = nuts_stream_send_start(s1, buf1, size); + t2 = nuts_stream_recv_start(s2, buf2, size); + + NUTS_PASS(nuts_stream_wait(t1)); + NUTS_PASS(nuts_stream_wait(t2)); + NUTS_TRUE(memcmp(buf1, buf2, size) == 0); + + nng_free(buf1, size); + nng_free(buf2, size); + nng_stream_free(s1); + nng_stream_free(s2); + nng_stream_dialer_free(d); + nng_stream_listener_free(l); + nng_tls_config_free(c1); + nng_tls_config_free(c2); + nng_aio_free(aio1); + nng_aio_free(aio2); +} + +void +test_tls_psk_server_identities(void) +{ + nng_stream_listener *l; + nng_stream_dialer *d; + nng_aio *aio1, *aio2; + nng_stream *s1; + nng_stream *s2; + nng_tls_config *c1; + nng_tls_config *c2; + char addr[32]; + uint8_t *buf1; + uint8_t *buf2; + size_t size = 10000; + void *t1; + void *t2; + int port; + char *identity = "test_identity"; + uint8_t key[32]; + + NUTS_ENABLE_LOG(NNG_LOG_INFO); + // allocate messages + NUTS_ASSERT((buf1 = nng_alloc(size)) != NULL); + NUTS_ASSERT((buf2 = nng_alloc(size)) != NULL); + + for (size_t i = 0; i < sizeof(key); i++) { + key[i] = rand() & 0xff; + } + for (size_t i = 0; i < size; i++) { + buf1[i] = rand() & 0xff; + } + + NUTS_PASS(nng_aio_alloc(&aio1, NULL, NULL)); + NUTS_PASS(nng_aio_alloc(&aio2, NULL, NULL)); + nng_aio_set_timeout(aio1, 5000); + nng_aio_set_timeout(aio2, 5000); + + // Allocate the listener first. We use a wild-card port. + NUTS_PASS(nng_stream_listener_alloc(&l, "tls+tcp://127.0.0.1:0")); + NUTS_PASS(nng_tls_config_alloc(&c1, NNG_TLS_MODE_SERVER)); + // Replace the identity .. first write one value, then we change it + NUTS_PASS( + nng_tls_config_psk(c1, "identity2", key + 4, sizeof(key) - 4)); + NUTS_PASS(nng_tls_config_psk(c1, identity, key + 4, sizeof(key) - 4)); + NUTS_PASS(nng_tls_config_psk(c1, identity, key, sizeof(key))); + NUTS_PASS(nng_stream_listener_set_ptr(l, NNG_OPT_TLS_CONFIG, c1)); + NUTS_PASS(nng_stream_listener_listen(l)); + NUTS_PASS( + nng_stream_listener_get_int(l, NNG_OPT_TCP_BOUND_PORT, &port)); + NUTS_TRUE(port > 0); + NUTS_TRUE(port < 65536); + + snprintf(addr, sizeof(addr), "tls+tcp://127.0.0.1:%d", port); + NUTS_PASS(nng_stream_dialer_alloc(&d, addr)); + NUTS_PASS(nng_tls_config_alloc(&c2, NNG_TLS_MODE_CLIENT)); + NUTS_PASS(nng_tls_config_psk(c2, identity, key, sizeof(key))); + NUTS_PASS(nng_tls_config_server_name(c2, "localhost")); + + NUTS_PASS(nng_stream_dialer_set_ptr(d, NNG_OPT_TLS_CONFIG, c2)); + + nng_stream_listener_accept(l, aio1); + nng_stream_dialer_dial(d, aio2); + + nng_aio_wait(aio1); + nng_aio_wait(aio2); + + NUTS_PASS(nng_aio_result(aio1)); + NUTS_PASS(nng_aio_result(aio2)); + + NUTS_TRUE((s1 = nng_aio_get_output(aio1, 0)) != NULL); + NUTS_TRUE((s2 = nng_aio_get_output(aio2, 0)) != NULL); + + t1 = nuts_stream_send_start(s1, buf1, size); + t2 = nuts_stream_recv_start(s2, buf2, size); + + NUTS_PASS(nuts_stream_wait(t1)); + NUTS_PASS(nuts_stream_wait(t2)); + NUTS_TRUE(memcmp(buf1, buf2, size) == 0); + + nng_free(buf1, size); + nng_free(buf2, size); + nng_stream_free(s1); + nng_stream_free(s2); + nng_stream_dialer_free(d); + nng_stream_listener_free(l); + nng_tls_config_free(c1); + nng_tls_config_free(c2); + nng_aio_free(aio1); + nng_aio_free(aio2); +} + +void +test_tls_psk_bad_identity(void) +{ + nng_stream_listener *l; + nng_stream_dialer *d; + nng_aio *aio1, *aio2; + nng_stream *s1; + nng_stream *s2; + nng_tls_config *c1; + nng_tls_config *c2; + char addr[32]; + uint8_t *buf1; + uint8_t *buf2; + size_t size = 10000; + void *t1; + void *t2; + int port; + uint8_t key[32]; + + NUTS_ENABLE_LOG(NNG_LOG_INFO); + // allocate messages + NUTS_ASSERT((buf1 = nng_alloc(size)) != NULL); + NUTS_ASSERT((buf2 = nng_alloc(size)) != NULL); + + for (size_t i = 0; i < sizeof(key); i++) { + key[i] = rand() & 0xff; + } + for (size_t i = 0; i < size; i++) { + buf1[i] = rand() & 0xff; + } + + NUTS_PASS(nng_aio_alloc(&aio1, NULL, NULL)); + NUTS_PASS(nng_aio_alloc(&aio2, NULL, NULL)); + nng_aio_set_timeout(aio1, 5000); + nng_aio_set_timeout(aio2, 5000); + + // Allocate the listener first. We use a wild-card port. + NUTS_PASS(nng_stream_listener_alloc(&l, "tls+tcp://127.0.0.1:0")); + NUTS_PASS(nng_tls_config_alloc(&c1, NNG_TLS_MODE_SERVER)); + // Replace the identity .. first write one value, then we change it + NUTS_PASS(nng_tls_config_psk(c1, "identity1", key, sizeof(key))); + NUTS_PASS(nng_stream_listener_set_ptr(l, NNG_OPT_TLS_CONFIG, c1)); + NUTS_PASS(nng_stream_listener_listen(l)); + NUTS_PASS( + nng_stream_listener_get_int(l, NNG_OPT_TCP_BOUND_PORT, &port)); + NUTS_TRUE(port > 0); + NUTS_TRUE(port < 65536); + + snprintf(addr, sizeof(addr), "tls+tcp://127.0.0.1:%d", port); + NUTS_PASS(nng_stream_dialer_alloc(&d, addr)); + NUTS_PASS(nng_tls_config_alloc(&c2, NNG_TLS_MODE_CLIENT)); + NUTS_PASS(nng_tls_config_psk(c2, "identity2", key, sizeof(key))); + NUTS_PASS(nng_tls_config_server_name(c2, "localhost")); + + NUTS_PASS(nng_stream_dialer_set_ptr(d, NNG_OPT_TLS_CONFIG, c2)); + + nng_stream_listener_accept(l, aio1); + nng_stream_dialer_dial(d, aio2); + + nng_aio_wait(aio1); + nng_aio_wait(aio2); + + NUTS_PASS(nng_aio_result(aio1)); + NUTS_PASS(nng_aio_result(aio2)); + + NUTS_TRUE((s1 = nng_aio_get_output(aio1, 0)) != NULL); + NUTS_TRUE((s2 = nng_aio_get_output(aio2, 0)) != NULL); + + t1 = nuts_stream_send_start(s1, buf1, size); + t2 = nuts_stream_recv_start(s2, buf2, size); + + NUTS_FAIL(nuts_stream_wait(t1), NNG_ECRYPTO); + NUTS_FAIL(nuts_stream_wait(t2), NNG_ECRYPTO); + + nng_free(buf1, size); + nng_free(buf2, size); + nng_stream_free(s1); + nng_stream_free(s2); + nng_stream_dialer_free(d); + nng_stream_listener_free(l); + nng_tls_config_free(c1); + nng_tls_config_free(c2); + nng_aio_free(aio1); + nng_aio_free(aio2); +} + +void +test_tls_psk_key_too_big(void) +{ + nng_tls_config *c1; + uint8_t key[5000]; + + NUTS_ENABLE_LOG(NNG_LOG_INFO); + + // Allocate the listener first. We use a wild-card port. + NUTS_PASS(nng_tls_config_alloc(&c1, NNG_TLS_MODE_CLIENT)); + NUTS_FAIL( + nng_tls_config_psk(c1, "identity", key, sizeof(key)), NNG_ECRYPTO); + nng_tls_config_free(c1); +} + +void +test_tls_psk_config_busy(void) +{ + nng_tls_config *c1; + uint8_t key[32]; + nng_stream_listener *l; + nng_aio *aio; + + nng_aio_alloc(&aio, NULL, NULL); + + NUTS_ENABLE_LOG(NNG_LOG_INFO); + + NUTS_PASS(nng_stream_listener_alloc(&l, "tls+tcp://127.0.0.1:0")); + NUTS_PASS(nng_tls_config_alloc(&c1, NNG_TLS_MODE_SERVER)); + NUTS_PASS(nng_tls_config_psk(c1, "identity", key, sizeof(key))); + NUTS_PASS(nng_stream_listener_set_ptr(l, NNG_OPT_TLS_CONFIG, c1)); + nng_stream_listener_accept(l, aio); + nng_msleep(100); + NUTS_FAIL( + nng_tls_config_psk(c1, "identity2", key, sizeof(key)), NNG_EBUSY); + + nng_stream_listener_free(l); + nng_aio_free(aio); + nng_tls_config_free(c1); +} + TEST_LIST = { { "tls config version", test_tls_config_version }, { "tls conn refused", test_tls_conn_refused }, { "tls large message", test_tls_large_message }, { "tls garbled cert", test_tls_garbled_cert }, + { "tls psk", test_tls_psk }, + { "tls psk server identities", test_tls_psk_server_identities }, + { "tls psk bad identity", test_tls_psk_bad_identity }, + { "tls psk key too big", test_tls_psk_key_too_big }, + { "tls psk key config busy", test_tls_psk_config_busy }, { NULL, NULL }, }; |
