Skip to content

Commit dcc2ddd

Browse files
authored
Add discrepancies when comparing the execution of two models (#79)
* update requirements * add discrepancies figures * fix command line * doc
1 parent a906010 commit dcc2ddd

File tree

4 files changed

+81
-9
lines changed

4 files changed

+81
-9
lines changed

CHANGELOGS.rst

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ Change Logs
55
+++++
66

77
* :pr:`77`: supports ConcatOfShape and Slice with the light API
8-
* :pr:`76`: add a mode to compare models without execution
8+
* :pr:`76`, :pr:`79`: add a mode to compare models without execution
99
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
1010
* :pr:`71`: adds tools to compare two onnx graphs
1111
* :pr:`61`: adds function to plot onnx model as graphs

_unittests/ut_reference/test_evaluator_yield.py

+25
Original file line numberDiff line numberDiff line change
@@ -462,6 +462,31 @@ def test_compare_execution(self):
462462
self.assertIn("CAAA Constant", text)
463463
self.assertEqual(len(align), 5)
464464

465+
def test_compare_execution_discrepancies(self):
466+
m1 = parse_model(
467+
"""
468+
<ir_version: 8, opset_import: [ "": 18]>
469+
agraph (float[N] x) => (float[N] z) {
470+
two = Constant <value_float=2.0> ()
471+
four = Add(two, two)
472+
z = Mul(x, x)
473+
}"""
474+
)
475+
m2 = parse_model(
476+
"""
477+
<ir_version: 8, opset_import: [ "": 18]>
478+
agraph (float[N] x) => (float[N] z) {
479+
two = Constant <value_float=2.0> ()
480+
z = Mul(x, x)
481+
}"""
482+
)
483+
res1, res2, align, dc = compare_onnx_execution(m1, m2, keep_tensor=True)
484+
text = dc.to_str(res1, res2, align)
485+
print(text)
486+
self.assertIn("CAAA Constant", text)
487+
self.assertIn("| a=", text)
488+
self.assertIn(" r=", text)
489+
465490
def test_no_execution(self):
466491
model = make_model(
467492
make_graph(

onnx_array_api/_command_lines_parser.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,15 @@ def get_parser_compare() -> ArgumentParser:
106106
parser.add_argument(
107107
"-c",
108108
"--column-size",
109-
default=50,
109+
default=60,
110110
help="column size when displaying the results",
111111
)
112+
parser.add_argument(
113+
"-d",
114+
"--discrepancies",
115+
default=0,
116+
help="show precise discrepancies when mode is execution",
117+
)
112118
return parser
113119

114120

@@ -120,7 +126,11 @@ def _cmd_compare(argv: List[Any]):
120126
onx1 = onnx.load(args.model1)
121127
onx2 = onnx.load(args.model2)
122128
res1, res2, align, dc = compare_onnx_execution(
123-
onx1, onx2, verbose=args.verbose, mode=args.mode
129+
onx1,
130+
onx2,
131+
verbose=args.verbose,
132+
mode=args.mode,
133+
keep_tensor=args.discrepancies in (1, "1", "True", True),
124134
)
125135
text = dc.to_str(res1, res2, align, column_size=int(args.column_size))
126136
print(text)

onnx_array_api/reference/evaluator_yield.py

+43-6
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ class ResultExecution:
5757
summary: str
5858
op_type: str
5959
name: str
60+
value: Optional[Any] = None
6061

6162
def __len__(self) -> int:
6263
return 6
@@ -122,9 +123,11 @@ def make_summary(value: Any, length: int = 4, modulo: int = 26) -> str:
122123
else:
123124
value2 = value.flatten().astype(np.float64)
124125
value4 = value2.reshape((4, -1)).sum(axis=1)
125-
value4i = value4.astype(np.int64) % modulo
126-
s = "".join([chr(65 + i) for i in value4i])
127-
return s
126+
value4 = np.where(np.abs(value4) < 1e10, value4, np.nan)
127+
s = []
128+
for v in value4:
129+
s.append("?" if np.isnan(v) else (chr(65 + int(v) % modulo)))
130+
return "".join(s)
128131

129132

130133
class YieldEvaluator:
@@ -228,6 +231,7 @@ def enumerate_summarized(
228231
output_names: Optional[List[str]] = None,
229232
feed_inputs: Optional[Dict[str, Any]] = None,
230233
raise_exc: bool = True,
234+
keep_tensor: bool = False,
231235
) -> Iterator[ResultExecution]:
232236
"""
233237
Executes the onnx model and enumerate intermediate results without their names.
@@ -236,17 +240,40 @@ def enumerate_summarized(
236240
:param feed_inputs: dictionary `{ input name: input value }`
237241
:param raise_exc: raises an exception if the execution fails or stop
238242
where it is
243+
:param keep_tensor:keep the tensor in order to compute precise distances
239244
:return: iterator on ResultExecution
240245
"""
241246
for kind, name, value, op_type in self.enumerate_results(
242247
output_names, feed_inputs, raise_exc=raise_exc
243248
):
244249
summary = make_summary(value)
245250
yield ResultExecution(
246-
kind, value.dtype, value.shape, summary, op_type, name
251+
kind,
252+
value.dtype,
253+
value.shape,
254+
summary,
255+
op_type,
256+
name,
257+
value=value if keep_tensor else None,
247258
)
248259

249260

261+
def discrepancies(
262+
expected: np.ndarray, value: np.ndarray, eps: float = 1e-7
263+
) -> Dict[str, float]:
264+
"""
265+
Computes absolute error and relative error between two matrices.
266+
"""
267+
assert (
268+
expected.size == value.size
269+
), f"Incompatible shapes v1.shape={expected.shape}, v2.shape={value.shape}"
270+
expected = expected.ravel().astype(np.float32)
271+
value = value.ravel().astype(np.float32)
272+
diff = np.abs(expected - value)
273+
rel = diff / (np.abs(expected) + eps)
274+
return dict(aerr=float(diff.max()), rerr=float(rel.max()))
275+
276+
250277
class DistanceExecution:
251278
"""
252279
Computes a distance between two results.
@@ -403,6 +430,14 @@ def to_str(
403430
d = self.distance_pair(d1, d2)
404431
symbol = "=" if d == 0 else "~"
405432
line = f"{symbol} | {_align(str(d1), column_size)} | {_align(str(d2), column_size)}"
433+
if (
434+
d1.value is not None
435+
and d2.value is not None
436+
and d1.value.size == d2.value.size
437+
):
438+
disc = discrepancies(d1.value, d2.value)
439+
a, r = disc["aerr"], disc["rerr"]
440+
line += f" | a={a:.3f} r={r:.3f}"
406441
elif i == last[0]:
407442
d2 = s2[j]
408443
line = (
@@ -551,6 +586,7 @@ def compare_onnx_execution(
551586
verbose: int = 0,
552587
raise_exc: bool = True,
553588
mode: str = "execute",
589+
keep_tensor: bool = False,
554590
) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]:
555591
"""
556592
Compares the execution of two onnx models.
@@ -566,6 +602,7 @@ def compare_onnx_execution(
566602
:param raise_exc: raise exception if the execution fails or stop at the error
567603
:param mode: the model should be executed but the function can be executed
568604
but the comparison may append on nodes only
605+
:param keep_tensor: keeps the tensor in order to compute a precise distance
569606
:return: four results, a sequence of results for the first model and the second model,
570607
the alignment between the two, DistanceExecution
571608
"""
@@ -589,15 +626,15 @@ def compare_onnx_execution(
589626
print("[compare_onnx_execution] execute first model")
590627
res1 = list(
591628
YieldEvaluator(model1).enumerate_summarized(
592-
None, feeds1, raise_exc=raise_exc
629+
None, feeds1, raise_exc=raise_exc, keep_tensor=keep_tensor
593630
)
594631
)
595632
if verbose:
596633
print(f"[compare_onnx_execution] got {len(res1)} results")
597634
print("[compare_onnx_execution] execute second model")
598635
res2 = list(
599636
YieldEvaluator(model2).enumerate_summarized(
600-
None, feeds2, raise_exc=raise_exc
637+
None, feeds2, raise_exc=raise_exc, keep_tensor=keep_tensor
601638
)
602639
)
603640
elif mode == "nodes":

0 commit comments

Comments
 (0)