From c86c5df0814946c3e1835c1103eaf46aeb1ef689 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Manuel=20P=C3=A9gouri=C3=A9-Gonnard?=
 <manuel.pegourie-gonnard@arm.com>
Date: Mon, 15 Jul 2019 11:23:03 +0200
Subject: [PATCH] Add saved fields from top-level structure

---
 library/ssl_tls.c | 174 ++++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 174 insertions(+)

diff --git a/library/ssl_tls.c b/library/ssl_tls.c
index 2d068faadb..d20cce1ed9 100644
--- a/library/ssl_tls.c
+++ b/library/ssl_tls.c
@@ -11430,6 +11430,88 @@ int mbedtls_ssl_context_save( mbedtls_ssl_context *ssl,
     }
 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
 
+    /*
+     * Saved fields from top-level ssl_context structure
+     */
+#if defined(MBEDTLS_SSL_DTLS_BADMAC_LIMIT)
+    used += 4;
+    if( used <= buf_len )
+    {
+        *p++ = (unsigned char)( ( ssl->badmac_seen >> 24 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->badmac_seen >> 16 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->badmac_seen >>  8 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->badmac_seen       ) & 0xFF );
+    }
+#endif /* MBEDTLS_SSL_DTLS_BADMAC_LIMIT */
+
+#if defined(MBEDTLS_SSL_DTLS_ANTI_REPLAY)
+    used += 16;
+    if( used <= buf_len )
+    {
+        *p++ = (unsigned char)( ( ssl->in_window_top >> 56 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window_top >> 48 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window_top >> 40 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window_top >> 32 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window_top >> 24 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window_top >> 16 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window_top >>  8 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window_top       ) & 0xFF );
+
+        *p++ = (unsigned char)( ( ssl->in_window >> 56 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window >> 48 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window >> 40 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window >> 32 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window >> 24 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window >> 16 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window >>  8 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->in_window       ) & 0xFF );
+    }
+#endif /* MBEDTLS_SSL_DTLS_ANTI_REPLAY */
+
+#if defined(MBEDTLS_SSL_PROTO_DTLS)
+    used += 1;
+    if( used <= buf_len )
+    {
+        *p++ = ssl->disable_datagram_packing;
+    }
+#endif /* MBEDTLS_SSL_PROTO_DTLS */
+
+    used += 8;
+    if( used <= buf_len )
+    {
+        memcpy( p, ssl->cur_out_ctr, 8 );
+        p += 8;
+    }
+
+#if defined(MBEDTLS_SSL_PROTO_DTLS)
+    used += 2;
+    if( used <= buf_len )
+    {
+        *p++ = (unsigned char)( ( ssl->mtu >>  8 ) & 0xFF );
+        *p++ = (unsigned char)( ( ssl->mtu       ) & 0xFF );
+    }
+#endif /* MBEDTLS_SSL_PROTO_DTLS */
+
+#if defined(MBEDTLS_SSL_ALPN)
+    {
+        const uint8_t alpn_len = ssl->alpn_chosen
+                               ? strlen( ssl->alpn_chosen )
+                               : 0;
+
+        used += 1 + alpn_len;
+        if( used <= buf_len )
+        {
+            *p++ = alpn_len;
+
+            if( ssl->alpn_chosen != NULL )
+            {
+                memcpy( p, ssl->alpn_chosen, alpn_len );
+                p += alpn_len;
+            }
+        }
+    }
+#endif /* MBEDTLS_SSL_ALPN */
+
     /*
      * Done
      */
@@ -11610,6 +11692,98 @@ static int ssl_context_load( mbedtls_ssl_context *ssl,
     p += ssl->transform->out_cid_len;
 #endif /* MBEDTLS_SSL_DTLS_CONNECTION_ID */
 
+    /*
+     * Saved fields from top-level ssl_context structure
+     */
+#if defined(MBEDTLS_SSL_DTLS_BADMAC_LIMIT)
+    if( (size_t)( end - p ) < 4 )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+    ssl->badmac_seen = ( (uint32_t) p[0] << 24 ) |
+                       ( (uint32_t) p[1] << 16 ) |
+                       ( (uint32_t) p[2] <<  8 ) |
+                       ( (uint32_t) p[3]       );
+    p += 4;
+#endif /* MBEDTLS_SSL_DTLS_BADMAC_LIMIT */
+
+#if defined(MBEDTLS_SSL_DTLS_ANTI_REPLAY)
+    if( (size_t)( end - p ) < 16 )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+    ssl->in_window_top = ( (uint64_t) p[0] << 56 ) |
+                         ( (uint64_t) p[1] << 48 ) |
+                         ( (uint64_t) p[2] << 40 ) |
+                         ( (uint64_t) p[3] << 32 ) |
+                         ( (uint64_t) p[4] << 24 ) |
+                         ( (uint64_t) p[5] << 16 ) |
+                         ( (uint64_t) p[6] <<  8 ) |
+                         ( (uint64_t) p[7]       );
+    p += 8;
+
+    ssl->in_window = ( (uint64_t) p[0] << 56 ) |
+                     ( (uint64_t) p[1] << 48 ) |
+                     ( (uint64_t) p[2] << 40 ) |
+                     ( (uint64_t) p[3] << 32 ) |
+                     ( (uint64_t) p[4] << 24 ) |
+                     ( (uint64_t) p[5] << 16 ) |
+                     ( (uint64_t) p[6] <<  8 ) |
+                     ( (uint64_t) p[7]       );
+    p += 8;
+#endif /* MBEDTLS_SSL_DTLS_ANTI_REPLAY */
+
+#if defined(MBEDTLS_SSL_PROTO_DTLS)
+    if( (size_t)( end - p ) < 1 )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+    ssl->disable_datagram_packing = *p++;
+#endif /* MBEDTLS_SSL_PROTO_DTLS */
+
+    if( (size_t)( end - p ) < 8 )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+    memcpy( ssl->cur_out_ctr, p, 8 );
+    p += 8;
+
+#if defined(MBEDTLS_SSL_PROTO_DTLS)
+    if( (size_t)( end - p ) < 2 )
+        return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+    ssl->mtu = ( p[0] << 8 ) | p[1];
+    p += 2;
+#endif /* MBEDTLS_SSL_PROTO_DTLS */
+
+#if defined(MBEDTLS_SSL_ALPN)
+    {
+        uint8_t alpn_len;
+        const char **cur;
+
+        if( (size_t)( end - p ) < 1 )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+        alpn_len = *p++;
+
+        if( alpn_len != 0 && ssl->conf->alpn_list != NULL )
+        {
+            /* alpn_chosen should point to an item in the configured list */
+            for( cur = ssl->conf->alpn_list; *cur != NULL; cur++ )
+            {
+                if( strlen( *cur ) == alpn_len &&
+                    memcmp( p, cur, alpn_len ) == 0 )
+                {
+                    ssl->alpn_chosen = *cur;
+                    break;
+                }
+            }
+        }
+
+        /* can only happen on conf mismatch */
+        if( alpn_len != 0 && ssl->alpn_chosen == NULL )
+            return( MBEDTLS_ERR_SSL_BAD_INPUT_DATA );
+
+        p += alpn_len;
+    }
+#endif /* MBEDTLS_SSL_ALPN */
+
     /*
      * Done - should have consumed entire buffer
      */