diff --git a/library/ssl_tls.c b/library/ssl_tls.c index 8c1e37251b..9fabda442d 100644 --- a/library/ssl_tls.c +++ b/library/ssl_tls.c @@ -8942,6 +8942,7 @@ unsigned int mbedtls_ssl_tls12_get_preferred_hash_for_sig_alg( * * struct { * uint64 start_time; + * uint8 endpoint; * uint8 ciphersuite[2]; // defined by the standard * uint8 session_id_len; // at most 32 * opaque session_id[32]; @@ -8988,13 +8989,15 @@ static size_t ssl_tls12_session_save(const mbedtls_ssl_session *session, /* * Basic mandatory fields */ - used += 2 /* ciphersuite */ + used += 1 /* endpoint */ + + 2 /* ciphersuite */ + 1 /* id_len */ + sizeof(session->id) + sizeof(session->master) + 4; /* verify_result */ if (used <= buf_len) { + *p++ = session->endpoint; MBEDTLS_PUT_UINT16_BE(session->ciphersuite, p, 0); p += 2; @@ -9129,10 +9132,11 @@ static int ssl_tls12_session_load(mbedtls_ssl_session *session, /* * Basic mandatory fields */ - if (2 + 1 + 32 + 48 + 4 > (size_t) (end - p)) { + if (1 + 2 + 1 + 32 + 48 + 4 > (size_t) (end - p)) { return MBEDTLS_ERR_SSL_BAD_INPUT_DATA; } + session->endpoint = *p++; session->ciphersuite = MBEDTLS_GET_UINT16_BE(p, 0); p += 2; diff --git a/tests/include/test/ssl_helpers.h b/tests/include/test/ssl_helpers.h index d03c62414b..ec00fd54dd 100644 --- a/tests/include/test/ssl_helpers.h +++ b/tests/include/test/ssl_helpers.h @@ -531,6 +531,7 @@ int mbedtls_test_ssl_prepare_record_mac(mbedtls_record *record, */ int mbedtls_test_ssl_tls12_populate_session(mbedtls_ssl_session *session, int ticket_len, + int endpoint_type, const char *crt_file); #if defined(MBEDTLS_SSL_PROTO_TLS1_3) diff --git a/tests/src/test_helpers/ssl_helpers.c b/tests/src/test_helpers/ssl_helpers.c index 3d8937da6d..8f20fa6d44 100644 --- a/tests/src/test_helpers/ssl_helpers.c +++ b/tests/src/test_helpers/ssl_helpers.c @@ -1636,12 +1636,15 @@ exit: #if defined(MBEDTLS_SSL_PROTO_TLS1_2) int mbedtls_test_ssl_tls12_populate_session(mbedtls_ssl_session *session, int ticket_len, + int endpoint_type, const char *crt_file) { #if defined(MBEDTLS_HAVE_TIME) session->start = mbedtls_time(NULL) - 42; #endif session->tls_version = MBEDTLS_SSL_VERSION_TLS1_2; + session->endpoint = endpoint_type == MBEDTLS_SSL_IS_CLIENT ? + MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER; session->ciphersuite = 0xabcd; session->id_len = sizeof(session->id); memset(session->id, 66, session->id_len); diff --git a/tests/suites/test_suite_ssl.function b/tests/suites/test_suite_ssl.function index 8a03d1b970..b116705b1b 100644 --- a/tests/suites/test_suite_ssl.function +++ b/tests/suites/test_suite_ssl.function @@ -1941,7 +1941,6 @@ void ssl_serialize_session_save_load(int ticket_len, char *crt_file, USE_PSA_INIT(); /* Prepare a dummy session to work on */ - ((void) endpoint_type); ((void) tls_version); ((void) ticket_len); ((void) crt_file); @@ -1955,7 +1954,7 @@ void ssl_serialize_session_save_load(int ticket_len, char *crt_file, #if defined(MBEDTLS_SSL_PROTO_TLS1_2) if (tls_version == MBEDTLS_SSL_VERSION_TLS1_2) { TEST_ASSERT(mbedtls_test_ssl_tls12_populate_session( - &original, ticket_len, crt_file) == 0); + &original, ticket_len, endpoint_type, crt_file) == 0); } #endif @@ -1995,6 +1994,7 @@ void ssl_serialize_session_save_load(int ticket_len, char *crt_file, #endif TEST_ASSERT(original.tls_version == restored.tls_version); + TEST_ASSERT(original.endpoint == restored.endpoint); TEST_ASSERT(original.ciphersuite == restored.ciphersuite); #if defined(MBEDTLS_SSL_PROTO_TLS1_2) if (tls_version == MBEDTLS_SSL_VERSION_TLS1_2) { @@ -2053,7 +2053,6 @@ void ssl_serialize_session_save_load(int ticket_len, char *crt_file, #if defined(MBEDTLS_SSL_PROTO_TLS1_3) if (tls_version == MBEDTLS_SSL_VERSION_TLS1_3) { - TEST_ASSERT(original.endpoint == restored.endpoint); TEST_ASSERT(original.ciphersuite == restored.ciphersuite); TEST_ASSERT(original.ticket_age_add == restored.ticket_age_add); TEST_ASSERT(original.ticket_flags == restored.ticket_flags); @@ -2123,7 +2122,6 @@ void ssl_serialize_session_load_save(int ticket_len, char *crt_file, USE_PSA_INIT(); /* Prepare a dummy session to work on */ - ((void) endpoint_type); ((void) ticket_len); ((void) crt_file); @@ -2138,7 +2136,7 @@ void ssl_serialize_session_load_save(int ticket_len, char *crt_file, #if defined(MBEDTLS_SSL_PROTO_TLS1_2) case MBEDTLS_SSL_VERSION_TLS1_2: TEST_ASSERT(mbedtls_test_ssl_tls12_populate_session( - &session, ticket_len, crt_file) == 0); + &session, ticket_len, endpoint_type, crt_file) == 0); break; #endif default: @@ -2197,7 +2195,6 @@ void ssl_serialize_session_save_buf_size(int ticket_len, char *crt_file, USE_PSA_INIT(); /* Prepare dummy session and get serialized size */ - ((void) endpoint_type); ((void) ticket_len); ((void) crt_file); @@ -2211,7 +2208,7 @@ void ssl_serialize_session_save_buf_size(int ticket_len, char *crt_file, #if defined(MBEDTLS_SSL_PROTO_TLS1_2) case MBEDTLS_SSL_VERSION_TLS1_2: TEST_ASSERT(mbedtls_test_ssl_tls12_populate_session( - &session, ticket_len, crt_file) == 0); + &session, ticket_len, endpoint_type, crt_file) == 0); break; #endif default: @@ -2257,7 +2254,6 @@ void ssl_serialize_session_load_buf_size(int ticket_len, char *crt_file, USE_PSA_INIT(); /* Prepare serialized session data */ - ((void) endpoint_type); ((void) ticket_len); ((void) crt_file); @@ -2272,7 +2268,7 @@ void ssl_serialize_session_load_buf_size(int ticket_len, char *crt_file, #if defined(MBEDTLS_SSL_PROTO_TLS1_2) case MBEDTLS_SSL_VERSION_TLS1_2: TEST_ASSERT(mbedtls_test_ssl_tls12_populate_session( - &session, ticket_len, crt_file) == 0); + &session, ticket_len, endpoint_type, crt_file) == 0); break; #endif @@ -2329,7 +2325,6 @@ void ssl_session_serialize_version_check(int corrupt_major, mbedtls_ssl_session_init(&session); USE_PSA_INIT(); - ((void) endpoint_type); switch (tls_version) { #if defined(MBEDTLS_SSL_PROTO_TLS1_3) @@ -2341,7 +2336,7 @@ void ssl_session_serialize_version_check(int corrupt_major, #if defined(MBEDTLS_SSL_PROTO_TLS1_2) case MBEDTLS_SSL_VERSION_TLS1_2: TEST_ASSERT(mbedtls_test_ssl_tls12_populate_session( - &session, 0, NULL) == 0); + &session, 0, endpoint_type, NULL) == 0); break; #endif