diff --git a/library/sha3.c b/library/sha3.c index 4b97a85c5f..982550419b 100644 --- a/library/sha3.c +++ b/library/sha3.c @@ -259,10 +259,13 @@ int mbedtls_sha3_update(mbedtls_sha3_context *ctx, int mbedtls_sha3_finish(mbedtls_sha3_context *ctx, uint8_t *output, size_t olen) { + int ret = MBEDTLS_ERR_ERROR_CORRUPTION_DETECTED; + /* Catch SHA-3 families, with fixed output length */ if (ctx->olen > 0) { if (ctx->olen > olen) { - return MBEDTLS_ERR_SHA3_BAD_INPUT_DATA; + ret = MBEDTLS_ERR_SHA3_BAD_INPUT_DATA; + goto exit; } olen = ctx->olen; } @@ -280,7 +283,11 @@ int mbedtls_sha3_finish(mbedtls_sha3_context *ctx, } } - return 0; + ret = 0; + +exit: + mbedtls_platform_zeroize(ctx, sizeof(mbedtls_sha3_context)); + return ret; } /* diff --git a/tests/suites/test_suite_shax.function b/tests/suites/test_suite_shax.function index 7dd9166658..629e281008 100644 --- a/tests/suites/test_suite_shax.function +++ b/tests/suites/test_suite_shax.function @@ -176,9 +176,12 @@ void sha3_invalid_param() TEST_EQUAL(mbedtls_sha3_starts(&ctx, MBEDTLS_SHA3_NONE), MBEDTLS_ERR_SHA3_BAD_INPUT_DATA); TEST_EQUAL(mbedtls_sha3_starts(&ctx, MBEDTLS_SHA3_256), 0); - TEST_EQUAL(mbedtls_sha3_finish(&ctx, output, 0), MBEDTLS_ERR_SHA3_BAD_INPUT_DATA); + + TEST_EQUAL(mbedtls_sha3_starts(&ctx, MBEDTLS_SHA3_256), 0); TEST_EQUAL(mbedtls_sha3_finish(&ctx, output, 31), MBEDTLS_ERR_SHA3_BAD_INPUT_DATA); + + TEST_EQUAL(mbedtls_sha3_starts(&ctx, MBEDTLS_SHA3_256), 0); TEST_EQUAL(mbedtls_sha3_finish(&ctx, output, 32), 0); exit: