Skip to content

Commit ada0b51

Browse files
committed
lint
1 parent 8aa1f28 commit ada0b51

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+218
-183
lines changed

_doc/examples/plot_benchmark_rf.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,7 @@ def skl2onnx_convert_lightgbm(scope, operator, container):
4040
)
4141

4242
options = scope.get_options(operator.raw_operator)
43-
if "split" in options:
44-
operator.split = options["split"]
45-
else:
46-
operator.split = None
43+
operator.split = options.get("split", None)
4744
convert_lightgbm(scope, operator, container)
4845

4946

@@ -103,7 +100,7 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
103100
:return: number of runs, sum of the time, average, median
104101
"""
105102
times = []
106-
for n in range(repeat):
103+
for _n in range(repeat):
107104
perf = time.perf_counter()
108105
fct(X)
109106
delta = time.perf_counter() - perf
@@ -241,7 +238,10 @@ def measure_inference(fct, X, repeat, max_time=5, quantile=1):
241238
# onnxruntime
242239
bar.set_description(f"J={n_j} E={n_estimators} D={max_depth} predictO")
243240
r, t, mean, med = measure_inference(
244-
lambda x: sess.run(None, {"X": x}), X, repeat=repeat, max_time=max_time
241+
lambda x, sess=sess: sess.run(None, {"X": x}),
242+
X,
243+
repeat=repeat,
244+
max_time=max_time,
245245
)
246246
o2 = obs.copy()
247247
o2.update(dict(avg=mean, med=med, n_runs=r, ttime=t, name="ort_"))

_doc/examples/plot_onnxruntime.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -87,14 +87,14 @@ def loop(n=1000):
8787
x = np.random.randn(n, 2).astype(np.float32)
8888
y = np.random.randn(n, 2).astype(np.float32)
8989

90-
obs = measure_time(lambda: myloss(x, y))
90+
obs = measure_time(lambda x=x, y=y: myloss(x, y))
9191
obs["name"] = "numpy"
9292
obs["n"] = n
9393
data.append(obs)
9494

9595
xort = OrtTensor.from_array(x)
9696
yort = OrtTensor.from_array(y)
97-
obs = measure_time(lambda: ort_myloss(xort, yort))
97+
obs = measure_time(lambda xort=xort, yort=yort: ort_myloss(xort, yort))
9898
obs["name"] = "ort"
9999
obs["n"] = n
100100
data.append(obs)

_unittests/ut_light_api/test_backend_export.py

+5-4
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@
3131

3232
class ReferenceImplementationError(RuntimeError):
3333
"Fails, export cannot be compared."
34-
pass
3534

3635

3736
class ExportWrapper:
@@ -64,7 +63,8 @@ def run(
6463
expected = self.expected_sess.run(names, feeds)
6564
except (RuntimeError, AssertionError, TypeError, KeyError) as e:
6665
raise ReferenceImplementationError(
67-
f"ReferenceImplementation fails with {onnx_simple_text_plot(self.model)}"
66+
f"ReferenceImplementation fails with "
67+
f"{onnx_simple_text_plot(self.model)}"
6868
f"\n--RAW--\n{self.model}"
6969
) from e
7070

@@ -85,7 +85,7 @@ def run(
8585
new_code = "\n".join(
8686
[f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))]
8787
)
88-
raise AssertionError(f"ERROR {e}\n{new_code}")
88+
raise AssertionError(f"ERROR {e}\n{new_code}") # noqa: B904
8989

9090
locs = {
9191
"np": numpy,
@@ -154,7 +154,8 @@ def run(
154154
):
155155
if a.tolist() != b.tolist():
156156
raise AssertionError(
157-
f"Text discrepancies for api {api!r} with a.dtype={a.dtype} "
157+
f"Text discrepancies for api {api!r} "
158+
f"with a.dtype={a.dtype} "
158159
f"and b.dtype={b.dtype}"
159160
f"\n--BASE--\n{onnx_simple_text_plot(self.model)}"
160161
f"\n--EXP[{api}]--\n{onnx_simple_text_plot(export_model)}"

_unittests/ut_light_api/test_light_api.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -484,7 +484,7 @@ def g(self):
484484
def ah(self):
485485
return True
486486

487-
setattr(A, "h", ah)
487+
setattr(A, "h", ah) # noqa: B010
488488

489489
self.assertTrue(A().h())
490490
self.assertIn("(self)", str(inspect.signature(A.h)))

_unittests/ut_plotting/test_dot_plot.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
import os
32
import unittest
43

_unittests/ut_plotting/test_text_plot.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# -*- coding: utf-8 -*-
21
import os
32
import textwrap
43
import unittest

_unittests/ut_translate_api/test_translate.py

+8-2
Original file line numberDiff line numberDiff line change
@@ -160,8 +160,14 @@ def test_export_if(self):
160160
self.assertEqualArray(np.array([1], dtype=np.int64), got[0])
161161

162162
code = translate(onx)
163-
selse = "g().cst(np.array([0], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
164-
sthen = "g().cst(np.array([1], dtype=np.int64)).rename('Z').bring('Z').vout(elem_type=TensorProto.FLOAT)"
163+
selse = (
164+
"g().cst(np.array([0], dtype=np.int64)).rename('Z')."
165+
"bring('Z').vout(elem_type=TensorProto.FLOAT)"
166+
)
167+
sthen = (
168+
"g().cst(np.array([1], dtype=np.int64)).rename('Z')."
169+
"bring('Z').vout(elem_type=TensorProto.FLOAT)"
170+
)
165171
expected = dedent(
166172
f"""
167173
(

_unittests/ut_translate_api/test_translate_classic.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -221,8 +221,12 @@ def test_topk_reverse(self):
221221
sorted=1
222222
)
223223
)
224-
outputs.append(make_tensor_value_info('Values', TensorProto.FLOAT, shape=[]))
225-
outputs.append(make_tensor_value_info('Indices', TensorProto.FLOAT, shape=[]))
224+
outputs.append(
225+
make_tensor_value_info('Values', TensorProto.FLOAT, shape=[])
226+
)
227+
outputs.append(
228+
make_tensor_value_info('Indices', TensorProto.FLOAT, shape=[])
229+
)
226230
graph = make_graph(
227231
nodes,
228232
'light_api',
@@ -252,7 +256,7 @@ def test_fft(self):
252256
new_code = "\n".join(
253257
[f"{i+1:04} {line}" for i, line in enumerate(code.split("\n"))]
254258
)
255-
raise AssertionError(f"ERROR {e}\n{new_code}")
259+
raise AssertionError(f"ERROR {e}\n{new_code}") # noqa: B904
256260

257261
def test_aionnxml(self):
258262
onx = (

_unittests/ut_validation/test_f8.py

+11-13
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def test_fe5m2_to_float32_paper(self):
8888
self.assertEqual(fe5m2_to_float32(int("11111100", 2)), -numpy.inf)
8989

9090
def test_fe4m3fn_to_float32_all(self):
91-
for i in range(0, 256):
91+
for i in range(256):
9292
a = fe4m3_to_float32_float(i)
9393
b = fe4m3_to_float32(i)
9494
if numpy.isnan(a):
@@ -97,7 +97,7 @@ def test_fe4m3fn_to_float32_all(self):
9797
self.assertEqual(a, b)
9898

9999
def test_fe4m3fn_to_float32_all_ml_types(self):
100-
for i in range(0, 256):
100+
for i in range(256):
101101
a = fe4m3_to_float32_float(i)
102102
b = fe4m3_to_float32(i)
103103
c = new_cvt_float32_to_e4m3fn(b)
@@ -188,7 +188,7 @@ def test_search_float32_into_fe5m2_simple(self):
188188
self.assertEqual(b1, b2)
189189

190190
def test_search_float32_into_fe4m3fn_equal(self):
191-
values = [(fe4m3_to_float32_float(i), i) for i in range(0, 256)]
191+
values = [(fe4m3_to_float32_float(i), i) for i in range(256)]
192192
values.sort()
193193

194194
for value, expected in values:
@@ -208,7 +208,7 @@ def test_search_float32_into_fe4m3fn_equal(self):
208208
self.assertIn(nf, (0, 128))
209209

210210
def test_search_float32_into_fe5m2_equal(self):
211-
values = [(fe5m2_to_float32_float(i), i) for i in range(0, 256)]
211+
values = [(fe5m2_to_float32_float(i), i) for i in range(256)]
212212
values.sort()
213213

214214
for value, expected in values:
@@ -233,7 +233,7 @@ def test_search_float32_into_fe5m2_equal(self):
233233
self.assertEqual(fe5m2_to_float32(nf), float(cf))
234234

235235
def test_search_float32_into_fe4m3fn(self):
236-
values = [(fe4m3_to_float32_float(i), i) for i in range(0, 256)]
236+
values = [(fe4m3_to_float32_float(i), i) for i in range(256)]
237237
values.sort()
238238

239239
obs = []
@@ -308,7 +308,7 @@ def test_search_float32_into_fe4m3fn(self):
308308
)
309309

310310
def test_search_float32_into_fe5m2(self):
311-
values = [(fe5m2_to_float32_float(i), i) for i in range(0, 256)]
311+
values = [(fe5m2_to_float32_float(i), i) for i in range(256)]
312312
values.sort()
313313

314314
obs = []
@@ -651,7 +651,7 @@ def test_search_float32_into_fe5m2fnuz_simple(self):
651651
self.assertEqual(expected, got)
652652

653653
def test_fe4m3fnuz_to_float32_all(self):
654-
for i in range(0, 256):
654+
for i in range(256):
655655
a = fe4m3_to_float32_float(i, uz=True)
656656
b = fe4m3_to_float32(i, uz=True)
657657
if numpy.isnan(a):
@@ -660,7 +660,7 @@ def test_fe4m3fnuz_to_float32_all(self):
660660
self.assertEqual(a, b)
661661

662662
def test_fe5m2fnuz_to_float32_all(self):
663-
for i in range(0, 256):
663+
for i in range(256):
664664
a = fe5m2_to_float32_float(i, fn=True, uz=True)
665665
b = fe5m2_to_float32(i, fn=True, uz=True)
666666
if numpy.isnan(a):
@@ -669,7 +669,7 @@ def test_fe5m2fnuz_to_float32_all(self):
669669
self.assertEqual(a, b)
670670

671671
def test_search_float32_into_fe4m3fnuz(self):
672-
values = [(fe4m3_to_float32_float(i, uz=True), i) for i in range(0, 256)]
672+
values = [(fe4m3_to_float32_float(i, uz=True), i) for i in range(256)]
673673
values.sort()
674674

675675
obs = []
@@ -715,9 +715,7 @@ def test_search_float32_into_fe4m3fnuz(self):
715715
)
716716

