From 806a8c48ce811af33b25345762bd533423267f97 Mon Sep 17 00:00:00 2001 From: Balint Molnar Date: Fri, 3 Nov 2023 13:46:10 +0100 Subject: [PATCH] Initial impl of cert bundle caching (#81) * Initial impl of cert bundle caching * Add certificate validation for expiration * Remove code dup around cert_cache handling * Use id from opa.h insted of keyid * Remove global rsa_key and use per connection one * Fix memory leak for rsa_pub rsa_priv key * Fix memory leak around cache and generated certificate * Add global rsa_private key back since it is causing additional errors * Fix cache key mismatch happened because of freeing the char* id after socket close * Fix typo in function free_rsa_public_key * Wip, todo need to rearange list initialization * Use proper lock handling for remove cert from cache * Allign with linux kernel naming on locked function * Fix word typo * refactor cert mem representation * remove unnecessary socket closed logs * passthrough if the agent is not running * Optimize certificate decode, and store validity in the struct * Beautify function and field names, return if cert is null --------- Co-authored-by: Zsolt Varga --- Makefile | 1 + cert_tools.c | 242 +++++++++++++++++++++++++++++++++++++++++ cert_tools.h | 57 ++++++++++ commands.c | 41 ++++--- commands.h | 6 +- device_driver.c | 8 +- device_driver.h | 8 ++ rsa_tools.c | 40 +++++-- rsa_tools.h | 2 + socket.c | 258 ++++++++++++++++++++++++-------------------- third-party/BearSSL | 2 +- 11 files changed, 513 insertions(+), 152 deletions(-) create mode 100644 cert_tools.c create mode 100644 cert_tools.h diff --git a/Makefile b/Makefile index b1f08f7a..2aa4b3f9 100644 --- a/Makefile +++ b/Makefile @@ -60,6 +60,7 @@ nasp-objs := third-party/wasm3/source/m3_api_libc.o \ main.o \ csr.o \ rsa_tools.o \ + cert_tools.o \ wasm.o \ opa.o \ proxywasm.o \ diff --git a/cert_tools.c b/cert_tools.c new file mode 100644 index 00000000..3b7c515f --- /dev/null +++ b/cert_tools.c @@ -0,0 +1,242 @@ +/* + * Copyright (c) 2023 Cisco and/or its affiliates. All rights reserved. + * + * SPDX-License-Identifier: MIT OR GPL-2.0-only + * + * Licensed under the MIT license or the GPLv2 license + * , at your option. This file may not be copied, + * modified, or distributed except according to those terms. + */ + +#include +#include + +#include "cert_tools.h" +#include "string.h" +#include "rsa_tools.h" + +// Define the maximum number of elements inside the cache +#define MAX_CACHE_LENGTH 64 + +// certs that are in use or used once by a workload +static LIST_HEAD(cert_cache); + +// lock for the above list to make it thread safe +static DEFINE_MUTEX(certificate_cache_lock); + +static void cert_cache_lock(void) +{ + mutex_lock(&certificate_cache_lock); +} + +static void cert_cache_unlock(void) +{ + mutex_unlock(&certificate_cache_lock); +} + +static size_t linkedlist_length(struct list_head *head) +{ + struct list_head *pos; + int length = 0; + + list_for_each(pos, head) + { + length++; + } + + return length; +} + +// add_cert_to_cache adds a certificate chain with a given trust anchor to a linked list. The key will identify this entry. +// the function is thread safe. +void add_cert_to_cache(char *key, x509_certificate *cert) +{ + if (!key) + { + pr_err("nasp: provided key is null"); + return; + } + + cert_with_key *new_entry = kzalloc(sizeof(cert_with_key), GFP_KERNEL); + if (!new_entry) + { + pr_err("nasp: memory allocation error"); + return; + } + new_entry->key = strdup(key); + new_entry->cert = cert; + + cert_cache_lock(); + INIT_LIST_HEAD(&new_entry->list); + list_add(&new_entry->list, &cert_cache); + cert_cache_unlock(); +} + +// remove_unused_expired_certs_from_cache iterates over the whole cache and tries to clean up the unused/expired certificates. +// it works like a garbage collection which now runs before every add. +// TODO handle cases when cache length is maxed out but no expired certificate +void remove_unused_expired_certs_from_cache() +{ + cert_with_key *cert_bundle, *cert_bundle_tmp; + + if (linkedlist_length(&cert_cache) >= MAX_CACHE_LENGTH) + { + pr_warn("nasp: cache is full removing the oldest element"); + cert_with_key *last_entry = list_last_entry(&cert_cache, cert_with_key, list); + pr_warn("nasp: removing key:%s from the cache", last_entry->key); + remove_cert_from_cache(last_entry); + return; + } + + cert_cache_lock(); + list_for_each_entry_safe_reverse(cert_bundle, cert_bundle_tmp, &cert_cache, list) + { + if (!validate_cert(cert_bundle->cert->validity)) + { + remove_cert_from_cache_locked(cert_bundle); + } + } + cert_cache_unlock(); +} + +// find_cert_from_cache tries to find a certificate bundle for the given key. In case of failure it returns a NULL. +// this function also runs a garbage collection on the cache. +// the function is thread safe +cert_with_key *find_cert_from_cache(char *key) +{ + + remove_unused_expired_certs_from_cache(); + + cert_with_key *cert_bundle; + cert_cache_lock(); + list_for_each_entry(cert_bundle, &cert_cache, list) + { + if (strncmp(cert_bundle->key, key, strlen(key)) == 0) + { + cert_cache_unlock(); + return cert_bundle; + } + } + cert_cache_unlock(); + return 0; +} + +// remove_cert_from_cache_locked removes a given certificate bundle from the cache +// the function is thread safe +void remove_cert_from_cache(cert_with_key *cert_bundle) +{ + if (cert_bundle) + { + cert_cache_lock(); + remove_cert_from_cache_locked(cert_bundle); + cert_cache_unlock(); + } +} + +// remove_cert_from_cache removes a given certificate bundle from the cache +void remove_cert_from_cache_locked(cert_with_key *cert_bundle) +{ + if (cert_bundle) + { + list_del(&cert_bundle->list); + x509_certificate_put(cert_bundle->cert); + kfree(cert_bundle); + } +} + +// set_cert_validity decodes the provided certificate and filling the validity seconds and days. +// if the decode fails it returns -1 +int set_cert_validity(x509_certificate *x509_cert) +{ + br_x509_decoder_context dc; + + br_x509_decoder_init(&dc, 0, 0); + br_x509_decoder_push(&dc, x509_cert->chain->data, x509_cert->chain->data_len); + int err = br_x509_decoder_last_error(&dc); + if (err != 0) + { + pr_err("nasp: cert decode faild during setting cert validity: %d", err); + return -1; + } + x509_cert->validity.notbefore_seconds = dc.notbefore_seconds; + x509_cert->validity.notbefore_days = dc.notbefore_days; + + x509_cert->validity.notafter_seconds = dc.notafter_seconds; + x509_cert->validity.notafter_days = dc.notafter_days; + + return 0; +} + +// validate_cert validates the given certificate if it has expired or not. +bool validate_cert(x509_certificate_validity cert_validity) +{ + bool result = false; + + time64_t x = ktime_get_real_seconds(); + uint32_t vd = (uint32_t)(x / 86400) + 719528; + uint32_t vs = (uint32_t)(x % 86400); + + if (vd < cert_validity.notbefore_days || (vd == cert_validity.notbefore_days && vs < cert_validity.notbefore_seconds)) + { + pr_warn("nasp: cert expired"); + } + else if (vd > cert_validity.notafter_days || (vd == cert_validity.notafter_days && vs > cert_validity.notafter_seconds)) + { + pr_warn("nasp: cert not valid yet"); + } + else + { + result = true; + } + + return result; +} + +x509_certificate *x509_certificate_init(void) +{ + x509_certificate *cert = kzalloc(sizeof(x509_certificate), GFP_KERNEL); + + kref_init(&cert->kref); + + return cert; +} + +static void x509_certificate_free(x509_certificate *cert) +{ + pr_info("nasp: x509_certificate_free"); + + if (!cert) + { + return; + } + + free_br_x509_certificate(cert->chain, cert->chain_len); + free_br_x509_trust_anchors(cert->trust_anchors, cert->trust_anchors_len); + + kfree(cert); +} + +static void x509_certificate_release(struct kref *kref) +{ + x509_certificate *cert = container_of(kref, x509_certificate, kref); + + x509_certificate_free(cert); +} + +void x509_certificate_get(x509_certificate *cert) +{ + if (!cert) + { + return; + } + kref_get(&cert->kref); +} + +void x509_certificate_put(x509_certificate *cert) +{ + if (!cert) + { + return; + } + kref_put(&cert->kref, x509_certificate_release); +} diff --git a/cert_tools.h b/cert_tools.h new file mode 100644 index 00000000..65ef44bb --- /dev/null +++ b/cert_tools.h @@ -0,0 +1,57 @@ +/* + * Copyright (c) 2023 Cisco and/or its affiliates. All rights reserved. + * + * SPDX-License-Identifier: MIT OR GPL-2.0-only + * + * Licensed under the MIT license or the GPLv2 license + * , at your option. This file may not be copied, + * modified, or distributed except according to those terms. + */ + +#ifndef cert_tools_h +#define cert_tools_h + +#include + +#include "bearssl.h" + +typedef struct +{ + uint32_t notbefore_seconds; + uint32_t notbefore_days; + uint32_t notafter_seconds; + uint32_t notafter_days; +} x509_certificate_validity; + +typedef struct +{ + struct kref kref; + br_x509_certificate *chain; + size_t chain_len; + br_x509_trust_anchor *trust_anchors; + size_t trust_anchors_len; + + x509_certificate_validity validity; +} x509_certificate; + +typedef struct +{ + char *key; + x509_certificate *cert; + struct list_head list; +} cert_with_key; + +x509_certificate *x509_certificate_init(void); +void x509_certificate_get(x509_certificate *cert); +void x509_certificate_put(x509_certificate *cert); + +void add_cert_to_cache(char *key, x509_certificate *cert); +cert_with_key *find_cert_from_cache(char *key); +void remove_cert_from_cache(cert_with_key *cert); +void remove_cert_from_cache_locked(cert_with_key *cert); +void remove_unused_expired_certs_from_cache(void); + +bool validate_cert(x509_certificate_validity cert_validity); +int set_cert_validity(x509_certificate *x509_cert); + +#endif \ No newline at end of file diff --git a/commands.c b/commands.c index 7983d3ba..3927c58f 100644 --- a/commands.c +++ b/commands.c @@ -237,28 +237,30 @@ csr_sign_answer *send_csrsign_command(unsigned char *csr) goto error; } - csr_sign_answer->trust_anchors_len = json_array_get_count(trust_anchors); + csr_sign_answer->cert = x509_certificate_init(); + + csr_sign_answer->cert->trust_anchors_len = json_array_get_count(trust_anchors); size_t srclen; - if (csr_sign_answer->trust_anchors_len > 0) + if (csr_sign_answer->cert->trust_anchors_len > 0) { - csr_sign_answer->trust_anchors = kmalloc(csr_sign_answer->trust_anchors_len * sizeof *csr_sign_answer->trust_anchors, GFP_KERNEL); + csr_sign_answer->cert->trust_anchors = kmalloc(csr_sign_answer->cert->trust_anchors_len * sizeof *csr_sign_answer->cert->trust_anchors, GFP_KERNEL); size_t u; - for (u = 0; u < csr_sign_answer->trust_anchors_len; u++) + for (u = 0; u < csr_sign_answer->cert->trust_anchors_len; u++) { JSON_Object *ta = json_array_get_object(trust_anchors, u); - csr_sign_answer->trust_anchors[u].flags = BR_X509_TA_CA; - csr_sign_answer->trust_anchors[u].pkey.key_type = BR_KEYTYPE_RSA; + csr_sign_answer->cert->trust_anchors[u].flags = BR_X509_TA_CA; + csr_sign_answer->cert->trust_anchors[u].pkey.key_type = BR_KEYTYPE_RSA; // RAW (DN) const char *raw_subject = json_object_get_string(ta, "rawSubject"); if (raw_subject != NULL) { srclen = strlen(raw_subject); - csr_sign_answer->trust_anchors[u].dn.data = kmalloc(srclen, GFP_KERNEL); - csr_sign_answer->trust_anchors[u].dn.len = base64_decode(csr_sign_answer->trust_anchors[u].dn.data, srclen, raw_subject, srclen); + csr_sign_answer->cert->trust_anchors[u].dn.data = kmalloc(srclen, GFP_KERNEL); + csr_sign_answer->cert->trust_anchors[u].dn.len = base64_decode(csr_sign_answer->cert->trust_anchors[u].dn.data, srclen, raw_subject, srclen); } // RSA_N @@ -266,8 +268,8 @@ csr_sign_answer *send_csrsign_command(unsigned char *csr) if (rsa_n != NULL) { srclen = strlen(rsa_n); - csr_sign_answer->trust_anchors[u].pkey.key.rsa.n = kmalloc(srclen, GFP_KERNEL); - csr_sign_answer->trust_anchors[u].pkey.key.rsa.nlen = base64_decode(csr_sign_answer->trust_anchors[u].pkey.key.rsa.n, srclen, rsa_n, srclen); + csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.n = kmalloc(srclen, GFP_KERNEL); + csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.nlen = base64_decode(csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.n, srclen, rsa_n, srclen); } // RSA_E @@ -275,8 +277,8 @@ csr_sign_answer *send_csrsign_command(unsigned char *csr) if (rsa_e != NULL) { srclen = strlen(rsa_e); - csr_sign_answer->trust_anchors[u].pkey.key.rsa.e = kmalloc(srclen, GFP_KERNEL); - csr_sign_answer->trust_anchors[u].pkey.key.rsa.elen = base64_decode(csr_sign_answer->trust_anchors[u].pkey.key.rsa.e, srclen, rsa_e, srclen); + csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.e = kmalloc(srclen, GFP_KERNEL); + csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.elen = base64_decode(csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.e, srclen, rsa_e, srclen); } } } @@ -288,12 +290,19 @@ csr_sign_answer *send_csrsign_command(unsigned char *csr) goto error; } - csr_sign_answer->chain = kzalloc(1 * sizeof *csr_sign_answer->chain, GFP_KERNEL); - csr_sign_answer->chain_len = 1; + csr_sign_answer->cert->chain = kzalloc(1 * sizeof *csr_sign_answer->cert->chain, GFP_KERNEL); + csr_sign_answer->cert->chain_len = 1; srclen = strlen(raw); - csr_sign_answer->chain[0].data = kmalloc(srclen, GFP_KERNEL); - csr_sign_answer->chain[0].data_len = base64_decode(csr_sign_answer->chain[0].data, srclen, raw, srclen); + csr_sign_answer->cert->chain->data = kmalloc(srclen, GFP_KERNEL); + csr_sign_answer->cert->chain->data_len = base64_decode(csr_sign_answer->cert->chain->data, srclen, raw, srclen); + + int result = set_cert_validity(csr_sign_answer->cert); + if (result < 0) + { + errormsg = "could not decode generated certificate"; + goto error; + } } return csr_sign_answer; diff --git a/commands.h b/commands.h index 1f634e26..1d908bdc 100644 --- a/commands.h +++ b/commands.h @@ -14,6 +14,7 @@ #include "task_context.h" #include "bearssl.h" #include "socket.h" +#include "cert_tools.h" #define COMMAND_TIMEOUT_SECONDS 1 @@ -26,10 +27,7 @@ typedef struct command_answer typedef struct csr_sign_answer { char *error; - br_x509_certificate *chain; - size_t chain_len; - br_x509_trust_anchor *trust_anchors; - size_t trust_anchors_len; + x509_certificate *cert; } csr_sign_answer; void free_command_answer(command_answer *cmd_answer); diff --git a/device_driver.c b/device_driver.c index d77e6b8a..49791730 100644 --- a/device_driver.c +++ b/device_driver.c @@ -29,14 +29,8 @@ static int major; /* major number assigned to our device driver */ -enum -{ - CDEV_NOT_USED = 0, - CDEV_EXCLUSIVE_OPEN = 1, -}; - /* Is device open? Used to prevent multiple access to device */ -static atomic_t already_open = ATOMIC_INIT(CDEV_NOT_USED); +atomic_t already_open = ATOMIC_INIT(CDEV_NOT_USED); static char device_buffer[DEVICE_BUFFER_SIZE]; static size_t device_buffer_size = 0; diff --git a/device_driver.h b/device_driver.h index bed22465..27b7ba62 100644 --- a/device_driver.h +++ b/device_driver.h @@ -22,6 +22,14 @@ wasm_vm_result load_module(const char *name, const char *code, unsigned length, #define DEVICE_NAME "nasp" /* Dev name as it appears in /dev/devices */ #define DEVICE_BUFFER_SIZE 2 * 1024 * 1024 /* Max length of the message from the device */ +enum +{ + CDEV_NOT_USED = 0, + CDEV_EXCLUSIVE_OPEN = 1, +}; + +extern atomic_t already_open; + int chardev_init(void); void chardev_exit(void); diff --git a/rsa_tools.c b/rsa_tools.c index e27aa67b..74772d6c 100644 --- a/rsa_tools.c +++ b/rsa_tools.c @@ -15,7 +15,7 @@ static br_hmac_drbg_context hmac_drbg_ctx; -#define RSA_BIT_LENGHT 2048 +#define RSA_BIT_LENGTH 2048 #define RSA_PUB_EXP 3 // BearSSL RSA Keygen related functions @@ -39,10 +39,10 @@ uint32_t generate_rsa_keys(br_rsa_private_key *rsa_priv, br_rsa_public_key *rsa_ { br_rsa_keygen rsa_keygen = br_rsa_keygen_get_default(); - unsigned char *raw_priv_key = kmalloc(BR_RSA_KBUF_PRIV_SIZE(RSA_BIT_LENGHT), GFP_KERNEL); - unsigned char *raw_pub_key = kmalloc(BR_RSA_KBUF_PUB_SIZE(RSA_BIT_LENGHT), GFP_KERNEL); + unsigned char *raw_priv_key = kmalloc(BR_RSA_KBUF_PRIV_SIZE(RSA_BIT_LENGTH), GFP_KERNEL); + unsigned char *raw_pub_key = kmalloc(BR_RSA_KBUF_PUB_SIZE(RSA_BIT_LENGTH), GFP_KERNEL); - return rsa_keygen(&hmac_drbg_ctx.vtable, rsa_priv, raw_priv_key, rsa_pub, raw_pub_key, RSA_BIT_LENGHT, RSA_PUB_EXP); + return rsa_keygen(&hmac_drbg_ctx.vtable, rsa_priv, raw_priv_key, rsa_pub, raw_pub_key, RSA_BIT_LENGTH, RSA_PUB_EXP); } void free_rsa_private_key(br_rsa_private_key *key) @@ -57,8 +57,36 @@ void free_rsa_public_key(br_rsa_public_key *key) kfree(key); } +void free_br_x509_certificate(br_x509_certificate *chain, size_t chain_len) +{ + if (chain_len > 0) + { + size_t i; + for (i = 0; i < chain_len; i++) + { + kfree(chain[i].data); + } + } + kfree(chain); +} + +void free_br_x509_trust_anchors(br_x509_trust_anchor *trust_anchors, size_t trust_anchor_len) +{ + if (trust_anchor_len > 0) + { + size_t i; + for (i = 0; i < trust_anchor_len; i++) + { + kfree(trust_anchors[i].dn.data); + kfree(trust_anchors[i].pkey.key.rsa.n); + kfree(trust_anchors[i].pkey.key.rsa.e); + } + } + kfree(trust_anchors); +} + // BearSSL RSA Keygen related functions -// Encodes rsa private key to pkcs8 der format and returns it's lenght. +// Encodes rsa private key to pkcs8 der format and returns it's length. // If the der parameter is set to NULL then it computes only the length int encode_rsa_priv_key_to_der(unsigned char *der, br_rsa_private_key *rsa_priv, br_rsa_public_key *rsa_pub) { @@ -66,7 +94,7 @@ int encode_rsa_priv_key_to_der(unsigned char *der, br_rsa_private_key *rsa_priv, size_t priv_exponent_size = rsa_priv_exp_comp(NULL, rsa_priv, RSA_PUB_EXP); if (priv_exponent_size == 0) { - pr_err("rsa_tools: error happened during priv_exponent lenght calculation"); + pr_err("rsa_tools: error happened during priv_exponent length calculation"); return -1; } unsigned char priv_exponent[priv_exponent_size]; diff --git a/rsa_tools.h b/rsa_tools.h index 032cbf93..1c0a8d95 100644 --- a/rsa_tools.h +++ b/rsa_tools.h @@ -17,6 +17,8 @@ int init_rnd_gen(void); uint32_t generate_rsa_keys(br_rsa_private_key *rsa_priv, br_rsa_public_key *rsa_pub); void free_rsa_private_key(br_rsa_private_key *key); void free_rsa_public_key(br_rsa_public_key *key); +void free_br_x509_certificate(br_x509_certificate *chain, size_t chain_len); +void free_br_x509_trust_anchors(br_x509_trust_anchor *trust_anchors, size_t trust_anchor_len); int encode_rsa_priv_key_to_der(unsigned char *der, br_rsa_private_key *rsa_priv, br_rsa_public_key *rsa_pub); #endif diff --git a/socket.c b/socket.c index 0dbbe5c8..fb6739cd 100644 --- a/socket.c +++ b/socket.c @@ -28,6 +28,7 @@ #include "socket.h" #include "tls.h" #include "string.h" +#include "cert_tools.h" const char *ALPNs[] = { "istio-peer-exchange", @@ -65,13 +66,9 @@ struct nasp_socket br_rsa_private_key *rsa_priv; br_rsa_public_key *rsa_pub; - br_x509_certificate *cert; csr_parameters *parameters; - br_x509_certificate *chain; - size_t chain_len; - br_x509_trust_anchor *trust_anchors; - size_t trust_anchors_len; + x509_certificate *cert; proxywasm *p; proxywasm_context *pc; @@ -295,22 +292,13 @@ static void nasp_socket_free(nasp_socket *s) br_x509_nasp_free(&s->xc); opa_socket_context_free(s->opa_socket_ctx); - - // if (c->rsa_priv != NULL) - // { - // kfree(c->rsa_priv->p); - // } - // if (c->rsa_pub != NULL) - // { - // kfree(c->rsa_pub->n); - // } - buffer_free(s->read_buffer); buffer_free(s->write_buffer); kfree(s->rsa_priv); kfree(s->rsa_pub); - kfree(s->cert); + x509_certificate_put(s->cert); + kfree(s->parameters); kfree(s); } @@ -348,9 +336,6 @@ static nasp_socket *nasp_socket_accept(struct sock *sock) s->sc = kzalloc(sizeof(br_ssl_server_context), GFP_KERNEL); s->rsa_priv = kzalloc(sizeof(br_rsa_private_key), GFP_KERNEL); s->rsa_pub = kzalloc(sizeof(br_rsa_public_key), GFP_KERNEL); - s->cert = kzalloc(sizeof(br_x509_certificate), GFP_KERNEL); - s->chain = kzalloc(sizeof(br_x509_certificate), GFP_KERNEL); - s->trust_anchors = kzalloc(sizeof(br_x509_trust_anchor), GFP_KERNEL); s->parameters = kzalloc(sizeof(csr_parameters), GFP_KERNEL); s->read_buffer = buffer_new(16 * 1024); s->write_buffer = buffer_new(16 * 1024); @@ -383,7 +368,6 @@ static nasp_socket *nasp_socket_connect(struct sock *sock) s->cc = kzalloc(sizeof(br_ssl_client_context), GFP_KERNEL); s->rsa_priv = kzalloc(sizeof(br_rsa_private_key), GFP_KERNEL); s->rsa_pub = kzalloc(sizeof(br_rsa_public_key), GFP_KERNEL); - s->cert = kzalloc(sizeof(br_x509_certificate), GFP_KERNEL); s->parameters = kzalloc(sizeof(csr_parameters), GFP_KERNEL); s->read_buffer = buffer_new(16 * 1024); s->write_buffer = buffer_new(16 * 1024); @@ -789,111 +773,141 @@ int (*connect_v6)(struct sock *sk, struct sockaddr *uaddr, int addr_len); static int handle_cert_gen(nasp_socket *sc) { - // We should not only check for empty cert but we must check the certs validity - // TODO must set the certificate to avoid new cert generation every time - if (sc->chain_len == 0) + // Generating certificate signing request + if (sc->rsa_priv->plen == 0 || sc->rsa_pub->elen == 0) { - // generating certificate signing request - if (sc->rsa_priv->plen == 0 || sc->rsa_pub->elen == 0) + u_int32_t result = generate_rsa_keys(sc->rsa_priv, sc->rsa_pub); + if (result == 0) { - u_int32_t result = generate_rsa_keys(sc->rsa_priv, sc->rsa_pub); - if (result == 0) - { - pr_err("nasp: generate_csr error generating rsa keys"); - return -1; - } - } - - int len = encode_rsa_priv_key_to_der(NULL, sc->rsa_priv, sc->rsa_pub); - if (len <= 0) - { - pr_err("nasp: generate_csr error during rsa private der key length calculation"); + pr_err("nasp: generate_csr error generating rsa keys"); return -1; } + } - unsigned char *csr_ptr; + int len = encode_rsa_priv_key_to_der(NULL, sc->rsa_priv, sc->rsa_pub); + if (len <= 0) + { + pr_err("nasp: generate_csr error during rsa private der key length calculation"); + return -1; + } - csr_module *csr = this_cpu_csr(); - csr_lock(csr); - { - // Allocate memory inside the wasm vm since this data must be available inside the module - wasm_vm_result malloc_result = csr_malloc(csr, len); - if (malloc_result.err) - { - pr_err("nasp: generate_csr wasm_vm_csr_malloc error: %s", malloc_result.err); - csr_unlock(csr); - return -1; - } + unsigned char *csr_ptr; - uint8_t *mem = wasm_vm_memory(get_csr_module(csr)); - i32 addr = malloc_result.data->i32; + csr_module *csr = this_cpu_csr(); + csr_lock(csr); + // Allocate memory inside the wasm vm since this data must be available inside the module + wasm_vm_result malloc_result = csr_malloc(csr, len); + if (malloc_result.err) + { + pr_err("nasp: generate_csr wasm_vm_csr_malloc error: %s", malloc_result.err); + csr_unlock(csr); + return -1; + } - unsigned char *der = mem + addr; + uint8_t *mem = wasm_vm_memory(get_csr_module(csr)); + i32 addr = malloc_result.data->i32; - int error = encode_rsa_priv_key_to_der(der, sc->rsa_priv, sc->rsa_pub); - if (error <= 0) - { - pr_err("nasp: generate_csr error during rsa private key der encoding"); - csr_unlock(csr); - return -1; - } + unsigned char *der = mem + addr; - sc->parameters->subject = "CN=nasp-protected-workload"; + int error = encode_rsa_priv_key_to_der(der, sc->rsa_priv, sc->rsa_pub); + if (error <= 0) + { + pr_err("nasp: generate_csr error during rsa private key der encoding"); + csr_unlock(csr); + return -1; + } - if (sc->opa_socket_ctx.dns) - { - sc->parameters->dns = sc->opa_socket_ctx.dns; - } - if (sc->opa_socket_ctx.uri) - { - sc->parameters->uri = sc->opa_socket_ctx.uri; - } + sc->parameters->subject = "CN=nasp-protected-workload"; - csr_result generated_csr = csr_gen(csr, addr, len, sc->parameters); - if (generated_csr.err) - { - pr_err("nasp: generate_csr wasm_vm_csr_gen error: %s", generated_csr.err); - csr_unlock(csr); - return -1; - } + if (sc->opa_socket_ctx.dns) + { + sc->parameters->dns = sc->opa_socket_ctx.dns; + } + if (sc->opa_socket_ctx.uri) + { + sc->parameters->uri = sc->opa_socket_ctx.uri; + } - wasm_vm_result free_result = csr_free(csr, addr); - if (free_result.err) - { - pr_err("nasp: generate_csr wasm_vm_csr_free error: %s", free_result.err); - csr_unlock(csr); - return -1; - } + csr_result generated_csr = csr_gen(csr, addr, len, sc->parameters); + if (generated_csr.err) + { + pr_err("nasp: generate_csr wasm_vm_csr_gen error: %s", generated_csr.err); + csr_unlock(csr); + return -1; + } - csr_ptr = strndup(generated_csr.csr_ptr + mem, generated_csr.csr_len); - free_result = csr_free(csr, generated_csr.csr_ptr); - if (free_result.err) - { - pr_err("nasp: generate_csr wasm_vm_csr_free error: %s", free_result.err); - csr_unlock(csr); - return -1; - } - } + wasm_vm_result free_result = csr_free(csr, addr); + if (free_result.err) + { + pr_err("nasp: generate_csr wasm_vm_csr_free error: %s", free_result.err); + csr_unlock(csr); + return -1; + } + + csr_ptr = strndup(generated_csr.csr_ptr + mem, generated_csr.csr_len); + free_result = csr_free(csr, generated_csr.csr_ptr); + if (free_result.err) + { + pr_err("nasp: generate_csr wasm_vm_csr_free error: %s", free_result.err); csr_unlock(csr); + return -1; + } + csr_unlock(csr); + + csr_sign_answer *csr_sign_answer; + csr_sign_answer = send_csrsign_command(csr_ptr); + if (csr_sign_answer->error) + { + pr_err("nasp: generate_csr csr sign answer error: %s", csr_sign_answer->error); + kfree(csr_sign_answer->error); + kfree(csr_sign_answer); + return -1; + } + else + { + x509_certificate_get(csr_sign_answer->cert); + sc->cert = csr_sign_answer->cert; + } + kfree(csr_sign_answer); + return 0; +} + +static int cache_and_validate_cert(nasp_socket *sc, char *key) +{ + // Check if cert gen is required or we already have a cached certificate for this socket. + u16 cert_validation_err_no = 0; - csr_sign_answer *csr_sign_answer; - csr_sign_answer = send_csrsign_command(csr_ptr); - if (csr_sign_answer->error) + cert_with_key *cached_cert_bundle = find_cert_from_cache(key); + if (!cached_cert_bundle) + { + regen_cert: + int err = handle_cert_gen(sc); + if (err == -1) { - pr_err("nasp: generate_csr csr sign answer error: %s", csr_sign_answer->error); - kfree(csr_sign_answer->error); - kfree(csr_sign_answer); return -1; } - else + add_cert_to_cache(key, sc->cert); + } + // Cert found in the cache use that + else + { + x509_certificate_get(cached_cert_bundle->cert); + sc->cert = cached_cert_bundle->cert; + } + // Validate the cached or the generated cert + if (!validate_cert(sc->cert->validity)) + { + pr_warn("nasp: provided certificate is invalid"); + remove_cert_from_cache(cached_cert_bundle); + cert_validation_err_no++; + if (cert_validation_err_no == 1) { - sc->trust_anchors = csr_sign_answer->trust_anchors; - sc->trust_anchors_len = csr_sign_answer->trust_anchors_len; - sc->chain = csr_sign_answer->chain; - sc->chain_len = csr_sign_answer->chain_len; + goto regen_cert; + } + else if (cert_validation_err_no == 2) + { + return -1; } - - kfree(csr_sign_answer); } return 0; } @@ -946,6 +960,12 @@ struct sock *nasp_accept(struct sock *sk, int flags, int *err, bool kern) goto error; } + // return if the agent is not running + if (atomic_read(&already_open) == CDEV_NOT_USED) + { + return client; + } + u16 port = (u16)(sk->sk_portpair >> 16); sc = nasp_socket_accept(client); @@ -978,11 +998,12 @@ struct sock *nasp_accept(struct sock *sk, int flags, int *err, bool kern) memcpy(sc->rsa_priv, rsa_priv, sizeof *sc->rsa_priv); memcpy(sc->rsa_pub, rsa_pub, sizeof *sc->rsa_pub); - int result = handle_cert_gen(sc); + int result = cache_and_validate_cert(sc, sc->opa_socket_ctx.id); if (result == -1) { goto error; } + /* * Initialise the context with the cipher suites and * algorithms. This depends on the server key type @@ -997,13 +1018,13 @@ struct sock *nasp_accept(struct sock *sk, int flags, int *err, bool kern) * EC key, cert signed with RSA: ECDH_RSA or ECDHE_ECDSA */ pr_info("nasp: accept use cert from agent"); - br_ssl_server_init_full_rsa(sc->sc, sc->chain, sc->chain_len, sc->rsa_priv); + br_ssl_server_init_full_rsa(sc->sc, sc->cert->chain, sc->cert->chain_len, sc->rsa_priv); // mTLS enablement if (sc->opa_socket_ctx.mtls) { - br_x509_minimal_init_full(&sc->xc.ctx, sc->trust_anchors, sc->trust_anchors_len); - br_ssl_server_set_trust_anchor_names_alt(sc->sc, sc->trust_anchors, sc->trust_anchors_len); + br_x509_minimal_init_full(&sc->xc.ctx, sc->cert->trust_anchors, sc->cert->trust_anchors_len); + br_ssl_server_set_trust_anchor_names_alt(sc->sc, sc->cert->trust_anchors, sc->cert->trust_anchors_len); br_x509_nasp_init(&sc->xc, &sc->sc->eng, &sc->opa_socket_ctx); br_ssl_engine_set_default_rsavrfy(&sc->sc->eng); @@ -1045,8 +1066,6 @@ struct sock *nasp_accept(struct sock *sk, int flags, int *err, bool kern) if (client) client->sk_prot->close(client, 0); - pr_err("nasp: [%s] accept error, socket closed", current->comm); - return NULL; } @@ -1075,6 +1094,12 @@ int nasp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) goto error; } + // return if the agent is not running + if (atomic_read(&already_open) == CDEV_NOT_USED) + { + return err; + } + pr_info("nasp: nasp_connect uid: %d app: %s to port: %d", current_uid().val, current->comm, port); sc = nasp_socket_connect(sk); @@ -1107,12 +1132,11 @@ int nasp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) memcpy(sc->rsa_priv, rsa_priv, sizeof *sc->rsa_priv); memcpy(sc->rsa_pub, rsa_pub, sizeof *sc->rsa_pub); - int result = handle_cert_gen(sc); + int result = cache_and_validate_cert(sc, sc->opa_socket_ctx.id); if (result == -1) { goto error; } - /* * Initialise the context with the cipher suites and * algorithms. This depends on the server key type @@ -1127,14 +1151,14 @@ int nasp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) * EC key, cert signed with RSA: ECDH_RSA or ECDHE_ECDSA */ pr_info("nasp: connect use cert from agent"); - br_ssl_client_init_full(sc->cc, &sc->xc.ctx, sc->trust_anchors, sc->trust_anchors_len); + br_ssl_client_init_full(sc->cc, &sc->xc.ctx, sc->cert->trust_anchors, sc->cert->trust_anchors_len); br_x509_nasp_init(&sc->xc, &sc->cc->eng, &sc->opa_socket_ctx); // mTLS enablement if (sc->opa_socket_ctx.mtls) { - br_ssl_client_set_single_rsa(sc->cc, sc->chain, sc->chain_len, sc->rsa_priv, br_rsa_pkcs1_sign_get_default()); + br_ssl_client_set_single_rsa(sc->cc, sc->cert->chain, sc->cert->chain_len, sc->rsa_priv, br_rsa_pkcs1_sign_get_default()); } /* @@ -1179,8 +1203,6 @@ int nasp_connect(struct sock *sk, struct sockaddr *uaddr, int addr_len) sk->sk_prot->close(sk, 0); release_sock(sk); - pr_err("nasp: [%s] connect error, socket closed", current->comm); - return err; } @@ -1247,8 +1269,8 @@ void socket_exit(void) tcpv6_prot.connect = connect_v6; //- free global tls key - kfree(rsa_priv); - kfree(rsa_pub); + free_rsa_private_key(rsa_priv); + free_rsa_public_key(rsa_pub); pr_info("nasp: socket support unloaded."); } diff --git a/third-party/BearSSL b/third-party/BearSSL index 57d61c2c..941f96a1 160000 --- a/third-party/BearSSL +++ b/third-party/BearSSL @@ -1 +1 @@ -Subproject commit 57d61c2cb14255ad01e40b4b8b8a3d67fcd109a0 +Subproject commit 941f96a139e271ba9bdfd867655e3e03ea4b509b