Skip to content

Add command line to replace constants in a model #87

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOGS.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Change Logs
0.3.0
+++++

* :pr:`87`: add command line to replace contant by ConstantOfShape
* :pr:`79`: first draft to export to GraphBuilder
* :pr:`77`: supports ConcatOfShape and Slice with the light API

Expand Down
5 changes: 5 additions & 0 deletions _doc/api/tools.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@ Benchmark

.. autofunction:: onnx_array_api.ext_test_case.measure_time

Manipulations
+++++++++++++

.. autofunction:: onnx_array_api.tools.replace_constants.replace_initializer_by_constant_of_shape

Examples
++++++++

Expand Down
160 changes: 160 additions & 0 deletions _unittests/ut_tools/test_replace_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
import unittest
import numpy as np
import onnx
import onnx.helper as oh
import onnx.numpy_helper as onh
from onnx import TensorProto
from onnx_array_api.ext_test_case import ExtTestCase
from onnx_array_api.reference import (
ExtendedReferenceEvaluator as ReferenceEvaluator,
)
from onnx_array_api.tools.replace_constants import (
replace_initializer_by_constant_of_shape,
)


class TestReplaceConstants(ExtTestCase):

def test_replace_initializer(self):
dtype = np.float32
value = np.random.randn(2, 100).astype(dtype)
A = onh.from_array(value, name="A")
value = np.array([1], dtype=dtype)
C = onh.from_array(value, name="C")

X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
graph = oh.make_graph([node1, node2], "lr", [X], [Y], [A, C])
model_def = oh.make_model(graph)

x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
oinf1 = ReferenceEvaluator(model_def)
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(model_def)
node_types = {n.op_type for n in repl.graph.node}
self.assertIn("ConstantOfShape", node_types)
oinf2 = ReferenceEvaluator(repl)
y1[:, :] = 3.5
y1[0, :] = 0.5
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
self.assertEqualArray(y1, y2)

def test_replace_constant(self):
dtype = np.float32
value = np.random.randn(2, 10).astype(dtype)
A = onh.from_array(value, name="A")
value = np.array([1], dtype=dtype)
C = onh.from_array(value, name="C")

X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
node0 = oh.make_node("Constant", [], ["A"], value=A)
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
graph = oh.make_graph([node0, node1, node2], "lr", [X], [Y], [C])
model_def = oh.make_model(graph)

x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
oinf1 = ReferenceEvaluator(model_def)
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(model_def, threshold=0)
node_types = {n.op_type for n in repl.graph.node}
self.assertIn("ConstantOfShape", node_types)
oinf2 = ReferenceEvaluator(repl)
y1[:, :] = 4
y1[0, :] = 1
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
self.assertEqualArray(y1, y2)

def test_replace_constant_function(self):
dtype = np.float32
value = np.random.randn(2, 100).astype(dtype)
A = onh.from_array(value, name="A")
value = np.array([1], dtype=dtype)
C = onh.from_array(value, name="C")

X = oh.make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Y = oh.make_tensor_value_info("Y", TensorProto.FLOAT, [None])
nodeC = oh.make_node("Constant", [], ["C"], value=C)
node0 = oh.make_node("Constant", [], ["A"], value=A)
node1 = oh.make_node("MatMul", ["X", "A"], ["AX"])
node2 = oh.make_node("Sub", ["AX", "C"], ["Y"])
opset_imports = [
oh.make_opsetid("", onnx.defs.onnx_opset_version()),
oh.make_opsetid("custom", 1),
]
fct = oh.make_function(
"custom",
"unittest",
["X"],
["Y"],
[nodeC, node0, node1, node2],
opset_imports,
)

node = oh.make_node("unittest", ["X"], ["Y"], domain="custom")
graph = oh.make_graph([node], "lr", [X], [Y], [C])
model_def = oh.make_model(graph, functions=[fct], opset_imports=opset_imports)

x = np.array([1, 2, 4, 5, 5, 4]).astype(np.float32).reshape((3, 2))
oinf1 = ReferenceEvaluator(model_def)
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(model_def)
node_types = {n.op_type for n in repl.functions[0].node}
self.assertIn("ConstantOfShape", node_types)
oinf2 = ReferenceEvaluator(repl)
y1[:, :] = 3.5
y1[0, :] = 0.5
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
self.assertEqualArray(y1, y2)

def test_replace_constant_graph(self):
value = np.array([0], dtype=np.float32)
zero = onh.from_array(value, name="zero")

X = oh.make_tensor_value_info("X", onnx.TensorProto.FLOAT, [None, None])
Y = oh.make_tensor_value_info("Y", onnx.TensorProto.FLOAT, [None])

rsum = oh.make_node("ReduceSum", ["X"], ["rsum"])
cond = oh.make_node("Greater", ["rsum", "zero"], ["cond"])

then_out = oh.make_tensor_value_info("then_out", onnx.TensorProto.FLOAT, None)
then_cst = onh.from_array(np.array([1] * 129).astype(np.float32))

then_const_node = oh.make_node(
"Constant", inputs=[], outputs=["then_out"], value=then_cst, name="cst1"
)
then_body = oh.make_graph([then_const_node], "then_body", [], [then_out])

else_out = oh.make_tensor_value_info("else_out", onnx.TensorProto.FLOAT, None)
else_cst = onh.from_array(np.array([-1] * 129).astype(np.float32))
else_const_node = oh.make_node(
"Constant", inputs=[], outputs=["else_out"], value=else_cst, name="cst2"
)
else_body = oh.make_graph([else_const_node], "else_body", [], [else_out])

