Skip to content

Commit 4cf9dcc

Browse files
authored
Adds a mode to compare models without execution (#76)
* update requirements * Add a mode to compare model without execution * changelogs * improve initializer * fix display * fix side
1 parent 7675869 commit 4cf9dcc

File tree

4 files changed

+255
-38
lines changed

4 files changed

+255
-38
lines changed

CHANGELOGS.rst

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`76`: add a mode to compare models without execution
78
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
89
* :pr:`71`: adds tools to compare two onnx graphs
910
* :pr:`61`: adds function to plot onnx model as graphs

_unittests/ut_reference/test_evaluator_yield.py

+71-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import unittest
22
import numpy as np
33
from onnx import TensorProto
4+
from onnx.checker import check_model
45
from onnx.helper import (
56
make_function,
67
make_graph,
@@ -9,6 +10,7 @@
910
make_opsetid,
1011
make_tensor_value_info,
1112
)
13+
from onnx.numpy_helper import from_array
1214
from onnx.parser import parse_model
1315
from onnx_array_api.ext_test_case import ExtTestCase
1416
from onnx_array_api.reference import (
@@ -422,13 +424,13 @@ def test_distance_sequence_str(self):
422424
text = dc.to_str(s1, s2, align)
423425
self.assertIn("OUTPUT", text)
424426
expected = """
425-
001=|INPUTfloat322x2ABCDA|INPUTfloat322x2ABCDA
426-
002=|INPUTfloat322x2ABCDB|INPUTfloat322x2ABCDB
427-
003~|INPUTfloat322x3ABCDX|INPUTfloat322x2ABCDX
428-
004-|RESULTfloat322x2CEIOExpH|
429-
005=|RESULTfloat322x2CEIOLinearRegrY1|RESULTfloat322x2CEIOLinearRegrY1
430-
006~|RESULTfloat322x2CEIOAbsY|RESULTfloat322x3CEIPAbsZ
431-
007~|OUTPUTfloat322x2CEIOY|OUTPUTfloat322x2CEIPY
427+
001=|INPUTfloat322:2x2ABCDA|INPUTfloat322:2x2ABCDA
428+
002=|INPUTfloat322:2x2ABCDB|INPUTfloat322:2x2ABCDB
429+
003~|INPUTfloat322:2x3ABCDX|INPUTfloat322:2x2ABCDX
430+
004-|RESULTfloat322:2x2CEIOExpH|
431+
005=|RESULTfloat322:2x2CEIOLinearRegressioY1|RESULTfloat322:2x2CEIOLinearRegressioY1
432+
006~|RESULTfloat322:2x2CEIOAbsY|RESULTfloat322:2x3CEIPAbsZ
433+
007~|OUTPUTfloat322:2x2CEIOY|OUTPUTfloat322:2x2CEIPY
432434
""".replace(
433435
" ", ""
434436
).strip(
@@ -460,6 +462,68 @@ def test_compare_execution(self):
460462
self.assertIn("CAAA Constant", text)
461463
self.assertEqual(len(align), 5)
462464

465+
def test_no_execution(self):
466+
model = make_model(
467+
make_graph(
468+
[
469+
make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
470+
make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
471+
make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
472+
make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
473+
make_node("Cast", ["xm2c"], ["xm2"], to=1),
474+
make_node("MatMul", ["xm1", "xm2"], ["xm"]),
475+
make_node("Reshape", ["xm", "shape3"], ["Z"]),
476+
],
477+
"dummy",
478+
[
479+
make_tensor_value_info("X", TensorProto.FLOAT, [32, 128]),
480+
make_tensor_value_info("Y", TensorProto.FLOAT, [3, 5, 128, 64]),
481+
],
482+
[make_tensor_value_info("Z", TensorProto.FLOAT, [3, 5, 32, "N"])],
483+
[
484+
from_array(np.array([0], dtype=np.int64), name="zero"),
485+
from_array(np.array([1], dtype=np.int64), name="un"),
486+
from_array(np.array([1, 32, 128], dtype=np.int64), name="shape1"),
487+
from_array(np.array([15, 128, 64], dtype=np.int64), name="shape2"),
488+
from_array(np.array([3, 5, 32, 64], dtype=np.int64), name="shape3"),
489+
],
490+
)
491+
)
492+
check_model(model)
493+
res1, res2, align, dc = compare_onnx_execution(model, model, mode="nodes")
494+
text = dc.to_str(res1, res2, align)
495+
self.assertIn("012 = | NODE", text)
496+
497+
model2 = make_model(
498+
make_graph(
499+
[
500+
make_node("Unsqueeze", ["X", "zero"], ["xu1"]),
501+
make_node("Unsqueeze", ["xu1", "un"], ["xu2"]),
502+
make_node("Reshape", ["xu2", "shape1"], ["xm1"]),
503+
make_node("Reshape", ["Y", "shape2"], ["xm2c"]),
504+
make_node("MatMul", ["xm1", "xm2c"], ["xm"]),
505+
make_node("Reshape", ["xm", "shape3"], ["Z"]),
506+
],
507+
"dummy",
508+
[
509+
make_tensor_value_info("X", TensorProto.FLOAT, [32, 128]),
510+
make_tensor_value_info("Y", TensorProto.FLOAT, [3, 5, 128, 64]),
511+
],
512+
[make_tensor_value_info("Z", TensorProto.FLOAT, [3, 5, 32, "N"])],
513+
[
514+
from_array(np.array([0], dtype=np.int64), name="zero"),
515+
from_array(np.array([1], dtype=np.int64), name="un"),
516+
from_array(np.array([1, 32, 128], dtype=np.int64), name="shape1"),
517+
from_array(np.array([15, 128, 64], dtype=np.int64), name="shape2"),
518+
from_array(np.array([3, 5, 32, 64], dtype=np.int64), name="shape3"),
519+
],
520+
)
521+
)
522+
check_model(model2)
523+
res1, res2, align, dc = compare_onnx_execution(model, model2, mode="nodes")
524+
text = dc.to_str(res1, res2, align)
525+
self.assertIn("012 = | NODE", text)
526+
463527

464528
if __name__ == "__main__":
465529
unittest.main(verbosity=2)

onnx_array_api/_command_lines_parser.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ def get_main_parser() -> ArgumentParser:
2020
Selects a command.
2121
2222
'translate' exports an onnx graph into a piece of code replicating it,
23-
'compares' compares the execution of two onnx models
23+
'compare' compares the execution of two onnx models
2424
"""
2525
),
2626
)
@@ -90,6 +90,13 @@ def get_parser_compare() -> ArgumentParser:
9090
required=True,
9191
help="second onnx model",
9292
)
93+
parser.add_argument(
94+
"-m",
95+
"--mode",
96+
choices=["execute", "nodes"],
97+
default="execute",
98+
help="compare the execution ('execute') or the nodes only ('nodes')",
99+
)
93100
parser.add_argument(
94101
"-v",
95102
"--verbose",
@@ -112,8 +119,10 @@ def _cmd_compare(argv: List[Any]):
112119
args = parser.parse_args(argv[1:])
113120
onx1 = onnx.load(args.model1)
114121
onx2 = onnx.load(args.model2)
115-
res1, res2, align, dc = compare_onnx_execution(onx1, onx2, verbose=args.verbose)
116-
text = dc.to_str(res1, res2, align, column_size=args.column_size)
122+
res1, res2, align, dc = compare_onnx_execution(
123+
onx1, onx2, verbose=args.verbose, mode=args.mode
124+
)
125+
text = dc.to_str(res1, res2, align, column_size=int(args.column_size))
117126
print(text)
118127

119128

@@ -127,7 +136,7 @@ def main(argv: Optional[List[Any]] = None):
127136
parser = get_main_parser()
128137
parser.parse_args(argv)
129138
else:
130-
parsers = dict(translate=get_parser_translate)
139+
parsers = dict(translate=get_parser_translate, compare=get_parser_compare)
131140
cmd = argv[0]
132141
if cmd not in parsers:
133142
raise ValueError(

0 commit comments

Comments
 (0)