Skip to content
This repository has been archived by the owner on Sep 19, 2024. It is now read-only.

Commit

Permalink
Initial impl of cert bundle caching (#81)
Browse files Browse the repository at this point in the history
* Initial impl of cert bundle caching

* Add certificate validation for expiration

* Remove code dup around cert_cache handling

* Use id from opa.h insted of keyid

* Remove global rsa_key and use per connection one

* Fix memory leak for rsa_pub rsa_priv key

* Fix memory leak around cache and generated certificate

* Add global rsa_private key back since it is causing additional errors

* Fix cache key mismatch happened because of freeing the char* id after socket close

* Fix typo in function free_rsa_public_key

* Wip, todo need to rearange list initialization

* Use proper lock handling for remove cert from cache

* Allign with linux kernel naming on locked function

* Fix word typo

* refactor cert mem representation

* remove unnecessary socket closed logs

* passthrough if the agent is not running

* Optimize certificate decode, and store validity in the struct

* Beautify function and field names, return if cert is null

---------

Co-authored-by: Zsolt Varga <[email protected]>
  • Loading branch information
baluchicken and waynz0r authored Nov 3, 2023
1 parent 07dbbfc commit 806a8c4
Show file tree
Hide file tree
Showing 11 changed files with 513 additions and 152 deletions.
1 change: 1 addition & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ nasp-objs := third-party/wasm3/source/m3_api_libc.o \
main.o \
csr.o \
rsa_tools.o \
cert_tools.o \
wasm.o \
opa.o \
proxywasm.o \
Expand Down
242 changes: 242 additions & 0 deletions cert_tools.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,242 @@
/*
* Copyright (c) 2023 Cisco and/or its affiliates. All rights reserved.
*
* SPDX-License-Identifier: MIT OR GPL-2.0-only
*
* Licensed under the MIT license <LICENSE.MIT or https://opensource.org/licenses/MIT> or the GPLv2 license
* <LICENSE.GPL or https://opensource.org/license/gpl-2-0>, at your option. This file may not be copied,
* modified, or distributed except according to those terms.
*/

#include <linux/list.h>
#include <linux/slab.h>

#include "cert_tools.h"
#include "string.h"
#include "rsa_tools.h"

// Define the maximum number of elements inside the cache
#define MAX_CACHE_LENGTH 64

// certs that are in use or used once by a workload
static LIST_HEAD(cert_cache);

// lock for the above list to make it thread safe
static DEFINE_MUTEX(certificate_cache_lock);

static void cert_cache_lock(void)
{
mutex_lock(&certificate_cache_lock);
}

static void cert_cache_unlock(void)
{
mutex_unlock(&certificate_cache_lock);
}

static size_t linkedlist_length(struct list_head *head)
{
struct list_head *pos;
int length = 0;

list_for_each(pos, head)
{
length++;
}

return length;
}

// add_cert_to_cache adds a certificate chain with a given trust anchor to a linked list. The key will identify this entry.
// the function is thread safe.
void add_cert_to_cache(char *key, x509_certificate *cert)
{
if (!key)
{
pr_err("nasp: provided key is null");
return;
}

cert_with_key *new_entry = kzalloc(sizeof(cert_with_key), GFP_KERNEL);
if (!new_entry)
{
pr_err("nasp: memory allocation error");
return;
}
new_entry->key = strdup(key);
new_entry->cert = cert;

cert_cache_lock();
INIT_LIST_HEAD(&new_entry->list);
list_add(&new_entry->list, &cert_cache);
cert_cache_unlock();
}

// remove_unused_expired_certs_from_cache iterates over the whole cache and tries to clean up the unused/expired certificates.
// it works like a garbage collection which now runs before every add.
// TODO handle cases when cache length is maxed out but no expired certificate
void remove_unused_expired_certs_from_cache()
{
cert_with_key *cert_bundle, *cert_bundle_tmp;

if (linkedlist_length(&cert_cache) >= MAX_CACHE_LENGTH)
{
pr_warn("nasp: cache is full removing the oldest element");
cert_with_key *last_entry = list_last_entry(&cert_cache, cert_with_key, list);
pr_warn("nasp: removing key:%s from the cache", last_entry->key);
remove_cert_from_cache(last_entry);
return;
}

cert_cache_lock();
list_for_each_entry_safe_reverse(cert_bundle, cert_bundle_tmp, &cert_cache, list)
{
if (!validate_cert(cert_bundle->cert->validity))
{
remove_cert_from_cache_locked(cert_bundle);
}
}
cert_cache_unlock();
}

// find_cert_from_cache tries to find a certificate bundle for the given key. In case of failure it returns a NULL.
// this function also runs a garbage collection on the cache.
// the function is thread safe
cert_with_key *find_cert_from_cache(char *key)
{

remove_unused_expired_certs_from_cache();

cert_with_key *cert_bundle;
cert_cache_lock();
list_for_each_entry(cert_bundle, &cert_cache, list)
{
if (strncmp(cert_bundle->key, key, strlen(key)) == 0)
{
cert_cache_unlock();
return cert_bundle;
}
}
cert_cache_unlock();
return 0;
}

// remove_cert_from_cache_locked removes a given certificate bundle from the cache
// the function is thread safe
void remove_cert_from_cache(cert_with_key *cert_bundle)
{
if (cert_bundle)
{
cert_cache_lock();
remove_cert_from_cache_locked(cert_bundle);
cert_cache_unlock();
}
}

// remove_cert_from_cache removes a given certificate bundle from the cache
void remove_cert_from_cache_locked(cert_with_key *cert_bundle)
{
if (cert_bundle)
{
list_del(&cert_bundle->list);
x509_certificate_put(cert_bundle->cert);
kfree(cert_bundle);
}
}

// set_cert_validity decodes the provided certificate and filling the validity seconds and days.
// if the decode fails it returns -1
int set_cert_validity(x509_certificate *x509_cert)
{
br_x509_decoder_context dc;

br_x509_decoder_init(&dc, 0, 0);
br_x509_decoder_push(&dc, x509_cert->chain->data, x509_cert->chain->data_len);
int err = br_x509_decoder_last_error(&dc);
if (err != 0)
{
pr_err("nasp: cert decode faild during setting cert validity: %d", err);
return -1;
}
x509_cert->validity.notbefore_seconds = dc.notbefore_seconds;
x509_cert->validity.notbefore_days = dc.notbefore_days;

x509_cert->validity.notafter_seconds = dc.notafter_seconds;
x509_cert->validity.notafter_days = dc.notafter_days;

return 0;
}

// validate_cert validates the given certificate if it has expired or not.
bool validate_cert(x509_certificate_validity cert_validity)
{
bool result = false;

time64_t x = ktime_get_real_seconds();
uint32_t vd = (uint32_t)(x / 86400) + 719528;
uint32_t vs = (uint32_t)(x % 86400);

if (vd < cert_validity.notbefore_days || (vd == cert_validity.notbefore_days && vs < cert_validity.notbefore_seconds))
{
pr_warn("nasp: cert expired");
}
else if (vd > cert_validity.notafter_days || (vd == cert_validity.notafter_days && vs > cert_validity.notafter_seconds))
{
pr_warn("nasp: cert not valid yet");
}
else
{
result = true;
}

return result;
}

x509_certificate *x509_certificate_init(void)
{
x509_certificate *cert = kzalloc(sizeof(x509_certificate), GFP_KERNEL);

kref_init(&cert->kref);

return cert;
}

static void x509_certificate_free(x509_certificate *cert)
{
pr_info("nasp: x509_certificate_free");

if (!cert)
{
return;
}

free_br_x509_certificate(cert->chain, cert->chain_len);
free_br_x509_trust_anchors(cert->trust_anchors, cert->trust_anchors_len);

kfree(cert);
}

static void x509_certificate_release(struct kref *kref)
{
x509_certificate *cert = container_of(kref, x509_certificate, kref);

x509_certificate_free(cert);
}

void x509_certificate_get(x509_certificate *cert)
{
if (!cert)
{
return;
}
kref_get(&cert->kref);
}

void x509_certificate_put(x509_certificate *cert)
{
if (!cert)
{
return;
}
kref_put(&cert->kref, x509_certificate_release);
}
57 changes: 57 additions & 0 deletions cert_tools.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/*
* Copyright (c) 2023 Cisco and/or its affiliates. All rights reserved.
*
* SPDX-License-Identifier: MIT OR GPL-2.0-only
*
* Licensed under the MIT license <LICENSE.MIT or https://opensource.org/licenses/MIT> or the GPLv2 license
* <LICENSE.GPL or https://opensource.org/license/gpl-2-0>, at your option. This file may not be copied,
* modified, or distributed except according to those terms.
*/

#ifndef cert_tools_h
#define cert_tools_h

#include <linux/kref.h>

#include "bearssl.h"

typedef struct
{
uint32_t notbefore_seconds;
uint32_t notbefore_days;
uint32_t notafter_seconds;
uint32_t notafter_days;
} x509_certificate_validity;

typedef struct
{
struct kref kref;
br_x509_certificate *chain;
size_t chain_len;
br_x509_trust_anchor *trust_anchors;
size_t trust_anchors_len;

x509_certificate_validity validity;
} x509_certificate;

typedef struct
{
char *key;
x509_certificate *cert;
struct list_head list;
} cert_with_key;

x509_certificate *x509_certificate_init(void);
void x509_certificate_get(x509_certificate *cert);
void x509_certificate_put(x509_certificate *cert);

void add_cert_to_cache(char *key, x509_certificate *cert);
cert_with_key *find_cert_from_cache(char *key);
void remove_cert_from_cache(cert_with_key *cert);
void remove_cert_from_cache_locked(cert_with_key *cert);
void remove_unused_expired_certs_from_cache(void);

bool validate_cert(x509_certificate_validity cert_validity);
int set_cert_validity(x509_certificate *x509_cert);

#endif
41 changes: 25 additions & 16 deletions commands.c
Original file line number Diff line number Diff line change
Expand Up @@ -237,46 +237,48 @@ csr_sign_answer *send_csrsign_command(unsigned char *csr)
goto error;
}

