Skip to content

Commit

Permalink
[fix] Fix the activation checkpointing when using SwiGLUPackedFusedOp
Browse files Browse the repository at this point in the history
The IF conditional on the x.requires_grad state (to change the behavior between inference/training modes) changes behavior of the recomputation of the forward() method which breaks activation checkpointing
(as on recomputation phase x is detached with requires_grad==False, and different number of tensors are saved in the save_for_backward() method).
  • Loading branch information
warpuv committed Oct 17, 2024
1 parent 46d2823 commit 4829d7e
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions xformers/csrc/swiglu/swiglu_packedw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,12 @@ at::Tensor swiglu_packedw_cuda(
const std::optional<at::Tensor> b1b2,
const at::Tensor w3,
const std::optional<at::Tensor> b3) {
if (torch::GradMode::is_enabled()) {
return SwiGLUPackedWeights::apply(x, w1w2, b1b2, w3, b3);
} else {
return SwiGLUPackedWeights::forward(
/* ctx */ nullptr, x, w1w2, b1b2, w3, b3);
}
}
} // namespace

Expand Down

0 comments on commit 4829d7e

Please sign in to comment.