Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow reusing the same context across multiple Encoders and Decoders #247

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
124 changes: 95 additions & 29 deletions src/stream/raw.rs
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ pub struct Status {

/// An in-memory decoder for streams of data.
pub struct Decoder<'a> {
context: zstd_safe::DCtx<'a>,
context: MaybeOwnedDCtx<'a>,
}

impl Decoder<'static> {
Expand All @@ -148,11 +148,20 @@ impl Decoder<'static> {
context
.load_dictionary(dictionary)
.map_err(map_error_code)?;
Ok(Decoder { context })
Ok(Decoder {
context: MaybeOwnedDCtx::Owned(context),
})
}
}

impl<'a> Decoder<'a> {
/// Creates a new decoder which employs the provided context for deserialization.
pub fn with_context(context: &'a mut zstd_safe::DCtx<'static>) -> Self {
Self {
context: MaybeOwnedDCtx::Borrowed(context),
}
}

/// Creates a new decoder, using an existing `DecoderDictionary`.
pub fn with_prepared_dictionary<'b>(
dictionary: &DecoderDictionary<'b>,
Expand All @@ -164,14 +173,18 @@ impl<'a> Decoder<'a> {
context
.ref_ddict(dictionary.as_ddict())
.map_err(map_error_code)?;
Ok(Decoder { context })
Ok(Decoder {
context: MaybeOwnedDCtx::Owned(context),
})
}

/// Sets a decompression parameter for this decoder.
pub fn set_parameter(&mut self, parameter: DParameter) -> io::Result<()> {
self.context
.set_parameter(parameter)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedDCtx::Owned(x) => x.set_parameter(parameter),
MaybeOwnedDCtx::Borrowed(x) => x.set_parameter(parameter),
}
.map_err(map_error_code)?;
Ok(())
}
}
Expand All @@ -182,9 +195,11 @@ impl Operation for Decoder<'_> {
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
self.context
.decompress_stream(output, input)
.map_err(map_error_code)
match &mut self.context {
MaybeOwnedDCtx::Owned(x) => x.decompress_stream(output, input),
MaybeOwnedDCtx::Borrowed(x) => x.decompress_stream(output, input),
}
.map_err(map_error_code)
}

fn flush<C: WriteBuf + ?Sized>(
Expand All @@ -205,9 +220,15 @@ impl Operation for Decoder<'_> {
}

fn reinit(&mut self) -> io::Result<()> {
self.context
.reset(zstd_safe::ResetDirective::SessionOnly)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedDCtx::Owned(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
MaybeOwnedDCtx::Borrowed(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
}
.map_err(map_error_code)?;
Ok(())
}

Expand All @@ -229,7 +250,7 @@ impl Operation for Decoder<'_> {

/// An in-memory encoder for streams of data.
pub struct Encoder<'a> {
context: zstd_safe::CCtx<'a>,
context: MaybeOwnedCCtx<'a>,
}

impl Encoder<'static> {
Expand All @@ -250,11 +271,20 @@ impl Encoder<'static> {
.load_dictionary(dictionary)
.map_err(map_error_code)?;

Ok(Encoder { context })
Ok(Encoder {
context: MaybeOwnedCCtx::Owned(context),
})
}
}

impl<'a> Encoder<'a> {
/// Creates a new encoder that uses the provided context for serialization.
pub fn with_context(context: &'a mut zstd_safe::CCtx<'static>) -> Self {
Self {
context: MaybeOwnedCCtx::Borrowed(context),
}
}

/// Creates a new encoder using an existing `EncoderDictionary`.
pub fn with_prepared_dictionary<'b>(
dictionary: &EncoderDictionary<'b>,
Expand All @@ -266,14 +296,18 @@ impl<'a> Encoder<'a> {
context
.ref_cdict(dictionary.as_cdict())
.map_err(map_error_code)?;
Ok(Encoder { context })
Ok(Encoder {
context: MaybeOwnedCCtx::Owned(context),
})
}

/// Sets a compression parameter for this encoder.
pub fn set_parameter(&mut self, parameter: CParameter) -> io::Result<()> {
self.context
.set_parameter(parameter)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.set_parameter(parameter),
MaybeOwnedCCtx::Borrowed(x) => x.set_parameter(parameter),
}
.map_err(map_error_code)?;
Ok(())
}

