Skip to content

Commit 01e0fac

Browse files
authoredJun 5, 2024··
Add command line to replace constants in a model (#87)
* example * Add command line to replace constant * doc * ut * doc
1 parent 381d829 commit 01e0fac

File tree

9 files changed

+482
-8
lines changed

9 files changed

+482
-8
lines changed
 

‎CHANGELOGS.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.3.0
55
+++++
66

7+
* :pr:`87`: add command line to replace contant by ConstantOfShape
78
* :pr:`79`: first draft to export to GraphBuilder
89
* :pr:`77`: supports ConcatOfShape and Slice with the light API
910

‎_doc/api/tools.rst

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@ Benchmark
66

77
.. autofunction:: onnx_array_api.ext_test_case.measure_time
88

9+
Manipulations
10+
+++++++++++++
11+
12+
.. autofunction:: onnx_array_api.tools.replace_constants.replace_initializer_by_constant_of_shape
13+
914
Examples
1015
++++++++
1116

Lines changed: 160 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,160 @@
1+
import unittest
2+
import numpy as np
3+
import onnx
4+
import onnx.helper as oh
5+
import onnx.numpy_helper as onh
6+
from onnx import TensorProto
7+
from onnx_array_api.ext_test_case import ExtTestCase
8+
from onnx_array_api.reference import (
9+
ExtendedReferenceEvaluator as ReferenceEvaluator,
10+
)
11+
from onnx_array_api.tools.replace_constants import (
12+
replace_initializer_by_constant_of_shape,
13+
)
14+
15+
16+
class TestReplaceConstants(ExtTestCase):
17+
18+
def test_replace_initializer(self):
19+
dtype = np.float32
20+
value = np.random.randn(2, 100).astype(dtype)
21+
A = onh.from_array(value, name="A")
22+
value = np.array([1], dtype=dtype)
23+
C = onh.from_array(value, name="C")
24+
25+
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
26+
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
27+
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
28+
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
29+
graph = oh.make_graph([node1, node2], "lr", [X], [Y], [A, C])
30+
model_def = oh.make_model(graph)
31+
32+
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
33+
oinf1 = ReferenceEvaluator(model_def)
34+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
35+
repl = replace_initializer_by_constant_of_shape(model_def)
36+
node_types = {n.op_type for n in repl.graph.node}
37+
self.assertIn("ConstantOfShape", node_types)
38+
oinf2 = ReferenceEvaluator(repl)
39+
y1[:, :] = 3.5
40+
y1[0, :] = 0.5
41+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
42+
self.assertEqualArray(y1, y2)
43+
44+
def test_replace_constant(self):
45+
dtype = np.float32
46+
value = np.random.randn(2, 10).astype(dtype)
47+
A = onh.from_array(value, name="A")
48+
value = np.array([1], dtype=dtype)
49+
C = onh.from_array(value, name="C")
50+
51+
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
52+
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
53+
node0 = oh.make_node("Constant", [], ["A"], value=A)
54+
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
55+
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
56+
graph = oh.make_graph([node0, node1, node2], "lr", [X], [Y], [C])
57+
model_def = oh.make_model(graph)
58+
59+
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
60+
oinf1 = ReferenceEvaluator(model_def)
61+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
62+
repl = replace_initializer_by_constant_of_shape(model_def, threshold=0)
63+
node_types = {n.op_type for n in repl.graph.node}
64+
self.assertIn("ConstantOfShape", node_types)
65+
oinf2 = ReferenceEvaluator(repl)
66+
y1[:, :] = 4
67+
y1[0, :] = 1
68+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
69+
self.assertEqualArray(y1, y2)
70+
71+
def test_replace_constant_function(self):
72+
dtype = np.float32
73+
value = np.random.randn(2, 100).astype(dtype)
74+
A = onh.from_array(value, name="A")
75+
value = np.array([1], dtype=dtype)
76+
C = onh.from_array(value, name="C")
77+
78+
X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
79+
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
80+
nodeC = oh.make_node("Constant", [], ["C"], value=C)
81+
node0 = oh.make_node("Constant", [], ["A"], value=A)
82+
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
83+
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
84+
opset_imports = [
85+
oh.make_opsetid("", onnx.defs.onnx_opset_version()),
86+
oh.make_opsetid("custom", 1),
87+
]
88+
fct = oh.make_function(
89+
"custom",
90+
"unittest",
91+
["X"],
92+
["Y"],
93+
[nodeC, node0, node1, node2],
94+
opset_imports,
95+
)
96+
97+
node = oh.make_node("unittest", ["X"], ["Y"], domain="custom")
98+
graph = oh.make_graph([node], "lr", [X], [Y], [C])
99+
model_def = oh.make_model(graph, functions=[fct], opset_imports=opset_imports)
100+
101+
x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
102+
oinf1 = ReferenceEvaluator(model_def)
103+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
104+
repl = replace_initializer_by_constant_of_shape(model_def)
105+
node_types = {n.op_type for n in repl.functions[0].node}
106+
self.assertIn("ConstantOfShape", node_types)
107+
oinf2 = ReferenceEvaluator(repl)
108+
y1[:, :] = 3.5
109+
y1[0, :] = 0.5
110+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
111+
self.assertEqualArray(y1, y2)
112+
113+
def test_replace_constant_graph(self):
114+
value = np.array([0], dtype=np.float32)
115+
zero = onh.from_array(value, name="zero")
116+
117+
X = oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None])
118+
Y = oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None])
119+
120+
rsum = oh.make_node("ReduceSum", ["X"], ["rsum"])
121+
cond = oh.make_node("Greater", ["rsum", "zero"], ["cond"])
122+
123+
then_out = oh.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, None)
124+
then_cst = onh.from_array(np.array([1] * 129).astype(np.float32))
125+
126+
then_const_node = oh.make_node(
127+
"Constant", inputs=[], outputs=["then_out"], value=then_cst, name="cst1"
128+
)
129+
then_body = oh.make_graph([then_const_node], "then_body", [], [then_out])
130+
131+
else_out = oh.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, None)
132+
else_cst = onh.from_array(np.array([-1] * 129).astype(np.float32))
133+
else_const_node = oh.make_node(
134+
"Constant", inputs=[], outputs=["else_out"], value=else_cst, name="cst2"
135+
)
136+
else_body = oh.make_graph([else_const_node], "else_body", [], [else_out])
137+
138+
if_node = oh.make_node(
139+
"If", ["cond"], ["Y"], then_branch=then_body, else_branch=else_body
140+
)
141+
graph = oh.make_graph([rsum, cond, if_node], "if", [X], [Y], [zero])
142+
onnx_model = oh.make_model(
143+
graph, opset_imports=[oh.make_opsetid("", onnx.defs.onnx_opset_version())]
144+
)
145+
self.assertNotIn("ConstantOfShape", str(onnx_model))
146+
147+
x = np.ones((3, 2), dtype=np.float32)
148+
oinf1 = ReferenceEvaluator(onnx_model)
149+
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
150+
repl = replace_initializer_by_constant_of_shape(onnx_model)
151+
self.assertIn("ConstantOfShape", str(repl))
152+
oinf2 = ReferenceEvaluator(repl)
153+
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
154+
y1 = y1.copy()
155+
y1[:] = 0.5
156+
self.assertEqualArray(y1, y2)
157+
158+
159+
if __name__ == "__main__":
160+
unittest.main(verbosity=2)

