diff --git a/library/psa_crypto.c b/library/psa_crypto.c index f3a22588d5..e1f1e7b896 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -5299,31 +5299,58 @@ static psa_status_t psa_tls12_prf_psk_to_ms_set_key( size_t data_length ) { psa_status_t status; - uint8_t pms[ 4 + 2 * PSA_TLS12_PSK_TO_MS_PSK_MAX_SIZE ]; - uint8_t *cur = pms; + const size_t pms_len = ( prf->state == PSA_TLS12_PRF_STATE_OTHER_KEY_SET ? + 4 + data_length + prf->other_secret_length : + 4 + 2 * data_length ); if( data_length > PSA_TLS12_PSK_TO_MS_PSK_MAX_SIZE ) return( PSA_ERROR_INVALID_ARGUMENT ); - /* Quoting RFC 4279, Section 2: + uint8_t *pms = mbedtls_calloc( 1, pms_len ); + uint8_t *cur = pms; + + /* pure-PSK: + * Quoting RFC 4279, Section 2: * * The premaster secret is formed as follows: if the PSK is N octets * long, concatenate a uint16 with the value N, N zero octets, a second * uint16 with the value N, and the PSK itself. + * + * mixed-PSK: + * In a DHE-PSK, RSA-PSK, ECDHE-PSK the premaster secret is formed as + * follows: concatenate a uint16 with the length of the other secret, + * the other secret itself, uint16 with the length of PSK, and the + * PSK itself. + * For details please check: + * - RFC 4279, Section 4 for the definition of RSA-PSK, + * - RFC 4279, Section 3 for the definition of DHE-PSK, + * - RFC 5489 for the definition of ECDHE-PSK. */ + if ( prf->state == PSA_TLS12_PRF_STATE_OTHER_KEY_SET ) + { + *cur++ = MBEDTLS_BYTE_1( prf->other_secret_length ); + *cur++ = MBEDTLS_BYTE_0( prf->other_secret_length ); + memcpy( cur, prf->other_secret, prf->other_secret_length ); + cur += prf->other_secret_length; + } + else + { + *cur++ = MBEDTLS_BYTE_1( data_length ); + *cur++ = MBEDTLS_BYTE_0( data_length ); + memset( cur, 0, data_length ); + cur += data_length; + } + *cur++ = MBEDTLS_BYTE_1( data_length ); *cur++ = MBEDTLS_BYTE_0( data_length ); - memset( cur, 0, data_length ); - cur += data_length; - *cur++ = pms[0]; - *cur++ = pms[1]; memcpy( cur, data, data_length ); cur += data_length; status = psa_tls12_prf_set_key( prf, pms, cur - pms ); - mbedtls_platform_zeroize( pms, sizeof( pms ) ); + mbedtls_platform_zeroize( pms, pms_len ); + mbedtls_free( pms ); return( status ); }