Skip to content

Commit c22e79e

Browse files
committed
Fix bug in einsum
A shortcut in the numpy implementation of einsum_path when there's nothing to optimize, creates a default path that can combine more than 2 operands. Our implementation only works with 2 or 1 operand operations at each step. https://github.com/numpy/numpy/blob/cc5851e654bfd82a23f2758be4bd224be84fc1c3/numpy/_core/einsumfunc.py#L945-L951
1 parent 8bb2038 commit c22e79e

File tree

2 files changed

+39
-3
lines changed

2 files changed

+39
-3
lines changed

Diff for: pytensor/tensor/einsum.py

+20-3
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,12 @@ def _contraction_list_from_path(
410410
return contraction_list
411411

412412

413+
def _right_to_left_path(n: int) -> tuple[tuple[int, int], ...]:
414+
# Create a right to left contraction path
415+
# if n = 5, out = ((4, 3), (3, 2), (2, 1), (1, 0))
416+
return tuple(pairwise(reversed(range(n))))
417+
418+
413419
def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVariable:
414420
"""
415421
Multiplication and summation of tensors using the Einstein summation convention.
@@ -563,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
563569
else:
564570
# By default, we try right to left because we assume that most graphs
565571
# have a lower dimensional rightmost operand
566-
path = tuple(pairwise(reversed(range(len(tensor_operands)))))
572+
path = _right_to_left_path(len(tensor_operands))
567573
contraction_list = _contraction_list_from_path(
568574
subscripts, tensor_operands, path
569575
)
@@ -581,7 +587,18 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
581587
einsum_call=True, # Not part of public API
582588
optimize="optimal",
583589
) # type: ignore
584-
path = tuple(contraction[0] for contraction in contraction_list)
590+
np_path = tuple(contraction[0] for contraction in contraction_list)
591+
592+
if len(np_path) == 1 and len(np_path[0]) > 2:
593+
# When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
594+
# pairwise reductions, which our implementation below demands.
595+
path = _right_to_left_path(len(tensor_operands))
596+
contraction_list = _contraction_list_from_path(
597+
subscripts, tensor_operands, path
598+
)
599+
else:
600+
path = np_path
601+
585602
optimized = True
586603

587604
def removechars(s, chars):
@@ -744,7 +761,7 @@ def filter_singleton_dims(operand, names, other_operand, other_names):
744761
)
745762
else:
746763
raise ValueError(
747-
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}"
764+
f"Each step of einsum must have 1 or 2 operands, got {len(operand_indices)}, {path=}."
748765
)
749766

750767
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result

Diff for: tests/tensor/test_einsum.py

+19
Original file line numberDiff line numberDiff line change
@@ -262,3 +262,22 @@ def test_broadcastable_dims():
262262
atol = 1e-12 if config.floatX == "float64" else 1e-2
263263
np.testing.assert_allclose(suboptimal_eval, np_eval, atol=atol)
264264
np.testing.assert_allclose(optimal_eval, np_eval, atol=atol)
265+
266+
267+
@pytest.mark.parametrize("static_length", [False, True])
268+
def test_threeway_mul(static_length):
269+
# Regression test for https://github.com/pymc-devs/pytensor/issues/1184
270+
# x, y, z = vectors("x", "y", "z")
271+
sh = (3,) if static_length else (None,)
272+
x = tensor("x", shape=sh)
273+
y = tensor("y", shape=sh)
274+
z = tensor("z", shape=sh)
275+
out = einsum("..., ..., ... -> ...", x, y, z)
276+
277+
x_test = np.ones((3,), dtype=x.dtype)
278+
y_test = x_test + 1
279+
z_test = x_test + 2
280+
np.testing.assert_allclose(
281+
out.eval({x: x_test, y: y_test, z: z_test}),
282+
np.full((3,), fill_value=6),
283+
)

0 commit comments

Comments
 (0)