Skip to content

Commit 0749c98

Browse files
authored
update from keras-core to keras 3 (#970)
1 parent cb145a5 commit 0749c98

File tree

8 files changed

+37
-97
lines changed

8 files changed

+37
-97
lines changed

.github/workflows/actions.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ jobs:
3232
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
3333
- name: Install dependencies
3434
run: |
35-
pip install -e ".[tensorflow-cpu,tests]" --progress-bar off --upgrade
35+
pip install -e ".[tests]" --progress-bar off --upgrade
36+
pip install tensorflow-cpu==2.14.0
3637
pip install jax[cpu]
3738
- name: Test with pytest
3839
run: |
@@ -64,10 +65,11 @@ jobs:
6465
key: ${{ runner.os }}-pip-${{ hashFiles('setup.py') }}
6566
- name: Install dependencies
6667
run: |
67-
pip install -e ".[tensorflow-cpu,tests]" --progress-bar off --upgrade
68+
pip install -e ".[tests]" --progress-bar off --upgrade
6869
pip install torch>=2.0.1+cpu --progress-bar off
6970
pip install jax[cpu] --progress-bar off
70-
pip install -e ".[tensorflow-cpu,tests]" --progress-bar off --upgrade
71+
pip uninstall keras -y
72+
pip install tf-nightly==2.16.0.dev20231103
7173
- name: Test with pytest
7274
env:
7375
KERAS_BACKEND: ${{ matrix.backend }}

keras_tuner/applications/efficientnet.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
from keras_tuner.api_export import keras_tuner_export
1818
from keras_tuner.backend import keras
19-
from keras_tuner.backend import ops
2019
from keras_tuner.backend.keras import layers
2120
from keras_tuner.engine import hypermodel
2221

@@ -121,12 +120,7 @@ def build(self, hp):
121120
)
122121
img_size = EFFICIENTNET_IMG_SIZE[version]
123122

124-
x = ops.image.resize(
125-
x,
126-
(img_size, img_size),
127-
interpolation="bilinear",
128-
data_format=keras.backend.image_data_format(),
129-
)
123+
x = layers.Resizing(img_size, img_size, interpolation="bilinear")(x)
130124
efficientnet_model = EFFICIENTNET_MODELS[version](
131125
include_top=False, input_tensor=x
132126
)

