Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PSK and DTLS support #202

Merged
merged 11 commits into from
Oct 6, 2022
10 changes: 9 additions & 1 deletion mbedtls/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ cc = "1.0"
default = ["std", "aesni", "time", "padlock"]
std = ["mbedtls-sys-auto/std", "serde/std", "yasna"]
debug = ["mbedtls-sys-auto/debug"]
no_std_deps = ["core_io", "spin"]
no_std_deps = ["core_io", "spin", "serde/alloc"]
zugzwang marked this conversation as resolved.
Show resolved Hide resolved
force_aesni_support = ["mbedtls-sys-auto/custom_has_support", "mbedtls-sys-auto/aes_alt", "aesni"]
mpi_force_c_code = ["mbedtls-sys-auto/mpi_force_c_code"]
rdrand = []
Expand All @@ -73,6 +73,14 @@ pkcs12_rc2 = ["pkcs12", "rc2", "block-modes"]
name = "client"
required-features = ["std"]

[[example]]
name = "client_dtls"
required-features = ["std"]

[[example]]
name = "client_psk"
required-features = ["std"]

[[example]]
name = "server"
required-features = ["std"]
Expand Down
57 changes: 57 additions & 0 deletions mbedtls/examples/client_dtls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
/* Copyright (c) Fortanix, Inc.
*
* Licensed under the GNU General Public License, version 2 <LICENSE-GPL or
* https://www.gnu.org/licenses/gpl-2.0.html> or the Apache License, Version
* 2.0 <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0>, at your
* option. This file may not be copied, modified, or distributed except
* according to those terms. */

// needed to have common code for `mod support` in unit and integrations tests
extern crate mbedtls;

use std::io::{self, stdin, stdout, Write};
use std::net::UdpSocket;
use std::sync::Arc;

use mbedtls::rng::CtrDrbg;
use mbedtls::ssl::config::{Endpoint, Preset, Transport};
use mbedtls::ssl::{Config, Context};
use mbedtls::x509::Certificate;
use mbedtls::Result as TlsResult;

#[path = "../tests/support/mod.rs"]
mod support;
use support::entropy::entropy_new;
use support::keys;

fn result_main(addr: &str) -> TlsResult<()> {
let entropy = Arc::new(entropy_new());
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
let cert = Arc::new(Certificate::from_pem_multiple(keys::ROOT_CA_CERT.as_bytes())?);
let mut config = Config::new(Endpoint::Client, Transport::Datagram, Preset::Default);
config.set_rng(rng);
config.set_ca_list(cert, None);
let mut ctx = Context::new(Arc::new(config));
ctx.set_timer_callback(Box::new(mbedtls::ssl::context::Timer::new()));

let sock = UdpSocket::bind("localhost:12345").unwrap();
let sock = mbedtls::ssl::context::ConnectedUdpSocket::connect(sock, addr).unwrap();
ctx.establish(sock, None).unwrap();

let mut line = String::new();
stdin().read_line(&mut line).unwrap();
ctx.write_all(line.as_bytes()).unwrap();
io::copy(&mut ctx, &mut stdout()).unwrap();
Ok(())
}

fn main() {
let mut args = std::env::args();
args.next();
result_main(
&args
.next()
.expect("supply destination in command-line argument"),
)
.unwrap();
}
52 changes: 52 additions & 0 deletions mbedtls/examples/client_psk.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/* Copyright (c) Fortanix, Inc.
*
* Licensed under the GNU General Public License, version 2 <LICENSE-GPL or
* https://www.gnu.org/licenses/gpl-2.0.html> or the Apache License, Version
* 2.0 <LICENSE-APACHE or http://www.apache.org/licenses/LICENSE-2.0>, at your
* option. This file may not be copied, modified, or distributed except
* according to those terms. */

// needed to have common code for `mod support` in unit and integrations tests
extern crate mbedtls;

use std::io::{self, stdin, stdout, Write};
use std::net::TcpStream;
use std::sync::Arc;

