1
-
2
1
import jax
3
2
import keras
4
3
@@ -13,7 +12,14 @@ def train_step(self, *args, **kwargs):
13
12
def test_step (self , * args , ** kwargs ):
14
13
return self .stateless_test_step (* args , ** kwargs )
15
14
16
- def stateless_compute_metrics (self , trainable_variables : any , non_trainable_variables : any , metrics_variables : any , data : dict [str , Tensor ], stage : str = "training" ) -> (Tensor , tuple ):
15
+ def stateless_compute_metrics (
16
+ self ,
17
+ trainable_variables : any ,
18
+ non_trainable_variables : any ,
19
+ metrics_variables : any ,
20
+ data : dict [str , Tensor ],
21
+ stage : str = "training" ,
22
+ ) -> (Tensor , tuple ):
17
23
"""
18
24
Things we do for jax:
19
25
1. Accept trainable variables as the first argument
@@ -47,11 +53,13 @@ def stateless_train_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[s
47
53
48
54
grad_fn = jax .value_and_grad (self .stateless_compute_metrics , has_aux = True )
49
55
50
- (loss , aux ), grads = grad_fn (trainable_variables , non_trainable_variables , metrics_variables , data , stage = "training" )
56
+ (loss , aux ), grads = grad_fn (
57
+ trainable_variables , non_trainable_variables , metrics_variables , data , stage = "training"
58
+ )
51
59
metrics , non_trainable_variables , metrics_variables = aux
52
60
53
- trainable_variables , optimizer_variables = (
54
- self . optimizer . stateless_apply ( optimizer_variables , grads , trainable_variables )
61
+ trainable_variables , optimizer_variables = self . optimizer . stateless_apply (
62
+ optimizer_variables , grads , trainable_variables
55
63
)
56
64
57
65
metrics_variables = self ._update_loss (loss , metrics_variables )
@@ -62,7 +70,9 @@ def stateless_train_step(self, state: tuple, data: dict[str, Tensor]) -> (dict[s
62
70
def stateless_test_step (self , state : tuple , data : dict [str , Tensor ]) -> (dict [str , Tensor ], tuple ):
63
71
trainable_variables , non_trainable_variables , metrics_variables = state
64
72
65
- loss , aux = self .stateless_compute_metrics (trainable_variables , non_trainable_variables , metrics_variables , data , stage = "validation" )
73
+ loss , aux = self .stateless_compute_metrics (
74
+ trainable_variables , non_trainable_variables , metrics_variables , data , stage = "validation"
75
+ )
66
76
metrics , non_trainable_variables , metrics_variables = aux
67
77
68
78
metrics_variables = self ._update_loss (loss , metrics_variables )
0 commit comments