‎_unittests/ut_xrun_doc/test_command_lines1.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
get_main_parser,
1717
get_parser_compare,
1818
get_parser_translate,
19+
get_parser_replace,
1920
main,
2021
)
2122

@@ -35,6 +36,13 @@ def test_parser_translate(self):
3536
text = st.getvalue()
3637
self.assertIn("model", text)
3738

39+
def test_parser_replace(self):
40+
st = StringIO()
41+
with redirect_stdout(st):
42+
get_parser_replace().print_help()
43+
text = st.getvalue()
44+
self.assertIn("model", text)
45+
3846
def test_command_translate(self):
3947
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
4048
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6])

‎onnx_array_api/_command_lines_parser.py

Lines changed: 76 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@ def get_main_parser() -> ArgumentParser:
1414
)
1515
parser.add_argument(
1616
"cmd",
17-
choices=["translate", "compare"],
17+
choices=["translate", "compare", "replace"],
1818
help=dedent(
1919
"""
2020
Selects a command.
2121
2222
'translate' exports an onnx graph into a piece of code replicating it,
23-
'compare' compares the execution of two onnx models
23+
'compare' compares the execution of two onnx models,
24+
'replace' replaces constant and initliazers by ConstantOfShape to make the model lighter
2425
"""
2526
),
2627
)
@@ -142,8 +143,75 @@ def _cmd_compare(argv: List[Any]):
142143
print(text)
143144

