diff --git a/include/polarssl/ssl.h b/include/polarssl/ssl.h index 2c2bedab07..bd7f1f775b 100644 --- a/include/polarssl/ssl.h +++ b/include/polarssl/ssl.h @@ -1682,6 +1682,13 @@ int ssl_close_notify( ssl_context *ssl ); */ void ssl_free( ssl_context *ssl ); +/** + * \brief Initialize SSL session structure + * + * \param session SSL session + */ +void ssl_session_init( ssl_session *session ); + /** * \brief Free referenced items in an SSL session including the * peer certificate and clear memory diff --git a/library/ssl_srv.c b/library/ssl_srv.c index 1e75408eeb..9d2507a3b6 100644 --- a/library/ssl_srv.c +++ b/library/ssl_srv.c @@ -343,6 +343,8 @@ static int ssl_parse_ticket( ssl_context *ssl, ssl_session_free( ssl->session_negotiate ); memcpy( ssl->session_negotiate, &session, sizeof( ssl_session ) ); + + /* Zeroize instead of free as we copied the content */ polarssl_zeroize( &session, sizeof( ssl_session ) ); return( 0 ); diff --git a/library/ssl_tls.c b/library/ssl_tls.c index a6761adce4..d3bfab5188 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -3270,80 +3270,108 @@ int ssl_parse_finished( ssl_context *ssl ) return( 0 ); } +static void ssl_handshake_params_init( ssl_handshake_params *handshake, + ssl_key_cert *key_cert ) +{ + memset( handshake, 0, sizeof( ssl_handshake_params ) ); + +#if defined(POLARSSL_SSL_PROTO_SSL3) || defined(POLARSSL_SSL_PROTO_TLS1) || \ + defined(POLARSSL_SSL_PROTO_TLS1_1) + md5_init( &handshake->fin_md5 ); + sha1_init( &handshake->fin_sha1 ); + md5_starts( &handshake->fin_md5 ); + sha1_starts( &handshake->fin_sha1 ); +#endif +#if defined(POLARSSL_SSL_PROTO_TLS1_2) +#if defined(POLARSSL_SHA256_C) + sha256_init( &handshake->fin_sha256 ); + sha256_starts( &handshake->fin_sha256, 0 ); +#endif +#if defined(POLARSSL_SHA512_C) + sha512_init( &handshake->fin_sha512 ); + sha512_starts( &handshake->fin_sha512, 1 ); +#endif +#endif /* POLARSSL_SSL_PROTO_TLS1_2 */ + + handshake->update_checksum = ssl_update_checksum_start; + handshake->sig_alg = SSL_HASH_SHA1; + +#if defined(POLARSSL_DHM_C) + dhm_init( &handshake->dhm_ctx ); +#endif +#if defined(POLARSSL_ECDH_C) + ecdh_init( &handshake->ecdh_ctx ); +#endif + +#if defined(POLARSSL_X509_CRT_PARSE_C) + handshake->key_cert = key_cert; +#endif +} + +static void ssl_transform_init( ssl_transform *transform ) +{ + memset( transform, 0, sizeof(ssl_transform) ); +} + +void ssl_session_init( ssl_session *session ) +{ + memset( session, 0, sizeof(ssl_session) ); +} + static int ssl_handshake_init( ssl_context *ssl ) { + /* Clear old handshake information if present */ if( ssl->transform_negotiate ) ssl_transform_free( ssl->transform_negotiate ); - else + if( ssl->session_negotiate ) + ssl_session_free( ssl->session_negotiate ); + if( ssl->handshake ) + ssl_handshake_free( ssl->handshake ); + + /* + * Either the pointers are now NULL or cleared properly and can be freed. + * Now allocate missing structures. + */ + if( ssl->transform_negotiate == NULL ) { ssl->transform_negotiate = (ssl_transform *) polarssl_malloc( sizeof(ssl_transform) ); - - if( ssl->transform_negotiate != NULL ) - memset( ssl->transform_negotiate, 0, sizeof(ssl_transform) ); } - if( ssl->session_negotiate ) - ssl_session_free( ssl->session_negotiate ); - else + if( ssl->session_negotiate == NULL ) { ssl->session_negotiate = (ssl_session *) polarssl_malloc( sizeof(ssl_session) ); - - if( ssl->session_negotiate != NULL ) - memset( ssl->session_negotiate, 0, sizeof(ssl_session) ); } - if( ssl->handshake ) - ssl_handshake_free( ssl->handshake ); - else + if( ssl->handshake == NULL) { ssl->handshake = (ssl_handshake_params *) polarssl_malloc( sizeof(ssl_handshake_params) ); - - if( ssl->handshake != NULL ) - memset( ssl->handshake, 0, sizeof(ssl_handshake_params) ); } + /* All pointers should exist and can be directly freed without issue */ if( ssl->handshake == NULL || ssl->transform_negotiate == NULL || ssl->session_negotiate == NULL ) { SSL_DEBUG_MSG( 1, ( "malloc() of ssl sub-contexts failed" ) ); + + polarssl_free( ssl->handshake ); + polarssl_free( ssl->transform_negotiate ); + polarssl_free( ssl->session_negotiate ); + + ssl->handshake = NULL; + ssl->transform_negotiate = NULL; + ssl->session_negotiate = NULL; + return( POLARSSL_ERR_SSL_MALLOC_FAILED ); } -#if defined(POLARSSL_SSL_PROTO_SSL3) || defined(POLARSSL_SSL_PROTO_TLS1) || \ - defined(POLARSSL_SSL_PROTO_TLS1_1) - md5_init( &ssl->handshake->fin_md5 ); - sha1_init( &ssl->handshake->fin_sha1 ); - md5_starts( &ssl->handshake->fin_md5 ); - sha1_starts( &ssl->handshake->fin_sha1 ); -#endif -#if defined(POLARSSL_SSL_PROTO_TLS1_2) -#if defined(POLARSSL_SHA256_C) - sha256_init( &ssl->handshake->fin_sha256 ); - sha256_starts( &ssl->handshake->fin_sha256, 0 ); -#endif -#if defined(POLARSSL_SHA512_C) - sha512_init( &ssl->handshake->fin_sha512 ); - sha512_starts( &ssl->handshake->fin_sha512, 1 ); -#endif -#endif /* POLARSSL_SSL_PROTO_TLS1_2 */ - - ssl->handshake->update_checksum = ssl_update_checksum_start; - ssl->handshake->sig_alg = SSL_HASH_SHA1; - -#if defined(POLARSSL_DHM_C) - dhm_init( &ssl->handshake->dhm_ctx ); -#endif -#if defined(POLARSSL_ECDH_C) - ecdh_init( &ssl->handshake->ecdh_ctx ); -#endif - -#if defined(POLARSSL_X509_CRT_PARSE_C) - ssl->handshake->key_cert = ssl->key_cert; -#endif + /* Initialize structures */ + ssl_session_init( ssl->session_negotiate ); + ssl_transform_init( ssl->transform_negotiate ); + ssl_handshake_params_init( ssl->handshake, ssl->key_cert ); return( 0 ); } @@ -4470,6 +4498,9 @@ int ssl_close_notify( ssl_context *ssl ) void ssl_transform_free( ssl_transform *transform ) { + if( transform == NULL ) + return; + #if defined(POLARSSL_ZLIB_SUPPORT) deflateEnd( &transform->ctx_deflate ); inflateEnd( &transform->ctx_inflate ); @@ -4507,6 +4538,9 @@ static void ssl_key_cert_free( ssl_key_cert *key_cert ) void ssl_handshake_free( ssl_handshake_params *handshake ) { + if( handshake == NULL ) + return; + #if defined(POLARSSL_DHM_C) dhm_free( &handshake->dhm_ctx ); #endif @@ -4543,6 +4577,9 @@ void ssl_handshake_free( ssl_handshake_params *handshake ) void ssl_session_free( ssl_session *session ) { + if( session == NULL ) + return; + #if defined(POLARSSL_X509_CRT_PARSE_C) if( session->peer_cert != NULL ) { @@ -4563,6 +4600,9 @@ void ssl_session_free( ssl_session *session ) */ void ssl_free( ssl_context *ssl ) { + if( ssl == NULL ) + return; + SSL_DEBUG_MSG( 2, ( "=> free" ) ); if( ssl->out_ctr != NULL )