From 013874c0aecb02fcd46c67581b2351f515a7fdb4 Mon Sep 17 00:00:00 2001 From: Kimish Patel Date: Thu, 17 Apr 2025 16:05:59 -0700 Subject: [PATCH] [Executorch][llama] Allow custom sdpa op replacement pass to leverage attention mask Previously we assumed that the custom sdpa always does causal attention. This diff adds option to this module swap pass to make custom sdpa leverage attention mask instead of causal. Differential Revision: [D73222736](https://our.internmc.facebook.com/intern/diff/D73222736/) [ghstack-poisoned] --- .../llama/source_transformation/sdpa.py | 64 +++++++++++++++---- 1 file changed, 50 insertions(+), 14 deletions(-) diff --git a/examples/models/llama/source_transformation/sdpa.py b/examples/models/llama/source_transformation/sdpa.py index a50c6aeea22..8df16a4c1b4 100644 --- a/examples/models/llama/source_transformation/sdpa.py +++ b/examples/models/llama/source_transformation/sdpa.py @@ -22,9 +22,15 @@ class SDPACustom(torch.nn.Module): def __init__( self, dim: int, + max_context_len, + enable_dynamic_shape, + use_attention_mask: bool = False, ): super().__init__() self.dim = dim + self.max_context_len = max_context_len + self.use_attention_mask = use_attention_mask + self.enable_dynamic_shape = enable_dynamic_shape def forward( self, @@ -36,6 +42,16 @@ def forward( seqlen, mask, ): + if self.enable_dynamic_shape: + start_pos = input_pos[-1].item() + torch._check_is_size(start_pos) + torch._check(start_pos < self.max_context_len) + seq_length = q.size(2) + # pyre-ignore: Incompatible parameter type [6] + mask = mask.narrow(0, start_pos, seq_length) + else: + mask = mask[input_pos] + q = q.transpose(1, 2) # (bs, seqlen, n_local_heads, head_dim) k = k.transpose(1, 2) v = v.transpose(1, 2) @@ -47,34 +63,54 @@ def forward( k = k.to(dtype=torch.float) v = v.to(dtype=torch.float) - output = torch.ops.llama.custom_sdpa( - q, - k, - v, - input_pos[0].item(), - None, # Attention mask - 0, # dropout probability. Ignored by the code - True, # is_causal - ) + if self.use_attention_mask: + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + input_pos[0].item(), + mask, # Attention mask + 0, # dropout probability. Ignored by the code + False, # is_causal + ) + else: + output = torch.ops.llama.custom_sdpa( + q, + k, + v, + input_pos[0].item(), + None, # Attention mask + 0, # dropout probability. Ignored by the code + True, # is_causal + ) return output.view(bsz, seqlen, self.dim).to(dtype=input_dtype) -def _replace_sdpa_with_custom_op(module: torch.nn.Module): +def _replace_sdpa_with_custom_op( + module: torch.nn.Module, use_attention_mask: bool = False +): for name, child in module.named_children(): if isinstance(child, SDPA): setattr( module, name, - SDPACustom(child.dim), + SDPACustom( + child.dim, + child.max_context_len, + child.enable_dynamic_shape, + use_attention_mask=use_attention_mask, + ), ) else: - _replace_sdpa_with_custom_op(child) + _replace_sdpa_with_custom_op(child, use_attention_mask=use_attention_mask) -def replace_sdpa_with_custom_op(module: torch.nn.Module) -> torch.nn.Module: +def replace_sdpa_with_custom_op( + module: torch.nn.Module, use_attention_mask: bool = False +) -> torch.nn.Module: from executorch.extension.llm.custom_ops import custom_ops # noqa - _replace_sdpa_with_custom_op(module) + _replace_sdpa_with_custom_op(module, use_attention_mask=use_attention_mask) return module