Skip to content
This repository was archived by the owner on Aug 22, 2024. It is now read-only.

Commit dbd57da

Browse files
update
1 parent b0f174d commit dbd57da

10 files changed

+137
-40
lines changed

Extension.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public static string Peek(this Tensor tensor, string id, int n = 10)
4848
avg = avg.round(4);
4949
var str = $"{id}: sum: {avg.ToSingle()} dtype: {dtype} shape: [{shapeString}]";
5050

51-
//Console.WriteLine(str);
51+
Console.WriteLine(str);
5252

5353
return str;
5454
}

Module/Phi2Attention.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -129,8 +129,8 @@ public override (Tensor, Tensor?, Tensor?) forward(
129129
this.cache_v[..batchSize, .., pastKeyValueLength..kvSeqLen, ..] = valueStates;
130130
keyStates = this.cache_k[..batchSize, .., ..kvSeqLen, ..];
131131
valueStates = this.cache_v[..batchSize, .., ..kvSeqLen, ..];
132-
var keyStates2 = Utils.RepeatKV(keyStates, this.numKeyValueGroups).transpose(2, 3);
133-
var valueStates2 = Utils.RepeatKV(valueStates, this.numKeyValueGroups);
132+
var keyStates2 = Utils.Phi2RepeatKV(keyStates, this.numKeyValueGroups).transpose(2, 3);
133+
var valueStates2 = Utils.Phi2RepeatKV(valueStates, this.numKeyValueGroups);
134134
// Queries and keys upcast to fp32 is required by Phi-2 to avoid overflow
135135
var attnWeights = torch.matmul(queryStates.to_type(float32), keyStates2.to_type(float32));
136136
attnWeights = attnWeights / Math.Sqrt(this.headDim);

Module/Phi3Attention.cs

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -123,10 +123,9 @@ public override Phi3AttentionOutput forward(Phi3AttentionInput input)
123123

124124
var qkv = this.qkv_proj.forward(hidden_states);
125125
var query_pos = this.num_heads * this.head_dim;
126-
var query_states = qkv[.., ..query_pos];
127-
var key_states = qkv[.., query_pos .. (query_pos + this.num_key_value_heads * this.head_dim)];
128-
var value_states = qkv[.., (query_pos + this.num_key_value_heads * this.head_dim)..];
129-
126+
var query_states = qkv[.., .., ..query_pos];
127+
var key_states = qkv[.., .., query_pos .. (query_pos + this.num_key_value_heads * this.head_dim)];
128+
var value_states = qkv[.., .., (query_pos + this.num_key_value_heads * this.head_dim)..];
130129
query_states = query_states.view(bsz, q_len, this.num_heads, this.head_dim).transpose(1, 2);
131130
key_states = key_states.view(bsz, q_len, this.num_key_value_heads, this.head_dim).transpose(1, 2);
132131
value_states = value_states.view(bsz, q_len, this.num_key_value_heads, this.head_dim).transpose(1, 2);
@@ -138,19 +137,23 @@ public override Phi3AttentionOutput forward(Phi3AttentionInput input)
138137
kv_seq_len += past_key_value.GetUsableLength(kv_seq_len, this.layer_idx);
139138
}
140139

141-
var embOutput = this.rotary_emb.forward(new Phi3RotaryEmbeddingInput(q_len, kv_seq_len, this.layer_idx));
140+
var embOutput = this.rotary_emb.forward(new Phi3RotaryEmbeddingInput(value_states, positionIds, kv_seq_len));
142141
(var cos, var sin) = (embOutput.Cos, embOutput.Sin);
143142

144-
(query_states, key_states) = Utils.ApplyRotaryPosEmb(query_states, key_states, cos, sin, positionIds);
143+
query_states.Peek("query_states");
144+
key_states.Peek("key_states");
145+
cos.Peek("cos");
146+
sin.Peek("sin");
147+
(query_states, key_states) = Utils.ApplyRotaryPosEmb(query_states, key_states, cos, sin);
145148

146149
if (past_key_value is not null)
147150
{
148151
(key_states, value_states) = past_key_value.UpdateKVCache(key_states, value_states, this.layer_idx);
149152
}
150153

