Skip to content
This repository has been archived by the owner on Dec 16, 2022. It is now read-only.

Commit

Permalink
Compute attributions w.r.t the predicted logit, not the predicted loss
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwie committed Dec 22, 2020
1 parent 1fff7ca commit 9c43e97
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class SimpleGradient(SaliencyInterpreter):

def saliency_interpret_from_json(self, inputs: JsonDict) -> JsonDict:
"""
Interprets the model's prediction for inputs. Gets the gradients of the loss with respect
Interprets the model's prediction for inputs. Gets the gradients of the logits with respect
to the input and returns those gradients normalized and sanitized.
"""
labeled_instances = self.predictor.json_to_labeled_instances(inputs)
Expand Down
8 changes: 4 additions & 4 deletions allennlp/predictors/predictor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def json_to_labeled_instances(self, inputs: JsonDict) -> List[Instance]:

def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
"""
Gets the gradients of the loss with respect to the model inputs.
Gets the gradients of the logits with respect to the model inputs.
# Parameters
Expand All @@ -91,7 +91,7 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
Takes a `JsonDict` representing the inputs of the model and converts
them to [`Instances`](../data/instance.md)), sends these through
the model [`forward`](../models/model.md#forward) function after registering hooks on the embedding
layer of the model. Calls `backward` on the loss and then removes the
layer of the model. Calls `backward` on the logits and then removes the
hooks.
"""
# set requires_grad to true for all parameters, but save original values to
Expand All @@ -113,13 +113,13 @@ def get_gradients(self, instances: List[Instance]) -> Tuple[Dict[str, Any], Dict
self._model.forward(**dataset_tensor_dict) # type: ignore
)

loss = outputs["loss"]
predicted_logit = outputs["logits"].squeeze(0)[int(torch.argmax(outputs['probs']))]
# Zero gradients.
# NOTE: this is actually more efficient than calling `self._model.zero_grad()`
# because it avoids a read op when the gradients are first updated below.
for p in self._model.parameters():
p.grad = None
loss.backward()
predicted_logit.backward()

for hook in hooks:
hook.remove()
Expand Down

0 comments on commit 9c43e97

Please sign in to comment.