144145

146+
def get_parser_replace() -> ArgumentParser:
147+
parser = ArgumentParser(
148+
prog="translate",
149+
description=dedent(
150+
"""
151+
Replaces constants and initializes by ConstOfShape or any other nodes
152+
to make the model smaller.
153+
"""
154+
),
155+
epilog="This is mostly used to write unit tests without adding "
156+
"a big file to the repository.",
157+
)
158+
parser.add_argument(
159+
"-m",
160+
"--model",
161+
type=str,
162+
required=True,
163+
help="onnx model to translate",
164+
)
165+
parser.add_argument(
166+
"-o",
167+
"--out",
168+
type=str,
169+
required=True,
170+
help="output file",
171+
)
172+
parser.add_argument(
173+
"-t",
174+
"--threshold",
175+
default=128,
176+
help="Threshold above which every constant is replaced",
177+
)
178+
parser.add_argument(
179+
"--type",
180+
default="ConstontOfShape",
181+
help="Inserts this operator type",
182+
)
183+
parser.add_argument(
184+
"--domain",
185+
default="",
186+
help="Inserts this domain",
187+
)
188+
parser.add_argument(
189+
"-v",
190+
"--verbose",
191+
default=0,
192+
help="verbosity",
193+
)
194+
return parser
195+
196+
197+
def _cmd_replace(argv: List[Any]):
198+
from .tools.replace_constants import replace_initializer_by_constant_of_shape
199+
200+
parser = get_parser_replace()
201+
args = parser.parse_args(argv[1:])
202+
if args.verbose in ("1", 1, "True", True):
203+
print(f"[compare] load model {args.model!r}")
204+
onx = onnx.load(args.model)
205+
new_onx = replace_initializer_by_constant_of_shape(
206+
onx, threshold=args.threshold, op_type=args.type, domain=args.domain
207+
)
208+
if args.verbose in ("1", 1, "True", True):
209+
print(f"[compare] save model {args.out!r}")
210+
onnx.save(new_onx, args.out)
211+
212+
145213
def main(argv: Optional[List[Any]] = None):
146-
fcts = dict(translate=_cmd_translate, compare=_cmd_compare)
214+
fcts = dict(translate=_cmd_translate, compare=_cmd_compare, replace=_cmd_replace)
147215