151154
// repeat k/v heads if n_kv_heads < n_heads
152-
key_states = Utils.RepeatKV(key_states, this.num_key_value_heads);
153-
value_states = Utils.RepeatKV(value_states, this.num_key_value_heads);
155+
key_states = Utils.Phi3RepeatKV(key_states, this.num_key_value_groups);
156+
value_states = Utils.Phi3RepeatKV(value_states, this.num_key_value_groups);
154157

155158
var attn_weights = torch.matmul(query_states, key_states.transpose(2, 3));
156159
attn_weights = attn_weights / Math.Sqrt(this.head_dim);
@@ -160,7 +163,7 @@ public override Phi3AttentionOutput forward(Phi3AttentionInput input)
160163
var attention_mask = input.attention_mask;
161164
if (attention_mask is not null)
162165
{
163-
attention_mask.shape.Should().BeEquivalentTo(new long[] { bsz, 1, 1, kv_seq_len });
166+
attention_mask.shape.Should().BeEquivalentTo(new long[] { bsz, 1, q_len, kv_seq_len });
164167
attn_weights = attn_weights + attention_mask;
165168
}
166169

Module/Phi3MLP.cs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,10 @@ public Phi3MLP(Phi3Config config)
2323
public override Tensor forward(Tensor input)
2424
{
2525
using var input1 = this.gate_up_proj.forward(input);
26-
using var input2 = this.activation_fn.forward(input1);
27-
return this.down_proj.forward(input2);
26+
var chunks = input1.chunk(2, dim: -1);
27+
var gate = chunks[0];
28+
var up_status = chunks[1];
29+
up_status = up_status * this.activation_fn.forward(gate);
30+
return this.down_proj.forward(up_status);
2831
}
2932
}

Module/Phi3Model.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -151,10 +151,11 @@ public override CasualLMModelOutput forward(CasualLMModelInput input)
151151
}
152152
else
153153
{
154-
attention_mask = AttentionMaskConverter.Create4DCasualAttentionMask(attention_mask, [batch_size, seq_length], inputs_embeds.dtype, device, past_key_values_length, this.config.SlidingWindow);
154+
attention_mask = AttentionMaskConverter.Create4DCausalAttentionMask(attention_mask, [batch_size, seq_length], inputs_embeds.dtype, device, past_key_values_length, this.config.SlidingWindow);
155155
}
156156

157157
var hidden_states = inputs_embeds;
158+
hidden_states.Peek("hidden_states");
158159

159160
var all_hidden_states = new List<Tensor>();
160161
var all_attentions = new List<Tensor>();

