14
14
#include <zephyr/ztest.h>
15
15
16
16
#include <mbedtls/x509.h>
17
+ #include <mbedtls/x509_crt.h>
17
18
18
19
LOG_MODULE_REGISTER (tls_test , CONFIG_NET_SOCKETS_LOG_LEVEL );
19
20
@@ -150,6 +151,7 @@ static void server_thread_fn(void *arg0, void *arg1, void *arg2)
150
151
{
151
152
const int server_fd = POINTER_TO_INT (arg0 );
152
153
const int echo = POINTER_TO_INT (arg1 );
154
+ const int expect_failure = POINTER_TO_INT (arg2 );
153
155
154
156
int r ;
155
157
int client_fd ;
@@ -168,6 +170,10 @@ static void server_thread_fn(void *arg0, void *arg1, void *arg2)
168
170
NET_DBG ("Accepting client connection.." );
169
171
k_sem_give (& server_sem );
170
172
r = accept (server_fd , (struct sockaddr * )& sa , & addrlen );
173
+ if (expect_failure ) {
174
+ zassert_equal (r , -1 , "accept() should've failed" );
175
+ return ;
176
+ }
171
177
zassert_not_equal (r , -1 , "accept() failed (%d)" , r );
172
178
client_fd = r ;
173
179
@@ -199,7 +205,7 @@ static void server_thread_fn(void *arg0, void *arg1, void *arg2)
199
205
}
200
206
201
207
static int test_configure_server (k_tid_t * server_thread_id , int peer_verify ,
202
- int echo )
208
+ int echo , int expect_failure )
203
209
{
204
210
static const sec_tag_t server_tag_list_verify_none [] = {
205
211
SERVER_CERTIFICATE_TAG ,
@@ -282,7 +288,8 @@ static int test_configure_server(k_tid_t *server_thread_id, int peer_verify,
282
288
* server_thread_id = k_thread_create (& server_thread , server_stack ,
283
289
STACK_SIZE , server_thread_fn ,
284
290
INT_TO_POINTER (server_fd ),
285
- INT_TO_POINTER (echo ), NULL ,
291
+ INT_TO_POINTER (echo ),
292
+ INT_TO_POINTER (expect_failure ),
286
293
K_PRIO_PREEMPT (8 ), 0 , K_NO_WAIT );
287
294
288
295
r = k_sem_take (& server_sem , K_MSEC (TIMEOUT ));
@@ -380,7 +387,8 @@ static void test_common(int peer_verify)
380
387
/*
381
388
* Server socket setup
382
389
*/
383
- server_fd = test_configure_server (& server_thread_id , peer_verify , true);
390
+ server_fd = test_configure_server (& server_thread_id , peer_verify , true,
391
+ false);
384
392
385
393
/*
386
394
* Client socket setup
@@ -444,7 +452,7 @@ static void test_tls_cert_verify_result_opt_common(uint32_t expect)
444
452
}
445
453
446
454
server_fd = test_configure_server (& server_thread_id , TLS_PEER_VERIFY_NONE ,
447
- false);
455
+ false, false );
448
456
client_fd = test_configure_client (& sa , false, hostname );
449
457
450
458
ret = zsock_setsockopt (client_fd , SOL_TLS , TLS_PEER_VERIFY ,
@@ -473,6 +481,71 @@ ZTEST(net_socket_tls_api_extension, test_tls_cert_verify_result_opt_bad_cn)
473
481
test_tls_cert_verify_result_opt_common (MBEDTLS_X509_BADCERT_CN_MISMATCH );
474
482
}
475
483
484
+ struct test_cert_verify_ctx {
485
+ bool cb_called ;
486
+ int result ;
487
+ };
488
+
489
+ static int cert_verify_cb (void * ctx , mbedtls_x509_crt * crt , int depth ,
490
+ uint32_t * flags )
491
+ {
492
+ struct test_cert_verify_ctx * test_ctx = (struct test_cert_verify_ctx * )ctx ;
493
+
494
+ test_ctx -> cb_called = true;
495
+
496
+ if (test_ctx -> result == 0 ) {
497
+ * flags = 0 ;
498
+ } else {
499
+ * flags |= MBEDTLS_X509_BADCERT_NOT_TRUSTED ;
500
+ }
501
+
502
+ return test_ctx -> result ;
503
+ }
504
+
505
+ static void test_tls_cert_verify_cb_opt_common (int result )
506
+ {
507
+ int server_fd , client_fd , ret ;
508
+ k_tid_t server_thread_id ;
509
+ struct sockaddr_in sa ;
510
+ struct test_cert_verify_ctx ctx = {
511
+ .cb_called = false,
512
+ .result = result ,
513
+ };
514
+ struct tls_cert_verify_cb cb = {
515
+ .cb = cert_verify_cb ,
516
+ .ctx = & ctx ,
517
+ };
518
+
519
+ server_fd = test_configure_server (& server_thread_id , TLS_PEER_VERIFY_NONE ,
520
+ false, result == 0 ? false : true);
521
+ client_fd = test_configure_client (& sa , false, "localhost" );
522
+
523
+ ret = zsock_setsockopt (client_fd , SOL_TLS , TLS_CERT_VERIFY_CALLBACK ,
524
+ & cb , sizeof (cb ));
525
+ zassert_ok (ret , "failed to set TLS_CERT_VERIFY_CALLBACK (%d)" , errno );
526
+
527
+ ret = zsock_connect (client_fd , (struct sockaddr * )& sa , sizeof (sa ));
528
+ zassert_true (ctx .cb_called , "callback not called" );
529
+ if (result == 0 ) {
530
+ zassert_equal (ret , 0 , "failed to connect (%d)" , errno );
531
+ } else {
532
+ zassert_equal (ret , -1 , "connect() should fail" );
533
+ zassert_equal (errno , ECONNABORTED , "invalid errno" );
534
+ }
535
+
536
+ test_shutdown (client_fd , server_fd , server_thread_id );
537
+ }
538
+
539
+ ZTEST (net_socket_tls_api_extension , test_tls_cert_verify_cb_opt_ok )
540
+ {
541
+ test_tls_cert_verify_cb_opt_common (0 );
542
+ }
543
+
544
+ ZTEST (net_socket_tls_api_extension , test_tls_cert_verify_cb_opt_bad_cert )
545
+ {
546
+ test_tls_cert_verify_cb_opt_common (MBEDTLS_ERR_X509_CERT_VERIFY_FAILED );
547
+ }
548
+
476
549
static void * setup (void )
477
550
{
478
551
int r ;
0 commit comments