Skip to content

Commit e372822

Browse files
committed
[Executorch][llama] bug fix for custom sdpa for attention bias
Pull Request resolved: #10284 When using attention bias dont override seq length for causal attention ghstack-source-id: 279292323 @exported-using-ghexport Differential Revision: [D73222733](https://our.internmc.facebook.com/intern/diff/D73222733/)
1 parent db5e40b commit e372822

File tree

3 files changed

+18
-292
lines changed

3 files changed

+18
-292
lines changed

extension/llm/custom_ops/op_sdpa.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,8 @@ Tensor& custom_sdpa_out_impl(
400400

401401
ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor");
402402

403-
const int64_t num_keys_for_causal_attention = start_pos + seq_len;
403+
const int64_t num_keys_for_causal_attention =
404+
attn_mask.has_value() ? -1 : start_pos + seq_len;
404405

405406
ET_KERNEL_CHECK(
406407
ctx,

extension/llm/custom_ops/op_sdpa_with_kv_cache_test.cpp

-283
Original file line numberDiff line numberDiff line change
@@ -524,289 +524,6 @@ TEST(OpScaledDotProductAttentionTest, LargerTest) {
524524
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_5, 1e-4, 1e-4);
525525
}
526526

527-
TEST(OpScaledDotProductAttentionTest, BasicTestWithAttnMask) {
528-
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
529-
530-
executorch::aten::Tensor query = tfFloat.make(
531-
{1, 1, 4, 4},
532-
{0.8823,
533-
0.9150,
534-
0.3829,
535-
0.9593,
536-
0.3904,
537-
0.6009,
538-
0.2566,
539-
0.7936,
540-
0.9408,
541-
0.1332,
542-
0.9346,
543-
0.5936,
544-
0.8694,
545-
0.5677,
546-
0.7411,
547-
0.4294});
548-
executorch::aten::Tensor key = tfFloat.make(
549-
{1, 1, 4, 4},
550-
{0.8854,
551-
0.5739,
552-
0.2666,
553-
0.6274,
554-
0.2696,
555-
0.4414,
556-
0.2969,
557-
0.8317,
558-
0.1053,
559-
0.2695,
560-
0.3588,
561-
0.1994,
562-
0.5472,
563-
0.0062,
564-
0.9516,
565-
0.0753});
566-
executorch::aten::Tensor value = tfFloat.make(
567-
{1, 1, 4, 4},
568-
{0.8860,
569-
0.5832,
570-
0.3376,
571-
0.8090,
572-
0.5779,
573-
0.9040,
574-
0.5547,
575-
0.3423,
576-
0.6343,
577-
0.3644,
578-
0.7104,
579-
0.9464,
580-
0.7890,
581-
0.2814,
582-
0.7886,
583-
0.5895});
584-
executorch::aten::Tensor attn_mask = tfFloat.make({1, 1}, {0});
585-
executorch::aten::Tensor key_cache_0 = tfFloat.zeros({1, 5, 4, 4});
586-
executorch::aten::Tensor value_cache_0 = tfFloat.zeros({1, 5, 4, 4});
587-
executorch::aten::Tensor key_cache_1 = tfFloat.zeros({1, 5, 4, 4});
588-
executorch::aten::Tensor value_cache_1 = tfFloat.zeros({1, 5, 4, 4});
589-
executorch::aten::Tensor key_cache_2 = tfFloat.zeros({1, 5, 4, 4});
590-
executorch::aten::Tensor value_cache_2 = tfFloat.zeros({1, 5, 4, 4});
591-
double dropout_p = 0;
592-
bool is_causal = false;
593-
executorch::aten::optional<double> scale;
594-
595-
// start pos: 0 layer id 0
596-
executorch::aten::Tensor ret_expected_0 = tfFloat.make(
597-
{1, 1, 4, 4},
598-
{0.8860,
599-
0.5832,
600-
0.3376,
601-
0.8090,
602-
0.5779,
603-
0.9040,
604-
0.5547,
605-
0.3423,
606-
0.6343,
607-
0.3644,
608-
0.7104,
609-
0.9464,
610-
0.7890,
611-
0.2814,
612-
0.7886,
613-
0.5895});
614-
615-
std::vector<int32_t> out_size = {1, 1, 4, 4};
616-
executorch::aten::Tensor out = tfFloat.zeros(out_size);
617-
executorch::aten::Tensor ret = op_sdpa_with_kv_cache(
618-
query,
619-
key,
620-
value,
621-
key_cache_0,
622-
value_cache_0,
623-
0,
624-
1,
625-
attn_mask,
626-
dropout_p,
627-
is_causal,
628-
scale,
629-
out);
630-
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_0, 1e-4, 1e-4);
631-
632-
// start pos: 0 layer id 2
633-
executorch::aten::Tensor ret_expected_1 = tfFloat.make(
634-
{1, 1, 4, 4},
635-
{0.8860,
636-
0.5832,
637-
0.3376,
638-
0.8090,
639-
0.5779,
640-
0.9040,
641-
0.5547,
642-
0.3423,
643-
0.6343,
644-
0.3644,
645-
0.7104,
646-
0.9464,
647-
0.7890,
648-
0.2814,
649-
0.7886,
650-
0.5895});
651-
out = tfFloat.zeros(out_size);
652-
ret = op_sdpa_with_kv_cache(
653-
query,
654-
key,
655-
value,
656-
key_cache_2,
657-
value_cache_2,
658-
0,
659-
1,
660-
attn_mask,
661-
dropout_p,
662-
is_causal,
663-
scale,
664-
out);
665-
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_1, 1e-4, 1e-4);
666-
667-
attn_mask = tfFloat.make({1, 2}, {0, 0});
668-
// start pos: 1 layer id 0
669-
executorch::aten::Tensor ret_expected_2 = tfFloat.make(
670-
{1, 1, 4, 4},
671-
{0.8860,
672-
0.5832,
673-
0.3376,
674-
0.8090,
675-
0.5779,
676-
0.9040,
677-
0.5547,
678-
0.3423,
679-
0.6343,
680-
0.3644,
681-
0.7104,
682-
0.9464,
683-
0.7890,
684-
0.2814,
685-
0.7886,
686-
0.5895});
687-
out = tfFloat.zeros(out_size);
688-
ret = op_sdpa_with_kv_cache(
689-
query,
690-
key,
691-
value,
692-
key_cache_0,
693-
value_cache_0,
694-
1,
695-
1,
696-
attn_mask,
697-
dropout_p,
698-
is_causal,
699-
scale,
700-
out);
701-
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_2, 1e-4, 1e-4);
702-
703-
// start pos: 1 layer id 1
704-
executorch::aten::Tensor ret_expected_3 = tfFloat.make(
705-
{1, 1, 4, 4},
706-
{0.6486,
707-
0.4270,
708-
0.2472,
709-
0.5922,
710-
0.3669,
711-
0.5740,
712-
0.3522,
713-
0.2173,
714-
0.3635,
715-
0.2088,
716-
0.4071,
717-
0.5423,
718-
0.5110,
719-
0.1822,
720-
0.5107,
721-
0.3817});
722-
out = tfFloat.zeros(out_size);
723-
ret = op_sdpa_with_kv_cache(
724-
query,
725-
key,
726-
value,
727-
key_cache_1,
728-
value_cache_1,
729-
1,
730-
1,
731-
attn_mask,
732-
dropout_p,
733-
is_causal,
734-
scale,
735-
out);
736-
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_3, 1e-4, 1e-4);
737-
738-
attn_mask = tfFloat.make({1, 3}, {0, 0, 0});
739-
// start pos: 2 layer id 1
740-
executorch::aten::Tensor ret_expected_4 = tfFloat.make(
741-
{1, 1, 4, 4},
742-
{0.7490,
743-
0.4930,
744-
0.2854,
745-
0.6838,
746-
0.4489,
747-
0.7021,
748-
0.4308,
749-
0.2659,
750-
0.4622,
751-
0.2655,
752-
0.5176,
753-
0.6895,
754-
0.6202,
755-
0.2212,
756-
0.6199,
757-
0.4634});
758-
out = tfFloat.zeros(out_size);
759-
ret = op_sdpa_with_kv_cache(
760-
query,
761-
key,
762-
value,
763-
key_cache_1,
764-
value_cache_1,
765-
2,
766-
1,
767-
attn_mask,
768-
dropout_p,
769-
is_causal,
770-
scale,
771-
out);
772-
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_4, 1e-4, 1e-4);
773-
774-
// start pos: 2 layer id 2
775-
executorch::aten::Tensor ret_expected_5 = tfFloat.make(
776-
{1, 1, 4, 4},
777-
{0.7490,
778-
0.4930,
779-
0.2854,
780-
0.6838,
781-
0.4489,
782-
0.7021,
783-
0.4308,
784-
0.2659,
785-
0.4622,
786-
0.2655,
787-
0.5176,
788-
0.6895,
789-
0.6202,
790-
0.2212,
791-
0.6199,
792-
0.4634});
793-
out = tfFloat.zeros(out_size);
794-
ret = op_sdpa_with_kv_cache(
795-
query,
796-
key,
797-
value,
798-
key_cache_2,
799-
value_cache_2,
800-
2,
801-
1,
802-
attn_mask,
803-
dropout_p,
804-
is_causal,
805-
scale,
806-
out);
807-
EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_5, 1e-4, 1e-4);
808-
}
809-
810527
TEST(OpScaledDotProductAttentionTest, SequenceTest) {
811528
TensorFactory<executorch::aten::ScalarType::Float> tfFloat;
812529

extension/llm/custom_ops/test_sdpa_with_kv_cache.py

+16-8
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,14 @@ def test_sdpa_with_cache_no_mqa_1(self):
6767
)
6868
if self.use_mask_with_custom_op:
6969
attn_mask = attn_mask.contiguous()
70+
sliced_k_cache = self.k_cache[:, : start_pos + seq_len, :, :]
71+
sliced_v_cache = self.v_cache[:, : start_pos + seq_len, :, :]
7072
op_output = torch.ops.llama.sdpa_with_kv_cache(
7173
q,
7274
k,
7375
v,
74-
self.k_cache,
75-
self.v_cache,
76+
sliced_k_cache,
77+
sliced_v_cache,
7678
start_pos,
7779
seq_len,
7880
attn_mask,
@@ -108,12 +110,14 @@ def test_sdpa_with_cache_no_mqa_2(self):
108110
)
109111
if self.use_mask_with_custom_op:
110112
attn_mask = attn_mask.contiguous()
113+
sliced_k_cache = self.k_cache[:, : start_pos + seq_len, :, :]
114+
sliced_v_cache = self.v_cache[:, : start_pos + seq_len, :, :]
111115
op_output = torch.ops.llama.sdpa_with_kv_cache(
112116
q,
113117
k,
114118
v,
115-
self.k_cache,
116-
self.v_cache,
119+
sliced_k_cache,
120+
sliced_v_cache,
117121
start_pos,
118122
seq_len,
119123
attn_mask,
@@ -150,12 +154,14 @@ def test_sdpa_with_cache_no_mqa_3(self):
150154
)
151155
if self.use_mask_with_custom_op:
152156
attn_mask = attn_mask.contiguous()
157+
sliced_k_cache = self.k_cache[:, : start_pos + seq_len, :, :]
158+
sliced_v_cache = self.v_cache[:, : start_pos + seq_len, :, :]
153159
op_output = torch.ops.llama.sdpa_with_kv_cache(
154160
q,
155161
k,
156162
v,
157-
self.k_cache,
158-
self.v_cache,
163+
sliced_k_cache,
164+
sliced_v_cache,
159165
start_pos,
160166
seq_len,
161167
attn_mask,
@@ -191,12 +197,14 @@ def test_sdpa_with_cache_no_mqa_4(self):
191197
)
192198
if self.use_mask_with_custom_op:
193199
attn_mask = attn_mask.contiguous()
200+
sliced_k_cache = self.k_cache[:, : start_pos + seq_len, :, :]
201+
sliced_v_cache = self.v_cache[:, : start_pos + seq_len, :, :]
194202
op_output = torch.ops.llama.sdpa_with_kv_cache(
195203
q,
196204
k,
197205
v,
198-
self.k_cache,
199-
self.v_cache,
206+
sliced_k_cache,
207+
sliced_v_cache,
200208
start_pos,
201209
seq_len,
202210
attn_mask,

0 commit comments

Comments
 (0)