ssl_tls.c: Factorize save/load of endpoint and ciphersuite

Move the save/load of session endpoint and
ciphersuite that are common to TLS 1.2 and
TLS 1.3 serialized data from the
specialized ssl_{tls12,tls13}_session_{save,load}
functions to ssl__session_{save,load}.

Signed-off-by: Ronald Cron <ronald.cron@arm.com>
This commit is contained in:
Ronald Cron 2024-01-15 10:21:30 +01:00
parent 3c0072b58e
commit 40a4ab0e0c

View File

@ -2449,8 +2449,6 @@ mbedtls_ssl_mode_t mbedtls_ssl_get_mode_from_ciphersuite(
* } ClientOnlyData;
*
* struct {
* uint8 endpoint;
* uint8 ciphersuite[2];
* uint32 ticket_age_add;
* uint8 ticket_flags;
* opaque resumption_key<0..255>;
@ -2476,11 +2474,9 @@ static int ssl_tls13_session_save(const mbedtls_ssl_session *session,
size_t hostname_len = (session->hostname == NULL) ?
0 : strlen(session->hostname) + 1;
#endif
size_t needed = 1 /* endpoint */
+ 2 /* ciphersuite */
+ 4 /* ticket_age_add */
+ 1 /* ticket_flags */
+ 1; /* resumption_key length */
size_t needed = 4 /* ticket_age_add */
+ 1 /* ticket_flags */
+ 1; /* resumption_key length */
*olen = 0;
if (session->resumption_key_len > MBEDTLS_SSL_TLS1_3_TICKET_RESUMPTION_KEY_LEN) {
@ -2523,14 +2519,12 @@ static int ssl_tls13_session_save(const mbedtls_ssl_session *session,
return MBEDTLS_ERR_SSL_BUFFER_TOO_SMALL;
}
p[0] = session->endpoint;
MBEDTLS_PUT_UINT16_BE(session->ciphersuite, p, 1);
MBEDTLS_PUT_UINT32_BE(session->ticket_age_add, p, 3);
p[7] = session->ticket_flags;
MBEDTLS_PUT_UINT32_BE(session->ticket_age_add, p, 0);
p[4] = session->ticket_flags;
/* save resumption_key */
p[8] = session->resumption_key_len;
p += 9;
p[5] = session->resumption_key_len;
p += 6;
memcpy(p, session->resumption_key, session->resumption_key_len);
p += session->resumption_key_len;
@ -2589,17 +2583,15 @@ static int ssl_tls13_session_load(mbedtls_ssl_session *session,
const unsigned char *p = buf;
const unsigned char *end = buf + len;
if (end - p < 9) {
if (end - p < 6) {
return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
}
session->endpoint = p[0];
session->ciphersuite = MBEDTLS_GET_UINT16_BE(p, 1);
session->ticket_age_add = MBEDTLS_GET_UINT32_BE(p, 3);
session->ticket_flags = p[7];
session->ticket_age_add = MBEDTLS_GET_UINT32_BE(p, 0);
session->ticket_flags = p[4];
/* load resumption_key */
session->resumption_key_len = p[8];
p += 9;
session->resumption_key_len = p[5];
p += 6;
if (end - p < session->resumption_key_len) {
return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
@ -3777,11 +3769,16 @@ static int ssl_session_save(const mbedtls_ssl_session *session,
}
/*
* TLS version identifier
* TLS version identifier, endpoint, ciphersuite
*/
used += 1;
used += 1 /* TLS version */
+ 1 /* endpoint */
+ 2; /* ciphersuite */
if (used <= buf_len) {
*p++ = MBEDTLS_BYTE_0(session->tls_version);
*p++ = session->endpoint;
MBEDTLS_PUT_UINT16_BE(session->ciphersuite, p, 0);
p += 2;
}
/* Forward to version-specific serialization routine. */
@ -3864,12 +3861,15 @@ static int ssl_session_load(mbedtls_ssl_session *session,
}
/*
* TLS version identifier
* TLS version identifier, endpoint, ciphersuite
*/
if (1 > (size_t) (end - p)) {
if (4 > (size_t) (end - p)) {
return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
}
session->tls_version = (mbedtls_ssl_protocol_version) (0x0300 | *p++);
session->endpoint = *p++;
session->ciphersuite = MBEDTLS_GET_UINT16_BE(p, 0);
p += 2;
/* Dispatch according to TLS version. */
remaining_len = (size_t) (end - p);
@ -8942,8 +8942,6 @@ unsigned int mbedtls_ssl_tls12_get_preferred_hash_for_sig_alg(
*
* struct {
* uint64 start_time;
* uint8 endpoint;
* uint8 ciphersuite[2]; // defined by the standard
* uint8 session_id_len; // at most 32
* opaque session_id[32];
* opaque master[48]; // fixed length in the standard
@ -8990,18 +8988,12 @@ static size_t ssl_tls12_session_save(const mbedtls_ssl_session *session,
/*
* Basic mandatory fields
*/
used += 1 /* endpoint */
+ 2 /* ciphersuite */
+ 1 /* id_len */
used += 1 /* id_len */
+ sizeof(session->id)
+ sizeof(session->master)
+ 4; /* verify_result */
if (used <= buf_len) {
*p++ = session->endpoint;
MBEDTLS_PUT_UINT16_BE(session->ciphersuite, p, 0);
p += 2;
*p++ = MBEDTLS_BYTE_0(session->id_len);
memcpy(p, session->id, 32);
p += 32;
@ -9147,14 +9139,10 @@ static int ssl_tls12_session_load(mbedtls_ssl_session *session,
/*
* Basic mandatory fields
*/
if (1 + 2 + 1 + 32 + 48 + 4 > (size_t) (end - p)) {
if (1 + 32 + 48 + 4 > (size_t) (end - p)) {
return MBEDTLS_ERR_SSL_BAD_INPUT_DATA;
}
session->endpoint = *p++;
session->ciphersuite = MBEDTLS_GET_UINT16_BE(p, 0);
p += 2;
session->id_len = *p++;
memcpy(session->id, p, 32);
p += 32;