Skip to content

Commit 7675869

Browse files
authored
Extend ExtendedReferenceEvaluator (#75)
* update requirements * add more operator to the reference evaluator * extend unit test copverage
1 parent a070da3 commit 7675869

File tree

6 files changed

+222
-0
lines changed

6 files changed

+222
-0
lines changed

CHANGELOGS.rst

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
78
* :pr:`71`: adds tools to compare two onnx graphs
89
* :pr:`61`: adds function to plot onnx model as graphs
910
* :pr:`60`: supports translation of local functions

_unittests/ut_reference/test_reference_ops.py

+82
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,88 @@ def test_fused_matmul11(self):
5959
got = ref.run(None, {"X": a, "Y": a})
6060
self.assertEqualArray(a.T @ a.T, got[0])
6161

62+
def test_memcpy(self):
63+
model = make_model(
64+
make_graph(
65+
[
66+
make_node("MemcpyToHost", ["X"], ["Z"]),
67+
make_node("MemcpyFromHost", ["X"], ["Z"]),
68+
],
69+
"name",
70+
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
71+
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
72+
),
73+
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
74+
ir_version=9,
75+
)
76+
a = np.arange(4).reshape(-1, 2).astype(np.float32)
77+
ref = ExtendedReferenceEvaluator(model)
78+
got = ref.run(None, {"X": a})
79+
self.assertEqualArray(a, got[0])
80+
81+
def test_quick_gelu(self):
82+
from onnxruntime import InferenceSession
83+
84+
for alpha in [0.0, 2.0]:
85+
model = make_model(
86+
make_graph(
87+
[
88+
make_node(
89+
"QuickGelu",
90+
["X"],
91+
["Z"],
92+
domain="com.microsoft",
93+
alpha=alpha,
94+
)
95+
],
96+
"name",
97+
[make_tensor_value_info("X", TensorProto.FLOAT, None)],
98+
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
99+
),
100+
opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)],
101+
ir_version=9,
102+
)
103+
sess = InferenceSession(
104+
model.SerializeToString(), providers=["CPUExecutionProvider"]
105+
)
106+
a = np.arange(4).reshape(-1, 2).astype(np.float32)
107+
expected = sess.run(None, {"X": a})
108+
ref = ExtendedReferenceEvaluator(model)
109+
got = ref.run(None, {"X": a})
110+
self.assertEqualArray(expected[0], got[0])
111+
112+
def test_scatter_elements(self):
113+
model = make_model(
114+
make_graph(
115+
[
116+
make_node(
117+
"ScatterElements",
118+
["data", "indices", "updates"],
119+
["Z"],
120+
axis=3,
121+
reduction="add",
122+
)
123+
],
124+
"name",
125+
[
126+
make_tensor_value_info("data", TensorProto.FLOAT, None),
127+
make_tensor_value_info("indices", TensorProto.INT64, None),
128+
make_tensor_value_info("updates", TensorProto.FLOAT, None),
129+
],
130+
[make_tensor_value_info("Z", TensorProto.FLOAT, None)],
131+
),
132+
opset_imports=[make_opsetid("", 18)],
133+
)
134+
data = np.zeros(2**4, dtype=np.float32).reshape((2, 2, 2, 2))
135+
indices = np.array([[[[0]]]], dtype=np.int64)
136+
updates = np.array([[[[1]]]], dtype=np.float32)
137+
y = np.array(
138+
[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], dtype=np.float32
139+
).reshape((2, 2, 2, 2))
140+
ref = ExtendedReferenceEvaluator(model)
141+
got = ref.run(None, {"data": data, "indices": indices, "updates": updates})
142+
self.assertEqualArray(y, got[0])
143+
62144

63145
if __name__ == "__main__":
64146
unittest.main(verbosity=2)

onnx_array_api/reference/evaluator.py

+7
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
from .ops.op_concat import Concat
99
from .ops.op_constant_of_shape import ConstantOfShape
1010
from .ops.op_fused_matmul import FusedMatMul
11+
from .ops.op_memcpy_host import MemcpyFromHost, MemcpyToHost
12+
from .ops.op_quick_gelu import QuickGelu
13+
from .ops.op_scatter_elements import ScatterElements
1114

1215

1316
logger = getLogger("onnx-array-api-eval")
@@ -34,6 +37,10 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator):
3437
CastLike_19,
3538
ConstantOfShape,
3639
FusedMatMul,
40+
MemcpyFromHost,
41+
MemcpyToHost,
42+
QuickGelu,
43+
ScatterElements,
3744
]
3845

