Skip to content

Commit 126b855

Browse files
hertschuhtensorflower-gardener
authored andcommitted
Call super().build(input_shape) instead of self.built = True in all Keras layers.
Within `build()`, some Keras layers where calling `super().build(input_shape)` while some were calling `self.built = True`. This would result in a different config when serializing whereby layers doing `self.built = True` would not have a `build_config`. This change makes it consistent between all the layers as well as consistent with Keras 3. Note that some layers need to call `Layer.build(self, input_shape)` directly to bypass some class' `build()` but still populate the information for the `build_config`. PiperOrigin-RevId: 678454186
1 parent ed3f017 commit 126b855

File tree

3 files changed

+20
-11
lines changed

3 files changed

+20
-11
lines changed

tensorflow_model_optimization/python/core/clustering/keras/cluster_test.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -669,10 +669,13 @@ def testStripClusteringSequentialModel(self):
669669

670670
self.assertEqual(self._count_clustered_layers(stripped_model), 0)
671671
model_config = model.get_config()
672+
stripped_model_config = stripped_model.get_config()
673+
# New serialization format includes `build_config` in all layers
672674
for layer in model_config['layers']:
673-
# New serialization format includes `build_config` in wrapper
674675
layer.pop('build_config', None)
675-
self.assertEqual(model_config, stripped_model.get_config())
676+
for layer in stripped_model_config['layers']:
677+
layer.pop('build_config', None)
678+
self.assertEqual(model_config, stripped_model_config)
676679

677680
def testClusterStrippingFunctionalModel(self):
678681
"""Verifies that stripping the clustering wrappers from a functional model produces the expected config."""

tensorflow_model_optimization/python/core/common/keras/compression/internal/optimize.py

+6
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,9 @@ def create_layer_for_training(layer, algorithm):
404404
if compressible_weights:
405405
# Set pretrained weight values.
406406
wrapped_layer.build(input_shape)
407+
# Clear `_build_input_shape` so that `build()` is not immediately called
408+
# during reloading. We want the wrapper layer to initiate `build()`.
409+
wrapped_layer.layer._build_input_shape = None # pylint: disable=protected-access
407410
training_weights = _map_to_training_weights(
408411
algorithm,
409412
layer,
@@ -445,6 +448,9 @@ def create_layer_for_inference(layer: _TrainingWrapper, algorithm):
445448
layer_for_inference = _InferenceWrapper(cloned_layer, algorithm,
446449
compressible_training_tensors)
447450
layer_for_inference.build(input_shape)
451+
# Clear `_build_input_shape` so that `build()` is not immediately called
452+
# during reloading. We want the wrapper layer to initiate `build()`.
453+
layer_for_inference.layer._build_input_shape = None # pylint: disable=protected-access
448454

449455
if layer.get_weights():
450456
# Set weights of layer for inference according to what was trained.

tensorflow_model_optimization/python/core/sparsity/keras/pruning_wrapper_test.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -122,17 +122,17 @@ def testPruneModel(self):
122122

123123
# Test serialization
124124
model_config = self.model.get_config()
125-
for layer in model_config['layers']:
126-
layer.pop('build_config', None)
127-
self.assertEqual(
125+
pruned_model = self.model.__class__.from_config(
128126
model_config,
129-
self.model.__class__.from_config(
130-
self.model.get_config(),
131-
custom_objects={
132-
'PruneLowMagnitude': pruning_wrapper.PruneLowMagnitude
133-
},
134-
).get_config(),
127+
custom_objects={'PruneLowMagnitude': pruning_wrapper.PruneLowMagnitude},
135128
)
129+
pruned_model_config = pruned_model.get_config()
130+
# New serialization format includes `build_config` in all layers
131+
for layer in model_config['layers']:
132+
layer.pop('build_config', None)
133+
for layer in pruned_model_config['layers']:
134+
layer.pop('build_config', None)
135+
self.assertEqual(model_config, pruned_model_config)
136136

137137
def testCustomLayerNonPrunable(self):
138138
layer = CustomLayer(input_dim=16, output_dim=32)

0 commit comments

Comments
 (0)