Expand All @@ -289,9 +323,15 @@ impl<'a> Encoder<'a> {
&mut self,
pledged_src_size: Option<u64>,
) -> io::Result<()> {
self.context
.set_pledged_src_size(pledged_src_size)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => {
x.set_pledged_src_size(pledged_src_size)
}
MaybeOwnedCCtx::Borrowed(x) => {
x.set_pledged_src_size(pledged_src_size)
}
}
.map_err(map_error_code)?;
Ok(())
}
}
Expand All @@ -302,34 +342,60 @@ impl<'a> Operation for Encoder<'a> {
input: &mut InBuffer<'_>,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
self.context
.compress_stream(output, input)
.map_err(map_error_code)
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.compress_stream(output, input),
MaybeOwnedCCtx::Borrowed(x) => x.compress_stream(output, input),
}
.map_err(map_error_code)
}

fn flush<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
) -> io::Result<usize> {
self.context.flush_stream(output).map_err(map_error_code)
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.flush_stream(output),
MaybeOwnedCCtx::Borrowed(x) => x.flush_stream(output),
}
.map_err(map_error_code)
}

fn finish<C: WriteBuf + ?Sized>(
&mut self,
output: &mut OutBuffer<'_, C>,
_finished_frame: bool,
) -> io::Result<usize> {
self.context.end_stream(output).map_err(map_error_code)
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => x.end_stream(output),
MaybeOwnedCCtx::Borrowed(x) => x.end_stream(output),
}
.map_err(map_error_code)
}

fn reinit(&mut self) -> io::Result<()> {
self.context
.reset(zstd_safe::ResetDirective::SessionOnly)
.map_err(map_error_code)?;
match &mut self.context {
MaybeOwnedCCtx::Owned(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
MaybeOwnedCCtx::Borrowed(x) => {
x.reset(zstd_safe::ResetDirective::SessionOnly)
}
}
.map_err(map_error_code)?;
Ok(())
}
}

enum MaybeOwnedCCtx<'a> {
Owned(zstd_safe::CCtx<'a>),
Borrowed(&'a mut zstd_safe::CCtx<'static>),
}

enum MaybeOwnedDCtx<'a> {
Owned(zstd_safe::DCtx<'a>),
Borrowed(&'a mut zstd_safe::DCtx<'static>),
}

#[cfg(test)]
mod tests {

Expand Down
13 changes: 13 additions & 0 deletions src/stream/read/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,19 @@ impl<R: BufRead> Decoder<'static, R> {
}
}
impl<'a, R: BufRead> Decoder<'a, R> {
/// Creates a new decoder which employs the provided context for deserialization.
pub fn with_context(
reader: R,
context: &'a mut zstd_safe::DCtx<'static>,
) -> Self {
Self {
reader: zio::Reader::new(
reader,
raw::Decoder::with_context(context),
),
}
}

/// Sets this `Decoder` to stop after the first frame.
///
/// By default, it keeps concatenating frames until EOF is reached.
Expand Down
13 changes: 13 additions & 0 deletions src/stream/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,19 @@ impl<W: Write> Encoder<'static, W> {
}

impl<'a, W: Write> Encoder<'a, W> {
/// Creates an encoder that uses the provided context to compress a stream.
pub fn with_context(
writer: W,
context: &'a mut zstd_safe::CCtx<'static>,
) -> Self {
Self {
writer: zio::Writer::new(
writer,
raw::Encoder::with_context(context),
),
}
}

/// Creates a new encoder, using an existing prepared `EncoderDictionary`.
///
/// (Provides better compression ratio for small files,
Expand Down
Loading