Skip to content

Commit 64dc9ff

Browse files
committed
Add roberta model for question answering using SQuADv2.0 dataset.
1 parent aeee780 commit 64dc9ff

14 files changed

+2633
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,194 @@
1+
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2+
using System;
3+
using System.Collections.Generic;
4+
using System.IO;
5+
using System.Linq;
6+
using System.Text.Json;
7+
using Examples.Utils;
8+
using TorchSharp;
9+
using TorchSharp.Examples;
10+
11+
namespace CSharpExamples
12+
{
13+
public class QuestionAnsweringInference
14+
{
15+
internal static void Run()
16+
{
17+
// Configure the run
18+
var config = new QuestionAnsweringConfig
19+
{
20+
LoadModelPath = "roberta-bertformat-model_weights.dat",
21+
DataDir = "data",
22+
TrainFile = "mixed_train.json",
23+
ValidFile = "mixed_valid.json",
24+
TestFile = "test.json",
25+
VocabDir = "vocab_files",
26+
27+
BatchSize = 8,
28+
OptimizeSteps = 1,
29+
MaxSequence = 384,
30+
Cuda = true,
31+
SaveDir = "saved_models",
32+
33+
LearningRate = 3e-5,
34+
LogEveryNSteps = 10,
35+
ValidateEveryNSteps = 2000,
36+
TopK = 5
37+
};
38+
Directory.CreateDirectory(config.SaveDir);
39+
40+
// Initialize Model, Optimizer and Data Pre-processors
41+
var runner = new QuestionAnsweringInference(config);
42+
43+
// Load Pre-trained General Purpose Model
44+
runner.LoadModel(config.LoadModelPath);
45+
46+
// Load Corpus from Disk
47+
var corpus = runner.LoadCorpus(Path.Join(config.DataDir, config.TestFile));
48+
49+
// Start Inference Loop
50+
runner.SearchOverCorpus(corpus);
51+
}
52+
53+
private static readonly Logger<QuestionAnsweringInference> _logger = new();
54+
private const string _exit = "exit";
55+
56+
private QuestionAnsweringConfig Config { get; }
57+
private RobertaForQuestionAnswering Model { get; }
58+
private RobertaTokenizer Tokenizer { get; }
59+
private RobertaInputBuilder InputBuilder { get; }
60+
61+
private QuestionAnsweringInference(QuestionAnsweringConfig config)
62+
{
63+
Config = config;
64+
65+
Model = new RobertaForQuestionAnswering(
66+
numLayers: 12,
67+
numAttentionHeads: 12,
68+
numEmbeddings: 50265,
69+
embeddingSize: 768,
70+
hiddenSize: 768,
71+
outputSize: 768,
72+
ffnHiddenSize: 3072,
73+
maxPositions: 512,
74+
maxTokenTypes: 2,
75+
layerNormEps: 1e-12,
76+
embeddingDropoutRate: 0.1,
77+
attentionDropoutRate: 0.1,
78+
attentionOutputDropoutRate: 0.1,
79+
outputDropoutRate: 0.1);
80+
if (config.Cuda) Model.cuda();
81+
82+
Tokenizer = new RobertaTokenizer(config.VocabDir);
83+
InputBuilder = new RobertaInputBuilder(Tokenizer, config.MaxSequence);
84+
}
85+
86+
public void LoadModel(string path)
87+
{
88+
_logger.Log($"Loading model from {path}...", newline: false);
89+
Model.load(path, false);
90+
if (Config.Cuda) Model.cuda();
91+
_logger.LogAppend("Done.");
92+
}
93+
94+
public SquadCorpus LoadCorpus(string path)
95+
{
96+
return new SquadCorpus(path, Tokenizer, InputBuilder);
97+
}
98+
99+
private void ModelForward(SquadSampleBatch batch, bool applyPredictMasks,
100+
out int trueBatchSize, out torch.Tensor startLogits, out torch.Tensor endLogits,
101+
out torch.Tensor startPositions, out torch.Tensor endPositions)
102+
{
103+
trueBatchSize = (int)batch.Tokens.size(0);
104+
(startLogits, endLogits) = Model.forward(batch.Tokens, batch.Positions, batch.Segments, batch.AttentionMasks);
105+
if (applyPredictMasks)
106+
{
107+
startLogits = startLogits.add_(batch.PredictMasks);
108+
endLogits = endLogits.add_(batch.PredictMasks);
109+
}
110+
111+
startPositions = null;
112+
endPositions = null;
113+
if (batch.Starts.IsNotNull())
114+
{
115+
var ignoreIndex = startLogits.size(-1);
116+
startPositions = batch.Starts.view(-1).clamp(0, ignoreIndex);
117+
endPositions = batch.Ends.view(-1).clamp(0, ignoreIndex);
118+
}
119+
}
120+
121+
/// <summary>
122+
/// Save GPU memory usage following this passage:
123+
/// https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#pre-allocate-memory-in-case-of-variable-input-length
124+
/// </summary>
125+
private void GpuMemoryWarmupOnlyForward()
126+
{
127+
using var disposeScope = torch.NewDisposeScope();
128+
Model.eval();
129+
var batch = SquadCorpus.GetMaxDummyBatch(Config);
130+
ModelForward(batch, false, out var trueBatchSize, out var startLogits, out var endLogits,
131+
out var startPositions, out var endPositions);
132+
}
133+
134+
public void SearchOverCorpus(SquadCorpus corpus)
135+
{
136+
var serializerOptions = new JsonSerializerOptions
137+
{
138+
WriteIndented = true
139+
};
140+
var selector = new TfIdfDocumentSelector(corpus.Documents, Tokenizer);
141+
142+
using var _ = torch.no_grad();
143+
GpuMemoryWarmupOnlyForward();
144+
Model.eval();
145+
while (true)
146+
{
147+
Console.Clear();
148+
Console.Write($"\nType your question (\"{_exit}\" to exit): ");
149+
var question = Console.ReadLine();
150+
if (question == _exit) break;
151+
152+
var questionTokenIds = Tokenizer.TokenizeToId(question);
153+
var questionLength = questionTokenIds.Count + 2;
154+
155+
var answers = new List<PredictionAnswer>();
156+
var bestMatch = selector.TopK(question, Config.TopK);
157+
foreach (var batch in corpus.GetBatches(Config, questionTokenIds, bestMatch.Take(1).ToArray()))
158+
{
159+
using var disposeScope = torch.NewDisposeScope();
160+
ModelForward(batch, true, out var trueBatchSize, out var startLogits, out var endLogits,
161+
out var startPositions, out var endPositions);
162+
for (var i = 0; i < trueBatchSize; ++i)
163+
{
164+
var (predictStartScores, predictStarts) = startLogits[i].topk(Config.TopK);
165+
var (predictEndScores, predictEnds) = endLogits[i].topk(Config.TopK);
166+
var topKSpans = SquadMetric.ComputeTopKSpansWithScore(predictStartScores, predictStarts, predictEndScores, predictEnds, Config.TopK);
167+
var predictStart = topKSpans[0].start;
168+
var predictEnd = topKSpans[0].end;
169+
170+
// Restore predicted answer text
171+
var document = bestMatch[i];
172+
var contextText = Tokenizer.Untokenize(document.ContextTokens);
173+
foreach (var (start, end, score) in topKSpans)
174+
{
175+
var answerText = Tokenizer.Untokenize(
176+
document.ContextTokens.ToArray()[(start - questionLength)..(end - questionLength + 1)]);
177+
answers.Add(new PredictionAnswer { Score = score, Text = answerText });
178+
}
179+
}
180+
181+
answers = answers.OrderByDescending(answer => answer.Score).Take(Config.TopK).ToList();
182+
var outputString = JsonSerializer.Serialize(answers, serializerOptions);
183+
Console.WriteLine($"Predictions:\n{outputString}");
184+
} // end foreach
185+
} // end while
186+
}
187+
}
188+
189+
internal struct PredictionAnswer
190+
{
191+
public string Text { get; set; }
192+
public double Score { get; set; }
193+
}
194+
}

0 commit comments

Comments
 (0)