Skip to content
This repository was archived by the owner on Aug 22, 2024. It is now read-only.

Commit b0f174d

Browse files
update
1 parent 8f0927c commit b0f174d

13 files changed

+180
-243
lines changed

Extension.cs

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,30 @@
1+
using Phi.Pipeline;
12
using System.Text;
23
using TorchSharp;
34
using static TorchSharp.torch;
45

56
public static class Extension
67
{
78
public static string Generate(
8-
this Phi2ForCasualLM phi,
9+
this CasualLMPipeline pipeline,
910
string prompt,
1011
int maxLen = 128,
1112
float temperature = 0.7f,
1213
float topP = 0.9f,
13-
string[]? stopSequences = null)
14+
string[]? stopSequences = null,
15+
string device = "cpu",
16+
bool bos = true,
17+
bool eos = false,
18+
bool echo = false)
1419
{
15-
var inputIds = phi.Tokenizer.Encode(prompt);
16-
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: phi.Device).unsqueeze(0);
20+
var inputIds = pipeline.Tokenizer.Encode(prompt, bos, eos);
21+
var inputTensor = torch.tensor(inputIds.ToArray(), dtype: ScalarType.Int64, device: device).unsqueeze(0);
1722
var attentionMask = torch.ones_like(inputTensor);
18-
var stopTokenIds = stopSequences == null ? [[ phi.Tokenizer.EosId ]] : stopSequences.Select(x => phi.Tokenizer.Encode(x)).ToArray();
19-
(var token, var _) = phi.Generate(inputTensor, attentionMask, temperature: temperature, maxLen: maxLen, topP: topP, stopTokenSequence: stopTokenIds);
23+
var stopTokenIds = stopSequences == null ? [[ pipeline.Tokenizer.EosId ]] : stopSequences.Select(x => pipeline.Tokenizer.Encode(x, bos, eos)).ToArray();
24+
(var token, var _) = pipeline.Generate(inputTensor, attentionMask, temperature: temperature, maxLen: maxLen, topP: topP, stopTokenSequence: stopTokenIds, echo: echo);
2025

2126
var tokenIds = token[0].to_type(ScalarType.Int32).data<int>().ToArray();
22-
var output = phi.Tokenizer.Decode(tokenIds);
27+
var output = pipeline.Tokenizer.Decode(tokenIds);
2328
return output;
2429

2530
}

Module/Phi3MLP.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ public Phi3MLP(Phi3Config config)
1818
{
1919
this.gate_up_proj = torch.nn.Linear(config.HiddenSize, 2 * config.IntermediateSize, hasBias: false, dtype: config.DType);
2020
this.down_proj = torch.nn.Linear(config.IntermediateSize, config.HiddenSize, hasBias: false, dtype: config.DType);
21-
this.activation_fn = new NewGELUActivation();
21+
this.activation_fn = Utils.GetActivation(config.HiddenAct);
2222
}
2323
public override Tensor forward(Tensor input)
2424
{

Module/Phi3Model.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,12 @@ public CasualLMModelOutput(
5555
{
5656
this.last_hidden_state = last_hidden_state;
5757
this.hidden_states = hidden_states;
58-
this.legits = legits;
58+
this.logits = legits;
5959
this.attentions = attentions;
6060
this.past_key_values = past_key_values;
6161
}
6262

63-
public Tensor legits { get; set; }
63+
public Tensor logits { get; set; }
6464

6565
public Tensor last_hidden_state { get; set; }
6666

Phi2/Model.cs

Lines changed: 0 additions & 113 deletions
This file was deleted.

0 commit comments

Comments
 (0)