Skip to content

Commit

Permalink
use Vec<u8> instead of Vec<MaybeUninit<u8>> as buffers
Browse files Browse the repository at this point in the history
Because according to the documentation,
transmuting `&mut [MaybeUninit<u8>]` to `&mut [u8]`
is unsound.

The allocation is only once and
the initialization process impact shouldn't
be notable.

Fixes #5
  • Loading branch information
bczhc committed Jun 12, 2023
1 parent e2ed199 commit 9de4568
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 63 deletions.
22 changes: 0 additions & 22 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@
//! ```
extern crate core;

use std::mem;
use std::mem::MaybeUninit;
use std::{ffi::CStr, io::Read};

use bytesize::ByteSize;
Expand Down Expand Up @@ -80,26 +78,6 @@ where
}
}

fn init_buffer(size: usize) -> Vec<MaybeUninit<u8>> {
let mut buffer = Vec::<MaybeUninit<u8>>::with_capacity(size);
unsafe {
buffer.set_len(size);
}
buffer
}

#[inline(always)]
unsafe fn transmute_uninitialized_buffer(buffer: &mut [MaybeUninit<u8>]) -> &mut [u8] {
mem::transmute(buffer)
}

fn uninit_copy_from_slice(src: &[u8], dst: &mut [MaybeUninit<u8>]) {
unsafe {
let transmute: &[MaybeUninit<u8>] = mem::transmute(src);
dst.copy_from_slice(transmute);
}
}

pub fn version() -> &'static str {
unsafe { CStr::from_ptr(libbzip3_sys::bz3_version()) }
.to_str()
Expand Down
18 changes: 7 additions & 11 deletions src/read.rs
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
//! Read-based BZip3 compressor and decompressor.

use std::io::{Cursor, ErrorKind, Read, Write};
use std::mem::MaybeUninit;
use std::{io, slice};

use byteorder::{ReadBytesExt, WriteBytesExt, LE};

use libbzip3_sys::{bz3_decode_block, bz3_encode_block};

use crate::errors::*;
use crate::{init_buffer, transmute_uninitialized_buffer, Bz3State, TryReadExact, MAGIC_NUMBER};
use crate::{Bz3State, TryReadExact, MAGIC_NUMBER};

pub struct Bz3Encoder<R>
where
Expand All @@ -18,7 +17,7 @@ where
state: Bz3State,
reader: R,
/// Temporary buffer for [`Read::read`]
buffer: Vec<MaybeUninit<u8>>,
buffer: Vec<u8>,
buffer_pos: usize,
buffer_len: usize,
block_size: usize,
Expand All @@ -42,16 +41,13 @@ where
let state = Bz3State::new(block_size)?;

let buffer_size = block_size + block_size / 50 + 32 + MAGIC_NUMBER.len() + 4;
let mut buffer = Vec::<MaybeUninit<u8>>::with_capacity(buffer_size);
unsafe {
buffer.set_len(buffer_size);
}
let mut buffer = vec![0_u8; buffer_size];

let mut header = Cursor::new(Vec::new());
header.write_all(MAGIC_NUMBER).unwrap();
header.write_i32::<LE>(block_size as i32).unwrap();
for x in header.get_ref().iter().enumerate() {
buffer[x.0] = MaybeUninit::new(*x.1);
buffer[x.0] = *x.1;
}

