diff --git a/include/polarssl/ssl_ciphersuites.h b/include/polarssl/ssl_ciphersuites.h
index 714cdcdfac..85392c177a 100644
--- a/include/polarssl/ssl_ciphersuites.h
+++ b/include/polarssl/ssl_ciphersuites.h
@@ -27,6 +27,7 @@
 #ifndef POLARSSL_SSL_CIPHERSUITES_H
 #define POLARSSL_SSL_CIPHERSUITES_H
 
+#include "pk.h"
 #include "cipher.h"
 #include "md.h"
 
@@ -197,6 +198,8 @@ const int *ssl_ciphersuites_list( void );
 const ssl_ciphersuite_t *ssl_ciphersuite_from_string( const char *ciphersuite_name );
 const ssl_ciphersuite_t *ssl_ciphersuite_from_id( int ciphersuite_id );
 
+pk_type_t ssl_get_ciphersuite_sig_pk_alg( const ssl_ciphersuite_t *info );
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/library/ssl_ciphersuites.c b/library/ssl_ciphersuites.c
index 63601f66f2..759845ee99 100644
--- a/library/ssl_ciphersuites.c
+++ b/library/ssl_ciphersuites.c
@@ -916,4 +916,20 @@ int ssl_get_ciphersuite_id( const char *ciphersuite_name )
     return( cur->id );
 }
 
+pk_type_t ssl_get_ciphersuite_sig_pk_alg( const ssl_ciphersuite_t *info )
+{
+    switch( info->key_exchange )
+    {
+        case POLARSSL_KEY_EXCHANGE_DHE_RSA:
+        case POLARSSL_KEY_EXCHANGE_ECDHE_RSA:
+            return( POLARSSL_PK_RSA );
+
+        case POLARSSL_KEY_EXCHANGE_ECDHE_ECDSA:
+            return( POLARSSL_PK_ECDSA );
+
+        default:
+            return( POLARSSL_PK_NONE );
+    }
+}
+
 #endif
diff --git a/library/ssl_cli.c b/library/ssl_cli.c
index 267e385952..605d4668d2 100644
--- a/library/ssl_cli.c
+++ b/library/ssl_cli.c
@@ -1394,6 +1394,19 @@ static int ssl_parse_server_key_exchange( ssl_context *ssl )
             return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
         }
 
+        if( pk_alg != POLARSSL_PK_NONE )
+        {
+            if( pk_alg != ssl_get_ciphersuite_sig_pk_alg( ciphersuite_info ) )
+            {
+                SSL_DEBUG_MSG( 1, ( "bad server key exchange message" ) );
+                return( POLARSSL_ERR_SSL_BAD_HS_SERVER_KEY_EXCHANGE );
+            }
+        }
+        else
+        {
+            pk_alg = ssl_get_ciphersuite_sig_pk_alg( ciphersuite_info );
+        }
+
         sig_len = ( p[0] << 8 ) | p[1];
         p += 2;