diff --git a/library/lmots.c b/library/lmots.c index 059d6c8eff..9168ef189d 100644 --- a/library/lmots.c +++ b/library/lmots.c @@ -409,6 +409,11 @@ void mbedtls_lmots_public_free( mbedtls_lmots_public_t *ctx ) int mbedtls_lmots_import_public_key( mbedtls_lmots_public_t *ctx, const unsigned char *key, size_t key_len ) { + if( key_len < MBEDTLS_LMOTS_SIG_TYPE_OFFSET + MBEDTLS_LMOTS_TYPE_LEN ) + { + return( MBEDTLS_ERR_LMS_BAD_INPUT_DATA ); + } + ctx->params.type = mbedtls_lms_network_bytes_to_unsigned_int( MBEDTLS_LMOTS_TYPE_LEN, key + MBEDTLS_LMOTS_SIG_TYPE_OFFSET ); diff --git a/library/lms.c b/library/lms.c index a4235addc2..fba5d88480 100644 --- a/library/lms.c +++ b/library/lms.c @@ -235,11 +235,6 @@ int mbedtls_lms_import_public_key( mbedtls_lms_public_t *ctx, mbedtls_lms_algorithm_type_t type; mbedtls_lmots_algorithm_type_t otstype; - if( key_size != MBEDTLS_LMS_PUBLIC_KEY_LEN(ctx->params.type) ) - { - return( MBEDTLS_ERR_LMS_BAD_INPUT_DATA ); - } - type = mbedtls_lms_network_bytes_to_unsigned_int( MBEDTLS_LMS_TYPE_LEN, key + PUBLIC_KEY_TYPE_OFFSET ); if( type != MBEDTLS_LMS_SHA256_M32_H10 ) @@ -248,6 +243,11 @@ int mbedtls_lms_import_public_key( mbedtls_lms_public_t *ctx, } ctx->params.type = type; + if( key_size != MBEDTLS_LMS_PUBLIC_KEY_LEN(ctx->params.type) ) + { + return( MBEDTLS_ERR_LMS_BAD_INPUT_DATA ); + } + otstype = mbedtls_lms_network_bytes_to_unsigned_int( MBEDTLS_LMOTS_TYPE_LEN, key + PUBLIC_KEY_OTSTYPE_OFFSET ); if( otstype != MBEDTLS_LMOTS_SHA256_N32_W8 )