From 74a5076e505e23906885c3adbb1045bb83fa06fb Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Wed, 9 Apr 2025 22:51:27 -0700 Subject: [PATCH 01/13] Initial commit --- recipes_source/recipes_index.rst | 9 +++ .../torch_compile_torch_function_modes.py | 77 +++++++++++++++++++ 2 files changed, 86 insertions(+) create mode 100644 recipes_source/torch_compile_torch_function_modes.py diff --git a/recipes_source/recipes_index.rst b/recipes_source/recipes_index.rst index f136c4b9c6..a309bbd36c 100644 --- a/recipes_source/recipes_index.rst +++ b/recipes_source/recipes_index.rst @@ -317,6 +317,15 @@ Recipes are bite-sized, actionable examples of how to use specific PyTorch featu :link: ../recipes/amx.html :tags: Model-Optimization +.. (beta) Utilizing Torch Function modes with torch.compile + +.. customcarditem:: + :header: (beta) Utilizing Torch Function modes with torch.compile + :card_description: Override torch operators with Torch Function modes and torch.compile + :image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png + :link: ../recipes/torch_compile_torch_function_modes.html + :tags: Model-Optimization + .. (beta) Compiling the Optimizer with torch.compile .. customcarditem:: diff --git a/recipes_source/torch_compile_torch_function_modes.py b/recipes_source/torch_compile_torch_function_modes.py new file mode 100644 index 0000000000..a0a344da0d --- /dev/null +++ b/recipes_source/torch_compile_torch_function_modes.py @@ -0,0 +1,77 @@ +""" +(beta) Utilizing Torch Function modes with torch.compile +============================================================ + +**Author:** `Michael Lazos `_ +""" + +######################################################### +# This tutorial covers how to use a key torch extensibility point, +# torch function modes, in tandem with torch.compile to override +# the behavior of torch ops at trace time, with no runtime overhead. +# +# .. note:: +# +# This tutorial requires PyTorch 2.7.0 or later. + + +##################################################################### +# Rewriting a torch op (torch.add -> torch.mul) +# ~~~~~~~~~~~~~~~~~~~~~ +# For this example, we'll use torch function modes to rewrite occurences +# of addition with multiply instead. This type of override can be common +# if a certain backend has a custom implementation that should be dispatched +# for a given op. +import torch + +# exit cleanly if we are on a device that doesn't support ``torch.compile`` +if torch.cuda.get_device_capability() < (7, 0): + print("Exiting because torch.compile is not supported on this device.") + import sys + sys.exit(0) + +from torch.overrides import BaseTorchFunctionMode + +# Define our mode, Note: BaseTorchFunctionMode +# implements the actual invocation of func(..) +class AddToMultiplyMode(BaseTorchFunctionMode): + def __torch_function__(self, func, types, args=(), kwargs=None): + if func == torch.Tensor.add: + func = torch.mul + + return super().__torch_function__(func, types, args, kwargs) + +@torch.compile() +def test_fn(x, y): + return x + y * x # Note: infix operators map to torch.Tensor.* methods + +x = torch.rand(2, 2) +y = torch.rand_like(x) + +with AddToMultiplyMode(): + z = test_fn(x, y) + +assert torch.allclose(z, x * y * x) + +# The mode can also be used within the compiled region as well like so + +@torch.compile() +def test_fn(x, y): + with AddToMultiplyMode(): + return x + y * x # Note: infix operators map to torch.Tensor.* methods + +x = torch.rand(2, 2) +y = torch.rand_like(x) +z = test_fn(x, y) + +assert torch.allclose(z, x * y * x) + +###################################################################### +# Conclusion +# ~~~~~~~~~~ +# In this tutorial we demonstrated how to override the behavior of torch.* operators +# using torch function modes from within torch.compile. This enables users to utilize +# the extensibility benefits of torch function modes without the runtime overhead +# of calling torch function on every op invocation. +# +# * `Extending Torch API with Modes `__ - Other examples and backgroun on Torch Function modes. From c51e070f0e2cc5d6cd3138b6ef66c6d8198cc72c Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 10 Apr 2025 12:16:49 -0700 Subject: [PATCH 02/13] Update recipes_source/torch_compile_torch_function_modes.py Co-authored-by: Svetlana Karslioglu --- recipes_source/torch_compile_torch_function_modes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_torch_function_modes.py b/recipes_source/torch_compile_torch_function_modes.py index a0a344da0d..2c99170d82 100644 --- a/recipes_source/torch_compile_torch_function_modes.py +++ b/recipes_source/torch_compile_torch_function_modes.py @@ -6,7 +6,7 @@ """ ######################################################### -# This tutorial covers how to use a key torch extensibility point, +# This recipe covers how to use a key torch extensibility point, # torch function modes, in tandem with torch.compile to override # the behavior of torch ops at trace time, with no runtime overhead. # From 52be71fc2a46abd1778e82753b79a0b9e4e0c931 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 10 Apr 2025 12:16:59 -0700 Subject: [PATCH 03/13] Update recipes_source/torch_compile_torch_function_modes.py Co-authored-by: Svetlana Karslioglu --- recipes_source/torch_compile_torch_function_modes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_torch_function_modes.py b/recipes_source/torch_compile_torch_function_modes.py index 2c99170d82..1f900fa2c4 100644 --- a/recipes_source/torch_compile_torch_function_modes.py +++ b/recipes_source/torch_compile_torch_function_modes.py @@ -7,7 +7,7 @@ ######################################################### # This recipe covers how to use a key torch extensibility point, -# torch function modes, in tandem with torch.compile to override +# torch function modes, in tandem with ``torch.compile`` to override # the behavior of torch ops at trace time, with no runtime overhead. # # .. note:: From 7014f3a780a9635f7c63da7c85c1ab5668a88a4f Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 10 Apr 2025 12:17:08 -0700 Subject: [PATCH 04/13] Update recipes_source/torch_compile_torch_function_modes.py Co-authored-by: Svetlana Karslioglu --- recipes_source/torch_compile_torch_function_modes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_torch_function_modes.py b/recipes_source/torch_compile_torch_function_modes.py index 1f900fa2c4..7d429e5760 100644 --- a/recipes_source/torch_compile_torch_function_modes.py +++ b/recipes_source/torch_compile_torch_function_modes.py @@ -8,7 +8,7 @@ ######################################################### # This recipe covers how to use a key torch extensibility point, # torch function modes, in tandem with ``torch.compile`` to override -# the behavior of torch ops at trace time, with no runtime overhead. +# the behavior of torch operators, also know as **ops**, at trace time, with no runtime overhead. # # .. note:: # From ede8d603cf9890d46fd11126648205485a199a80 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 10 Apr 2025 12:17:15 -0700 Subject: [PATCH 05/13] Update recipes_source/torch_compile_torch_function_modes.py Co-authored-by: Svetlana Karslioglu --- recipes_source/torch_compile_torch_function_modes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_torch_function_modes.py b/recipes_source/torch_compile_torch_function_modes.py index 7d429e5760..92cb83c160 100644 --- a/recipes_source/torch_compile_torch_function_modes.py +++ b/recipes_source/torch_compile_torch_function_modes.py @@ -12,7 +12,7 @@ # # .. note:: # -# This tutorial requires PyTorch 2.7.0 or later. +# This recipe requires PyTorch 2.7.0 or later. ##################################################################### From 2e7943ac92f8fedd033a75c41ef5bff4a9efbb7b Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 10 Apr 2025 12:17:21 -0700 Subject: [PATCH 06/13] Update recipes_source/torch_compile_torch_function_modes.py Co-authored-by: Svetlana Karslioglu --- recipes_source/torch_compile_torch_function_modes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_torch_function_modes.py b/recipes_source/torch_compile_torch_function_modes.py index 92cb83c160..fbb3beccba 100644 --- a/recipes_source/torch_compile_torch_function_modes.py +++ b/recipes_source/torch_compile_torch_function_modes.py @@ -17,7 +17,7 @@ ##################################################################### # Rewriting a torch op (torch.add -> torch.mul) -# ~~~~~~~~~~~~~~~~~~~~~ +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # For this example, we'll use torch function modes to rewrite occurences # of addition with multiply instead. This type of override can be common # if a certain backend has a custom implementation that should be dispatched From 3d3a691dadb878b99eed1fe601a5ded01e346d0b Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 10 Apr 2025 12:17:28 -0700 Subject: [PATCH 07/13] Update recipes_source/torch_compile_torch_function_modes.py Co-authored-by: Svetlana Karslioglu --- recipes_source/torch_compile_torch_function_modes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_torch_function_modes.py b/recipes_source/torch_compile_torch_function_modes.py index fbb3beccba..a1d51ef568 100644 --- a/recipes_source/torch_compile_torch_function_modes.py +++ b/recipes_source/torch_compile_torch_function_modes.py @@ -32,7 +32,7 @@ from torch.overrides import BaseTorchFunctionMode -# Define our mode, Note: BaseTorchFunctionMode +# Define our mode, Note: ``BaseTorchFunctionMode`` # implements the actual invocation of func(..) class AddToMultiplyMode(BaseTorchFunctionMode): def __torch_function__(self, func, types, args=(), kwargs=None): From f9ab2ebb0d300ace917a4f4b5315f396cfde731d Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 10 Apr 2025 12:17:36 -0700 Subject: [PATCH 08/13] Update recipes_source/torch_compile_torch_function_modes.py Co-authored-by: Svetlana Karslioglu --- recipes_source/torch_compile_torch_function_modes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_torch_function_modes.py b/recipes_source/torch_compile_torch_function_modes.py index a1d51ef568..ebd53bfcde 100644 --- a/recipes_source/torch_compile_torch_function_modes.py +++ b/recipes_source/torch_compile_torch_function_modes.py @@ -53,7 +53,7 @@ def test_fn(x, y): assert torch.allclose(z, x * y * x) -# The mode can also be used within the compiled region as well like so +# The mode can also be used within the compiled region as well like this: @torch.compile() def test_fn(x, y): From 2a64b21e1c0ffecea380e1b2ddddf46822f633fc Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 10 Apr 2025 12:17:42 -0700 Subject: [PATCH 09/13] Update recipes_source/torch_compile_torch_function_modes.py Co-authored-by: Svetlana Karslioglu --- recipes_source/torch_compile_torch_function_modes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_torch_function_modes.py b/recipes_source/torch_compile_torch_function_modes.py index ebd53bfcde..e9f7072f2c 100644 --- a/recipes_source/torch_compile_torch_function_modes.py +++ b/recipes_source/torch_compile_torch_function_modes.py @@ -69,7 +69,7 @@ def test_fn(x, y): ###################################################################### # Conclusion # ~~~~~~~~~~ -# In this tutorial we demonstrated how to override the behavior of torch.* operators +# In this recipe we demonstrated how to override the behavior of ``torch.*`` operators # using torch function modes from within torch.compile. This enables users to utilize # the extensibility benefits of torch function modes without the runtime overhead # of calling torch function on every op invocation. From 50ab48e7006392eb1c1afc49436ff515104e962c Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 10 Apr 2025 12:17:49 -0700 Subject: [PATCH 10/13] Update recipes_source/torch_compile_torch_function_modes.py Co-authored-by: Svetlana Karslioglu --- recipes_source/torch_compile_torch_function_modes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_torch_function_modes.py b/recipes_source/torch_compile_torch_function_modes.py index e9f7072f2c..45fca3f67c 100644 --- a/recipes_source/torch_compile_torch_function_modes.py +++ b/recipes_source/torch_compile_torch_function_modes.py @@ -70,7 +70,7 @@ def test_fn(x, y): # Conclusion # ~~~~~~~~~~ # In this recipe we demonstrated how to override the behavior of ``torch.*`` operators -# using torch function modes from within torch.compile. This enables users to utilize +# using torch function modes from within ``torch.compile``. This enables users to utilize # the extensibility benefits of torch function modes without the runtime overhead # of calling torch function on every op invocation. # From 3c2efb649f325c31228952a8d08e8a102b38ee57 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 10 Apr 2025 12:17:55 -0700 Subject: [PATCH 11/13] Update recipes_source/torch_compile_torch_function_modes.py Co-authored-by: Svetlana Karslioglu --- recipes_source/torch_compile_torch_function_modes.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/recipes_source/torch_compile_torch_function_modes.py b/recipes_source/torch_compile_torch_function_modes.py index 45fca3f67c..7808579563 100644 --- a/recipes_source/torch_compile_torch_function_modes.py +++ b/recipes_source/torch_compile_torch_function_modes.py @@ -74,4 +74,4 @@ def test_fn(x, y): # the extensibility benefits of torch function modes without the runtime overhead # of calling torch function on every op invocation. # -# * `Extending Torch API with Modes `__ - Other examples and backgroun on Torch Function modes. +# * See `Extending Torch API with Modes `__ for other examples and background on Torch Function modes. From 56dee0c17c2ec993862b71dcf59eaaa5426e5f06 Mon Sep 17 00:00:00 2001 From: Michael Lazos Date: Thu, 10 Apr 2025 16:26:22 -0700 Subject: [PATCH 12/13] Fix metadata --- .jenkins/metadata.json | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.jenkins/metadata.json b/.jenkins/metadata.json index 6e82d054b4..deb365ec1b 100644 --- a/.jenkins/metadata.json +++ b/.jenkins/metadata.json @@ -49,6 +49,9 @@ "recipes_source/compiling_optimizer_lr_scheduler.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" }, + "recipes_source/torch_compile_torch_function_modes.py" { + "needs": "linux.g5.4xlarge.nvidia.gpu" + }, "intermediate_source/torch_compile_tutorial.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" }, From 6fe3b64a4b60e82355d1bf606233d75782c9706b Mon Sep 17 00:00:00 2001 From: Svetlana Karslioglu Date: Wed, 16 Apr 2025 15:18:13 -0700 Subject: [PATCH 13/13] Apply suggestions from code review --- .jenkins/metadata.json | 3 --- 1 file changed, 3 deletions(-) diff --git a/.jenkins/metadata.json b/.jenkins/metadata.json index deb365ec1b..6e82d054b4 100644 --- a/.jenkins/metadata.json +++ b/.jenkins/metadata.json @@ -49,9 +49,6 @@ "recipes_source/compiling_optimizer_lr_scheduler.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" }, - "recipes_source/torch_compile_torch_function_modes.py" { - "needs": "linux.g5.4xlarge.nvidia.gpu" - }, "intermediate_source/torch_compile_tutorial.py": { "needs": "linux.g5.4xlarge.nvidia.gpu" },