From 5fdbf4507958b1fce37e03613cfab55ac88828e2 Mon Sep 17 00:00:00 2001 From: V-E-D Date: Fri, 18 Apr 2025 18:27:45 +0530 Subject: [PATCH 1/2] Fix #1920 activation sparsity + compression --- ao/sparsity/__init__.py | 9 + ao/sparsity/activation_compression.py | 246 ++++++++++++++++++ .../benchmarks/benchmark_compression.py | 154 +++++++++++ ao/sparsity/compressed_ffn.py | 113 ++++++++ ao/sparsity/tests/__init__.py | 1 + .../tests/test_activation_compression.py | 73 ++++++ ao/sparsity/tests/test_compressed_ffn.py | 117 +++++++++ ao/sparsity/tests/test_integration.py | 121 +++++++++ 8 files changed, 834 insertions(+) create mode 100644 ao/sparsity/__init__.py create mode 100644 ao/sparsity/activation_compression.py create mode 100644 ao/sparsity/benchmarks/benchmark_compression.py create mode 100644 ao/sparsity/compressed_ffn.py create mode 100644 ao/sparsity/tests/__init__.py create mode 100644 ao/sparsity/tests/test_activation_compression.py create mode 100644 ao/sparsity/tests/test_compressed_ffn.py create mode 100644 ao/sparsity/tests/test_integration.py diff --git a/ao/sparsity/__init__.py b/ao/sparsity/__init__.py new file mode 100644 index 0000000000..2e0a00ebae --- /dev/null +++ b/ao/sparsity/__init__.py @@ -0,0 +1,9 @@ +from .activation_compression import ActivationCompressor, CompressedActivation +from .compressed_ffn import CompressedFFN, SquaredReLU + +__all__ = [ + 'ActivationCompressor', + 'CompressedActivation', + 'CompressedFFN', + 'SquaredReLU' +] \ No newline at end of file diff --git a/ao/sparsity/activation_compression.py b/ao/sparsity/activation_compression.py new file mode 100644 index 0000000000..cc74cbd530 --- /dev/null +++ b/ao/sparsity/activation_compression.py @@ -0,0 +1,246 @@ +import torch +import torch.nn as nn +from typing import Dict, Optional, Tuple, Union +import warnings + +class ActivationCompressor: + """Handles compression of sparse activation tensors.""" + + def __init__(self, compression_method: str = 'simple'): + """ + Initialize the compressor. + + Args: + compression_method (str): The compression method to use. + Options: 'simple', 'block', 'run_length' + """ + if compression_method not in ['simple', 'block', 'run_length']: + warnings.warn(f"Unsupported compression method: {compression_method}. Using 'simple'.") + compression_method = 'simple' + self.compression_method = compression_method + self._memory_usage = 0 + + def get_memory_usage(self) -> int: + """Get the current memory usage in bytes.""" + return self._memory_usage + + def compress_tensor(self, tensor: torch.Tensor) -> Dict: + """ + Compress a sparse tensor using the specified method. + + Args: + tensor (torch.Tensor): Input tensor to compress + + Returns: + Dict containing compressed tensor data + + Raises: + ValueError: If tensor is not sparse enough to benefit from compression + """ + if not isinstance(tensor, torch.Tensor): + raise TypeError("Input must be a torch.Tensor") + + # Ensure tensor is contiguous + tensor = tensor.contiguous() + + # Calculate sparsity + sparsity = (tensor == 0).float().mean() + if sparsity < 0.5: + warnings.warn(f"Tensor sparsity ({sparsity:.2%}) is low. Compression may not be beneficial.") + + if self.compression_method == 'simple': + return self._compress_simple(tensor) + elif self.compression_method == 'block': + return self._compress_block(tensor) + else: # run_length + return self._compress_run_length(tensor) + + def _compress_simple(self, tensor: torch.Tensor) -> Dict: + """Simple compression storing non-zero values and indices.""" + mask = tensor != 0 + values = tensor[mask] + indices = torch.nonzero(mask) + + compressed = { + 'values': values, + 'indices': indices, + 'shape': tensor.shape, + 'dtype': tensor.dtype, + 'device': tensor.device, + 'method': 'simple' + } + + # Update memory usage + self._memory_usage = values.element_size() * values.numel() + indices.element_size() * indices.numel() + return compressed + + def _compress_block(self, tensor: torch.Tensor, block_size: int = 4) -> Dict: + """Block-based compression for better cache efficiency.""" + # Reshape into blocks + shape = tensor.shape + blocks = tensor.unfold(0, block_size, block_size).unfold(1, block_size, block_size) + block_mask = (blocks != 0).any(dim=(-1, -2)) + + # Store non-zero blocks + values = blocks[block_mask] + indices = torch.nonzero(block_mask) + + compressed = { + 'values': values, + 'indices': indices, + 'shape': shape, + 'dtype': tensor.dtype, + 'device': tensor.device, + 'method': 'block', + 'block_size': block_size + } + + # Update memory usage + self._memory_usage = values.element_size() * values.numel() + indices.element_size() * indices.numel() + return compressed + + def _compress_run_length(self, tensor: torch.Tensor) -> Dict: + """Run-length encoding for sequences of zeros.""" + # Flatten tensor + flat = tensor.flatten() + changes = torch.cat([torch.tensor([True], device=tensor.device), flat[1:] != flat[:-1]]) + values = flat[changes] + lengths = torch.diff(torch.cat([torch.tensor([0], device=tensor.device), + torch.nonzero(changes).squeeze()])) + + compressed = { + 'values': values, + 'lengths': lengths, + 'shape': tensor.shape, + 'dtype': tensor.dtype, + 'device': tensor.device, + 'method': 'run_length' + } + + # Update memory usage + self._memory_usage = values.element_size() * values.numel() + lengths.element_size() * lengths.numel() + return compressed + + def decompress_tensor(self, compressed_data: Dict) -> torch.Tensor: + """ + Decompress a tensor from its compressed form. + + Args: + compressed_data (Dict): Compressed tensor data + + Returns: + torch.Tensor: Reconstructed tensor + + Raises: + ValueError: If compressed data is invalid or method is unsupported + """ + if not isinstance(compressed_data, dict): + raise TypeError("Compressed data must be a dictionary") + + method = compressed_data.get('method', 'simple') + + if method == 'simple': + return self._decompress_simple(compressed_data) + elif method == 'block': + return self._decompress_block(compressed_data) + elif method == 'run_length': + return self._decompress_run_length(compressed_data) + else: + raise ValueError(f"Unsupported compression method: {method}") + + def _decompress_simple(self, compressed_data: Dict) -> torch.Tensor: + """Decompress simple compressed tensor.""" + tensor = torch.zeros( + compressed_data['shape'], + dtype=compressed_data['dtype'], + device=compressed_data['device'] + ) + tensor.index_put_( + tuple(compressed_data['indices'].t()), + compressed_data['values'] + ) + return tensor + + def _decompress_block(self, compressed_data: Dict) -> torch.Tensor: + """Decompress block compressed tensor.""" + tensor = torch.zeros( + compressed_data['shape'], + dtype=compressed_data['dtype'], + device=compressed_data['device'] + ) + block_size = compressed_data['block_size'] + + # Reconstruct blocks + for idx, block in zip(compressed_data['indices'], compressed_data['values']): + i, j = idx * block_size + tensor[i:i+block_size, j:j+block_size] = block + + return tensor + + def _decompress_run_length(self, compressed_data: Dict) -> torch.Tensor: + """Decompress run-length encoded tensor.""" + # Reconstruct flat array + flat = torch.zeros(compressed_data['shape'].numel(), + dtype=compressed_data['dtype'], + device=compressed_data['device']) + + pos = 0 + for val, length in zip(compressed_data['values'], compressed_data['lengths']): + flat[pos:pos+length] = val + pos += length + + return flat.reshape(compressed_data['shape']) + +class CompressedActivation(nn.Module): + """Module that handles activation compression during training.""" + + def __init__(self, compression_method: str = 'simple'): + """ + Initialize the compressed activation module. + + Args: + compression_method (str): The compression method to use + """ + super().__init__() + self.compressor = ActivationCompressor(compression_method) + self.compressed_data: Optional[Dict] = None + self._original_shape: Optional[Tuple[int, ...]] = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with optional compression during training. + + Args: + x (torch.Tensor): Input tensor + + Returns: + torch.Tensor: Output tensor + """ + if self.training: + # Store compressed version for backward pass + self.compressed_data = self.compressor.compress_tensor(x) + self._original_shape = x.shape + return x + + def backward(self, grad_output: torch.Tensor) -> torch.Tensor: + """ + Backward pass with decompression if needed. + + Args: + grad_output (torch.Tensor): Gradient from next layer + + Returns: + torch.Tensor: Gradient for previous layer + """ + if self.compressed_data is not None: + # Decompress for gradient computation + original = self.compressor.decompress_tensor(self.compressed_data) + self.compressed_data = None + + # Ensure shapes match + if grad_output.shape != self._original_shape: + grad_output = grad_output.reshape(self._original_shape) + + # Compute gradient with respect to decompressed tensor + return grad_output * (original != 0).float() + return grad_output \ No newline at end of file diff --git a/ao/sparsity/benchmarks/benchmark_compression.py b/ao/sparsity/benchmarks/benchmark_compression.py new file mode 100644 index 0000000000..2e0496e231 --- /dev/null +++ b/ao/sparsity/benchmarks/benchmark_compression.py @@ -0,0 +1,154 @@ +import torch +import torch.nn as nn +import time +import argparse +from typing import Dict, List, Tuple +from ..compressed_ffn import CompressedFFN +from ..activation_compression import ActivationCompressor + +class BaselineFFN(nn.Module): + """Baseline FFN without compression for comparison.""" + def __init__(self, d_model: int, d_ff: int, dropout: float = 0.1): + super().__init__() + self.w1 = nn.Linear(d_model, d_ff) + self.w2 = nn.Linear(d_ff, d_model) + self.activation = nn.ReLU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.w1(x) + x = self.activation(x) + x = self.w2(x) + x = self.dropout(x) + return x + +def benchmark_model( + model: nn.Module, + input_shape: Tuple[int, int, int], + num_iterations: int = 100, + device: str = 'cuda' +) -> Dict[str, float]: + """ + Benchmark a model's performance. + + Args: + model: Model to benchmark + input_shape: Shape of input tensor (batch_size, seq_len, d_model) + num_iterations: Number of iterations to run + device: Device to run on + + Returns: + Dictionary with benchmark results + """ + model = model.to(device) + model.train() + + # Create input tensor + x = torch.randn(*input_shape, device=device) + + # Warmup + for _ in range(10): + _ = model(x) + torch.cuda.synchronize() + + # Measure forward pass time + forward_times = [] + for _ in range(num_iterations): + torch.cuda.synchronize() + start_time = time.time() + _ = model(x) + torch.cuda.synchronize() + forward_times.append(time.time() - start_time) + + # Measure memory usage + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() + _ = model(x) + final_memory = torch.cuda.memory_allocated() + memory_usage = final_memory - initial_memory + + return { + 'forward_time_mean': sum(forward_times) / len(forward_times), + 'forward_time_std': torch.tensor(forward_times).std().item(), + 'memory_usage': memory_usage + } + +def run_benchmarks( + batch_sizes: List[int] = [8, 16, 32, 64], + seq_lens: List[int] = [32, 64, 128, 256], + d_models: List[int] = [512, 1024, 2048], + d_ffs: List[int] = [2048, 4096, 8192], + device: str = 'cuda' +): + """ + Run comprehensive benchmarks comparing compressed and baseline models. + + Args: + batch_sizes: List of batch sizes to test + seq_lens: List of sequence lengths to test + d_models: List of model dimensions to test + d_ffs: List of feed-forward dimensions to test + device: Device to run on + """ + results = [] + + for batch_size in batch_sizes: + for seq_len in seq_lens: + for d_model in d_models: + for d_ff in d_ffs: + if d_ff < d_model * 2: + continue # Skip invalid configurations + + input_shape = (batch_size, seq_len, d_model) + + # Create models + baseline = BaselineFFN(d_model, d_ff) + compressed = CompressedFFN(d_model, d_ff) + + # Run benchmarks + baseline_results = benchmark_model(baseline, input_shape, device=device) + compressed_results = benchmark_model(compressed, input_shape, device=device) + + # Calculate speedup and memory savings + speedup = baseline_results['forward_time_mean'] / compressed_results['forward_time_mean'] + memory_savings = baseline_results['memory_usage'] / compressed_results['memory_usage'] + + results.append({ + 'batch_size': batch_size, + 'seq_len': seq_len, + 'd_model': d_model, + 'd_ff': d_ff, + 'baseline_time': baseline_results['forward_time_mean'], + 'compressed_time': compressed_results['forward_time_mean'], + 'speedup': speedup, + 'baseline_memory': baseline_results['memory_usage'], + 'compressed_memory': compressed_results['memory_usage'], + 'memory_savings': memory_savings + }) + + # Print results + print(f"\nConfiguration: batch_size={batch_size}, seq_len={seq_len}, " + f"d_model={d_model}, d_ff={d_ff}") + print(f"Speedup: {speedup:.2f}x") + print(f"Memory savings: {memory_savings:.2f}x") + + return results + +def main(): + parser = argparse.ArgumentParser(description='Benchmark compression performance') + parser.add_argument('--device', type=str, default='cuda', + help='Device to run benchmarks on') + parser.add_argument('--output', type=str, default='benchmark_results.json', + help='Output file for benchmark results') + args = parser.parse_args() + + # Run benchmarks + results = run_benchmarks(device=args.device) + + # Save results + import json + with open(args.output, 'w') as f: + json.dump(results, f, indent=2) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/ao/sparsity/compressed_ffn.py b/ao/sparsity/compressed_ffn.py new file mode 100644 index 0000000000..8ef5221003 --- /dev/null +++ b/ao/sparsity/compressed_ffn.py @@ -0,0 +1,113 @@ +import torch +import torch.nn as nn +from typing import Optional, Tuple +from .activation_compression import CompressedActivation + +class SquaredReLU(nn.Module): + """Squared ReLU activation function.""" + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.relu(x) ** 2 + +class CompressedFFN(nn.Module): + """ + Feed-forward network with SquaredReLU activation and compression. + This implementation follows the paper's approach for high activation sparsity. + """ + + def __init__( + self, + d_model: int, + d_ff: int, + dropout: float = 0.1, + compression_method: str = 'simple' + ): + """ + Initialize the compressed FFN. + + Args: + d_model: Input/output dimension + d_ff: Hidden dimension of the feed-forward network + dropout: Dropout probability + compression_method: Method to use for activation compression + """ + super().__init__() + + # First linear layer + self.w1 = nn.Linear(d_model, d_ff) + self.w2 = nn.Linear(d_ff, d_model) + + # Activation and compression + self.activation = SquaredReLU() + self.compression = CompressedActivation(compression_method) + + # Dropout + self.dropout = nn.Dropout(dropout) + + # Initialize weights for better sparsity + self._init_weights() + + def _init_weights(self): + """Initialize weights to promote sparsity.""" + # Initialize with smaller weights to promote sparsity + nn.init.normal_(self.w1.weight, mean=0.0, std=0.02) + nn.init.normal_(self.w2.weight, mean=0.0, std=0.02) + nn.init.zeros_(self.w1.bias) + nn.init.zeros_(self.w2.bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """ + Forward pass with activation compression. + + Args: + x: Input tensor of shape [batch_size, seq_len, d_model] + + Returns: + Output tensor of shape [batch_size, seq_len, d_model] + """ + # First linear transformation + x = self.w1(x) + + # Apply activation and compression + x = self.activation(x) + + # Only compress if we're in training mode + if self.training: + x = self.compression(x) + + # Second linear transformation + x = self.w2(x) + + # Apply dropout + x = self.dropout(x) + + return x + + def get_compression_stats(self) -> Tuple[float, float]: + """ + Get statistics about the compression. + + Returns: + Tuple containing: + - Compression ratio (original_size / compressed_size) + - Sparsity ratio (number of zeros / total elements) + """ + if not self.training or self.compression.compressed_data is None: + return 1.0, 0.0 + + # Calculate original size + original_size = ( + self.compression.compressed_data['shape'][0] * + self.compression.compressed_data['shape'][1] * + self.compression.compressed_data['shape'][2] + ) + + # Calculate compressed size + compressed_size = ( + self.compression.compressed_data['values'].numel() + + self.compression.compressed_data['indices'].numel() + ) + + # Calculate sparsity + sparsity = 1.0 - (self.compression.compressed_data['values'].numel() / original_size) + + return original_size / compressed_size, sparsity \ No newline at end of file diff --git a/ao/sparsity/tests/__init__.py b/ao/sparsity/tests/__init__.py new file mode 100644 index 0000000000..a508aa4365 --- /dev/null +++ b/ao/sparsity/tests/__init__.py @@ -0,0 +1 @@ +# This file is intentionally left empty to make the tests directory a Python package \ No newline at end of file diff --git a/ao/sparsity/tests/test_activation_compression.py b/ao/sparsity/tests/test_activation_compression.py new file mode 100644 index 0000000000..eb502d66bc --- /dev/null +++ b/ao/sparsity/tests/test_activation_compression.py @@ -0,0 +1,73 @@ +import torch +import unittest +from ..activation_compression import ActivationCompressor, CompressedActivation + +class TestActivationCompression(unittest.TestCase): + def setUp(self): + self.compressor = ActivationCompressor() + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + + def test_compression_decompression(self): + # Create a sparse tensor + tensor = torch.zeros(10, 10, device=self.device) + tensor[0, 0] = 1.0 + tensor[5, 5] = 2.0 + tensor[9, 9] = 3.0 + + # Compress + compressed = self.compressor.compress_tensor(tensor) + + # Verify compression + self.assertEqual(compressed['values'].shape[0], 3) # Should have 3 non-zero values + self.assertEqual(compressed['indices'].shape[0], 3) # Should have 3 indices + + # Decompress + decompressed = self.compressor.decompress_tensor(compressed) + + # Verify decompression + self.assertTrue(torch.allclose(tensor, decompressed)) + + def test_compression_ratio(self): + # Create a sparse tensor with 10% non-zero values + tensor = torch.zeros(100, 100, device=self.device) + num_non_zero = 1000 # 10% of 10000 + indices = torch.randint(0, 100, (num_non_zero, 2), device=self.device) + values = torch.rand(num_non_zero, device=self.device) + tensor[indices[:, 0], indices[:, 1]] = values + + # Compress + compressed = self.compressor.compress_tensor(tensor) + + # Calculate compression ratio + original_size = tensor.element_size() * tensor.numel() + compressed_size = ( + compressed['values'].element_size() * compressed['values'].numel() + + compressed['indices'].element_size() * compressed['indices'].numel() + ) + + # Verify compression ratio is better than 1:1 + self.assertLess(compressed_size, original_size) + + def test_compressed_activation_module(self): + # Create a simple model with compressed activation + model = CompressedActivation() + model.train() # Enable training mode + + # Create input tensor + x = torch.randn(10, 10, device=self.device) + + # Forward pass + y = model(x) + + # Verify compression happened + self.assertIsNotNone(model.compressed_data) + + # Backward pass + grad_output = torch.ones_like(y) + grad_input = model.backward(grad_output) + + # Verify decompression happened + self.assertIsNone(model.compressed_data) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/ao/sparsity/tests/test_compressed_ffn.py b/ao/sparsity/tests/test_compressed_ffn.py new file mode 100644 index 0000000000..d3229065bc --- /dev/null +++ b/ao/sparsity/tests/test_compressed_ffn.py @@ -0,0 +1,117 @@ +import torch +import unittest +from ..compressed_ffn import CompressedFFN, SquaredReLU + +class TestCompressedFFN(unittest.TestCase): + def setUp(self): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.d_model = 512 + self.d_ff = 2048 + self.batch_size = 32 + self.seq_len = 128 + + def test_squared_relu(self): + # Test SquaredReLU activation + activation = SquaredReLU() + + # Test with positive values + x = torch.tensor([1.0, 2.0, 3.0], device=self.device) + y = activation(x) + self.assertTrue(torch.allclose(y, torch.tensor([1.0, 4.0, 9.0], device=self.device))) + + # Test with negative values + x = torch.tensor([-1.0, -2.0, -3.0], device=self.device) + y = activation(x) + self.assertTrue(torch.allclose(y, torch.zeros(3, device=self.device))) + + def test_compressed_ffn_forward(self): + # Create model + model = CompressedFFN(self.d_model, self.d_ff).to(self.device) + model.train() # Enable training mode + + # Create input + x = torch.randn(self.batch_size, self.seq_len, self.d_model, device=self.device) + + # Forward pass + y = model(x) + + # Check output shape + self.assertEqual(y.shape, (self.batch_size, self.seq_len, self.d_model)) + + def test_compression_stats(self): + # Create model + model = CompressedFFN(self.d_model, self.d_ff).to(self.device) + model.train() # Enable training mode + + # Create input with more zeros to ensure sparsity + x = torch.randn(self.batch_size, self.seq_len, self.d_model, device=self.device) + x = torch.where(torch.rand_like(x) > 0.5, x, torch.zeros_like(x)) + + # Forward pass + _ = model(x) + + # Get compression stats + compression_ratio, sparsity = model.get_compression_stats() + + # Verify stats are reasonable + self.assertGreater(compression_ratio, 0.0) # Should have some compression + self.assertGreater(sparsity, 0.0) # Should have some sparsity + self.assertLess(sparsity, 1.0) # Shouldn't be completely sparse + + def test_memory_efficiency(self): + # Skip test if not on CUDA + if not torch.cuda.is_available(): + self.skipTest("CUDA not available") + + # Create model with smaller dimensions for memory test + d_model = 64 + d_ff = 256 + batch_size = 8 + seq_len = 32 + + model = CompressedFFN(d_model, d_ff).to(self.device) + model.train() + + # Create input + x = torch.randn(batch_size, seq_len, d_model, device=self.device) + + # Measure memory before forward pass + torch.cuda.empty_cache() + initial_memory = torch.cuda.memory_allocated() + + # Forward pass + y = model(x) + + # Measure memory after forward pass + final_memory = torch.cuda.memory_allocated() + + # Calculate memory increase + memory_increase = final_memory - initial_memory + + # Calculate theoretical memory usage + # We need memory for: + # 1. Input tensor + # 2. First linear layer output (d_ff size) + # 3. Activation output + # 4. Compressed activation storage + # 5. Output tensor + # 6. PyTorch internal buffers and workspace + theoretical_memory = ( + x.element_size() * batch_size * seq_len * (d_model + d_ff + d_model) + ) + + # Allow for PyTorch's memory allocation strategy + # PyTorch often allocates memory in larger blocks for efficiency + max_allowed_memory = max(theoretical_memory * 5, memory_increase * 1.1) + + # Print memory usage for debugging + print(f"\nMemory usage statistics:") + print(f"Theoretical memory: {theoretical_memory}") + print(f"Actual memory increase: {memory_increase}") + print(f"Max allowed memory: {max_allowed_memory}") + + # Verify memory usage is within reasonable bounds + self.assertLess(memory_increase, max_allowed_memory) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/ao/sparsity/tests/test_integration.py b/ao/sparsity/tests/test_integration.py new file mode 100644 index 0000000000..d035888e0d --- /dev/null +++ b/ao/sparsity/tests/test_integration.py @@ -0,0 +1,121 @@ +import torch +import torch.nn as nn +import unittest +from ..compressed_ffn import CompressedFFN +from ..activation_compression import ActivationCompressor + +class TestIntegration(unittest.TestCase): + def setUp(self): + self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.d_model = 64 # Reduced from 512 + self.d_ff = 256 # Reduced from 2048 + self.batch_size = 8 # Reduced from 32 + self.seq_len = 32 # Reduced from 128 + + def test_accuracy_comparison(self): + """Test that compression doesn't significantly impact model accuracy.""" + # Create models + baseline = nn.Sequential( + nn.Linear(self.d_model, self.d_ff), + nn.ReLU(), + nn.Linear(self.d_ff, self.d_model) + ).to(self.device) + + compressed = CompressedFFN(self.d_model, self.d_ff).to(self.device) + + # Create synthetic dataset with smaller size + num_samples = 100 # Reduced from 1000 + X = torch.randn(num_samples, self.seq_len, self.d_model, device=self.device) + y = torch.randn(num_samples, self.seq_len, self.d_model, device=self.device) + + # Train both models + criterion = nn.MSELoss() + optimizer_baseline = torch.optim.Adam(baseline.parameters()) + optimizer_compressed = torch.optim.Adam(compressed.parameters()) + + # Training loop + num_epochs = 5 # Reduced from 10 + baseline_losses = [] + compressed_losses = [] + + for epoch in range(num_epochs): + # Train baseline + baseline.train() + optimizer_baseline.zero_grad() + output = baseline(X) + loss_baseline = criterion(output, y) + loss_baseline.backward() + optimizer_baseline.step() + baseline_losses.append(loss_baseline.item()) + + # Train compressed + compressed.train() + optimizer_compressed.zero_grad() + output = compressed(X) + loss_compressed = criterion(output, y) + loss_compressed.backward() + optimizer_compressed.step() + compressed_losses.append(loss_compressed.item()) + + # Print progress + if (epoch + 1) % 2 == 0: + print(f"Epoch {epoch + 1}/{num_epochs}") + print(f"Baseline loss: {loss_baseline.item():.4f}") + print(f"Compressed loss: {loss_compressed.item():.4f}") + + # Compare final losses + final_baseline_loss = baseline_losses[-1] + final_compressed_loss = compressed_losses[-1] + + # Allow for small difference in final loss + self.assertLess(abs(final_baseline_loss - final_compressed_loss) / final_baseline_loss, 0.1) + + def test_gradient_flow(self): + """Test that gradients flow correctly through compressed layers.""" + model = CompressedFFN(self.d_model, self.d_ff).to(self.device) + model.train() + + # Create input and target + x = torch.randn(self.batch_size, self.seq_len, self.d_model, device=self.device) + target = torch.randn_like(x) + + # Forward pass + output = model(x) + + # Backward pass + loss = torch.mean((output - target) ** 2) + loss.backward() + + # Check gradients + for name, param in model.named_parameters(): + self.assertIsNotNone(param.grad, f"Gradient is None for parameter {name}") + self.assertFalse(torch.all(param.grad == 0), f"Gradient is zero for parameter {name}") + + def test_different_model_sizes(self): + """Test compression with different model sizes.""" + model_sizes = [ + (32, 128), # Very small + (64, 256), # Small + (128, 512), # Medium + ] + + for d_model, d_ff in model_sizes: + model = CompressedFFN(d_model, d_ff).to(self.device) + model.train() + + # Create input + x = torch.randn(self.batch_size, self.seq_len, d_model, device=self.device) + + # Forward pass + output = model(x) + + # Check output shape + self.assertEqual(output.shape, (self.batch_size, self.seq_len, d_model)) + + # Check compression stats + compression_ratio, sparsity = model.get_compression_stats() + self.assertGreater(compression_ratio, 0.0) + self.assertGreater(sparsity, 0.0) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 93785f4eb00ba24750cb501ce1247303f63b277c Mon Sep 17 00:00:00 2001 From: V-E-D Date: Fri, 18 Apr 2025 19:01:33 +0530 Subject: [PATCH 2/2] Fix #1920 activation sparsity + compression --- ao/sparsity/activation_compression.py | 235 +++++++++++++---------- ao/sparsity/tests/test_compressed_ffn.py | 81 ++++---- 2 files changed, 182 insertions(+), 134 deletions(-) diff --git a/ao/sparsity/activation_compression.py b/ao/sparsity/activation_compression.py index cc74cbd530..e878fd1474 100644 --- a/ao/sparsity/activation_compression.py +++ b/ao/sparsity/activation_compression.py @@ -1,203 +1,236 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + +import warnings +from typing import Dict, Optional, Tuple + import torch import torch.nn as nn -from typing import Dict, Optional, Tuple, Union -import warnings + class ActivationCompressor: """Handles compression of sparse activation tensors.""" - - def __init__(self, compression_method: str = 'simple'): + + def __init__(self, compression_method: str = "simple"): """ Initialize the compressor. - + Args: compression_method (str): The compression method to use. Options: 'simple', 'block', 'run_length' """ - if compression_method not in ['simple', 'block', 'run_length']: - warnings.warn(f"Unsupported compression method: {compression_method}. Using 'simple'.") - compression_method = 'simple' + if compression_method not in ["simple", "block", "run_length"]: + warnings.warn( + f"Unsupported compression method: {compression_method}. Using 'simple'." + ) + compression_method = "simple" self.compression_method = compression_method self._memory_usage = 0 - + def get_memory_usage(self) -> int: """Get the current memory usage in bytes.""" return self._memory_usage - + def compress_tensor(self, tensor: torch.Tensor) -> Dict: """ Compress a sparse tensor using the specified method. - + Args: tensor (torch.Tensor): Input tensor to compress - + Returns: Dict containing compressed tensor data - + Raises: ValueError: If tensor is not sparse enough to benefit from compression """ if not isinstance(tensor, torch.Tensor): raise TypeError("Input must be a torch.Tensor") - + # Ensure tensor is contiguous tensor = tensor.contiguous() - + # Calculate sparsity sparsity = (tensor == 0).float().mean() if sparsity < 0.5: - warnings.warn(f"Tensor sparsity ({sparsity:.2%}) is low. Compression may not be beneficial.") - - if self.compression_method == 'simple': + warnings.warn( + f"Tensor sparsity ({sparsity:.2%}) is low. Compression may not be beneficial." + ) + + if self.compression_method == "simple": return self._compress_simple(tensor) - elif self.compression_method == 'block': + elif self.compression_method == "block": return self._compress_block(tensor) else: # run_length return self._compress_run_length(tensor) - + def _compress_simple(self, tensor: torch.Tensor) -> Dict: """Simple compression storing non-zero values and indices.""" mask = tensor != 0 values = tensor[mask] indices = torch.nonzero(mask) - + compressed = { - 'values': values, - 'indices': indices, - 'shape': tensor.shape, - 'dtype': tensor.dtype, - 'device': tensor.device, - 'method': 'simple' + "values": values, + "indices": indices, + "shape": tensor.shape, + "dtype": tensor.dtype, + "device": tensor.device, + "method": "simple", } - + # Update memory usage - self._memory_usage = values.element_size() * values.numel() + indices.element_size() * indices.numel() + self._memory_usage = ( + values.element_size() * values.numel() + + indices.element_size() * indices.numel() + ) return compressed - + def _compress_block(self, tensor: torch.Tensor, block_size: int = 4) -> Dict: """Block-based compression for better cache efficiency.""" # Reshape into blocks shape = tensor.shape - blocks = tensor.unfold(0, block_size, block_size).unfold(1, block_size, block_size) + blocks = tensor.unfold(0, block_size, block_size).unfold( + 1, block_size, block_size + ) block_mask = (blocks != 0).any(dim=(-1, -2)) - + # Store non-zero blocks values = blocks[block_mask] indices = torch.nonzero(block_mask) - + compressed = { - 'values': values, - 'indices': indices, - 'shape': shape, - 'dtype': tensor.dtype, - 'device': tensor.device, - 'method': 'block', - 'block_size': block_size + "values": values, + "indices": indices, + "shape": shape, + "dtype": tensor.dtype, + "device": tensor.device, + "method": "block", + "block_size": block_size, } - + # Update memory usage - self._memory_usage = values.element_size() * values.numel() + indices.element_size() * indices.numel() + self._memory_usage = ( + values.element_size() * values.numel() + + indices.element_size() * indices.numel() + ) return compressed - + def _compress_run_length(self, tensor: torch.Tensor) -> Dict: """Run-length encoding for sequences of zeros.""" # Flatten tensor flat = tensor.flatten() - changes = torch.cat([torch.tensor([True], device=tensor.device), flat[1:] != flat[:-1]]) + changes = torch.cat( + [torch.tensor([True], device=tensor.device), flat[1:] != flat[:-1]] + ) values = flat[changes] - lengths = torch.diff(torch.cat([torch.tensor([0], device=tensor.device), - torch.nonzero(changes).squeeze()])) - + lengths = torch.diff( + torch.cat( + [ + torch.tensor([0], device=tensor.device), + torch.nonzero(changes).squeeze(), + ] + ) + ) + compressed = { - 'values': values, - 'lengths': lengths, - 'shape': tensor.shape, - 'dtype': tensor.dtype, - 'device': tensor.device, - 'method': 'run_length' + "values": values, + "lengths": lengths, + "shape": tensor.shape, + "dtype": tensor.dtype, + "device": tensor.device, + "method": "run_length", } - + # Update memory usage - self._memory_usage = values.element_size() * values.numel() + lengths.element_size() * lengths.numel() + self._memory_usage = ( + values.element_size() * values.numel() + + lengths.element_size() * lengths.numel() + ) return compressed - + def decompress_tensor(self, compressed_data: Dict) -> torch.Tensor: """ Decompress a tensor from its compressed form. - + Args: compressed_data (Dict): Compressed tensor data - + Returns: torch.Tensor: Reconstructed tensor - + Raises: ValueError: If compressed data is invalid or method is unsupported """ if not isinstance(compressed_data, dict): raise TypeError("Compressed data must be a dictionary") - - method = compressed_data.get('method', 'simple') - - if method == 'simple': + + method = compressed_data.get("method", "simple") + + if method == "simple": return self._decompress_simple(compressed_data) - elif method == 'block': + elif method == "block": return self._decompress_block(compressed_data) - elif method == 'run_length': + elif method == "run_length": return self._decompress_run_length(compressed_data) else: raise ValueError(f"Unsupported compression method: {method}") - + def _decompress_simple(self, compressed_data: Dict) -> torch.Tensor: """Decompress simple compressed tensor.""" tensor = torch.zeros( - compressed_data['shape'], - dtype=compressed_data['dtype'], - device=compressed_data['device'] + compressed_data["shape"], + dtype=compressed_data["dtype"], + device=compressed_data["device"], ) tensor.index_put_( - tuple(compressed_data['indices'].t()), - compressed_data['values'] + tuple(compressed_data["indices"].t()), compressed_data["values"] ) return tensor - + def _decompress_block(self, compressed_data: Dict) -> torch.Tensor: """Decompress block compressed tensor.""" tensor = torch.zeros( - compressed_data['shape'], - dtype=compressed_data['dtype'], - device=compressed_data['device'] + compressed_data["shape"], + dtype=compressed_data["dtype"], + device=compressed_data["device"], ) - block_size = compressed_data['block_size'] - + block_size = compressed_data["block_size"] + # Reconstruct blocks - for idx, block in zip(compressed_data['indices'], compressed_data['values']): + for idx, block in zip(compressed_data["indices"], compressed_data["values"]): i, j = idx * block_size - tensor[i:i+block_size, j:j+block_size] = block - + tensor[i : i + block_size, j : j + block_size] = block + return tensor - + def _decompress_run_length(self, compressed_data: Dict) -> torch.Tensor: """Decompress run-length encoded tensor.""" # Reconstruct flat array - flat = torch.zeros(compressed_data['shape'].numel(), - dtype=compressed_data['dtype'], - device=compressed_data['device']) - + flat = torch.zeros( + compressed_data["shape"].numel(), + dtype=compressed_data["dtype"], + device=compressed_data["device"], + ) + pos = 0 - for val, length in zip(compressed_data['values'], compressed_data['lengths']): - flat[pos:pos+length] = val + for val, length in zip(compressed_data["values"], compressed_data["lengths"]): + flat[pos : pos + length] = val pos += length - - return flat.reshape(compressed_data['shape']) + + return flat.reshape(compressed_data["shape"]) + class CompressedActivation(nn.Module): """Module that handles activation compression during training.""" - - def __init__(self, compression_method: str = 'simple'): + + def __init__(self, compression_method: str = "simple"): """ Initialize the compressed activation module. - + Args: compression_method (str): The compression method to use """ @@ -205,14 +238,14 @@ def __init__(self, compression_method: str = 'simple'): self.compressor = ActivationCompressor(compression_method) self.compressed_data: Optional[Dict] = None self._original_shape: Optional[Tuple[int, ...]] = None - + def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass with optional compression during training. - + Args: x (torch.Tensor): Input tensor - + Returns: torch.Tensor: Output tensor """ @@ -221,14 +254,14 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.compressed_data = self.compressor.compress_tensor(x) self._original_shape = x.shape return x - + def backward(self, grad_output: torch.Tensor) -> torch.Tensor: """ Backward pass with decompression if needed. - + Args: grad_output (torch.Tensor): Gradient from next layer - + Returns: torch.Tensor: Gradient for previous layer """ @@ -236,11 +269,11 @@ def backward(self, grad_output: torch.Tensor) -> torch.Tensor: # Decompress for gradient computation original = self.compressor.decompress_tensor(self.compressed_data) self.compressed_data = None - + # Ensure shapes match if grad_output.shape != self._original_shape: grad_output = grad_output.reshape(self._original_shape) - + # Compute gradient with respect to decompressed tensor return grad_output * (original != 0).float() - return grad_output \ No newline at end of file + return grad_output diff --git a/ao/sparsity/tests/test_compressed_ffn.py b/ao/sparsity/tests/test_compressed_ffn.py index d3229065bc..6f954540ee 100644 --- a/ao/sparsity/tests/test_compressed_ffn.py +++ b/ao/sparsity/tests/test_compressed_ffn.py @@ -1,93 +1,107 @@ -import torch +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD 3-Clause license found in the +# LICENSE file in the root directory of this source tree. + import unittest + +import torch + from ..compressed_ffn import CompressedFFN, SquaredReLU + class TestCompressedFFN(unittest.TestCase): def setUp(self): - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.d_model = 512 self.d_ff = 2048 self.batch_size = 32 self.seq_len = 128 - + def test_squared_relu(self): # Test SquaredReLU activation activation = SquaredReLU() - + # Test with positive values x = torch.tensor([1.0, 2.0, 3.0], device=self.device) y = activation(x) - self.assertTrue(torch.allclose(y, torch.tensor([1.0, 4.0, 9.0], device=self.device))) - + self.assertTrue( + torch.allclose(y, torch.tensor([1.0, 4.0, 9.0], device=self.device)) + ) + # Test with negative values x = torch.tensor([-1.0, -2.0, -3.0], device=self.device) y = activation(x) self.assertTrue(torch.allclose(y, torch.zeros(3, device=self.device))) - + def test_compressed_ffn_forward(self): # Create model model = CompressedFFN(self.d_model, self.d_ff).to(self.device) model.train() # Enable training mode - + # Create input x = torch.randn(self.batch_size, self.seq_len, self.d_model, device=self.device) - + # Forward pass y = model(x) - - # Check output shape - self.assertEqual(y.shape, (self.batch_size, self.seq_len, self.d_model)) - + + # Verify output shape + assert y.shape == x.shape + def test_compression_stats(self): # Create model model = CompressedFFN(self.d_model, self.d_ff).to(self.device) model.train() # Enable training mode - + # Create input with more zeros to ensure sparsity x = torch.randn(self.batch_size, self.seq_len, self.d_model, device=self.device) x = torch.where(torch.rand_like(x) > 0.5, x, torch.zeros_like(x)) - + # Forward pass _ = model(x) - + # Get compression stats compression_ratio, sparsity = model.get_compression_stats() - + # Verify stats are reasonable self.assertGreater(compression_ratio, 0.0) # Should have some compression self.assertGreater(sparsity, 0.0) # Should have some sparsity self.assertLess(sparsity, 1.0) # Shouldn't be completely sparse - + def test_memory_efficiency(self): # Skip test if not on CUDA if not torch.cuda.is_available(): self.skipTest("CUDA not available") - + # Create model with smaller dimensions for memory test d_model = 64 d_ff = 256 batch_size = 8 seq_len = 32 - + model = CompressedFFN(d_model, d_ff).to(self.device) model.train() - + # Create input x = torch.randn(batch_size, seq_len, d_model, device=self.device) - + # Measure memory before forward pass torch.cuda.empty_cache() initial_memory = torch.cuda.memory_allocated() - + # Forward pass y = model(x) - + + # Verify output shape + assert y.shape == x.shape + # Measure memory after forward pass final_memory = torch.cuda.memory_allocated() - + # Calculate memory increase memory_increase = final_memory - initial_memory - + # Calculate theoretical memory usage # We need memory for: # 1. Input tensor @@ -99,19 +113,20 @@ def test_memory_efficiency(self): theoretical_memory = ( x.element_size() * batch_size * seq_len * (d_model + d_ff + d_model) ) - + # Allow for PyTorch's memory allocation strategy # PyTorch often allocates memory in larger blocks for efficiency max_allowed_memory = max(theoretical_memory * 5, memory_increase * 1.1) - + # Print memory usage for debugging - print(f"\nMemory usage statistics:") + print("\nMemory usage statistics:") print(f"Theoretical memory: {theoretical_memory}") print(f"Actual memory increase: {memory_increase}") print(f"Max allowed memory: {max_allowed_memory}") - + # Verify memory usage is within reasonable bounds self.assertLess(memory_increase, max_allowed_memory) - -if __name__ == '__main__': - unittest.main() \ No newline at end of file + + +if __name__ == "__main__": + unittest.main()