diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index dbc92b3878..226485add8 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -1827,6 +1827,15 @@ void ssl_write_version( int major, int minor, int transport, void ssl_read_version( int *major, int *minor, int transport, const unsigned char ver[2] ); +static inline size_t ssl_hdr_len( const ssl_context *ssl ) +{ +#if defined(POLARSSL_SSL_PROTO_DTLS) + if( ssl->transport == SSL_TRANSPORT_DATAGRAM ) + return( 13 ); +#endif + return( 5 ); +} + /* constant-time buffer comparison */ static inline int safer_memcmp( const void *a, const void *b, size_t n ) { diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 617db9d714..5ab626a673 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -1134,7 +1134,7 @@ static int ssl_parse_client_hello( ssl_context *ssl ) SSL_DEBUG_MSG( 2, ( "=> parse client hello" ) ); if( ssl->renegotiation == SSL_INITIAL_HANDSHAKE && - ( ret = ssl_fetch_input( ssl, 5 ) ) != 0 ) + ( ret = ssl_fetch_input( ssl, ssl_hdr_len( ssl ) ) ) != 0 ) { SSL_DEBUG_RET( 1, "ssl_fetch_input", ret ); return( ret ); @@ -1147,7 +1147,7 @@ static int ssl_parse_client_hello( ssl_context *ssl ) return ssl_parse_client_hello_v2( ssl ); #endif - SSL_DEBUG_BUF( 4, "record header", buf, 5 ); // TODO: 13 for DTLS + SSL_DEBUG_BUF( 4, "record header", buf, ssl_hdr_len( ssl ) ); SSL_DEBUG_MSG( 3, ( "client hello v3, message type: %d", buf[0] ) ); @@ -1191,7 +1191,7 @@ static int ssl_parse_client_hello( ssl_context *ssl ) } if( ssl->renegotiation == SSL_INITIAL_HANDSHAKE && - ( ret = ssl_fetch_input( ssl, 5 + n ) ) != 0 ) + ( ret = ssl_fetch_input( ssl, ssl_hdr_len( ssl ) + n ) ) != 0 ) { SSL_DEBUG_RET( 1, "ssl_fetch_input", ret ); return( ret ); @@ -1199,7 +1199,7 @@ static int ssl_parse_client_hello( ssl_context *ssl ) buf = ssl->in_msg; if( !ssl->renegotiation ) - n = ssl->in_left - 5; + n = ssl->in_left - ssl_hdr_len( ssl ); else n = ssl->in_msglen; diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 58fb306f83..5ad244d33f 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -1284,18 +1284,6 @@ static int ssl_encrypt_buf( ssl_context *ssl ) return( POLARSSL_ERR_SSL_INTERNAL_ERROR ); } - // TODO: adapt for DTLS (start from i = 6) - for( i = 8; i > 0; i-- ) - if( ++ssl->out_ctr[i - 1] != 0 ) - break; - - /* The loops goes to its end iff the counter is wrapping */ - if( i == 0 ) - { - SSL_DEBUG_MSG( 1, ( "outgoing message counter would wrap" ) ); - return( POLARSSL_ERR_SSL_COUNTER_WRAPPING ); - } - SSL_DEBUG_MSG( 2, ( "<= encrypt buf" ) ); return( 0 ); @@ -1702,16 +1690,19 @@ static int ssl_decrypt_buf( ssl_context *ssl ) else ssl->nb_zero = 0; - // TODO: DTLS: i = 6 - for( i = 8; i > 0; i-- ) - if( ++ssl->in_ctr[i - 1] != 0 ) - break; - - /* The loops goes to its end iff the counter is wrapping */ - if( i == 0 ) + /* For DTLS we don't maintain our own incoming counter (for now) */ + if( ssl->transport == SSL_TRANSPORT_STREAM ) { - SSL_DEBUG_MSG( 1, ( "incoming message counter would wrap" ) ); - return( POLARSSL_ERR_SSL_COUNTER_WRAPPING ); + for( i = 8; i > 0; i-- ) + if( ++ssl->in_ctr[i - 1] != 0 ) + break; + + /* The loop goes to its end iff the counter is wrapping */ + if( i == 0 ) + { + SSL_DEBUG_MSG( 1, ( "incoming message counter would wrap" ) ); + return( POLARSSL_ERR_SSL_COUNTER_WRAPPING ); + } } SSL_DEBUG_MSG( 2, ( "<= decrypt buf" ) ); @@ -1860,17 +1851,25 @@ int ssl_fetch_input( ssl_context *ssl, size_t nb_want ) */ int ssl_flush_output( ssl_context *ssl ) { - int ret; + int ret, i; unsigned char *buf; SSL_DEBUG_MSG( 2, ( "=> flush output" ) ); + /* Avoid incrementing counter if data is flushed */ + if( ssl->out_left == 0 ) + { + SSL_DEBUG_MSG( 2, ( "<= flush output" ) ); + return( 0 ); + } + while( ssl->out_left > 0 ) { SSL_DEBUG_MSG( 2, ( "message length: %d, out_left: %d", - 5 + ssl->out_msglen, ssl->out_left ) ); + ssl_hdr_len( ssl ) + ssl->out_msglen, ssl->out_left ) ); - buf = ssl->out_hdr + 5 + ssl->out_msglen - ssl->out_left; + buf = ssl->out_hdr + ssl_hdr_len( ssl ) + + ssl->out_msglen - ssl->out_left; ret = ssl->f_send( ssl->p_send, buf, ssl->out_left ); SSL_DEBUG_RET( 2, "ssl->f_send", ret ); @@ -1881,6 +1880,18 @@ int ssl_flush_output( ssl_context *ssl ) ssl->out_left -= ret; } + // TODO: adapt for DTLS (start from i = 6) + for( i = 8; i > 0; i-- ) + if( ++ssl->out_ctr[i - 1] != 0 ) + break; + + /* The loop goes to its end iff the counter is wrapping */ + if( i == 0 ) + { + SSL_DEBUG_MSG( 1, ( "outgoing message counter would wrap" ) ); + return( POLARSSL_ERR_SSL_COUNTER_WRAPPING ); + } + SSL_DEBUG_MSG( 2, ( "<= flush output" ) ); return( 0 ); @@ -1958,7 +1969,7 @@ int ssl_write_record( ssl_context *ssl ) ssl->out_len[1] = (unsigned char)( len ); } - ssl->out_left = 5 + ssl->out_msglen; + ssl->out_left = ssl_hdr_len( ssl ) + ssl->out_msglen; SSL_DEBUG_MSG( 3, ( "output record: msgtype = %d, " "version = [%d:%d], msglen = %d", @@ -1966,7 +1977,7 @@ int ssl_write_record( ssl_context *ssl ) ( ssl->out_len[0] << 8 ) | ssl->out_len[1] ) ); SSL_DEBUG_BUF( 4, "output record sent to network", - ssl->out_hdr, 5 + ssl->out_msglen ); + ssl->out_hdr, ssl_hdr_len( ssl ) + ssl->out_msglen ); } if( ( ret = ssl_flush_output( ssl ) ) != 0 ) @@ -2028,7 +2039,7 @@ int ssl_read_record( ssl_context *ssl ) /* * Read the record header and validate it */ - if( ( ret = ssl_fetch_input( ssl, 5 ) ) != 0 ) + if( ( ret = ssl_fetch_input( ssl, ssl_hdr_len( ssl ) ) ) != 0 ) { SSL_DEBUG_RET( 1, "ssl_fetch_input", ret ); return( ret ); @@ -2110,14 +2121,15 @@ int ssl_read_record( ssl_context *ssl ) /* * Read and optionally decrypt the message contents */ - if( ( ret = ssl_fetch_input( ssl, 5 + ssl->in_msglen ) ) != 0 ) + if( ( ret = ssl_fetch_input( ssl, + ssl_hdr_len( ssl ) + ssl->in_msglen ) ) != 0 ) { SSL_DEBUG_RET( 1, "ssl_fetch_input", ret ); return( ret ); } SSL_DEBUG_BUF( 4, "input record from network", - ssl->in_hdr, 5 + ssl->in_msglen ); + ssl->in_hdr, ssl_hdr_len( ssl ) + ssl->in_msglen ); #if defined(POLARSSL_SSL_HW_RECORD_ACCEL) if( ssl_hw_record_read != NULL ) @@ -3417,39 +3429,27 @@ int ssl_init( ssl_context *ssl ) #endif /* - * Prepare base structures (assume TLS for now) + * Prepare base structures */ ssl->in_buf = (unsigned char *) polarssl_malloc( len ); - ssl->in_ctr = ssl->in_buf; - ssl->in_hdr = ssl->in_buf + 8; - ssl->in_len = ssl->in_buf + 11; - ssl->in_iv = ssl->in_buf + 13; - ssl->in_msg = ssl->in_buf + 13; - - if( ssl->in_buf == NULL ) - { - SSL_DEBUG_MSG( 1, ( "malloc(%d bytes) failed", len ) ); - return( POLARSSL_ERR_SSL_MALLOC_FAILED ); - } - ssl->out_buf = (unsigned char *) polarssl_malloc( len ); - ssl->out_ctr = ssl->out_buf; - ssl->out_hdr = ssl->out_buf + 8; - ssl->out_len = ssl->out_buf + 11; - ssl->out_iv = ssl->out_buf + 13; - ssl->out_msg = ssl->out_buf + 13; - if( ssl->out_buf == NULL ) + if( ssl->in_buf == NULL || ssl->out_buf == NULL ) { SSL_DEBUG_MSG( 1, ( "malloc(%d bytes) failed", len ) ); polarssl_free( ssl->in_buf ); + polarssl_free( ssl->out_buf ); ssl->in_buf = NULL; + ssl->out_buf = NULL; return( POLARSSL_ERR_SSL_MALLOC_FAILED ); } memset( ssl-> in_buf, 0, SSL_BUFFER_LEN ); memset( ssl->out_buf, 0, SSL_BUFFER_LEN ); + /* No error is possible, SSL_TRANSPORT_STREAM always valid */ + (void) ssl_set_transport( ssl, SSL_TRANSPORT_STREAM ); + #if defined(POLARSSL_SSL_SESSION_TICKETS) ssl->ticket_lifetime = SSL_DEFAULT_TICKET_LIFETIME; #endif @@ -3617,6 +3617,18 @@ int ssl_set_transport( ssl_context *ssl, int transport ) { ssl->transport = transport; + ssl->out_hdr = ssl->out_buf; + ssl->out_ctr = ssl->out_buf + 3; + ssl->out_len = ssl->out_buf + 11; + ssl->out_iv = ssl->out_buf + 13; + ssl->out_msg = ssl->out_buf + 13; + + ssl->in_hdr = ssl->in_buf; + ssl->in_ctr = ssl->in_buf + 3; + ssl->in_len = ssl->in_buf + 11; + ssl->in_iv = ssl->in_buf + 13; + ssl->in_msg = ssl->in_buf + 13; + /* DTLS starts with TLS1.1 */ if( ssl->min_minor_ver < SSL_MINOR_VERSION_2 ) ssl->min_minor_ver = SSL_MINOR_VERSION_2; @@ -3631,6 +3643,19 @@ int ssl_set_transport( ssl_context *ssl, int transport ) if( transport == SSL_TRANSPORT_STREAM ) { ssl->transport = transport; + + ssl->out_ctr = ssl->out_buf; + ssl->out_hdr = ssl->out_buf + 8; + ssl->out_len = ssl->out_buf + 11; + ssl->out_iv = ssl->out_buf + 13; + ssl->out_msg = ssl->out_buf + 13; + + ssl->in_ctr = ssl->in_buf; + ssl->in_hdr = ssl->in_buf + 8; + ssl->in_len = ssl->in_buf + 11; + ssl->in_iv = ssl->in_buf + 13; + ssl->in_msg = ssl->in_buf + 13; + return( 0 ); }