Skip to content

Add Context Parallel tutorial #3319

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Apr 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
864aedb
[DO NOT MERGE] 2.7 RC Test
svekars Mar 18, 2025
490e3b1
Update .jenkins/build.sh
svekars Mar 18, 2025
a7fcb92
Update .jenkins/build.sh
svekars Mar 18, 2025
1daf409
Update build.sh
svekars Mar 18, 2025
6e218dc
Update build.sh
svekars Mar 18, 2025
6359756
Update build.sh
svekars Mar 18, 2025
edd7240
Update onnxscript in requirements (#3300)
justinchuby Mar 19, 2025
9a649ea
Update build.sh
svekars Mar 21, 2025
83a8781
Update .jenkins/validate_tutorials_built.py
svekars Mar 21, 2025
cfb2719
Update build.sh
svekars Mar 21, 2025
29f4c56
Update .jenkins/build.sh
svekars Mar 22, 2025
4eb24e1
Update build.sh
svekars Mar 24, 2025
45f2bd5
Apply suggestions from code review
svekars Mar 24, 2025
c309e11
Update build.sh
svekars Mar 24, 2025
3674238
Update requirements.txt
svekars Mar 24, 2025
d40f855
Update .jenkins/build.sh
svekars Mar 24, 2025
dbfe3da
Update .jenkins/build.sh
svekars Mar 24, 2025
3885455
Fix the AOTI example (#3306)
desertfire Mar 25, 2025
d7d29fe
Update build.sh
svekars Mar 26, 2025
f2fcf6f
Disable rl tutorials again
svekars Mar 26, 2025
b87d98d
Add Context Parallel tutorial
XilunWu Apr 8, 2025
f75c9fd
fix typo
XilunWu Apr 9, 2025
dcf02de
fix: address comment
XilunWu Apr 10, 2025
4275c42
fix: typos
XilunWu Apr 10, 2025
c6d8dfa
address review comments
XilunWu Apr 14, 2025
a6938ff
address comments: improve pass-KV description
XilunWu Apr 14, 2025
b74e6cc
address comments: improve API description
XilunWu Apr 14, 2025
02d419c
address comments: improve API description
XilunWu Apr 14, 2025
80f228c
fix indentation
XilunWu Apr 14, 2025
0432a23
address review comments
XilunWu Apr 15, 2025
ada3e08
manually fix rebase issues
XilunWu Apr 16, 2025
5872433
manually fix rebase issues
XilunWu Apr 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
228 changes: 228 additions & 0 deletions prototype_source/context_parallel.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,228 @@
Introduction to Context Parallel
======================================
**Authors**: `Xilun Wu <https://github.com/XilunWu>`_, `Chien-Chin Huang <https://github.com/fegin>`__

.. note::
|edit| View and edit this tutorial in `GitHub <https://github.com/pytorch/tutorials/blob/main/prototype_source/context_parallel.rst>`__.

.. grid:: 2

.. grid-item-card:: :octicon:`mortar-board;1em;` What you will learn
:class-card: card-prerequisites

* `Context Parallel APIs <https://pytorch.org/docs/stable/distributed.tensor.html#torch.distributed.tensor.experimental.context_parallel>`__
* `1M sequence training in TorchTitan with Context Parallel <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__


.. grid-item-card:: :octicon:`list-unordered;1em;` Prerequisites
:class-card: card-prerequisites

* PyTorch 2.7 or later


Introduction
------------

Context Parallel is an approach used in large language model training to reduce peak activation size by sharding the long input sequence across multiple devices.
It breaks the constraint on input sequence length resulting from peak memory usage on storing activations in Transformer blocks.

Ring Attention, a novel parallel implementation of the Attention layer, is critical to performant Context Parallel.
Ring Attention shuffles the KV shards and calculates the partial attention scores, repeats until all KV shards have been used on each device.
Two Ring Attention variants have been implemented: `the all-gather based pass-KV <https://arxiv.org/abs/2407.21783>`__ and `the all-to-all based pass-KV <https://openreview.net/forum?id=WsRHpHH4s0>`__:

1. The all-gather based pass-KV algorithm is used in Llama3 training, which initially performs an all-gather on the key and value tensors, followed by computing the attention output for the
local query tensor chunk. Our modified all-gather based pass-KV algorithm concurrently all-gathers KV shards and computes attention output for the local query tensor chunk
using local key and value tensor chunks, followed by a final computation of attention output for the local query tensor and remaining KV shards. This allows some degree of
overlap between the attention computation and the all-gather collective. For example, in the case of Llama3 training, we also shard ``freq_cis`` over the sequence dimension.
2. The all-to-all approach uses interleaved all-to-all collectives to ring shuffle KV shards to overlap the SDPA (Scaled Dot Product Attention) computation and the all-to-all communication
necessary for the next SDPA.

The Context Parallel APIs consist of two parts:

1. ``context_parallel()`` allows users to create a Python context where the SDPA function (``torch.nn.functional.scaled_dot_product_attention``)
will be automatically replaced with Ring Attention. To shard Tensors along a dimension, simply pass the Tensors and their sharding dimensions to
argument ``buffers`` and ``buffer_seq_dims`` respectively. We recommend that users add tensors computing along the sequence dimension to ``buffers``
and shard them along this dimension. Taking Llama3 training as an example, missing ``freq_cis`` in ``buffers`` will result in a miscalculated rotary embedding.
2. ``set_rotate_method()`` allows users to choose between the all-gather based pass-KV approach and the all-to-all based pass-KV approach.


Setup
---------------------

With ``torch.distributed.tensor.experimental.context_parallel()``, users can easily shard the Tensor input and parallelize the execution of the SDPA function.
To better demonstrate the usage of this API, we start with a simple code snippet doing SDPA and then parallelize it using the API:

.. code:: python

import torch
import torch.nn.functional as F

from torch.nn.attention import sdpa_kernel, SDPBackend


def sdpa_example():
assert torch.cuda.is_available()
torch.cuda.set_device("cuda:0")
torch.cuda.manual_seed(0)

batch = 8
nheads = 8
qkv_len = 8192
dim = 32
backend = SDPBackend.FLASH_ATTENTION
dtype = (
torch.bfloat16
if backend == SDPBackend.FLASH_ATTENTION
or backend == SDPBackend.CUDNN_ATTENTION
else torch.float32
)

qkv = [
torch.rand(
(batch, nheads, qkv_len, dim),
dtype=dtype,
requires_grad=True,
device='cuda',
)
for _ in range(3)
]
# specify the SDPBackend to use
with sdpa_kernel(backend):
out = F.scaled_dot_product_attention(*qkv, is_causal=True)


if __name__ == "__main__":
sdpa_example()


Enable Context Parallel
-----------------------

Now, let's first adapt it to a distributed program where each rank has the same tensor input. Then we apply the context parallel API to
shard to input and distribute the computation across ranks:

.. code:: python

# file: cp_sdpa_example.py
import os

import torch
import torch.distributed as dist
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.tensor.experimental import context_parallel
from torch.distributed.tensor.experimental._attention import context_parallel_unshard
from torch.nn.attention import sdpa_kernel, SDPBackend


def context_parallel_sdpa_example(world_size: int, rank: int):
assert torch.cuda.is_available()
assert dist.is_nccl_available()
torch.cuda.set_device(f"cuda:{rank}")
torch.cuda.manual_seed(0)

dist.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank,
)
device_mesh = init_device_mesh(
device_type="cuda", mesh_shape=(world_size,), mesh_dim_names=("cp",)
)

batch = 8
nheads = 8
qkv_len = 64
dim = 32
backend = SDPBackend.FLASH_ATTENTION
dtype = (
torch.bfloat16
if backend == SDPBackend.FLASH_ATTENTION
or backend == SDPBackend.CUDNN_ATTENTION
else torch.float32
)

qkv = [
torch.rand(
(batch, nheads, qkv_len, dim),
dtype=dtype,
requires_grad=True,
device='cuda',
)
for _ in range(3)
]
# specify the SDPBackend to use
with sdpa_kernel(backend):
out = F.scaled_dot_product_attention(*qkv, is_causal=True)

# make a clean copy of QKV for output comparison
cp_qkv = [t.detach().clone() for t in qkv]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh, so this is not even needed for cp just for the reference?

I wonder if it's better to delete the reference. That's more appropriate for a unit test, but for an example people usually want something that's minimal and copy able, and in this case they might be distracted by this line.


with sdpa_kernel(backend):
# This `context_parallel()` performs two actions:
# 1. Shard the tensor objects in `buffers` in-place along the dimension
# specified in `buffer_seq_dims`, the tensors in `buffers` and their
# sharding dims in `buffer_seq_dims` are organized in the same order.
# 2. Replace the execution of `F.scaled_dot_product_attention` with a
# context-paralleled-enabled Ring Attention.
with context_parallel(
device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
):
cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)

