altcp_tls_mbedtls: improve sent/recved handling

This commit is contained in:
goldsimon 2017-03-23 22:04:36 +01:00
parent 1ddb125e2c
commit a2bc02d682
2 changed files with 133 additions and 58 deletions

View File

@ -35,21 +35,22 @@
* *
* Author: Simon Goldschmidt <goldsimon@gmx.de> * Author: Simon Goldschmidt <goldsimon@gmx.de>
* *
* Missing things / @todo: * Watch out:
* - RX data is acknowledged after receiving (tcp_recved is called when enqueueing * - 'sent' is always called with len==0 to the upper layer. This is because keeping
* the pbuf for mbedTLS receive, not when processed by mbedTLS or the inner * track of the ratio of application data and TLS overhead would be too much.
* connection; altcp_recved() from inner connection does nothing)
* - TX data is marked as 'sent' (i.e. acknowledged; sent callback is called) right
* after enqueueing for transmission, not when actually ACKed be the remote host.
* - Client connections starting with 'connect()' are not handled yet...
* - some unhandled things are caught by LWIP_ASSERTs...
* - only one mbedTLS configuration is supported yet (i.e. one certificate, settings, etc.)
* *
* Configuration: * Mandatory security-related configuration:
* - define ALTCP_MBEDTLS_RNG_FN to a custom GOOD rng function returning 0 on success: * - define ALTCP_MBEDTLS_RNG_FN to a custom GOOD rng function returning 0 on success:
* int my_rng_fn(void *ctx, unsigned char *buffer , size_t len) * int my_rng_fn(void *ctx, unsigned char *buffer , size_t len)
* - define ALTCP_MBEDTLS_ENTROPY_PTR and ALTCP_MBEDTLS_ENTROPY_LEN to something providing * - define ALTCP_MBEDTLS_ENTROPY_PTR and ALTCP_MBEDTLS_ENTROPY_LEN to something providing
* GOOD custom entropy * GOOD custom entropy
*
* Missing things / @todo:
* - RX data is acknowledged after receiving (tcp_recved is called when enqueueing
* the pbuf for mbedTLS receive, not when processed by mbedTLS or the inner
* connection; altcp_recved() from inner connection does nothing)
* - Client connections starting with 'connect()' are not handled yet...
* - some unhandled things are caught by LWIP_ASSERTs...
*/ */
#include "lwip/opt.h" #include "lwip/opt.h"
@ -107,7 +108,7 @@ struct altcp_tls_config
static err_t altcp_mbedtls_setup(void *conf, struct altcp_pcb *conn, struct altcp_pcb *inner_conn); static err_t altcp_mbedtls_setup(void *conf, struct altcp_pcb *conn, struct altcp_pcb *inner_conn);
static void altcp_mbedtls_dealloc(struct altcp_pcb *conn); static void altcp_mbedtls_dealloc(struct altcp_pcb *conn);
static err_t altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn); static err_t altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state);
static int altcp_mbedtls_bio_send(void* ctx, const unsigned char* dataptr, size_t size); static int altcp_mbedtls_bio_send(void* ctx, const unsigned char* dataptr, size_t size);
@ -157,6 +158,17 @@ altcp_mbedtls_lower_connected(void *arg, struct altcp_pcb *inner_conn, err_t err
return ERR_VAL; return ERR_VAL;
} }
/* Call recved for possibly more than an u16_t */
static void
altcp_mbedtls_lower_recved(struct altcp_pcb *inner_conn, int recvd_cnt)
{
while (recvd_cnt > 0) {
u16_t recvd_part = (u16_t)LWIP_MIN(recvd_cnt, 0xFFFF);
altcp_recved(inner_conn, recvd_part);
recvd_cnt -= recvd_part;
}
}
/** Recv callback from lower connection (i.e. TCP) /** Recv callback from lower connection (i.e. TCP)
* This one mainly differs between connection setup/handshake (data is fed into mbedTLS only) * This one mainly differs between connection setup/handshake (data is fed into mbedTLS only)
* and application phase (data is decoded by mbedTLS and passed on to the application). * and application phase (data is decoded by mbedTLS and passed on to the application).
@ -231,14 +243,18 @@ altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p
if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) { if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
/* handle connection setup (handshake not done) */ /* handle connection setup (handshake not done) */
int ret; int ret = mbedtls_ssl_handshake(&state->ssl_context);
/* try to send data... */
altcp_output(conn->inner_conn);
if (state->bio_bytes_read) {
/* acknowledge all bytes read */
altcp_mbedtls_lower_recved(inner_conn, state->bio_bytes_read);
state->bio_bytes_read = 0;
}
/* during handshake: mark all data as received */
altcp_recved(conn->inner_conn, p->tot_len);
ret = mbedtls_ssl_handshake(&state->ssl_context);
if(ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) { if(ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
/* handshake not done, wait for more recv calls */ /* handshake not done, wait for more recv calls */
LWIP_ASSERT("in this state, the rx chain should be empty", state->rx == NULL);
return ERR_OK; return ERR_OK;
} }
if (ret != 0) { if (ret != 0) {
@ -260,21 +276,48 @@ altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p
return ERR_OK; return ERR_OK;
} else { } else {
/* handle application data */ /* handle application data */
/* @todo: call recved for unencrypted overhead only */ return altcp_mbedtls_handle_rx_data(conn, state);
altcp_recved(conn->inner_conn, p->tot_len);
return altcp_mbedtls_handle_rx_data(conn);
} }
} }
/* Pass queued decoded rx data to application */
static err_t
altcp_mbedtls_pass_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state)
{
struct pbuf *buf;
LWIP_ASSERT("conn != NULL", conn != NULL);
LWIP_ASSERT("state != NULL", state != NULL);
buf = state->rx_app;
if (buf) {
if (conn->recv) {
u16_t tot_len = state->rx_app->tot_len;
err_t err;
/* this needs to be increased first because the 'recved' call may come nested */
state->rx_passed_unrecved += tot_len;
err = conn->recv(conn->arg, conn, state->rx_app, ERR_OK);
if (err != ERR_OK) {
/* not received, leave the pbuf(s) queued (and decrease 'unrecved' again) */
state->rx_passed_unrecved -= tot_len;
LWIP_ASSERT("state->rx_passed_unrecved >= 0", state->rx_passed_unrecved >= 0);
if (state->rx_passed_unrecved < 0) {
state->rx_passed_unrecved = 0;
}
return err;
}
} else {
pbuf_free(buf);
}
state->rx_app = NULL;
}
return ERR_OK;
}
/* Helper function that processes rx application data stored in rx pbuf chain */ /* Helper function that processes rx application data stored in rx pbuf chain */
static err_t static err_t
altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn) altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state)
{ {
int ret; int ret;
altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state; LWIP_ASSERT("state != NULL", state != NULL);
if (!state) {
return ERR_VAL;
}
if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) { if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
/* handshake not done yet */ /* handshake not done yet */
return ERR_VAL; return ERR_VAL;
@ -311,19 +354,37 @@ altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn)
altcp_abort(conn); altcp_abort(conn);
return ERR_ABRT; return ERR_ABRT;
} else { } else {
LWIP_ASSERT("bogus receive length", ret <= 0xFFFF && ret <= PBUF_POOL_BUFSIZE); err_t err;
LWIP_ASSERT("bogus receive length", ret <= PBUF_POOL_BUFSIZE);
/* trim pool pbuf to actually decoded length */ /* trim pool pbuf to actually decoded length */
pbuf_realloc(buf, (uint16_t)ret); pbuf_realloc(buf, (uint16_t)ret);
if (conn->recv) { state->bio_bytes_appl += ret;
err_t err; if (mbedtls_ssl_get_bytes_avail(&state->ssl_context) == 0) {
state->rx_passed_unrecved += buf->tot_len; /* Record is done, now we know the share between application and protocol bytes
err = conn->recv(conn->arg, conn, buf, ERR_OK); and can adjust the RX window by the protocol bytes.
if (err == ERR_ABRT) { The rest is 'recved' by the application calling our 'recved' fn. */
return ERR_ABRT; int overhead_bytes;
} LWIP_ASSERT("bogus byte counts", state->bio_bytes_read > state->bio_bytes_appl);
overhead_bytes = state->bio_bytes_read - state->bio_bytes_appl;
altcp_mbedtls_lower_recved(conn->inner_conn, overhead_bytes);
state->bio_bytes_read = 0;
state->bio_bytes_appl = 0;
}
if (state->rx_app == NULL) {
state->rx_app = buf;
} else { } else {
pbuf_free(buf); pbuf_cat(state->rx_app, buf);
}
err = altcp_mbedtls_pass_rx_data(conn, state);
if (err != ERR_OK) {
if (err == ERR_ABRT) {
/* recv callback needs to return this as the pcb is deallocated */
return err;
}
/* we hide all other errors as we retry feeding the pbuf to the app later */
return ERR_OK;
} }
} }
} }
@ -341,8 +402,7 @@ altcp_mbedtls_bio_recv(void *ctx, unsigned char *buf, size_t len)
altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state; altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
struct pbuf* p; struct pbuf* p;
u16_t ret; u16_t ret;
/* limit number of byts to copy to fit into an s16_t for pbuf_header */ u16_t copy_len;
u16_t copy_len = (u16_t)LWIP_MIN(len, 0x7FFF);
err_t err; err_t err;
if (state == NULL) { if (state == NULL) {
@ -350,46 +410,49 @@ altcp_mbedtls_bio_recv(void *ctx, unsigned char *buf, size_t len)
} }
p = state->rx; p = state->rx;
LWIP_ASSERT("len is too big", len <= 0xFFFF);
if (p == NULL) { if (p == NULL) {
return MBEDTLS_ERR_SSL_WANT_READ; return MBEDTLS_ERR_SSL_WANT_READ;
} }
/* limit number of bytes to copy to fit into an s16_t for pbuf_header */
copy_len = (u16_t)LWIP_MIN(len, 0x7FFF);
/* limit number of bytes again to copy from first pbuf in a chain only */
copy_len = LWIP_MIN(copy_len, p->len); copy_len = LWIP_MIN(copy_len, p->len);
/* copy the data */
ret = pbuf_copy_partial(p, buf, copy_len, 0); ret = pbuf_copy_partial(p, buf, copy_len, 0);
LWIP_ASSERT("ret <= p->len", ret <= p->len); LWIP_ASSERT("ret == copy_len", ret == copy_len);
/* hide the copied bytes from the pbuf */
err = pbuf_header(p, -(s16_t)ret); err = pbuf_header(p, -(s16_t)ret);
LWIP_ASSERT("error", err == ERR_OK); LWIP_ASSERT("error", err == ERR_OK);
if(p->len == 0) { if (p->len == 0) {
/* the first pbuf has been fully read, free it */
state->rx = p->next; state->rx = p->next;
p->next = NULL; p->next = NULL;
pbuf_free(p); pbuf_free(p);
} }
state->bio_bytes_read += (int)ret;
return ret; return ret;
} }
/** Sent callback from lower connection (i.e. TCP) /** Sent callback from lower connection (i.e. TCP)
* @todo: Pass on the correct number of bytes to the application. * This only informs the upper layer to try to send more, not about
* This is somewhat tricky as we don't know the data/overhead ratio... * the number of ACKed bytes.
*/ */
static err_t static err_t
altcp_mbedtls_lower_sent(void *arg, struct altcp_pcb *inner_conn, u16_t len) altcp_mbedtls_lower_sent(void *arg, struct altcp_pcb *inner_conn, u16_t len)
{ {
struct altcp_pcb *conn = (struct altcp_pcb *)arg; struct altcp_pcb *conn = (struct altcp_pcb *)arg;
LWIP_UNUSED_ARG(len);
if (conn) { if (conn) {
u16_t sent_upper;
altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state; altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn); LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn);
if (!state || !(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) { if (!state || !(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
/* @todo: do something here? */ /* @todo: do something here? */
return ERR_OK; return ERR_OK;
} }
/* @todo: this is not accurate yet, need to fix byte counting to upper and lower conn */ /* call upper sent with len==0 if the application already sent data */
sent_upper = (u16_t)LWIP_MIN(len, state->tx_unacked); if ((state->flags & ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT) && conn->sent) {
state->tx_unacked -= sent_upper; return conn->sent(conn->arg, conn, 0);
if (conn->sent && sent_upper) {
return conn->sent(conn->arg, conn, len);
} }
} }
return ERR_OK; return ERR_OK;
@ -406,7 +469,9 @@ altcp_mbedtls_lower_poll(void *arg, struct altcp_pcb *inner_conn)
if (conn) { if (conn) {
LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn); LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn);
/* check if there's unreceived rx data */ /* check if there's unreceived rx data */
altcp_mbedtls_handle_rx_data(conn); if (conn->state) {
altcp_mbedtls_handle_rx_data(conn, (altcp_mbedtls_state_t *)conn->state);
}
if (conn->poll) { if (conn->poll) {
return conn->poll(conn->arg, conn); return conn->poll(conn->arg, conn);
} }
@ -624,6 +689,7 @@ altcp_mbedtls_set_poll(struct altcp_pcb *conn, u8_t interval)
static void static void
altcp_mbedtls_recved(struct altcp_pcb *conn, u16_t len) altcp_mbedtls_recved(struct altcp_pcb *conn, u16_t len)
{ {
u16_t lower_recved;
altcp_mbedtls_state_t *state; altcp_mbedtls_state_t *state;
if (conn == NULL) { if (conn == NULL) {
return; return;
@ -635,11 +701,15 @@ altcp_mbedtls_recved(struct altcp_pcb *conn, u16_t len)
if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) { if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
return; return;
} }
LWIP_ASSERT("recved mismatch", state->rx_passed_unrecved >= len); lower_recved = len;
state->rx_passed_unrecved -= len; if (lower_recved > state->rx_passed_unrecved) {
LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("bogus recved count (len > state->rx_passed_unrecved / %d / %d)",
len, state->rx_passed_unrecved));
lower_recved = (u16_t)state->rx_passed_unrecved;
}
state->rx_passed_unrecved -= lower_recved;
/* to pass this down, we need to convert 'altcp_recved' handling in lower_recv first altcp_recved(conn->inner_conn, lower_recved);
altcp_recved(conn->inner_conn, len);*/
} }
static err_t static err_t
@ -741,8 +811,10 @@ altcp_mbedtls_write(struct altcp_pcb *conn, const void *dataptr, u16_t len, u8_t
} }
ret = mbedtls_ssl_write(&state->ssl_context, (const unsigned char *)dataptr, len); ret = mbedtls_ssl_write(&state->ssl_context, (const unsigned char *)dataptr, len);
/* try to send data... */
altcp_output(conn->inner_conn);
if(ret == len) { if(ret == len) {
state->tx_unacked += len; state->flags |= ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT;
return ERR_OK; return ERR_OK;
} else if (ret <= 0) { } else if (ret <= 0) {
/* @todo: convert error to err_t */ /* @todo: convert error to err_t */

View File

@ -53,19 +53,22 @@ extern "C" {
#endif #endif
#define ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE 0x01 #define ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE 0x01
#define ALTCP_MBEDTLS_FLAGS_RX_CLOSED 0x02 #define ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT 0x02
#define ALTCP_MBEDTLS_FLAGS_TX_CLOSED 0x04 #define ALTCP_MBEDTLS_FLAGS_RX_CLOSED 0x04
#define ALTCP_MBEDTLS_FLAGS_TX_CLOSED 0x08
#define ALTCP_MBEDTLS_FLAGS_CLOSED (ALTCP_MBEDTLS_FLAGS_RX_CLOSED|ALTCP_MBEDTLS_FLAGS_TX_CLOSED) #define ALTCP_MBEDTLS_FLAGS_CLOSED (ALTCP_MBEDTLS_FLAGS_RX_CLOSED|ALTCP_MBEDTLS_FLAGS_TX_CLOSED)
#define ALTCP_MBEDTLS_FLAGS_UPPER_CALLED 0x08 #define ALTCP_MBEDTLS_FLAGS_UPPER_CALLED 0x10
typedef struct altcp_mbedtls_state_s { typedef struct altcp_mbedtls_state_s {
void *conf; void *conf;
mbedtls_ssl_context ssl_context; mbedtls_ssl_context ssl_context;
/* chain of rx pbufs (before decryption) */ /* chain of rx pbufs (before decryption) */
struct pbuf* rx; struct pbuf* rx;
struct pbuf* rx_app;
u8_t flags; u8_t flags;
size_t rx_passed_unrecved; int rx_passed_unrecved;
size_t tx_unacked; int bio_bytes_read;
int bio_bytes_appl;
} altcp_mbedtls_state_t; } altcp_mbedtls_state_t;
#ifdef __cplusplus #ifdef __cplusplus