|
| 1 | +Static Quantization |
| 2 | +-------------------- |
| 3 | + |
| 4 | +Static quantization refers to using a fixed quantization range for all inputs during inference or generation. Unlike dynamic quantization, which dynamically computes new quantization ranges for each new input batch, static quantization typically results in more efficient computation, potentially at the cost of lower quantized accuracy since we cannot adapt to changes in the input distribution on-the-fly. |
| 5 | + |
| 6 | +In static quantization, this fixed quantization range is typically calibrated on similar inputs before quantizing the model. During the calibration phase, we first insert observers into the model to "observe" the distribution of the inputs to be quantized, and use this distribution to decide what scales and zero points to ultimately use when quantizing the model. |
| 7 | + |
| 8 | +In this tutorial, we walk through an example of how to achieve this in torchao. All code can be found in this `example script <https://github.com/pytorch/ao/tree/main/tutorials/calibration_flow/static_quant.py>`__. Let's start with our toy linear model: |
| 9 | + |
| 10 | +.. code:: py |
| 11 | +
|
| 12 | + import copy |
| 13 | + import torch |
| 14 | +
|
| 15 | + class ToyLinearModel(torch.nn.Module): |
| 16 | + def __init__(self, m=64, n=32, k=64): |
| 17 | + super().__init__() |
| 18 | + self.linear1 = torch.nn.Linear(m, k, bias=False) |
| 19 | + self.linear2 = torch.nn.Linear(k, n, bias=False) |
| 20 | + |
| 21 | + def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"): |
| 22 | + return ( |
| 23 | + torch.randn( |
| 24 | + batch_size, self.linear1.in_features, dtype=dtype, device=device |
| 25 | + ), |
| 26 | + ) |
| 27 | + |
| 28 | + def forward(self, x): |
| 29 | + x = self.linear1(x) |
| 30 | + x = self.linear2(x) |
| 31 | + return x |
| 32 | +
|
| 33 | + dtype = torch.bfloat16 |
| 34 | + m = ToyLinearModel().eval().to(dtype).to("cuda") |
| 35 | + m = torch.compile(m, mode="max-autotune") |
| 36 | +
|
| 37 | +
|
| 38 | +Calibration Phase |
| 39 | +~~~~~~~~~~~~~~~~~ |
| 40 | + |
| 41 | +torchao comes with a a simple observer implementation, `AffineQuantizedMinMaxObserver`, that records the min and max values that have flowed through the observer during the calibration phase. Users are welcome to implement their own desired, more advanced observation techniques, such as those relying on moving averages or histograms, and these may be added to torchao in the future. |
| 42 | + |
| 43 | +.. code:: py |
| 44 | +
|
| 45 | + from torchao.quantization.granularity import PerAxis, PerTensor |
| 46 | + from torchao.quantization.observer import AffineQuantizedMinMaxObserver |
| 47 | + from torchao.quantization.quant_primitives import MappingType |
| 48 | +
|
| 49 | + # per tensor input activation asymmetric quantization |
| 50 | + act_obs = AffineQuantizedMinMaxObserver( |
| 51 | + MappingType.ASYMMETRIC, |
| 52 | + torch.uint8, |
| 53 | + granularity=PerTensor(), |
| 54 | + eps=torch.finfo(torch.float32).eps, |
| 55 | + scale_dtype=torch.float32, |
| 56 | + zero_point_dtype=torch.float32, |
| 57 | + ) |
| 58 | +
|
| 59 | + # per channel weight asymmetric quantization |
| 60 | + weight_obs = AffineQuantizedMinMaxObserver( |
| 61 | + MappingType.ASYMMETRIC, |
| 62 | + torch.uint8, |
| 63 | + granularity=PerAxis(axis=0), |
| 64 | + eps=torch.finfo(torch.float32).eps, |
| 65 | + scale_dtype=torch.float32, |
| 66 | + zero_point_dtype=torch.float32, |
| 67 | + ) |
| 68 | +
|
| 69 | +Next, we define our observed linear that we will swap our `torch.nn.Linear` with. This is a high precision (e.g. fp32) linear module with the above observers inserted to record the input activation and weight values during calibration: |
| 70 | + |
| 71 | +.. code:: py |
| 72 | +
|
| 73 | + import torch.nn.functional as F |
| 74 | +
|
| 75 | + class ObservedLinear(torch.nn.Linear): |
| 76 | + def __init__( |
| 77 | + self, |
| 78 | + in_features: int, |
| 79 | + out_features: int, |
| 80 | + act_obs: torch.nn.Module, |
| 81 | + weight_obs: torch.nn.Module, |
| 82 | + bias: bool = True, |
| 83 | + device=None, |
| 84 | + dtype=None, |
| 85 | + ): |
| 86 | + super().__init__(in_features, out_features, bias, device, dtype) |
| 87 | + self.act_obs = act_obs |
| 88 | + self.weight_obs = weight_obs |
| 89 | + |
| 90 | + def forward(self, input: torch.Tensor): |
| 91 | + observed_input = self.act_obs(input) |
| 92 | + observed_weight = self.weight_obs(self.weight) |
| 93 | + return F.linear(observed_input, observed_weight, self.bias) |
| 94 | + |
| 95 | + @classmethod |
| 96 | + def from_float(cls, float_linear, act_obs, weight_obs): |
| 97 | + observed_linear = cls( |
| 98 | + float_linear.in_features, |
| 99 | + float_linear.out_features, |
| 100 | + act_obs, |
| 101 | + weight_obs, |
| 102 | + False, |
| 103 | + device=float_linear.weight.device, |
| 104 | + dtype=float_linear.weight.dtype, |
| 105 | + ) |
| 106 | + observed_linear.weight = float_linear.weight |
| 107 | + observed_linear.bias = float_linear.bias |
| 108 | + return observed_linear |
| 109 | +
|
| 110 | +To actually insert these observers into our toy model: |
| 111 | + |
| 112 | +.. code:: py |
| 113 | +
|
| 114 | + from torchao.quantization.quant_api import ( |
| 115 | + _replace_with_custom_fn_if_matches_filter, |
| 116 | + ) |
| 117 | +
|
| 118 | + def insert_observers_(model, act_obs, weight_obs): |
| 119 | + _is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear) |
| 120 | +
|
| 121 | + def replacement_fn(m): |
| 122 | + copied_act_obs = copy.deepcopy(act_obs) |
| 123 | + copied_weight_obs = copy.deepcopy(weight_obs) |
| 124 | + return ObservedLinear.from_float(m, copied_act_obs, copied_weight_obs) |
| 125 | +
|
| 126 | + _replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear) |
| 127 | +
|
| 128 | + insert_observers_(m, act_obs, weight_obs) |
| 129 | +
|
| 130 | +Now we are ready to calibrate the model, which populates the observers we inserted with statistics recorded during the calibration. We can do this simply by feeding some example inputs to our "observed" model: |
| 131 | + |
| 132 | +.. code:: py |
| 133 | +
|
| 134 | + for _ in range(10): |
| 135 | + example_inputs = m.example_inputs(dtype=dtype, device="cuda") |
| 136 | + m(*example_inputs) |
| 137 | +
|
| 138 | +
|
| 139 | +Quantization Phase |
| 140 | +~~~~~~~~~~~~~~~~~~ |
| 141 | + |
| 142 | +There are multiple ways to actually quantize the model. Here we walk through the simpler alternative, which is to define a `QuantizedLinear` class that we will swap our `ObservedLinear` to. Defining this new class isn't strictly necessary. For an alternative method that simply uses the existing `torch.nn.Linear`, please see the full `example script <https://github.com/pytorch/ao/tree/main/tutorials/calibration_flow/static_quant.py>`__. |
| 143 | + |
| 144 | +.. code:: py |
| 145 | +
|
| 146 | + from torchao.dtypes import to_affine_quantized_intx_static |
| 147 | +
|
| 148 | + class QuantizedLinear(torch.nn.Module): |
| 149 | + def __init__( |
| 150 | + self, |
| 151 | + in_features: int, |
| 152 | + out_features: int, |
| 153 | + act_obs: torch.nn.Module, |
| 154 | + weight_obs: torch.nn.Module, |
| 155 | + weight: torch.Tensor, |
| 156 | + bias: torch.Tensor, |
| 157 | + target_dtype: torch.dtype, |
| 158 | + ): |
| 159 | + super().__init__() |
| 160 | + self.act_scale, self.act_zero_point = act_obs.calculate_qparams() |
| 161 | + weight_scale, weight_zero_point = weight_obs.calculate_qparams() |
| 162 | + assert weight.dim() == 2 |
| 163 | + block_size = (1, weight.shape[1]) |
| 164 | + self.target_dtype = target_dtype |
| 165 | + self.bias = bias |
| 166 | + self.qweight = to_affine_quantized_intx_static( |
| 167 | + weight, weight_scale, weight_zero_point, block_size, self.target_dtype |
| 168 | + ) |
| 169 | + |
| 170 | + def forward(self, input: torch.Tensor): |
| 171 | + block_size = input.shape |
| 172 | + qinput = to_affine_quantized_intx_static( |
| 173 | + input, |
| 174 | + self.act_scale, |
| 175 | + self.act_zero_point, |
| 176 | + block_size, |
| 177 | + self.target_dtype, |
| 178 | + ) |
| 179 | + return F.linear(qinput, self.qweight, self.bias) |
| 180 | + |
| 181 | + @classmethod |
| 182 | + def from_observed(cls, observed_linear, target_dtype): |
| 183 | + quantized_linear = cls( |
| 184 | + observed_linear.in_features, |
| 185 | + observed_linear.out_features, |
| 186 | + observed_linear.act_obs, |
| 187 | + observed_linear.weight_obs, |
| 188 | + observed_linear.weight, |
| 189 | + observed_linear.bias, |
| 190 | + target_dtype, |
| 191 | + ) |
| 192 | + return quantized_linear |
| 193 | +
|
| 194 | +This linear class computes the scales and zero points for both input activations and weights in the beginning, effectively fixing the quantization range for future forward calls. Now, to actually quantize the model using this linear class, we can define the following config and pass it to torchao's main `quantize_` API: |
| 195 | + |
| 196 | +.. code:: py |
| 197 | +
|
| 198 | + from dataclasses import dataclass |
| 199 | +
|
| 200 | + from torchao.core.config import AOBaseConfig |
| 201 | + from torchao.quantization import quantize_ |
| 202 | + from torchao.quantization.transform_module import ( |
| 203 | + register_quantize_module_handler, |
| 204 | + ) |
| 205 | + |
| 206 | + @dataclass |
| 207 | + class ApplyStaticQuantConfig(AOBaseConfig): |
| 208 | + target_dtype: torch.dtype |
| 209 | + |
| 210 | + @register_quantize_module_handler(ApplyStaticQuantConfig) |
| 211 | + def _apply_static_quant( |
| 212 | + module: torch.nn.Module, |
| 213 | + config: ApplyStaticQuantConfig, |
| 214 | + ): |
| 215 | + """ |
| 216 | + Define a transformation associated with `ApplyStaticQuantConfig`. |
| 217 | + This is called by `quantize_`, not by the user directly. |
| 218 | + """ |
| 219 | + return QuantizedLinear.from_observed(module, config.target_dtype) |
| 220 | +
|
| 221 | + # filter function to identify which modules to swap |
| 222 | + is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear) |
| 223 | +
|
| 224 | + # perform static quantization |
| 225 | + quantize_(m, ApplyStaticQuantConfig(torch.uint8), is_observed_linear) |
| 226 | +
|
| 227 | +Now, we will see that the linear layers in our model are swapped to our `QuantizedLinear` class, with a fixed input activation scale and a fixed quantized weight: |
| 228 | + |
| 229 | +.. code:: py |
| 230 | +
|
| 231 | + >>> m |
| 232 | + OptimizedModule( |
| 233 | + (_orig_mod): ToyLinearModel( |
| 234 | + (linear1): QuantizedLinear() |
| 235 | + (linear2): QuantizedLinear() |
| 236 | + ) |
| 237 | + ) |
| 238 | + >>> m.linear1.act_scale |
| 239 | + tensor([0.0237], device='cuda:0') |
| 240 | + >>> m.linear1.qweight |
| 241 | + AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=tensor([[142, 31, 42, ..., 113, 157, 57], |
| 242 | + [ 59, 160, 70, ..., 23, 150, 67], |
| 243 | + [ 44, 49, 241, ..., 238, 69, 235], |
| 244 | + ..., |
| 245 | + [228, 255, 201, ..., 114, 236, 73], |
| 246 | + [ 50, 88, 83, ..., 109, 209, 92], |
| 247 | + [184, 141, 35, ..., 224, 110, 66]], device='cuda:0', |
| 248 | + dtype=torch.uint8)... , scale=tensor([0.0009, 0.0010, 0.0009, 0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0010, |
| 249 | + 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, |
| 250 | + 0.0010, 0.0010, 0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0010, |
| 251 | + 0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0009, 0.0009, 0.0010, 0.0009, |
| 252 | + 0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0009, 0.0009, 0.0010, |
| 253 | + 0.0009, 0.0010, 0.0009, 0.0009, 0.0009, 0.0010, 0.0010, 0.0009, 0.0009, |
| 254 | + 0.0010, 0.0009, 0.0010, 0.0010, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009, |
| 255 | + 0.0010], device='cuda:0')... , zero_point=tensor([130., 128., 122., 130., 132., 128., 125., 130., 126., 128., 129., 126., |
| 256 | + 128., 128., 128., 128., 129., 127., 130., 125., 128., 133., 126., 126., |
| 257 | + 128., 124., 127., 128., 128., 128., 129., 124., 126., 133., 129., 127., |
| 258 | + 126., 124., 130., 126., 127., 129., 124., 125., 127., 130., 128., 132., |
| 259 | + 128., 129., 128., 129., 131., 132., 127., 135., 126., 130., 124., 136., |
| 260 | + 131., 124., 130., 129.], device='cuda:0')... , _layout=PlainLayout()), block_size=(1, 64), shape=torch.Size([64, 64]), device=cuda:0, dtype=torch.bfloat16, requires_grad=False) |
| 261 | +
|
| 262 | +In this tutorial, we walked through a basic example of how to perform integer static quantization in torchao. We also have an example of how to perform the same static quantization in float8. Please see the full `example script <https://github.com/pytorch/ao/tree/main/tutorials/calibration_flow/static_quant.py>`__ for more detail! |
0 commit comments