diff --git a/nfp/layers/graph_layers.py b/nfp/layers/graph_layers.py index bc9de52..f442482 100644 --- a/nfp/layers/graph_layers.py +++ b/nfp/layers/graph_layers.py @@ -146,8 +146,6 @@ def __init__(self, units, num_heads, **kwargs): def build(self, input_shape): super().build(input_shape) dense_units = self.units * self.num_heads # N*H - # if self.use_global: - # assert input_shape[-1][-1] == dense_units self.query_layer = layers.Dense(self.num_heads, name='query') self.value_layer = layers.Dense(dense_units, name='value') @@ -160,7 +158,6 @@ def transpose_scores(self, input_tensor): def call(self, inputs, mask=None): if not self.use_global: atom_state, bond_state, connectivity = inputs - global_state = None else: atom_state, bond_state, connectivity, global_state = inputs