Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ProvenanceTensor #543

Merged
merged 25 commits into from
Aug 24, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -53,13 +53,16 @@ ifeq (${FUNSOR_BACKEND}, torch)
python examples/adam.py --num-steps=21
@echo PASS
else ifeq (${FUNSOR_BACKEND}, jax)
pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi --ignore=test/test_distribution.py --ignore=test/test_distribution_generic.py
pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi \
--ignore=test/test_distribution.py --ignore=test/test_distribution_generic.py \
--ignore=test/torch
pytest -v -n auto test/test_distribution.py
pytest -v -n auto test/test_distribution_generic.py
@echo PASS
else
# default backend
pytest -v -n auto --ignore=test/examples --ignore=test/pyro --ignore=test/pyroapi
pytest -v -n auto --ignore=test/examples --ignore=test/pyro \
--ignore=test/pyroapi --ignore=test/torch
@echo PASS
endif

Expand Down
65 changes: 65 additions & 0 deletions funsor/torch/provenance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import torch


class ProvenanceTensor(torch.Tensor):
"""
Provenance tracking implementation in Pytorch.

Provenance of the output tensor is the union of provenances of input tensors.
"""

def __new__(cls, data, provenance=frozenset(), **kwargs):
if not provenance:
return data
instance = torch.Tensor.__new__(cls)
instance.__init__(data, provenance)
return instance

def __init__(self, data, provenance=frozenset()):
assert isinstance(provenance, frozenset)
if isinstance(data, ProvenanceTensor):
provenance |= data._provenance
data = data._t
self._t = data
self._provenance = provenance

def __repr__(self):
return "Provenance:\n{}\nTensor:\n{}".format(self._provenance, self._t)

def __torch_function__(self, func, types, args=(), kwargs=None):
Copy link
Member

@eb8680 eb8680 Aug 3, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@ordabayevy now that you've had a chance to play around with __torch_function__, I'm curious about whether you think we should add a Funsor.__torch_function__ method and attempt to use it in Pyro more directly in lieu of the combination of ProvenanceTensor and to_data/to_funsor. I opened #546 to discuss.

if kwargs is None:
kwargs = {}
# collect provenance information from args
provenance = frozenset()
# extract ProvenanceTensor._t data from args
_args = []
for arg in args:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is a bit convoluted. Maybe it could be simplified with some of the helpers in torch.overrides?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the helpers in torch.overrides?

That might be useful, I look more into torch.overrides functionality.

if isinstance(arg, ProvenanceTensor):
provenance |= arg._provenance
_args.append(arg._t)
elif isinstance(arg, tuple):
_arg = []
for a in arg:
if isinstance(a, ProvenanceTensor):
provenance |= a._provenance
_arg.append(a._t)
else:
_arg.append(a)
_args.append(tuple(_arg))
else:
_args.append(arg)
ret = func(*_args, **kwargs)
if isinstance(ret, torch.Tensor):
return ProvenanceTensor(ret, provenance=provenance)
if isinstance(ret, tuple):
_ret = []
for r in ret:
if isinstance(r, torch.Tensor):
_ret.append(ProvenanceTensor(r, provenance=provenance))
else:
_ret.append(r)
return tuple(_ret)
return ret
103 changes: 103 additions & 0 deletions test/torch/test_provenance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
# Copyright Contributors to the Pyro project.
# SPDX-License-Identifier: Apache-2.0

import pytest
import torch
from pyro.ops.indexing import Vindex

from funsor.torch.provenance import ProvenanceTensor


@pytest.mark.parametrize("op", ["log", "exp", "long"])
@pytest.mark.parametrize(
"data,provenance",
[
(torch.tensor([1]), "ab"),
(torch.tensor([1]), "a"),
],
)
def test_unary(op, data, provenance):
data = ProvenanceTensor(data, frozenset(provenance))

expected = frozenset(provenance)
actual = getattr(data, op)()._provenance
assert actual == expected


@pytest.mark.parametrize("data1,provenance1", [(torch.tensor([1]), "a")])
@pytest.mark.parametrize(
"data2,provenance2",
[
(torch.tensor([2]), "b"),
(torch.tensor([2]), ""),
(2, ""),
],
)
def test_binary_add(data1, provenance1, data2, provenance2):
data1 = ProvenanceTensor(data1, frozenset(provenance1))
if provenance2:
data2 = ProvenanceTensor(data2, frozenset(provenance2))

expected = frozenset(provenance1 + provenance2)
actual = torch.add(data1, data2)._provenance
assert actual == expected


@pytest.mark.parametrize(
"data1,provenance1",
[
(torch.tensor([0, 1]), "a"),
(torch.tensor([0, 1]), ""),
],
)
@pytest.mark.parametrize(
"data2,provenance2",
[
(torch.tensor([0]), "b"),
(torch.tensor([1]), ""),
],
)
def test_indexing(data1, provenance1, data2, provenance2):
if provenance1:
data1 = ProvenanceTensor(data1, frozenset(provenance1))
if provenance2:
data2 = ProvenanceTensor(data2, frozenset(provenance2))

expected = frozenset(provenance1 + provenance2)
actual = getattr(data1[data2], "_provenance", frozenset())
assert actual == expected


@pytest.mark.parametrize(
"data1,provenance1",
[
(torch.tensor([[0, 1], [2, 3]]), "a"),
(torch.tensor([[0, 1], [2, 3]]), ""),
],
)
@pytest.mark.parametrize(
"data2,provenance2",
[
(torch.tensor([0.0, 1.0]), "b"),
(torch.tensor([0.0, 1.0]), ""),
],
)
@pytest.mark.parametrize(
"data3,provenance3",
[
(torch.tensor([0, 1]), "c"),
(torch.tensor([0, 1]), ""),
],
)
def test_vindex(data1, provenance1, data2, provenance2, data3, provenance3):
if provenance1:
data1 = ProvenanceTensor(data1, frozenset(provenance1))
if provenance2:
data2 = ProvenanceTensor(data2, frozenset(provenance2))
if provenance3:
data3 = ProvenanceTensor(data3, frozenset(provenance3))

expected = frozenset(provenance1 + provenance2 + provenance3)
result = Vindex(data1)[data2.long().unsqueeze(-1), data3]
actual = getattr(result, "_provenance", frozenset())
assert actual == expected