diff --git a/library/psa_crypto.c b/library/psa_crypto.c index 3b3c116252..d21c13ea54 100644 --- a/library/psa_crypto.c +++ b/library/psa_crypto.c @@ -3165,15 +3165,27 @@ psa_status_t psa_sign_message_builtin( psa_status_t psa_sign_message(mbedtls_svc_key_id_t key, psa_algorithm_t alg, - const uint8_t *input, + const uint8_t *input_external, size_t input_length, - uint8_t *signature, + uint8_t *signature_external, size_t signature_size, size_t *signature_length) { - return psa_sign_internal( - key, 1, alg, input, input_length, - signature, signature_size, signature_length); + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + LOCAL_INPUT_DECLARE(input_external, input); + LOCAL_OUTPUT_DECLARE(signature_external, signature); + + LOCAL_INPUT_ALLOC(input_external, input_length, input); + LOCAL_OUTPUT_ALLOC(signature_external, signature_size, signature); + status = psa_sign_internal(key, 1, alg, input, input_length, signature, + signature_size, signature_length); + +#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) +exit: +#endif + LOCAL_INPUT_FREE(input_external, input); + LOCAL_OUTPUT_FREE(signature_external, signature); + return status; } psa_status_t psa_verify_message_builtin( @@ -3212,14 +3224,27 @@ psa_status_t psa_verify_message_builtin( psa_status_t psa_verify_message(mbedtls_svc_key_id_t key, psa_algorithm_t alg, - const uint8_t *input, + const uint8_t *input_external, size_t input_length, - const uint8_t *signature, + const uint8_t *signature_external, size_t signature_length) { - return psa_verify_internal( - key, 1, alg, input, input_length, - signature, signature_length); + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + LOCAL_INPUT_DECLARE(input_external, input); + LOCAL_INPUT_DECLARE(signature_external, signature); + + LOCAL_INPUT_ALLOC(input_external, input_length, input); + LOCAL_INPUT_ALLOC(signature_external, signature_length, signature); + status = psa_verify_internal(key, 1, alg, input, input_length, signature, + signature_length); + +#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) +exit: +#endif + LOCAL_INPUT_FREE(input_external, input); + LOCAL_INPUT_FREE(signature_external, signature); + + return status; } psa_status_t psa_sign_hash_builtin( @@ -3272,15 +3297,28 @@ psa_status_t psa_sign_hash_builtin( psa_status_t psa_sign_hash(mbedtls_svc_key_id_t key, psa_algorithm_t alg, - const uint8_t *hash, + const uint8_t *hash_external, size_t hash_length, - uint8_t *signature, + uint8_t *signature_external, size_t signature_size, size_t *signature_length) { - return psa_sign_internal( - key, 0, alg, hash, hash_length, - signature, signature_size, signature_length); + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + LOCAL_INPUT_DECLARE(hash_external, hash); + LOCAL_OUTPUT_DECLARE(signature_external, signature); + + LOCAL_INPUT_ALLOC(hash_external, hash_length, hash); + LOCAL_OUTPUT_ALLOC(signature_external, signature_size, signature); + status = psa_sign_internal(key, 0, alg, hash, hash_length, signature, + signature_size, signature_length); + +#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) +exit: +#endif + LOCAL_INPUT_FREE(hash_external, hash); + LOCAL_OUTPUT_FREE(signature_external, signature); + + return status; } psa_status_t psa_verify_hash_builtin( @@ -3332,14 +3370,27 @@ psa_status_t psa_verify_hash_builtin( psa_status_t psa_verify_hash(mbedtls_svc_key_id_t key, psa_algorithm_t alg, - const uint8_t *hash, + const uint8_t *hash_external, size_t hash_length, - const uint8_t *signature, + const uint8_t *signature_external, size_t signature_length) { - return psa_verify_internal( - key, 0, alg, hash, hash_length, - signature, signature_length); + psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED; + LOCAL_INPUT_DECLARE(hash_external, hash); + LOCAL_INPUT_DECLARE(signature_external, signature); + + LOCAL_INPUT_ALLOC(hash_external, hash_length, hash); + LOCAL_INPUT_ALLOC(signature_external, signature_length, signature); + status = psa_verify_internal(key, 0, alg, hash, hash_length, signature, + signature_length); + +#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) +exit: +#endif + LOCAL_INPUT_FREE(hash_external, hash); + LOCAL_INPUT_FREE(signature_external, signature); + + return status; } psa_status_t psa_asymmetric_encrypt(mbedtls_svc_key_id_t key, diff --git a/tests/scripts/generate_psa_wrappers.py b/tests/scripts/generate_psa_wrappers.py index e5b4256f5e..005a324116 100755 --- a/tests/scripts/generate_psa_wrappers.py +++ b/tests/scripts/generate_psa_wrappers.py @@ -145,6 +145,11 @@ class PSAWrapperGenerator(c_wrapper_generator.Base): # Proof-of-concept: just instrument one function for now if function_name == 'psa_cipher_encrypt': return True + if function_name in ('psa_sign_message', + 'psa_verify_message', + 'psa_sign_hash', + 'psa_verify_hash'): + return True return False def _write_function_call(self, out: typing_util.Writable, diff --git a/tests/src/psa_test_wrappers.c b/tests/src/psa_test_wrappers.c index 3a3aaade9a..460d4535f5 100644 --- a/tests/src/psa_test_wrappers.c +++ b/tests/src/psa_test_wrappers.c @@ -873,7 +873,15 @@ psa_status_t mbedtls_test_wrap_psa_sign_hash( size_t arg5_signature_size, size_t *arg6_signature_length) { +#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) + MBEDTLS_TEST_MEMORY_POISON(arg2_hash, arg3_hash_length); + MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_size); +#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */ psa_status_t status = (psa_sign_hash)(arg0_key, arg1_alg, arg2_hash, arg3_hash_length, arg4_signature, arg5_signature_size, arg6_signature_length); +#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) + MBEDTLS_TEST_MEMORY_UNPOISON(arg2_hash, arg3_hash_length); + MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_size); +#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */ return status; } @@ -918,7 +926,15 @@ psa_status_t mbedtls_test_wrap_psa_sign_message( size_t arg5_signature_size, size_t *arg6_signature_length) { +#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) + MBEDTLS_TEST_MEMORY_POISON(arg2_input, arg3_input_length); + MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_size); +#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */ psa_status_t status = (psa_sign_message)(arg0_key, arg1_alg, arg2_input, arg3_input_length, arg4_signature, arg5_signature_size, arg6_signature_length); +#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) + MBEDTLS_TEST_MEMORY_UNPOISON(arg2_input, arg3_input_length); + MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_size); +#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */ return status; } @@ -931,7 +947,15 @@ psa_status_t mbedtls_test_wrap_psa_verify_hash( const uint8_t *arg4_signature, size_t arg5_signature_length) { +#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) + MBEDTLS_TEST_MEMORY_POISON(arg2_hash, arg3_hash_length); + MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_length); +#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */ psa_status_t status = (psa_verify_hash)(arg0_key, arg1_alg, arg2_hash, arg3_hash_length, arg4_signature, arg5_signature_length); +#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) + MBEDTLS_TEST_MEMORY_UNPOISON(arg2_hash, arg3_hash_length); + MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_length); +#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */ return status; } @@ -974,7 +998,15 @@ psa_status_t mbedtls_test_wrap_psa_verify_message( const uint8_t *arg4_signature, size_t arg5_signature_length) { +#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) + MBEDTLS_TEST_MEMORY_POISON(arg2_input, arg3_input_length); + MBEDTLS_TEST_MEMORY_POISON(arg4_signature, arg5_signature_length); +#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */ psa_status_t status = (psa_verify_message)(arg0_key, arg1_alg, arg2_input, arg3_input_length, arg4_signature, arg5_signature_length); +#if defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) + MBEDTLS_TEST_MEMORY_UNPOISON(arg2_input, arg3_input_length); + MBEDTLS_TEST_MEMORY_UNPOISON(arg4_signature, arg5_signature_length); +#endif /* defined(MBEDTLS_PSA_COPY_CALLER_BUFFERS) */ return status; }