Skip to content

Commit 6d478c8

Browse files
committed
Add static quant tutorial
1 parent 31f119e commit 6d478c8

File tree

3 files changed

+269
-1
lines changed

3 files changed

+269
-1
lines changed

docs/source/index.rst

+1
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ for an overall introduction to the library and recent highlight and updates.
3939
serialization
4040
subclass_basic
4141
subclass_advanced
42+
static_quantization

docs/source/static_quantization.rst

+262
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
Static Quantization
2+
--------------------
3+
4+
Static quantization refers to using a fixed quantization range for all inputs during inference or generation. Unlike dynamic quantization, which dynamically computes new quantization ranges for each new input batch, static quantization typically results in more efficient computation, potentially at the cost of lower quantized accuracy since we cannot adapt to changes in the input distribution on-the-fly.
5+
6+
In static quantization, this fixed quantization range is typically calibrated on similar inputs before quantizing the model. During the calibration phase, we first insert observers into the model to "observe" the distribution of the inputs to be quantized, and use this distribution to decide what scales and zero points to ultimately use when quantizing the model.
7+
8+
In this tutorial, we walk through an example of how to achieve this in torchao. All code can be found in this `example script <https://github.com/pytorch/ao/tree/main/tutorials/calibration_flow/static_quant.py>`__. Let's start with our toy linear model:
9+
10+
.. code:: py
11+
12+
import copy
13+
import torch
14+
15+
class ToyLinearModel(torch.nn.Module):
16+
def __init__(self, m=64, n=32, k=64):
17+
super().__init__()
18+
self.linear1 = torch.nn.Linear(m, k, bias=False)
19+
self.linear2 = torch.nn.Linear(k, n, bias=False)
20+
21+
def example_inputs(self, batch_size=1, dtype=torch.float32, device="cpu"):
22+
return (
23+
torch.randn(
24+
batch_size, self.linear1.in_features, dtype=dtype, device=device
25+
),
26+
)
27+
28+
def forward(self, x):
29+
x = self.linear1(x)
30+
x = self.linear2(x)
31+
return x
32+
33+
dtype = torch.bfloat16
34+
m = ToyLinearModel().eval().to(dtype).to("cuda")
35+
m = torch.compile(m, mode="max-autotune")
36+
37+
38+
Calibration Phase
39+
~~~~~~~~~~~~~~~~~
40+
41+
torchao comes with a a simple observer implementation, `AffineQuantizedMinMaxObserver`, that records the min and max values that have flowed through the observer during the calibration phase. Users are welcome to implement their own desired, more advanced observation techniques, such as those relying on moving averages or histograms, and these may be added to torchao in the future.
42+
43+
.. code:: py
44+
45+
from torchao.quantization.granularity import PerAxis, PerTensor
46+
from torchao.quantization.observer import AffineQuantizedMinMaxObserver
47+
from torchao.quantization.quant_primitives import MappingType
48+
49+
# per tensor input activation asymmetric quantization
50+
act_obs = AffineQuantizedMinMaxObserver(
51+
MappingType.ASYMMETRIC,
52+
torch.uint8,
53+
granularity=PerTensor(),
54+
eps=torch.finfo(torch.float32).eps,
55+
scale_dtype=torch.float32,
56+
zero_point_dtype=torch.float32,
57+
)
58+
59+
# per channel weight asymmetric quantization
60+
weight_obs = AffineQuantizedMinMaxObserver(
61+
MappingType.ASYMMETRIC,
62+
torch.uint8,
63+
granularity=PerAxis(axis=0),
64+
eps=torch.finfo(torch.float32).eps,
65+
scale_dtype=torch.float32,
66+
zero_point_dtype=torch.float32,
67+
)
68+
69+
Next, we define our observed linear that we will swap our `torch.nn.Linear` with. This is a high precision (e.g. fp32) linear module with the above observers inserted to record the input activation and weight values during calibration:
70+
71+
.. code:: py
72+
73+
import torch.nn.functional as F
74+
75+
class ObservedLinear(torch.nn.Linear):
76+
def __init__(
77+
self,
78+
in_features: int,
79+
out_features: int,
80+
act_obs: torch.nn.Module,
81+
weight_obs: torch.nn.Module,
82+
bias: bool = True,
83+
device=None,
84+
dtype=None,
85+
):
86+
super().__init__(in_features, out_features, bias, device, dtype)
87+
self.act_obs = act_obs
88+
self.weight_obs = weight_obs
89+
90+
def forward(self, input: torch.Tensor):
91+
observed_input = self.act_obs(input)
92+
observed_weight = self.weight_obs(self.weight)
93+
return F.linear(observed_input, observed_weight, self.bias)
94+
95+
@classmethod
96+
def from_float(cls, float_linear, act_obs, weight_obs):
97+
observed_linear = cls(
98+
float_linear.in_features,
99+
float_linear.out_features,
100+
act_obs,
101+
weight_obs,
102+
False,
103+
device=float_linear.weight.device,
104+
dtype=float_linear.weight.dtype,
105+
)
106+
observed_linear.weight = float_linear.weight
107+
observed_linear.bias = float_linear.bias
108+
return observed_linear
109+
110+
To actually insert these observers into our toy model:
111+
112+
.. code:: py
113+
114+
from torchao.quantization.quant_api import (
115+
_replace_with_custom_fn_if_matches_filter,
116+
)
117+
118+
def insert_observers_(model, act_obs, weight_obs):
119+
_is_linear = lambda m, fqn: isinstance(m, torch.nn.Linear)
120+
121+
def replacement_fn(m):
122+
copied_act_obs = copy.deepcopy(act_obs)
123+
copied_weight_obs = copy.deepcopy(weight_obs)
124+
return ObservedLinear.from_float(m, copied_act_obs, copied_weight_obs)
125+
126+
_replace_with_custom_fn_if_matches_filter(model, replacement_fn, _is_linear)
127+
128+
insert_observers_(m, act_obs, weight_obs)
129+
130+
Now we are ready to calibrate the model, which populates the observers we inserted with statistics recorded during the calibration. We can do this simply by feeding some example inputs to our "observed" model:
131+
132+
.. code:: py
133+
134+
for _ in range(10):
135+
example_inputs = m.example_inputs(dtype=dtype, device="cuda")
136+
m(*example_inputs)
137+
138+
139+
Quantization Phase
140+
~~~~~~~~~~~~~~~~~~
141+
142+
There are multiple ways to actually quantize the model. Here we walk through the simpler alternative, which is to define a `QuantizedLinear` class that we will swap our `ObservedLinear` to. Defining this new class isn't strictly necessary. For an alternative method that simply uses the existing `torch.nn.Linear`, please see the full `example script <https://github.com/pytorch/ao/tree/main/tutorials/calibration_flow/static_quant.py>`__.
143+
144+
.. code:: py
145+
146+
from torchao.dtypes import to_affine_quantized_intx_static
147+
148+
class QuantizedLinear(torch.nn.Module):
149+
def __init__(
150+
self,
151+
in_features: int,
152+
out_features: int,
153+
act_obs: torch.nn.Module,
154+
weight_obs: torch.nn.Module,
155+
weight: torch.Tensor,
156+
bias: torch.Tensor,
157+
target_dtype: torch.dtype,
158+
):
159+
super().__init__()
160+
self.act_scale, self.act_zero_point = act_obs.calculate_qparams()
161+
weight_scale, weight_zero_point = weight_obs.calculate_qparams()
162+
assert weight.dim() == 2
163+
block_size = (1, weight.shape[1])
164+
self.target_dtype = target_dtype
165+
self.bias = bias
166+
self.qweight = to_affine_quantized_intx_static(
167+
weight, weight_scale, weight_zero_point, block_size, self.target_dtype
168+
)
169+
170+
def forward(self, input: torch.Tensor):
171+
block_size = input.shape
172+
qinput = to_affine_quantized_intx_static(
173+
input,
174+
self.act_scale,
175+
self.act_zero_point,
176+
block_size,
177+
self.target_dtype,
178+
)
179+
return F.linear(qinput, self.qweight, self.bias)
180+
181+
@classmethod
182+
def from_observed(cls, observed_linear, target_dtype):
183+
quantized_linear = cls(
184+
observed_linear.in_features,
185+
observed_linear.out_features,
186+
observed_linear.act_obs,
187+
observed_linear.weight_obs,
188+
observed_linear.weight,
189+
observed_linear.bias,
190+
target_dtype,
191+
)
192+
return quantized_linear
193+
194+
This linear class computes the scales and zero points for both input activations and weights in the beginning, effectively fixing the quantization range for future forward calls. Now, to actually quantize the model using this linear class, we can define the following config and pass it to torchao's main `quantize_` API:
195+
196+
.. code:: py
197+
198+
from dataclasses import dataclass
199+
200+
from torchao.core.config import AOBaseConfig
201+
from torchao.quantization import quantize_
202+
from torchao.quantization.transform_module import (
203+
register_quantize_module_handler,
204+
)
205+
206+
@dataclass
207+
class ApplyStaticQuantConfig(AOBaseConfig):
208+
target_dtype: torch.dtype
209+
210+
@register_quantize_module_handler(ApplyStaticQuantConfig)
211+
def _apply_static_quant(
212+
module: torch.nn.Module,
213+
config: ApplyStaticQuantConfig,
214+
):
215+
"""
216+
Define a transformation associated with `ApplyStaticQuantConfig`.
217+
This is called by `quantize_`, not by the user directly.
218+
"""
219+
return QuantizedLinear.from_observed(module, config.target_dtype)
220+
221+
# filter function to identify which modules to swap
222+
is_observed_linear = lambda m, fqn: isinstance(m, ObservedLinear)
223+
224+
# perform static quantization
225+
quantize_(m, ApplyStaticQuantConfig(torch.uint8), is_observed_linear)
226+
227+
Now, we will see that the linear layers in our model are swapped to our `QuantizedLinear` class, with a fixed input activation scale and a fixed quantized weight:
228+
229+
.. code:: py
230+
231+
>>> m
232+
OptimizedModule(
233+
(_orig_mod): ToyLinearModel(
234+
(linear1): QuantizedLinear()
235+
(linear2): QuantizedLinear()
236+
)
237+
)
238+
>>> m.linear1.act_scale
239+
tensor([0.0237], device='cuda:0')
240+
>>> m.linear1.qweight
241+
AffineQuantizedTensor(tensor_impl=PlainAQTTensorImpl(data=tensor([[142, 31, 42, ..., 113, 157, 57],
242+
[ 59, 160, 70, ..., 23, 150, 67],
243+
[ 44, 49, 241, ..., 238, 69, 235],
244+
...,
245+
[228, 255, 201, ..., 114, 236, 73],
246+
[ 50, 88, 83, ..., 109, 209, 92],
247+
[184, 141, 35, ..., 224, 110, 66]], device='cuda:0',
248+
dtype=torch.uint8)... , scale=tensor([0.0009, 0.0010, 0.0009, 0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0010,
249+
0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010, 0.0010,
250+
0.0010, 0.0010, 0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0010,
251+
0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0009, 0.0009, 0.0010, 0.0009,
252+
0.0010, 0.0009, 0.0010, 0.0010, 0.0010, 0.0009, 0.0009, 0.0009, 0.0010,
253+
0.0009, 0.0010, 0.0009, 0.0009, 0.0009, 0.0010, 0.0010, 0.0009, 0.0009,
254+
0.0010, 0.0009, 0.0010, 0.0010, 0.0009, 0.0009, 0.0009, 0.0009, 0.0009,
255+
0.0010], device='cuda:0')... , zero_point=tensor([130., 128., 122., 130., 132., 128., 125., 130., 126., 128., 129., 126.,
256+
128., 128., 128., 128., 129., 127., 130., 125., 128., 133., 126., 126.,
257+
128., 124., 127., 128., 128., 128., 129., 124., 126., 133., 129., 127.,
258+
126., 124., 130., 126., 127., 129., 124., 125., 127., 130., 128., 132.,
259+
128., 129., 128., 129., 131., 132., 127., 135., 126., 130., 124., 136.,
260+
131., 124., 130., 129.], device='cuda:0')... , _layout=PlainLayout()), block_size=(1, 64), shape=torch.Size([64, 64]), device=cuda:0, dtype=torch.bfloat16, requires_grad=False)
261+
262+
In this tutorial, we walked through a basic example of how to perform integer static quantization in torchao. We also have an example of how to perform the same static quantization in float8. Please see the full `example script <https://github.com/pytorch/ao/tree/main/tutorials/calibration_flow/static_quant.py>`__ for more detail!

torchao/dtypes/affine_quantized_tensor.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -337,7 +337,12 @@ def from_hp_to_intx_static(
337337
zero_point_domain,
338338
)
339339

340-
int_data, scale, zero_point = _layout.post_process(int_data, scale, zero_point)
340+
int_data, scale, zero_point = _layout.post_process(
341+
int_data,
342+
scale,
343+
zero_point,
344+
block_size,
345+
)
341346

342347
tensor_impl_ctr = get_tensor_impl_constructor(type(_layout))
343348
tensor_impl = tensor_impl_ctr(int_data, scale, zero_point, _layout)

0 commit comments

Comments
 (0)