-
Notifications
You must be signed in to change notification settings - Fork 189
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixing uniform empty tensor handling #283
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Makes sense, thanks for adding this!
self.assertTrue(torch.equal(data["test2"], reloaded["test2"])) | ||
|
||
def test_disjoint_tensors_shared_storage(self): | ||
A = torch.zeros((10, 10)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit, not fan of capitalized names for variables. Not fan of one-letter names either ;-)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do you have a better suggestion ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor
, array
, matrix
@@ -36,7 +36,7 @@ def storage_size(tensor: torch.Tensor) -> int: | |||
def _find_shared_tensors(state_dict: Dict[str, torch.Tensor]) -> List[Set[str]]: | |||
tensors = defaultdict(set) | |||
for k, v in state_dict.items(): | |||
if v.device != torch.device("meta"): | |||
if v.device != torch.device("meta") and storage_ptr(v) != 0 and storage_size(v) != 0: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can't you just test that the tensor.numel() != 0
? Nothing against the current implementation but somehow I don't understand why you need to check the pointer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Because it's the storage pointer, not the tensor I'm looking for.
You can have an empty tensor, that still is backed by shared storage, in which case I still want to yell.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum so why not do storage_size(v) != 0
only? You should accept any tensor that has empty storage regardless on where they are actually stored no?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, I'm also trying to safeguard users from doing wrong things.
Saving empty tensors should worry most users.
Here this was never allocated in the first place, so it looks intentional.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't understand. Am I right in the following?
torch.tensor([]) # want to support
torch.zeros((2,0)) # want to support
torch.zeros((2,1))[:, :0] # no support
If so, storage_size(v) != 0
should work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
>>> l = [torch.tensor([]), torch.zeros((2,0)), torch.zeros((2,1))[:, :0]]
>>> for elt in l:
... print(elt.untyped_storage().size())
...
0
0
8
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.zeros((2,1))[:, :0]
this will work if it's alone in the state dict actually. But if 2 tensors share the same storage, then crash will happen. I don't really want to start looking if the slices actually overlap or not.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hum what I mean is if your storage is 0, there's no overlap or anything.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.zeros((2,1))[:, :0]
has not a 0
sized storage.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay I think I may be misunderstanding something. Anyway, it's not very important.
What does this PR do?
Empty tensors would be accepted if single, disallowed if many (because storage overlap).
This fixes it by simply ignoring empty tensors like meta tensors are ignored.
Fixes # (issue) or description of the problem this PR solves.