diff --git a/src/stream/raw.rs b/src/stream/raw.rs index 32d1b14e..048a8ac3 100644 --- a/src/stream/raw.rs +++ b/src/stream/raw.rs @@ -5,6 +5,8 @@ //! //! They are mostly thin wrappers around `zstd_safe::{DCtx, CCtx}`. use std::io; +use std::marker::PhantomData; +use std::ops::{Deref, DerefMut}; pub use zstd_safe::{CParameter, DParameter, InBuffer, OutBuffer, WriteBuf}; @@ -132,7 +134,7 @@ pub struct Status { /// An in-memory decoder for streams of data. pub struct Decoder<'a> { - context: zstd_safe::DCtx<'a>, + context: MaybeOwned<'a, zstd_safe::DCtx<'a>>, } impl Decoder<'static> { @@ -148,11 +150,16 @@ impl Decoder<'static> { context .load_dictionary(dictionary) .map_err(map_error_code)?; - Ok(Decoder { context }) + Ok(Decoder { context: MaybeOwned::owned(context) }) } } impl<'a> Decoder<'a> { + /// Creates a new decoder which employs the provided context for deserialization. + pub fn with_context<'b: 'a>(context: &'a mut zstd_safe::DCtx<'b>) -> Self { + Self { context: MaybeOwned::borrowed(context) } + } + /// Creates a new decoder, using an existing `DecoderDictionary`. pub fn with_prepared_dictionary<'b>( dictionary: &DecoderDictionary<'b>, @@ -164,7 +171,7 @@ impl<'a> Decoder<'a> { context .ref_ddict(dictionary.as_ddict()) .map_err(map_error_code)?; - Ok(Decoder { context }) + Ok(Decoder { context: MaybeOwned::owned(context) }) } /// Sets a decompression parameter for this decoder. @@ -229,7 +236,7 @@ impl Operation for Decoder<'_> { /// An in-memory encoder for streams of data. pub struct Encoder<'a> { - context: zstd_safe::CCtx<'a>, + context: MaybeOwned<'a, zstd_safe::CCtx<'a>>, } impl Encoder<'static> { @@ -250,11 +257,16 @@ impl Encoder<'static> { .load_dictionary(dictionary) .map_err(map_error_code)?; - Ok(Encoder { context }) + Ok(Encoder { context: MaybeOwned::owned(context) }) } } impl<'a> Encoder<'a> { + /// Creates a new encoder that uses the provided context for serialization. + pub fn with_context<'b: 'a>(context: &'a mut zstd_safe::CCtx<'b>) -> Self { + Self { context: MaybeOwned::borrowed(context) } + } + /// Creates a new encoder using an existing `EncoderDictionary`. pub fn with_prepared_dictionary<'b>( dictionary: &EncoderDictionary<'b>, @@ -266,7 +278,7 @@ impl<'a> Encoder<'a> { context .ref_cdict(dictionary.as_cdict()) .map_err(map_error_code)?; - Ok(Encoder { context }) + Ok(Encoder { context: MaybeOwned::owned(context) }) } /// Sets a compression parameter for this encoder. @@ -330,6 +342,50 @@ impl<'a> Operation for Encoder<'a> { } } +struct MaybeOwned<'a, T>(MaybeOwnedInner<'a, T>); + +impl<'a, T> MaybeOwned<'a, T> { + pub fn owned(value: T) -> Self { + Self(MaybeOwnedInner::Owned(value)) + } + + pub fn borrowed(value: &'a mut T) -> Self { + Self(MaybeOwnedInner::Borrowed((value as *mut T) as *mut _, PhantomData)) + } +} + +impl<'a, T> Deref for MaybeOwned<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { + match &self.0 { + MaybeOwnedInner::Owned(x) => x, + MaybeOwnedInner::Borrowed(x, _) => &*(*x as *mut _) + } + } + } +} + +impl<'a, T> DerefMut for MaybeOwned<'a, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { + match &mut self.0 { + MaybeOwnedInner::Owned(x) => x, + MaybeOwnedInner::Borrowed(x, _) => &mut *(*x as *mut _) + } + } + } +} + +enum MaybeOwnedInner<'a, T> { + Owned(T), + Borrowed(*mut (), PhantomData<&'a ()>) +} + +unsafe impl<'a, T: Send> Send for MaybeOwned<'a, T> {} +unsafe impl<'a, T: Sync> Sync for MaybeOwned<'a, T> {} + #[cfg(test)] mod tests { diff --git a/src/stream/read/mod.rs b/src/stream/read/mod.rs index a3a947b0..412d9c71 100644 --- a/src/stream/read/mod.rs +++ b/src/stream/read/mod.rs @@ -46,6 +46,11 @@ impl Decoder<'static, R> { } } impl<'a, R: BufRead> Decoder<'a, R> { + /// Creates a new decoder which employs the provided context for deserialization. + pub fn with_context<'b: 'a>(reader: R, context: &'a mut zstd_safe::DCtx<'b>) -> Self { + Self { reader: zio::Reader::new(reader, raw::Decoder::with_context(context)) } + } + /// Sets this `Decoder` to stop after the first frame. /// /// By default, it keeps concatenating frames until EOF is reached. diff --git a/src/stream/write/mod.rs b/src/stream/write/mod.rs index 9103c2d9..05eecb5e 100644 --- a/src/stream/write/mod.rs +++ b/src/stream/write/mod.rs @@ -193,6 +193,11 @@ impl Encoder<'static, W> { } impl<'a, W: Write> Encoder<'a, W> { + /// Creates an encoder that uses the provided context to compress a stream. + pub fn with_context<'b: 'a>(writer: W, context: &'a mut zstd_safe::CCtx<'b>) -> Self { + Self { writer: zio::Writer::new(writer, raw::Encoder::with_context(context)) } + } + /// Creates a new encoder, using an existing prepared `EncoderDictionary`. /// /// (Provides better compression ratio for small files,