From 251e86ae3f19c3866650b517f61a80cd881f44a3 Mon Sep 17 00:00:00 2001
From: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
Date: Fri, 17 Feb 2023 14:30:50 +0100
Subject: [PATCH] Adapt names to more suitable and fix conditional compilation
 flags

Signed-off-by: Przemek Stekiel <przemyslaw.stekiel@mobica.com>
---
 docs/proposed/psa-driver-interface.md         |  4 +--
 include/psa/crypto_builtin_composites.h       |  8 +++---
 .../psa/crypto_driver_contexts_composites.h   |  2 +-
 .../psa/crypto_driver_contexts_primitives.h   |  2 --
 include/psa/crypto_extra.h                    | 11 +++++---
 library/psa_crypto.c                          |  6 ++---
 library/psa_crypto_driver_wrappers.h          |  4 +--
 library/psa_crypto_pake.c                     | 26 ++++++++++---------
 library/psa_crypto_pake.h                     |  4 +--
 .../psa_crypto_driver_wrappers.c.jinja        |  4 +--
 tests/include/test/drivers/pake.h             |  8 +++---
 tests/src/drivers/test_driver_pake.c          |  8 +++---
 12 files changed, 46 insertions(+), 41 deletions(-)

diff --git a/docs/proposed/psa-driver-interface.md b/docs/proposed/psa-driver-interface.md
index 1b941cede0..07f198908d 100644
--- a/docs/proposed/psa-driver-interface.md
+++ b/docs/proposed/psa-driver-interface.md
@@ -410,7 +410,7 @@ The pointer output by `psa_crypto_driver_pake_get_password_key` is only valid un
 
 ```
 psa_status_t acme_pake_output(acme_pake_operation_t *operation,
-                              psa_pake_driver_step_t step,
+                              psa_crypto_driver_pake_step_t step,
                               uint8_t *output,
                               size_t output_size,
                               size_t *output_length);
@@ -437,7 +437,7 @@ For `PSA_ALG_JPAKE` the following steps are available for output operation:
 #### PAKE driver input
 ```
 psa_status_t acme_pake_input(acme_pake_operation_t *operation,
-                             psa_pake_driver_step_t step,
+                             psa_crypto_driver_pake_step_t step,
                              uint8_t *input,
                              size_t input_size);
 ```
diff --git a/include/psa/crypto_builtin_composites.h b/include/psa/crypto_builtin_composites.h
index 3221a64234..f331ec5f48 100644
--- a/include/psa/crypto_builtin_composites.h
+++ b/include/psa/crypto_builtin_composites.h
@@ -191,23 +191,25 @@ typedef struct {
 /* Note: the format for mbedtls_ecjpake_read/write function has an extra
  * length byte for each step, plus an extra 3 bytes for ECParameters in the
  * server's 2nd round. */
-#define MBEDTLS_PSA_PAKE_BUFFER_SIZE ((3 + 1 + 65 + 1 + 65 + 1 + 32) * 2)
+#define MBEDTLS_PSA_JPAKE_BUFFER_SIZE ((3 + 1 + 65 + 1 + 65 + 1 + 32) * 2)
 
 typedef struct {
     psa_algorithm_t MBEDTLS_PRIVATE(alg);
 
-#if defined(MBEDTLS_PSA_BUILTIN_PAKE)
     uint8_t *MBEDTLS_PRIVATE(password);
     size_t MBEDTLS_PRIVATE(password_len);
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
     uint8_t MBEDTLS_PRIVATE(role);
-    uint8_t MBEDTLS_PRIVATE(buffer[MBEDTLS_PSA_PAKE_BUFFER_SIZE]);
+    uint8_t MBEDTLS_PRIVATE(buffer[MBEDTLS_PSA_JPAKE_BUFFER_SIZE]);
     size_t MBEDTLS_PRIVATE(buffer_length);
     size_t MBEDTLS_PRIVATE(buffer_offset);
 #endif
     /* Context structure for the Mbed TLS EC-JPAKE implementation. */
     union {
         unsigned int MBEDTLS_PRIVATE(dummy);
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
         mbedtls_ecjpake_context MBEDTLS_PRIVATE(pake);
+#endif
     } MBEDTLS_PRIVATE(ctx);
 
 } mbedtls_psa_pake_operation_t;
diff --git a/include/psa/crypto_driver_contexts_composites.h b/include/psa/crypto_driver_contexts_composites.h
index 4d0e9848d3..6c56a51dbc 100644
--- a/include/psa/crypto_driver_contexts_composites.h
+++ b/include/psa/crypto_driver_contexts_composites.h
@@ -93,7 +93,7 @@ typedef mbedtls_psa_aead_operation_t
 
 typedef libtestdriver1_mbedtls_psa_pake_operation_t
     mbedtls_transparent_test_driver_pake_operation_t;
-typedef libtestdriver1_psa_pake_operation_t
+typedef libtestdriver1_mbedtls_psa_pake_operation_t
     mbedtls_opaque_test_driver_pake_operation_t;
 
 #define MBEDTLS_TRANSPARENT_TEST_DRIVER_PAKE_OPERATION_INIT \
diff --git a/include/psa/crypto_driver_contexts_primitives.h b/include/psa/crypto_driver_contexts_primitives.h
index f1463f34d0..620a4b3a77 100644
--- a/include/psa/crypto_driver_contexts_primitives.h
+++ b/include/psa/crypto_driver_contexts_primitives.h
@@ -45,8 +45,6 @@
 #include <libtestdriver1/include/psa/crypto.h>
 #endif
 
-#include "mbedtls/ecjpake.h"
-
 #if defined(PSA_CRYPTO_DRIVER_TEST)
 
 #if defined(MBEDTLS_TEST_LIBTESTDRIVER1) && \
diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h
index 8b8cb042e7..39ef52cbec 100644
--- a/include/psa/crypto_extra.h
+++ b/include/psa/crypto_extra.h
@@ -429,7 +429,7 @@ psa_status_t mbedtls_psa_inject_entropy(const uint8_t *seed,
  */
 #define PSA_DH_FAMILY_CUSTOM             ((psa_dh_family_t) 0x7e)
 
-/** EC-JPAKE operation stages. */
+/** PAKE operation stages. */
 #define PSA_PAKE_OPERATION_STAGE_SETUP 0
 #define PSA_PAKE_OPERATION_STAGE_COLLECT_INPUTS 1
 #define PSA_PAKE_OPERATION_STAGE_COMPUTATION 2
@@ -1895,7 +1895,7 @@ psa_status_t psa_pake_abort(psa_pake_operation_t *operation);
  * psa_pake_operation_t.
  */
 #define PSA_PAKE_OPERATION_INIT { 0, PSA_ALG_NONE, PSA_PAKE_OPERATION_STAGE_SETUP, \
-                                  { { 0, 0, 0, 0 } }, { { 0 } } }
+                                  { 0 }, { { 0 } } }
 
 struct psa_pake_cipher_suite_s {
     psa_algorithm_t algorithm;
@@ -2002,7 +2002,7 @@ enum psa_jpake_sequence {
     PSA_PAKE_SEQ_END            = 7,
 };
 
-typedef enum psa_pake_driver_step {
+typedef enum psa_crypto_driver_pake_step {
     PSA_JPAKE_STEP_INVALID        = 0,  /* Invalid step */
     PSA_JPAKE_X1_STEP_KEY_SHARE   = 1,  /* Round 1: input/output key share (for ephemeral private key X1).*/
     PSA_JPAKE_X1_STEP_ZK_PUBLIC   = 2,  /* Round 1: input/output Schnorr NIZKP public key for the X1 key */
@@ -2016,7 +2016,7 @@ typedef enum psa_pake_driver_step {
     PSA_JPAKE_X4S_STEP_KEY_SHARE  = 10, /* Round 2: input X4S key (from peer) */
     PSA_JPAKE_X4S_STEP_ZK_PUBLIC  = 11, /* Round 2: input Schnorr NIZKP public key for the X4S key (from peer) */
     PSA_JPAKE_X4S_STEP_ZK_PROOF   = 12  /* Round 2: input Schnorr NIZKP proof for the X4S key (from peer) */
-} psa_pake_driver_step_t;
+} psa_crypto_driver_pake_step_t;
 
 
 struct psa_jpake_computation_stage_s {
@@ -2042,7 +2042,10 @@ struct psa_pake_operation_s {
     uint8_t MBEDTLS_PRIVATE(stage);
     /* Holds computation stage of the PAKE algorithms. */
     union {
+        uint8_t MBEDTLS_PRIVATE(dummy);
+#if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
         psa_jpake_computation_stage_t MBEDTLS_PRIVATE(jpake);
+#endif
     } MBEDTLS_PRIVATE(computation_stage);
     union {
         psa_driver_pake_context_t MBEDTLS_PRIVATE(ctx);
diff --git a/library/psa_crypto.c b/library/psa_crypto.c
index c57583aef3..2c1a910fbd 100644
--- a/library/psa_crypto.c
+++ b/library/psa_crypto.c
@@ -7407,7 +7407,7 @@ exit:
 }
 
 /* Auxiliary function to convert core computation stage(step, sequence, state) to single driver step. */
-static psa_pake_driver_step_t convert_jpake_computation_stage_to_driver_step(
+static psa_crypto_driver_pake_step_t convert_jpake_computation_stage_to_driver_step(
     psa_jpake_computation_stage_t *stage)
 {
     switch (stage->state) {
@@ -7843,7 +7843,7 @@ psa_status_t psa_pake_get_implicit_key(
 {
     psa_status_t status = PSA_ERROR_CORRUPTION_DETECTED;
     psa_status_t abort_status = PSA_ERROR_CORRUPTION_DETECTED;
-    uint8_t shared_key[MBEDTLS_PSA_PAKE_BUFFER_SIZE];
+    uint8_t shared_key[MBEDTLS_PSA_JPAKE_BUFFER_SIZE];
     size_t shared_key_len = 0;
 
     if (operation->stage != PSA_PAKE_OPERATION_STAGE_COMPUTATION) {
@@ -7874,7 +7874,7 @@ psa_status_t psa_pake_get_implicit_key(
                                             shared_key,
                                             shared_key_len);
 
-    mbedtls_platform_zeroize(shared_key, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
+    mbedtls_platform_zeroize(shared_key, MBEDTLS_PSA_JPAKE_BUFFER_SIZE);
 exit:
     abort_status = psa_pake_abort(operation);
     return status == PSA_SUCCESS ? abort_status : status;
diff --git a/library/psa_crypto_driver_wrappers.h b/library/psa_crypto_driver_wrappers.h
index 11a95e3a00..65d0d3f078 100644
--- a/library/psa_crypto_driver_wrappers.h
+++ b/library/psa_crypto_driver_wrappers.h
@@ -421,14 +421,14 @@ psa_status_t psa_driver_wrapper_pake_setup(
 
 psa_status_t psa_driver_wrapper_pake_output(
     psa_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length);
 
 psa_status_t psa_driver_wrapper_pake_input(
     psa_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     const uint8_t *input,
     size_t input_length);
 
diff --git a/library/psa_crypto_pake.c b/library/psa_crypto_pake.c
index fdfbd16fbf..73032c6a8a 100644
--- a/library/psa_crypto_pake.c
+++ b/library/psa_crypto_pake.c
@@ -163,6 +163,7 @@ static psa_status_t mbedtls_ecjpake_to_psa_error(int ret)
 }
 #endif
 
+#if defined(MBEDTLS_PSA_BUILTIN_PAKE)
 #if defined(MBEDTLS_PSA_BUILTIN_ALG_JPAKE)
 static psa_status_t psa_pake_ecjpake_setup(mbedtls_psa_pake_operation_t *operation)
 {
@@ -187,6 +188,7 @@ static psa_status_t psa_pake_ecjpake_setup(mbedtls_psa_pake_operation_t *operati
 
     return PSA_SUCCESS;
 }
+#endif
 
 psa_status_t mbedtls_psa_pake_setup(mbedtls_psa_pake_operation_t *operation,
                                     const psa_crypto_driver_pake_inputs_t *inputs)
@@ -237,7 +239,7 @@ psa_status_t mbedtls_psa_pake_setup(mbedtls_psa_pake_operation_t *operation,
         operation->role = role;
         operation->alg = cipher_suite.algorithm;
 
-        mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
+        mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_JPAKE_BUFFER_SIZE);
         operation->buffer_length = 0;
         operation->buffer_offset = 0;
 
@@ -259,7 +261,7 @@ psa_status_t mbedtls_psa_pake_setup(mbedtls_psa_pake_operation_t *operation,
 
 static psa_status_t mbedtls_psa_pake_output_internal(
     mbedtls_psa_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length)
@@ -288,7 +290,7 @@ static psa_status_t mbedtls_psa_pake_output_internal(
         if (step == PSA_JPAKE_X1_STEP_KEY_SHARE) {
             ret = mbedtls_ecjpake_write_round_one(&operation->ctx.pake,
                                                   operation->buffer,
-                                                  MBEDTLS_PSA_PAKE_BUFFER_SIZE,
+                                                  MBEDTLS_PSA_JPAKE_BUFFER_SIZE,
                                                   &operation->buffer_length,
                                                   mbedtls_psa_get_random,
                                                   MBEDTLS_PSA_RANDOM_STATE);
@@ -300,7 +302,7 @@ static psa_status_t mbedtls_psa_pake_output_internal(
         } else if (step == PSA_JPAKE_X2S_STEP_KEY_SHARE) {
             ret = mbedtls_ecjpake_write_round_two(&operation->ctx.pake,
                                                   operation->buffer,
-                                                  MBEDTLS_PSA_PAKE_BUFFER_SIZE,
+                                                  MBEDTLS_PSA_JPAKE_BUFFER_SIZE,
                                                   &operation->buffer_length,
                                                   mbedtls_psa_get_random,
                                                   MBEDTLS_PSA_RANDOM_STATE);
@@ -350,7 +352,7 @@ static psa_status_t mbedtls_psa_pake_output_internal(
         /* Reset buffer after ZK_PROOF sequence */
         if ((step == PSA_JPAKE_X2_STEP_ZK_PROOF) ||
             (step == PSA_JPAKE_X2S_STEP_ZK_PROOF)) {
-            mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
+            mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_JPAKE_BUFFER_SIZE);
             operation->buffer_length = 0;
             operation->buffer_offset = 0;
         }
@@ -367,7 +369,7 @@ static psa_status_t mbedtls_psa_pake_output_internal(
 }
 
 psa_status_t mbedtls_psa_pake_output(mbedtls_psa_pake_operation_t *operation,
-                                     psa_pake_driver_step_t step,
+                                     psa_crypto_driver_pake_step_t step,
                                      uint8_t *output,
                                      size_t output_size,
                                      size_t *output_length)
@@ -380,7 +382,7 @@ psa_status_t mbedtls_psa_pake_output(mbedtls_psa_pake_operation_t *operation,
 
 static psa_status_t mbedtls_psa_pake_input_internal(
     mbedtls_psa_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     const uint8_t *input,
     size_t input_length)
 {
@@ -441,7 +443,7 @@ static psa_status_t mbedtls_psa_pake_input_internal(
                                                  operation->buffer,
                                                  operation->buffer_length);
 
-            mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
+            mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_JPAKE_BUFFER_SIZE);
             operation->buffer_length = 0;
 
             if (ret != 0) {
@@ -452,7 +454,7 @@ static psa_status_t mbedtls_psa_pake_input_internal(
                                                  operation->buffer,
                                                  operation->buffer_length);
 
-            mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
+            mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_JPAKE_BUFFER_SIZE);
             operation->buffer_length = 0;
 
             if (ret != 0) {
@@ -471,7 +473,7 @@ static psa_status_t mbedtls_psa_pake_input_internal(
 }
 
 psa_status_t mbedtls_psa_pake_input(mbedtls_psa_pake_operation_t *operation,
-                                    psa_pake_driver_step_t step,
+                                    psa_crypto_driver_pake_step_t step,
                                     const uint8_t *input,
                                     size_t input_length)
 {
@@ -491,7 +493,7 @@ psa_status_t mbedtls_psa_pake_get_implicit_key(
     if (operation->alg == PSA_ALG_JPAKE) {
         ret = mbedtls_ecjpake_write_shared_key(&operation->ctx.pake,
                                                operation->buffer,
-                                               MBEDTLS_PSA_PAKE_BUFFER_SIZE,
+                                               MBEDTLS_PSA_JPAKE_BUFFER_SIZE,
                                                &operation->buffer_length,
                                                mbedtls_psa_get_random,
                                                MBEDTLS_PSA_RANDOM_STATE);
@@ -520,7 +522,7 @@ psa_status_t mbedtls_psa_pake_abort(mbedtls_psa_pake_operation_t *operation)
         operation->password = NULL;
         operation->password_len = 0;
         operation->role = PSA_PAKE_ROLE_NONE;
-        mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_PAKE_BUFFER_SIZE);
+        mbedtls_platform_zeroize(operation->buffer, MBEDTLS_PSA_JPAKE_BUFFER_SIZE);
         operation->buffer_length = 0;
         operation->buffer_offset = 0;
         mbedtls_ecjpake_free(&operation->ctx.pake);
diff --git a/library/psa_crypto_pake.h b/library/psa_crypto_pake.h
index dc6ad7b54f..365855601b 100644
--- a/library/psa_crypto_pake.h
+++ b/library/psa_crypto_pake.h
@@ -96,7 +96,7 @@ psa_status_t mbedtls_psa_pake_setup(mbedtls_psa_pake_operation_t *operation,
  *         results in this error code.
  */
 psa_status_t mbedtls_psa_pake_output(mbedtls_psa_pake_operation_t *operation,
-                                     psa_pake_driver_step_t step,
+                                     psa_crypto_driver_pake_step_t step,
                                      uint8_t *output,
                                      size_t output_size,
                                      size_t *output_length);
@@ -143,7 +143,7 @@ psa_status_t mbedtls_psa_pake_output(mbedtls_psa_pake_operation_t *operation,
  *         results in this error code.
  */
 psa_status_t mbedtls_psa_pake_input(mbedtls_psa_pake_operation_t *operation,
-                                    psa_pake_driver_step_t step,
+                                    psa_crypto_driver_pake_step_t step,
                                     const uint8_t *input,
                                     size_t input_length);
 
diff --git a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
index cf08794c60..b287b37a1d 100644
--- a/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
+++ b/scripts/data_files/driver_templates/psa_crypto_driver_wrappers.c.jinja
@@ -2865,7 +2865,7 @@ psa_status_t psa_driver_wrapper_pake_setup(
 }
 psa_status_t psa_driver_wrapper_pake_output(
     psa_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length )
@@ -2901,7 +2901,7 @@ psa_status_t psa_driver_wrapper_pake_output(
 
 psa_status_t psa_driver_wrapper_pake_input(
     psa_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     const uint8_t *input,
     size_t input_length )
 {
diff --git a/tests/include/test/drivers/pake.h b/tests/include/test/drivers/pake.h
index 23cb98aa43..d082d6e5ed 100644
--- a/tests/include/test/drivers/pake.h
+++ b/tests/include/test/drivers/pake.h
@@ -57,14 +57,14 @@ psa_status_t mbedtls_test_transparent_pake_setup(
 
 psa_status_t mbedtls_test_transparent_pake_output(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length);
 
 psa_status_t mbedtls_test_transparent_pake_input(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     const uint8_t *input,
     size_t input_length);
 
@@ -101,14 +101,14 @@ psa_status_t mbedtls_test_opaque_pake_set_role(
 
 psa_status_t mbedtls_test_opaque_pake_output(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length);
 
 psa_status_t mbedtls_test_opaque_pake_input(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     const uint8_t *input,
     size_t input_length);
 
diff --git a/tests/src/drivers/test_driver_pake.c b/tests/src/drivers/test_driver_pake.c
index 9d51ea10b7..615f7ef8a1 100644
--- a/tests/src/drivers/test_driver_pake.c
+++ b/tests/src/drivers/test_driver_pake.c
@@ -64,7 +64,7 @@ psa_status_t mbedtls_test_transparent_pake_setup(
 
 psa_status_t mbedtls_test_transparent_pake_output(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length)
@@ -112,7 +112,7 @@ psa_status_t mbedtls_test_transparent_pake_output(
 
 psa_status_t mbedtls_test_transparent_pake_input(
     mbedtls_transparent_test_driver_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     const uint8_t *input,
     size_t input_length)
 {
@@ -260,7 +260,7 @@ psa_status_t mbedtls_test_opaque_pake_set_role(
 
 psa_status_t mbedtls_test_opaque_pake_output(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     uint8_t *output,
     size_t output_size,
     size_t *output_length)
@@ -276,7 +276,7 @@ psa_status_t mbedtls_test_opaque_pake_output(
 
 psa_status_t mbedtls_test_opaque_pake_input(
     mbedtls_opaque_test_driver_pake_operation_t *operation,
-    psa_pake_driver_step_t step,
+    psa_crypto_driver_pake_step_t step,
     const uint8_t *input,
     size_t input_length)
 {