Ok(Self {
Expand Down Expand Up @@ -163,7 +159,7 @@ where
state: Bz3State,
reader: R,
/// Temporary buffer for [`Read::read`]
buffer: Vec<MaybeUninit<u8>>,
buffer: Vec<u8>,
buffer_pos: usize,
buffer_len: usize,
block_size: usize,
Expand Down Expand Up @@ -198,7 +194,7 @@ where
let state = Bz3State::new(block_size)?;

let buffer_size = block_size + block_size / 50 + 32;
let buffer = init_buffer(buffer_size);
let buffer = vec![0_u8; buffer_size];

Ok(Self {
state,
Expand Down Expand Up @@ -248,7 +244,7 @@ where

debug_assert!(self.buffer.len() >= read_size);

let buffer = unsafe { transmute_uninitialized_buffer(&mut self.buffer) };
let buffer = &mut self.buffer;
self.reader.read_exact(&mut buffer[..(new_size as usize)])?;

unsafe {
Expand Down
47 changes: 17 additions & 30 deletions src/write.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,23 @@
//! Write-based BZip3 compressor and decompressor.

use std::io::{Cursor, Read, Write};
use std::mem::{size_of, MaybeUninit};
use std::mem::size_of;
use std::{io, mem};

use byteorder::{ReadBytesExt, WriteBytesExt, LE};

use libbzip3_sys::{bz3_decode_block, bz3_encode_block};

use crate::errors::*;
use crate::{
init_buffer, transmute_uninitialized_buffer, uninit_copy_from_slice, Bz3State, MAGIC_NUMBER,
};
use crate::{Bz3State, MAGIC_NUMBER};

pub struct Bz3Encoder<W>
where
W: Write,
{
writer: W,
state: Bz3State,
buffer: Vec<MaybeUninit<u8>>,
buffer: Vec<u8>,
buffer_pos: usize,
block_size: usize,
}
Expand All @@ -43,7 +41,7 @@ where
writer.write_all(header.get_ref())?;

let buffer_size = block_size + block_size / 50 + 32;
let buffer = init_buffer(buffer_size as usize);
let buffer = vec![0; buffer_size as usize];

Ok(Self {
writer,
Expand All @@ -59,11 +57,8 @@ where
let data_size = self.buffer_pos;
debug_assert!(data_size <= self.block_size);
unsafe {
let new_size = bz3_encode_block(
self.state.raw,
transmute_uninitialized_buffer(&mut self.buffer).as_mut_ptr(),
data_size as i32,
);
let new_size =
bz3_encode_block(self.state.raw, self.buffer.as_mut_ptr(), data_size as i32);
if new_size == -1 {
return Err(Error::ProcessBlock(self.state.error().into()));
}
Expand Down Expand Up @@ -99,10 +94,8 @@ where
write_size = remaining_size;
}

uninit_copy_from_slice(
&buf[..write_size],
&mut self.buffer[self.buffer_pos..(self.buffer_pos + write_size)],
);
self.buffer[self.buffer_pos..(self.buffer_pos + write_size)]
.copy_from_slice(&buf[..write_size]);

self.buffer_pos += write_size;

Expand Down Expand Up @@ -131,7 +124,7 @@ where
{
writer: W,
state: Option<Bz3State>,
buffer: Vec<MaybeUninit<u8>>,
buffer: Vec<u8>,
buffer_pos: usize,
header_len: usize,
block_header_buf: [u8; BLOCK_HEADER_SIZE], /* (i32, i32) */
Expand Down Expand Up @@ -166,7 +159,7 @@ where
Self {
state: None, /* here can't get the block size */
writer,
buffer: init_buffer(header_len), /* need header data to initialize first */
buffer: vec![0_u8; header_len], /* need header data to initialize first */
buffer_pos: 0,
header_len,
block_header_buf: [0_u8; 8],
Expand All @@ -176,8 +169,7 @@ where
}

fn initialize(&mut self) -> Result<()> {
let buffer = unsafe { transmute_uninitialized_buffer(&mut self.buffer) };
let mut cursor = Cursor::new(buffer);
let mut cursor = Cursor::new(&mut self.buffer);
let mut magic = [0_u8; MAGIC_NUMBER.len()];
cursor.read_exact(&mut magic).unwrap();
if &magic != MAGIC_NUMBER {
Expand All @@ -186,7 +178,7 @@ where
let block_size = cursor.read_i32::<LE>().unwrap();
// reinitialize the buffer
let buffer_size = block_size + block_size / 50 + 32;
self.buffer = init_buffer(buffer_size as usize);
self.buffer = vec![0_u8; buffer_size as usize];
self.state = Some(Bz3State::new(block_size as usize)?);
Ok(())
}
Expand All @@ -197,10 +189,9 @@ where

let Some(block_header) = &self.block_header else { unreachable!() };
unsafe {
let buffer = transmute_uninitialized_buffer(&mut self.buffer);
let result = bz3_decode_block(
state.raw,
buffer.as_mut_ptr(),
self.buffer.as_mut_ptr(),
block_header.new_size,
block_header.read_size,
);
Expand Down Expand Up @@ -228,10 +219,8 @@ where
if write_size > needed_size {
write_size = needed_size;
}
uninit_copy_from_slice(
&buf[..write_size],
&mut self.buffer[self.buffer_pos..(self.buffer_pos + write_size)],
);
self.buffer[self.buffer_pos..(self.buffer_pos + write_size)]
.copy_from_slice(&buf[..write_size]);
self.buffer_pos += write_size;
if self.buffer_pos == self.header_len {
// header prepared
Expand Down Expand Up @@ -269,10 +258,8 @@ where
if write_size > needed_size {
write_size = needed_size;
}
uninit_copy_from_slice(
&buf[..write_size],
&mut self.buffer[self.buffer_pos..(self.buffer_pos + write_size)],
);
self.buffer[self.buffer_pos..(self.buffer_pos + write_size)]
.copy_from_slice(&buf[..write_size]);
self.buffer_pos += write_size;
if self.buffer_pos == block_header.new_size as usize {
self.decompress_block().map_err(Error::into_io_error)?;
Expand Down

0 comments on commit 9de4568

Please sign in to comment.