diff --git a/README.md b/README.md index 697cef18..240ed6dd 100644 --- a/README.md +++ b/README.md @@ -75,9 +75,13 @@ with safe_open("model.safetensors", framework="pt", device="cpu") as f: ### Format -- 8 bytes: `N`, a u64 int, containing the size of the header -- N bytes: a JSON utf-8 string representing the header. - - The header is a dict like `{"TENSOR_NAME": {"dtype": "F16", "shape": [1, 16, 256], "data_offsets": [BEGIN, END]}, "NEXT_TENSOR_NAME": {...}, ...}`, where offsets point to the tensor data relative to the beginning of the byte buffer, with `BEGIN` as the starting offset and `END` as the one-past offset (so total tensor byte size = `END - BEGIN`). +- 8 bytes: `N`, an unsigned little-endian 64-bit integer, containing the size of the header +- N bytes: a JSON UTF-8 string representing the header. + - The header data MUST begin with a `{` character (0x7B). + - The header data MAY be trailing padded with whitespace (0x20). + - The header is a dict like `{"TENSOR_NAME": {"dtype": "F16", "shape": [1, 16, 256], "data_offsets": [BEGIN, END]}, "NEXT_TENSOR_NAME": {...}, ...}`, + - `data_offsets` point to the tensor data relative to the beginning of the byte buffer (i.e. not an absolute position in the file), + with `BEGIN` as the starting offset and `END` as the one-past offset (so total tensor byte size = `END - BEGIN`). - A special key `__metadata__` is allowed to contain free form string-to-string map. Arbitrary JSON is not allowed, all values must be strings. - Rest of the file: byte-buffer. diff --git a/safetensors/README.md b/safetensors/README.md index 697cef18..240ed6dd 100644 --- a/safetensors/README.md +++ b/safetensors/README.md @@ -75,9 +75,13 @@ with safe_open("model.safetensors", framework="pt", device="cpu") as f: ### Format -- 8 bytes: `N`, a u64 int, containing the size of the header -- N bytes: a JSON utf-8 string representing the header. - - The header is a dict like `{"TENSOR_NAME": {"dtype": "F16", "shape": [1, 16, 256], "data_offsets": [BEGIN, END]}, "NEXT_TENSOR_NAME": {...}, ...}`, where offsets point to the tensor data relative to the beginning of the byte buffer, with `BEGIN` as the starting offset and `END` as the one-past offset (so total tensor byte size = `END - BEGIN`). +- 8 bytes: `N`, an unsigned little-endian 64-bit integer, containing the size of the header +- N bytes: a JSON UTF-8 string representing the header. + - The header data MUST begin with a `{` character (0x7B). + - The header data MAY be trailing padded with whitespace (0x20). + - The header is a dict like `{"TENSOR_NAME": {"dtype": "F16", "shape": [1, 16, 256], "data_offsets": [BEGIN, END]}, "NEXT_TENSOR_NAME": {...}, ...}`, + - `data_offsets` point to the tensor data relative to the beginning of the byte buffer (i.e. not an absolute position in the file), + with `BEGIN` as the starting offset and `END` as the one-past offset (so total tensor byte size = `END - BEGIN`). - A special key `__metadata__` is allowed to contain free form string-to-string map. Arbitrary JSON is not allowed, all values must be strings. - Rest of the file: byte-buffer. diff --git a/safetensors/src/lib.rs b/safetensors/src/lib.rs index a66b1e7a..6a3fe70c 100644 --- a/safetensors/src/lib.rs +++ b/safetensors/src/lib.rs @@ -52,9 +52,13 @@ //! //!## Format //! -//! - 8 bytes: `N`, a u64 int, containing the size of the header -//! - N bytes: a JSON utf-8 string representing the header. -//! - The header is a dict like `{"TENSOR_NAME": {"dtype": "F16", "shape": [1, 16, 256], "data_offsets": [BEGIN, END]}, "NEXT_TENSOR_NAME": {...}, ...}`, where offsets point to the tensor data relative to the beginning of the byte buffer, with `BEGIN` as the starting offset and `END` as the one-past offset (so total tensor byte size = `END - BEGIN`). +//! - 8 bytes: `N`, an unsigned little-endian 64-bit integer, containing the size of the header +//! - N bytes: a JSON UTF-8 string representing the header. +//! - The header data MUST begin with a `{` character (0x7B). +//! - The header data MAY be trailing padded with whitespace (0x20). +//! - The header is a dict like `{"TENSOR_NAME": {"dtype": "F16", "shape": [1, 16, 256], "data_offsets": [BEGIN, END]}, "NEXT_TENSOR_NAME": {...}, ...}`, +//! - `data_offsets` point to the tensor data relative to the beginning of the byte buffer (i.e. not an absolute position in the file), +//! with `BEGIN` as the starting offset and `END` as the one-past offset (so total tensor byte size = `END - BEGIN`). //! - A special key `__metadata__` is allowed to contain free form string-to-string map. Arbitrary JSON is not allowed, all values must be strings. //! - Rest of the file: byte-buffer. //! diff --git a/safetensors/src/tensor.rs b/safetensors/src/tensor.rs index a2c2dce2..a05f49a0 100644 --- a/safetensors/src/tensor.rs +++ b/safetensors/src/tensor.rs @@ -15,6 +15,8 @@ const MAX_HEADER_SIZE: usize = 100_000_000; pub enum SafeTensorError { /// The header is an invalid UTF-8 string and cannot be read. InvalidHeader, + /// The header's first byte is not the expected `{`. + InvalidHeaderStart, /// The header does contain a valid string, but it is not valid JSON. InvalidHeaderDeserialization, /// The header is large than 100Mo which is considered too large (Might evolve in the future). @@ -302,6 +304,10 @@ impl<'data> SafeTensors<'data> { } let string = std::str::from_utf8(&buffer[8..stop]).map_err(|_| SafeTensorError::InvalidHeader)?; + // Assert the string starts with { + if !string.starts_with('{') { + return Err(SafeTensorError::InvalidHeaderStart); + } let metadata: Metadata = serde_json::from_str(string) .map_err(|_| SafeTensorError::InvalidHeaderDeserialization)?; let buffer_end = metadata.validate()?; @@ -1087,6 +1093,25 @@ mod tests { } } + #[test] + /// Test that the JSON header may be trailing-padded with JSON whitespace characters. + fn test_whitespace_padded_header() { + let serialized = b"\x06\x00\x00\x00\x00\x00\x00\x00{}\x0D\x20\x09\x0A"; + let loaded = SafeTensors::deserialize(serialized).unwrap(); + assert_eq!(loaded.len(), 0); + } + + #[test] + /// Test that the JSON header must begin with a `{` character. + fn test_whitespace_start_padded_header_is_not_allowed() { + let serialized = b"\x06\x00\x00\x00\x00\x00\x00\x00\x09\x0A{}\x0D\x20"; + match SafeTensors::deserialize(serialized) { + Err(SafeTensorError::InvalidHeaderStart) => { + // Correct error + }, + _ => panic!("This should not be able to be deserialized"), + } + } #[test] fn test_zero_sized_tensor() { let serialized = b"<\x00\x00\x00\x00\x00\x00\x00{\"test\":{\"dtype\":\"I32\",\"shape\":[2,0],\"data_offsets\":[0, 0]}}";