keras_tuner/backend/__init__.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,12 @@
1515
Keras backend module.
1616
1717
This module adds a temporarily Keras API surface that is fully under KerasTuner
18-
control. This allows us to switch between `keras_core` and `tf.keras`, as well
18+
control. This allows us to switch between Keras 3 and `tf.keras`, as well
1919
as add shims to support older version of `tf.keras`.
2020
2121
- `config`: check which backend is being run.
22-
- `keras`: The full `keras` API (via `keras_core` or `tf.keras`).
23-
- `ops`: `keras_core.ops`, always tf backed if using `tf.keras`.
24-
- `random`: `keras_core.random`, always tf backed if using `tf.keras`.
22+
- `keras`: The full `keras` API (via `keras` 3 or `tf.keras`).
23+
- `ops`: `keras.ops`, always tf backed if using `tf.keras`.
2524
"""
2625

2726
from keras_tuner.backend import config

keras_tuner/backend/config.py

Lines changed: 7 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -12,63 +12,22 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import json
16-
import os
15+
import keras
1716

18-
import keras_core
1917

20-
_MULTI_BACKEND = False
18+
def _multi_backend():
19+
version_fn = getattr(keras, "version", None)
20+
return version_fn and version_fn().startswith("3.")
2121

22-
# Set Keras base dir path given KERAS_HOME env variable, if applicable.
23-
# Otherwise either ~/.keras or /tmp.
24-
if "KERAS_HOME" in os.environ:
25-
_keras_dir = os.environ.get("KERAS_HOME")
26-
else:
27-
_keras_base_dir = os.path.expanduser("~")
28-
if not os.access(_keras_base_dir, os.W_OK):
29-
_keras_base_dir = "/tmp"
30-
_keras_dir = os.path.join(_keras_base_dir, ".keras")
3122

32-
# Attempt to read KerasTuner config file.
33-
_config_path = os.path.expanduser(os.path.join(_keras_dir, "keras_tuner.json"))
34-
if os.path.exists(_config_path):
35-
try:
36-
with open(_config_path) as f:
37-
_config = json.load(f)
38-
except ValueError:
39-
_config = {}
40-
_MULTI_BACKEND = _config.get("multi_backend", _MULTI_BACKEND)
41-
42-
# Save config file, if possible.
43-
if not os.path.exists(_keras_dir):
44-
try:
45-
os.makedirs(_keras_dir)
46-
except OSError:
47-
# Except permission denied and potential race conditions
48-
# in multi-threaded environments.
49-
pass
50-
51-
if not os.path.exists(_config_path):
52-
_config = {
53-
"multi_backend": _MULTI_BACKEND,
54-
}
55-
try:
56-
with open(_config_path, "w") as f:
57-
f.write(json.dumps(_config, indent=4))
58-
except IOError:
59-
# Except permission denied.
60-
pass
61-
62-
# Use keras-core if KERAS_BACKEND is set in the environment.
63-
if "KERAS_BACKEND" in os.environ and os.environ["KERAS_BACKEND"]:
64-
_MULTI_BACKEND = True
23+
_MULTI_BACKEND = _multi_backend()
6524

6625

6726
def multi_backend():
68-
"""Check if keras_core is enabled."""
27+
"""Check if multi-backend keras is enabled."""
6928
return _MULTI_BACKEND
7029

7130

7231
def backend():
7332
"""Check the backend framework."""
74-
return "tensorflow" if not multi_backend() else keras_core.config.backend()
33+
return "tensorflow" if not multi_backend() else keras.config.backend()

keras_tuner/backend/keras.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from keras_tuner.backend.config import multi_backend
1818

1919
if multi_backend():
20-
from keras_core import * # noqa: F403, F401
20+
from keras import * # noqa: F403, F401
2121
else:
2222
import tensorflow as tf
2323
from tensorflow.keras import * # noqa: F403, F401

keras_tuner/backend/ops.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,41 +12,27 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
import keras_core
16-
1715
from keras_tuner.backend.config import multi_backend
1816

1917
if multi_backend():
20-
from keras_core.src.ops import * # noqa: F403, F401
18+
from keras.src.ops import * # noqa: F403, F401
2119
else:
22-
from keras_core.src.backend.tensorflow import * # noqa: F403, F401
23-
from keras_core.src.backend.tensorflow.core import * # noqa: F403, F401
24-
from keras_core.src.backend.tensorflow.math import * # noqa: F403, F401
25-
from keras_core.src.backend.tensorflow.nn import * # noqa: F403, F401
26-
from keras_core.src.backend.tensorflow.numpy import * # noqa: F403, F401
27-
28-
29-
if keras_core.config.backend() == "tensorflow" or not multi_backend():
20+
import tensorflow as tf
21+
from tensorflow import cast # noqa: F403, F401
3022

31-
def take_along_axis(x, indices, axis=None):
32-
import tensorflow as tf
23+
def any_symbolic_tensors(args=None, kwargs=None):
24+
args = args or ()
25+
kwargs = kwargs or {}
26+
for x in tf.nest.flatten((args, kwargs)):
27+
if "KerasTensor" in x.__class__.__name__:
28+
return True
29+
return False
3330

34-
# TODO: move this workaround for dynamic shapes into keras-core.
35-
if axis < 0:
36-
axis = axis + indices.shape.rank
37-
# If all shapes after axis are 1, squeeze them off and use tf.gather.
38-
# tf.gather plays nicer with dynamic shapes in compiled functions.
39-
leftover_axes = list(range(axis + 1, indices.shape.rank))
40-
static_shape = indices.shape.as_list()
41-
squeezable = True
42-
for i in leftover_axes:
43-
if static_shape[i] != 1:
44-
squeezable = False
45-
if squeezable:
46-
if leftover_axes:
47-
indices = tf.squeeze(indices, leftover_axes)
48-
return tf.gather(x, indices, batch_dims=axis)
49-
# Otherwise, fall back to the tfnp call.
50-
return keras_core.src.backend.tensorflow.numpy.take_along_axis(
51-
x, indices, axis=axis
31+
def shape(x):
32+
if any_symbolic_tensors((x,)):
33+
return x.shape
34+
dynamic = tf.shape(x)
35+
static = x.shape.as_list()
36+
return tuple(
37+
dynamic[i] if s is None else s for i, s in enumerate(static)
5238
)

keras_tuner/backend/random.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,6 @@
1515
from keras_tuner.backend.config import multi_backend
1616

1717
if multi_backend():
18-
from keras_core.random import * # noqa: F403, F401
18+
from keras.random import * # noqa: F403, F401
1919
else:
20-
from keras_core.src.backend.tensorflow.random import * # noqa: F403, F401
20+
from tensorflow.random import * # noqa: F403, F401

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_version(rel_path):
5252
license="Apache License 2.0",
5353
version=VERSION,
5454
install_requires=[
55-
"keras-core",
55+
"keras",
5656
"packaging",
5757
"requests",
5858
"kt-legacy",

0 commit comments

Comments
 (0)