3946
@staticmethod
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from onnx.reference.op_run import OpRun
2+
3+
4+
class MemcpyFromHost(OpRun):
5+
def _run(self, x):
6+
return (x,)
7+
8+
9+
class MemcpyToHost(OpRun):
10+
def _run(self, x):
11+
return (x,)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
import numpy as np
2+
from onnx.reference.op_run import OpRun
3+
4+
5+
def sigmoid(x): # type: ignore
6+
if x > 0:
7+
return 1 / (1 + np.exp(-x))
8+
return np.exp(x) / (1 + np.exp(x))
9+
10+
11+
class QuickGelu(OpRun):
12+
op_domain = "com.microsoft"
13+
14+
def __init__(self, onnx_node, run_params): # type: ignore
15+
OpRun.__init__(self, onnx_node, run_params)
16+
self.vf = np.vectorize(sigmoid)
17+
18+
def _run(self, X, alpha=1.0):
19+
if len(X.shape) == 0:
20+
return ((X * sigmoid(X * alpha)).astype(X.dtype),)
21+
if X.size == 0:
22+
return (X,)
23+
return ((X * self.vf(X * alpha)).astype(X.dtype),)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
import numpy as np
2+
3+
from onnx.reference.op_run import OpRun
4+
5+
6+
def scatter_elements(data, indices, updates, axis=0, reduction=None): # type: ignore
7+
if reduction == "add":
8+
9+
def f(x, y):
10+
return x + y
11+
12+
elif reduction == "min":
13+
14+
def f(x, y):
15+
return min(x, y)
16+
17+
elif reduction == "max":
18+
19+
def f(x, y):
20+
return max(x, y)
21+
22+
else:
23+
24+
def f(x, y):
25+
return y
26+
27+
if axis < 0:
28+
axis = data.ndim + axis
29+
30+
if len(data.shape) == 1 and axis == 0:
31+
scattered = np.copy(data)
32+
for pos, up in zip(indices, updates):
33+
scattered[pos] = f(scattered[pos], up)
34+
return scattered
35+
36+
if len(indices.shape) == 2:
37+
scattered = np.copy(data)
38+
if axis == 0:
39+
for i in range(indices.shape[0]):
40+
for j in range(indices.shape[1]):
41+
scattered[indices[i, j], j] = f(
42+
scattered[indices[i, j], j], updates[i, j]
43+
)
44+
else:
45+
for i in range(indices.shape[0]):
46+
for j in range(indices.shape[1]):
47+
scattered[i, indices[i, j]] = f(
48+
scattered[i, indices[i, j]], updates[i, j]
49+
)
50+
return scattered
51+
52+
if len(indices.shape) == 3:
53+
scattered = np.copy(data)
54+
if axis == 0:
55+
for i in range(indices.shape[0]):
56+
for j in range(indices.shape[1]):
57+
for k in range(indices.shape[2]):
58+
scattered[indices[i, j, k], j, k] = f(
59+
scattered[indices[i, j, k], j, k], updates[i, j, k]
60+
)
61+
elif axis == 1:
62+
for i in range(indices.shape[0]):
63+
for j in range(indices.shape[1]):
64+
for k in range(indices.shape[2]):
65+
scattered[i, indices[i, j, k], k] = f(
66+
scattered[i, indices[i, j, k], k], updates[i, j, k]
67+
)
68+
elif axis == 2:
69+
for i in range(indices.shape[0]):
70+
for j in range(indices.shape[1]):
71+
for k in range(indices.shape[2]):
72+
scattered[i, j, indices[i, j, k]] = f(
73+
scattered[i, j, indices[i, j, k]], updates[i, j, k]
74+
)
75+
return scattered
76+
77+
if len(indices.shape) == 4:
78+
scattered = np.copy(data)
79+
if axis == 3:
80+
for a in range(indices.shape[0]):
81+
for i in range(indices.shape[1]):
82+
for j in range(indices.shape[2]):
83+
for k in range(indices.shape[3]):
84+
scattered[a, i, j, indices[a, i, j, k]] = f(
85+
scattered[a, i, j, indices[a, i, j, k]],
86+
updates[a, i, j, k],
87+
)
88+
return scattered
89+
90+
raise RuntimeError(
91+
f"Not implemented for indices.shape={indices.shape} and axis={axis}"
92+
)
93+
94+
95+
class ScatterElements(OpRun):
96+
def _run(self, data, indices, updates, axis=None, reduction=None): # type: ignore
97+
res = scatter_elements(data, indices, updates, axis=axis, reduction=reduction)
98+
return (res,)

0 commit comments

Comments
 (0)