@@ -524,289 +524,6 @@ TEST(OpScaledDotProductAttentionTest, LargerTest) {
524
524
EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_5, 1e-4 , 1e-4 );
525
525
}
526
526
527
- TEST (OpScaledDotProductAttentionTest, BasicTestWithAttnMask) {
528
- TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
529
-
530
- executorch::aten::Tensor query = tfFloat.make (
531
- {1 , 1 , 4 , 4 },
532
- {0.8823 ,
533
- 0.9150 ,
534
- 0.3829 ,
535
- 0.9593 ,
536
- 0.3904 ,
537
- 0.6009 ,
538
- 0.2566 ,
539
- 0.7936 ,
540
- 0.9408 ,
541
- 0.1332 ,
542
- 0.9346 ,
543
- 0.5936 ,
544
- 0.8694 ,
545
- 0.5677 ,
546
- 0.7411 ,
547
- 0.4294 });
548
- executorch::aten::Tensor key = tfFloat.make (
549
- {1 , 1 , 4 , 4 },
550
- {0.8854 ,
551
- 0.5739 ,
552
- 0.2666 ,
553
- 0.6274 ,
554
- 0.2696 ,
555
- 0.4414 ,
556
- 0.2969 ,
557
- 0.8317 ,
558
- 0.1053 ,
559
- 0.2695 ,
560
- 0.3588 ,
561
- 0.1994 ,
562
- 0.5472 ,
563
- 0.0062 ,
564
- 0.9516 ,
565
- 0.0753 });
566
- executorch::aten::Tensor value = tfFloat.make (
567
- {1 , 1 , 4 , 4 },
568
- {0.8860 ,
569
- 0.5832 ,
570
- 0.3376 ,
571
- 0.8090 ,
572
- 0.5779 ,
573
- 0.9040 ,
574
- 0.5547 ,
575
- 0.3423 ,
576
- 0.6343 ,
577
- 0.3644 ,
578
- 0.7104 ,
579
- 0.9464 ,
580
- 0.7890 ,
581
- 0.2814 ,
582
- 0.7886 ,
583
- 0.5895 });
584
- executorch::aten::Tensor attn_mask = tfFloat.make ({1 , 1 }, {0 });
585
- executorch::aten::Tensor key_cache_0 = tfFloat.zeros ({1 , 5 , 4 , 4 });
586
- executorch::aten::Tensor value_cache_0 = tfFloat.zeros ({1 , 5 , 4 , 4 });
587
- executorch::aten::Tensor key_cache_1 = tfFloat.zeros ({1 , 5 , 4 , 4 });
588
- executorch::aten::Tensor value_cache_1 = tfFloat.zeros ({1 , 5 , 4 , 4 });
589
- executorch::aten::Tensor key_cache_2 = tfFloat.zeros ({1 , 5 , 4 , 4 });
590
- executorch::aten::Tensor value_cache_2 = tfFloat.zeros ({1 , 5 , 4 , 4 });
591
- double dropout_p = 0 ;
592
- bool is_causal = false ;
593
- executorch::aten::optional<double > scale;
594
-
595
- // start pos: 0 layer id 0
596
- executorch::aten::Tensor ret_expected_0 = tfFloat.make (
597
- {1 , 1 , 4 , 4 },
598
- {0.8860 ,
599
- 0.5832 ,
600
- 0.3376 ,
601
- 0.8090 ,
602
- 0.5779 ,
603
- 0.9040 ,
604
- 0.5547 ,
605
- 0.3423 ,
606
- 0.6343 ,
607
- 0.3644 ,
608
- 0.7104 ,
609
- 0.9464 ,
610
- 0.7890 ,
611
- 0.2814 ,
612
- 0.7886 ,
613
- 0.5895 });
614
-
615
- std::vector<int32_t > out_size = {1 , 1 , 4 , 4 };
616
- executorch::aten::Tensor out = tfFloat.zeros (out_size);
617
- executorch::aten::Tensor ret = op_sdpa_with_kv_cache (
618
- query,
619
- key,
620
- value,
621
- key_cache_0,
622
- value_cache_0,
623
- 0 ,
624
- 1 ,
625
- attn_mask,
626
- dropout_p,
627
- is_causal,
628
- scale,
629
- out);
630
- EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_0, 1e-4 , 1e-4 );
631
-
632
- // start pos: 0 layer id 2
633
- executorch::aten::Tensor ret_expected_1 = tfFloat.make (
634
- {1 , 1 , 4 , 4 },
635
- {0.8860 ,
636
- 0.5832 ,
637
- 0.3376 ,
638
- 0.8090 ,
639
- 0.5779 ,
640
- 0.9040 ,
641
- 0.5547 ,
642
- 0.3423 ,
643
- 0.6343 ,
644
- 0.3644 ,
645
- 0.7104 ,
646
- 0.9464 ,
647
- 0.7890 ,
648
- 0.2814 ,
649
- 0.7886 ,
650
- 0.5895 });
651
- out = tfFloat.zeros (out_size);
652
- ret = op_sdpa_with_kv_cache (
653
- query,
654
- key,
655
- value,
656
- key_cache_2,
657
- value_cache_2,
658
- 0 ,
659
- 1 ,
660
- attn_mask,
661
- dropout_p,
662
- is_causal,
663
- scale,
664
- out);
665
- EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_1, 1e-4 , 1e-4 );
666
-
667
- attn_mask = tfFloat.make ({1 , 2 }, {0 , 0 });
668
- // start pos: 1 layer id 0
669
- executorch::aten::Tensor ret_expected_2 = tfFloat.make (
670
- {1 , 1 , 4 , 4 },
671
- {0.8860 ,
672
- 0.5832 ,
673
- 0.3376 ,
674
- 0.8090 ,
675
- 0.5779 ,
676
- 0.9040 ,
677
- 0.5547 ,
678
- 0.3423 ,
679
- 0.6343 ,
680
- 0.3644 ,
681
- 0.7104 ,
682
- 0.9464 ,
683
- 0.7890 ,
684
- 0.2814 ,
685
- 0.7886 ,
686
- 0.5895 });
687
- out = tfFloat.zeros (out_size);
688
- ret = op_sdpa_with_kv_cache (
689
- query,
690
- key,
691
- value,
692
- key_cache_0,
693
- value_cache_0,
694
- 1 ,
695
- 1 ,
696
- attn_mask,
697
- dropout_p,
698
- is_causal,
699
- scale,
700
- out);
701
- EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_2, 1e-4 , 1e-4 );
702
-
703
- // start pos: 1 layer id 1
704
- executorch::aten::Tensor ret_expected_3 = tfFloat.make (
705
- {1 , 1 , 4 , 4 },
706
- {0.6486 ,
707
- 0.4270 ,
708
- 0.2472 ,
709
- 0.5922 ,
710
- 0.3669 ,
711
- 0.5740 ,
712
- 0.3522 ,
713
- 0.2173 ,
714
- 0.3635 ,
715
- 0.2088 ,
716
- 0.4071 ,
717
- 0.5423 ,
718
- 0.5110 ,
719
- 0.1822 ,
720
- 0.5107 ,
721
- 0.3817 });
722
- out = tfFloat.zeros (out_size);
723
- ret = op_sdpa_with_kv_cache (
724
- query,
725
- key,
726
- value,
727
- key_cache_1,
728
- value_cache_1,
729
- 1 ,
730
- 1 ,
731
- attn_mask,
732
- dropout_p,
733
- is_causal,
734
- scale,
735
- out);
736
- EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_3, 1e-4 , 1e-4 );
737
-
738
- attn_mask = tfFloat.make ({1 , 3 }, {0 , 0 , 0 });
739
- // start pos: 2 layer id 1
740
- executorch::aten::Tensor ret_expected_4 = tfFloat.make (
741
- {1 , 1 , 4 , 4 },
742
- {0.7490 ,
743
- 0.4930 ,
744
- 0.2854 ,
745
- 0.6838 ,
746
- 0.4489 ,
747
- 0.7021 ,
748
- 0.4308 ,
749
- 0.2659 ,
750
- 0.4622 ,
751
- 0.2655 ,
752
- 0.5176 ,
753
- 0.6895 ,
754
- 0.6202 ,
755
- 0.2212 ,
756
- 0.6199 ,
757
- 0.4634 });
758
- out = tfFloat.zeros (out_size);
759
- ret = op_sdpa_with_kv_cache (
760
- query,
761
- key,
762
- value,
763
- key_cache_1,
764
- value_cache_1,
765
- 2 ,
766
- 1 ,
767
- attn_mask,
768
- dropout_p,
769
- is_causal,
770
- scale,
771
- out);
772
- EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_4, 1e-4 , 1e-4 );
773
-
774
- // start pos: 2 layer id 2
775
- executorch::aten::Tensor ret_expected_5 = tfFloat.make (
776
- {1 , 1 , 4 , 4 },
777
- {0.7490 ,
778
- 0.4930 ,
779
- 0.2854 ,
780
- 0.6838 ,
781
- 0.4489 ,
782
- 0.7021 ,
783
- 0.4308 ,
784
- 0.2659 ,
785
- 0.4622 ,
786
- 0.2655 ,
787
- 0.5176 ,
788
- 0.6895 ,
789
- 0.6202 ,
790
- 0.2212 ,
791
- 0.6199 ,
792
- 0.4634 });
793
- out = tfFloat.zeros (out_size);
794
- ret = op_sdpa_with_kv_cache (
795
- query,
796
- key,
797
- value,
798
- key_cache_2,
799
- value_cache_2,
800
- 2 ,
801
- 1 ,
802
- attn_mask,
803
- dropout_p,
804
- is_causal,
805
- scale,
806
- out);
807
- EXPECT_TENSOR_CLOSE_WITH_TOL (ret, ret_expected_5, 1e-4 , 1e-4 );
808
- }
809
-
810
527
TEST (OpScaledDotProductAttentionTest, SequenceTest) {
811
528
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
812
529
0 commit comments