From 81f5d3c6268ff91ee9c36c4cb34f6f9bfd54740d Mon Sep 17 00:00:00 2001 From: Garrett D'Amore Date: Thu, 12 Dec 2024 17:55:48 -0800 Subject: streams: add explicit stop functions This allows us to explicitly stop streams, dialers, and listeners, before we start tearing down things. This hopefully will be useful in resolving use-after-free bugs in http, tls, and websockets. The new functions are not yet documented, but they are nng_stream_stop, nng_stream_dialer_stop, and nng_stream_listener_stop. They should be called after close, and before free. The close functions now close without blocking, but the stop function is allowed to block. --- src/supplemental/http/http_client.c | 2 + src/supplemental/http/http_server.c | 1 + src/supplemental/tls/tls_common.c | 39 ++++++++--- src/supplemental/tls/tls_test.c | 39 ++++++++--- src/supplemental/websocket/websocket.c | 100 ++++++++++++++++++---------- src/supplemental/websocket/websocket_test.c | 15 ++++- 6 files changed, 143 insertions(+), 53 deletions(-) (limited to 'src/supplemental') diff --git a/src/supplemental/http/http_client.c b/src/supplemental/http/http_client.c index af5bc717..03125abf 100644 --- a/src/supplemental/http/http_client.c +++ b/src/supplemental/http/http_client.c @@ -89,6 +89,8 @@ http_dial_cb(void *arg) void nni_http_client_fini(nni_http_client *c) { + nni_aio_stop(c->aio); + nng_stream_dialer_stop(c->dialer); nni_aio_free(c->aio); nng_stream_dialer_free(c->dialer); nni_mtx_fini(&c->mtx); diff --git a/src/supplemental/http/http_server.c b/src/supplemental/http/http_server.c index 2240492f..0d5d9cb0 100644 --- a/src/supplemental/http/http_server.c +++ b/src/supplemental/http/http_server.c @@ -895,6 +895,7 @@ http_server_fini(nni_http_server *s) http_error *epage; nni_aio_stop(s->accaio); + nng_stream_listener_stop(s->listener); nni_mtx_lock(&s->mtx); NNI_ASSERT(nni_list_empty(&s->conns)); diff --git a/src/supplemental/tls/tls_common.c b/src/supplemental/tls/tls_common.c index c04d03a5..a871e74e 100644 --- a/src/supplemental/tls/tls_common.c +++ b/src/supplemental/tls/tls_common.c @@ -126,6 +126,14 @@ tls_dialer_free(void *arg) } } +static void +tls_dialer_stop(void *arg) +{ + tls_dialer *d = arg; + + nng_stream_dialer_stop(d->d); +} + // For dialing, we need to have our own completion callback, instead of // the user's completion callback. @@ -281,6 +289,7 @@ nni_tls_dialer_alloc(nng_stream_dialer **dp, const nng_url *url) d->ops.sd_close = tls_dialer_close; d->ops.sd_free = tls_dialer_free; + d->ops.sd_stop = tls_dialer_stop; d->ops.sd_dial = tls_dialer_dial; d->ops.sd_get = tls_dialer_get; d->ops.sd_set = tls_dialer_set; @@ -306,6 +315,13 @@ tls_listener_close(void *arg) nng_stream_listener_close(l->l); } +static void +tls_listener_stop(void *arg) +{ + tls_listener *l = arg; + nng_stream_listener_close(l->l); +} + static void tls_listener_free(void *arg) { @@ -436,6 +452,7 @@ nni_tls_listener_alloc(nng_stream_listener **lp, const nng_url *url) } l->ops.sl_free = tls_listener_free; l->ops.sl_close = tls_listener_close; + l->ops.sl_stop = tls_listener_stop; l->ops.sl_accept = tls_listener_accept; l->ops.sl_listen = tls_listener_listen; l->ops.sl_get = tls_listener_get; @@ -526,6 +543,18 @@ tls_close(void *arg) nng_stream_close(conn->tcp); } +static void +tls_stop(void *arg) +{ + tls_conn *conn = arg; + + tls_close(conn); + nng_stream_stop(conn->tcp); + nni_aio_stop(&conn->conn_aio); + nni_aio_stop(&conn->tcp_send); + nni_aio_stop(&conn->tcp_recv); +} + static int tls_get_verified(void *arg, void *buf, size_t *szp, nni_type t) { @@ -624,6 +653,7 @@ tls_alloc(tls_conn **conn_p, nng_tls_config *cfg, nng_aio *user_aio) conn->stream.s_close = tls_close; conn->stream.s_free = tls_free; + conn->stream.s_stop = tls_stop; conn->stream.s_send = tls_send; conn->stream.s_recv = tls_recv; conn->stream.s_get = tls_get; @@ -638,14 +668,7 @@ tls_reap(void *arg) { tls_conn *conn = arg; - // Shut it all down first. We should be freed. - if (conn->tcp != NULL) { - nng_stream_close(conn->tcp); - } - nni_aio_stop(&conn->conn_aio); - nni_aio_stop(&conn->tcp_send); - nni_aio_stop(&conn->tcp_recv); - + tls_stop(conn); conn->ops.fini((void *) (conn + 1)); nni_aio_fini(&conn->conn_aio); nni_aio_fini(&conn->tcp_send); diff --git a/src/supplemental/tls/tls_test.c b/src/supplemental/tls/tls_test.c index 43ce0c85..517be143 100644 --- a/src/supplemental/tls/tls_test.c +++ b/src/supplemental/tls/tls_test.c @@ -57,6 +57,7 @@ test_tls_conn_refused(void) NUTS_FAIL(nng_aio_result(aio), NNG_ECONNREFUSED); nng_aio_free(aio); + nng_stream_dialer_stop(dialer); nng_stream_dialer_free(dialer); } @@ -133,6 +134,10 @@ test_tls_large_message(void) nng_free(buf1, size); nng_free(buf2, size); + nng_stream_stop(s1); + nng_stream_stop(s2); + nng_stream_dialer_stop(d); + nng_stream_listener_stop(l); nng_stream_free(s1); nng_stream_free(s2); nng_stream_dialer_free(d); @@ -214,8 +219,10 @@ test_tls_ecdsa(void) NUTS_PASS(nuts_stream_wait(t2)); NUTS_TRUE(memcmp(buf1, buf2, size) == 0); - nng_free(buf1, size); - nng_free(buf2, size); + nng_stream_stop(s1); + nng_stream_stop(s2); + nng_stream_dialer_stop(d); + nng_stream_listener_stop(l); nng_stream_free(s1); nng_stream_free(s2); nng_stream_dialer_free(d); @@ -224,6 +231,8 @@ test_tls_ecdsa(void) nng_tls_config_free(c2); nng_aio_free(aio1); nng_aio_free(aio2); + nng_free(buf1, size); + nng_free(buf2, size); } void @@ -241,6 +250,7 @@ test_tls_garbled_cert(void) c1, nuts_garbled_crt, nuts_server_key, NULL), NNG_ECRYPTO); + nng_stream_listener_stop(l); nng_stream_listener_free(l); nng_tls_config_free(c1); } @@ -318,8 +328,10 @@ test_tls_psk(void) NUTS_PASS(nuts_stream_wait(t2)); NUTS_TRUE(memcmp(buf1, buf2, size) == 0); - nng_free(buf1, size); - nng_free(buf2, size); + nng_stream_stop(s1); + nng_stream_stop(s2); + nng_stream_dialer_stop(d); + nng_stream_listener_stop(l); nng_stream_free(s1); nng_stream_free(s2); nng_stream_dialer_free(d); @@ -328,6 +340,8 @@ test_tls_psk(void) nng_tls_config_free(c2); nng_aio_free(aio1); nng_aio_free(aio2); + nng_free(buf1, size); + nng_free(buf2, size); } void @@ -408,8 +422,10 @@ test_tls_psk_server_identities(void) NUTS_PASS(nuts_stream_wait(t2)); NUTS_TRUE(memcmp(buf1, buf2, size) == 0); - nng_free(buf1, size); - nng_free(buf2, size); + nng_stream_stop(s1); + nng_stream_stop(s2); + nng_stream_dialer_stop(d); + nng_stream_listener_stop(l); nng_stream_free(s1); nng_stream_free(s2); nng_stream_dialer_free(d); @@ -418,6 +434,8 @@ test_tls_psk_server_identities(void) nng_tls_config_free(c2); nng_aio_free(aio1); nng_aio_free(aio2); + nng_free(buf1, size); + nng_free(buf2, size); } void @@ -495,8 +513,10 @@ test_tls_psk_bad_identity(void) NUTS_ASSERT(nuts_stream_wait(t1) != 0); NUTS_ASSERT(nuts_stream_wait(t2) != 0); - nng_free(buf1, size); - nng_free(buf2, size); + nng_stream_stop(s1); + nng_stream_stop(s2); + nng_stream_dialer_stop(d); + nng_stream_listener_stop(l); nng_stream_free(s1); nng_stream_free(s2); nng_stream_dialer_free(d); @@ -505,6 +525,8 @@ test_tls_psk_bad_identity(void) nng_tls_config_free(c2); nng_aio_free(aio1); nng_aio_free(aio2); + nng_free(buf1, size); + nng_free(buf2, size); } void @@ -543,6 +565,7 @@ test_tls_psk_config_busy(void) NUTS_FAIL( nng_tls_config_psk(c1, "identity2", key, sizeof(key)), NNG_EBUSY); + nng_stream_listener_stop(l); nng_stream_listener_free(l); nng_aio_free(aio); nng_tls_config_free(c1); diff --git a/src/supplemental/websocket/websocket.c b/src/supplemental/websocket/websocket.c index 9f3f6d0b..e7372a49 100644 --- a/src/supplemental/websocket/websocket.c +++ b/src/supplemental/websocket/websocket.c @@ -1183,12 +1183,9 @@ ws_close_error(nni_ws *ws, uint16_t code) } static void -ws_fini(void *arg) +ws_stop(void *arg) { - nni_ws *ws = arg; - ws_frame *frame; - nng_aio *aio; - + nni_ws *ws = arg; ws_close_error(ws, WS_CLOSE_NORMAL_CLOSE); // Give a chance for the close frame to drain. @@ -1198,7 +1195,6 @@ ws_fini(void *arg) nni_aio_stop(&ws->txaio); nni_aio_stop(&ws->closeaio); nni_aio_stop(&ws->httpaio); - nni_aio_stop(&ws->connaio); if (nni_list_node_active(&ws->node)) { nni_ws_dialer *d; @@ -1210,6 +1206,16 @@ ws_fini(void *arg) nni_mtx_unlock(&d->mtx); } } +} + +static void +ws_fini(void *arg) +{ + nni_ws *ws = arg; + ws_frame *frame; + nng_aio *aio; + + ws_stop(ws); nni_mtx_lock(&ws->mtx); while ((frame = nni_list_first(&ws->rxq)) != NULL) { @@ -1450,6 +1456,7 @@ ws_init(nni_ws **wsp) ws->ops.s_close = ws_str_close; ws->ops.s_free = ws_str_free; + ws->ops.s_stop = ws_stop; ws->ops.s_send = ws_str_send; ws->ops.s_recv = ws_str_recv; ws->ops.s_get = ws_str_get; @@ -1460,10 +1467,11 @@ ws_init(nni_ws **wsp) } static void -ws_listener_free(void *arg) +ws_listener_stop(void *arg) { - nni_ws_listener *l = arg; - ws_header *hdr; + nni_ws_listener *l = arg; + nni_http_handler *h; + nni_http_server *s; ws_listener_close(l); @@ -1471,16 +1479,27 @@ ws_listener_free(void *arg) while (!nni_list_empty(&l->reply)) { nni_cv_wait(&l->cv); } + h = l->handler; + s = l->server; + l->handler = NULL; + l->server = NULL; nni_mtx_unlock(&l->mtx); - if (l->handler != NULL) { - nni_http_handler_fini(l->handler); - l->handler = NULL; + if (h != NULL) { + nni_http_handler_fini(h); } - if (l->server != NULL) { - nni_http_server_fini(l->server); - l->server = NULL; + if (s != NULL) { + nni_http_server_fini(s); } +} + +static void +ws_listener_free(void *arg) +{ + nni_ws_listener *l = arg; + ws_header *hdr; + + ws_listener_stop(l); nni_cv_fini(&l->cv); nni_mtx_fini(&l->mtx); @@ -2148,6 +2167,7 @@ nni_ws_listener_alloc(nng_stream_listener **wslp, const nng_url *url) l->isstream = true; l->ops.sl_free = ws_listener_free; l->ops.sl_close = ws_listener_close; + l->ops.sl_stop = ws_listener_stop; l->ops.sl_accept = ws_listener_accept; l->ops.sl_listen = ws_listener_listen; l->ops.sl_set = ws_listener_set; @@ -2255,16 +2275,43 @@ err: } static void -ws_dialer_free(void *arg) +ws_dialer_close(void *arg) { nni_ws_dialer *d = arg; - ws_header *hdr; + nni_ws *ws; + nni_mtx_lock(&d->mtx); + if (d->closed) { + nni_mtx_unlock(&d->mtx); + return; + } + d->closed = true; + NNI_LIST_FOREACH (&d->wspend, ws) { + nni_aio_close(&ws->connaio); + nni_aio_close(&ws->httpaio); + } + nni_mtx_unlock(&d->mtx); +} +static void +ws_dialer_stop(void *arg) +{ + nni_ws_dialer *d = arg; + + ws_dialer_close(d); nni_mtx_lock(&d->mtx); while (!nni_list_empty(&d->wspend)) { nni_cv_wait(&d->cv); } nni_mtx_unlock(&d->mtx); +} + +static void +ws_dialer_free(void *arg) +{ + nni_ws_dialer *d = arg; + ws_header *hdr; + + ws_dialer_stop(d); nni_strfree(d->proto); while ((hdr = nni_list_first(&d->headers)) != NULL) { @@ -2284,24 +2331,6 @@ ws_dialer_free(void *arg) NNI_FREE_STRUCT(d); } -static void -ws_dialer_close(void *arg) -{ - nni_ws_dialer *d = arg; - nni_ws *ws; - nni_mtx_lock(&d->mtx); - if (d->closed) { - nni_mtx_unlock(&d->mtx); - return; - } - d->closed = true; - NNI_LIST_FOREACH (&d->wspend, ws) { - nni_aio_close(&ws->connaio); - nni_aio_close(&ws->httpaio); - } - nni_mtx_unlock(&d->mtx); -} - static void ws_dial_cancel(nni_aio *aio, void *arg, int rv) { @@ -2679,6 +2708,7 @@ nni_ws_dialer_alloc(nng_stream_dialer **dp, const nng_url *url) d->ops.sd_free = ws_dialer_free; d->ops.sd_close = ws_dialer_close; + d->ops.sd_stop = ws_dialer_stop; d->ops.sd_dial = ws_dialer_dial; d->ops.sd_set = ws_dialer_set; d->ops.sd_get = ws_dialer_get; diff --git a/src/supplemental/websocket/websocket_test.c b/src/supplemental/websocket/websocket_test.c index 781ca1d8..9a28d69b 100644 --- a/src/supplemental/websocket/websocket_test.c +++ b/src/supplemental/websocket/websocket_test.c @@ -107,9 +107,13 @@ test_websocket_wildcard(void) NUTS_TRUE(memcmp(buf1, buf2, 5) == 0); nng_stream_close(c1); - nng_stream_free(c1); nng_stream_close(c2); + nng_stream_stop(c1); + nng_stream_stop(c2); + nng_stream_free(c1); nng_stream_free(c2); + nng_stream_listener_stop(l); + nng_stream_dialer_stop(d); nng_aio_free(daio); nng_aio_free(laio); nng_aio_free(aio1); @@ -206,9 +210,13 @@ test_websocket_conn_props(void) nng_strfree(str); nng_stream_close(c1); - nng_stream_free(c1); nng_stream_close(c2); + nng_stream_stop(c1); + nng_stream_stop(c2); + nng_stream_free(c1); nng_stream_free(c2); + nng_stream_listener_stop(l); + nng_stream_dialer_stop(d); nng_aio_free(daio); nng_aio_free(laio); nng_stream_listener_free(l); @@ -495,6 +503,7 @@ test_websocket_fragmentation(void) nng_aio_free(caio); nng_stream_close(c); + nng_stream_stop(c); nng_stream_free(c); nng_aio_free(state.aio); @@ -502,6 +511,8 @@ test_websocket_fragmentation(void) nng_cv_free(state.cv); nng_mtx_free(state.lock); + nng_stream_dialer_stop(d); + nng_stream_listener_stop(l); nng_free(send_buf, state.total); nng_free(recv_buf, state.total); nng_aio_free(daio); -- cgit v1.2.3-70-g09d2