Skip to content

Support microbenchmarking for low precision training #2101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 22 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
8b22a68
Update
jainapurva Apr 4, 2025
04f39ef
Add profiler
jainapurva Apr 8, 2025
4b7ea5d
Add support for different models and different shapes
jainapurva Apr 10, 2025
33fa3ca
Add ruff fixes
jainapurva Apr 10, 2025
5ee6b58
Updates
jainapurva Apr 10, 2025
345a00c
Updates
jainapurva Apr 10, 2025
6e88306
Merge remote-tracking branch 'origin/bench-gpu-profiling' into model_…
jainapurva Apr 10, 2025
5895b7e
Updates
jainapurva Apr 10, 2025
bbcba36
Updates
jainapurva Apr 10, 2025
62a1e70
Memory profiler
jainapurva Apr 14, 2025
f0709a8
Updates to memory_profiler
jainapurva Apr 14, 2025
d5bdb4a
updates
jainapurva Apr 14, 2025
7677902
Merge remote-tracking branch 'origin/bench-gpu-profiling' into model_…
jainapurva Apr 14, 2025
7c15006
Updates to memory_profiler
jainapurva Apr 14, 2025
06f5ee7
Merge remote-tracking branch 'origin/main' into model_shapes_config
jainapurva Apr 18, 2025
784ec94
Added a future todo
jainapurva Apr 18, 2025
ceded86
Merge remote-tracking branch 'origin/model_shapes_config' into memory…
jainapurva Apr 18, 2025
92b1e3b
Merge remote-tracking branch 'origin/memory_profiler' into training_fbp
jainapurva Apr 18, 2025
1cff42d
Add training benchmarking
jainapurva Apr 21, 2025
abc9ef5
Use baseline bf16
jainapurva Apr 22, 2025
44f564d
Add TOPS calculation
jainapurva Apr 22, 2025
f12ece0
Simplified the logic for linear
jainapurva Apr 22, 2025
6970052
Calculating ref with every run
jainapurva Apr 22, 2025
b4423e9
Add better prints
jainapurva Apr 25, 2025
dbc188d
Merge remote-tracking branch 'origin/main' into training_fbp
jainapurva Apr 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
194 changes: 193 additions & 1 deletion benchmarks/microbenchmarks/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ The microbenchmarking system works as follows:

