diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index a50c6aeea22..1bc54198fba 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -22,9 +22,15 @@ class SDPACustom(torch.nn.Module): def __init__( self, dim: int, + max_context_len, + enable_dynamic_shape, + use_attention_mask: bool = False, ): super().__init__() self.dim = dim + self.max_context_len = max_context_len + self.use_attention_mask = use_attention_mask + self.enable_dynamic_shape = enable_dynamic_shape def forward( self, @@ -36,6 +42,16 @@ def forward( seqlen, mask, ): + if self.use_attention_mask: + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_context_len) + seq_length = q.size(2) + mask = mask.narrow(0, start_pos, seq_length) + else: + mask = mask[input_pos] + q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) k = k.transpose(1, 2) v = v.transpose(1, 2) @@ -47,34 +63,54 @@ def forward( k = k.to(dtype=torch.float) v = v.to(dtype=torch.float) - output = torch.ops.llama.custom_sdpa( - q, - k, - v, - input_pos[0].item(), - None, # Attention mask - 0, # dropout probability. Ignored by the code - True, # is_causal - ) + if self.use_attention_mask: + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + input_pos[0].item(), + mask, # Attention mask + 0, # dropout probability. Ignored by the code + False, # is_causal + ) + else: + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + input_pos[0].item(), + None, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal + ) return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) -def _replace_sdpa_with_custom_op(module: torch.nn.Module): +def _replace_sdpa_with_custom_op( + module: torch.nn.Module, use_attention_mask: bool = False +): for name, child in module.named_children(): if isinstance(child, SDPA): setattr( module, name, - SDPACustom(child.dim), + SDPACustom( + child.dim, + child.max_context_len, + child.enable_dynamic_shape, + use_attention_mask=use_attention_mask, + ), ) else: - _replace_sdpa_with_custom_op(child) + _replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask) -def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: +def replace_sdpa_with_custom_op( + module: torch.nn.Module, use_attention_mask: bool = False +) -> torch.nn.Module: from executorch.extension.llm.custom_ops import custom_ops # noqa - _replace_sdpa_with_custom_op(module) + _replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask) return module diff --git a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py index b2c93d7d93d..e5e278f8ce8 100644 --- a/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py +++ b/examples/models/llama/source_transformation/test_sdpa_with_quantized_kv_cache.py @@ -71,8 +71,8 @@ def test_simple(self, is_dynamic_shape=False): self.seq_len = 3 self._init_cache() q, k_val, v_val = self._init_kv() - self.float_sdpa = SDPACustom(self.dim) - self.quantized_sdpa = SDPACustom(self.dim) + self.float_sdpa = SDPACustom(self.dim, self.max_context_len, True) + self.quantized_sdpa = SDPACustom(self.dim, self.max_context_len, True) k, v = self.custom_kv_cache.update(input_pos, k_val, v_val) float_out = self.float_sdpa(input_pos, q, k, v, 1, self.seq_len, None) k, v = self.quantized_kv_cache.update(input_pos, k_val, v_val)