Skip to content

Commit

Permalink
Add the ability to provide a safe context for encoders to use
Browse files Browse the repository at this point in the history
  • Loading branch information
DouglasDwyer committed Apr 15, 2024
1 parent e470f00 commit 90a526d
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 6 deletions.
68 changes: 62 additions & 6 deletions src/stream/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
//!
//! They are mostly thin wrappers around `zstd_safe::{DCtx, CCtx}`.
use std::io;
use std::marker::PhantomData;
use std::ops::{Deref, DerefMut};

pub use zstd_safe::{CParameter, DParameter, InBuffer, OutBuffer, WriteBuf};

Expand Down Expand Up @@ -132,7 +134,7 @@ pub struct Status {

/// An in-memory decoder for streams of data.
pub struct Decoder<'a> {
context: zstd_safe::DCtx<'a>,
context: MaybeOwned<'a, zstd_safe::DCtx<'a>>,
}

impl Decoder<'static> {
Expand All @@ -148,11 +150,16 @@ impl Decoder<'static> {
context
.load_dictionary(dictionary)
.map_err(map_error_code)?;
Ok(Decoder { context })
Ok(Decoder { context: MaybeOwned::owned(context) })
}
}

impl<'a> Decoder<'a> {
/// Creates a new decoder which employs the provided context for deserialization.
pub fn with_context<'b: 'a>(context: &'a mut zstd_safe::DCtx<'b>) -> Self {
Self { context: MaybeOwned::borrowed(context) }
}

/// Creates a new decoder, using an existing `DecoderDictionary`.
pub fn with_prepared_dictionary<'b>(
dictionary: &DecoderDictionary<'b>,
Expand All @@ -164,7 +171,7 @@ impl<'a> Decoder<'a> {
context
.ref_ddict(dictionary.as_ddict())
.map_err(map_error_code)?;
Ok(Decoder { context })
Ok(Decoder { context: MaybeOwned::owned(context) })
}

/// Sets a decompression parameter for this decoder.
Expand Down Expand Up @@ -229,7 +236,7 @@ impl Operation for Decoder<'_> {

/// An in-memory encoder for streams of data.
pub struct Encoder<'a> {
context: zstd_safe::CCtx<'a>,
context: MaybeOwned<'a, zstd_safe::CCtx<'a>>,
}

impl Encoder<'static> {
Expand All @@ -250,11 +257,16 @@ impl Encoder<'static> {
.load_dictionary(dictionary)
.map_err(map_error_code)?;

Ok(Encoder { context })
Ok(Encoder { context: MaybeOwned::owned(context) })
}
}

impl<'a> Encoder<'a> {
/// Creates a new encoder that uses the provided context for serialization.
pub fn with_context<'b: 'a>(context: &'a mut zstd_safe::CCtx<'b>) -> Self {
Self { context: MaybeOwned::borrowed(context) }
}

/// Creates a new encoder using an existing `EncoderDictionary`.
pub fn with_prepared_dictionary<'b>(
dictionary: &EncoderDictionary<'b>,
Expand All @@ -266,7 +278,7 @@ impl<'a> Encoder<'a> {
context
.ref_cdict(dictionary.as_cdict())
.map_err(map_error_code)?;
Ok(Encoder { context })
Ok(Encoder { context: MaybeOwned::owned(context) })
}

/// Sets a compression parameter for this encoder.
Expand Down Expand Up @@ -330,6 +342,50 @@ impl<'a> Operation for Encoder<'a> {
}
}

struct MaybeOwned<'a, T>(MaybeOwnedInner<'a, T>);

impl<'a, T> MaybeOwned<'a, T> {
pub fn owned(value: T) -> Self {
Self(MaybeOwnedInner::Owned(value))
}

pub fn borrowed(value: &'a mut T) -> Self {
Self(MaybeOwnedInner::Borrowed((value as *mut T) as *mut _, PhantomData))
}
}

impl<'a, T> Deref for MaybeOwned<'a, T> {
type Target = T;

fn deref(&self) -> &Self::Target {
unsafe {
match &self.0 {
MaybeOwnedInner::Owned(x) => x,
MaybeOwnedInner::Borrowed(x, _) => &*(*x as *mut _)
}
}
}
}

impl<'a, T> DerefMut for MaybeOwned<'a, T> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe {
match &mut self.0 {
MaybeOwnedInner::Owned(x) => x,
MaybeOwnedInner::Borrowed(x, _) => &mut *(*x as *mut _)
}
}
}
}

enum MaybeOwnedInner<'a, T> {
Owned(T),
Borrowed(*mut (), PhantomData<&'a ()>)
}

unsafe impl<'a, T: Send> Send for MaybeOwned<'a, T> {}
unsafe impl<'a, T: Sync> Sync for MaybeOwned<'a, T> {}

#[cfg(test)]
mod tests {

Expand Down
5 changes: 5 additions & 0 deletions src/stream/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ impl<R: BufRead> Decoder<'static, R> {
}
}
impl<'a, R: BufRead> Decoder<'a, R> {
/// Creates a new decoder which employs the provided context for deserialization.
pub fn with_context<'b: 'a>(reader: R, context: &'a mut zstd_safe::DCtx<'b>) -> Self {
Self { reader: zio::Reader::new(reader, raw::Decoder::with_context(context)) }
}

/// Sets this `Decoder` to stop after the first frame.
///
/// By default, it keeps concatenating frames until EOF is reached.
Expand Down
5 changes: 5 additions & 0 deletions src/stream/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,11 @@ impl<W: Write> Encoder<'static, W> {
}

impl<'a, W: Write> Encoder<'a, W> {
/// Creates an encoder that uses the provided context to compress a stream.
pub fn with_context<'b: 'a>(writer: W, context: &'a mut zstd_safe::CCtx<'b>) -> Self {
Self { writer: zio::Writer::new(writer, raw::Encoder::with_context(context)) }
}

/// Creates a new encoder, using an existing prepared `EncoderDictionary`.
///
/// (Provides better compression ratio for small files,
Expand Down

0 comments on commit 90a526d

Please sign in to comment.