diff --git a/src/apps/altcp_tls/altcp_tls_mbedtls.c b/src/apps/altcp_tls/altcp_tls_mbedtls.c index e2902ae8..1363f00d 100644 --- a/src/apps/altcp_tls/altcp_tls_mbedtls.c +++ b/src/apps/altcp_tls/altcp_tls_mbedtls.c @@ -106,9 +106,11 @@ struct altcp_tls_config #endif }; +static err_t altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p, err_t err); static err_t altcp_mbedtls_setup(void *conf, struct altcp_pcb *conn, struct altcp_pcb *inner_conn); static void altcp_mbedtls_dealloc(struct altcp_pcb *conn); -static err_t altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state); +static err_t altcp_mbedtls_lower_recv_process(struct altcp_pcb *conn, altcp_mbedtls_state_t *state); +static err_t altcp_mbedtls_handle_rx_appldata(struct altcp_pcb *conn, altcp_mbedtls_state_t *state); static int altcp_mbedtls_bio_send(void* ctx, const unsigned char* dataptr, size_t size); @@ -148,12 +150,18 @@ static err_t altcp_mbedtls_lower_connected(void *arg, struct altcp_pcb *inner_conn, err_t err) { struct altcp_pcb *conn = (struct altcp_pcb *)arg; - if (conn) { + if (conn && conn->state) { LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn); /* upper connected is called when handshake is done */ - LWIP_UNUSED_ARG(err); - LWIP_ASSERT("TODO: implement active connect", 0); - return ERR_OK; + if (err != ERR_OK) { + if (conn->connected) { + if (conn->connected(conn->arg, conn, err) == ERR_ABRT) { + return ERR_ABRT; + } + return ERR_OK; + } + } + return altcp_mbedtls_lower_recv_process(conn, (altcp_mbedtls_state_t*)conn->state); } return ERR_VAL; } @@ -203,29 +211,34 @@ altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p if (p == NULL) { /* remote host sent FIN, remember this (SSL state is destroyed when both sides are closed only!) */ - state->flags |= ALTCP_MBEDTLS_FLAGS_RX_CLOSED; + state->flags |= ALTCP_MBEDTLS_FLAGS_RX_CLOSE_QUEUED; } - if (state->flags & ALTCP_MBEDTLS_FLAGS_UPPER_CALLED) { + if ((state->flags & (ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE|ALTCP_MBEDTLS_FLAGS_UPPER_CALLED)) == + (ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE|ALTCP_MBEDTLS_FLAGS_UPPER_CALLED)) { /* need to notify upper layer (e.g. 'accept' called or 'connect' succeeded) */ + if ((err == ERR_OK) && ((state->rx != NULL) || (state->rx_app != NULL))) { + LWIP_ASSERT("p == NULL", p == NULL); + /* this is a normal close (FIN) but we have unprocessed data */ + altcp_mbedtls_handle_rx_appldata(conn, state); + return ERR_OK; + } if (conn->recv) { local_err = conn->recv(conn->arg, conn, p, err); - } else { - /* no recv callback? close connection */ - if (p) { + if ((local_err != ERR_OK) && (p != NULL)) { pbuf_free(p); } - altcp_close(conn); + p = NULL; } } else { /* before connection setup is done: call 'err' */ - if (p) { - pbuf_free(p); - } if (conn->err) { conn->err(conn->arg, ERR_CLSD); } - altcp_close(conn); } + if (p) { + pbuf_free(p); + } + altcp_close(conn); if (conn->state && ((state->flags & ALTCP_MBEDTLS_FLAGS_CLOSED) == ALTCP_MBEDTLS_FLAGS_CLOSED)) { altcp_mbedtls_dealloc(conn); } @@ -240,7 +253,12 @@ altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p LWIP_ASSERT("rx pbuf overflow", (int)p->tot_len + (int)p->len <= 0xFFFF); pbuf_cat(state->rx, p); } + return altcp_mbedtls_lower_recv_process(conn, state); +} +static err_t +altcp_mbedtls_lower_recv_process(struct altcp_pcb *conn, altcp_mbedtls_state_t *state) +{ if (!(state->flags & ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE)) { /* handle connection setup (handshake not done) */ int ret = mbedtls_ssl_handshake(&state->ssl_context); @@ -248,7 +266,7 @@ altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p altcp_output(conn->inner_conn); if (state->bio_bytes_read) { /* acknowledge all bytes read */ - altcp_mbedtls_lower_recved(inner_conn, state->bio_bytes_read); + altcp_mbedtls_lower_recved(conn->inner_conn, state->bio_bytes_read); state->bio_bytes_read = 0; } @@ -276,7 +294,7 @@ altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p return ERR_OK; } else { /* handle application data */ - return altcp_mbedtls_handle_rx_data(conn, state); + return altcp_mbedtls_handle_rx_appldata(conn, state); } } @@ -284,6 +302,7 @@ altcp_mbedtls_lower_recv(void *arg, struct altcp_pcb *inner_conn, struct pbuf *p static err_t altcp_mbedtls_pass_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state) { + err_t err; struct pbuf *buf; LWIP_ASSERT("conn != NULL", conn != NULL); LWIP_ASSERT("state != NULL", state != NULL); @@ -291,9 +310,9 @@ altcp_mbedtls_pass_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state) if (buf) { if (conn->recv) { u16_t tot_len = state->rx_app->tot_len; - err_t err; /* this needs to be increased first because the 'recved' call may come nested */ state->rx_passed_unrecved += tot_len; + state->flags |= ALTCP_MBEDTLS_FLAGS_UPPER_CALLED; err = conn->recv(conn->arg, conn, state->rx_app, ERR_OK); if (err != ERR_OK) { /* not received, leave the pbuf(s) queued (and decrease 'unrecved' again) */ @@ -308,13 +327,23 @@ altcp_mbedtls_pass_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state) pbuf_free(buf); } state->rx_app = NULL; + } else if ((state->flags & (ALTCP_MBEDTLS_FLAGS_RX_CLOSE_QUEUED|ALTCP_MBEDTLS_FLAGS_RX_CLOSED)) == + ALTCP_MBEDTLS_FLAGS_RX_CLOSE_QUEUED) { + state->flags |= ALTCP_MBEDTLS_FLAGS_RX_CLOSED; + if (conn->recv) { + err = conn->recv(conn->arg, conn, NULL, ERR_OK); + if (err == ERR_ABRT) { + return ERR_ABRT; + } + } } + return ERR_OK; } /* Helper function that processes rx application data stored in rx pbuf chain */ static err_t -altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *state) +altcp_mbedtls_handle_rx_appldata(struct altcp_pcb *conn, altcp_mbedtls_state_t *state) { int ret; LWIP_ASSERT("state != NULL", state != NULL); @@ -355,27 +384,32 @@ altcp_mbedtls_handle_rx_data(struct altcp_pcb *conn, altcp_mbedtls_state_t *stat return ERR_ABRT; } else { err_t err; - LWIP_ASSERT("bogus receive length", ret <= PBUF_POOL_BUFSIZE); - /* trim pool pbuf to actually decoded length */ - pbuf_realloc(buf, (uint16_t)ret); + if (ret) { + LWIP_ASSERT("bogus receive length", ret <= PBUF_POOL_BUFSIZE); + /* trim pool pbuf to actually decoded length */ + pbuf_realloc(buf, (uint16_t)ret); - state->bio_bytes_appl += ret; - if (mbedtls_ssl_get_bytes_avail(&state->ssl_context) == 0) { - /* Record is done, now we know the share between application and protocol bytes - and can adjust the RX window by the protocol bytes. - The rest is 'recved' by the application calling our 'recved' fn. */ - int overhead_bytes; - LWIP_ASSERT("bogus byte counts", state->bio_bytes_read > state->bio_bytes_appl); - overhead_bytes = state->bio_bytes_read - state->bio_bytes_appl; - altcp_mbedtls_lower_recved(conn->inner_conn, overhead_bytes); - state->bio_bytes_read = 0; - state->bio_bytes_appl = 0; - } + state->bio_bytes_appl += ret; + if (mbedtls_ssl_get_bytes_avail(&state->ssl_context) == 0) { + /* Record is done, now we know the share between application and protocol bytes + and can adjust the RX window by the protocol bytes. + The rest is 'recved' by the application calling our 'recved' fn. */ + int overhead_bytes; + LWIP_ASSERT("bogus byte counts", state->bio_bytes_read > state->bio_bytes_appl); + overhead_bytes = state->bio_bytes_read - state->bio_bytes_appl; + altcp_mbedtls_lower_recved(conn->inner_conn, overhead_bytes); + state->bio_bytes_read = 0; + state->bio_bytes_appl = 0; + } - if (state->rx_app == NULL) { - state->rx_app = buf; + if (state->rx_app == NULL) { + state->rx_app = buf; + } else { + pbuf_cat(state->rx_app, buf); + } } else { - pbuf_cat(state->rx_app, buf); + pbuf_free(buf); + buf = NULL; } err = altcp_mbedtls_pass_rx_data(conn, state); if (err != ERR_OK) { @@ -410,7 +444,16 @@ altcp_mbedtls_bio_recv(void *ctx, unsigned char *buf, size_t len) } p = state->rx; - if (p == NULL) { + if ((p == NULL) || ((p->len == 0) && (p->next == NULL))) { + if (p) { + pbuf_free(p); + } + state->rx = NULL; + if ((state->flags & (ALTCP_MBEDTLS_FLAGS_RX_CLOSE_QUEUED|ALTCP_MBEDTLS_FLAGS_RX_CLOSED)) == + ALTCP_MBEDTLS_FLAGS_RX_CLOSE_QUEUED) { + /* close queued but not passed up yet */ + return 0; + } return MBEDTLS_ERR_SSL_WANT_READ; } /* limit number of bytes to copy to fit into an s16_t for pbuf_header */ @@ -470,7 +513,7 @@ altcp_mbedtls_lower_poll(void *arg, struct altcp_pcb *inner_conn) LWIP_ASSERT("pcb mismatch", conn->inner_conn == inner_conn); /* check if there's unreceived rx data */ if (conn->state) { - altcp_mbedtls_handle_rx_data(conn, (altcp_mbedtls_state_t *)conn->state); + altcp_mbedtls_handle_rx_appldata(conn, (altcp_mbedtls_state_t *)conn->state); } if (conn->poll) { return conn->poll(conn->arg, conn); @@ -585,8 +628,8 @@ dummy_rng(void *ctx, unsigned char *buffer , size_t len) /** Create new TLS configuration * ATTENTION: Server certificate and private key have to be added outside this function! */ -struct altcp_tls_config* -altcp_tls_create_config(void) +static struct altcp_tls_config * +altcp_tls_create_config(int is_server) { int ret; struct altcp_tls_config *conf; @@ -611,13 +654,14 @@ altcp_tls_create_config(void) } /* Setup ssl context (@todo: what's different for a client here? -> might better be done on listen/connect) */ - ret = mbedtls_ssl_config_defaults(&conf->conf, MBEDTLS_SSL_IS_SERVER, + ret = mbedtls_ssl_config_defaults(&conf->conf, is_server ? MBEDTLS_SSL_IS_SERVER : MBEDTLS_SSL_IS_CLIENT, MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT); if (ret != 0) { LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_ssl_config_defaults failed: %d", ret)); altcp_mbedtls_free_config(conf); return NULL; } + mbedtls_ssl_conf_authmode(&conf->conf, MBEDTLS_SSL_VERIFY_OPTIONAL); mbedtls_ssl_conf_rng(&conf->conf, mbedtls_ctr_drbg_random, &conf->ctr_drbg); #if ALTCP_MBEDTLS_DEBUG != LWIP_DBG_OFF @@ -636,15 +680,15 @@ altcp_tls_create_config(void) * This is a suboptimal version that gets the encrypted private key and its password, * as well as the server certificate. */ -struct altcp_tls_config* -altcp_tls_create_config_privkey_cert(const u8_t *privkey, size_t privkey_len, +struct altcp_tls_config * +altcp_tls_create_config_server_privkey_cert(const u8_t *privkey, size_t privkey_len, const u8_t *privkey_pass, size_t privkey_pass_len, const u8_t *cert, size_t cert_len) { int ret; static mbedtls_x509_crt srvcert; static mbedtls_pk_context pkey; - struct altcp_tls_config *conf = altcp_tls_create_config(); + struct altcp_tls_config *conf = altcp_tls_create_config(1); if (conf == NULL) { return NULL; } @@ -677,6 +721,30 @@ altcp_tls_create_config_privkey_cert(const u8_t *privkey, size_t privkey_len, return conf; } +struct altcp_tls_config * +altcp_tls_create_config_client(const u8_t *cert, size_t cert_len) +{ + int ret; + static mbedtls_x509_crt acc_cert; + struct altcp_tls_config *conf = altcp_tls_create_config(0); + if (conf == NULL) { + return NULL; + } + + mbedtls_x509_crt_init(&acc_cert); + + /* Load the certificates */ + ret = mbedtls_x509_crt_parse(&acc_cert, cert, cert_len); + if (ret != 0) { + LWIP_DEBUGF(ALTCP_MBEDTLS_DEBUG, ("mbedtls_x509_crt_parse failed: %d", ret)); + altcp_mbedtls_free_config(conf); + return NULL; + } + + mbedtls_ssl_conf_ca_chain(&conf->conf, &acc_cert, NULL); + return conf; +} + /* "virtual" functions */ static void altcp_mbedtls_set_poll(struct altcp_pcb *conn, u8_t interval) diff --git a/src/apps/altcp_tls/altcp_tls_mbedtls_structs.h b/src/apps/altcp_tls/altcp_tls_mbedtls_structs.h index 12170b02..732fabd6 100644 --- a/src/apps/altcp_tls/altcp_tls_mbedtls_structs.h +++ b/src/apps/altcp_tls/altcp_tls_mbedtls_structs.h @@ -54,10 +54,11 @@ extern "C" { #define ALTCP_MBEDTLS_FLAGS_HANDSHAKE_DONE 0x01 #define ALTCP_MBEDTLS_FLAGS_APPLDATA_SENT 0x02 -#define ALTCP_MBEDTLS_FLAGS_RX_CLOSED 0x04 -#define ALTCP_MBEDTLS_FLAGS_TX_CLOSED 0x08 +#define ALTCP_MBEDTLS_FLAGS_RX_CLOSE_QUEUED 0x04 +#define ALTCP_MBEDTLS_FLAGS_RX_CLOSED 0x08 +#define ALTCP_MBEDTLS_FLAGS_TX_CLOSED 0x10 #define ALTCP_MBEDTLS_FLAGS_CLOSED (ALTCP_MBEDTLS_FLAGS_RX_CLOSED|ALTCP_MBEDTLS_FLAGS_TX_CLOSED) -#define ALTCP_MBEDTLS_FLAGS_UPPER_CALLED 0x10 +#define ALTCP_MBEDTLS_FLAGS_UPPER_CALLED 0x20 typedef struct altcp_mbedtls_state_s { void *conf; diff --git a/src/include/lwip/apps/altcp_tls.h b/src/include/lwip/apps/altcp_tls.h index b5c532c4..3a0bc9bd 100644 --- a/src/include/lwip/apps/altcp_tls.h +++ b/src/include/lwip/apps/altcp_tls.h @@ -55,10 +55,11 @@ extern "C" { struct altcp_tls_config; -struct altcp_tls_config* altcp_tls_create_config(void); -struct altcp_tls_config* altcp_tls_create_config_privkey_cert(const u8_t *privkey, size_t privkey_len, +struct altcp_tls_config *altcp_tls_create_config_server_privkey_cert(const u8_t *privkey, size_t privkey_len, const u8_t *privkey_pass, size_t privkey_pass_len, const u8_t *cert, size_t cert_len); +struct altcp_tls_config *altcp_tls_create_config_client(const u8_t *cert, size_t cert_len); + struct altcp_pcb *altcp_tls_new(struct altcp_tls_config* config, struct altcp_pcb *inner_pcb); #ifdef __cplusplus