diff --git a/ChangeLog b/ChangeLog index cc51067ae2..b926d077f5 100644 --- a/ChangeLog +++ b/ChangeLog @@ -37,6 +37,7 @@ Bugfix * Fix warnings from mingw64 in timing.c (found by kxjklele). * Fix potential unintended sign extension in asn1_get_len() on 64-bit platforms. + * Fix potential memory leak in ssl_set_psk() (found by Mansour Moufid). Changes * Move from SHA-1 to SHA-256 in example programs using signatures diff --git a/library/ssl_tls.c b/library/ssl_tls.c index ea621ae2ea..052a198577 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -5463,21 +5463,23 @@ int ssl_set_psk( ssl_context *ssl, const unsigned char *psk, size_t psk_len, if( psk_len > POLARSSL_PSK_MAX_LEN ) return( POLARSSL_ERR_SSL_BAD_INPUT_DATA ); - if( ssl->psk != NULL ) + if( ssl->psk != NULL || ssl->psk_identity != NULL ) { polarssl_free( ssl->psk ); polarssl_free( ssl->psk_identity ); } + if( ( ssl->psk = polarssl_malloc( psk_len ) ) == NULL || + ( ssl->psk_identity = polarssl_malloc( psk_identity_len ) ) == NULL ) + { + polarssl_free( ssl->psk ); + ssl->psk = NULL; + return( POLARSSL_ERR_SSL_MALLOC_FAILED ); + } + ssl->psk_len = psk_len; ssl->psk_identity_len = psk_identity_len; - ssl->psk = polarssl_malloc( ssl->psk_len ); - ssl->psk_identity = polarssl_malloc( ssl->psk_identity_len ); - - if( ssl->psk == NULL || ssl->psk_identity == NULL ) - return( POLARSSL_ERR_SSL_MALLOC_FAILED ); - memcpy( ssl->psk, psk, ssl->psk_len ); memcpy( ssl->psk_identity, psk_identity, ssl->psk_identity_len ); diff --git a/programs/ssl/ssl_server2.c b/programs/ssl/ssl_server2.c index 5319c7eb54..78198ff0d8 100644 --- a/programs/ssl/ssl_server2.c +++ b/programs/ssl/ssl_server2.c @@ -643,7 +643,7 @@ psk_entry *psk_parse( char *psk_string ) while( p <= end ) { if( ( new = polarssl_malloc( sizeof( psk_entry ) ) ) == NULL ) - return( NULL ); + goto error; memset( new, 0, sizeof( psk_entry ) ); diff --git a/scripts/find-mem-leak.cocci b/scripts/find-mem-leak.cocci new file mode 100644 index 0000000000..34cfd082d2 --- /dev/null +++ b/scripts/find-mem-leak.cocci @@ -0,0 +1,20 @@ +@@ +expression x, y; +statement S; +@@ + x = polarssl_malloc(...); + y = polarssl_malloc(...); + ... +* if (x == NULL || y == NULL) + S + +@@ +expression x, y; +statement S; +@@ + if ( +* (x = polarssl_malloc(...)) == NULL + || +* (y = polarssl_malloc(...)) == NULL + ) + S