Add test thread create/join abstraction

Signed-off-by: Paul Elliott <paul.elliott@arm.com>
This commit is contained in:
Paul Elliott 2023-12-08 20:49:47 +00:00
parent 17c119a5e3
commit 3a4d2f14a8
2 changed files with 127 additions and 1 deletions

View File

@ -15,6 +15,66 @@
#if defined MBEDTLS_THREADING_C
#include "mbedtls/private_access.h"
#include "mbedtls/build_info.h"
/* Most fields of publicly available structs are private and are wrapped with
* MBEDTLS_PRIVATE macro. This define allows tests to access the private fields
* directly (without using the MBEDTLS_PRIVATE wrapper). */
#define MBEDTLS_ALLOW_PRIVATE_ACCESS
#define MBEDTLS_ERR_THREADING_THREAD_ERROR -0x001F
#if defined(MBEDTLS_THREADING_PTHREAD)
#include <pthread.h>
typedef struct mbedtls_test_thread_t {
pthread_t MBEDTLS_PRIVATE(thread);
} mbedtls_test_thread_t;
#endif /* MBEDTLS_THREADING_PTHREAD */
#if defined(MBEDTLS_THREADING_ALT)
/* You should define the mbedtls_test_thread_t type in your header */
#include "threading_alt.h"
/**
* \brief Set your alternate threading implementation
* function pointers fgr test threads. If used,
* this function must be called once in the main thread
* before any other MbedTLS function is called.
*
* \note These functions are part of the testing API only and
* thus not considered part of the public API of
* MbedTLS and thus may change without notice.
*
* \param thread_create The thread create function implementation
* \param thread_join The thread join function implementation
*/
void mbedtls_test_thread_set_alt(int (*thread_create)(mbedtls_test_thread_t *thread,
void *(*thread_func)(
void *),
void *thread_data),
int (*thread_join)(mbedtls_test_thread_t *thread));
#endif /* MBEDTLS_THREADING_ALT*/
/**
* \brief The function pointers for thread create and thread
* join.
*
* \note These functions are part of the testing API only and
* thus not considered part of the public API of
* MbedTLS and thus may change without notice.
*
* \note All these functions are expected to work or
* the result will be undefined.
*/
extern int (*mbedtls_test_thread_create)(mbedtls_test_thread_t *thread,
void *(*thread_func)(void *), void *thread_data);
extern int (*mbedtls_test_thread_join)(mbedtls_test_thread_t *thread);
#if defined(MBEDTLS_THREADING_PTHREAD) && defined(MBEDTLS_TEST_HOOKS)
#define MBEDTLS_TEST_MUTEX_USAGE
#endif
@ -42,4 +102,3 @@ void mbedtls_test_mutex_usage_check(void);
#endif /* MBEDTLS_THREADING_C */
#endif /* THREADING_HELPERS_H */

View File

@ -9,6 +9,71 @@
#include <test/threading_helpers.h>
#include <test/macros.h>
#include "mbedtls/threading.h"
#if defined(MBEDTLS_THREADING_C)
#if defined(MBEDTLS_THREADING_PTHREAD)
static int threading_thread_create_pthread(mbedtls_test_thread_t *thread, void *(*thread_func)(
void *), void *thread_data)
{
if (thread == NULL || thread_func == NULL) {
return MBEDTLS_ERR_THREADING_BAD_INPUT_DATA;
}
if (pthread_create(&thread->thread, NULL, thread_func, thread_data)) {
return MBEDTLS_ERR_THREADING_THREAD_ERROR;
}
return 0;
}
static int threading_thread_join_pthread(mbedtls_test_thread_t *thread)
{
if (thread == NULL) {
return MBEDTLS_ERR_THREADING_BAD_INPUT_DATA;
}
if (pthread_join(thread->thread, NULL) != 0) {
return MBEDTLS_ERR_THREADING_THREAD_ERROR;
}
return 0;
}
int (*mbedtls_test_thread_create)(mbedtls_test_thread_t *thread, void *(*thread_func)(void *),
void *thread_data) = threading_thread_create_pthread;
int (*mbedtls_test_thread_join)(mbedtls_test_thread_t *thread) = threading_thread_join_pthread;
#endif /* MBEDTLS_THREADING_PTHREAD */
#if defined(MBEDTLS_THREADING_ALT)
static int threading_thread_create_fail(mbedtls_test_thread_t *thread,
void *(*thread_func)(void *),
void *thread_data)
{
(void) thread;
(void) thread_func;
(void) thread_data;
return MBEDTLS_ERR_THREADING_BAD_INPUT_DATA;
}
static int threading_thread_join_fail(mbedtls_test_thread_t *thread)
{
(void) thread;
return MBEDTLS_ERR_THREADING_BAD_INPUT_DATA;
}
int (*mbedtls_test_thread_create)(mbedtls_test_thread_t *thread, void *(*thread_func)(void *),
void *thread_data) = threading_thread_create_fail;
int (*mbedtls_test_thread_join)(mbedtls_test_thread_t *thread) = threading_thread_join_fail;
#endif /* MBEDTLS_THREADING_ALT */
#if defined(MBEDTLS_TEST_MUTEX_USAGE)
#include "mbedtls/threading.h"
@ -258,3 +323,5 @@ void mbedtls_test_mutex_usage_end(void)
}
#endif /* MBEDTLS_TEST_MUTEX_USAGE */
#endif /* MBEDTLS_THREADING_C */