Skip to content

Commit 4583d99

Browse files
committed
Enabling MOE Quantization using linear decomposition [WIP]
Summary: This PR is a first step at optimizing moe inference using torchAO. The goal for this step is to enable existing quantization kernels and workflows to work for moe quantization by decomposing the group gemm into a sequence of unbalanced linear ops that can use the existing quantized kernels. To enable this we had to add support for quantizing these 3D tensors as well as slicing and indexing. current tests are running locally but will be added once working. currently int8wo and int8dq are working for multi and single token moe inference while int4wo is being finished up. TODO move test set into ao, move quantizable moe module code to ao test on hf model definition. Test Plan: Reviewers: Subscribers: Tasks: Tags: testing Summary: Test Plan: Reviewers: Subscribers: Tasks: Tags:
1 parent a81322e commit 4583d99

19 files changed

+1857
-59
lines changed

torchao/_models/mixtral-moe/generate.py

+396
Large diffs are not rendered by default.

torchao/_models/mixtral-moe/model.py

+360
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,360 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
from dataclasses import dataclass
7+
from typing import Optional
8+
9+
import torch
10+
import torch.nn as nn
11+
from torch import Tensor
12+
from torch.nn import functional as F
13+
14+
15+
def find_multiple(n: int, k: int) -> int:
16+
if n % k == 0:
17+
return n
18+
return n + k - (n % k)
19+
20+
@dataclass
21+
class ModelArgs:
22+
block_size: int = 2048
23+
vocab_size: int = 32000
24+
n_layer: int = 32
25+
n_head: int = 32
26+
dim: int = 4096
27+
intermediate_size: int = None
28+
n_local_heads: int = -1
29+
head_dim: int = 64
30+
rope_base: float = 10000
31+
norm_eps: float = 1e-5
32+
num_experts: int = 8
33+
num_activated_experts: int = 2
34+
35+
def __post_init__(self):
36+
if self.n_local_heads == -1:
37+
self.n_local_heads = self.n_head
38+
if self.intermediate_size is None:
39+
hidden_dim = 4 * self.dim
40+
n_hidden = int(2 * hidden_dim / 3)
41+
self.intermediate_size = find_multiple(n_hidden, 256)
42+
self.head_dim = self.dim // self.n_head
43+
44+
@classmethod
45+
def from_name(cls, name: str):
46+
if name in transformer_configs:
47+
return cls(**transformer_configs[name])
48+
# fuzzy search
49+
config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)]
50+
assert len(config) == 1, name
51+
return cls(**transformer_configs[config[0]])
52+
53+
54+
transformer_configs = {
55+
"Mixtral-8x7B-Instruct-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2),
56+
}
57+
58+
class KVCache(nn.Module):
59+
def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16):
60+
super().__init__()
61+
cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim)
62+
self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype))
63+
self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype))
64+
65+
def update(self, input_pos, k_val, v_val):
66+
# input_pos: [S], k_val: [B, H, S, D]
67+
assert input_pos.shape[0] == k_val.shape[2]
68+
69+
k_out = self.k_cache
70+
v_out = self.v_cache
71+
k_out[:, :, input_pos] = k_val
72+
v_out[:, :, input_pos] = v_val
73+
74+
return k_out, v_out
75+
76+
class Transformer(nn.Module):
77+
def __init__(self, config: ModelArgs) -> None:
78+
super().__init__()
79+
self.config = config
80+
81+
self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
82+
self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer))
83+
self.norm = RMSNorm(config.dim, eps=config.norm_eps)
84+
self.output = nn.Linear(config.dim, config.vocab_size, bias=False)
85+
86+
self.freqs_cis: Optional[Tensor] = None
87+
self.mask_cache: Optional[Tensor] = None
88+
self.max_batch_size = -1
89+
self.max_seq_length = -1
90+
91+
def setup_caches(self, max_batch_size, max_seq_length):
92+
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
93+
return
94+
head_dim = self.config.dim // self.config.n_head
95+
max_seq_length = find_multiple(max_seq_length, 8)
96+
self.max_seq_length = max_seq_length
97+
self.max_batch_size = max_batch_size
98+
for b in self.layers:
99+
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim)
100+
101+
self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base)
102+
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))
103+
104+
def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
105+
assert self.freqs_cis is not None, "Caches must be initialized first"
106+
mask = self.causal_mask[None, None, input_pos]
107+
freqs_cis = self.freqs_cis[input_pos]
108+
x = self.tok_embeddings(idx)
109+
110+
for i, layer in enumerate(self.layers):
111+
x = layer(x, input_pos, freqs_cis, mask)
112+
x = self.norm(x)
113+
logits = self.output(x)
114+
return logits
115+
116+
@classmethod
117+
def from_name(cls, name: str):
118+
return cls(ModelArgs.from_name(name))
119+
120+
121+
class TransformerBlock(nn.Module):
122+
def __init__(self, config: ModelArgs) -> None:
123+
super().__init__()
124+
self.attention = Attention(config)
125+
self.block_sparse_moe = MOEFeedForwardAOQuantizable(config)
126+
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
127+
self.attention_norm = RMSNorm(config.dim, config.norm_eps)
128+
129+
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
130+
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
131+
out = h + self.block_sparse_moe(self.ffn_norm(h))
132+
return out
133+
134+
135+
class Attention(nn.Module):
136+
def __init__(self, config: ModelArgs):
137+
super().__init__()
138+
assert config.dim % config.n_head == 0
139+
140+
total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim
141+
# key, query, value projections for all heads, but in a batch
142+
self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False)
143+
self.wo = nn.Linear(config.dim, config.dim, bias=False)
144+
self.kv_cache = None
145+
146+
self.n_head = config.n_head
147+
self.head_dim = config.head_dim
148+
self.n_local_heads = config.n_local_heads
149+
self.dim = config.dim
150+
self._register_load_state_dict_pre_hook(self.load_hook)
151+
152+
def load_hook(self, state_dict, prefix, *args):
153+
if prefix + "wq.weight" in state_dict:
154+
wq = state_dict.pop(prefix + "wq.weight")
155+
wk = state_dict.pop(prefix + "wk.weight")
156+
wv = state_dict.pop(prefix + "wv.weight")
157+
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])
158+
159+
def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
160+
bsz, seqlen, _ = x.shape
161+
162+
kv_size = self.n_local_heads * self.head_dim
163+
q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1)
164+
165+
q = q.view(bsz, seqlen, self.n_head, self.head_dim)
166+
k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim)
167+
v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim)
168+
169+
q = apply_rotary_emb(q, freqs_cis)
170+
k = apply_rotary_emb(k, freqs_cis)
171+
172+
q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v))
173+
174+
if self.kv_cache is not None:
175+
k, v = self.kv_cache.update(input_pos, k, v)
176+
177+
k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
178+
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
179+
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)
180+
181+
y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)
182+
183+
y = self.wo(y)
184+
return y
185+
186+
187+
class ConditionalFeedForward(nn.Module):
188+
def __init__(self, config):
189+
super().__init__()
190+
self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
191+
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size))
192+
self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim))
193+
194+
def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor:
195+
w1_weights = self.w1[expert_indices] # [T, A, D, D]
196+
w3_weights = self.w3[expert_indices] # [T, A, D, D]
197+
w2_weights = self.w2[expert_indices] # [T, A, D, D]
198+
x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights))
199+
x3 = torch.einsum('ti, taoi -> tao', x, w3_weights)
200+
expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights)
201+
return expert_outs
202+
203+
204+
class MOEFeedForward(nn.Module):
205+
def __init__(self, config) -> None:
206+
super().__init__()
207+
self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
208+
self.cond_ffn = ConditionalFeedForward(config)
209+
self.dim = config.dim
210+
self.num_activated_experts = config.num_activated_experts
211+
def forward(self, x: Tensor) -> Tensor:
212+
x = x.view(-1, self.dim)
213+
# T = num_tokens, E = num_experts, D = hidden dim, A = activated experts
214+
# x: [T, D]
215+
scores = self.gate(x) # [T, E]
216+
expert_weights = F.softmax(scores, dim=-1)
217+
expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
218+
expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A]
219+
expert_outs = self.cond_ffn(x, expert_indices)
220+
return torch.einsum('tai,ta -> ti', expert_outs, expert_weights)
221+
222+
223+
class RMSNorm(nn.Module):
224+
def __init__(self, dim: int, eps: float = 1e-5):
225+
super().__init__()
226+
self.eps = eps
227+
self.weight = nn.Parameter(torch.ones(dim))
228+
229+
def _norm(self, x):
230+
return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps)
231+
232+
def forward(self, x: Tensor) -> Tensor:
233+
output = self._norm(x.float()).type_as(x)
234+
return output * self.weight
235+
236+
237+
def precompute_freqs_cis(
238+
seq_len: int, n_elem: int, base: int = 10000
239+
) -> Tensor:
240+
freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem))
241+
t = torch.arange(seq_len, device=freqs.device)
242+
freqs = torch.outer(t, freqs)
243+
freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
244+
cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1)
245+
return cache.to(dtype=torch.bfloat16)
246+
247+
248+
def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
249+
xshaped = x.float().reshape(*x.shape[:-1], -1, 2)
250+
freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2)
251+
x_out2 = torch.stack(
252+
[
253+
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
254+
xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1],
255+
],
256+
-1,
257+
)
258+
259+
x_out2 = x_out2.flatten(3)
260+
return x_out2.type_as(x)
261+
262+
263+
# T tokens
264+
# E experts
265+
# D dim
266+
# I intermediate dim
267+
# A activated experts
268+
# T'(e) tokens for expert e
269+
270+
class MOEFeedForwardAOQuantizable(nn.Module):
271+
def __init__(self, config) -> None:
272+
super().__init__()
273+
self.gate = nn.Linear(config.dim, config.num_experts, bias=False)
274+
self.cond_ffn = ConditionalFeedForwardAOQuantizable(config)
275+
self.dim = config.dim
276+
self.num_activated_experts = config.num_activated_experts
277+
def forward(self, x: Tensor) -> Tensor:
278+
batch_size = x.shape[0]
279+
x = x.view(-1, self.dim) # x: [T, D]
280+
scores = self.gate(x) # [T, E]
281+
expert_weights = F.softmax(scores, dim=-1)
282+
expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A]
283+
expert_weights /= expert_weights.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A]
284+
out = self.cond_ffn(x, expert_indices, expert_weights, self.num_activated_experts)
285+
return out.reshape(batch_size, -1, self.dim)
286+
287+
288+
class ConditionalFeedForwardAOQuantizable(nn.Module):
289+
def __init__(self, config):
290+
super().__init__()
291+
self.config = config
292+
self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D
293+
self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) # E, D, I
294+
self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D
295+
self.num_experts = config.num_experts
296+
def forward(
297+
self, x: Tensor, # T, D
298+
expert_indices: Tensor, # T, A
299+
expert_weights: Tensor, # T, A
300+
num_activated_experts: int,
301+
) -> Tensor:
302+
num_tokens, dim = x.shape
303+
num_token_activations = num_tokens * num_activated_experts
304+
305+
if x.shape[0]==1: #only 1 token (can be done without graph breaks when compiled)
306+
outs = []
307+
expert_indices=expert_indices.squeeze()
308+
# collect used experts
309+
w1 = self.w1[expert_indices]
310+
w2 = self.w2[expert_indices]
311+
w3 = self.w3[expert_indices]
312+
313+
# run token through each expert
314+
for index in range(num_activated_experts):
315+
y1 = F.silu(F.linear(x, w1[index]))
316+
y3 = F.linear(x, w3[index])
317+
y2 = w2[index]
318+
cur_out = F.linear( y1 * y3, y2)
319+
outs.append(cur_out)
320+
321+
# combine outputs
322+
final_out = (torch.cat(outs, dim=0) * expert_weights.view(-1,1)).sum(dim=0).unsqueeze(-1)
323+
return final_out
324+
else:
325+
expert_list = [x for x in range(self.num_experts)]
326+
327+
# shuffle tokens into groups for each expert
328+
ordered_token_activations = expert_indices.view(-1).argsort(stable=True) # [A]
329+
ordered_token_indices = ordered_token_activations.div(num_activated_experts).floor().to(torch.int64) # [T]
330+
331+
num_tokens_per_expert = torch.histc(expert_indices, bins=self.num_experts+1, min=-1, max=self.num_experts) # [E+1] (added leading 0 so can be used for indexing)
332+
cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(torch.int64) # [E+1]
333+
334+
@torch._dynamo.disable()
335+
def group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list):
336+
token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ...
337+
return token_indices_per_expert
338+
token_indices_per_expert = group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list)
339+
tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert]
340+
341+
# calculate outputs for each expert
342+
outs = []
343+
for cur_x, expert in zip(tokens_grouped_by_expert,expert_list):
344+
345+
w1=self.w1[expert] # I, D
346+
w2=self.w2[expert] # D, I
347+
w3=self.w3[expert] # I, D
348+
349+
cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # [T'(e), D]
350+
outs.append(cur_out)
351+
352+
# weigh outputs
353+
ordered_outs = torch.cat(outs, dim=0) # [T*A, D]
354+
ordered_token_activation_weights = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1) # [T*A, 1]
355+
weighted_ordered_outs = ordered_outs*ordered_token_activation_weights # [T*A, D]
356+
357+
# sum weighted token-activation outputs together for each token
358+
final_out = torch.zeros_like(x) # [T, D]
359+
final_out = final_out.scatter_add(dim=0, index=ordered_token_indices.unsqueeze(-1).expand(num_token_activations,dim).to(torch.int64), src=weighted_ordered_outs)
360+
return final_out

torchao/_models/mixtral-moe/run.sh

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
export MODEL_REPO=mistralai/Mixtral-8x7B-Instruct-v0.1
2+
export CHECKPOINT_PATH=/data/users/cdhernandez/gpt-fast/checkpoints/
3+
4+
5+
6+
python generate.py --checkpoint_path $CHECKPOINT_PATH/$MODEL_REPO/model.pth --batch_size 1 --moe_quant int8wo-base --compile

0 commit comments

Comments
 (0)