|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# All rights reserved. |
| 3 | + |
| 4 | +# This source code is licensed under the license found in the |
| 5 | +# LICENSE file in the root directory of this source tree. |
| 6 | +from dataclasses import dataclass |
| 7 | +from typing import Optional |
| 8 | + |
| 9 | +import torch |
| 10 | +import torch.nn as nn |
| 11 | +from torch import Tensor |
| 12 | +from torch.nn import functional as F |
| 13 | + |
| 14 | + |
| 15 | +def find_multiple(n: int, k: int) -> int: |
| 16 | + if n % k == 0: |
| 17 | + return n |
| 18 | + return n + k - (n % k) |
| 19 | + |
| 20 | +@dataclass |
| 21 | +class ModelArgs: |
| 22 | + block_size: int = 2048 |
| 23 | + vocab_size: int = 32000 |
| 24 | + n_layer: int = 32 |
| 25 | + n_head: int = 32 |
| 26 | + dim: int = 4096 |
| 27 | + intermediate_size: int = None |
| 28 | + n_local_heads: int = -1 |
| 29 | + head_dim: int = 64 |
| 30 | + rope_base: float = 10000 |
| 31 | + norm_eps: float = 1e-5 |
| 32 | + num_experts: int = 8 |
| 33 | + num_activated_experts: int = 2 |
| 34 | + |
| 35 | + def __post_init__(self): |
| 36 | + if self.n_local_heads == -1: |
| 37 | + self.n_local_heads = self.n_head |
| 38 | + if self.intermediate_size is None: |
| 39 | + hidden_dim = 4 * self.dim |
| 40 | + n_hidden = int(2 * hidden_dim / 3) |
| 41 | + self.intermediate_size = find_multiple(n_hidden, 256) |
| 42 | + self.head_dim = self.dim // self.n_head |
| 43 | + |
| 44 | + @classmethod |
| 45 | + def from_name(cls, name: str): |
| 46 | + if name in transformer_configs: |
| 47 | + return cls(**transformer_configs[name]) |
| 48 | + # fuzzy search |
| 49 | + config = [config for config in transformer_configs if config in str(name).upper() or config in str(name)] |
| 50 | + assert len(config) == 1, name |
| 51 | + return cls(**transformer_configs[config[0]]) |
| 52 | + |
| 53 | + |
| 54 | +transformer_configs = { |
| 55 | + "Mixtral-8x7B-Instruct-v0.1": dict(block_size=32768, n_layer=32, n_head=32, n_local_heads=8, dim=4096, intermediate_size=14336, rope_base=1000000.0, num_experts=8, num_activated_experts=2), |
| 56 | +} |
| 57 | + |
| 58 | +class KVCache(nn.Module): |
| 59 | + def __init__(self, max_batch_size, max_seq_length, n_heads, head_dim, dtype=torch.bfloat16): |
| 60 | + super().__init__() |
| 61 | + cache_shape = (max_batch_size, n_heads, max_seq_length, head_dim) |
| 62 | + self.register_buffer('k_cache', torch.zeros(cache_shape, dtype=dtype)) |
| 63 | + self.register_buffer('v_cache', torch.zeros(cache_shape, dtype=dtype)) |
| 64 | + |
| 65 | + def update(self, input_pos, k_val, v_val): |
| 66 | + # input_pos: [S], k_val: [B, H, S, D] |
| 67 | + assert input_pos.shape[0] == k_val.shape[2] |
| 68 | + |
| 69 | + k_out = self.k_cache |
| 70 | + v_out = self.v_cache |
| 71 | + k_out[:, :, input_pos] = k_val |
| 72 | + v_out[:, :, input_pos] = v_val |
| 73 | + |
| 74 | + return k_out, v_out |
| 75 | + |
| 76 | +class Transformer(nn.Module): |
| 77 | + def __init__(self, config: ModelArgs) -> None: |
| 78 | + super().__init__() |
| 79 | + self.config = config |
| 80 | + |
| 81 | + self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim) |
| 82 | + self.layers = nn.ModuleList(TransformerBlock(config) for _ in range(config.n_layer)) |
| 83 | + self.norm = RMSNorm(config.dim, eps=config.norm_eps) |
| 84 | + self.output = nn.Linear(config.dim, config.vocab_size, bias=False) |
| 85 | + |
| 86 | + self.freqs_cis: Optional[Tensor] = None |
| 87 | + self.mask_cache: Optional[Tensor] = None |
| 88 | + self.max_batch_size = -1 |
| 89 | + self.max_seq_length = -1 |
| 90 | + |
| 91 | + def setup_caches(self, max_batch_size, max_seq_length): |
| 92 | + if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: |
| 93 | + return |
| 94 | + head_dim = self.config.dim // self.config.n_head |
| 95 | + max_seq_length = find_multiple(max_seq_length, 8) |
| 96 | + self.max_seq_length = max_seq_length |
| 97 | + self.max_batch_size = max_batch_size |
| 98 | + for b in self.layers: |
| 99 | + b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim) |
| 100 | + |
| 101 | + self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base) |
| 102 | + self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) |
| 103 | + |
| 104 | + def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: |
| 105 | + assert self.freqs_cis is not None, "Caches must be initialized first" |
| 106 | + mask = self.causal_mask[None, None, input_pos] |
| 107 | + freqs_cis = self.freqs_cis[input_pos] |
| 108 | + x = self.tok_embeddings(idx) |
| 109 | + |
| 110 | + for i, layer in enumerate(self.layers): |
| 111 | + x = layer(x, input_pos, freqs_cis, mask) |
| 112 | + x = self.norm(x) |
| 113 | + logits = self.output(x) |
| 114 | + return logits |
| 115 | + |
| 116 | + @classmethod |
| 117 | + def from_name(cls, name: str): |
| 118 | + return cls(ModelArgs.from_name(name)) |
| 119 | + |
| 120 | + |
| 121 | +class TransformerBlock(nn.Module): |
| 122 | + def __init__(self, config: ModelArgs) -> None: |
| 123 | + super().__init__() |
| 124 | + self.attention = Attention(config) |
| 125 | + self.block_sparse_moe = MOEFeedForwardAOQuantizable(config) |
| 126 | + self.ffn_norm = RMSNorm(config.dim, config.norm_eps) |
| 127 | + self.attention_norm = RMSNorm(config.dim, config.norm_eps) |
| 128 | + |
| 129 | + def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: |
| 130 | + h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) |
| 131 | + out = h + self.block_sparse_moe(self.ffn_norm(h)) |
| 132 | + return out |
| 133 | + |
| 134 | + |
| 135 | +class Attention(nn.Module): |
| 136 | + def __init__(self, config: ModelArgs): |
| 137 | + super().__init__() |
| 138 | + assert config.dim % config.n_head == 0 |
| 139 | + |
| 140 | + total_head_dim = (config.n_head + 2 * config.n_local_heads) * config.head_dim |
| 141 | + # key, query, value projections for all heads, but in a batch |
| 142 | + self.wqkv = nn.Linear(config.dim, total_head_dim, bias=False) |
| 143 | + self.wo = nn.Linear(config.dim, config.dim, bias=False) |
| 144 | + self.kv_cache = None |
| 145 | + |
| 146 | + self.n_head = config.n_head |
| 147 | + self.head_dim = config.head_dim |
| 148 | + self.n_local_heads = config.n_local_heads |
| 149 | + self.dim = config.dim |
| 150 | + self._register_load_state_dict_pre_hook(self.load_hook) |
| 151 | + |
| 152 | + def load_hook(self, state_dict, prefix, *args): |
| 153 | + if prefix + "wq.weight" in state_dict: |
| 154 | + wq = state_dict.pop(prefix + "wq.weight") |
| 155 | + wk = state_dict.pop(prefix + "wk.weight") |
| 156 | + wv = state_dict.pop(prefix + "wv.weight") |
| 157 | + state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) |
| 158 | + |
| 159 | + def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: |
| 160 | + bsz, seqlen, _ = x.shape |
| 161 | + |
| 162 | + kv_size = self.n_local_heads * self.head_dim |
| 163 | + q, k, v = self.wqkv(x).split([self.dim, kv_size, kv_size], dim=-1) |
| 164 | + |
| 165 | + q = q.view(bsz, seqlen, self.n_head, self.head_dim) |
| 166 | + k = k.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| 167 | + v = v.view(bsz, seqlen, self.n_local_heads, self.head_dim) |
| 168 | + |
| 169 | + q = apply_rotary_emb(q, freqs_cis) |
| 170 | + k = apply_rotary_emb(k, freqs_cis) |
| 171 | + |
| 172 | + q, k, v = map(lambda x: x.transpose(1, 2), (q, k, v)) |
| 173 | + |
| 174 | + if self.kv_cache is not None: |
| 175 | + k, v = self.kv_cache.update(input_pos, k, v) |
| 176 | + |
| 177 | + k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) |
| 178 | + v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) |
| 179 | + y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) |
| 180 | + |
| 181 | + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) |
| 182 | + |
| 183 | + y = self.wo(y) |
| 184 | + return y |
| 185 | + |
| 186 | + |
| 187 | +class ConditionalFeedForward(nn.Module): |
| 188 | + def __init__(self, config): |
| 189 | + super().__init__() |
| 190 | + self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) |
| 191 | + self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) |
| 192 | + self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) |
| 193 | + |
| 194 | + def forward(self, x: Tensor, expert_indices: Tensor) -> Tensor: |
| 195 | + w1_weights = self.w1[expert_indices] # [T, A, D, D] |
| 196 | + w3_weights = self.w3[expert_indices] # [T, A, D, D] |
| 197 | + w2_weights = self.w2[expert_indices] # [T, A, D, D] |
| 198 | + x1 = F.silu(torch.einsum('ti,taoi -> tao', x, w1_weights)) |
| 199 | + x3 = torch.einsum('ti, taoi -> tao', x, w3_weights) |
| 200 | + expert_outs = torch.einsum('tao, taio -> tai', (x1 * x3), w2_weights) |
| 201 | + return expert_outs |
| 202 | + |
| 203 | + |
| 204 | +class MOEFeedForward(nn.Module): |
| 205 | + def __init__(self, config) -> None: |
| 206 | + super().__init__() |
| 207 | + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) |
| 208 | + self.cond_ffn = ConditionalFeedForward(config) |
| 209 | + self.dim = config.dim |
| 210 | + self.num_activated_experts = config.num_activated_experts |
| 211 | + def forward(self, x: Tensor) -> Tensor: |
| 212 | + x = x.view(-1, self.dim) |
| 213 | + # T = num_tokens, E = num_experts, D = hidden dim, A = activated experts |
| 214 | + # x: [T, D] |
| 215 | + scores = self.gate(x) # [T, E] |
| 216 | + expert_weights = F.softmax(scores, dim=-1) |
| 217 | + expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] |
| 218 | + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) # [T, A] |
| 219 | + expert_outs = self.cond_ffn(x, expert_indices) |
| 220 | + return torch.einsum('tai,ta -> ti', expert_outs, expert_weights) |
| 221 | + |
| 222 | + |
| 223 | +class RMSNorm(nn.Module): |
| 224 | + def __init__(self, dim: int, eps: float = 1e-5): |
| 225 | + super().__init__() |
| 226 | + self.eps = eps |
| 227 | + self.weight = nn.Parameter(torch.ones(dim)) |
| 228 | + |
| 229 | + def _norm(self, x): |
| 230 | + return x * torch.rsqrt(torch.mean(x * x, dim=-1, keepdim=True) + self.eps) |
| 231 | + |
| 232 | + def forward(self, x: Tensor) -> Tensor: |
| 233 | + output = self._norm(x.float()).type_as(x) |
| 234 | + return output * self.weight |
| 235 | + |
| 236 | + |
| 237 | +def precompute_freqs_cis( |
| 238 | + seq_len: int, n_elem: int, base: int = 10000 |
| 239 | +) -> Tensor: |
| 240 | + freqs = 1.0 / (base ** (torch.arange(0, n_elem, 2)[: (n_elem // 2)].float() / n_elem)) |
| 241 | + t = torch.arange(seq_len, device=freqs.device) |
| 242 | + freqs = torch.outer(t, freqs) |
| 243 | + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) |
| 244 | + cache = torch.stack([freqs_cis.real, freqs_cis.imag], dim=-1) |
| 245 | + return cache.to(dtype=torch.bfloat16) |
| 246 | + |
| 247 | + |
| 248 | +def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: |
| 249 | + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) |
| 250 | + freqs_cis = freqs_cis.view(1, xshaped.size(1), 1, xshaped.size(3), 2) |
| 251 | + x_out2 = torch.stack( |
| 252 | + [ |
| 253 | + xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], |
| 254 | + xshaped[..., 1] * freqs_cis[..., 0] + xshaped[..., 0] * freqs_cis[..., 1], |
| 255 | + ], |
| 256 | + -1, |
| 257 | + ) |
| 258 | + |
| 259 | + x_out2 = x_out2.flatten(3) |
| 260 | + return x_out2.type_as(x) |
| 261 | + |
| 262 | + |
| 263 | +# T tokens |
| 264 | +# E experts |
| 265 | +# D dim |
| 266 | +# I intermediate dim |
| 267 | +# A activated experts |
| 268 | +# T'(e) tokens for expert e |
| 269 | + |
| 270 | +class MOEFeedForwardAOQuantizable(nn.Module): |
| 271 | + def __init__(self, config) -> None: |
| 272 | + super().__init__() |
| 273 | + self.gate = nn.Linear(config.dim, config.num_experts, bias=False) |
| 274 | + self.cond_ffn = ConditionalFeedForwardAOQuantizable(config) |
| 275 | + self.dim = config.dim |
| 276 | + self.num_activated_experts = config.num_activated_experts |
| 277 | + def forward(self, x: Tensor) -> Tensor: |
| 278 | + batch_size = x.shape[0] |
| 279 | + x = x.view(-1, self.dim) # x: [T, D] |
| 280 | + scores = self.gate(x) # [T, E] |
| 281 | + expert_weights = F.softmax(scores, dim=-1) |
| 282 | + expert_weights, expert_indices = torch.topk(expert_weights, self.num_activated_experts, dim=-1) # [T, A], [T, A] |
| 283 | + expert_weights /= expert_weights.sum(dim=-1, keepdim=True).to(x.dtype) # [T, A] |
| 284 | + out = self.cond_ffn(x, expert_indices, expert_weights, self.num_activated_experts) |
| 285 | + return out.reshape(batch_size, -1, self.dim) |
| 286 | + |
| 287 | + |
| 288 | +class ConditionalFeedForwardAOQuantizable(nn.Module): |
| 289 | + def __init__(self, config): |
| 290 | + super().__init__() |
| 291 | + self.config = config |
| 292 | + self.w1 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D |
| 293 | + self.w2 = nn.Parameter(torch.empty(config.num_experts, config.dim, config.intermediate_size)) # E, D, I |
| 294 | + self.w3 = nn.Parameter(torch.empty(config.num_experts, config.intermediate_size, config.dim)) # E, I, D |
| 295 | + self.num_experts = config.num_experts |
| 296 | + def forward( |
| 297 | + self, x: Tensor, # T, D |
| 298 | + expert_indices: Tensor, # T, A |
| 299 | + expert_weights: Tensor, # T, A |
| 300 | + num_activated_experts: int, |
| 301 | + ) -> Tensor: |
| 302 | + num_tokens, dim = x.shape |
| 303 | + num_token_activations = num_tokens * num_activated_experts |
| 304 | + |
| 305 | + if x.shape[0]==1: #only 1 token (can be done without graph breaks when compiled) |
| 306 | + outs = [] |
| 307 | + expert_indices=expert_indices.squeeze() |
| 308 | + # collect used experts |
| 309 | + w1 = self.w1[expert_indices] |
| 310 | + w2 = self.w2[expert_indices] |
| 311 | + w3 = self.w3[expert_indices] |
| 312 | + |
| 313 | + # run token through each expert |
| 314 | + for index in range(num_activated_experts): |
| 315 | + y1 = F.silu(F.linear(x, w1[index])) |
| 316 | + y3 = F.linear(x, w3[index]) |
| 317 | + y2 = w2[index] |
| 318 | + cur_out = F.linear( y1 * y3, y2) |
| 319 | + outs.append(cur_out) |
| 320 | + |
| 321 | + # combine outputs |
| 322 | + final_out = (torch.cat(outs, dim=0) * expert_weights.view(-1,1)).sum(dim=0).unsqueeze(-1) |
| 323 | + return final_out |
| 324 | + else: |
| 325 | + expert_list = [x for x in range(self.num_experts)] |
| 326 | + |
| 327 | + # shuffle tokens into groups for each expert |
| 328 | + ordered_token_activations = expert_indices.view(-1).argsort(stable=True) # [A] |
| 329 | + ordered_token_indices = ordered_token_activations.div(num_activated_experts).floor().to(torch.int64) # [T] |
| 330 | + |
| 331 | + num_tokens_per_expert = torch.histc(expert_indices, bins=self.num_experts+1, min=-1, max=self.num_experts) # [E+1] (added leading 0 so can be used for indexing) |
| 332 | + cum_tokens_per_expert = num_tokens_per_expert.cumsum(0).to(torch.int64) # [E+1] |
| 333 | + |
| 334 | + @torch._dynamo.disable() |
| 335 | + def group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list): |
| 336 | + token_indices_per_expert = [ordered_token_indices[cum_tokens_per_expert[expert]:cum_tokens_per_expert[expert+1]] for expert in expert_list] # [T'(e1)], [T'(e2)] ... |
| 337 | + return token_indices_per_expert |
| 338 | + token_indices_per_expert = group_tokens_by_expert(ordered_token_indices, cum_tokens_per_expert, expert_list) |
| 339 | + tokens_grouped_by_expert = [x[indices] for indices in token_indices_per_expert] |
| 340 | + |
| 341 | + # calculate outputs for each expert |
| 342 | + outs = [] |
| 343 | + for cur_x, expert in zip(tokens_grouped_by_expert,expert_list): |
| 344 | + |
| 345 | + w1=self.w1[expert] # I, D |
| 346 | + w2=self.w2[expert] # D, I |
| 347 | + w3=self.w3[expert] # I, D |
| 348 | + |
| 349 | + cur_out = F.linear( F.silu(F.linear(cur_x, w1)) * F.linear(cur_x, w3), w2) # [T'(e), D] |
| 350 | + outs.append(cur_out) |
| 351 | + |
| 352 | + # weigh outputs |
| 353 | + ordered_outs = torch.cat(outs, dim=0) # [T*A, D] |
| 354 | + ordered_token_activation_weights = expert_weights.view(-1,1)[ordered_token_activations].view(-1,1) # [T*A, 1] |
| 355 | + weighted_ordered_outs = ordered_outs*ordered_token_activation_weights # [T*A, D] |
| 356 | + |
| 357 | + # sum weighted token-activation outputs together for each token |
| 358 | + final_out = torch.zeros_like(x) # [T, D] |
| 359 | + final_out = final_out.scatter_add(dim=0, index=ordered_token_indices.unsqueeze(-1).expand(num_token_activations,dim).to(torch.int64), src=weighted_ordered_outs) |
| 360 | + return final_out |
0 commit comments