Module/Phi3RotaryEmbedding.cs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,15 @@ public override Phi3RotaryEmbeddingOutput forward(Phi3RotaryEmbeddingInput input
5959
// TODO
6060
// can be calculated once and cached
6161
var inv_freq = this.get_buffer("inv_freq").to(x.device);
62-
var inv_freq_expanded = inv_freq.unsqueeze(0).unsqueeze(-1).expand(position_ids.shape[0], -1, 1);
63-
//position_ids_expanded = position_ids[:, None, :].float()
62+
var inv_freq_expanded = inv_freq.unsqueeze(0).unsqueeze(-1);
63+
inv_freq_expanded.Peek("inv_freq_expanded");
64+
position_ids.Peek("position_ids");
65+
inv_freq_expanded.Peek("inv_freq_expanded");
66+
inv_freq_expanded = inv_freq_expanded.expand(new long[] { position_ids.shape[0], -1, 1 });
67+
6468
var position_ids_expanded = position_ids.unsqueeze(1).to(torch.float32);
6569
var freqs = inv_freq_expanded * position_ids_expanded;
70+
freqs = freqs.transpose(1, 2);
6671
var emb = torch.cat([freqs, freqs], dim: -1);
6772

6873
var cos = torch.cos(emb);

Phi3/Phi3ForCasualLM.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ public override CasualLMModelOutput forward(CasualLMModelInput input)
4040
public static Phi3ForCasualLM FromPretrained(
4141
string modelFolder,
4242
string configName = "config.json",
43-
string modelWeightName = "model.safetensors.index.json",
43+
string checkPointName = "model.safetensors.index.json",
4444
ScalarType torchDtype = ScalarType.BFloat16,
4545
string device = "cpu")
4646
{
@@ -49,7 +49,7 @@ public static Phi3ForCasualLM FromPretrained(
4949
modelConfig.DType = torchDtype;
5050
var phi = new Phi3ForCasualLM(modelConfig);
5151
var loadedParameters = new Dictionary<string, bool>();
52-
phi.load_checkpoint(path: modelFolder, checkpointName: modelWeightName, strict: false, loadedParameters: loadedParameters);
52+
phi.load_checkpoint(path: modelFolder, checkpointName: checkPointName, strict: false, loadedParameters: loadedParameters);
5353
phi = phi.to(device);
5454
phi.eval();
5555

Program.cs

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,14 @@
11
using System.Runtime.InteropServices;
22
using FluentAssertions;
3+
using Phi;
34
using Phi.Pipeline;
45
using TorchSharp;
56
using static TorchSharp.torch;
67

78
// Dynamic loading libtorch because Cuda 12 only support GPU driver >= 520
89
// And I can't upgrade GPU driver because it's a cloud machine.
910

10-
var phi2Folder = @"C:\Users\xiaoyuz\source\repos\phi-2";
11+
var phi2Folder = @"C:\Users\xiaoyuz\source\repos\Phi-3-mini-4k-instruct";
1112
var device = "cuda";
1213

1314
if (device == "cuda")
@@ -16,15 +17,14 @@
1617
torch.cuda.is_available().Should().BeTrue();
1718
}
1819

19-
var defaultType = ScalarType.Float16;
20-
torch.set_default_dtype(defaultType);
20+
var defaultType = ScalarType.BFloat16;
2121
torch.manual_seed(1);
2222

2323
Console.WriteLine("Loading Phi2 from huggingface model weight folder");
2424
var timer = System.Diagnostics.Stopwatch.StartNew();
25-
var phi2 = Phi2ForCasualLM.FromPretrained(phi2Folder, device: device, torchDtype: defaultType, checkPointName: "model.safetensors.index.json");
26-
var tokenizer = Phi2Tokenizer.FromPretrained(phi2Folder);
27-
var pipeline = new CasualLMPipeline(tokenizer, phi2, device);
25+
var model = Phi3ForCasualLM.FromPretrained(phi2Folder, device: device, torchDtype: defaultType, checkPointName: "model.safetensors.index.json");
26+
var tokenizer = LLama2Tokenizer.FromPretrained(phi2Folder);
27+
var pipeline = new CasualLMPipeline(tokenizer, model, device);
2828

2929

3030
timer.Stop();
@@ -34,11 +34,10 @@
3434
int maxLen = 512;
3535
float temperature = 0.0f;
3636
Console.WriteLine($"QA Format: maxLen: {maxLen} temperature: {temperature}");
37-
var prompt = "Instruct: A skier slides down a frictionless slope of height 40m and length 80m, what's the skier's speed at the bottom, think step by step.\nOutput:";
37+
var prompt = "Can you provide ways to eat combinations of bananas and dragonfruits?";
3838
// wait for user to press enter
3939
Console.WriteLine($"Prompt: {prompt}");
4040
Console.WriteLine("Press enter to continue inferencing QA format");
41-
Console.ReadLine();
4241

4342
Console.WriteLine(prompt);
4443
pipeline.Generate(prompt, maxLen: maxLen, temperature: temperature, device: device);

Utils.cs

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ public static Tensor RotateHalf(Tensor x)
133133
// k_embed = (k * cos) + (rotate_half(k) * sin)
134134
// return q_embed, k_embed
135135

136-
public static (Tensor, Tensor) ApplyRotaryPosEmb(Tensor q, Tensor k, Tensor cos, Tensor sin, Tensor positionIds, int unsqueezeDim = 1)
136+
public static (Tensor, Tensor) ApplyRotaryPosEmb(Tensor q, Tensor k, Tensor cos, Tensor sin, Tensor? positionIds = null, int unsqueezeDim = 1)
137137
{
138138
// The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
139139
// sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -142,8 +142,17 @@ public static (Tensor, Tensor) ApplyRotaryPosEmb(Tensor q, Tensor k, Tensor cos,
142142
// cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
143143
// the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
144144

145-
cos = cos[positionIds].unsqueeze(unsqueezeDim);
146-
sin = sin[positionIds].unsqueeze(unsqueezeDim);
145+
if (positionIds is not null)
146+
{
147+
cos = cos[positionIds!].unsqueeze(unsqueezeDim);
148+
sin = sin[positionIds!].unsqueeze(unsqueezeDim);
149+
}
150+
else
151+
{
152+
cos = cos.unsqueeze(unsqueezeDim);
153+
sin = sin.unsqueeze(unsqueezeDim);
154+
}
155+
147156
var qEmbed = q * cos;
148157
qEmbed += RotateHalf(q) * sin;
149158

@@ -162,12 +171,12 @@ public static Module<Tensor, Tensor> GetActivation(string act_fn)
162171
"gelu" => nn.GELU(),
163172
"tanh" => nn.Tanh(),
164173
"swish" => nn.SiLU(),
165-
_ => throw new ArgumentException("Invalid activation function", nameof(act_fn)),
174+
_ => throw new ArgumentException("Invalid activation function", act_fn),
166175
};
167176
}
168177

169178

170-
public static Tensor RepeatKV(Tensor x, int nRep)
179+
public static Tensor Phi2RepeatKV(Tensor x, int nRep)
171180
{
172181
var batchSize = x.shape[0];
173182
var seqLen = x.shape[1];
@@ -183,4 +192,20 @@ public static Tensor RepeatKV(Tensor x, int nRep)
183192
.view(batchSize, seqLen, nKVHeads * nRep, headDim);
184193
}
185194

195+
public static Tensor Phi3RepeatKV(Tensor x, int nRep)
196+
{
197+
var batchSize = x.shape[0];
198+
var nKVHeads = x.shape[1];
199+
var seqLen = x.shape[2];
200+
var headDim = x.shape[3];
201+
if (nRep == 1)
202+
{
203+
return x;
204+
}
205+
206+
return x.unsqueeze(3)
207+
.expand(batchSize, nKVHeads, nRep, seqLen, headDim)
208+
.view(batchSize, nKVHeads * nRep, seqLen, headDim);
209+
}
210+
186211
}

