Skip to content

Commit 79f0f16

Browse files
committed
Add command line to replace constant
1 parent 32fc52e commit 79f0f16

File tree

6 files changed

+477
-4
lines changed

6 files changed

+477
-4
lines changed

_doc/api/tools.rst

+5
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ Benchmark
66

77
.. autofunction:: onnx_array_api.ext_test_case.measure_time
88

9+
Manipulations
10+
+++++++++++++
11+
12+
.. autofunction:: onnx_array_api.tools.replace_initializer_by_constant_of_shape
13+
914
Examples
1015
++++++++
1116

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import unittest
2+
import numpy as np
3+
import onnx
4+
import onnx.helper as oh
5+
import onnx.numpy_helper as onh
6+
from onnx import TensorProto
7+
from onnx_array_api.ext_test_case import ExtTestCase
8+
from onnx_array_api.reference import (
9+
ExtendedReferenceEvaluator as ReferenceEvaluator,
10+
)
11+
from onnx_array_api.tools.replace_constants import (
12+
replace_initializer_by_constant_of_shape,
13+
)
14+
15+
16+
class TestReplaceConstants(ExtTestCase):
17+
18+
def test_replace_initializer(self):
19+
dtype = np.float32
20+
value = np.random.randn(2, 100).astype(dtype)
21+
A = onh.from_array(value, name="A")
22+
value = np.array([1], dtype=dtype)
23+
C = onh.from_array(value, name="C")
24+
25+
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
26+
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
27+
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
28+
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
29+
graph = oh.make_graph([node1, node2], "lr", [X], [Y], [A, C])
30+
model_def = oh.make_model(graph)
31+
32+
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
33+
oinf1 = ReferenceEvaluator(model_def)
34+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
35+
repl = replace_initializer_by_constant_of_shape(model_def)
36+
node_types = {n.op_type for n in repl.graph.node}
37+
self.assertIn("ConstantOfShape", node_types)
38+
oinf2 = ReferenceEvaluator(repl)
39+
y1[:, :] = 3.5
40+
y1[0, :] = 0.5
41+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
42+
self.assertEqualArray(y1, y2)
43+
44+
def test_replace_constant(self):
45+
dtype = np.float32
46+
value = np.random.randn(2, 10).astype(dtype)
47+
A = onh.from_array(value, name="A")
48+
value = np.array([1], dtype=dtype)
49+
C = onh.from_array(value, name="C")
50+
51+
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
52+
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
53+
node0 = oh.make_node("Constant", [], ["A"], value=A)
54+
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
55+
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
56+
graph = oh.make_graph([node0, node1, node2], "lr", [X], [Y], [C])
57+
model_def = oh.make_model(graph)
58+
59+
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
60+
oinf1 = ReferenceEvaluator(model_def)
61+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
62+
repl = replace_initializer_by_constant_of_shape(model_def, threshold=0)
63+
node_types = {n.op_type for n in repl.graph.node}
64+
self.assertIn("ConstantOfShape", node_types)
65+
oinf2 = ReferenceEvaluator(repl)
66+
y1[:, :] = 4
67+
y1[0, :] = 1
68+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
69+
self.assertEqualArray(y1, y2)
70+
71+
def test_replace_constant_function(self):
72+
dtype = np.float32
73+
value = np.random.randn(2, 100).astype(dtype)
74+
A = onh.from_array(value, name="A")
75+
value = np.array([1], dtype=dtype)
76+
C = onh.from_array(value, name="C")
77+
78+
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
79+
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
80+
nodeC = oh.make_node("Constant", [], ["C"], value=C)
81+
node0 = oh.make_node("Constant", [], ["A"], value=A)
82+
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
83+
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
84+
opset_imports = [
85+
oh.make_opsetid("", onnx.defs.onnx_opset_version()),
86+
oh.make_opsetid("custom", 1),
87+
]
88+
fct = oh.make_function(
89+
"custom",
90+
"unittest",
91+
["X"],
92+
["Y"],
93+
[nodeC, node0, node1, node2],
94+
opset_imports,
95+
)
96+
97+
node = oh.make_node("unittest", ["X"], ["Y"], domain="custom")
98+
graph = oh.make_graph([node], "lr", [X], [Y], [C])
99+
model_def = oh.make_model(graph, functions=[fct], opset_imports=opset_imports)
100+
101+
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
102+
oinf1 = ReferenceEvaluator(model_def)
103+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
104+
repl = replace_initializer_by_constant_of_shape(model_def)
105+
node_types = {n.op_type for n in repl.functions[0].node}
106+
self.assertIn("ConstantOfShape", node_types)
107+
oinf2 = ReferenceEvaluator(repl)
108+
y1[:, :] = 3.5
109+
y1[0, :] = 0.5
110+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
111+
self.assertEqualArray(y1, y2)
112+
113+
def test_replace_constant_graph(self):
114+
value = np.array([0], dtype=np.float32)
115+
zero = onh.from_array(value, name="zero")
116+
117+
X = oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None])
118+
Y = oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None])
119+
120+
rsum = oh.make_node("ReduceSum", ["X"], ["rsum"])
121+
cond = oh.make_node("Greater", ["rsum", "zero"], ["cond"])
122+
123+
then_out = oh.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, None)
124+
then_cst = onh.from_array(np.array([1] * 129).astype(np.float32))
125+
126+
then_const_node = oh.make_node(
127+
"Constant", inputs=[], outputs=["then_out"], value=then_cst, name="cst1"
128+
)
129+
then_body = oh.make_graph([then_const_node], "then_body", [], [then_out])
130+
131+
else_out = oh.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, None)
132+
else_cst = onh.from_array(np.array([-1] * 129).astype(np.float32))
133+
else_const_node = oh.make_node(
134+
"Constant", inputs=[], outputs=["else_out"], value=else_cst, name="cst2"
135+
)
136+
else_body = oh.make_graph([else_const_node], "else_body", [], [else_out])
137+
138+
if_node = oh.make_node(
139+
"If", ["cond"], ["Y"], then_branch=then_body, else_branch=else_body
140+
)
141+
graph = oh.make_graph([rsum, cond, if_node], "if", [X], [Y], [zero])
142+
onnx_model = oh.make_model(
143+
graph, opset_imports=[oh.make_opsetid("", onnx.defs.onnx_opset_version())]
144+
)
145+
self.assertNotIn("ConstantOfShape", str(onnx_model))
146+
147+
x = np.ones((3, 2), dtype=np.float32)
148+
oinf1 = ReferenceEvaluator(onnx_model)
149+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
150+
repl = replace_initializer_by_constant_of_shape(onnx_model)
151+
self.assertIn("ConstantOfShape", str(repl))
152+
oinf2 = ReferenceEvaluator(repl)
153+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
154+
y1 = y1.copy()
155+
y1[:] = 0.5
156+
self.assertEqualArray(y1, y2)
157+
158+
159+
if __name__ == "__main__":
160+
unittest.main(verbosity=2)

