Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing uniform empty tensor handling #283

Merged
merged 1 commit into from
Jun 30, 2023
Merged

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Jun 27, 2023

What does this PR do?

Empty tensors would be accepted if single, disallowed if many (because storage overlap).
This fixes it by simply ignoring empty tensors like meta tensors are ignored.

Fixes # (issue) or description of the problem this PR solves.

@Narsil Narsil requested review from sgugger and thomasw21 June 27, 2023 14:02
Copy link
Contributor

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Makes sense, thanks for adding this!

self.assertTrue(torch.equal(data["test2"], reloaded["test2"]))

def test_disjoint_tensors_shared_storage(self):
A = torch.zeros((10, 10))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, not fan of capitalized names for variables. Not fan of one-letter names either ;-)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you have a better suggestion ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor, array, matrix

@@ -36,7 +36,7 @@ def storage_size(tensor: torch.Tensor) -> int:
def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]:
tensors = defaultdict(set)
for k, v in state_dict.items():
if v.device != torch.device("meta"):
if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can't you just test that the tensor.numel() != 0 ? Nothing against the current implementation but somehow I don't understand why you need to check the pointer.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because it's the storage pointer, not the tensor I'm looking for.
You can have an empty tensor, that still is backed by shared storage, in which case I still want to yell.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum so why not do storage_size(v) != 0 only? You should accept any tensor that has empty storage regardless on where they are actually stored no?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, I'm also trying to safeguard users from doing wrong things.
Saving empty tensors should worry most users.

Here this was never allocated in the first place, so it looks intentional.

Copy link
Contributor

@thomasw21 thomasw21 Jun 28, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't understand. Am I right in the following?

torch.tensor([]) # want to support
torch.zeros((2,0)) # want to support
torch.zeros((2,1))[:, :0] # no support

If so, storage_size(v) != 0 should work.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

>>> l = [torch.tensor([]), torch.zeros((2,0)), torch.zeros((2,1))[:, :0]]
>>> for elt in l:
...   print(elt.untyped_storage().size())
... 
0
0
8

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.zeros((2,1))[:, :0] this will work if it's alone in the state dict actually. But if 2 tensors share the same storage, then crash will happen. I don't really want to start looking if the slices actually overlap or not.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hum what I mean is if your storage is 0, there's no overlap or anything.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.zeros((2,1))[:, :0] has not a 0 sized storage.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay I think I may be misunderstanding something. Anyway, it's not very important.

@Narsil Narsil merged commit e967126 into main Jun 30, 2023
9 of 10 checks passed
@Narsil Narsil deleted the harmonize_empty_tensor_handling branch June 30, 2023 07:53
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Cannot save more than 1 empty pyTorch tensor in a safetensors file.
3 participants