Skip to content

Commit

Permalink
Fixing boolean + numpy > 1.20 (#326)
Browse files Browse the repository at this point in the history
* Fixing boolean + numpy > 1.20

* Adding bool numpy test.
  • Loading branch information
Narsil committed Aug 17, 2023
1 parent 698dd6e commit f4a6df0
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 15 deletions.
50 changes: 35 additions & 15 deletions bindings/python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,8 +473,8 @@ impl Open {
Storage::TorchStorage(storage) => {
Python::with_gil(|py| -> PyResult<PyObject> {
let torch = get_module(py, &TORCH_MODULE)?;
let dtype: PyObject = get_pydtype(torch, info.dtype)?;
let torch_uint8: PyObject = get_pydtype(torch, Dtype::U8)?;
let dtype: PyObject = get_pydtype(torch, info.dtype, false)?;
let torch_uint8: PyObject = get_pydtype(torch, Dtype::U8, false)?;
let kwargs = [(intern!(py, "dtype"), torch_uint8)].into_py_dict(py);
let view_kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict(py);
let shape = info.shape.to_vec();
Expand Down Expand Up @@ -504,7 +504,7 @@ impl Open {
let inplace_kwargs =
[(intern!(py, "inplace"), false.into_py(py))].into_py_dict(py);
if info.dtype == Dtype::BF16 {
let torch_f16: PyObject = get_pydtype(torch, Dtype::F16)?;
let torch_f16: PyObject = get_pydtype(torch, Dtype::F16, false)?;
tensor = tensor.getattr(intern!(py, "to"))?.call(
(),
Some([(intern!(py, "dtype"), torch_f16)].into_py_dict(py)),
Expand All @@ -519,7 +519,7 @@ impl Open {
tensor = torch.getattr(intern!(py, "from_numpy"))?.call1((numpy,))?;

if info.dtype == Dtype::BF16 {
let torch_bf16: PyObject = get_pydtype(torch, Dtype::BF16)?;
let torch_bf16: PyObject = get_pydtype(torch, Dtype::BF16, false)?;
tensor = tensor.getattr(intern!(py, "to"))?.call(
(),
Some([(intern!(py, "dtype"), torch_bf16)].into_py_dict(py)),
Expand Down Expand Up @@ -796,8 +796,8 @@ impl PySafeSlice {
}
Storage::TorchStorage(storage) => Python::with_gil(|py| -> PyResult<PyObject> {
let torch = get_module(py, &TORCH_MODULE)?;
let dtype: PyObject = get_pydtype(torch, self.info.dtype)?;
let torch_uint8: PyObject = get_pydtype(torch, Dtype::U8)?;
let dtype: PyObject = get_pydtype(torch, self.info.dtype, false)?;
let torch_uint8: PyObject = get_pydtype(torch, Dtype::U8, false)?;
let kwargs = [(intern!(py, "dtype"), torch_uint8)].into_py_dict(py);
let view_kwargs = [(intern!(py, "dtype"), dtype)].into_py_dict(py);
let shape = self.info.shape.to_vec();
Expand Down Expand Up @@ -873,13 +873,27 @@ fn create_tensor(
device: &Device,
) -> PyResult<PyObject> {
Python::with_gil(|py| -> PyResult<PyObject> {
let module: &PyModule = match framework {
Framework::Pytorch => TORCH_MODULE.get(py),
_ => NUMPY_MODULE.get(py),
}
.ok_or_else(|| SafetensorError::new_err(format!("Could not find module {framework:?}",)))?
.as_ref(py);
let dtype: PyObject = get_pydtype(module, dtype)?;
let (module, is_numpy): (&PyModule, bool) = match framework {
Framework::Pytorch => (
TORCH_MODULE
.get(py)
.ok_or_else(|| {
SafetensorError::new_err(format!("Could not find module {framework:?}",))
})?
.as_ref(py),
false,
),
_ => (
NUMPY_MODULE
.get(py)
.ok_or_else(|| {
SafetensorError::new_err(format!("Could not find module {framework:?}",))
})?
.as_ref(py),
true,
),
};
let dtype: PyObject = get_pydtype(module, dtype, is_numpy)?;
let count: usize = shape.iter().product();
let shape = shape.to_vec();
let shape: PyObject = shape.into_py(py);
Expand Down Expand Up @@ -939,7 +953,7 @@ fn create_tensor(
})
}

fn get_pydtype(module: &PyModule, dtype: Dtype) -> PyResult<PyObject> {
fn get_pydtype(module: &PyModule, dtype: Dtype, is_numpy: bool) -> PyResult<PyObject> {
Python::with_gil(|py| {
let dtype: PyObject = match dtype {
Dtype::F64 => module.getattr(intern!(py, "float64"))?.into(),
Expand All @@ -954,7 +968,13 @@ fn get_pydtype(module: &PyModule, dtype: Dtype) -> PyResult<PyObject> {
Dtype::I16 => module.getattr(intern!(py, "int16"))?.into(),
Dtype::U8 => module.getattr(intern!(py, "uint8"))?.into(),
Dtype::I8 => module.getattr(intern!(py, "int8"))?.into(),
Dtype::BOOL => module.getattr(intern!(py, "bool"))?.into(),
Dtype::BOOL => {
if is_numpy {
py.import("builtins")?.getattr(intern!(py, "bool"))?.into()
} else {
module.getattr(intern!(py, "bool"))?.into()
}
}
dtype => {
return Err(SafetensorError::new_err(format!(
"Dtype not understood: {dtype:?}"
Expand Down
13 changes: 13 additions & 0 deletions bindings/python/tests/test_simple.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,19 @@ def test_numpy_example(self):
loaded = load(out)
self.assertTensorEqual(tensors, loaded, np.allclose)

def test_numpy_bool(self):
tensors = {"a": np.asarray(False)}

save_file(tensors, "./out_bool.safetensors")
out = save(tensors)

# Now loading
loaded = load_file("./out_bool.safetensors")
self.assertTensorEqual(tensors, loaded, np.allclose)

loaded = load(out)
self.assertTensorEqual(tensors, loaded, np.allclose)

def test_torch_example(self):
tensors = {
"a": torch.zeros((2, 2)),
Expand Down

0 comments on commit f4a6df0

Please sign in to comment.