_unittests/ut_xrun_doc/test_command_lines1.py

+8
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_main_parser,
1717
get_parser_compare,
1818
get_parser_translate,
19+
get_parser_replace,
1920
main,
2021
)
2122

@@ -35,6 +36,13 @@ def test_parser_translate(self):
3536
text = st.getvalue()
3637
self.assertIn("model", text)
3738

39+
def test_parser_replace(self):
40+
st = StringIO()
41+
with redirect_stdout(st):
42+
get_parser_replace().print_help()
43+
text = st.getvalue()
44+
self.assertIn("model", text)
45+
3846
def test_command_translate(self):
3947
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
4048
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6])

onnx_array_api/_command_lines_parser.py

+76-4
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ def get_main_parser() -> ArgumentParser:
1414
)
1515
parser.add_argument(
1616
"cmd",
17-
choices=["translate", "compare"],
17+
choices=["translate", "compare", "replace"],
1818
help=dedent(
1919
"""
2020
Selects a command.
2121
2222
'translate' exports an onnx graph into a piece of code replicating it,
23-
'compare' compares the execution of two onnx models
23+
'compare' compares the execution of two onnx models,
24+
'replace' replaces constant and initliazers by ConstantOfShape to make the model lighter
2425
"""
2526
),
2627
)
@@ -142,8 +143,75 @@ def _cmd_compare(argv: List[Any]):
142143
print(text)
143144

144145

146+
def get_parser_replace() -> ArgumentParser:
147+
parser = ArgumentParser(
148+
prog="translate",
149+
description=dedent(
150+
"""
151+
Replaces constants and initializes by ConstOfShape or any other nodes
152+
to make the model smaller.
153+
"""
154+
),
155+
epilog="This is mostly used to write unit tests without adding "
156+
"a big file to the repository.",
157+
)
158+
parser.add_argument(
159+
"-m",
160+
"--model",
161+
type=str,
162+
required=True,
163+
help="onnx model to translate",
164+
)
165+
parser.add_argument(
166+
"-o",
167+
"--out",
168+
type=str,
169+
required=True,
170+
help="output file",
171+
)
172+
parser.add_argument(
173+
"-t",
174+
"--threshold",
175+
default=128,
176+
help="Threshold above which every constant is replaced",
177+
)
178+
parser.add_argument(
179+
"--type",
180+
default="ConstontOfShape",
181+
help="Inserts this operator type",
182+
)
183+
parser.add_argument(
184+
"--domain",
185+
default="",
186+
help="Inserts this domain",
187+
)
188+
parser.add_argument(
189+
"-v",
190+
"--verbose",
191+
default=0,
192+
help="verbosity",
193+
)
194+
return parser
195+
196+
197+
def _cmd_replace(argv: List[Any]):
198+
from .tools.replace_constants import replace_initializer_by_constant_of_shape
199+
200+
parser = get_parser_replace()
201+
args = parser.parse_args(argv[1:])
202+
if args.verbose in ("1", 1, "True", True):
203+
print(f"[compare] load model {args.model!r}")
204+
onx = onnx.load(args.model)
205+
new_onx = replace_initializer_by_constant_of_shape(
206+
onx, threshold=args.threshold, op_type=args.type, domain=args.domain
207+
)
208+
if args.verbose in ("1", 1, "True", True):
209+
print(f"[compare] save model {args.out!r}")
210+
onnx.save(new_onx, args.out)
211+
212+
145213
def main(argv: Optional[List[Any]] = None):
146-
fcts = dict(translate=_cmd_translate, compare=_cmd_compare)
214+
fcts = dict(translate=_cmd_translate, compare=_cmd_compare, replace=_cmd_replace)
147215

148216
if argv is None:
149217
argv = sys.argv[1:]
@@ -152,7 +220,11 @@ def main(argv: Optional[List[Any]] = None):
152220
parser = get_main_parser()
153221
parser.parse_args(argv)
154222
else:
155-
parsers = dict(translate=get_parser_translate, compare=get_parser_compare)
223+
parsers = dict(
224+
translate=get_parser_translate,
225+
compare=get_parser_compare,
226+
replace=get_parser_replace,
227+
)
156228
cmd = argv[0]
157229
if cmd not in parsers:
158230
raise ValueError(

onnx_array_api/tools/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+

0 commit comments

Comments
 (0)