- **benchmark_runner.py**: Main entry point that orchestrates the benchmarking process
- **benchmark_inference.py**: Handles model creation and inference benchmarking
- **benchmark_training.py**: Manages the setup and execution of training benchmarks
- **utils.py**: Contains utility functions and configuration classes
- **test\/**: Test files and sample configurations

Expand Down Expand Up @@ -50,10 +51,20 @@ model_params:
compile: "max-autotune" # Options: "default", "max-autotune", "false"
device: "cuda" # Options: "cuda", "mps", "xpu", "cpu"
model_type: "linear" # Options: "linear", "ln_linear_sigmoid"
enable_profiler: true # Enable standard profiling
enable_memory_profiler: true # Enable CUDA memory profiling
```

## Configuration Options

### Profiling Options
- `enable_profiler`: Enable standard PyTorch profiling (default: false)
- `enable_memory_profiler`: Enable CUDA memory profiling (default: false)
- Only works when device is set to "cuda"
- Generates memory snapshots before and after inference
- Creates visualizations of memory usage
- Outputs are saved in the memory_profiler subdirectory

### Quantization Methods
Currently, quantization string is in same format as the one being passed in llama/generate.py.
- `baseline`: No quantization
Expand All @@ -63,14 +74,74 @@ Currently, quantization string is in same format as the one being passed in llam

### Model Types
- `linear`: Simple linear layer
- `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid
- `ln_linear_<activation>`: LayerNorm + Linear + Activation, where activation can be:
- `ln_linear_sigmoid`: LayerNorm + Linear + Sigmoid
- `ln_linear_relu`: LayerNorm + Linear + ReLU
- `ln_linear_leakyrelu`: LayerNorm + Linear + LeakyReLU
- `ln_linear_relu6`: LayerNorm + Linear + ReLU6
- `ln_linear_gelu`: LayerNorm + Linear + GELU
- `ln_linear_silu`: LayerNorm + Linear + SiLU
- `ln_linear_hardswish`: LayerNorm + Linear + Hardswish
- `transformer_block`: Transformer block with self-attention and MLP

### Device Options
- `cuda`: NVIDIA GPU
- `xpu`: Intel GPU
- `mps`: Apple Silicon GPU
- `cpu`: CPU fallback

### Shape Generation Options
- `custom`: Manually specify shapes as a list of [m, k, n] dimensions
```yaml
matrix_shapes:
- name: "custom"
shapes: [
[1024, 1024, 1024], # [m, k, n]
[2048, 4096, 1024]
]
```

- `llama`: Use LLaMa 2 70B single-node weight shapes (assumes fused attn.wqkv and ffn.w13)
- Generates shapes for: "attn.wqkv", "attn.w0", "ffn.w13", "ffn.w2"
```yaml
matrix_shapes:
- name: "llama"
```

- `pow2`: Generate shapes with dimensions that are powers of 2
- Parameters:
- `min_power`: Minimum power of 2 (default: 10, which is 1024)
- `max_power`: Maximum power of 2 (default: 14, which is 16,384)
```yaml
matrix_shapes:
- name: "pow2"
min_power: 10 # 2^10 = 1024
max_power: 12 # 2^12 = 4096
```

- `pow2_extended`: Generate shapes with dimensions that are powers of 2 and powers of 2 + half
- Parameters:
- `min_power`: Minimum power of 2 (default: 10, which is 1024)
- `max_power`: Maximum power of 2 (default: 14, which is 16,384)
```yaml
matrix_shapes:
- name: "pow2_extended"
min_power: 10 # Generates: 1024, 1536, 2048, 3072, etc.
max_power: 11
```

- `sweep`: Generate a sweep of shapes with different powers of 2 for M, K, N dimensions
- Parameters:
- `min_power`: Minimum power of 2 (default: 8, which is 256)
- `max_power`: Maximum power of 2 (default: 15, which is 32,768)
- Note: This generates all combinations of M, K, N dimensions, which can be a large number of shapes
```yaml
matrix_shapes:
- name: "sweep"
min_power: 8 # 2^8 = 256
max_power: 9 # 2^9 = 512
```

## Output

Results are saved to a CSV file in the specified output directory
Expand All @@ -82,3 +153,124 @@ To run the test suite:
```bash
python -m unittest discover benchmarks/microbenchmarks/test
```

# Training Microbenchmarks

This directory contains tools for benchmarking training performance with low precision datatypes like float8.

## Overview

The training microbenchmarking framework allows you to:

1. Benchmark forward and backward pass performance of models with different precision types
2. Compare float8 training performance against baseline implementations
3. Configure various float8 training parameters (scaling types, granularity, etc.)
4. Measure speedups and performance characteristics
5. Generate detailed profiling information

## Usage

### Running Training Benchmarks

To run training benchmarks, use the benchmark_runner.py script with a training configuration file:

```bash
python -m benchmarks.microbenchmarks.benchmark_runner --config benchmarks/microbenchmarks/test/training_benchmark_config.yml
```

### Configuration File Format

The training benchmark configuration file uses YAML format. Here's an example:

```yaml
# Configuration for training benchmarks with float8 and other low precision dtypes
benchmark_mode: "training"
output_dir: "benchmarks/microbenchmarks/results/training"

# Float8 training specific configurations
quantization_config_recipe_names:
- "float8dq-tensor"
- "float8dq-row"
- "baseline" # Always include baseline for comparison

# Float8 training specific configurations

# Training specific configurations
scaling_type_input: "dynamic"
scaling_type_weight: "dynamic"
scaling_type_grad_output: "dynamic"
scaling_granularity: "tensorwise"
use_fast_accum: true
repeat_n: 100 # Number of iterations for benchmarking

model_params:
- name: "float8_linear_training"
matrix_shapes:
- name: "custom"
shapes: [
[1024, 1024, 1024], # [m, k, n]
[2048, 4096, 1024],
[4096, 4096, 1024]
]
high_precision_dtype: "torch.bfloat16"
use_torch_compile: true
torch_compile_mode: "max-autotune"
device: "cuda"
model_type: "linear"
enable_profiler: true
enable_memory_profiler: true
```

### Configuration Options

#### General Options

- `benchmark_mode`: Must be set to "training" for training benchmarks
- `output_dir`: Directory where benchmark results will be saved
- `quantization_config_recipe_names`: List of quantization methods to benchmark

#### Float8 Training Specific Options

- `scaling_type_input`: Scaling type for input tensors ("dynamic" or "static")
- `scaling_type_weight`: Scaling type for weight tensors ("dynamic" or "static")
- `scaling_type_grad_output`: Scaling type for gradient output tensors ("dynamic" or "static")
- `scaling_granularity`: Scaling granularity ("tensorwise", "rowwise", or "columnwise")
- `use_fast_accum`: Whether to use fast accumulation (boolean)
- `repeat_n`: Number of iterations for benchmarking

#### Model Parameters

Each model configuration can include:

- `name`: Name of the benchmark
- `matrix_shapes`: Matrix shapes to benchmark
- `custom`: Custom shapes as [m, k, n] lists
- `llama`: Predefined LLaMA model shapes
- `pow2`: Powers of 2 shapes
- `high_precision_dtype`: High precision dtype (e.g., "torch.bfloat16")
- `use_torch_compile`: Whether to use torch.compile
- `torch_compile_mode`: Compilation mode
- `device`: Device to run on ("cuda" or "cpu")
- `model_type`: Type of model ("linear", "ln_linear_relu", "transformer_block", etc.)
- `enable_profiler`: Whether to enable profiling
- `enable_memory_profiler`: Whether to enable memory profiling

## Extending for New Datatypes

The framework is designed to be easily extended for new datatypes. To add support for a new datatype:

1. Update the `run` function in `benchmark_training.py` to handle the new datatype
2. Add any necessary configuration options to `TrainingBenchmarkConfig`
3. Update the result processing in `TrainingBenchmarkResult` if needed

## Output

The benchmark results include:

- Forward pass time (ms)
- Backward pass time (ms)
- Total training time (ms)
- Speedup compared to baseline
- Scaling configuration details

Results are displayed in a table format and saved to a CSV file in the specified output directory.
51 changes: 46 additions & 5 deletions benchmarks/microbenchmarks/benchmark_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,24 +10,29 @@
- run() function is the main entry point for running inference benchmarks.
"""

import os
from copy import deepcopy
from pathlib import Path

import torch

from benchmarks.microbenchmarks.profiler import (
generate_memory_profile,
generate_model_profile,
visualize_memory_profile,
)
from benchmarks.microbenchmarks.utils import (
BenchmarkConfig,
BenchmarkResult,
clean_caches,
create_model_and_input,
model_inference_time_in_ms,
string_to_config,
)
from torchao.quantization import quantize_
from torchao.sparsity.sparse_api import sparsify_
from torchao.testing.model_architectures import (
create_model_and_input_data,
)


def run(config: BenchmarkConfig) -> BenchmarkResult:
Expand All @@ -38,7 +43,7 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
# Create output directory if it doesn't exist
Path(config.output_dir).mkdir(parents=True, exist_ok=True)

base_model, input_data = create_model_and_input(
base_model, input_data = create_model_and_input_data(
config.model_type,
config.m,
config.k,
Expand Down Expand Up @@ -96,11 +101,47 @@ def run(config: BenchmarkConfig) -> BenchmarkResult:
if config.enable_profiler:
print("Running profiler...")
try:
result.profiler_json_path = generate_model_profile(
m_copy, input_data, config.profiler_file_name
profiler_json_path = generate_model_profile(
model=m_copy,
input_data=input_data,
profile_file_path=os.path.join(
config.output_dir, "profiler", f"{config.name}_profile.json"
),
)
result.profiler_json_path = profiler_json_path
except Exception as e:
print(f"Error running profiler for {config.name} with error: {e}")
print(f"Error running profiler: {e}")

# Run memory profiler if enabled
if config.enable_memory_profiler:
print("Running memory profiler...")
try:
result.memory_profile_path, result.memory_stats = (
generate_memory_profile(
model=m_copy,
input_data=input_data,
profile_file_path=os.path.join(
config.output_dir,
"memory_profiler/pickle",
f"{config.name}_quant_{config.quantization}_sparsity_{config.sparsity}_memory_profile.pickle",
),
)
)

if result.memory_profile_path:
result.memory_visualization_path = visualize_memory_profile(
result.memory_profile_path
)
except ValueError as e:
if "not enough values to unpack" in e:
print(
"Failed due to existing bugs, re-run the code to generate memory profile. Please raise an issue if it persists."
)
except Exception as e:
print(f"Error running memory profiler: {e}")
import traceback

traceback.print_exc()

return result
except Exception as e:
Expand Down
Loading
Loading