Skip to content

Commit afc1a6c

Browse files
committed
Fix get_canonical_form_slice when lengths are numpy integers
Introduced in f9dfe70
1 parent 781073b commit afc1a6c

File tree

2 files changed

+6
-3
lines changed

2 files changed

+6
-3
lines changed

pytensor/tensor/subtensor.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -325,7 +325,7 @@ def analyze(x):
325325
and is_step_constant
326326
and is_length_constant
327327
):
328-
assert isinstance(length, int)
328+
assert isinstance(length, int | np.integer)
329329
_start, _stop, _step = slice(start, stop, step).indices(length)
330330
if _start <= _stop and _step >= 1:
331331
return slice(_start, _stop, _step), 1

tests/tensor/test_subtensor.py

+5-2
Original file line numberDiff line numberDiff line change
@@ -154,8 +154,11 @@ def test_symbolic_tensor(self):
154154
assert isinstance(res[0].owner.op, ptb.ScalarFromTensor)
155155
assert res[1] == 1
156156

157-
def test_all_integer(self):
158-
res = get_canonical_form_slice(slice(1, 5, 2), 7)
157+
@pytest.mark.parametrize("int_fn", [int, np.int64, as_tensor, as_scalar])
158+
def test_all_integer(self, int_fn):
159+
res = get_canonical_form_slice(
160+
slice(int_fn(1), int_fn(5), int_fn(2)), int_fn(7)
161+
)
159162
assert isinstance(res[0], slice)
160163
assert res[1] == 1
161164

0 commit comments

Comments
 (0)