1
+ // Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
2
+ using System ;
3
+ using System . Collections . Generic ;
4
+ using System . IO ;
5
+ using System . Linq ;
6
+ using System . Text . Json ;
7
+ using Examples . Utils ;
8
+ using TorchSharp ;
9
+ using TorchSharp . Examples ;
10
+
11
+ namespace CSharpExamples
12
+ {
13
+ public class QuestionAnsweringInference
14
+ {
15
+ internal static void Run ( )
16
+ {
17
+ // Configure the run
18
+ var config = new QuestionAnsweringConfig
19
+ {
20
+ LoadModelPath = "roberta-bertformat-model_weights.dat" ,
21
+ DataDir = "data" ,
22
+ TrainFile = "mixed_train.json" ,
23
+ ValidFile = "mixed_valid.json" ,
24
+ TestFile = "test.json" ,
25
+ VocabDir = "vocab_files" ,
26
+
27
+ BatchSize = 8 ,
28
+ OptimizeSteps = 1 ,
29
+ MaxSequence = 384 ,
30
+ Cuda = true ,
31
+ SaveDir = "saved_models" ,
32
+
33
+ LearningRate = 3e-5 ,
34
+ LogEveryNSteps = 10 ,
35
+ ValidateEveryNSteps = 2000 ,
36
+ TopK = 5
37
+ } ;
38
+ Directory . CreateDirectory ( config . SaveDir ) ;
39
+
40
+ // Initialize Model, Optimizer and Data Pre-processors
41
+ var runner = new QuestionAnsweringInference ( config ) ;
42
+
43
+ // Load Pre-trained General Purpose Model
44
+ runner . LoadModel ( config . LoadModelPath ) ;
45
+
46
+ // Load Corpus from Disk
47
+ var corpus = runner . LoadCorpus ( Path . Join ( config . DataDir , config . TestFile ) ) ;
48
+
49
+ // Start Inference Loop
50
+ runner . SearchOverCorpus ( corpus ) ;
51
+ }
52
+
53
+ private static readonly Logger < QuestionAnsweringInference > _logger = new ( ) ;
54
+ private const string _exit = "exit" ;
55
+
56
+ private QuestionAnsweringConfig Config { get ; }
57
+ private RobertaForQuestionAnswering Model { get ; }
58
+ private RobertaTokenizer Tokenizer { get ; }
59
+ private RobertaInputBuilder InputBuilder { get ; }
60
+
61
+ private QuestionAnsweringInference ( QuestionAnsweringConfig config )
62
+ {
63
+ Config = config ;
64
+
65
+ Model = new RobertaForQuestionAnswering (
66
+ numLayers : 12 ,
67
+ numAttentionHeads : 12 ,
68
+ numEmbeddings : 50265 ,
69
+ embeddingSize : 768 ,
70
+ hiddenSize : 768 ,
71
+ outputSize : 768 ,
72
+ ffnHiddenSize : 3072 ,
73
+ maxPositions : 512 ,
74
+ maxTokenTypes : 2 ,
75
+ layerNormEps : 1e-12 ,
76
+ embeddingDropoutRate : 0.1 ,
77
+ attentionDropoutRate : 0.1 ,
78
+ attentionOutputDropoutRate : 0.1 ,
79
+ outputDropoutRate : 0.1 ) ;
80
+ if ( config . Cuda ) Model . cuda ( ) ;
81
+
82
+ Tokenizer = new RobertaTokenizer ( config . VocabDir ) ;
83
+ InputBuilder = new RobertaInputBuilder ( Tokenizer , config . MaxSequence ) ;
84
+ }
85
+
86
+ public void LoadModel ( string path )
87
+ {
88
+ _logger . Log ( $ "Loading model from { path } ...", newline : false ) ;
89
+ Model . load ( path , false ) ;
90
+ if ( Config . Cuda ) Model . cuda ( ) ;
91
+ _logger . LogAppend ( "Done." ) ;
92
+ }
93
+
94
+ public SquadCorpus LoadCorpus ( string path )
95
+ {
96
+ return new SquadCorpus ( path , Tokenizer , InputBuilder ) ;
97
+ }
98
+
99
+ private void ModelForward ( SquadSampleBatch batch , bool applyPredictMasks ,
100
+ out int trueBatchSize , out torch . Tensor startLogits , out torch . Tensor endLogits ,
101
+ out torch . Tensor startPositions , out torch . Tensor endPositions )
102
+ {
103
+ trueBatchSize = ( int ) batch . Tokens . size ( 0 ) ;
104
+ ( startLogits , endLogits ) = Model . forward ( batch . Tokens , batch . Positions , batch . Segments , batch . AttentionMasks ) ;
105
+ if ( applyPredictMasks )
106
+ {
107
+ startLogits = startLogits . add_ ( batch . PredictMasks ) ;
108
+ endLogits = endLogits . add_ ( batch . PredictMasks ) ;
109
+ }
110
+
111
+ startPositions = null ;
112
+ endPositions = null ;
113
+ if ( batch . Starts . IsNotNull ( ) )
114
+ {
115
+ var ignoreIndex = startLogits . size ( - 1 ) ;
116
+ startPositions = batch . Starts . view ( - 1 ) . clamp ( 0 , ignoreIndex ) ;
117
+ endPositions = batch . Ends . view ( - 1 ) . clamp ( 0 , ignoreIndex ) ;
118
+ }
119
+ }
120
+
121
+ /// <summary>
122
+ /// Save GPU memory usage following this passage:
123
+ /// https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#pre-allocate-memory-in-case-of-variable-input-length
124
+ /// </summary>
125
+ private void GpuMemoryWarmupOnlyForward ( )
126
+ {
127
+ using var disposeScope = torch . NewDisposeScope ( ) ;
128
+ Model . eval ( ) ;
129
+ var batch = SquadCorpus . GetMaxDummyBatch ( Config ) ;
130
+ ModelForward ( batch , false , out var trueBatchSize , out var startLogits , out var endLogits ,
131
+ out var startPositions , out var endPositions ) ;
132
+ }
133
+
134
+ public void SearchOverCorpus ( SquadCorpus corpus )
135
+ {
136
+ var serializerOptions = new JsonSerializerOptions
137
+ {
138
+ WriteIndented = true
139
+ } ;
140
+ var selector = new TfIdfDocumentSelector ( corpus . Documents , Tokenizer ) ;
141
+
142
+ using var _ = torch . no_grad ( ) ;
143
+ GpuMemoryWarmupOnlyForward ( ) ;
144
+ Model . eval ( ) ;
145
+ while ( true )
146
+ {
147
+ Console . Clear ( ) ;
148
+ Console . Write ( $ "\n Type your question (\" { _exit } \" to exit): ") ;
149
+ var question = Console . ReadLine ( ) ;
150
+ if ( question == _exit ) break ;
151
+
152
+ var questionTokenIds = Tokenizer . TokenizeToId ( question ) ;
153
+ var questionLength = questionTokenIds . Count + 2 ;
154
+
155
+ var answers = new List < PredictionAnswer > ( ) ;
156
+ var bestMatch = selector . TopK ( question , Config . TopK ) ;
157
+ foreach ( var batch in corpus . GetBatches ( Config , questionTokenIds , bestMatch . Take ( 1 ) . ToArray ( ) ) )
158
+ {
159
+ using var disposeScope = torch . NewDisposeScope ( ) ;
160
+ ModelForward ( batch , true , out var trueBatchSize , out var startLogits , out var endLogits ,
161
+ out var startPositions , out var endPositions ) ;
162
+ for ( var i = 0 ; i < trueBatchSize ; ++ i )
163
+ {
164
+ var ( predictStartScores , predictStarts ) = startLogits [ i ] . topk ( Config . TopK ) ;
165
+ var ( predictEndScores , predictEnds ) = endLogits [ i ] . topk ( Config . TopK ) ;
166
+ var topKSpans = SquadMetric . ComputeTopKSpansWithScore ( predictStartScores , predictStarts , predictEndScores , predictEnds , Config . TopK ) ;
167
+ var predictStart = topKSpans [ 0 ] . start ;
168
+ var predictEnd = topKSpans [ 0 ] . end ;
169
+
170
+ // Restore predicted answer text
171
+ var document = bestMatch [ i ] ;
172
+ var contextText = Tokenizer . Untokenize ( document . ContextTokens ) ;
173
+ foreach ( var ( start , end , score ) in topKSpans )
174
+ {
175
+ var answerText = Tokenizer . Untokenize (
176
+ document . ContextTokens . ToArray ( ) [ ( start - questionLength ) ..( end - questionLength + 1 ) ] ) ;
177
+ answers . Add ( new PredictionAnswer { Score = score , Text = answerText } ) ;
178
+ }
179
+ }
180
+
181
+ answers = answers . OrderByDescending ( answer => answer . Score ) . Take ( Config . TopK ) . ToList ( ) ;
182
+ var outputString = JsonSerializer . Serialize ( answers , serializerOptions ) ;
183
+ Console . WriteLine ( $ "Predictions:\n { outputString } ") ;
184
+ } // end foreach
185
+ } // end while
186
+ }
187
+ }
188
+
189
+ internal struct PredictionAnswer
190
+ {
191
+ public string Text { get ; set ; }
192
+ public double Score { get ; set ; }
193
+ }
194
+ }
0 commit comments