diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 21bee7c6680..79a225232e0 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1227,10 +1227,22 @@ def _get_source_transforms( # noqa if args.expand_rope_table: transforms.append(materialze_broadcast_of_rope_freq_cis) + use_attention_mask_for_custom_sdpa = False + if isinstance(args, argparse.Namespace): + if getattr(args, "use_custom_sdpa_with_attention_mask", None): + use_attention_mask_for_custom_sdpa = True + if args.use_sdpa_with_kv_cache: transforms.append(replace_kv_cache_with_custom_kv_cache) # todo: do this optionally - transforms.append(replace_sdpa_with_custom_op) + # if use attention mask instead of causal attention + # then create partial function that sets use_attention_mask=True + if use_attention_mask_for_custom_sdpa: + transforms.append( + partial(replace_sdpa_with_custom_op, use_attention_mask=True) + ) + else: + transforms.append(replace_sdpa_with_custom_op) if args.quantize_kv_cache: assert args.use_kv_cache, "quantize_kv_cache requires use_kv_cache=True"