use mbedtls::rng::CtrDrbg;
use mbedtls::ssl::config::{Endpoint, Preset, Transport};
use mbedtls::ssl::{Config, Context};
use mbedtls::Result as TlsResult;

#[path = "../tests/support/mod.rs"]
mod support;
use support::entropy::entropy_new;

fn result_main(addr: &str) -> TlsResult<()> {
let entropy = Arc::new(entropy_new());
let rng = Arc::new(CtrDrbg::new(entropy, None)?);
let mut config = Config::new(Endpoint::Client, Transport::Stream, Preset::Default);
config.set_rng(rng);
config.set_psk(&[0x12, 0x34, 0x56, 0x78], "client").unwrap();
let mut ctx = Context::new(Arc::new(config));

let conn = TcpStream::connect(addr).unwrap();
ctx.establish(conn, None)?;

let mut line = String::new();
stdin().read_line(&mut line).unwrap();
ctx.write_all(line.as_bytes()).unwrap();
io::copy(&mut ctx, &mut stdout()).unwrap();
Ok(())
}

fn main() {
let mut args = std::env::args();
args.next();
result_main(
&args
.next()
.expect("supply destination in command-line argument"),
)
.unwrap();
}
16 changes: 15 additions & 1 deletion mbedtls/src/ssl/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,6 +457,21 @@ impl Config {
self.dbg_callback = Some(Arc::new(cb));
unsafe { ssl_conf_dbg(self.into(), Some(dbg_callback::<F>), &**self.dbg_callback.as_mut().unwrap() as *const _ as *mut c_void) }
}

/// Sets the PSK and the PSK-Identity
///
/// Only a single entry is supported at the moment. If another one was set before, it will be
/// overridden.
pub fn set_psk(&mut self, psk: &[u8], psk_identity: &str) -> Result<()> {
unsafe {
// This allocates and copies the buffers and does not store any pointer to them
let psk_identity = psk_identity.as_bytes();
ssl_conf_psk(self.into(), psk.as_ptr(), psk.len(), psk_identity.as_ptr(), psk_identity.len())
.into_result()
.map(|_| ())?;
}
Ok(())
DrTobe marked this conversation as resolved.
Show resolved Hide resolved
}
}

// TODO
Expand All @@ -466,7 +481,6 @@ impl Config {
// ssl_conf_dtls_badmac_limit
// ssl_conf_handshake_timeout
// ssl_conf_session_cache
// ssl_conf_psk
// ssl_conf_psk_cb
// ssl_conf_sig_hashes
// ssl_conf_alpn_protocols
Expand Down
131 changes: 128 additions & 3 deletions mbedtls/src/ssl/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use core::result::Result as StdResult;

#[cfg(feature = "std")]
use {
std::io::{Read, Write, Result as IoResult},
std::io::{Read, Write, Result as IoResult, Error as IoError},
std::sync::Arc,
};

Expand Down Expand Up @@ -67,6 +67,121 @@ impl<IO: Read + Write> IoCallback for IO {
}
}

#[cfg(feature = "std")]
pub struct ConnectedUdpSocket {
socket: std::net::UdpSocket,
}

#[cfg(feature = "std")]
impl ConnectedUdpSocket {
pub fn connect<A: std::net::ToSocketAddrs>(socket: std::net::UdpSocket, addr: A) -> StdResult<Self, (IoError, std::net::UdpSocket)> {
match socket.connect(addr) {
Ok(_) => Ok(ConnectedUdpSocket {
socket,
}),
Err(e) => Err((e, socket)),
}
}
}

#[cfg(feature = "std")]
impl IoCallback for ConnectedUdpSocket {
unsafe extern "C" fn call_recv(user_data: *mut c_void, data: *mut c_uchar, len: size_t) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.recv(::core::slice::from_raw_parts_mut(data, len)) {
Ok(i) => i as c_int,
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => 0,
Err(_) => ::mbedtls_sys::ERR_NET_RECV_FAILED,
}
}

