Skip to content

Commit

Permalink
Add the ability to provide a safe context for encoders to use
Browse files Browse the repository at this point in the history
Remove unsafe methods and clean up MaybeOwned enums
  • Loading branch information
DouglasDwyer authored and gyscos committed Jul 5, 2024
1 parent a3738d6 commit 6688207
Show file tree
Hide file tree
Showing 3 changed files with 121 additions and 29 deletions.
124 changes: 95 additions & 29 deletions src/stream/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub struct Status {

/// An in-memory decoder for streams of data.
pub struct Decoder<'a> {
context: zstd_safe::DCtx<'a>,
context: MaybeOwnedDCtx<'a>,
}

impl Decoder<'static> {
Expand All @@ -148,11 +148,20 @@ impl Decoder<'static> {
context
.load_dictionary(dictionary)
.map_err(map_error_code)?;
Ok(Decoder { context })
Ok(Decoder {
context: MaybeOwnedDCtx::Owned(context),
})
}
}

impl<'a> Decoder<'a> {
/// Creates a new decoder which employs the provided context for deserialization.
pub fn with_context(context: &'a mut zstd_safe::DCtx<'static>) -> Self {
Self {
context: MaybeOwnedDCtx::Borrowed(context),
}
}

/// Creates a new decoder, using an existing `DecoderDictionary`.
pub fn with_prepared_dictionary<'b>(
dictionary: &DecoderDictionary<'b>,
Expand All @@ -164,7 +173,9 @@ impl<'a> Decoder<'a> {
context
.ref_ddict(dictionary.as_ddict())
.map_err(map_error_code)?;
Ok(Decoder { context })
Ok(Decoder {
context: MaybeOwnedDCtx::Owned(context),
})
}

/// Creates a new decoder, using a ref prefix
Expand All @@ -183,9 +194,11 @@ impl<'a> Decoder<'a> {

/// Sets a decompression parameter for this decoder.
pub fn set_parameter(&mut self, parameter: DParameter) -> io::Result<()> {
self.context
.set_parameter(parameter)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedDCtx::Owned(x) => x.set_parameter(parameter),
MaybeOwnedDCtx::Borrowed(x) => x.set_parameter(parameter),
}
.map_err(map_error_code)?;
Ok(())
}
}
Expand All @@ -196,9 +209,11 @@ impl Operation for Decoder<'_> {
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
self.context
.decompress_stream(output, input)
.map_err(map_error_code)
match &mut self.context {
MaybeOwnedDCtx::Owned(x) => x.decompress_stream(output, input),
MaybeOwnedDCtx::Borrowed(x) => x.decompress_stream(output, input),
}
.map_err(map_error_code)
}

fn flush<C: WriteBuf + ?Sized>(
Expand All @@ -219,9 +234,15 @@ impl Operation for Decoder<'_> {
}

fn reinit(&mut self) -> io::Result<()> {
self.context
.reset(zstd_safe::ResetDirective::SessionOnly)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedDCtx::Owned(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
MaybeOwnedDCtx::Borrowed(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
}
.map_err(map_error_code)?;
Ok(())
}

Expand All @@ -243,7 +264,7 @@ impl Operation for Decoder<'_> {

/// An in-memory encoder for streams of data.
pub struct Encoder<'a> {
context: zstd_safe::CCtx<'a>,
context: MaybeOwnedCCtx<'a>,
}

impl Encoder<'static> {
Expand All @@ -264,11 +285,20 @@ impl Encoder<'static> {
.load_dictionary(dictionary)
.map_err(map_error_code)?;

Ok(Encoder { context })
Ok(Encoder {
context: MaybeOwnedCCtx::Owned(context),
})
}
}

