diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index a6f80a0d66d..4a2c464eb56 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -400,7 +400,8 @@ Tensor& custom_sdpa_out_impl( ET_CHECK_MSG(q.dim() == 4, "query must be a 4D tensor"); - const int64_t num_keys_for_causal_attention = start_pos + seq_len; + const int64_t num_keys_for_causal_attention = + attn_mask.has_value() ? -1 : start_pos + seq_len; ET_KERNEL_CHECK( ctx, diff --git a/extension/llm/custom_ops/op_sdpa_with_kv_cache_test.cpp b/extension/llm/custom_ops/op_sdpa_with_kv_cache_test.cpp index 435cf44e66f..6c0496af32d 100644 --- a/extension/llm/custom_ops/op_sdpa_with_kv_cache_test.cpp +++ b/extension/llm/custom_ops/op_sdpa_with_kv_cache_test.cpp @@ -524,289 +524,6 @@ TEST(OpScaledDotProductAttentionTest, LargerTest) { EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_5, 1e-4, 1e-4); } -TEST(OpScaledDotProductAttentionTest, BasicTestWithAttnMask) { - TensorFactory tfFloat; - - executorch::aten::Tensor query = tfFloat.make( - {1, 1, 4, 4}, - {0.8823, - 0.9150, - 0.3829, - 0.9593, - 0.3904, - 0.6009, - 0.2566, - 0.7936, - 0.9408, - 0.1332, - 0.9346, - 0.5936, - 0.8694, - 0.5677, - 0.7411, - 0.4294}); - executorch::aten::Tensor key = tfFloat.make( - {1, 1, 4, 4}, - {0.8854, - 0.5739, - 0.2666, - 0.6274, - 0.2696, - 0.4414, - 0.2969, - 0.8317, - 0.1053, - 0.2695, - 0.3588, - 0.1994, - 0.5472, - 0.0062, - 0.9516, - 0.0753}); - executorch::aten::Tensor value = tfFloat.make( - {1, 1, 4, 4}, - {0.8860, - 0.5832, - 0.3376, - 0.8090, - 0.5779, - 0.9040, - 0.5547, - 0.3423, - 0.6343, - 0.3644, - 0.7104, - 0.9464, - 0.7890, - 0.2814, - 0.7886, - 0.5895}); - executorch::aten::Tensor attn_mask = tfFloat.make({1, 1}, {0}); - executorch::aten::Tensor key_cache_0 = tfFloat.zeros({1, 5, 4, 4}); - executorch::aten::Tensor value_cache_0 = tfFloat.zeros({1, 5, 4, 4}); - executorch::aten::Tensor key_cache_1 = tfFloat.zeros({1, 5, 4, 4}); - executorch::aten::Tensor value_cache_1 = tfFloat.zeros({1, 5, 4, 4}); - executorch::aten::Tensor key_cache_2 = tfFloat.zeros({1, 5, 4, 4}); - executorch::aten::Tensor value_cache_2 = tfFloat.zeros({1, 5, 4, 4}); - double dropout_p = 0; - bool is_causal = false; - executorch::aten::optional scale; - - // start pos: 0 layer id 0 - executorch::aten::Tensor ret_expected_0 = tfFloat.make( - {1, 1, 4, 4}, - {0.8860, - 0.5832, - 0.3376, - 0.8090, - 0.5779, - 0.9040, - 0.5547, - 0.3423, - 0.6343, - 0.3644, - 0.7104, - 0.9464, - 0.7890, - 0.2814, - 0.7886, - 0.5895}); - - std::vector out_size = {1, 1, 4, 4}; - executorch::aten::Tensor out = tfFloat.zeros(out_size); - executorch::aten::Tensor ret = op_sdpa_with_kv_cache( - query, - key, - value, - key_cache_0, - value_cache_0, - 0, - 1, - attn_mask, - dropout_p, - is_causal, - scale, - out); - EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_0, 1e-4, 1e-4); - - // start pos: 0 layer id 2 - executorch::aten::Tensor ret_expected_1 = tfFloat.make( - {1, 1, 4, 4}, - {0.8860, - 0.5832, - 0.3376, - 0.8090, - 0.5779, - 0.9040, - 0.5547, - 0.3423, - 0.6343, - 0.3644, - 0.7104, - 0.9464, - 0.7890, - 0.2814, - 0.7886, - 0.5895}); - out = tfFloat.zeros(out_size); - ret = op_sdpa_with_kv_cache( - query, - key, - value, - key_cache_2, - value_cache_2, - 0, - 1, - attn_mask, - dropout_p, - is_causal, - scale, - out); - EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_1, 1e-4, 1e-4); - - attn_mask = tfFloat.make({1, 2}, {0, 0}); - // start pos: 1 layer id 0 - executorch::aten::Tensor ret_expected_2 = tfFloat.make( - {1, 1, 4, 4}, - {0.8860, - 0.5832, - 0.3376, - 0.8090, - 0.5779, - 0.9040, - 0.5547, - 0.3423, - 0.6343, - 0.3644, - 0.7104, - 0.9464, - 0.7890, - 0.2814, - 0.7886, - 0.5895}); - out = tfFloat.zeros(out_size); - ret = op_sdpa_with_kv_cache( - query, - key, - value, - key_cache_0, - value_cache_0, - 1, - 1, - attn_mask, - dropout_p, - is_causal, - scale, - out); - EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_2, 1e-4, 1e-4); - - // start pos: 1 layer id 1 - executorch::aten::Tensor ret_expected_3 = tfFloat.make( - {1, 1, 4, 4}, - {0.6486, - 0.4270, - 0.2472, - 0.5922, - 0.3669, - 0.5740, - 0.3522, - 0.2173, - 0.3635, - 0.2088, - 0.4071, - 0.5423, - 0.5110, - 0.1822, - 0.5107, - 0.3817}); - out = tfFloat.zeros(out_size); - ret = op_sdpa_with_kv_cache( - query, - key, - value, - key_cache_1, - value_cache_1, - 1, - 1, - attn_mask, - dropout_p, - is_causal, - scale, - out); - EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_3, 1e-4, 1e-4); - - attn_mask = tfFloat.make({1, 3}, {0, 0, 0}); - // start pos: 2 layer id 1 - executorch::aten::Tensor ret_expected_4 = tfFloat.make( - {1, 1, 4, 4}, - {0.7490, - 0.4930, - 0.2854, - 0.6838, - 0.4489, - 0.7021, - 0.4308, - 0.2659, - 0.4622, - 0.2655, - 0.5176, - 0.6895, - 0.6202, - 0.2212, - 0.6199, - 0.4634}); - out = tfFloat.zeros(out_size); - ret = op_sdpa_with_kv_cache( - query, - key, - value, - key_cache_1, - value_cache_1, - 2, - 1, - attn_mask, - dropout_p, - is_causal, - scale, - out); - EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_4, 1e-4, 1e-4); - - // start pos: 2 layer id 2 - executorch::aten::Tensor ret_expected_5 = tfFloat.make( - {1, 1, 4, 4}, - {0.7490, - 0.4930, - 0.2854, - 0.6838, - 0.4489, - 0.7021, - 0.4308, - 0.2659, - 0.4622, - 0.2655, - 0.5176, - 0.6895, - 0.6202, - 0.2212, - 0.6199, - 0.4634}); - out = tfFloat.zeros(out_size); - ret = op_sdpa_with_kv_cache( - query, - key, - value, - key_cache_2, - value_cache_2, - 2, - 1, - attn_mask, - dropout_p, - is_causal, - scale, - out); - EXPECT_TENSOR_CLOSE_WITH_TOL(ret, ret_expected_5, 1e-4, 1e-4); -} - TEST(OpScaledDotProductAttentionTest, SequenceTest) { TensorFactory tfFloat; diff --git a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py index 41497b17a66..334e53c437f 100644 --- a/extension/llm/custom_ops/test_sdpa_with_kv_cache.py +++ b/extension/llm/custom_ops/test_sdpa_with_kv_cache.py @@ -67,12 +67,14 @@ def test_sdpa_with_cache_no_mqa_1(self): ) if self.use_mask_with_custom_op: attn_mask = attn_mask.contiguous() + sliced_k_cache = self.k_cache[:, : start_pos + seq_len, :, :] + sliced_v_cache = self.v_cache[:, : start_pos + seq_len, :, :] op_output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, - self.k_cache, - self.v_cache, + sliced_k_cache, + sliced_v_cache, start_pos, seq_len, attn_mask, @@ -108,12 +110,14 @@ def test_sdpa_with_cache_no_mqa_2(self): ) if self.use_mask_with_custom_op: attn_mask = attn_mask.contiguous() + sliced_k_cache = self.k_cache[:, : start_pos + seq_len, :, :] + sliced_v_cache = self.v_cache[:, : start_pos + seq_len, :, :] op_output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, - self.k_cache, - self.v_cache, + sliced_k_cache, + sliced_v_cache, start_pos, seq_len, attn_mask, @@ -150,12 +154,14 @@ def test_sdpa_with_cache_no_mqa_3(self): ) if self.use_mask_with_custom_op: attn_mask = attn_mask.contiguous() + sliced_k_cache = self.k_cache[:, : start_pos + seq_len, :, :] + sliced_v_cache = self.v_cache[:, : start_pos + seq_len, :, :] op_output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, - self.k_cache, - self.v_cache, + sliced_k_cache, + sliced_v_cache, start_pos, seq_len, attn_mask, @@ -191,12 +197,14 @@ def test_sdpa_with_cache_no_mqa_4(self): ) if self.use_mask_with_custom_op: attn_mask = attn_mask.contiguous() + sliced_k_cache = self.k_cache[:, : start_pos + seq_len, :, :] + sliced_v_cache = self.v_cache[:, : start_pos + seq_len, :, :] op_output = torch.ops.llama.sdpa_with_kv_cache( q, k, v, - self.k_cache, - self.v_cache, + sliced_k_cache, + sliced_v_cache, start_pos, seq_len, attn_mask,