148216
if argv is None:
149217
argv = sys.argv[1:]
@@ -152,7 +220,11 @@ def main(argv: Optional[List[Any]] = None):
152220
parser = get_main_parser()
153221
parser.parse_args(argv)
154222
else:
155-
parsers = dict(translate=get_parser_translate, compare=get_parser_compare)
223+
parsers = dict(
224+
translate=get_parser_translate,
225+
compare=get_parser_compare,
226+
replace=get_parser_replace,
227+
)
156228
cmd = argv[0]
157229
if cmd not in parsers:
158230
raise ValueError(

‎onnx_array_api/array_api/_onnx_common.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,14 +46,13 @@ def asarray(
4646
dtype: Optional[DType] = None,
4747
order: Optional[str] = None,
4848
like: Any = None,
49+
device: Optional[str] = None,
4950
copy: bool = False,
5051
) -> EagerTensor:
5152
"""
5253
Converts anything into an array.
5354
"""
54-
"""
55-
Converts anything into an array.
56-
"""
55+
assert device is None, f"asarray not implemented yet for device={device!r}"
5756
if order not in ("C", None):
5857
raise NotImplementedError(f"asarray is not implemented for order={order!r}.")
5958
if like is not None:

‎onnx_array_api/npx/npx_functions.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -281,7 +281,8 @@ def astype(
281281
to = DType(TensorProto.STRING)
282282
else:
283283
raise TypeError(f"dtype must of type DType, not {type(dtype)}-{dtype}.")
284-
return var(a, op="Cast", to=to.code)
284+
return var(a, op="Cast", to=to.code)
285+
return var(a, op="Cast", to=dtype.code)
285286

286287

287288
@npxapi_inline

‎onnx_array_api/tools/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
import numpy as np
2+
from onnx import FunctionProto, ModelProto, GraphProto, AttributeProto
3+
from onnx.helper import (
4+
make_model,
5+
set_model_props,
6+
make_graph,
7+
make_node,
8+
make_attribute,
9+
make_function,
10+
tensor_dtype_to_np_dtype,
11+
)
12+
from onnx.numpy_helper import from_array
13+
14+
15+
def replace_initializer_by_constant_of_shape(
16+
onx, threshold=128, op_type="ConstantOfShape", domain=""
17+
):
18+
"""
19+
Replaces initializers by nodes *ConstantOfShape* to reduce
20+
the size and still write a unit test.
21+
22+
:param onx: ModelProto
23+
:param threshold: every initializer under this threshold is not impacted
24+
:param op_type: replace by this node
25+
:param domain: replace by this domain
26+
:return: onx, modified ModelProto
27+
"""
28+
if isinstance(onx, FunctionProto):
29+
modified = False
30+
new_nodes = []
31+
for node in onx.node:
32+
if node.op_type == "Constant":
33+
from onnx_array_api.reference import ExtendedReferenceEvaluator
34+
35+
ref = ExtendedReferenceEvaluator(node)
36+
cst = ref.run(None, {})[0]
37+
38+
size = np.prod(cst.shape)
39+
if size <= threshold:
40+
new_nodes.append(node)
41+
continue
42+
43+
new_name = f"{node.output[0]}__SHAPE"
44+
new_nodes.append(
45+
make_node(
46+
"Constant",
47+
[],
48+
[new_name],
49+
value=from_array(
50+
np.array(cst.shape, dtype=np.int64), name=new_name
51+
),
52+
)
53+
)
54+
dtype = cst.dtype
55+
assert op_type != "Constant"
56+
new_nodes.append(
57+
make_node(
58+
op_type,
59+
[new_name],
60+
node.output,
61+
value=from_array(np.array([0.5], dtype=dtype)),
62+
domain=domain,
63+
)
64+
)
65+
modified = True
66+
continue
67+
68+
new_nodes.append(node)
69+
70+
if not modified:
71+
return onx
72+
73+
onxf = make_function(
74+
domain=onx.domain,
75+
fname=onx.name,
76+
inputs=onx.input,
77+
outputs=onx.output,
78+
nodes=new_nodes,
79+
doc_string=onx.doc_string,
80+
overload=onx.overload,
81+
opset_imports=[],
82+
)
83+
if onx.opset_import:
84+
onxf.opset_import.extend(onx.opset_import)
85+
if onx.value_info:
86+
onxf.value_info.extend(onx.value_info)
87+
if onx.attribute:
88+
onxf.attribute.extend(onx.attribute)
89+
if onx.attribute_proto:
90+
onxf.attribute_proto.extend(onx.attribute_proto)
91+
return onxf
92+
93+
if isinstance(onx, ModelProto):
94+
new_graph = replace_initializer_by_constant_of_shape(
95+
onx.graph, threshold=threshold, op_type=op_type, domain=domain
96+
)
97+
new_functions = [
98+
replace_initializer_by_constant_of_shape(
99+
f, threshold=threshold, op_type=op_type, domain=domain
100+
)
101+
for f in onx.functions
102+
]
103+
model = make_model(
104+
new_graph,
105+
functions=new_functions,
106+
producer_name=onx.producer_name,
107+
producer_version=onx.producer_version,
108+
ir_version=onx.ir_version,
109+
doc_string=onx.doc_string,
110+
domain=onx.domain,
111+
model_version=onx.model_version,
112+
)
113+
if len(onx.metadata_props) > 0: # pragma: no cover
114+
values = {p.key: p.value for p in onx.metadata_props}
115+
set_model_props(model, values)
116+
117+
del model.opset_import[:] # pylint: disable=E1101
118+
for oimp in onx.opset_import:
119+
op_set = model.opset_import.add() # pylint: disable=E1101
120+
if oimp.domain == "" and oimp.version < 9:
121+
raise RuntimeError(
122+
f"ConstantOfShape was introduced in "
123+
f"opset 9 but opset is {oimp.version}."
124+
)
125+
op_set.domain = oimp.domain
126+
op_set.version = oimp.version
127+
return model
128+
129+
if not isinstance(onx, GraphProto):
130+
raise TypeError(f"onx should be a GraphProto as this stage not {type(onx)}.")
131+
132+
new_nodes = []
133+
removed = set()
134+
additional_inputs = []
135+
136+
new_inits = []
137+
for init in onx.initializer:
138+
dims = tuple(init.dims)
139+
size = np.prod(dims)
140+
if size <= threshold:
141+
new_inits.append(init)
142+
continue
143+
new_name = f"{init.name}__SHAPE"
144+
new_inits.append(
145+
from_array(np.array(list(dims), dtype=np.int64), name=new_name)
146+
)
147+
dtype = tensor_dtype_to_np_dtype(init.data_type)
148+
node = make_node(
149+
op_type,
150+
[new_name],
151+
[init.name],
152+
value=from_array(np.array([0.5], dtype=dtype)),
153+
domain=domain,
154+
)
155+
new_nodes.append(node)
156+
removed.add(init.name)
157+
158+
new_sparse_inits = []
159+
for init in onx.sparse_initializer:
160+
dims = tuple(init.dims)
161+
size = np.prod(dims)
162+
if size <= threshold:
163+
new_sparse_inits.append(init)
164+
continue
165+
raise NotImplementedError(
166+
f"This feature is not yet implemented for sparse initializer"
167+
f"(name={init.name!r})."
168+
)
169+
170+
for node in onx.node:
171+
if node.op_type == "Constant":
172+
from onnx_array_api.reference import ExtendedReferenceEvaluator
173+
174+
ref = ExtendedReferenceEvaluator(node)
175+
cst = ref.run(None, {})[0]
176+
177+
size = np.prod(cst.shape)
178+
if size <= threshold:
179+
new_nodes.append(node)
180+
continue
181+
182+
new_name = f"{node.output[0]}__SHAPE"
183+
new_inits.append(
184+
from_array(np.array(cst.shape, dtype=np.int64), name=new_name)
185+
)
186+
dtype = cst.dtype
187+
new_nodes.append(
188+
make_node(
189+
op_type,
190+
[new_name],
191+
node.output,
192+
value=from_array(np.array([0.5], dtype=dtype)),
193+
domain=domain,
194+
)
195+
)
196+
continue
197+
198+
modified = False
199+
atts = []
200+
for att in node.attribute:
201+
if (
202+
att.type == AttributeProto.GRAPH
203+
and hasattr(att, "g")
204+
and att.g is not None
205+
):
206+
modified = True
207+
g = replace_initializer_by_constant_of_shape(
208+
att.g, threshold=threshold, op_type=op_type, domain=domain
209+
)
210+
att = make_attribute(att.name, g)
211+
atts.append(att)
212+
if modified:
213+
new_node = make_node(node.op_type, node.input, node.output)
214+
new_node.attribute.extend(atts)
215+
new_nodes.append(new_node)
216+
else:
217+
new_nodes.append(node)
218+
219+
graph = make_graph(
220+
new_nodes,
221+
onx.name,
222+
[i for i in onx.input if i.name not in removed] + additional_inputs,
223+
onx.output,
224+
initializer=new_inits,
225+
sparse_initializer=new_sparse_inits,
226+
)
227+
return graph

0 commit comments

Comments
 (0)
Please sign in to comment.