unsafe extern "C" fn call_send(user_data: *mut c_void, data: *const c_uchar, len: size_t) -> c_int {
let len = if len > (c_int::max_value() as size_t) {
c_int::max_value() as size_t
} else {
len
};
match (&mut *(user_data as *mut ConnectedUdpSocket)).socket.send(::core::slice::from_raw_parts(data, len)) {
Ok(i) => i as c_int,
Err(_) => ::mbedtls_sys::ERR_NET_SEND_FAILED,
}
}

fn data_ptr(&mut self) -> *mut c_void {
self as *mut ConnectedUdpSocket as *mut c_void
}
}

pub trait TimerCallback: Send + Sync {
unsafe extern "C" fn set_timer(
p_timer: *mut c_void,
int_ms: u32,
fin_ms: u32,
) where Self: Sized;

unsafe extern "C" fn get_timer(
p_timer: *mut c_void,
) -> c_int where Self: Sized;

fn data_ptr(&mut self) -> *mut c_void;
}

#[cfg(feature = "std")]
pub struct Timer {
timer_start: std::time::Instant,
timer_int_ms: u32,
timer_fin_ms: u32,
}

#[cfg(feature = "std")]
impl Timer {
pub fn new() -> Self {
Timer {
timer_start: std::time::Instant::now(),
timer_int_ms: 0,
timer_fin_ms: 0,
}
}
}

#[cfg(feature = "std")]
impl TimerCallback for Timer {
unsafe extern "C" fn set_timer(
p_timer: *mut c_void,
int_ms: u32,
fin_ms: u32,
) where Self: Sized {
let slf = (p_timer as *mut Timer).as_mut().unwrap();
slf.timer_start = std::time::Instant::now();
slf.timer_int_ms = int_ms;
slf.timer_fin_ms = fin_ms;
}

unsafe extern "C" fn get_timer(
p_timer: *mut c_void,
) -> c_int where Self: Sized {
let slf = (p_timer as *mut Timer).as_mut().unwrap();
if slf.timer_int_ms == 0 || slf.timer_fin_ms == 0 {
return 0;
}
let passed = std::time::Instant::now() - slf.timer_start;
if passed.as_millis() >= slf.timer_fin_ms.into() {
2
} else if passed.as_millis() >= slf.timer_int_ms.into() {
1
} else {
0
}
}

fn data_ptr(&mut self) -> *mut mbedtls_sys::types::raw_types::c_void {
self as *mut _ as *mut _
}
}

define!(
#[c_ty(ssl_context)]
Expand All @@ -89,11 +204,13 @@ pub struct Context<T> {
// Base structure used in SNI callback where we cannot determine the io type.
inner: HandshakeContext,

// config is used read-only for mutliple contexts and is immutable once configured.
// config is used read-only for multiple contexts and is immutable once configured.
config: Arc<Config>,

// Must be held in heap and pointer to it as pointer is sent to MbedSSL and can't be re-allocated.
io: Option<Box<T>>,

timer_callback: Option<Box<dyn TimerCallback>>,
}

impl<'a, T> Into<*const ssl_context> for &'a Context<T> {
Expand Down Expand Up @@ -128,6 +245,7 @@ impl<T> Context<T> {
},
config: config.clone(),
io: None,
timer_callback: None,
}
}

Expand Down Expand Up @@ -157,7 +275,7 @@ impl<T: IoCallback> Context<T> {
);

self.io = Some(io);
self.inner.reset_handshake();
self.inner.reset_handshake();
}

match self.handshake() {
Expand Down Expand Up @@ -298,6 +416,13 @@ impl<T> Context<T> {
}
}
}

pub fn set_timer_callback<F: TimerCallback + 'static>(&mut self, mut cb: Box<F>) {
unsafe {
ssl_set_timer_cb(self.into(), cb.data_ptr(), Some(F::set_timer), Some(F::get_timer));
}
self.timer_callback = Some(cb);
}
}

impl<T> Drop for Context<T> {
Expand Down