Utils/AttentionMaskConverter.cs

Lines changed: 69 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
using static TorchSharp.torch;
33
using TorchSharp.Modules;
44
using TorchSharp;
5+
using System.Threading.Tasks;
56

67
namespace Phi;
78

@@ -10,12 +11,65 @@ public class AttentionMaskConverter
1011
private readonly bool is_casual;
1112
private readonly int? sliding_window;
1213

13-
public AttentionMaskConverter(bool is_casual, int? sliding_window)
14+
public AttentionMaskConverter(bool is_causal, int? sliding_window)
1415
{
15-
this.is_casual = is_casual;
16+
this.is_casual = is_causal;
1617
this.sliding_window = sliding_window;
1718
}
1819

20+
/// <summary>
21+
/// Converts 2D attention mask to 4D attention mask by expanding mask to (bsz, head_dim=1, query_length,
22+
/// key_value_length) shape and by adding a large negative bias to not-attended positions.If attention_mask is
23+
/// causal, a causal mask will be added.
24+
/// </summary>
25+
/// <param name="attention_mask_2d"></param>
26+
/// <param name="query_length"></param>
27+
/// <param name="dtype"></param>
28+
/// <param name="key_value_length"></param>
29+
/// <returns></returns>
30+
public Tensor To4D(
31+
Tensor attention_mask_2d,
32+
int query_length,
33+
ScalarType dtype,
34+
int? key_value_length = null)
35+
{
36+
long[] input_shape = [attention_mask_2d.shape[0], query_length];
37+
38+
// create causal mask
39+
// [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
40+
Tensor? casual_4d_mask = null;
41+
if ((input_shape[^1] > 1 || this.sliding_window is not null) && this.is_casual)
42+
{
43+
if (key_value_length is null)
44+
{
45+
throw new ArgumentException("key_value_length should be provided when attention_mask is causal");
46+
}
47+
48+
var past_key_values_length = key_value_length.Value - query_length;
49+
casual_4d_mask = MakeCasualMask(input_shape, dtype, attention_mask_2d.device, past_key_values_length, this.sliding_window);
50+
}
51+
else if(this.sliding_window is not null)
52+
{
53+
throw new NotImplementedException("Sliding window is not supported for non-causal masks");
54+
}
55+
56+
var expanded_attn_mask = ExpandMask(attention_mask_2d, dtype, query_length).to(attention_mask_2d.device);
57+
if (casual_4d_mask is not null)
58+
{
59+
var min = dtype switch
60+
{
61+
ScalarType.Float32 => torch.finfo(dtype).min,
62+
ScalarType.Float64 => torch.finfo(dtype).min,
63+
ScalarType.Float16 => -65504.0,
64+
ScalarType.BFloat16 => -65504.0,
65+
_ => throw new ArgumentException("Invalid dtype"),
66+
};
67+
expanded_attn_mask = casual_4d_mask.masked_fill(expanded_attn_mask.to(ScalarType.Bool), min);
68+
}
69+
70+
return expanded_attn_mask;
71+
}
72+
1973
public Tensor? ToCasual4D(
2074
int batch_size,
2175
int query_length,
@@ -57,6 +111,7 @@ public static Tensor MakeCasualMask(
57111
ScalarType.Float32 => torch.finfo(dtype).min,
58112
ScalarType.Float64 => torch.finfo(dtype).min,
59113
ScalarType.Float16 => -65504.0,
114+
ScalarType.BFloat16 => -65504.0,
60115
_ => throw new ArgumentException("Invalid dtype"),
61116
};
62117
var mask = torch.full([tgt_len, tgt_len], min, dtype: dtype, device: device);
@@ -86,23 +141,28 @@ public static Tensor MakeCasualMask(
86141
/// Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)`
87142
/// </summary>
88143
/// <param name="input_shape">The input shape should be a tuple that defines `(batch_size, query_length)`.</param>
89-
public static Tensor? Create4DCasualAttentionMask(
144+
public static Tensor? Create4DCausalAttentionMask(
90145
Tensor? attention_mask,
91146
long[] input_shape,
92147
ScalarType dtype,
93148
Device device,
94149
int past_key_values_length = 0,
95150
int? sliding_window = null)
96151
{
152+
var converter = new AttentionMaskConverter(is_causal: true, sliding_window: sliding_window);
153+
var batch_size = (int)input_shape[0];
154+
var query_length = (int)input_shape[1];
155+
var key_value_length = past_key_values_length + query_length;
97156
if (attention_mask is not null)
98157
{
99-
throw new ArgumentException("This is not a casual mask");
158+
if (attention_mask.ndim != 2)
159+
{
160+
throw new ArgumentException("Attention mask should be 2D");
161+
}
162+
return converter.To4D(attention_mask, (int)input_shape[1], dtype, key_value_length);
100163
}
101164

102-
var batch_size = (int)input_shape[0];
103-
var query_length = (int)input_shape[1];
104-
var converter = new AttentionMaskConverter(is_casual: true, sliding_window: sliding_window);
105-
var key_value_length = past_key_values_length + query_length;
165+
106166
return converter.ToCasual4D(batch_size, query_length, key_value_length, dtype, device);
107167
}
108168

@@ -122,6 +182,7 @@ public static Tensor ExpandMask(
122182
ScalarType.Float32 => torch.finfo(dtype).min,
123183
ScalarType.Float64 => torch.finfo(dtype).min,
124184
ScalarType.Float16 => -65504.0,
185+
ScalarType.BFloat16 => -65504.0,
125186
_ => throw new ArgumentException("Invalid dtype"),
126187
};
127188

0 commit comments

Comments
 (0)