altcp_tls_mbedtls: improve sent/recved handling

This commit is contained in:
goldsimon 2017-03-23 22:04:36 +01:00
parent 1ddb125e2c
commit a2bc02d682
2 changed files with 133 additions and 58 deletions

View File

@ -35,21 +35,22 @@
*
* Author: Simon Goldschmidt <goldsimon@gmx.de>
*
* Missing things / @todo:
* - RX data is acknowledged after receiving (tcp_recved is called when enqueueing
* the pbuf for mbedTLS receive, not when processed by mbedTLS or the inner
* connection; altcp_recved() from inner connection does nothing)
* - TX data is marked as 'sent' (i.e. acknowledged; sent callback is called) right
* after enqueueing for transmission, not when actually ACKed be the remote host.
* - Client connections starting with 'connect()' are not handled yet...
* - some unhandled things are caught by LWIP_ASSERTs...
* - only one mbedTLS configuration is supported yet (i.e. one certificate, settings, etc.)
* Watch out:
* - 'sent' is always called with len==0 to the upper layer. This is because keeping
* track of the ratio of application data and TLS overhead would be too much.
*
* Configuration:
* Mandatory security-related configuration:
* - define ALTCP_MBEDTLS_RNG_FN to a custom GOOD rng function returning 0 on success:
* int my_rng_fn(void *ctx, unsigned char *buffer , size_t len)
* - define ALTCP_MBEDTLS_ENTROPY_PTR and ALTCP_MBEDTLS_ENTROPY_LEN to something providing
* GOOD custom entropy
*
* Missing things / @todo:
* - RX data is acknowledged after receiving (tcp_recved is called when enqueueing
* the pbuf for mbedTLS receive, not when processed by mbedTLS or the inner
* connection; altcp_recved() from inner connection does nothing)
* - Client connections starting with 'connect()' are not handled yet...
* - some unhandled things are caught by LWIP_ASSERTs...
*/
#include "lwip/opt.h"
@ -107,7 +108,7 @@ struct altcp_tls_config
static err_t altcp_mbedtls_setup(void *conf, struct altcp_pcb *conn, struct altcp_pcb *inner_conn);
static void altcp_mbedtls_dealloc(struct altcp_pcb *conn);
static err_t altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn);
static err_t altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state);
static int altcp_mbedtls_bio_send(void* ctx, const unsigned char* dataptr, size_t size);
@ -157,6 +158,17 @@ altcp_mbedtls_lower_connected(void *arg, struct altcp_pcb *inner_conn, err_t err
return ERR_VAL;
}
/* Call recved for possibly more than an u16_t */
static void
altcp_mbedtls_lower_recved(struct altcp_pcb *inner_conn, int recvd_cnt)
{
while (recvd_cnt > 0) {
u16_t recvd_part = (u16_t)LWIP_MIN(recvd_cnt, 0xFFFF);
altcp_recved(inner_conn, recvd_part);
recvd_cnt -= recvd_part;
}
}
/** Recv callback from lower connection (i.e. TCP)
* This one mainly differs between connection setup/handshake (data is fed into mbedTLS only)
* and application phase (data is decoded by mbedTLS and passed on to the application).
@ -231,14 +243,18 @@ altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p
if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
/* handle connection setup (handshake not done) */
int ret;
int ret = mbedtls_ssl_handshake(&state->ssl_context);
/* try to send data... */
altcp_output(conn->inner_conn);
if (state->bio_bytes_read) {
/* acknowledge all bytes read */
altcp_mbedtls_lower_recved(inner_conn, state->bio_bytes_read);
state->bio_bytes_read = 0;
}
/* during handshake: mark all data as received */
altcp_recved(conn->inner_conn, p->tot_len);
ret = mbedtls_ssl_handshake(&state->ssl_context);
if(ret == MBEDTLS_ERR_SSL_WANT_READ || ret == MBEDTLS_ERR_SSL_WANT_WRITE) {
/* handshake not done, wait for more recv calls */
LWIP_ASSERT("in this state, the rx chain should be empty", state->rx == NULL);
return ERR_OK;
}
if (ret != 0) {
@ -260,21 +276,48 @@ altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p
return ERR_OK;
} else {
/* handle application data */
/* @todo: call recved for unencrypted overhead only */
altcp_recved(conn->inner_conn, p->tot_len);
return altcp_mbedtls_handle_rx_data(conn);
return altcp_mbedtls_handle_rx_data(conn, state);
}
}
/* Pass queued decoded rx data to application */
static err_t
altcp_mbedtls_pass_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state)
{
struct pbuf *buf;
LWIP_ASSERT("conn != NULL", conn != NULL);
LWIP_ASSERT("state != NULL", state != NULL);
buf = state->rx_app;
if (buf) {
if (conn->recv) {
u16_t tot_len = state->rx_app->tot_len;
err_t err;
/* this needs to be increased first because the 'recved' call may come nested */
state->rx_passed_unrecved += tot_len;
err = conn->recv(conn->arg, conn, state->rx_app, ERR_OK);
if (err != ERR_OK) {
/* not received, leave the pbuf(s) queued (and decrease 'unrecved' again) */
state->rx_passed_unrecved -= tot_len;
LWIP_ASSERT("state->rx_passed_unrecved >= 0", state->rx_passed_unrecved >= 0);
if (state->rx_passed_unrecved < 0) {
state->rx_passed_unrecved = 0;
}
return err;
}
} else {
pbuf_free(buf);
}
state->rx_app = NULL;
}
return ERR_OK;
}
/* Helper function that processes rx application data stored in rx pbuf chain */
static err_t
altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn)
altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state)
{
int ret;
altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
if (!state) {
return ERR_VAL;
}
LWIP_ASSERT("state != NULL", state != NULL);
if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
/* handshake not done yet */
return ERR_VAL;
@ -311,19 +354,37 @@ altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn)
altcp_abort(conn);
return ERR_ABRT;
} else {
LWIP_ASSERT("bogus receive length", ret <= 0xFFFF && ret <= PBUF_POOL_BUFSIZE);
err_t err;
LWIP_ASSERT("bogus receive length", ret <= PBUF_POOL_BUFSIZE);
/* trim pool pbuf to actually decoded length */
pbuf_realloc(buf, (uint16_t)ret);
if (conn->recv) {
err_t err;
state->rx_passed_unrecved += buf->tot_len;
err = conn->recv(conn->arg, conn, buf, ERR_OK);
if (err == ERR_ABRT) {
return ERR_ABRT;
state->bio_bytes_appl += ret;
if (mbedtls_ssl_get_bytes_avail(&state->ssl_context) == 0) {
/* Record is done, now we know the share between application and protocol bytes
and can adjust the RX window by the protocol bytes.
The rest is 'recved' by the application calling our 'recved' fn. */
int overhead_bytes;
LWIP_ASSERT("bogus byte counts", state->bio_bytes_read > state->bio_bytes_appl);
overhead_bytes = state->bio_bytes_read - state->bio_bytes_appl;
altcp_mbedtls_lower_recved(conn->inner_conn, overhead_bytes);
state->bio_bytes_read = 0;
state->bio_bytes_appl = 0;
}
if (state->rx_app == NULL) {
state->rx_app = buf;
} else {
pbuf_free(buf);
pbuf_cat(state->rx_app, buf);
}
err = altcp_mbedtls_pass_rx_data(conn, state);
if (err != ERR_OK) {
if (err == ERR_ABRT) {
/* recv callback needs to return this as the pcb is deallocated */
return err;
}
/* we hide all other errors as we retry feeding the pbuf to the app later */
return ERR_OK;
}
}
}
@ -341,8 +402,7 @@ altcp_mbedtls_bio_recv(void *ctx, unsigned char *buf, size_t len)
altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
struct pbuf* p;
u16_t ret;
/* limit number of byts to copy to fit into an s16_t for pbuf_header */
u16_t copy_len = (u16_t)LWIP_MIN(len, 0x7FFF);
u16_t copy_len;
err_t err;
if (state == NULL) {
@ -350,46 +410,49 @@ altcp_mbedtls_bio_recv(void *ctx, unsigned char *buf, size_t len)
}
p = state->rx;
LWIP_ASSERT("len is too big", len <= 0xFFFF);
if (p == NULL) {
return MBEDTLS_ERR_SSL_WANT_READ;
}
/* limit number of bytes to copy to fit into an s16_t for pbuf_header */
copy_len = (u16_t)LWIP_MIN(len, 0x7FFF);
/* limit number of bytes again to copy from first pbuf in a chain only */
copy_len = LWIP_MIN(copy_len, p->len);
/* copy the data */
ret = pbuf_copy_partial(p, buf, copy_len, 0);
LWIP_ASSERT("ret <= p->len", ret <= p->len);
LWIP_ASSERT("ret == copy_len", ret == copy_len);
/* hide the copied bytes from the pbuf */
err = pbuf_header(p, -(s16_t)ret);
LWIP_ASSERT("error", err == ERR_OK);
if(p->len == 0) {
if (p->len == 0) {
/* the first pbuf has been fully read, free it */
state->rx = p->next;
p->next = NULL;
pbuf_free(p);
}
state->bio_bytes_read += (int)ret;
return ret;
}
/** Sent callback from lower connection (i.e. TCP)
* @todo: Pass on the correct number of bytes to the application.
* This is somewhat tricky as we don't know the data/overhead ratio...
* This only informs the upper layer to try to send more, not about
* the number of ACKed bytes.
*/
static err_t
altcp_mbedtls_lower_sent(void *arg, struct altcp_pcb *inner_conn, u16_t len)
{
struct altcp_pcb *conn = (struct altcp_pcb *)arg;
LWIP_UNUSED_ARG(len);
if (conn) {
u16_t sent_upper;
altcp_mbedtls_state_t *state = (altcp_mbedtls_state_t *)conn->state;
LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn);
if (!state || !(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
/* @todo: do something here? */
return ERR_OK;
}
/* @todo: this is not accurate yet, need to fix byte counting to upper and lower conn */
sent_upper = (u16_t)LWIP_MIN(len, state->tx_unacked);
state->tx_unacked -= sent_upper;
if (conn->sent && sent_upper) {
return conn->sent(conn->arg, conn, len);
/* call upper sent with len==0 if the application already sent data */
if ((state->flags & ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT) && conn->sent) {
return conn->sent(conn->arg, conn, 0);
}
}
return ERR_OK;
@ -406,7 +469,9 @@ altcp_mbedtls_lower_poll(void *arg, struct altcp_pcb *inner_conn)
if (conn) {
LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn);
/* check if there's unreceived rx data */
altcp_mbedtls_handle_rx_data(conn);
if (conn->state) {
altcp_mbedtls_handle_rx_data(conn, (altcp_mbedtls_state_t *)conn->state);
}
if (conn->poll) {
return conn->poll(conn->arg, conn);
}
@ -624,6 +689,7 @@ altcp_mbedtls_set_poll(struct altcp_pcb *conn, u8_t interval)
static void
altcp_mbedtls_recved(struct altcp_pcb *conn, u16_t len)
{
u16_t lower_recved;
altcp_mbedtls_state_t *state;
if (conn == NULL) {
return;
@ -635,11 +701,15 @@ altcp_mbedtls_recved(struct altcp_pcb *conn, u16_t len)
if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) {
return;
}
LWIP_ASSERT("recved mismatch", state->rx_passed_unrecved >= len);
state->rx_passed_unrecved -= len;
lower_recved = len;
if (lower_recved > state->rx_passed_unrecved) {
LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("bogus recved count (len > state->rx_passed_unrecved / %d / %d)",
len, state->rx_passed_unrecved));
lower_recved = (u16_t)state->rx_passed_unrecved;
}
state->rx_passed_unrecved -= lower_recved;
/* to pass this down, we need to convert 'altcp_recved' handling in lower_recv first
altcp_recved(conn->inner_conn, len);*/
altcp_recved(conn->inner_conn, lower_recved);
}
static err_t
@ -741,8 +811,10 @@ altcp_mbedtls_write(struct altcp_pcb *conn, const void *dataptr, u16_t len, u8_t
}
ret = mbedtls_ssl_write(&state->ssl_context, (const unsigned char *)dataptr, len);
/* try to send data... */
altcp_output(conn->inner_conn);
if(ret == len) {
state->tx_unacked += len;
state->flags |= ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT;
return ERR_OK;
} else if (ret <= 0) {
/* @todo: convert error to err_t */

View File

@ -53,19 +53,22 @@ extern "C" {
#endif
#define ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE 0x01
#define ALTCP_MBEDTLS_FLAGS_RX_CLOSED 0x02
#define ALTCP_MBEDTLS_FLAGS_TX_CLOSED 0x04
#define ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT 0x02
#define ALTCP_MBEDTLS_FLAGS_RX_CLOSED 0x04
#define ALTCP_MBEDTLS_FLAGS_TX_CLOSED 0x08
#define ALTCP_MBEDTLS_FLAGS_CLOSED (ALTCP_MBEDTLS_FLAGS_RX_CLOSED|ALTCP_MBEDTLS_FLAGS_TX_CLOSED)
#define ALTCP_MBEDTLS_FLAGS_UPPER_CALLED 0x08
#define ALTCP_MBEDTLS_FLAGS_UPPER_CALLED 0x10
typedef struct altcp_mbedtls_state_s {
void *conf;
mbedtls_ssl_context ssl_context;
/* chain of rx pbufs (before decryption) */
struct pbuf* rx;
struct pbuf* rx_app;
u8_t flags;
size_t rx_passed_unrecved;
size_t tx_unacked;
int rx_passed_unrecved;
int bio_bytes_read;
int bio_bytes_appl;
} altcp_mbedtls_state_t;
#ifdef __cplusplus