Skip to content

Commit

Permalink
InteractionFeatures input transform (#2560)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #2560

InteractionFeatures input transform to compute first-order interactions between inputs.

Used for feature importance work in conjunction with (warped) linear models.

Reviewed By: sdaulton

Differential Revision: D63673008

fbshipit-source-id: 1e57431b92f55cf25b711d5a35b8606f77a58c69
  • Loading branch information
Carl Hvarfner authored and facebook-github-bot committed Oct 2, 2024
1 parent 6d327b9 commit 68faeff
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 1 deletion.
26 changes: 25 additions & 1 deletion botorch/models/transforms/input.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import torch
from botorch.exceptions.errors import BotorchTensorDimensionError
from botorch.exceptions.warnings import UserInputWarning
from botorch.models.transforms.utils import subset_transform
from botorch.models.transforms.utils import interaction_features, subset_transform
from botorch.models.utils import fantasize
from botorch.utils.rounding import approximate_round, OneHotArgmaxSTE, RoundSTE
from gpytorch import Module as GPyTorchModule
Expand Down Expand Up @@ -1370,6 +1370,30 @@ def transform(self, X: Tensor) -> Tensor:
return appended_X.view(*X.shape[:-2], -1, appended_X.shape[-1])


class InteractionFeatures(AppendFeatures):
r"""A transform that appends the first-order interaction terms $x_i * x_j, i < j$,
for all or a subset of the input variables."""

def __init__(
self,
indices: Optional[list[int]] = None,
) -> None:
r"""Initializes the InteractionFeatures transform.
Args:
indices: Indices of the subset of dimensions to compute interaction
features on.
"""

super().__init__(
f=interaction_features,
indices=indices,
transform_on_train=True,
transform_on_eval=True,
transform_on_fantasize=True,
)


class FilterFeatures(InputTransform, Module):
r"""A transform that filters the input with a given set of features indices.
Expand Down
15 changes: 15 additions & 0 deletions botorch/models/transforms/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,18 @@ def f(self, X: Tensor) -> Tensor:
return Y

return f


def interaction_features(X: Tensor) -> Tensor:
"""Computes the interaction features between the inputs.
Args:
X: A `batch_shape x q x d`-dim tensor of inputs.
indices: The input dimensions to generate interaction features for.
Returns:
A `n x q x 1 x (d * (d-1) / 2))`-dim tensor of interaction features.
"""
dim = X.shape[-1]
row_idcs, col_idcs = torch.triu_indices(dim, dim, offset=1)
return (X.unsqueeze(-1) @ X.unsqueeze(-2))[..., row_idcs, col_idcs].unsqueeze(-2)
40 changes: 40 additions & 0 deletions test/models/transforms/test_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
InputPerturbation,
InputStandardize,
InputTransform,
InteractionFeatures,
Log10,
Normalize,
OneHotToNumeric,
Expand Down Expand Up @@ -1629,6 +1630,45 @@ def f2(x: Tensor, n_f: int = 1) -> Tensor:
self.assertEqual(X_transformed.shape, torch.Size((10, 4)))


class TestInteractionFeatures(BotorchTestCase):
def test_interaction_features(self) -> None:
interaction = InteractionFeatures()
X = torch.arange(6, dtype=torch.float).reshape(2, 3)
X_tf = interaction(X)
self.assertTrue(X_tf.shape, torch.Size([2, 6]))

# test correct output values
self.assertTrue(
torch.equal(
X_tf,
torch.tensor(
[[0.0, 1.0, 2.0, 0.0, 0.0, 2.0], [3.0, 4.0, 5.0, 12.0, 15.0, 20.0]]
),
)
)
X = torch.arange(6, dtype=torch.float).reshape(2, 3)
interaction = InteractionFeatures(indices=[1, 2])
X_tf = interaction(X)
self.assertTrue(
torch.equal(
X_tf,
torch.tensor([[0.0, 1.0, 2.0, 2.0], [3.0, 4.0, 5.0, 20.0]]),
)
)
with self.assertRaisesRegex(
IndexError, "index 2 is out of bounds for dimension 0 with size 2"
):
interaction(torch.rand(4, 2))

# test batched evaluation
interaction = InteractionFeatures()
X_tf = interaction(torch.rand(4, 2, 4))
self.assertTrue(X_tf.shape, torch.Size([4, 2, 10]))

X_tf = interaction(torch.rand(5, 7, 3, 4))
self.assertTrue(X_tf.shape, torch.Size([5, 7, 3, 10]))


class TestFilterFeatures(BotorchTestCase):
def test_filter_features(self) -> None:
with self.assertRaises(ValueError):
Expand Down

0 comments on commit 68faeff

Please sign in to comment.