@@ -410,6 +410,12 @@ def _contraction_list_from_path(
410
410
return contraction_list
411
411
412
412
413
+ def _right_to_left_path (n : int ) -> tuple [tuple [int , int ], ...]:
414
+ # Create a right to left contraction path
415
+ # if n = 5, out = ((4, 3), (3, 2), (2, 1), (1, 0))
416
+ return tuple (pairwise (reversed (range (n ))))
417
+
418
+
413
419
def einsum (subscripts : str , * operands : "TensorLike" , optimize = None ) -> TensorVariable :
414
420
"""
415
421
Multiplication and summation of tensors using the Einstein summation convention.
@@ -563,7 +569,7 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
563
569
else :
564
570
# By default, we try right to left because we assume that most graphs
565
571
# have a lower dimensional rightmost operand
566
- path = tuple ( pairwise ( reversed ( range ( len (tensor_operands ))) ))
572
+ path = _right_to_left_path ( len (tensor_operands ))
567
573
contraction_list = _contraction_list_from_path (
568
574
subscripts , tensor_operands , path
569
575
)
@@ -581,7 +587,18 @@ def einsum(subscripts: str, *operands: "TensorLike", optimize=None) -> TensorVar
581
587
einsum_call = True , # Not part of public API
582
588
optimize = "optimal" ,
583
589
) # type: ignore
584
- path = tuple (contraction [0 ] for contraction in contraction_list )
590
+ np_path = tuple (contraction [0 ] for contraction in contraction_list )
591
+
592
+ if len (np_path ) == 1 and len (np_path [0 ]) > 2 :
593
+ # When there's nothing to optimize, einsum_path reduces all entries simultaneously instead of doing
594
+ # pairwise reductions, which our implementation below demands.
595
+ path = _right_to_left_path (len (tensor_operands ))
596
+ contraction_list = _contraction_list_from_path (
597
+ subscripts , tensor_operands , path
598
+ )
599
+ else :
600
+ path = np_path
601
+
585
602
optimized = True
586
603
587
604
def removechars (s , chars ):
@@ -744,7 +761,7 @@ def filter_singleton_dims(operand, names, other_operand, other_names):
744
761
)
745
762
else :
746
763
raise ValueError (
747
- f"Each step of einsum must have 1 or 2 operands, got { len (operand_indices )} "
764
+ f"Each step of einsum must have 1 or 2 operands, got { len (operand_indices )} , { path = } . "
748
765
)
749
766
750
767
# the resulting 'operand' with axis labels 'names' should be a permutation of the desired result
0 commit comments