diff options
Diffstat (limited to 'src/supplemental')
| -rw-r--r-- | src/supplemental/tls/mbedtls/tls.c | 20 | ||||
| -rw-r--r-- | src/supplemental/tls/tls_common.c | 21 |
2 files changed, 19 insertions, 22 deletions
diff --git a/src/supplemental/tls/mbedtls/tls.c b/src/supplemental/tls/mbedtls/tls.c index d3816747..b7ed0575 100644 --- a/src/supplemental/tls/mbedtls/tls.c +++ b/src/supplemental/tls/mbedtls/tls.c @@ -95,7 +95,7 @@ struct nng_tls_config { mbedtls_x509_crt ca_certs; mbedtls_x509_crl crl; - int refcnt; // servers increment the reference + nni_atomic_u64 refcnt; nni_list certkeys; }; @@ -159,13 +159,9 @@ nni_tls_config_fini(nng_tls_config *cfg) nni_tls_certkey *ck; if (cfg != NULL) { - nni_mtx_lock(&cfg->lk); - cfg->refcnt--; - if (cfg->refcnt != 0) { - nni_mtx_unlock(&cfg->lk); + if (nni_atomic_dec64_nv(&cfg->refcnt) != 0) { return; } - nni_mtx_unlock(&cfg->lk); mbedtls_ssl_config_free(&cfg->cfg_ctx); #ifdef NNG_TLS_USE_CTR_DRBG @@ -199,7 +195,8 @@ nni_tls_config_init(nng_tls_config **cpp, enum nng_tls_mode mode) if ((cfg = NNI_ALLOC_STRUCT(cfg)) == NULL) { return (NNG_ENOMEM); } - cfg->refcnt = 1; + nni_atomic_init64(&cfg->refcnt); + nni_atomic_inc64(&cfg->refcnt); nni_mtx_init(&cfg->lk); if (mode == NNG_TLS_MODE_SERVER) { sslmode = MBEDTLS_SSL_IS_SERVER; @@ -247,9 +244,7 @@ nni_tls_config_init(nng_tls_config **cpp, enum nng_tls_mode mode) void nni_tls_config_hold(nng_tls_config *cfg) { - nni_mtx_lock(&cfg->lk); - cfg->refcnt++; - nni_mtx_unlock(&cfg->lk); + nni_atomic_inc64(&cfg->refcnt); } // tls_mkerr converts an mbed error to an NNG error. In all cases @@ -295,15 +290,15 @@ tls_free(void *arg) } nni_aio_stop(tls->tcp_send); nni_aio_stop(tls->tcp_recv); - nni_aio_fini(tls->com.aio); - nng_tls_config_free(tls->com.cfg); // And finalize / free everything. nng_stream_free(tls->tcp); nni_aio_fini(tls->tcp_send); nni_aio_fini(tls->tcp_recv); mbedtls_ssl_free(&tls->ctx); + nng_tls_config_free(tls->com.cfg); + if (tls->recvbuf != NULL) { nni_free(tls->recvbuf, NNG_TLS_MAX_RECV_SIZE); } @@ -311,6 +306,7 @@ tls_free(void *arg) nni_free(tls->sendbuf, NNG_TLS_MAX_RECV_SIZE); } nni_mtx_fini(&tls->lk); + memset(tls, 0xff, sizeof(*tls)); NNI_FREE_STRUCT(tls); } } diff --git a/src/supplemental/tls/tls_common.c b/src/supplemental/tls/tls_common.c index 990d3add..97bcce26 100644 --- a/src/supplemental/tls/tls_common.c +++ b/src/supplemental/tls/tls_common.c @@ -160,14 +160,14 @@ tls_dialer_set_config(void *arg, const void *buf, size_t sz, nni_type t) if (cfg == NULL) { return (NNG_EINVAL); } - nni_mtx_lock(&d->lk); - old = d->cfg; nng_tls_config_hold(cfg); + + nni_mtx_lock(&d->lk); + old = d->cfg; d->cfg = cfg; nni_mtx_unlock(&d->lk); - if (old != NULL) { - nng_tls_config_free(old); - } + + nng_tls_config_free(old); return (0); } @@ -432,14 +432,15 @@ tls_listener_set_config(void *arg, const void *buf, size_t sz, nni_type t) return (NNG_EINVAL); } - nni_mtx_lock(&l->lk); - old = l->cfg; nng_tls_config_hold(cfg); + + nni_mtx_lock(&l->lk); + old = l->cfg; l->cfg = cfg; nni_mtx_unlock(&l->lk); - if (old != NULL) { - nng_tls_config_free(old); - } + + nng_tls_config_free(old); + return (0); } |