if_node = oh.make_node(
"If", ["cond"], ["Y"], then_branch=then_body, else_branch=else_body
)
graph = oh.make_graph([rsum, cond, if_node], "if", [X], [Y], [zero])
onnx_model = oh.make_model(
graph, opset_imports=[oh.make_opsetid("", onnx.defs.onnx_opset_version())]
)
self.assertNotIn("ConstantOfShape", str(onnx_model))

x = np.ones((3, 2), dtype=np.float32)
oinf1 = ReferenceEvaluator(onnx_model)
y1 = oinf1.run(None, {"X": x})[0] # type: ignore[index]
repl = replace_initializer_by_constant_of_shape(onnx_model)
self.assertIn("ConstantOfShape", str(repl))
oinf2 = ReferenceEvaluator(repl)
y2 = oinf2.run(None, {"X": x})[0] # type: ignore[index]
y1 = y1.copy()
y1[:] = 0.5
self.assertEqualArray(y1, y2)


if __name__ == "__main__":
unittest.main(verbosity=2)
8 changes: 8 additions & 0 deletions _unittests/ut_xrun_doc/test_command_lines1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
get_main_parser,
get_parser_compare,
get_parser_translate,
get_parser_replace,
main,
)

Expand All @@ -35,6 +36,13 @@ def test_parser_translate(self):
text = st.getvalue()
self.assertIn("model", text)

def test_parser_replace(self):
st = StringIO()
with redirect_stdout(st):
get_parser_replace().print_help()
text = st.getvalue()
self.assertIn("model", text)

def test_command_translate(self):
X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None])
Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6])
Expand Down
80 changes: 76 additions & 4 deletions onnx_array_api/_command_lines_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,14 @@ def get_main_parser() -> ArgumentParser:
)
parser.add_argument(
"cmd",
choices=["translate", "compare"],
choices=["translate", "compare", "replace"],
help=dedent(
"""
Selects a command.

'translate' exports an onnx graph into a piece of code replicating it,
'compare' compares the execution of two onnx models
'compare' compares the execution of two onnx models,
'replace' replaces constant and initliazers by ConstantOfShape to make the model lighter
"""
),
)
Expand Down Expand Up @@ -142,8 +143,75 @@ def _cmd_compare(argv: List[Any]):
print(text)


def get_parser_replace() -> ArgumentParser:
parser = ArgumentParser(
prog="translate",
description=dedent(
"""
Replaces constants and initializes by ConstOfShape or any other nodes
to make the model smaller.
"""
),
epilog="This is mostly used to write unit tests without adding "
"a big file to the repository.",
)
parser.add_argument(
"-m",
"--model",
type=str,
required=True,
help="onnx model to translate",
)
parser.add_argument(
"-o",
"--out",
type=str,
required=True,
help="output file",
)
parser.add_argument(
"-t",
"--threshold",
default=128,
help="Threshold above which every constant is replaced",
)
parser.add_argument(
"--type",
default="ConstontOfShape",
help="Inserts this operator type",
)
parser.add_argument(
"--domain",
default="",
help="Inserts this domain",
)
parser.add_argument(
"-v",
"--verbose",
default=0,
help="verbosity",
)
return parser


def _cmd_replace(argv: List[Any]):
from .tools.replace_constants import replace_initializer_by_constant_of_shape

parser = get_parser_replace()
args = parser.parse_args(argv[1:])
if args.verbose in ("1", 1, "True", True):
print(f"[compare] load model {args.model!r}")
onx = onnx.load(args.model)
new_onx = replace_initializer_by_constant_of_shape(
onx, threshold=args.threshold, op_type=args.type, domain=args.domain
)
if args.verbose in ("1", 1, "True", True):
print(f"[compare] save model {args.out!r}")
onnx.save(new_onx, args.out)


def main(argv: Optional[List[Any]] = None):
fcts = dict(translate=_cmd_translate, compare=_cmd_compare)
fcts = dict(translate=_cmd_translate, compare=_cmd_compare, replace=_cmd_replace)

if argv is None:
argv = sys.argv[1:]
Expand All @@ -152,7 +220,11 @@ def main(argv: Optional[List[Any]] = None):
parser = get_main_parser()
parser.parse_args(argv)
else:
parsers = dict(translate=get_parser_translate, compare=get_parser_compare)
parsers = dict(
translate=get_parser_translate,
compare=get_parser_compare,
replace=get_parser_replace,
)
cmd = argv[0]
if cmd not in parsers:
raise ValueError(
Expand Down
5 changes: 2 additions & 3 deletions onnx_array_api/array_api/_onnx_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,13 @@ def asarray(
dtype: Optional[DType] = None,
order: Optional[str] = None,
like: Any = None,
device: Optional[str] = None,
copy: bool = False,
) -> EagerTensor:
"""
Converts anything into an array.
"""
"""
Converts anything into an array.
"""
assert device is None, f"asarray not implemented yet for device={device!r}"
if order not in ("C", None):
raise NotImplementedError(f"asarray is not implemented for order={order!r}.")
if like is not None:
Expand Down
3 changes: 2 additions & 1 deletion onnx_array_api/npx/npx_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,7 +281,8 @@ def astype(
to = DType(TensorProto.STRING)
else:
raise TypeError(f"dtype must of type DType, not {type(dtype)}-{dtype}.")
return var(a, op="Cast", to=to.code)
return var(a, op="Cast", to=to.code)
return var(a, op="Cast", to=dtype.code)


@npxapi_inline
Expand Down
1 change: 1 addition & 0 deletions onnx_array_api/tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

Loading
Loading