Skip to content

Commit a906010

Browse files
authored
Documentation (#78)
* update requirements * Add ConstantOfShape to light API * add slice * changelogs * k
1 parent 2dd0686 commit a906010

File tree

4 files changed

+38
-2
lines changed

4 files changed

+38
-2
lines changed

CHANGELOGS.rst

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ Change Logs
44
0.2.0
55
+++++
66

7+
* :pr:`77`: supports ConcatOfShape and Slice with the light API
78
* :pr:`76`: add a mode to compare models without execution
89
* :pr:`75`: add QuickGelu to ExtendedReferenceEvaluator
910
* :pr:`71`: adds tools to compare two onnx graphs

_unittests/ut_light_api/test_light_api.py

+29-1
Original file line numberDiff line numberDiff line change
@@ -538,7 +538,35 @@ def test_constant_of_shape(self):
538538
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
539539
self.assertEqualArray(np.zeros((2, 3), dtype=np.float32), got)
540540

541+
def test_constant_of_shape_value(self):
542+
onx = (
543+
start()
544+
.vin("X", TensorProto.INT64, shape=[None, None])
545+
.ConstantOfShape(value=np.array([1], dtype=np.float32))
546+
.vout(shape=[])
547+
.to_onnx()
548+
)
549+
ref = ReferenceEvaluator(onx)
550+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
551+
self.assertEqualArray(np.ones((2, 3), dtype=np.float32), got)
552+
553+
def test_slice(self):
554+
onx = (
555+
start(opset=18, ir_version=9)
556+
.cst(np.array([1], dtype=np.int64), name="one")
557+
.cst(np.array([2], dtype=np.int64), name="two")
558+
.vin("X", TensorProto.INT64, shape=[None, None])
559+
.ConstantOfShape(value=np.array([1], dtype=np.float32))
560+
.rename("CX")
561+
.bring("CX", "one", "two", "one")
562+
.Slice()
563+
.vout(shape=[])
564+
.to_onnx()
565+
)
566+
ref = ReferenceEvaluator(onx)
567+
got = ref.run(None, {"X": np.array([2, 3], dtype=np.int64)})[0]
568+
self.assertEqualArray(np.ones((2, 1), dtype=np.float32), got)
569+
541570

542571
if __name__ == "__main__":
543-
TestLightApi().test_add()
544572
unittest.main(verbosity=2)

onnx_array_api/light_api/_op_var.py

+7
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,13 @@ def Selu(
314314
def Shrink(self, bias: float = 0.0, lambd: float = 0.5) -> "Var":
315315
return self.make_node("Shrink", self, bias=bias, lambd=lambd)
316316

317+
def Slice(
318+
self, starts: "Var", ends: "Var", axes: "Var", steps: Optional["Var"] = None
319+
) -> "Var":
320+
if steps is None:
321+
return self.make_node("Slice", self, starts, ends, axes)
322+
return self.make_node("Slice", self, starts, ends, axes, steps)
323+
317324
def Softmax(self, axis: int = -1) -> "Var":
318325
return self.make_node("Softmax", self, axis=axis)
319326

onnx_array_api/light_api/model.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -406,7 +406,7 @@ def to_onnx(self) -> GRAPH_PROTO:
406406
return graph
407407
model = make_model(graph, opset_imports=opsets)
408408
if self.ir_version:
409-
model.ir_version = ir_version
409+
model.ir_version = self.ir_version
410410
if not is_windows() or not is_azure():
411411
# check_model fails sometimes on Windows
412412
check_model(model)

0 commit comments

Comments
 (0)