# The output `cp_out` is still sharded in the same way as QKV
# the `context_parallel_unshard` API allows users to easily
# unshard to gain the full tensor.
(cp_out,) = context_parallel_unshard(device_mesh, [cp_out], [2])

assert torch.allclose(
cp_out,
out,
atol=(1e-08 if dtype == torch.float32 else 1e-03 * world_size),
)


if __name__ == "__main__":
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

try:
context_parallel_sdpa_example(world_size, rank)
finally:
dist.barrier()
dist.destroy_process_group()


You can use the command ``torchrun --standalone --nnodes=1 --nproc-per-node=4 cp_sdpa_example.py`` to launch the above context parallel
SDPA on 4 GPUs. We demonstrate the numeric correctness by comparing the output of Ring Attention to that of SDPA on a single GPU.


Select Rotation Approach
------------------------

You can choose the desired shards rotation approach in Ring Attention by using ``torch.distributed.tensor.experimental._attention.set_rotate_method()``:

.. code:: python

# file: cp_sdpa_example.py
from torch.distributed.tensor.experimental._attention import set_rotate_method

set_rotate_method("alltoall") # rotate shards using all-to-all

with sdpa_kernel(backend):
with context_parallel(
device_mesh, buffers=tuple(cp_qkv), buffer_seq_dims=(2, 2, 2)
):
cp_out = F.scaled_dot_product_attention(*cp_qkv, is_causal=True)


