diff --git a/tiledb/ml/models/tensorflow_keras.py b/tiledb/ml/models/tensorflow_keras.py index d5546310..bb9d0755 100644 --- a/tiledb/ml/models/tensorflow_keras.py +++ b/tiledb/ml/models/tensorflow_keras.py @@ -279,6 +279,11 @@ def _serialize_optimizer_weights( assert self.artifact optimizer = self.artifact.optimizer if optimizer and not isinstance(optimizer, TFOptimizer): - optimizer_weights = tf.keras.backend.batch_get_value(optimizer.weights) + if hasattr(optimizer, "weights"): + optimizer_weights = tf.keras.backend.batch_get_value(optimizer.weights) + else: + optimizer_weights = [ # type:ignore + var.numpy() for var in optimizer.variables() + ] # type:ignore return pickle.dumps(optimizer_weights, protocol=4) return b""