diff --git a/tests/include/test/threading_helpers.h b/tests/include/test/threading_helpers.h index 3c4d44126a..9b7ced519b 100644 --- a/tests/include/test/threading_helpers.h +++ b/tests/include/test/threading_helpers.h @@ -15,6 +15,66 @@ #if defined MBEDTLS_THREADING_C +#include "mbedtls/private_access.h" +#include "mbedtls/build_info.h" + +/* Most fields of publicly available structs are private and are wrapped with + * MBEDTLS_PRIVATE macro. This define allows tests to access the private fields + * directly (without using the MBEDTLS_PRIVATE wrapper). */ +#define MBEDTLS_ALLOW_PRIVATE_ACCESS + +#define MBEDTLS_ERR_THREADING_THREAD_ERROR -0x001F + +#if defined(MBEDTLS_THREADING_PTHREAD) +#include + +typedef struct mbedtls_test_thread_t { + pthread_t MBEDTLS_PRIVATE(thread); +} mbedtls_test_thread_t; + +#endif /* MBEDTLS_THREADING_PTHREAD */ + +#if defined(MBEDTLS_THREADING_ALT) +/* You should define the mbedtls_test_thread_t type in your header */ +#include "threading_alt.h" + +/** + * \brief Set your alternate threading implementation + * function pointers fgr test threads. If used, + * this function must be called once in the main thread + * before any other MbedTLS function is called. + * + * \note These functions are part of the testing API only and + * thus not considered part of the public API of + * MbedTLS and thus may change without notice. + * + * \param thread_create The thread create function implementation + * \param thread_join The thread join function implementation + + */ +void mbedtls_test_thread_set_alt(int (*thread_create)(mbedtls_test_thread_t *thread, + void *(*thread_func)( + void *), + void *thread_data), + int (*thread_join)(mbedtls_test_thread_t *thread)); + +#endif /* MBEDTLS_THREADING_ALT*/ + +/** + * \brief The function pointers for thread create and thread + * join. + * + * \note These functions are part of the testing API only and + * thus not considered part of the public API of + * MbedTLS and thus may change without notice. + * + * \note All these functions are expected to work or + * the result will be undefined. + */ +extern int (*mbedtls_test_thread_create)(mbedtls_test_thread_t *thread, + void *(*thread_func)(void *), void *thread_data); +extern int (*mbedtls_test_thread_join)(mbedtls_test_thread_t *thread); + #if defined(MBEDTLS_THREADING_PTHREAD) && defined(MBEDTLS_TEST_HOOKS) #define MBEDTLS_TEST_MUTEX_USAGE #endif @@ -42,4 +102,3 @@ void mbedtls_test_mutex_usage_check(void); #endif /* MBEDTLS_THREADING_C */ #endif /* THREADING_HELPERS_H */ - diff --git a/tests/src/threading_helpers.c b/tests/src/threading_helpers.c index 38059343d8..5a871e102d 100644 --- a/tests/src/threading_helpers.c +++ b/tests/src/threading_helpers.c @@ -9,6 +9,71 @@ #include #include +#include "mbedtls/threading.h" + +#if defined(MBEDTLS_THREADING_C) + +#if defined(MBEDTLS_THREADING_PTHREAD) + +static int threading_thread_create_pthread(mbedtls_test_thread_t *thread, void *(*thread_func)( + void *), void *thread_data) +{ + if (thread == NULL || thread_func == NULL) { + return MBEDTLS_ERR_THREADING_BAD_INPUT_DATA; + } + + if (pthread_create(&thread->thread, NULL, thread_func, thread_data)) { + return MBEDTLS_ERR_THREADING_THREAD_ERROR; + } + + return 0; +} + +static int threading_thread_join_pthread(mbedtls_test_thread_t *thread) +{ + if (thread == NULL) { + return MBEDTLS_ERR_THREADING_BAD_INPUT_DATA; + } + + if (pthread_join(thread->thread, NULL) != 0) { + return MBEDTLS_ERR_THREADING_THREAD_ERROR; + } + + return 0; +} + +int (*mbedtls_test_thread_create)(mbedtls_test_thread_t *thread, void *(*thread_func)(void *), + void *thread_data) = threading_thread_create_pthread; +int (*mbedtls_test_thread_join)(mbedtls_test_thread_t *thread) = threading_thread_join_pthread; + +#endif /* MBEDTLS_THREADING_PTHREAD */ + +#if defined(MBEDTLS_THREADING_ALT) + +static int threading_thread_create_fail(mbedtls_test_thread_t *thread, + void *(*thread_func)(void *), + void *thread_data) +{ + (void) thread; + (void) thread_func; + (void) thread_data; + + return MBEDTLS_ERR_THREADING_BAD_INPUT_DATA; +} + +static int threading_thread_join_fail(mbedtls_test_thread_t *thread) +{ + (void) thread; + + return MBEDTLS_ERR_THREADING_BAD_INPUT_DATA; +} + +int (*mbedtls_test_thread_create)(mbedtls_test_thread_t *thread, void *(*thread_func)(void *), + void *thread_data) = threading_thread_create_fail; +int (*mbedtls_test_thread_join)(mbedtls_test_thread_t *thread) = threading_thread_join_fail; + +#endif /* MBEDTLS_THREADING_ALT */ + #if defined(MBEDTLS_TEST_MUTEX_USAGE) #include "mbedtls/threading.h" @@ -258,3 +323,5 @@ void mbedtls_test_mutex_usage_end(void) } #endif /* MBEDTLS_TEST_MUTEX_USAGE */ + +#endif /* MBEDTLS_THREADING_C */