The default rotation approach is the all-gather based pass-KV.


Conclusion
----------

In this tutorial, we have learned how to parallelize the SDPA computation along the sequence dimension easily with our Context Parallel APIs. For
design and implementation details, performance analysis, and an end-to-end training example in `TorchTitan <https://github.com/pytorch/torchtitan>`__,
see our post on `PyTorch native long-context training <https://discuss.pytorch.org/t/distributed-w-torchtitan-breaking-barriers-training-long-context-llms-with-1m-sequence-length-in-pytorch-using-context-parallel/215082>`__.
8 changes: 8 additions & 0 deletions prototype_source/prototype_index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,13 @@ Prototype features are not available as part of binary distributions like PyPI o
:link: ../prototype/flight_recorder_tutorial.html
:tags: Distributed, Debugging, FlightRecorder

.. customcarditem::
:header: Context Parallel Tutorial
:card_description: Parallelize the attention computation along sequence dimension
:image: ../_static/img/thumbnails/cropped/generic-pytorch-logo.png
:link: ../prototype/context_parallel.html
:tags: Distributed, Context Parallel

.. Integration
.. customcarditem::
:header: Out-of-tree extension autoloading in Python
Expand All @@ -265,6 +272,7 @@ Prototype features are not available as part of binary distributions like PyPI o
.. toctree::
:hidden:

prototype/context_parallel.html
prototype/fx_graph_mode_quant_guide.html
prototype/fx_graph_mode_ptq_dynamic.html
prototype/fx_graph_mode_ptq_static.html
Expand Down
Loading