impl<'a> Encoder<'a> {
/// Creates a new encoder that uses the provided context for serialization.
pub fn with_context(context: &'a mut zstd_safe::CCtx<'static>) -> Self {
Self {
context: MaybeOwnedCCtx::Borrowed(context),
}
}

/// Creates a new encoder using an existing `EncoderDictionary`.
pub fn with_prepared_dictionary<'b>(
dictionary: &EncoderDictionary<'b>,
Expand All @@ -280,7 +310,9 @@ impl<'a> Encoder<'a> {
context
.ref_cdict(dictionary.as_cdict())
.map_err(map_error_code)?;
Ok(Encoder { context })
Ok(Encoder {
context: MaybeOwnedCCtx::Owned(context),
})
}

/// Creates a new encoder initialized with the given ref prefix.
Expand All @@ -306,9 +338,11 @@ impl<'a> Encoder<'a> {

/// Sets a compression parameter for this encoder.
pub fn set_parameter(&mut self, parameter: CParameter) -> io::Result<()> {
self.context
.set_parameter(parameter)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.set_parameter(parameter),
MaybeOwnedCCtx::Borrowed(x) => x.set_parameter(parameter),
}
.map_err(map_error_code)?;
Ok(())
}

Expand All @@ -324,9 +358,15 @@ impl<'a> Encoder<'a> {
&mut self,
pledged_src_size: Option<u64>,
) -> io::Result<()> {
self.context
.set_pledged_src_size(pledged_src_size)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => {
x.set_pledged_src_size(pledged_src_size)
}
MaybeOwnedCCtx::Borrowed(x) => {
x.set_pledged_src_size(pledged_src_size)
}
}
.map_err(map_error_code)?;
Ok(())
}
}
Expand All @@ -337,34 +377,60 @@ impl<'a> Operation for Encoder<'a> {
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
self.context
.compress_stream(output, input)
.map_err(map_error_code)
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.compress_stream(output, input),
MaybeOwnedCCtx::Borrowed(x) => x.compress_stream(output, input),
}
.map_err(map_error_code)
}

fn flush<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
self.context.flush_stream(output).map_err(map_error_code)
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.flush_stream(output),
MaybeOwnedCCtx::Borrowed(x) => x.flush_stream(output),
}
.map_err(map_error_code)
}

fn finish<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
_finished_frame: bool,
) -> io::Result<usize> {
self.context.end_stream(output).map_err(map_error_code)
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.end_stream(output),
MaybeOwnedCCtx::Borrowed(x) => x.end_stream(output),
}
.map_err(map_error_code)
}

fn reinit(&mut self) -> io::Result<()> {
self.context
.reset(zstd_safe::ResetDirective::SessionOnly)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
MaybeOwnedCCtx::Borrowed(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
}
.map_err(map_error_code)?;
Ok(())
}
}

enum MaybeOwnedCCtx<'a> {
Owned(zstd_safe::CCtx<'a>),
Borrowed(&'a mut zstd_safe::CCtx<'static>),
}

enum MaybeOwnedDCtx<'a> {
Owned(zstd_safe::DCtx<'a>),
Borrowed(&'a mut zstd_safe::DCtx<'static>),
}

#[cfg(test)]
mod tests {

Expand Down
13 changes: 13 additions & 0 deletions src/stream/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ impl<R: BufRead> Decoder<'static, R> {
}
}
impl<'a, R: BufRead> Decoder<'a, R> {
/// Creates a new decoder which employs the provided context for deserialization.
pub fn with_context(
reader: R,
context: &'a mut zstd_safe::DCtx<'static>,
) -> Self {
Self {
reader: zio::Reader::new(
reader,
raw::Decoder::with_context(context),
),
}
}

/// Sets this `Decoder` to stop after the first frame.
///
/// By default, it keeps concatenating frames until EOF is reached.
Expand Down
13 changes: 13 additions & 0 deletions src/stream/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,19 @@ impl<W: Write> Encoder<'static, W> {
}

impl<'a, W: Write> Encoder<'a, W> {
/// Creates an encoder that uses the provided context to compress a stream.
pub fn with_context(
writer: W,
context: &'a mut zstd_safe::CCtx<'static>,
) -> Self {
Self {
writer: zio::Writer::new(
writer,
raw::Encoder::with_context(context),
),
}
}

/// Creates a new encoder, using an existing prepared `EncoderDictionary`.
///
/// (Provides better compression ratio for small files,
Expand Down

0 comments on commit 6688207

Please sign in to comment.