diff --git a/include/psa/crypto.h b/include/psa/crypto.h
index 487fce8222..08bdb84687 100644
--- a/include/psa/crypto.h
+++ b/include/psa/crypto.h
@@ -656,7 +656,8 @@ psa_status_t psa_destroy_key(psa_key_handle_t handle);
  *   and `PSA_ECC_CURVE_BRAINPOOL_PXXX`).
  *   This is the content of the `privateKey` field of the `ECPrivateKey`
  *   format defined by RFC 5915.
- * - For Diffie-Hellman key exchange key pairs (#PSA_KEY_TYPE_DH_KEYPAIR), the
+ * - For Diffie-Hellman key exchange key pairs (key types for which
+ *   #PSA_KEY_TYPE_IS_DH_KEYPAIR is true), the
  *   format is the representation of the private key `x` as a big-endian byte
  *   string. The length of the byte string is the private key size in bytes
  *   (leading zeroes are not stripped).
@@ -729,7 +730,8 @@ psa_status_t psa_export_key(psa_key_handle_t handle,
  *   representation of the public key `y = g^x mod p` as a big-endian byte
  *   string. The length of the byte string is the length of the base prime `p`
  *   in bytes.
- * - For Diffie-Hellman key exchange public keys (#PSA_KEY_TYPE_DH_PUBLIC_KEY),
+ * - For Diffie-Hellman key exchange public keys (key types for which
+ *   #PSA_KEY_TYPE_IS_DH_PUBLIC_KEY is true),
  *   the format is the representation of the public key `y = g^x mod p` as a
  *   big-endian byte string. The length of the byte string is the length of the
  *   base prime `p` in bytes.
@@ -3253,7 +3255,8 @@ psa_status_t psa_key_derivation_output_bytes(
  *       discard the first 8 bytes, use the next 8 bytes as the first key,
  *       and continue reading output from the operation to derive the other
  *       two keys).
- *     - Finite-field Diffie-Hellman keys (#PSA_KEY_TYPE_DH_KEYPAIR),
+ *     - Finite-field Diffie-Hellman keys (#PSA_KEY_TYPE_DH_KEYPAIR(\c group)
+ *       where \c group designates any Diffie-Hellman group),
  *       DSA keys (#PSA_KEY_TYPE_DSA_KEYPAIR), and
  *       ECC keys on a Weierstrass elliptic curve
  *       (#PSA_KEY_TYPE_ECC_KEYPAIR(\c curve) where \c curve designates a
diff --git a/include/psa/crypto_extra.h b/include/psa/crypto_extra.h
index 5016ba87ca..37d9b40b22 100644
--- a/include/psa/crypto_extra.h
+++ b/include/psa/crypto_extra.h
@@ -449,6 +449,16 @@ psa_status_t psa_generate_random_key_to_handle(psa_key_handle_t handle,
  * @{
  */
 
+/** Custom Diffie-Hellman group.
+ *
+ * For keys of type #PSA_KEY_TYPE_DH_PUBLIC_KEY(#PSA_DH_GROUP_CUSTOM) or
+ * #PSA_KEY_TYPE_DH_KEYPAIR(#PSA_DH_GROUP_CUSTOM), the group data comes
+ * from domain parameters set by psa_set_key_domain_parameters().
+ */
+/* This value is reserved for private use in the TLS named group registry. */
+#define PSA_DH_GROUP_CUSTOM             ((psa_dh_group_t) 0x01fc)
+
+
 /**
  * \brief Set domain parameters for a key.
  *
@@ -475,8 +485,9 @@ psa_status_t psa_generate_random_key_to_handle(psa_key_handle_t handle,
  *      g       INTEGER
  *   }
  *   ```
- * - For Diffie-Hellman key exchange keys (#PSA_KEY_TYPE_DH_PUBLIC_KEY or
- *   #PSA_KEY_TYPE_DH_KEYPAIR), the
+ * - For Diffie-Hellman key exchange keys
+ *   (#PSA_KEY_TYPE_DH_PUBLIC_KEY(#PSA_DH_GROUP_CUSTOM) or
+ *   #PSA_KEY_TYPE_DH_KEYPAIR(#PSA_DH_GROUP_CUSTOM)), the
  *   `DomainParameters` format as defined by RFC 3279 §2.3.3.
  *   ```
  *   DomainParameters ::= SEQUENCE {
diff --git a/include/psa/crypto_types.h b/include/psa/crypto_types.h
index ced42de1a6..02c26788f2 100644
--- a/include/psa/crypto_types.h
+++ b/include/psa/crypto_types.h
@@ -68,6 +68,9 @@ typedef uint32_t psa_key_type_t;
 /** The type of PSA elliptic curve identifiers. */
 typedef uint16_t psa_ecc_curve_t;
 
+/** The type of PSA Diffie-Hellman group identifiers. */
+typedef uint16_t psa_dh_group_t;
+
 /** \brief Encoding of a cryptographic algorithm.
  *
  * For algorithms that can be applied to multiple key types, this type
diff --git a/include/psa/crypto_values.h b/include/psa/crypto_values.h
index c54fc9a60f..6cd22c8401 100644
--- a/include/psa/crypto_values.h
+++ b/include/psa/crypto_values.h
@@ -492,14 +492,45 @@
 #define PSA_ECC_CURVE_CURVE25519        ((psa_ecc_curve_t) 0x001d)
 #define PSA_ECC_CURVE_CURVE448          ((psa_ecc_curve_t) 0x001e)
 
-/** Diffie-Hellman key exchange public key. */
-#define PSA_KEY_TYPE_DH_PUBLIC_KEY             ((psa_key_type_t)0x60040000)
-/** Diffie-Hellman key exchange key pair (private and public key). */
-#define PSA_KEY_TYPE_DH_KEYPAIR                ((psa_key_type_t)0x70040000)
-/** Whether a key type is a Diffie-Hellman key exchange key (pair or
- * public-only). */
-#define PSA_KEY_TYPE_IS_DH(type)                                       \
-    (PSA_KEY_TYPE_PUBLIC_KEY_OF_KEYPAIR(type) == PSA_KEY_TYPE_DH_PUBLIC_KEY)
+#define PSA_KEY_TYPE_DH_PUBLIC_KEY_BASE         ((psa_key_type_t)0x60040000)
+#define PSA_KEY_TYPE_DH_KEYPAIR_BASE            ((psa_key_type_t)0x70040000)
+#define PSA_KEY_TYPE_DH_GROUP_MASK              ((psa_key_type_t)0x0000ffff)
+/** Diffie-Hellman key pair. */
+#define PSA_KEY_TYPE_DH_KEYPAIR(group)          \
+    (PSA_KEY_TYPE_DH_KEYPAIR_BASE | (group))
+/** Diffie-Hellman public key. */
+#define PSA_KEY_TYPE_DH_PUBLIC_KEY(group)               \
+    (PSA_KEY_TYPE_DH_PUBLIC_KEY_BASE | (group))
+
+/** Whether a key type is a Diffie-Hellman key (pair or public-only). */
+#define PSA_KEY_TYPE_IS_DH(type)                                        \
+    ((PSA_KEY_TYPE_PUBLIC_KEY_OF_KEYPAIR(type) &                        \
+      ~PSA_KEY_TYPE_DH_GROUP_MASK) == PSA_KEY_TYPE_DH_PUBLIC_KEY_BASE)
+/** Whether a key type is a Diffie-Hellman key pair. */
+#define PSA_KEY_TYPE_IS_DH_KEYPAIR(type)                               \
+    (((type) & ~PSA_KEY_TYPE_DH_GROUP_MASK) ==                         \
+     PSA_KEY_TYPE_DH_KEYPAIR_BASE)
+/** Whether a key type is a Diffie-Hellman public key. */
+#define PSA_KEY_TYPE_IS_DH_PUBLIC_KEY(type)                            \
+    (((type) & ~PSA_KEY_TYPE_DH_GROUP_MASK) ==                         \
+     PSA_KEY_TYPE_DH_PUBLIC_KEY_BASE)
+
+/** Extract the group from a Diffie-Hellman key type. */
+#define PSA_KEY_TYPE_GET_GROUP(type)                            \
+    ((psa_dh_group_t) (PSA_KEY_TYPE_IS_DH(type) ?               \
+                       ((type) & PSA_KEY_TYPE_DH_GROUP_MASK) :  \
+                       0))
+
+/* The encoding of group identifiers is currently aligned with the
+ * TLS Supported Groups Registry (formerly known as the
+ * TLS EC Named Curve Registry)
+ * https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml#tls-parameters-8
+ * The values are defined by RFC 7919. */
+#define PSA_DH_GROUP_FFDHE2048          ((psa_dh_group_t) 0x0100)
+#define PSA_DH_GROUP_FFDHE3072          ((psa_dh_group_t) 0x0101)
+#define PSA_DH_GROUP_FFDHE4096          ((psa_dh_group_t) 0x0102)
+#define PSA_DH_GROUP_FFDHE6144          ((psa_dh_group_t) 0x0103)
+#define PSA_DH_GROUP_FFDHE8192          ((psa_dh_group_t) 0x0104)
 
 /** The block size of a block cipher.
  *
diff --git a/programs/psa/psa_constant_names.c b/programs/psa/psa_constant_names.c
index 5240b084a4..73692d0228 100644
--- a/programs/psa/psa_constant_names.c
+++ b/programs/psa/psa_constant_names.c
@@ -64,6 +64,7 @@ static void append_integer(char **buffer, size_t buffer_size,
 
 /* The code of these function is automatically generated and included below. */
 static const char *psa_ecc_curve_name(psa_ecc_curve_t curve);
+static const char *psa_dh_group_name(psa_dh_group_t group);
 static const char *psa_hash_algorithm_name(psa_algorithm_t hash_alg);
 
 static void append_with_curve(char **buffer, size_t buffer_size,
@@ -84,6 +85,24 @@ static void append_with_curve(char **buffer, size_t buffer_size,
     append(buffer, buffer_size, required_size, ")", 1);
 }
 
+static void append_with_group(char **buffer, size_t buffer_size,
+                              size_t *required_size,
+                              const char *string, size_t length,
+                              psa_dh_group_t group)
+{
+    const char *group_name = psa_dh_group_name(group);
+    append(buffer, buffer_size, required_size, string, length);
+    append(buffer, buffer_size, required_size, "(", 1);
+    if (group_name != NULL) {
+        append(buffer, buffer_size, required_size,
+               group_name, strlen(group_name));
+    } else {
+        append_integer(buffer, buffer_size, required_size,
+                       "0x%04x", group);
+    }
+    append(buffer, buffer_size, required_size, ")", 1);
+}
+
 typedef const char *(*psa_get_algorithm_name_func_ptr)(psa_algorithm_t alg);
 
 static void append_with_alg(char **buffer, size_t buffer_size,
@@ -137,6 +156,23 @@ static int psa_snprint_ecc_curve(char *buffer, size_t buffer_size,
     }
 }
 
+static int psa_snprint_dh_group(char *buffer, size_t buffer_size,
+                                psa_dh_group_t group)
+{
+    const char *name = psa_dh_group_name(group);
+    if (name == NULL) {
+        return snprintf(buffer, buffer_size, "0x%04x", (unsigned) group);
+    } else {
+        size_t length = strlen(name);
+        if (length < buffer_size) {
+            memcpy(buffer, name, length + 1);
+            return (int) length;
+        } else {
+            return (int) buffer_size;
+        }
+    }
+}
+
 static void usage(const char *program_name)
 {
     printf("Usage: %s TYPE VALUE [VALUE...]\n",
@@ -145,6 +181,7 @@ static void usage(const char *program_name)
     printf("Supported types (with = between aliases):\n");
     printf("  alg=algorithm         Algorithm (psa_algorithm_t)\n");
     printf("  curve=ecc_curve       Elliptic curve identifier (psa_ecc_curve_t)\n");
+    printf("  group=dh_group        Diffie-Hellman group identifier (psa_dh_group_t)\n");
     printf("  type=key_type         Key type (psa_key_type_t)\n");
     printf("  usage=key_usage       Key usage (psa_key_usage_t)\n");
     printf("  error=status          Status code (psa_status_t)\n");
@@ -188,6 +225,7 @@ int process_signed(signed_value_type type, long min, long max, char **argp)
 typedef enum {
     TYPE_ALGORITHM,
     TYPE_ECC_CURVE,
+    TYPE_DH_GROUP,
     TYPE_KEY_TYPE,
     TYPE_KEY_USAGE,
 } unsigned_value_type;
@@ -216,6 +254,10 @@ int process_unsigned(unsigned_value_type type, unsigned long max, char **argp)
                 psa_snprint_ecc_curve(buffer, sizeof(buffer),
                                       (psa_ecc_curve_t) value);
                 break;
+            case TYPE_DH_GROUP:
+                psa_snprint_dh_group(buffer, sizeof(buffer),
+                                     (psa_dh_group_t) value);
+                break;
             case TYPE_KEY_TYPE:
                 psa_snprint_key_type(buffer, sizeof(buffer),
                                      (psa_key_type_t) value);
@@ -252,6 +294,9 @@ int main(int argc, char *argv[])
     } else if (!strcmp(argv[1], "curve") || !strcmp(argv[1], "ecc_curve")) {
         return process_unsigned(TYPE_ECC_CURVE, (psa_ecc_curve_t) (-1),
                                 argv + 2);
+    } else if (!strcmp(argv[1], "group") || !strcmp(argv[1], "dh_group")) {
+        return process_unsigned(TYPE_DH_GROUP, (psa_dh_group_t) (-1),
+                                argv + 2);
     } else if (!strcmp(argv[1], "type") || !strcmp(argv[1], "key_type")) {
         return process_unsigned(TYPE_KEY_TYPE, (psa_key_type_t) (-1),
                                 argv + 2);
diff --git a/scripts/generate_psa_constants.py b/scripts/generate_psa_constants.py
index dac60034d4..ab7f1341f0 100755
--- a/scripts/generate_psa_constants.py
+++ b/scripts/generate_psa_constants.py
@@ -22,6 +22,14 @@ static const char *psa_ecc_curve_name(psa_ecc_curve_t curve)
     }
 }
 
+static const char *psa_dh_group_name(psa_dh_group_t group)
+{
+    switch (group) {
+    %(dh_group_cases)s
+    default: return NULL;
+    }
+}
+
 static const char *psa_hash_algorithm_name(psa_algorithm_t hash_alg)
 {
     switch (hash_alg) {
@@ -145,6 +153,12 @@ key_type_from_curve_template = '''if (%(tester)s(type)) {
                               PSA_KEY_TYPE_GET_CURVE(type));
         } else '''
 
+key_type_from_group_template = '''if (%(tester)s(type)) {
+            append_with_group(&buffer, buffer_size, &required_size,
+                              "%(builder)s", %(builder_length)s,
+                              PSA_KEY_TYPE_GET_GROUP(type));
+        } else '''
+
 algorithm_from_hash_template = '''if (%(tester)s(core_alg)) {
             append(&buffer, buffer_size, &required_size,
                    "%(builder)s(", %(builder_length)s + 1);
@@ -169,7 +183,9 @@ class MacroCollector:
         self.statuses = set()
         self.key_types = set()
         self.key_types_from_curve = {}
+        self.key_types_from_group = {}
         self.ecc_curves = set()
+        self.dh_groups = set()
         self.algorithms = set()
         self.hash_algorithms = set()
         self.ka_algorithms = set()
@@ -206,8 +222,12 @@ class MacroCollector:
             self.key_types.add(name)
         elif name.startswith('PSA_KEY_TYPE_') and parameter == 'curve':
             self.key_types_from_curve[name] = name[:13] + 'IS_' + name[13:]
+        elif name.startswith('PSA_KEY_TYPE_') and parameter == 'group':
+            self.key_types_from_group[name] = name[:13] + 'IS_' + name[13:]
         elif name.startswith('PSA_ECC_CURVE_') and not parameter:
             self.ecc_curves.add(name)
+        elif name.startswith('PSA_DH_GROUP_') and not parameter:
+            self.dh_groups.add(name)
         elif name.startswith('PSA_ALG_') and not parameter:
             if name in ['PSA_ALG_ECDSA_BASE',
                         'PSA_ALG_RSA_PKCS1V15_SIGN_BASE']:
@@ -265,6 +285,10 @@ class MacroCollector:
         return '\n    '.join(map(self.make_return_case,
                                  sorted(self.ecc_curves)))
 
+    def make_dh_group_cases(self):
+        return '\n    '.join(map(self.make_return_case,
+                                 sorted(self.dh_groups)))
+
     def make_key_type_cases(self):
         return '\n    '.join(map(self.make_append_case,
                                  sorted(self.key_types)))
@@ -274,11 +298,21 @@ class MacroCollector:
                                                'builder_length': len(builder),
                                                'tester': tester}
 
-    def make_key_type_code(self):
+    def make_key_type_from_group_code(self, builder, tester):
+        return key_type_from_group_template % {'builder': builder,
+                                               'builder_length': len(builder),
+                                               'tester': tester}
+
+    def make_ecc_key_type_code(self):
         d = self.key_types_from_curve
         make = self.make_key_type_from_curve_code
         return ''.join([make(k, d[k]) for k in sorted(d.keys())])
 
+    def make_dh_key_type_code(self):
+        d = self.key_types_from_group
+        make = self.make_key_type_from_group_code
+        return ''.join([make(k, d[k]) for k in sorted(d.keys())])
+
     def make_hash_algorithm_cases(self):
         return '\n    '.join(map(self.make_return_case,
                                  sorted(self.hash_algorithms)))
@@ -309,8 +343,10 @@ class MacroCollector:
         data = {}
         data['status_cases'] = self.make_status_cases()
         data['ecc_curve_cases'] = self.make_ecc_curve_cases()
+        data['dh_group_cases'] = self.make_dh_group_cases()
         data['key_type_cases'] = self.make_key_type_cases()
-        data['key_type_code'] = self.make_key_type_code()
+        data['key_type_code'] = (self.make_ecc_key_type_code() +
+                                 self.make_dh_key_type_code())
         data['hash_algorithm_cases'] = self.make_hash_algorithm_cases()
         data['ka_algorithm_cases'] = self.make_ka_algorithm_cases()
         data['algorithm_cases'] = self.make_algorithm_cases()
diff --git a/tests/scripts/test_psa_constant_names.py b/tests/scripts/test_psa_constant_names.py
index 421cf4e48a..cbe68b10d2 100755
--- a/tests/scripts/test_psa_constant_names.py
+++ b/tests/scripts/test_psa_constant_names.py
@@ -58,6 +58,7 @@ when applicable.'''
         self.statuses = set(['PSA_SUCCESS'])
         self.algorithms = set(['0xffffffff'])
         self.ecc_curves = set(['0xffff'])
+        self.dh_groups = set(['0xffff'])
         self.key_types = set(['0xffffffff'])
         self.key_usage_flags = set(['0x80000000'])
         # Hard-coded value for unknown algorithms
@@ -74,6 +75,7 @@ when applicable.'''
             'ERROR': self.statuses,
             'ALG': self.algorithms,
             'CURVE': self.ecc_curves,
+            'GROUP': self.dh_groups,
             'KEY_TYPE': self.key_types,
             'KEY_USAGE': self.key_usage_flags,
         }
@@ -94,6 +96,7 @@ Call this after parsing all the inputs.'''
         self.arguments_for['kdf_alg'] = sorted(self.kdf_algorithms)
         self.arguments_for['aead_alg'] = sorted(self.aead_algorithms)
         self.arguments_for['curve'] = sorted(self.ecc_curves)
+        self.arguments_for['group'] = sorted(self.dh_groups)
 
     def format_arguments(self, name, arguments):
         '''Format a macro call with arguments..'''
@@ -184,6 +187,8 @@ where each argument takes each possible value at least once.'''
             self.key_types.add(argument)
         elif function == 'ecc_key_types':
             self.ecc_curves.add(argument)
+        elif function == 'dh_key_types':
+            self.dh_groups.add(argument)
 
     # Regex matching a *.data line containing a test function call and
     # its arguments. The actual definition is partly positional, but this
@@ -299,6 +304,7 @@ not as expected.'''
     for type, names in [('status', inputs.statuses),
                         ('algorithm', inputs.algorithms),
                         ('ecc_curve', inputs.ecc_curves),
+                        ('dh_group', inputs.dh_groups),
                         ('key_type', inputs.key_types),
                         ('key_usage', inputs.key_usage_flags)]:
         c, e = do_test(options, inputs, type, names)
diff --git a/tests/suites/test_suite_psa_crypto_metadata.data b/tests/suites/test_suite_psa_crypto_metadata.data
index 94b80acdda..165b866543 100644
--- a/tests/suites/test_suite_psa_crypto_metadata.data
+++ b/tests/suites/test_suite_psa_crypto_metadata.data
@@ -454,3 +454,19 @@ ecc_key_types:PSA_ECC_CURVE_CURVE25519:255
 ECC key types: Curve448
 depends_on:MBEDTLS_ECP_DP_CURVE448_ENABLED
 ecc_key_types:PSA_ECC_CURVE_CURVE448:448
+
+DH group types: FFDHE2048
+dh_key_types:PSA_DH_GROUP_FFDHE2048:2048
+
+DH group types: FFDHE3072
+dh_key_types:PSA_DH_GROUP_FFDHE3072:2048
+
+DH group types: FFDHE4096
+dh_key_types:PSA_DH_GROUP_FFDHE4096:2048
+
+DH group types: FFDHE6144
+dh_key_types:PSA_DH_GROUP_FFDHE6144:2048
+
+DH group types: FFDHE8192
+dh_key_types:PSA_DH_GROUP_FFDHE8192:2048
+
diff --git a/tests/suites/test_suite_psa_crypto_metadata.function b/tests/suites/test_suite_psa_crypto_metadata.function
index e1eb1c526d..81b2937fae 100644
--- a/tests/suites/test_suite_psa_crypto_metadata.function
+++ b/tests/suites/test_suite_psa_crypto_metadata.function
@@ -49,6 +49,7 @@
 #define KEY_TYPE_IS_RSA                 ( 1u << 4 )
 #define KEY_TYPE_IS_DSA                 ( 1u << 5 )
 #define KEY_TYPE_IS_ECC                 ( 1u << 6 )
+#define KEY_TYPE_IS_DH                  ( 1u << 7 )
 
 #define TEST_CLASSIFICATION_MACRO( flag, alg, flags )           \
     TEST_ASSERT( PSA_##flag( alg ) == !! ( ( flags ) & flag ) )
@@ -91,6 +92,7 @@ void key_type_classification( psa_key_type_t type, unsigned flags )
     TEST_CLASSIFICATION_MACRO( KEY_TYPE_IS_KEYPAIR, type, flags );
     TEST_CLASSIFICATION_MACRO( KEY_TYPE_IS_RSA, type, flags );
     TEST_CLASSIFICATION_MACRO( KEY_TYPE_IS_ECC, type, flags );
+    TEST_CLASSIFICATION_MACRO( KEY_TYPE_IS_DH, type, flags );
 
     /* Macros with derived semantics */
     TEST_EQUAL( PSA_KEY_TYPE_IS_ASYMMETRIC( type ),
@@ -102,6 +104,12 @@ void key_type_classification( psa_key_type_t type, unsigned flags )
     TEST_EQUAL( PSA_KEY_TYPE_IS_ECC_PUBLIC_KEY( type ),
                 ( PSA_KEY_TYPE_IS_ECC( type ) &&
                   PSA_KEY_TYPE_IS_PUBLIC_KEY( type ) ) );
+    TEST_EQUAL( PSA_KEY_TYPE_IS_DH_KEYPAIR( type ),
+                ( PSA_KEY_TYPE_IS_DH( type ) &&
+                  PSA_KEY_TYPE_IS_KEYPAIR( type ) ) );
+    TEST_EQUAL( PSA_KEY_TYPE_IS_DH_PUBLIC_KEY( type ),
+                ( PSA_KEY_TYPE_IS_DH( type ) &&
+                  PSA_KEY_TYPE_IS_PUBLIC_KEY( type ) ) );
 
 exit: ;
 }
@@ -457,3 +465,22 @@ void ecc_key_types( int curve_arg, int curve_bits_arg )
     TEST_ASSERT( curve_bits <= PSA_VENDOR_ECC_MAX_CURVE_BITS );
 }
 /* END_CASE */
+
+/* BEGIN_CASE depends_on:MBEDTLS_DHM_C */
+void dh_key_types( int group_arg, int group_bits_arg )
+{
+    psa_dh_group_t group = group_arg;
+    size_t group_bits = group_bits_arg;
+    psa_key_type_t public_type = PSA_KEY_TYPE_DH_PUBLIC_KEY( group );
+    psa_key_type_t pair_type = PSA_KEY_TYPE_DH_KEYPAIR( group );
+
+    test_key_type( public_type, KEY_TYPE_IS_DH | KEY_TYPE_IS_PUBLIC_KEY );
+    test_key_type( pair_type, KEY_TYPE_IS_DH | KEY_TYPE_IS_KEYPAIR );
+
+    TEST_EQUAL( PSA_KEY_TYPE_GET_GROUP( public_type ), group );
+    TEST_EQUAL( PSA_KEY_TYPE_GET_GROUP( pair_type ), group );
+
+    /* We have nothing to validate about the group size yet. */
+    (void) group_bits;
+}
+/* END_CASE */