From c9e4b41e109b03ca9ffcf789f8278705451026c7 Mon Sep 17 00:00:00 2001 From: Garrett D'Amore Date: Sun, 9 Nov 2025 09:36:52 -0800 Subject: refactor/dtls: Use message oriented send/receive for DTLS. The protocol here needs to know and respect message boundaries. --- src/supplemental/tls/mbedtls/mbedtls.c | 10 +- src/supplemental/tls/openssl/openssl.c | 12 ++- src/supplemental/tls/tls_common.c | 170 +++++++++++++++++++++++++++++---- src/supplemental/tls/tls_common.h | 33 +++++-- src/supplemental/tls/tls_engine.h | 4 +- src/supplemental/tls/tls_stream.c | 2 +- src/supplemental/tls/wolfssl/wolfssl.c | 8 +- 7 files changed, 195 insertions(+), 44 deletions(-) (limited to 'src/supplemental') diff --git a/src/supplemental/tls/mbedtls/mbedtls.c b/src/supplemental/tls/mbedtls/mbedtls.c index 87a8bd96..32ff3c59 100644 --- a/src/supplemental/tls/mbedtls/mbedtls.c +++ b/src/supplemental/tls/mbedtls/mbedtls.c @@ -211,8 +211,8 @@ tls_mk_err(int err) static int net_send(void *tls, const unsigned char *buf, size_t len) { - size_t sz = len; - int rv; + size_t sz = len; + nng_err rv; rv = nng_tls_engine_send(tls, buf, &sz); switch (rv) { @@ -228,8 +228,8 @@ net_send(void *tls, const unsigned char *buf, size_t len) static int net_recv(void *tls, unsigned char *buf, size_t len) { - size_t sz = len; - int rv; + size_t sz = len; + nng_err rv; rv = nng_tls_engine_recv(tls, buf, &sz); switch (rv) { @@ -993,7 +993,7 @@ tls_engine_init(void) #endif // Uncomment the following to have noisy debug from mbedTLS. // This may be useful when trying to debug failures. - // mbedtls_debug_set_threshold(9); + mbedtls_debug_set_threshold(1); mbedtls_ssl_cookie_init(&mbed_ssl_cookie_ctx); rv = mbedtls_ssl_cookie_setup(&mbed_ssl_cookie_ctx, tls_random, NULL); diff --git a/src/supplemental/tls/openssl/openssl.c b/src/supplemental/tls/openssl/openssl.c index 69364dd1..1095bb2f 100644 --- a/src/supplemental/tls/openssl/openssl.c +++ b/src/supplemental/tls/openssl/openssl.c @@ -25,6 +25,7 @@ #include "../../../core/list.h" #include "../../../core/strs.h" #include "../tls_engine.h" +#include "nng/nng.h" // library code for openssl static int ossl_libcode; @@ -116,8 +117,8 @@ tls_log_err(const char *msgid, const char *context, int errnum) static int ossl_net_send(BIO *bio, const char *buf, size_t len, size_t *lenp) { - void *ctx = BIO_get_data(bio); - int rv; + void *ctx = BIO_get_data(bio); + nng_err rv; switch (rv = nng_tls_engine_send(ctx, (const uint8_t *) buf, &len)) { case NNG_OK: @@ -135,8 +136,8 @@ ossl_net_send(BIO *bio, const char *buf, size_t len, size_t *lenp) static int ossl_net_recv(BIO *bio, char *buf, size_t len, size_t *lenp) { - void *ctx = BIO_get_data(bio); - int rv; + void *ctx = BIO_get_data(bio); + nng_err rv; switch (rv = nng_tls_engine_recv(ctx, (uint8_t *) buf, &len)) { case NNG_OK: @@ -348,7 +349,8 @@ ossl_conn_handshake(nng_tls_engine_conn *ec) rv = SSL_do_handshake(ec->ssl); if (rv == 1) { - nng_log_debug("NNG-TLS-HS", "TLS handshake complete"); + nng_log_debug("NNG-TLS-HS", "TLS handshake complete %s", + ec->mode == NNG_TLS_MODE_CLIENT ? "client" : "server"); return (NNG_OK); } rv = SSL_get_error(ec->ssl, rv); diff --git a/src/supplemental/tls/tls_common.c b/src/supplemental/tls/tls_common.c index 75b056b3..1ab687b8 100644 --- a/src/supplemental/tls/tls_common.c +++ b/src/supplemental/tls/tls_common.c @@ -15,6 +15,7 @@ #include "../../core/nng_impl.h" +#include "nng/nng.h" #include "tls_common.h" #include "tls_engine.h" @@ -42,6 +43,9 @@ static void tls_bio_recv_cb(void *arg); static void tls_do_send(nni_tls_conn *); static void tls_do_recv(nni_tls_conn *); static void tls_bio_send_start(nni_tls_conn *); +static void tls_bio_send_msg_start(nni_tls_conn *); +static void tls_bio_recv_start(nni_tls_conn *); +static void tls_bio_recv_msg_start(nni_tls_conn *); static void tls_bio_error(nni_tls_conn *, nng_err); #define nni_tls_conn_ops (nng_tls_engine_ops.conn_ops) @@ -156,17 +160,25 @@ nni_tls_peer_cert(nni_tls_conn *conn, nng_tls_cert **certp) } int -nni_tls_init(nni_tls_conn *conn, nng_tls_config *cfg) +nni_tls_init(nni_tls_conn *conn, nng_tls_config *cfg, bool msg_oriented) { nni_mtx_lock(&cfg->lock); cfg->busy = true; nni_mtx_unlock(&cfg->lock); - if (((conn->bio_send_buf = nni_zalloc(NNG_TLS_MAX_SEND_SIZE)) == - NULL) || - ((conn->bio_recv_buf = nni_zalloc(NNG_TLS_MAX_RECV_SIZE)) == - NULL)) { - return (NNG_ENOMEM); + conn->msg_oriented = msg_oriented; + if (msg_oriented) { + nni_lmq_init(&conn->bio_send_lmq, NNG_TLS_MAX_SEND_MSG_QUEUE); + nni_lmq_init(&conn->bio_recv_lmq, NNG_TLS_MAX_RECV_MSG_QUEUE); + // TODO: REMOVE + conn->bio_recv_buf = nni_zalloc(NNG_TLS_MAX_RECV_SIZE); + } else { + if (((conn->bio_send_buf = + nni_zalloc(NNG_TLS_MAX_SEND_SIZE)) == NULL) || + ((conn->bio_recv_buf = + nni_zalloc(NNG_TLS_MAX_RECV_SIZE)) == NULL)) { + return (NNG_ENOMEM); + } } conn->cfg = cfg; @@ -205,6 +217,9 @@ nni_tls_fini(nni_tls_conn *conn) if (conn->bio != NULL) { conn->bio_ops.bio_free(conn->bio); } + nni_lmq_fini(&conn->bio_recv_lmq); + nni_lmq_fini(&conn->bio_send_lmq); + nni_msg_free(conn->bio_recv_msg); nni_mtx_fini(&conn->bio_lock); nni_mtx_fini(&conn->lock); } @@ -411,7 +426,11 @@ tls_bio_send_cb(void *arg) nni_mtx_lock(&conn->bio_lock); conn->bio_send_active = false; - if ((rv = nni_aio_result(aio)) != 0) { + if ((rv = nni_aio_result(aio)) != NNG_OK) { + if (conn->msg_oriented) { + nni_msg_free(nni_aio_get_msg(aio)); + nni_aio_set_msg(aio, NULL); + } tls_bio_error(conn, rv); nni_mtx_unlock(&conn->bio_lock); @@ -419,12 +438,19 @@ tls_bio_send_cb(void *arg) return; } - count = nni_aio_count(aio); - NNI_ASSERT(count <= conn->bio_send_len); - conn->bio_send_len -= count; - conn->bio_send_tail += count; - conn->bio_send_tail %= NNG_TLS_MAX_SEND_SIZE; - tls_bio_send_start(conn); + if (conn->msg_oriented) { + nng_msg *msg = nni_aio_get_msg(aio); + nni_msg_free(msg); + nni_aio_set_msg(aio, NULL); + tls_bio_send_msg_start(conn); + } else { + count = nni_aio_count(aio); + NNI_ASSERT(count <= conn->bio_send_len); + conn->bio_send_len -= count; + conn->bio_send_tail += count; + conn->bio_send_tail %= NNG_TLS_MAX_SEND_SIZE; + tls_bio_send_start(conn); + } nni_mtx_unlock(&conn->bio_lock); nni_tls_run(conn); @@ -446,9 +472,15 @@ tls_bio_recv_cb(void *arg) return; } - NNI_ASSERT(conn->bio_recv_len == 0); - NNI_ASSERT(conn->bio_recv_off == 0); - conn->bio_recv_len = nni_aio_count(aio); + if (conn->msg_oriented) { + nng_msg *msg = nni_aio_get_msg(aio); + nni_lmq_put(&conn->bio_recv_lmq, msg); + nni_aio_set_msg(aio, NULL); + } else { + NNI_ASSERT(conn->bio_recv_len == 0); + NNI_ASSERT(conn->bio_recv_off == 0); + conn->bio_recv_len = nni_aio_count(aio); + } nni_mtx_unlock(&conn->bio_lock); nni_tls_run(conn); @@ -480,6 +512,21 @@ tls_bio_recv_start(nni_tls_conn *conn) conn->bio_ops.bio_recv(conn->bio, &conn->bio_recv); } +static void +tls_bio_recv_msg_start(nni_tls_conn *conn) +{ + if (conn->bio_recv_pend || conn->bio_closed) { + // Already have a receive in flight. + return; + } + if (nni_lmq_full(&conn->bio_recv_lmq)) { + return; + } + + conn->bio_recv_pend = true; + conn->bio_ops.bio_recv(conn->bio, &conn->bio_recv); +} + static void tls_bio_send_start(nni_tls_conn *conn) { @@ -525,7 +572,46 @@ tls_bio_send_start(nni_tls_conn *conn) conn->bio_ops.bio_send(conn->bio, &conn->bio_send); } -int +static void +tls_bio_send_msg_start(nni_tls_conn *conn) +{ + nni_msg *msg; + + if (conn->bio_send_active || conn->bio_closed) { + return; + } + if (nni_lmq_get(&conn->bio_send_lmq, &msg) == NNG_OK) { + conn->bio_send_active = true; + nni_aio_set_msg(&conn->bio_send, msg); + conn->bio_ops.bio_send(conn->bio, &conn->bio_send); + } +} + +static nng_err +tls_engine_send_msg(nni_tls_conn *conn, const uint8_t *buf, size_t *szp) +{ + nng_msg *msg; + nng_err rv; + + // move the data into a message for the queue + nni_mtx_lock(&conn->bio_lock); + if (nni_lmq_full(&conn->bio_send_lmq)) { + nni_mtx_unlock(&conn->bio_lock); + return (NNG_EAGAIN); + } + if ((rv = nni_msg_alloc(&msg, *szp)) != NNG_OK) { + nni_mtx_unlock(&conn->bio_lock); + return (rv); + } + memcpy(nni_msg_body(msg), buf, *szp); + rv = nni_lmq_put(&conn->bio_send_lmq, msg); + NNI_ASSERT(rv == NNG_OK); // we checked it already above! + tls_bio_send_msg_start(conn); + nni_mtx_unlock(&conn->bio_lock); + return (rv); +} + +nng_err nng_tls_engine_send(void *arg, const uint8_t *buf, size_t *szp) { nni_tls_conn *conn = arg; @@ -535,6 +621,9 @@ nng_tls_engine_send(void *arg, const uint8_t *buf, size_t *szp) size_t space; size_t cnt; + if (conn->msg_oriented) { + return (tls_engine_send_msg(conn, buf, szp)); + } nni_mtx_lock(&conn->bio_lock); head = conn->bio_send_head; tail = conn->bio_send_tail; @@ -576,15 +665,55 @@ nng_tls_engine_send(void *arg, const uint8_t *buf, size_t *szp) tls_bio_send_start(conn); nni_mtx_unlock(&conn->bio_lock); - return (0); + return (NNG_OK); } -int +static nng_err +tls_engine_recv_msg(nni_tls_conn *conn, uint8_t *buf, size_t *szp) +{ + nni_msg *msg; + nng_err rv; + size_t len; + nni_mtx_lock(&conn->bio_lock); + if ((conn->bio_recv_msg == NULL) && + nni_lmq_empty(&conn->bio_recv_lmq)) { + tls_bio_recv_msg_start(conn); + nni_mtx_unlock(&conn->bio_lock); + return (NNG_EAGAIN); + } + + if (conn->bio_recv_msg == NULL) { + rv = nni_lmq_get(&conn->bio_recv_lmq, &conn->bio_recv_msg); + NNI_ASSERT(rv == NNG_OK); + } + msg = conn->bio_recv_msg; + + if ((len = nni_msg_len(msg)) < *szp) { + *szp = len; + } else { + len = *szp; + } + memcpy(buf, nni_msg_body(msg), len); + nni_msg_trim(msg, len); + if (nni_msg_len(msg) == 0) { + nni_msg_free(msg); + conn->bio_recv_msg = NULL; + } + tls_bio_recv_msg_start(conn); + nni_mtx_unlock(&conn->bio_lock); + return (NNG_OK); +} + +nng_err nng_tls_engine_recv(void *arg, uint8_t *buf, size_t *szp) { nni_tls_conn *conn = arg; size_t len = *szp; + if (conn->msg_oriented) { + return (tls_engine_recv_msg(conn, buf, szp)); + } + nni_mtx_lock(&conn->bio_lock); if (conn->bio_recv_len == 0) { tls_bio_recv_start(conn); @@ -604,7 +733,7 @@ nng_tls_engine_recv(void *arg, uint8_t *buf, size_t *szp) nni_mtx_unlock(&conn->bio_lock); *szp = len; - return (0); + return (NNG_OK); } int @@ -773,6 +902,7 @@ nng_tls_config_alloc(nng_tls_config **cfg_p, nng_tls_mode mode) cfg->size = size; cfg->ref = 1; cfg->busy = false; + cfg->mode = mode; nni_mtx_init(&cfg->lock); if ((rv = nni_tls_cfg_ops->init((void *) (cfg + 1), mode)) != 0) { diff --git a/src/supplemental/tls/tls_common.h b/src/supplemental/tls/tls_common.h index 6d163fd5..78cfe793 100644 --- a/src/supplemental/tls/tls_common.h +++ b/src/supplemental/tls/tls_common.h @@ -15,6 +15,7 @@ #include "../../core/nng_impl.h" +#include "core/lmq.h" #include "tls_engine.h" #ifndef NNG_TLS_TLS_COMMON_H @@ -34,17 +35,30 @@ #define NNG_TLS_MAX_RECV_SIZE 16384 #endif +// NNG_TLS_MAX_SEND_MSG_QUEUE limits the number of pending messages for +// sending. This is only used for msg oriented transports like DTLS or SCTP. +#ifndef NNG_TLS_MAX_SEND_MSG_QUEUE +#define NNG_TLS_MAX_SEND_MSG_QUEUE 32 +#endif + +// NNG_TLS_MAX_RECV_MSG_QUEUE limits the number of pending messages for +// receiving. This is only used for msg oriented transports like DTLS or SCTP. +#ifndef NNG_TLS_MAX_RECV_MSG_QUEUE +#define NNG_TLS_MAX_RECV_MSG_QUEUE 32 +#endif + // This file contains common code for TLS, and is only compiled if we // have TLS configured in the system. In particular, this provides the // parts of TLS support that are invariant relative to different TLS // libraries, such as dialer and listener support. struct nng_tls_config { - nni_mtx lock; - int ref; - bool busy; - bool key_is_set; - size_t size; + nni_mtx lock; + int ref; + bool busy; + bool key_is_set; + nng_tls_mode mode; + size_t size; // ... engine config data follows }; @@ -64,9 +78,10 @@ typedef struct { nng_tls_config *cfg; size_t size; nni_mtx lock; - bool closed; nni_atomic_flag did_close; bool hs_done; + bool closed; + bool msg_oriented; // works with messages instead of streams nni_list send_queue; nni_list recv_queue; @@ -86,13 +101,17 @@ typedef struct { size_t bio_send_len; size_t bio_send_head; size_t bio_send_tail; + nni_lmq bio_send_lmq; // for msg oriented only + nni_lmq bio_recv_lmq; // for msg oriented only + nni_msg *bio_recv_msg; // for msg oriented only nni_reap_node reap; // ... engine connection data follows } nni_tls_conn; extern void nni_tls_fini(nni_tls_conn *conn); -extern int nni_tls_init(nni_tls_conn *conn, nng_tls_config *cfg); +extern int nni_tls_init( + nni_tls_conn *conn, nng_tls_config *cfg, bool msg_oriented); extern int nni_tls_start(nni_tls_conn *conn, const nni_tls_bio_ops *biops, void *bio, const nng_sockaddr *sa); extern void nni_tls_stop(nni_tls_conn *conn); diff --git a/src/supplemental/tls/tls_engine.h b/src/supplemental/tls/tls_engine.h index c0e395d5..534e6af2 100644 --- a/src/supplemental/tls/tls_engine.h +++ b/src/supplemental/tls/tls_engine.h @@ -268,13 +268,13 @@ extern nng_tls_engine nng_tls_engine_ops; // accept more data yet), or some other error. On success the count is // updated with the number of bytes actually sent. The first argument // is the context structure passed in when starting the engine. -extern int nng_tls_engine_send(void *, const uint8_t *, size_t *); +extern nng_err nng_tls_engine_send(void *, const uint8_t *, size_t *); // nng_tls_engine_recv is called by the engine to receive data over // the underlying connection. It returns zero on success, NNG_EAGAIN // if the operation can't be completed yet (there is no data available // for reading), or some other error. On success the count is updated // with the number of bytes actually received. -extern int nng_tls_engine_recv(void *, uint8_t *, size_t *); +extern nng_err nng_tls_engine_recv(void *, uint8_t *, size_t *); #endif // NNG_SUPPLEMENTAL_TLS_TLS_ENGINE_H diff --git a/src/supplemental/tls/tls_stream.c b/src/supplemental/tls/tls_stream.c index b523e583..d4a0376b 100644 --- a/src/supplemental/tls/tls_stream.c +++ b/src/supplemental/tls/tls_stream.c @@ -172,7 +172,7 @@ nni_tls_stream_alloc(tls_stream **tsp, nng_tls_config *cfg, nng_aio *user_aio) nni_aio_init(&ts->conn_aio, tls_stream_conn_cb, ts); - if ((rv = nni_tls_init(&ts->conn, cfg)) != 0) { + if ((rv = nni_tls_init(&ts->conn, cfg, false)) != 0) { nni_tls_stream_free(ts); return (rv); } diff --git a/src/supplemental/tls/wolfssl/wolfssl.c b/src/supplemental/tls/wolfssl/wolfssl.c index ef3f7391..5bfaeb61 100644 --- a/src/supplemental/tls/wolfssl/wolfssl.c +++ b/src/supplemental/tls/wolfssl/wolfssl.c @@ -114,8 +114,8 @@ tls_log_err(const char *msgid, const char *context, int errnum) static int wolf_net_send(WOLFSSL *ssl, char *buf, int len, void *ctx) { - size_t sz = len; - int rv; + size_t sz = len; + nng_err rv; (void) ssl; rv = nng_tls_engine_send(ctx, (const uint8_t *) buf, &sz); @@ -136,8 +136,8 @@ wolf_net_send(WOLFSSL *ssl, char *buf, int len, void *ctx) static int wolf_net_recv(WOLFSSL *ssl, char *buf, int len, void *ctx) { - size_t sz = len; - int rv; + size_t sz = len; + nng_err rv; (void) ssl; rv = nng_tls_engine_recv(ctx, (uint8_t *) buf, &sz); -- cgit v1.2.3-70-g09d2