@@ -57,6 +57,7 @@ class ResultExecution:
57
57
summary : str
58
58
op_type : str
59
59
name : str
60
+ value : Optional [Any ] = None
60
61
61
62
def __len__ (self ) -> int :
62
63
return 6
@@ -122,9 +123,11 @@ def make_summary(value: Any, length: int = 4, modulo: int = 26) -> str:
122
123
else :
123
124
value2 = value .flatten ().astype (np .float64 )
124
125
value4 = value2 .reshape ((4 , - 1 )).sum (axis = 1 )
125
- value4i = value4 .astype (np .int64 ) % modulo
126
- s = "" .join ([chr (65 + i ) for i in value4i ])
127
- return s
126
+ value4 = np .where (np .abs (value4 ) < 1e10 , value4 , np .nan )
127
+ s = []
128
+ for v in value4 :
129
+ s .append ("?" if np .isnan (v ) else (chr (65 + int (v ) % modulo )))
130
+ return "" .join (s )
128
131
129
132
130
133
class YieldEvaluator :
@@ -228,6 +231,7 @@ def enumerate_summarized(
228
231
output_names : Optional [List [str ]] = None ,
229
232
feed_inputs : Optional [Dict [str , Any ]] = None ,
230
233
raise_exc : bool = True ,
234
+ keep_tensor : bool = False ,
231
235
) -> Iterator [ResultExecution ]:
232
236
"""
233
237
Executes the onnx model and enumerate intermediate results without their names.
@@ -236,17 +240,40 @@ def enumerate_summarized(
236
240
:param feed_inputs: dictionary `{ input name: input value }`
237
241
:param raise_exc: raises an exception if the execution fails or stop
238
242
where it is
243
+ :param keep_tensor:keep the tensor in order to compute precise distances
239
244
:return: iterator on ResultExecution
240
245
"""
241
246
for kind , name , value , op_type in self .enumerate_results (
242
247
output_names , feed_inputs , raise_exc = raise_exc
243
248
):
244
249
summary = make_summary (value )
245
250
yield ResultExecution (
246
- kind , value .dtype , value .shape , summary , op_type , name
251
+ kind ,
252
+ value .dtype ,
253
+ value .shape ,
254
+ summary ,
255
+ op_type ,
256
+ name ,
257
+ value = value if keep_tensor else None ,
247
258
)
248
259
249
260
261
+ def discrepancies (
262
+ expected : np .ndarray , value : np .ndarray , eps : float = 1e-7
263
+ ) -> Dict [str , float ]:
264
+ """
265
+ Computes absolute error and relative error between two matrices.
266
+ """
267
+ assert (
268
+ expected .size == value .size
269
+ ), f"Incompatible shapes v1.shape={ expected .shape } , v2.shape={ value .shape } "
270
+ expected = expected .ravel ().astype (np .float32 )
271
+ value = value .ravel ().astype (np .float32 )
272
+ diff = np .abs (expected - value )
273
+ rel = diff / (np .abs (expected ) + eps )
274
+ return dict (aerr = float (diff .max ()), rerr = float (rel .max ()))
275
+
276
+
250
277
class DistanceExecution :
251
278
"""
252
279
Computes a distance between two results.
@@ -403,6 +430,14 @@ def to_str(
403
430
d = self .distance_pair (d1 , d2 )
404
431
symbol = "=" if d == 0 else "~"
405
432
line = f"{ symbol } | { _align (str (d1 ), column_size )} | { _align (str (d2 ), column_size )} "
433
+ if (
434
+ d1 .value is not None
435
+ and d2 .value is not None
436
+ and d1 .value .size == d2 .value .size
437
+ ):
438
+ disc = discrepancies (d1 .value , d2 .value )
439
+ a , r = disc ["aerr" ], disc ["rerr" ]
440
+ line += f" | a={ a :.3f} r={ r :.3f} "
406
441
elif i == last [0 ]:
407
442
d2 = s2 [j ]
408
443
line = (
@@ -551,6 +586,7 @@ def compare_onnx_execution(
551
586
verbose : int = 0 ,
552
587
raise_exc : bool = True ,
553
588
mode : str = "execute" ,
589
+ keep_tensor : bool = False ,
554
590
) -> Tuple [List [ResultExecution ], List [ResultExecution ], List [Tuple [int , int ]]]:
555
591
"""
556
592
Compares the execution of two onnx models.
@@ -566,6 +602,7 @@ def compare_onnx_execution(
566
602
:param raise_exc: raise exception if the execution fails or stop at the error
567
603
:param mode: the model should be executed but the function can be executed
568
604
but the comparison may append on nodes only
605
+ :param keep_tensor: keeps the tensor in order to compute a precise distance
569
606
:return: four results, a sequence of results for the first model and the second model,
570
607
the alignment between the two, DistanceExecution
571
608
"""
@@ -589,15 +626,15 @@ def compare_onnx_execution(
589
626
print ("[compare_onnx_execution] execute first model" )
590
627
res1 = list (
591
628
YieldEvaluator (model1 ).enumerate_summarized (
592
- None , feeds1 , raise_exc = raise_exc
629
+ None , feeds1 , raise_exc = raise_exc , keep_tensor = keep_tensor
593
630
)
594
631
)
595
632
if verbose :
596
633
print (f"[compare_onnx_execution] got { len (res1 )} results" )
597
634
print ("[compare_onnx_execution] execute second model" )
598
635
res2 = list (
599
636
YieldEvaluator (model2 ).enumerate_summarized (
600
- None , feeds2 , raise_exc = raise_exc
637
+ None , feeds2 , raise_exc = raise_exc , keep_tensor = keep_tensor
601
638
)
602
639
)
603
640
elif mode == "nodes" :
0 commit comments