diff --git a/dig/threedgraph/method/dimenetpp/dimenetpp.py b/dig/threedgraph/method/dimenetpp/dimenetpp.py index 6a2ef500..59e0fd2b 100644 --- a/dig/threedgraph/method/dimenetpp/dimenetpp.py +++ b/dig/threedgraph/method/dimenetpp/dimenetpp.py @@ -131,7 +131,7 @@ def reset_parameters(self): glorot_orthogonal(self.lin_rbf.weight, scale=2.0) def forward(self, x, emb, idx_kj, idx_ji): - rbf0, sbf = emb + rbf0, sbf = emb ## e_{SBF}(d,tau) = sqrt(2/c^3 j_{l+1}^{3}.... x1,_ = x x_ji = self.act(self.lin_ji(x1))