csr_sign_answer->trust_anchors_len = json_array_get_count(trust_anchors);
csr_sign_answer->cert = x509_certificate_init();

csr_sign_answer->cert->trust_anchors_len = json_array_get_count(trust_anchors);
size_t srclen;

if (csr_sign_answer->trust_anchors_len > 0)
if (csr_sign_answer->cert->trust_anchors_len > 0)
{
csr_sign_answer->trust_anchors = kmalloc(csr_sign_answer->trust_anchors_len * sizeof *csr_sign_answer->trust_anchors, GFP_KERNEL);
csr_sign_answer->cert->trust_anchors = kmalloc(csr_sign_answer->cert->trust_anchors_len * sizeof *csr_sign_answer->cert->trust_anchors, GFP_KERNEL);

size_t u;
for (u = 0; u < csr_sign_answer->trust_anchors_len; u++)
for (u = 0; u < csr_sign_answer->cert->trust_anchors_len; u++)
{
JSON_Object *ta = json_array_get_object(trust_anchors, u);

csr_sign_answer->trust_anchors[u].flags = BR_X509_TA_CA;
csr_sign_answer->trust_anchors[u].pkey.key_type = BR_KEYTYPE_RSA;
csr_sign_answer->cert->trust_anchors[u].flags = BR_X509_TA_CA;
csr_sign_answer->cert->trust_anchors[u].pkey.key_type = BR_KEYTYPE_RSA;

// RAW (DN)
const char *raw_subject = json_object_get_string(ta, "rawSubject");
if (raw_subject != NULL)
{
srclen = strlen(raw_subject);
csr_sign_answer->trust_anchors[u].dn.data = kmalloc(srclen, GFP_KERNEL);
csr_sign_answer->trust_anchors[u].dn.len = base64_decode(csr_sign_answer->trust_anchors[u].dn.data, srclen, raw_subject, srclen);
csr_sign_answer->cert->trust_anchors[u].dn.data = kmalloc(srclen, GFP_KERNEL);
csr_sign_answer->cert->trust_anchors[u].dn.len = base64_decode(csr_sign_answer->cert->trust_anchors[u].dn.data, srclen, raw_subject, srclen);
}

// RSA_N
const char *rsa_n = json_object_dotget_string(ta, "publicKey.RSA_N");
if (rsa_n != NULL)
{
srclen = strlen(rsa_n);
csr_sign_answer->trust_anchors[u].pkey.key.rsa.n = kmalloc(srclen, GFP_KERNEL);
csr_sign_answer->trust_anchors[u].pkey.key.rsa.nlen = base64_decode(csr_sign_answer->trust_anchors[u].pkey.key.rsa.n, srclen, rsa_n, srclen);
csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.n = kmalloc(srclen, GFP_KERNEL);
csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.nlen = base64_decode(csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.n, srclen, rsa_n, srclen);
}

// RSA_E
const char *rsa_e = json_object_dotget_string(ta, "publicKey.RSA_E");
if (rsa_e != NULL)
{
srclen = strlen(rsa_e);
csr_sign_answer->trust_anchors[u].pkey.key.rsa.e = kmalloc(srclen, GFP_KERNEL);
csr_sign_answer->trust_anchors[u].pkey.key.rsa.elen = base64_decode(csr_sign_answer->trust_anchors[u].pkey.key.rsa.e, srclen, rsa_e, srclen);
csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.e = kmalloc(srclen, GFP_KERNEL);
csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.elen = base64_decode(csr_sign_answer->cert->trust_anchors[u].pkey.key.rsa.e, srclen, rsa_e, srclen);
}
}
}
Expand All @@ -288,12 +290,19 @@ csr_sign_answer *send_csrsign_command(unsigned char *csr)
goto error;
}

csr_sign_answer->chain = kzalloc(1 * sizeof *csr_sign_answer->chain, GFP_KERNEL);
csr_sign_answer->chain_len = 1;
csr_sign_answer->cert->chain = kzalloc(1 * sizeof *csr_sign_answer->cert->chain, GFP_KERNEL);
csr_sign_answer->cert->chain_len = 1;

srclen = strlen(raw);
csr_sign_answer->chain[0].data = kmalloc(srclen, GFP_KERNEL);
csr_sign_answer->chain[0].data_len = base64_decode(csr_sign_answer->chain[0].data, srclen, raw, srclen);
csr_sign_answer->cert->chain->data = kmalloc(srclen, GFP_KERNEL);
csr_sign_answer->cert->chain->data_len = base64_decode(csr_sign_answer->cert->chain->data, srclen, raw, srclen);

int result = set_cert_validity(csr_sign_answer->cert);
if (result < 0)
{
errormsg = "could not decode generated certificate";
goto error;
}
}

return csr_sign_answer;
Expand Down
Loading

0 comments on commit 806a8c4

Please sign in to comment.