Skip to content

Commit

Permalink
Support tensorflow 2.13
Browse files Browse the repository at this point in the history
  • Loading branch information
Shelnutt2 committed Jul 24, 2023
1 parent d7d67fd commit 2be225a
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 3 deletions.
11 changes: 10 additions & 1 deletion tests/models/test_tensorflow_keras_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,16 @@
get_small_sequential_mlp,
)
except ImportError:
from keras.testing_utils import get_small_functional_mlp, get_small_sequential_mlp
try:
from keras.testing_utils import (
get_small_functional_mlp,
get_small_sequential_mlp,
)
except ImportError:
from keras.src.testing_utils import (
get_small_functional_mlp,
get_small_sequential_mlp,
)

# Suppress all Tensorflow messages
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
Expand Down
13 changes: 11 additions & 2 deletions tiledb/ml/models/tensorflow_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,33 @@

from ._base import Meta, TileDBArtifact, Timestamp

FunctionalOrSequential = (keras.models.Functional, keras.models.Sequential)
keras_major, keras_minor, keras_patch = keras.__version__.split(".")
FunctionalOrSequential = keras.models.Sequential
# Handle keras <=v2.10
if int(keras_major) <= 2 and int(keras_minor) <= 10:
FunctionalOrSequential = (keras.models.Functional, keras.models.Sequential)
TFOptimizer = keras.optimizers.TFOptimizer
get_json_type = keras.saving.saved_model.json_utils.get_json_type
preprocess_weights_for_loading = (
keras.saving.hdf5_format.preprocess_weights_for_loading
)
saving_utils = keras.saving.saving_utils
# Handle keras >=v2.11
else:
elif int(keras_major) <= 2 and int(keras_minor) <= 12:
FunctionalOrSequential = (keras.models.Functional, keras.models.Sequential)
TFOptimizer = tf.keras.optimizers.legacy.Optimizer
get_json_type = keras.saving.legacy.saved_model.json_utils.get_json_type
preprocess_weights_for_loading = (
keras.saving.legacy.hdf5_format.preprocess_weights_for_loading
)
saving_utils = keras.saving.legacy.saving_utils
else:
TFOptimizer = tf.keras.optimizers.legacy.Optimizer
get_json_type = keras.src.saving.legacy.saved_model.json_utils.get_json_type
preprocess_weights_for_loading = (
keras.src.saving.legacy.hdf5_format.preprocess_weights_for_loading
)
saving_utils = keras.src.saving.legacy.saving_utils


class TensorflowKerasTileDBModel(TileDBArtifact[tf.keras.Model]):
Expand Down

0 comments on commit 2be225a

Please sign in to comment.