Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProvenanceTensor #543

Merged
merged 25 commits into from
Aug 24, 2021
Merged

ProvenanceTensor #543

merged 25 commits into from
Aug 24, 2021

Conversation

ordabayevy
Copy link
Member

@ordabayevy ordabayevy commented Jul 16, 2021

This is an implementation of Provenance Tracking (https://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.361.7132&rep=rep1&type=pdf) in Pytorch. The main idea is that provenance of the output tensor is the union of provenances of input tensors.

Tests:

  • unary and binary ops
  • simple indexing and advanced indexing with Vindex

@ordabayevy ordabayevy added the WIP label Jul 16, 2021
@ordabayevy
Copy link
Member Author

ordabayevy commented Jul 16, 2021

@eb8680 is this along the lines what you were suggesting? Is this new tensor type supposed to be wrapped by funsor.Tensor?

if kwargs is None:
kwargs = {}
meta = frozenset().union(
*tuple(a._metadata for a in args if hasattr(a, "_metadata"))
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provenance of the output tensor is the union of provenances of input tensors.

@ordabayevy ordabayevy changed the title MetadataTensor ProvenanceTensor Jul 26, 2021
return super(ConstantMeta, cls).__call__(const_inputs, arg)


class Constant(Funsor, metaclass=ConstantMeta):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting! It would probably be easiest for us to go over this PR and pyro-ppl/pyro#2893 over Zoom, but one thing that would help me beforehand is if you could add a docstring here explaining how Constant behaves differently from Delta wrt Reduce/Contraction/Integrate

def __repr__(self):
return "Provenance:\n{}\nTensor:\n{}".format(self._provenance, self._t)

def __torch_function__(self, func, types, args=(), kwargs=None):
Copy link
Member

@eb8680 eb8680 Aug 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ordabayevy now that you've had a chance to play around with __torch_function__, I'm curious about whether you think we should add a Funsor.__torch_function__ method and attempt to use it in Pyro more directly in lieu of the combination of ProvenanceTensor and to_data/to_funsor. I opened #546 to discuss.

Copy link
Member

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Implementation seems reasonable, and nicely separated from the rest of the code.

provenance = frozenset()
# extract ProvenanceTensor._t data from args
_args = []
for arg in args:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is a bit convoluted. Maybe it could be simplified with some of the helpers in torch.overrides?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the helpers in torch.overrides?

That might be useful, I look more into torch.overrides functionality.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants