From a2bc02d68272046b4db327a6407f06e216eb6c92 Mon Sep 17 00:00:00 2001 From: goldsimon Date: Thu, 23 Mar 2017 22:04:36 +0100 Subject: [PATCH] altcp_tls_mbedtls: improve sent/recved handling --- src/apps/altcp_tls/altcp_tls_mbedtls.c | 178 ++++++++++++------ .../altcp_tls/altcp_tls_mbedtls_structs.h | 13 +- 2 files changed, 133 insertions(+), 58 deletions(-) diff --git a/src/apps/altcp_tls/altcp_tls_mbedtls.c b/src/apps/altcp_tls/altcp_tls_mbedtls.c index 627b516a..e2902ae8 100644 --- a/src/apps/altcp_tls/altcp_tls_mbedtls.c +++ b/src/apps/altcp_tls/altcp_tls_mbedtls.c @@ -35,21 +35,22 @@ * * Author: Simon Goldschmidt * - * Missing things / @todo: - * - RX data is acknowledged after receiving (tcp_recved is called when enqueueing - * the pbuf for mbedTLS receive, not when processed by mbedTLS or the inner - * connection; altcp_recved() from inner connection does nothing) - * - TX data is marked as 'sent' (i.e. acknowledged; sent callback is called) right - * after enqueueing for transmission, not when actually ACKed be the remote host. - * - Client connections starting with 'connect()' are not handled yet... - * - some unhandled things are caught by LWIP_ASSERTs... - * - only one mbedTLS configuration is supported yet (i.e. one certificate, settings, etc.) + * Watch out: + * - 'sent' is always called with len==0 to the upper layer. This is because keeping + * track of the ratio of application data and TLS overhead would be too much. * - * Configuration: + * Mandatory security-related configuration: * - define ALTCP_MBEDTLS_RNG_FN to a custom GOOD rng function returning 0 on success: * int my_rng_fn(void *ctx, unsigned char *buffer , size_t len) * - define ALTCP_MBEDTLS_ENTROPY_PTR and ALTCP_MBEDTLS_ENTROPY_LEN to something providing * GOOD custom entropy + * + * Missing things / @todo: + * - RX data is acknowledged after receiving (tcp_recved is called when enqueueing + * the pbuf for mbedTLS receive, not when processed by mbedTLS or the inner + * connection; altcp_recved() from inner connection does nothing) + * - Client connections starting with 'connect()' are not handled yet... + * - some unhandled things are caught by LWIP_ASSERTs... */ #include "lwip/opt.h" @@ -107,7 +108,7 @@ struct altcp_tls_config static err_t altcp_mbedtls_setup(void *conf, struct altcp_pcb *conn, struct altcp_pcb *inner_conn); static void altcp_mbedtls_dealloc(struct altcp_pcb *conn); -static err_t altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn); +static err_t altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state); static int altcp_mbedtls_bio_send(void* ctx, const unsigned char* dataptr, size_t size); @@ -157,6 +158,17 @@ altcp_mbedtls_lower_connected(void *arg, struct altcp_pcb *inner_conn, err_t err return ERR_VAL; } +/* Call recved for possibly more than an u16_t */ +static void +altcp_mbedtls_lower_recved(struct altcp_pcb *inner_conn, int recvd_cnt) +{ + while (recvd_cnt > 0) { + u16_t recvd_part = (u16_t)LWIP_MIN(recvd_cnt, 0xFFFF); + altcp_recved(inner_conn, recvd_part); + recvd_cnt -= recvd_part; + } +} + /** Recv callback from lower connection (i.e. TCP) * This one mainly differs between connection setup/handshake (data is fed into mbedTLS only) * and application phase (data is decoded by mbedTLS and passed on to the application). @@ -231,14 +243,18 @@ altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) { /* handle connection setup (handshake not done) */ - int ret; + int ret = mbedtls_ssl_handshake(&state->ssl_context); + /* try to send data... */ + altcp_output(conn->inner_conn); + if (state->bio_bytes_read) { + /* acknowledge all bytes read */ + altcp_mbedtls_lower_recved(inner_conn, state->bio_bytes_read); + state->bio_bytes_read = 0; + } - /* during handshake: mark all data as received */ - altcp_recved(conn->inner_conn, p->tot_len); - - ret = mbedtls_ssl_handshake(&state->ssl_context); if(ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { /* handshake not done, wait for more recv calls */ + LWIP_ASSERT("in this state, the rx chain should be empty", state->rx == NULL); return ERR_OK; } if (ret != 0) { @@ -260,21 +276,48 @@ altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p return ERR_OK; } else { /* handle application data */ - /* @todo: call recved for unencrypted overhead only */ - altcp_recved(conn->inner_conn, p->tot_len); - return altcp_mbedtls_handle_rx_data(conn); + return altcp_mbedtls_handle_rx_data(conn, state); } } +/* Pass queued decoded rx data to application */ +static err_t +altcp_mbedtls_pass_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state) +{ + struct pbuf *buf; + LWIP_ASSERT("conn != NULL", conn != NULL); + LWIP_ASSERT("state != NULL", state != NULL); + buf = state->rx_app; + if (buf) { + if (conn->recv) { + u16_t tot_len = state->rx_app->tot_len; + err_t err; + /* this needs to be increased first because the 'recved' call may come nested */ + state->rx_passed_unrecved += tot_len; + err = conn->recv(conn->arg, conn, state->rx_app, ERR_OK); + if (err != ERR_OK) { + /* not received, leave the pbuf(s) queued (and decrease 'unrecved' again) */ + state->rx_passed_unrecved -= tot_len; + LWIP_ASSERT("state->rx_passed_unrecved >= 0", state->rx_passed_unrecved >= 0); + if (state->rx_passed_unrecved < 0) { + state->rx_passed_unrecved = 0; + } + return err; + } + } else { + pbuf_free(buf); + } + state->rx_app = NULL; + } + return ERR_OK; +} + /* Helper function that processes rx application data stored in rx pbuf chain */ static err_t -altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn) +altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state) { int ret; - altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state; - if (!state) { - return ERR_VAL; - } + LWIP_ASSERT("state != NULL", state != NULL); if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) { /* handshake not done yet */ return ERR_VAL; @@ -311,19 +354,37 @@ altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn) altcp_abort(conn); return ERR_ABRT; } else { - LWIP_ASSERT("bogus receive length", ret <= 0xFFFF && ret <= PBUF_POOL_BUFSIZE); + err_t err; + LWIP_ASSERT("bogus receive length", ret <= PBUF_POOL_BUFSIZE); /* trim pool pbuf to actually decoded length */ pbuf_realloc(buf, (uint16_t)ret); - if (conn->recv) { - err_t err; - state->rx_passed_unrecved += buf->tot_len; - err = conn->recv(conn->arg, conn, buf, ERR_OK); - if (err == ERR_ABRT) { - return ERR_ABRT; - } + state->bio_bytes_appl += ret; + if (mbedtls_ssl_get_bytes_avail(&state->ssl_context) == 0) { + /* Record is done, now we know the share between application and protocol bytes + and can adjust the RX window by the protocol bytes. + The rest is 'recved' by the application calling our 'recved' fn. */ + int overhead_bytes; + LWIP_ASSERT("bogus byte counts", state->bio_bytes_read > state->bio_bytes_appl); + overhead_bytes = state->bio_bytes_read - state->bio_bytes_appl; + altcp_mbedtls_lower_recved(conn->inner_conn, overhead_bytes); + state->bio_bytes_read = 0; + state->bio_bytes_appl = 0; + } + + if (state->rx_app == NULL) { + state->rx_app = buf; } else { - pbuf_free(buf); + pbuf_cat(state->rx_app, buf); + } + err = altcp_mbedtls_pass_rx_data(conn, state); + if (err != ERR_OK) { + if (err == ERR_ABRT) { + /* recv callback needs to return this as the pcb is deallocated */ + return err; + } + /* we hide all other errors as we retry feeding the pbuf to the app later */ + return ERR_OK; } } } @@ -341,8 +402,7 @@ altcp_mbedtls_bio_recv(void *ctx, unsigned char *buf, size_t len) altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state; struct pbuf* p; u16_t ret; - /* limit number of byts to copy to fit into an s16_t for pbuf_header */ - u16_t copy_len = (u16_t)LWIP_MIN(len, 0x7FFF); + u16_t copy_len; err_t err; if (state == NULL) { @@ -350,46 +410,49 @@ altcp_mbedtls_bio_recv(void *ctx, unsigned char *buf, size_t len) } p = state->rx; - LWIP_ASSERT("len is too big", len <= 0xFFFF); - if (p == NULL) { return MBEDTLS_ERR_SSL_WANT_READ; } + /* limit number of bytes to copy to fit into an s16_t for pbuf_header */ + copy_len = (u16_t)LWIP_MIN(len, 0x7FFF); + /* limit number of bytes again to copy from first pbuf in a chain only */ copy_len = LWIP_MIN(copy_len, p->len); + /* copy the data */ ret = pbuf_copy_partial(p, buf, copy_len, 0); - LWIP_ASSERT("ret <= p->len", ret <= p->len); + LWIP_ASSERT("ret == copy_len", ret == copy_len); + /* hide the copied bytes from the pbuf */ err = pbuf_header(p, -(s16_t)ret); LWIP_ASSERT("error", err == ERR_OK); - if(p->len == 0) { + if (p->len == 0) { + /* the first pbuf has been fully read, free it */ state->rx = p->next; p->next = NULL; pbuf_free(p); } + state->bio_bytes_read += (int)ret; return ret; } /** Sent callback from lower connection (i.e. TCP) - * @todo: Pass on the correct number of bytes to the application. - * This is somewhat tricky as we don't know the data/overhead ratio... + * This only informs the upper layer to try to send more, not about + * the number of ACKed bytes. */ static err_t altcp_mbedtls_lower_sent(void *arg, struct altcp_pcb *inner_conn, u16_t len) { struct altcp_pcb *conn = (struct altcp_pcb *)arg; + LWIP_UNUSED_ARG(len); if (conn) { - u16_t sent_upper; altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state; LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn); if (!state || !(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) { /* @todo: do something here? */ return ERR_OK; } - /* @todo: this is not accurate yet, need to fix byte counting to upper and lower conn */ - sent_upper = (u16_t)LWIP_MIN(len, state->tx_unacked); - state->tx_unacked -= sent_upper; - if (conn->sent && sent_upper) { - return conn->sent(conn->arg, conn, len); + /* call upper sent with len==0 if the application already sent data */ + if ((state->flags & ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT) && conn->sent) { + return conn->sent(conn->arg, conn, 0); } } return ERR_OK; @@ -406,7 +469,9 @@ altcp_mbedtls_lower_poll(void *arg, struct altcp_pcb *inner_conn) if (conn) { LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn); /* check if there's unreceived rx data */ - altcp_mbedtls_handle_rx_data(conn); + if (conn->state) { + altcp_mbedtls_handle_rx_data(conn, (altcp_mbedtls_state_t *)conn->state); + } if (conn->poll) { return conn->poll(conn->arg, conn); } @@ -624,6 +689,7 @@ altcp_mbedtls_set_poll(struct altcp_pcb *conn, u8_t interval) static void altcp_mbedtls_recved(struct altcp_pcb *conn, u16_t len) { + u16_t lower_recved; altcp_mbedtls_state_t *state; if (conn == NULL) { return; @@ -635,11 +701,15 @@ altcp_mbedtls_recved(struct altcp_pcb *conn, u16_t len) if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) { return; } - LWIP_ASSERT("recved mismatch", state->rx_passed_unrecved >= len); - state->rx_passed_unrecved -= len; + lower_recved = len; + if (lower_recved > state->rx_passed_unrecved) { + LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("bogus recved count (len > state->rx_passed_unrecved / %d / %d)", + len, state->rx_passed_unrecved)); + lower_recved = (u16_t)state->rx_passed_unrecved; + } + state->rx_passed_unrecved -= lower_recved; - /* to pass this down, we need to convert 'altcp_recved' handling in lower_recv first - altcp_recved(conn->inner_conn, len);*/ + altcp_recved(conn->inner_conn, lower_recved); } static err_t @@ -741,8 +811,10 @@ altcp_mbedtls_write(struct altcp_pcb *conn, const void *dataptr, u16_t len, u8_t } ret = mbedtls_ssl_write(&state->ssl_context, (const unsigned char *)dataptr, len); + /* try to send data... */ + altcp_output(conn->inner_conn); if(ret == len) { - state->tx_unacked += len; + state->flags |= ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT; return ERR_OK; } else if (ret <= 0) { /* @todo: convert error to err_t */ diff --git a/src/apps/altcp_tls/altcp_tls_mbedtls_structs.h b/src/apps/altcp_tls/altcp_tls_mbedtls_structs.h index 354808d6..12170b02 100644 --- a/src/apps/altcp_tls/altcp_tls_mbedtls_structs.h +++ b/src/apps/altcp_tls/altcp_tls_mbedtls_structs.h @@ -53,19 +53,22 @@ extern "C" { #endif #define ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE 0x01 -#define ALTCP_MBEDTLS_FLAGS_RX_CLOSED 0x02 -#define ALTCP_MBEDTLS_FLAGS_TX_CLOSED 0x04 +#define ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT 0x02 +#define ALTCP_MBEDTLS_FLAGS_RX_CLOSED 0x04 +#define ALTCP_MBEDTLS_FLAGS_TX_CLOSED 0x08 #define ALTCP_MBEDTLS_FLAGS_CLOSED (ALTCP_MBEDTLS_FLAGS_RX_CLOSED|ALTCP_MBEDTLS_FLAGS_TX_CLOSED) -#define ALTCP_MBEDTLS_FLAGS_UPPER_CALLED 0x08 +#define ALTCP_MBEDTLS_FLAGS_UPPER_CALLED 0x10 typedef struct altcp_mbedtls_state_s { void *conf; mbedtls_ssl_context ssl_context; /* chain of rx pbufs (before decryption) */ struct pbuf* rx; + struct pbuf* rx_app; u8_t flags; - size_t rx_passed_unrecved; - size_t tx_unacked; + int rx_passed_unrecved; + int bio_bytes_read; + int bio_bytes_appl; } altcp_mbedtls_state_t; #ifdef __cplusplus