717717
def test_search_float32_into_fe5m2fnuz(self):
718-
values = [
719-
(fe5m2_to_float32_float(i, fn=True, uz=True), i) for i in range(0, 256)
720-
]
718+
values = [(fe5m2_to_float32_float(i, fn=True, uz=True), i) for i in range(256)]
721719
values.sort()
722720

723721
obs = []
@@ -1235,7 +1233,7 @@ def test_nan(self):
12351233
expected,
12361234
)
12371235
]
1238-
for i in range(0, 23):
1236+
for i in range(23):
12391237
v = 0x7F800000 | (1 << i)
12401238
f = numpy.uint32(v).view(numpy.float32)
12411239
values.append((i, v, f, expected))

_unittests/ut_xrun_doc/test_documentation_examples.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ def run_test(self, fold: str, name: str, verbose=0) -> int:
4949
if verbose:
5050
print(f"failed: {name!r} due to missing dot.")
5151
return 0
52-
raise AssertionError(
52+
raise AssertionError( # noqa: B904
5353
"Example '{}' (cmd: {} - exec_prefix='{}') "
5454
"failed due to\n{}"
5555
"".format(name, cmds, sys.exec_prefix, st)

onnx_array_api/__init__.py

-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
# coding: utf-8
21
"""
32
APIs to create ONNX Graphs.
43
"""

onnx_array_api/_command_lines_parser.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ def get_main_parser() -> ArgumentParser:
1818
help=dedent(
1919
"""
2020
Selects a command.
21-
21+
2222
'translate' exports an onnx graph into a piece of code replicating it,
2323
'compare' compares the execution of two onnx models,
24-
'replace' replaces constant and initliazers by ConstantOfShape to make the model lighter
24+
'replace' replaces constant and initliazers by ConstantOfShape
25+
to make the model lighter
2526
"""
2627
),
2728
)
@@ -75,7 +76,8 @@ def get_parser_compare() -> ArgumentParser:
7576
Compares the execution of two onnx models.
7677
"""
7778
),
78-
epilog="This is used when two models are different but should produce the same results.",
79+
epilog="This is used when two models are different but "
80+
"should produce the same results.",
7981
)
8082
parser.add_argument(
8183
"-m1",

onnx_array_api/_helpers.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -41,5 +41,5 @@ def np_dtype_to_tensor_dtype(dtype: Any):
4141
elif dtype is float:
4242
dt = TensorProto.DOUBLE
4343
else:
44-
raise KeyError(f"Unable to guess type for dtype={dtype}.")
44+
raise KeyError(f"Unable to guess type for dtype={dtype}.") # noqa: B904
4545
return dt

onnx_array_api/array_api/__init__.py

+12-9
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,8 @@ def _finfo(dtype):
5151
d[k] = v
5252
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
5353
nres = type("finfo", (res.__class__,), d)
54-
setattr(nres, "smallest_normal", float(res.smallest_normal))
55-
setattr(nres, "tiny", float(res.tiny))
54+
setattr(nres, "smallest_normal", float(res.smallest_normal)) # noqa: B010
55+
setattr(nres, "tiny", float(res.tiny)) # noqa: B010
5656
return nres
5757

5858

@@ -84,8 +84,8 @@ def _iinfo(dtype):
8484
d[k] = v
8585
d["dtype"] = DType(np_dtype_to_tensor_dtype(dt))
8686
nres = type("iinfo", (res.__class__,), d)
87-
setattr(nres, "min", int(res.min))
88-
setattr(nres, "max", int(res.max))
87+
setattr(nres, "min", int(res.min)) # noqa: B010
88+
setattr(nres, "max", int(res.max)) # noqa: B010
8989
return nres
9090

9191

@@ -133,10 +133,10 @@ def _finalize_array_api(module, function_names, TEagerTensor):
133133
module.uint32 = DType(TensorProto.UINT32)
134134
module.uint64 = DType(TensorProto.UINT64)
135135
module.bfloat16 = DType(TensorProto.BFLOAT16)
136-
setattr(module, "bool", DType(TensorProto.BOOL))
137-
setattr(module, "str", DType(TensorProto.STRING))
138-
setattr(module, "finfo", _finfo)
139-
setattr(module, "iinfo", _iinfo)
136+
setattr(module, "bool", DType(TensorProto.BOOL)) # noqa: B010
137+
setattr(module, "str", DType(TensorProto.STRING)) # noqa: B010
138+
setattr(module, "finfo", _finfo) # noqa: B010
139+
setattr(module, "iinfo", _iinfo) # noqa: B010
140140

141141
if function_names is None:
142142
function_names = supported_functions
@@ -146,7 +146,10 @@ def _finalize_array_api(module, function_names, TEagerTensor):
146146
if f is None:
147147
f2 = getattr(npx_functions, name, None)
148148
if f2 is None:
149-
warnings.warn(f"Function {name!r} is not available in {module!r}.")
149+
warnings.warn(
150+
f"Function {name!r} is not available in {module!r}.",
151+
stacklevel=0,
152+
)
150153
continue
151154
f = lambda TEagerTensor, *args, _f=f2, **kwargs: _f( # noqa: E731
152155
*args, **kwargs

onnx_array_api/array_api/_onnx_common.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,13 @@ def asarray(
9292
elif isinstance(a, str):
9393
v = TEagerTensor(np.array(a, dtype=np.str_))
9494
elif isinstance(a, list):
95-
if all(map(lambda x: isinstance(x, bool), a)):
95+
if all(isinstance(x, bool) for x in a):
9696
v = TEagerTensor(np.array(a, dtype=np.bool_))
97-
elif all(map(lambda x: isinstance(x, int), a)):
97+
elif all(isinstance(x, int) for x in a):
9898
try:
9999
cvt = np.array(a, dtype=np.int64)
100100
except OverflowError as e:
101-
if all(map(lambda x: x >= 0, a)):
101+
if all(x >= 0 for x in a):
102102
cvt = np.array(a, dtype=np.uint64)
103103
else:
104104
raise e
@@ -127,9 +127,7 @@ def arange(
127127
step: EagerTensor[OptTensorType[ElemType.int64, "I", (1,)]] = None,
128128
dtype: OptParType[DType] = None,
129129
) -> EagerTensor[TensorType[ElemType.numerics, "T"]]:
130-
use_float = any(
131-
map(lambda x: isinstance(x, float), [start_or_stop, stop_or_step, step])
132-
)
130+
use_float = any(isinstance(x, float) for x in [start_or_stop, stop_or_step, step])
133131
if isinstance(start_or_stop, int):
134132
start_or_stop = TEagerTensor(
135133
np.array([start_or_stop], dtype=np.float64 if use_float else np.int64)
@@ -207,7 +205,7 @@ def eye(
207205
/,
208206
*,
209207
k: ParType[int] = 0,
210-
dtype: ParType[DType] = DType(TensorProto.DOUBLE),
208+
dtype: ParType[DType] = DType(TensorProto.DOUBLE), # noqa: B008
211209
):
212210
if isinstance(n_rows, int):
213211
n_rows = TEagerTensor(np.array(n_rows, dtype=np.int64))
@@ -245,7 +243,7 @@ def linspace(
245243
dtype: OptParType[DType] = None,
246244
endpoint: ParType[int] = 1,
247245
) -> EagerTensor[TensorType[ElemType.numerics, "T"]]:
248-
use_float = any(map(lambda x: isinstance(x, float), [start, stop]))
246+
use_float = any(isinstance(x, float) for x in [start, stop])
249247
if isinstance(start, int):
250248
start = TEagerTensor(
251249
np.array(start, dtype=np.float64 if use_float else np.int64)

0 commit comments

Comments
 (0)