diff --git a/src/CSharp/CSharpExamples/QuestionAnsweringInference.cs b/src/CSharp/CSharpExamples/QuestionAnsweringInference.cs new file mode 100644 index 0000000..e00c7f7 --- /dev/null +++ b/src/CSharp/CSharpExamples/QuestionAnsweringInference.cs @@ -0,0 +1,197 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text.Json; +using Examples.Utils; +using TorchSharp; +using TorchSharp.Examples; + +namespace CSharpExamples +{ + public class QuestionAnsweringInference + { + internal static void Run() + { + // Configure the run + var config = new QuestionAnsweringConfig + { + LoadModelPath = "roberta-bertformat-model_weights.dat", + DataDir = "data", + TrainFile = "mixed_train.json", + ValidFile = "mixed_valid.json", + TestFile = "test.json", + VocabDir = "vocab_files", + + BatchSize = 8, + OptimizeSteps = 1, + MaxSequence = 384, + Cuda = true, + SaveDir = "saved_models", + + LearningRate = 3e-5, + LogEveryNSteps = 10, + ValidateEveryNSteps = 2000, + TopK = 5 + }; + Directory.CreateDirectory(config.SaveDir); + + // Initialize Model, Optimizer and Data Pre-processors + var runner = new QuestionAnsweringInference(config); + + // Load Pre-trained General Purpose Model + runner.LoadModel(config.LoadModelPath); + + // Load Corpus from Disk + var corpus = runner.LoadCorpus(Path.Join(config.DataDir, config.TestFile)); + + // Start Inference Loop + runner.SearchOverCorpus(corpus); + } + + private static readonly Logger _logger = new(); + private const string _exit = "exit"; + + private QuestionAnsweringConfig Config { get; } + private RobertaForQuestionAnswering Model { get; } + private RobertaTokenizer Tokenizer { get; } + private RobertaInputBuilder InputBuilder { get; } + + private QuestionAnsweringInference(QuestionAnsweringConfig config) + { + Config = config; + + Model = new RobertaForQuestionAnswering( + numLayers: 12, + numAttentionHeads: 12, + numEmbeddings: 50265, + embeddingSize: 768, + hiddenSize: 768, + outputSize: 768, + ffnHiddenSize: 3072, + maxPositions: 512, + maxTokenTypes: 2, + layerNormEps: 1e-12, + embeddingDropoutRate: 0.1, + attentionDropoutRate: 0.1, + attentionOutputDropoutRate: 0.1, + outputDropoutRate: 0.1); + if (config.Cuda) Model.cuda(); + + Tokenizer = new RobertaTokenizer(config.VocabDir); + InputBuilder = new RobertaInputBuilder(Tokenizer, config.MaxSequence); + } + + public void LoadModel(string path) + { + _logger.Log($"Loading model from {path}...", newline: false); + Model.load(path, false); + if (Config.Cuda) Model.cuda(); + _logger.LogAppend("Done."); + } + + public SquadCorpus LoadCorpus(string path) + { + return new SquadCorpus(path, Tokenizer, InputBuilder); + } + + private void ModelForward(SquadSampleBatch batch, bool applyPredictMasks, + out int trueBatchSize, out torch.Tensor startLogits, out torch.Tensor endLogits, + out torch.Tensor startPositions, out torch.Tensor endPositions) + { + trueBatchSize = (int)batch.Tokens.size(0); + (startLogits, endLogits) = Model.forward(batch.Tokens, batch.Positions, batch.Segments, batch.AttentionMasks); + if (applyPredictMasks) + { + startLogits = startLogits.add_(batch.PredictMasks); + endLogits = endLogits.add_(batch.PredictMasks); + } + + startPositions = null; + endPositions = null; + if (batch.Starts.IsNotNull()) + { + var ignoreIndex = startLogits.size(-1); + startPositions = batch.Starts.view(-1).clamp(0, ignoreIndex); + endPositions = batch.Ends.view(-1).clamp(0, ignoreIndex); + } + } + + /// + /// Save GPU memory usage following this passage: + /// https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#pre-allocate-memory-in-case-of-variable-input-length + /// + private void GpuMemoryWarmupOnlyForward() + { + using var disposeScope = torch.NewDisposeScope(); + Model.eval(); + var batch = SquadCorpus.GetMaxDummyBatch(Config); + ModelForward(batch, false, out var trueBatchSize, out var startLogits, out var endLogits, + out var startPositions, out var endPositions); + } + + public void SearchOverCorpus(SquadCorpus corpus) + { + var serializerOptions = new JsonSerializerOptions + { + WriteIndented = true + }; + var selector = new TfIdfDocumentSelector(corpus.Documents, Tokenizer); + + using var _ = torch.no_grad(); + GpuMemoryWarmupOnlyForward(); + Model.eval(); + while (true) + { + Console.Clear(); + Console.Write($"Type your question (\"{_exit}\" to exit): "); + var question = Console.ReadLine(); + if (question == _exit) break; + + var questionTokenIds = Tokenizer.TokenizeToId(question); + var questionLength = questionTokenIds.Count + 2; + + var answers = new List(); + var bestMatch = selector.TopK(question, Config.TopK); + foreach (var batch in corpus.GetBatches(Config, questionTokenIds, bestMatch.Take(1).ToArray())) + { + using var disposeScope = torch.NewDisposeScope(); + ModelForward(batch, true, out var trueBatchSize, out var startLogits, out var endLogits, + out var startPositions, out var endPositions); + for (var i = 0; i < trueBatchSize; ++i) + { + var (predictStartScores, predictStarts) = startLogits[i].topk(Config.TopK); + var (predictEndScores, predictEnds) = endLogits[i].topk(Config.TopK); + var topKSpans = SquadMetric.ComputeTopKSpansWithScore(predictStartScores, predictStarts, predictEndScores, predictEnds, Config.TopK); + var predictStart = topKSpans[0].start; + var predictEnd = topKSpans[0].end; + + // Restore predicted answer text + var document = bestMatch[i]; + var contextText = Tokenizer.Untokenize(document.ContextTokens); + foreach (var (start, end, score) in topKSpans) + { + var answerText = Tokenizer.Untokenize( + document.ContextTokens.ToArray()[(start - questionLength)..(end - questionLength + 1)]); + answers.Add(new PredictionAnswer { Score = score, Text = answerText }); + } + } + + answers = answers.OrderByDescending(answer => answer.Score).Take(Config.TopK).ToList(); + var outputString = JsonSerializer.Serialize(answers, serializerOptions); + Console.WriteLine($"Predictions:\n{outputString}"); + + Console.Write("\nHit Enter Key to ask next question."); + Console.ReadLine(); + } // end foreach + } // end while + } + } + + internal struct PredictionAnswer + { + public string Text { get; set; } + public double Score { get; set; } + } +} \ No newline at end of file diff --git a/src/CSharp/CSharpExamples/QuestionAnsweringTraining.cs b/src/CSharp/CSharpExamples/QuestionAnsweringTraining.cs new file mode 100644 index 0000000..50fea75 --- /dev/null +++ b/src/CSharp/CSharpExamples/QuestionAnsweringTraining.cs @@ -0,0 +1,284 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.IO; +using System.Linq; +using Examples.Utils; +using TorchSharp; +using TorchSharp.Examples; +using TorchSharp.Modules; + +namespace CSharpExamples +{ + public class QuestionAnsweringTraining + { + internal static void Run() + { + // Configure the run + var config = new QuestionAnsweringConfig + { + LoadModelPath = "roberta-bertformat-model_weights.dat", + DataDir = "data", + TrainFile = "mixed_train.json", + ValidFile = "mixed_valid.json", + TestFile = "test.json", + VocabDir = "vocab_files", + + BatchSize = 8, + OptimizeSteps = 1, + MaxSequence = 384, + Cuda = true, + SaveDir = "saved_models", + + LearningRate = 3e-5, + LogEveryNSteps = 10, + ValidateEveryNSteps = 2000, + TopK = 5 + }; + Directory.CreateDirectory(config.SaveDir); + + // Initialize Model, Optimizer and Data Pre-processors + var runner = new QuestionAnsweringTraining(config); + + // Load Pre-trained General Purpose Model + runner.LoadModel(config.LoadModelPath); + + // Load Dataset from Disk + var trainDataset = runner.LoadDataset(Path.Join(config.DataDir, config.TrainFile)); + var validDataset = runner.LoadDataset(Path.Join(config.DataDir, config.ValidFile)); + var testDataset = runner.LoadDataset(Path.Join(config.DataDir, config.TestFile)); + + // Start Training (Finetuning) + runner.Train(trainDataset, validDataset, testDataset); + } + + private static readonly Logger _logger = new(); + + private QuestionAnsweringConfig Config { get; } + private RobertaForQuestionAnswering Model { get; } + private RobertaTokenizer Tokenizer { get; } + private RobertaInputBuilder InputBuilder { get; } + + private AdamW Optimizer { get; } + + private QuestionAnsweringTraining(QuestionAnsweringConfig config) + { + Config = config; + + Model = new RobertaForQuestionAnswering( + numLayers: 12, + numAttentionHeads: 12, + numEmbeddings: 50265, + embeddingSize: 768, + hiddenSize: 768, + outputSize: 768, + ffnHiddenSize: 3072, + maxPositions: 512, + maxTokenTypes: 2, + layerNormEps: 1e-12, + embeddingDropoutRate: 0.1, + attentionDropoutRate: 0.1, + attentionOutputDropoutRate: 0.1, + outputDropoutRate: 0.1); + if (config.Cuda) Model.cuda(); + Optimizer = torch.optim.AdamW(Model.parameters(), Config.LearningRate); + + Tokenizer = new RobertaTokenizer(config.VocabDir); + InputBuilder = new RobertaInputBuilder(Tokenizer, config.MaxSequence); + } + + private void LoadModel(string path) + { + _logger.Log($"Loading model from {path}...", newline: false); + Model.load(path, false); + if (Config.Cuda) Model.cuda(); + _logger.LogAppend("Done."); + } + + private SquadDataset LoadDataset(string path) + { + return new SquadDataset(path, Tokenizer, InputBuilder); + } + + private void ModelForward(SquadSampleBatch batch, bool applyPredictMasks, + out int trueBatchSize, out torch.Tensor startLogits, out torch.Tensor endLogits, + out torch.Tensor startPositions, out torch.Tensor endPositions) + { + trueBatchSize = (int)batch.Tokens.size(0); + (startLogits, endLogits) = Model.forward(batch.Tokens, batch.Positions, batch.Segments, batch.AttentionMasks); + if (applyPredictMasks) + { + startLogits = startLogits.add_(batch.PredictMasks); + endLogits = endLogits.add_(batch.PredictMasks); + } + + startPositions = null; + endPositions = null; + if (batch.Starts.IsNotNull()) + { + var ignoreIndex = startLogits.size(-1); + startPositions = batch.Starts.view(-1).clamp(0, ignoreIndex); + endPositions = batch.Ends.view(-1).clamp(0, ignoreIndex); + } + } + + /// + /// Save GPU memory usage following this passage: + /// https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#pre-allocate-memory-in-case-of-variable-input-length + /// + private void GpuMemoryWarmup() + { + using var disposeScope = torch.NewDisposeScope(); + + Model.train(); + var batch = SquadDataset.GetMaxDummyBatch(Config); + ModelForward(batch, false, out var trueBatchSize, out var startLogits, out var endLogits, + out var startPositions, out var endPositions); + var lossFunc = torch.nn.functional.cross_entropy_loss(); + var lossStart = lossFunc(startLogits, startPositions); + var lossEnd = lossFunc(endLogits, endPositions); + var loss = ((lossStart + lossEnd) / 2); + loss.backward(); + Optimizer.zero_grad(); + } + + public void Train(SquadDataset trainDataset, SquadDataset validDataset, SquadDataset testDataset) + { + var batchSize = Config.BatchSize; + var optimizeSteps = Config.OptimizeSteps; + + var correctSamples = 0; + var startCorrectSamples = 0; + var endCorrectSamples = 0; + var totalSamples = 0; + var all = (int)Math.Ceiling(1.0 * trainDataset.Count / batchSize); + var totalLoss = 0.0; + + var step = 1; + var startTime = DateTime.Now; + GpuMemoryWarmup(); + for (var epoch = 0; epoch < 10; epoch++) + { + foreach (var batch in trainDataset.GetBatches(Config, shuffle: true)) + { + using var disposeScope = torch.NewDisposeScope(); + Model.train(); + Optimizer.zero_grad(); + + ModelForward(batch, false, out var trueBatchSize, out var startLogits, out var endLogits, + out var startPositions, out var endPositions); + + var lossFunc = torch.nn.functional.cross_entropy_loss(); + var lossStart = lossFunc(startLogits, startPositions); + var lossEnd = lossFunc(endLogits, endPositions); + var loss = ((lossStart + lossEnd) / 2); + // For N-step optimization + totalLoss += loss.ToItem(); + loss /= optimizeSteps; + loss.backward(); + if (step % optimizeSteps == 0) + { + var temp = Optimizer.parameters().ToArray().Select(p => p.grad()).ToArray(); + Optimizer.step(); + } + + var predictionStarts = startLogits.argmax(-1).view(-1).ToArray(); + var predictionEnds = endLogits.argmax(-1).view(-1).ToArray(); + var groundStarts = startPositions.ToArray(); + var groundEnds = endPositions.ToArray(); + for (var i = 0; i < trueBatchSize; ++i) + { + if (predictionStarts[i] == groundStarts[i] && predictionEnds[i] == groundEnds[i]) + { + ++correctSamples; + ++startCorrectSamples; + ++endCorrectSamples; + } + else if (predictionStarts[i] == groundStarts[i]) + { + ++startCorrectSamples; + } + else if (predictionEnds[i] == groundEnds[i]) + { + ++endCorrectSamples; + } + } + totalSamples += trueBatchSize; + + if (step % Config.LogEveryNSteps == 0) + { + var endTime = DateTime.Now; + _logger.LogLoop($" " + + $"\rEpoch {epoch} - {step % all}/{all}," + + $"\t Time {endTime - startTime}," + + $"\t Acc {(totalSamples == 1 ? 0.0 : 1.0 * correctSamples / (totalSamples - 1)):N4}" + + $"\t StartAcc {(totalSamples == 1 ? 0.0 : 1.0 * startCorrectSamples / (totalSamples - 1)):N4}" + + $"\t EndAcc {(totalSamples == 1 ? 0.0 : 1.0 * endCorrectSamples / (totalSamples - 1)):N4}"); + } + + if (step % Config.ValidateEveryNSteps == 0) + { + _logger.LogAppend($"\t Loss {totalLoss / Config.ValidateEveryNSteps:N4}"); + Validate(validDataset); + Validate(testDataset); + correctSamples = startCorrectSamples = endCorrectSamples = totalSamples = 0; + totalLoss = 0.0; + Model.save(Path.Join(Config.SaveDir, $"model_{step}.tsm")); + } + ++step; + } + } + } + + public void Validate(SquadDataset validDataset) + { + _logger.Log($"Evaluating on {validDataset.FilePath}...", newline: false); + + Model.eval(); + var correct = 0; + var startCorrect = 0; + var endCorrect = 0; + var f1score = 0.0; + var total = 0; + using (torch.no_grad()) + { + foreach (var batch in validDataset.GetBatches(Config, shuffle: false)) + { + using var disposeScope = torch.NewDisposeScope(); + + ModelForward(batch, false, out var trueBatchSize, out var startLogits, out var endLogits, + out var startPositions, out var endPositions); + + var predictionStarts = startLogits.argmax(-1).view(-1).ToArray(); + var predictionEnds = endLogits.argmax(-1).view(-1).ToArray(); + for (var i = 0; i < trueBatchSize; ++i) + { + var predictionStart = predictionStarts[i]; + var predictionEnd = predictionEnds[i]; + var zip = Enumerable.Zip(startPositions[i].ToArray(), endPositions[i].ToArray()).ToArray(); + if (zip.Any(pair => predictionStart == pair.First && predictionEnd == pair.Second)) + { + ++correct; + ++startCorrect; + ++endCorrect; + } + else if (zip.Any(pair => predictionStart == pair.First)) + { + ++startCorrect; + } + else if (zip.Any(pair => predictionEnd == pair.Second)) + { + ++endCorrect; + } + f1score += zip.Max(pair => SquadMetric.ComputeF1(predictionStart, predictionEnd, pair.First, pair.Second)); + } + total += trueBatchSize; + } + } + _logger.LogAppend( + $"\rAccuracy: {1.0 * correct / total:N4}, {correct}/{total} ---- " + + $"F1Score: {f1score / total:N4} ---- " + + $"Start: {1.0 * startCorrect / total:N4}, {startCorrect} ---- " + + $"End: {1.0 * endCorrect / total:N4}, {endCorrect}."); + } + } +} \ No newline at end of file diff --git a/src/CSharp/Models/Roberta.cs b/src/CSharp/Models/Roberta.cs new file mode 100644 index 0000000..05c6110 --- /dev/null +++ b/src/CSharp/Models/Roberta.cs @@ -0,0 +1,294 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Linq; +using TorchSharp.Modules; + +namespace TorchSharp.Examples +{ + public class Roberta : torch.nn.Module + { + private readonly Embeddings embeddings; + private readonly Encoder encoder; + + public Roberta(int numLayers, int numAttentionHeads, + long numEmbeddings, long embeddingSize, long hiddenSize, long outputSize, long ffnHiddenSize, + long maxPositions, long maxTokenTypes, double layerNormEps, + double embeddingDropoutRate, double attentionDropoutRate, double attentionOutputDropoutRate, double outputDropoutRate) + : base(nameof(Roberta)) + { + embeddings = new Embeddings(numEmbeddings, embeddingSize, maxPositions, maxTokenTypes, + layerNormEps, embeddingDropoutRate); + encoder = new Encoder(numLayers, numAttentionHeads, embeddingSize, hiddenSize, outputSize, ffnHiddenSize, + layerNormEps, attentionDropoutRate, attentionOutputDropoutRate, outputDropoutRate); + RegisterComponents(); + } + + public override torch.Tensor forward(torch.Tensor t) + { + throw new NotImplementedException(); + } + + public override torch.Tensor forward(torch.Tensor x, torch.Tensor y) + { + throw new NotImplementedException(); + } + + public torch.Tensor forward(torch.Tensor tokens, torch.Tensor positions, torch.Tensor tokenTypes, torch.Tensor attentionMask) + { + using var disposeScope = torch.NewDisposeScope(); + var x = embeddings.forward(tokens, positions, tokenTypes); + var sequenceOutput = encoder.forward(x, attentionMask); + return sequenceOutput.MoveToOuterDisposeScope(); + } + } + + internal class Embeddings : torch.nn.Module + { + public readonly Embedding word_embeddings; + public readonly Embedding position_embeddings; + public readonly Embedding token_type_embeddings; + public readonly LayerNorm LayerNorm; + public readonly Dropout dropout; + + public Embeddings(long numEmbeddings, long embeddingSize, long maxPositions, long maxTokenTypes, + double layerNormEps, double dropoutRate) + : base(nameof(Embeddings)) + { + word_embeddings = torch.nn.Embedding(numEmbeddings, embeddingSize, padding_idx: 1); + position_embeddings = torch.nn.Embedding(maxPositions, embeddingSize); + token_type_embeddings = torch.nn.Embedding(maxTokenTypes, embeddingSize); + LayerNorm = torch.nn.LayerNorm(new long[] { embeddingSize }, eps: layerNormEps); + dropout = torch.nn.Dropout(dropoutRate); + + RegisterComponents(); + } + + public torch.Tensor forward(torch.Tensor tokens, torch.Tensor positions, torch.Tensor segments) + { + using var disposeScope = torch.NewDisposeScope(); + var tokenEmbedding = word_embeddings.forward(tokens); + var positionEmbedding = position_embeddings.forward(positions); + var tokenTypeEmbedding = token_type_embeddings.forward(segments); + var embedding = tokenEmbedding + positionEmbedding + tokenTypeEmbedding; + var output = LayerNorm.forward(embedding); + output = dropout.forward(output); + return output.MoveToOuterDisposeScope(); + } + } + + internal class Encoder : torch.nn.Module + { + public readonly int NumLayers; + public readonly ModuleList layer; + + public Encoder(int numLayers, int numAttentionHeads, long embeddingSize, long hiddenSize, long outputSize, long ffnHiddenSize, + double layerNormEps, double dropoutRate, double attentionDropoutRate, double outputDropoutRate) + : base(nameof(Encoder)) + { + NumLayers = numLayers; + layer = new ModuleList(Enumerable.Range(0, numLayers) + .Select(_ => new Layer(numAttentionHeads, hiddenSize, ffnHiddenSize, + layerNormEps, dropoutRate, attentionDropoutRate, outputDropoutRate)) + .ToArray()); + RegisterComponents(); + } + + public override torch.Tensor forward(torch.Tensor x, torch.Tensor attentionMask) + { + using var disposeScope = torch.NewDisposeScope(); + foreach (var lyr in layer) + { + x = lyr.forward(x, attentionMask); + } + return x.MoveToOuterDisposeScope(); + } + } + + internal class Layer : torch.nn.Module + { + public readonly Attention attention; + public readonly Intermediate intermediate; + public readonly Output output; + + public Layer(int numAttentionHeads, long hiddenSize, long ffnHiddenSize, double layerNormEps, + double dropoutRate, double attentionDropoutRate, double outputDropoutRate) + : base(nameof(Layer)) + { + attention = new Attention(numAttentionHeads, hiddenSize, layerNormEps, attentionDropoutRate, outputDropoutRate); + intermediate = new Intermediate(hiddenSize, ffnHiddenSize); + output = new Output(ffnHiddenSize, hiddenSize, dropoutRate); + RegisterComponents(); + } + + public override torch.Tensor forward(torch.Tensor input, torch.Tensor attentionMask) + { + using var disposeScope = torch.NewDisposeScope(); + var attentionOutput = attention.forward(input, attentionMask); + var intermediateOutput = intermediate.forward(attentionOutput); + var layerOutput = output.forward(intermediateOutput, attentionOutput); + return layerOutput.MoveToOuterDisposeScope(); + } + } + + internal class Attention : torch.nn.Module + { + public readonly AttentionSelf self; + public readonly AttentionOutput output; + + public Attention(int numAttentionHeads, long hiddenSize, double layerNormEps, double attentionDropoutRate, double outputDropoutRate) + : base(nameof(Attention)) + { + self = new AttentionSelf(numAttentionHeads, hiddenSize, layerNormEps, attentionDropoutRate); + output = new AttentionOutput(hiddenSize, layerNormEps, outputDropoutRate); + RegisterComponents(); + } + + public override torch.Tensor forward(torch.Tensor hiddenStates, torch.Tensor attentionMask) + { + using var disposeScope = torch.NewDisposeScope(); + var x = self.forward(hiddenStates, attentionMask); + x = output.forward(x, hiddenStates); + return x.MoveToOuterDisposeScope(); + } + } + + internal class AttentionSelf : torch.nn.Module + { + public readonly int NumAttentionHeads; + public readonly int AttentionHeadSize; + + public readonly Linear query; + public readonly Linear key; + public readonly Linear value; + public readonly Dropout attention_dropout; + + public AttentionSelf(int numAttentionHeads, long hiddenSize, double layerNormEps, double attentionDropoutRate) + : base(nameof(AttentionSelf)) + { + NumAttentionHeads = numAttentionHeads; + AttentionHeadSize = (int)hiddenSize / numAttentionHeads; + if (NumAttentionHeads * AttentionHeadSize != hiddenSize) + { + throw new ArgumentException($"NumAttentionHeads must be a factor of hiddenSize, got {numAttentionHeads} and {hiddenSize}."); + } + + query = torch.nn.Linear(hiddenSize, hiddenSize, true); + key = torch.nn.Linear(hiddenSize, hiddenSize, true); + value = torch.nn.Linear(hiddenSize, hiddenSize, true); + attention_dropout = torch.nn.Dropout(attentionDropoutRate); + + RegisterComponents(); + } + + public override torch.Tensor forward(torch.Tensor hiddenStates, torch.Tensor attentionMask) + { + using var disposeScope = torch.NewDisposeScope(); + var mixedQueryLayer = query.forward(hiddenStates); + var mixedKeyLayer = key.forward(hiddenStates); + var mixedValueLayer = value.forward(hiddenStates); + + var queryLayer = TransposeForScores(mixedQueryLayer); + var keyLayer = TransposeForScores(mixedKeyLayer); + var valueLayer = TransposeForScores(mixedValueLayer); + + // Attention + queryLayer.div_(Math.Sqrt(AttentionHeadSize)); + var attentionScores = torch.matmul(queryLayer, keyLayer.transpose_(-1, -2)); + if (attentionMask is not null && !attentionMask.IsInvalid) + { + attentionScores.add_(attentionMask); + } + + var attentionProbs = torch.nn.functional.softmax(attentionScores, dim: -1); + attentionProbs = attention_dropout.forward(attentionProbs); + + var contextLayer = torch.matmul(attentionProbs, valueLayer); + contextLayer = contextLayer.permute(0, 2, 1, 3).contiguous(); + var contextShape = contextLayer.shape[..^2].Append(NumAttentionHeads * AttentionHeadSize).ToArray(); + contextLayer = contextLayer.view(contextShape); + return contextLayer.MoveToOuterDisposeScope(); + } + + /// + /// [B x T x C] -> [B x Head x T x C_Head] + /// + private torch.Tensor TransposeForScores(torch.Tensor x) + { + using var disposeScope = torch.NewDisposeScope(); + var newShape = x.shape[..^1].Append(NumAttentionHeads).Append(AttentionHeadSize).ToArray(); + x = x.view(newShape); + x = x.permute(0, 2, 1, 3).contiguous(); + return x.MoveToOuterDisposeScope(); + } + } + + internal class AttentionOutput : torch.nn.Module + { + public readonly Linear dense; + public readonly Dropout dropout; + public readonly LayerNorm LayerNorm; + + public AttentionOutput(long hiddenSize, double layerNormEps, double outputDropoutRate) + : base(nameof(AttentionOutput)) + { + dense = torch.nn.Linear(hiddenSize, hiddenSize, true); + dropout = torch.nn.Dropout(outputDropoutRate); + LayerNorm = torch.nn.LayerNorm(new long[] { hiddenSize }); + RegisterComponents(); + } + + public override torch.Tensor forward(torch.Tensor hiddenStates, torch.Tensor inputTensor) + { + using var disposeScope = torch.NewDisposeScope(); + hiddenStates = dense.forward(hiddenStates); + hiddenStates = dropout.forward(hiddenStates); + hiddenStates = LayerNorm.forward(hiddenStates + inputTensor); + return hiddenStates.MoveToOuterDisposeScope(); + } + } + + internal class Intermediate : torch.nn.Module + { + public readonly Linear dense; + public readonly GELU gelu; + + public Intermediate(long hiddenSize, long ffnHiddenSize) : base(nameof(Intermediate)) + { + dense = torch.nn.Linear(hiddenSize, ffnHiddenSize, true); + gelu = torch.nn.GELU(); + RegisterComponents(); + } + + public override torch.Tensor forward(torch.Tensor t) + { + using var disposeScope = torch.NewDisposeScope(); + t = dense.forward(t); + t = gelu.forward(t); + return t.MoveToOuterDisposeScope(); + } + } + + internal class Output : torch.nn.Module + { + public readonly Linear dense; + public readonly LayerNorm LayerNorm; + public readonly Dropout Dropout; + + public Output(long ffnHiddenSize, long hiddenSize, double outputDropoutRate) : base(nameof(Output)) + { + dense = torch.nn.Linear(ffnHiddenSize, hiddenSize, true); + Dropout = torch.nn.Dropout(outputDropoutRate); + LayerNorm = torch.nn.LayerNorm(new long[] { hiddenSize }); + RegisterComponents(); + } + + public override torch.Tensor forward(torch.Tensor hiddenStates, torch.Tensor inputTensor) + { + using var disposeScope = torch.NewDisposeScope(); + hiddenStates = dense.forward(hiddenStates); + hiddenStates = Dropout.forward(hiddenStates); + hiddenStates = LayerNorm.forward(hiddenStates + inputTensor); + return hiddenStates.MoveToOuterDisposeScope(); + } + } +} \ No newline at end of file diff --git a/src/CSharp/Models/RobertaForQuestionAnswering.cs b/src/CSharp/Models/RobertaForQuestionAnswering.cs new file mode 100644 index 0000000..f1381d0 --- /dev/null +++ b/src/CSharp/Models/RobertaForQuestionAnswering.cs @@ -0,0 +1,74 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using TorchSharp.Modules; + +namespace TorchSharp.Examples +{ + public class RobertaForQuestionAnswering : torch.nn.Module + { + public readonly Roberta bert; + public readonly Linear qa_outputs; + + public RobertaForQuestionAnswering(int numLayers, int numAttentionHeads, + long numEmbeddings, long embeddingSize, long hiddenSize, long outputSize, long ffnHiddenSize, + long maxPositions, long maxTokenTypes, double layerNormEps, + double embeddingDropoutRate, double attentionDropoutRate, double attentionOutputDropoutRate, double outputDropoutRate) + : base(nameof(RobertaForQuestionAnswering)) + { + bert = new Roberta( + numLayers: numLayers, + numAttentionHeads: numAttentionHeads, + numEmbeddings: numEmbeddings, + embeddingSize: embeddingSize, + hiddenSize: hiddenSize, + outputSize: outputSize, + ffnHiddenSize: ffnHiddenSize, + maxPositions: maxPositions, + maxTokenTypes: maxTokenTypes, + layerNormEps: layerNormEps, + embeddingDropoutRate: embeddingDropoutRate, + attentionDropoutRate: attentionDropoutRate, + attentionOutputDropoutRate: attentionOutputDropoutRate, + outputDropoutRate: outputDropoutRate); + qa_outputs = torch.nn.Linear(inputSize: outputSize, outputSize: 2); + + apply(InitWeights); + + RegisterComponents(); + } + + public (torch.Tensor startLogits, torch.Tensor endLogits) forward( + torch.Tensor tokens, torch.Tensor positions, torch.Tensor tokenTypes, torch.Tensor attentionMask) + { + using var disposeScope = torch.NewDisposeScope(); + var encodedVector = bert.forward(tokens, positions, tokenTypes, attentionMask); + var logits = qa_outputs.forward(encodedVector); + var splitLogits = logits.split(1, dimension: -1); + var startLogits = splitLogits[0].squeeze(-1).contiguous(); + var endLogits = splitLogits[1].squeeze(-1).contiguous(); + return (startLogits.MoveToOuterDisposeScope(), endLogits.MoveToOuterDisposeScope()); + } + + private static void InitWeights(torch.nn.Module module) + { + using var disposeScope = torch.NewDisposeScope(); + if (module is Linear linearModule) + { + linearModule.weight.normal_(mean: 0.0, stddev: 0.02); + if (linearModule.bias is not null && !linearModule.bias.IsInvalid) + { + linearModule.bias.zero_(); + } + } + else if (module is Embedding embeddingModule) + { + embeddingModule.weight.normal_(mean: 0.0, stddev: 0.02); + embeddingModule.weight[1].zero_(); // padding_idx + } + else if (module is LayerNorm layerNormModule) + { + layerNormModule.weight.fill_(1.0); + layerNormModule.bias.zero_(); + } + } + } +} \ No newline at end of file diff --git a/src/Utils/DefaultDictionary.cs b/src/Utils/DefaultDictionary.cs new file mode 100644 index 0000000..119c519 --- /dev/null +++ b/src/Utils/DefaultDictionary.cs @@ -0,0 +1,100 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections; +using System.Collections.Generic; +using System.Linq; + +namespace Examples.Utils +{ + [Serializable] + public class DefaultDictionary : IDictionary + { + private readonly Func _init; + private readonly Dictionary _dictionary; + + public DefaultDictionary(Func init) + { + _init = init; + _dictionary = new Dictionary(); + } + + public TValue this[TKey key] + { + get + { + if (!ContainsKey(key)) Add(key, _init()); + return _dictionary[key]; + } + set + { + if (!ContainsKey(key)) Add(key, value); + else _dictionary[key] = value; + } + } + + /*** Below are auto-implemented methods ***/ + + public IEnumerator> GetEnumerator() + { + return _dictionary.GetEnumerator(); + } + + IEnumerator IEnumerable.GetEnumerator() + { + return ((IEnumerable)_dictionary).GetEnumerator(); + } + + public void Add(KeyValuePair item) + { + _dictionary.Add(item.Key, item.Value); + } + + public void Clear() + { + _dictionary.Clear(); + } + + public bool Contains(KeyValuePair item) + { + return _dictionary.Contains(item); + } + + public void CopyTo(KeyValuePair[] array, int index) + { + throw new NotImplementedException(); + } + + public bool Remove(KeyValuePair item) + { + return _dictionary.Remove(item.Key); + } + + public int Count => _dictionary.Count; + + public bool IsReadOnly => false; + + public void Add(TKey key, TValue value) + { + _dictionary.Add(key, value); + } + + public bool ContainsKey(TKey key) + { + return _dictionary.ContainsKey(key); + } + + public bool Remove(TKey key) + { + return _dictionary.Remove(key); + } + + public bool TryGetValue(TKey key, out TValue value) + { + return _dictionary.TryGetValue(key, out value); + } + + public ICollection Keys => _dictionary.Keys; + + public ICollection Values => _dictionary.Values; + } +} \ No newline at end of file diff --git a/src/Utils/Logger.cs b/src/Utils/Logger.cs new file mode 100644 index 0000000..5fa17f3 --- /dev/null +++ b/src/Utils/Logger.cs @@ -0,0 +1,36 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; + +namespace Examples.Utils +{ + public class Logger + { + private readonly Type Type = typeof(T); + + public void Log(string text = "", bool newline = true, bool carriageReturn = false, LogLevel logLevel = LogLevel.Info) + { + Console.Write( + $"{(carriageReturn ? "\r" : "")}" + + $"{DateTime.Now:u} - {Type.Name} {logLevel}: {text}" + + $"{(newline ? "\n" : "")}"); + } + + public void LogLoop(string text = "") + { + Console.Write($"\r{Type.Name}: {text}"); + } + + public void LogAppend(string text = "") + { + Console.WriteLine(text); + } + } + + public enum LogLevel + { + Info, + Debug, + Warning, + Error + } +} \ No newline at end of file diff --git a/src/Utils/QuestionAnsweringConfig.cs b/src/Utils/QuestionAnsweringConfig.cs new file mode 100644 index 0000000..0131c54 --- /dev/null +++ b/src/Utils/QuestionAnsweringConfig.cs @@ -0,0 +1,27 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. + +namespace Examples.Utils +{ + public class QuestionAnsweringConfig + { + public string LoadModelPath { get; set; } + public string DataDir { get; set; } + public string TrainFile { get; set; } + public string ValidFile { get; set; } + public string TestFile { get; set; } + public string VocabDir { get; set; } + + public int BatchSize { get; set; } + public int OptimizeSteps { get; set; } + public int MaxSequence { get; set; } + public bool Cuda { get; set; } + public string SaveDir { get; set; } + + public double LearningRate { get; set; } + public int LogEveryNSteps { get; set; } + public int ValidateEveryNSteps { get; set; } + + public int TopK { get; set; } + } + +} \ No newline at end of file diff --git a/src/Utils/RobertaInputBuilder.cs b/src/Utils/RobertaInputBuilder.cs new file mode 100644 index 0000000..5a7a630 --- /dev/null +++ b/src/Utils/RobertaInputBuilder.cs @@ -0,0 +1,153 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System.Collections.Generic; +using System.Linq; + +namespace Examples.Utils +{ + public class RobertaInputBuilder + { + private readonly RobertaTokenizer Tokenizer; + public readonly int MaxPositions; + private readonly int[] Positions; + private readonly int[] Zeros; + private readonly int[] Ones; + + public RobertaInputBuilder(RobertaTokenizer tokenizer, int maxPositions) + { + Tokenizer = tokenizer; + MaxPositions = maxPositions; + Positions = Enumerable.Range(0, maxPositions).ToArray(); + Zeros = Enumerable.Repeat(0, maxPositions).ToArray(); + Ones = Enumerable.Repeat(1, maxPositions).ToArray(); + } + + /// + /// Build Tensor input for the model according to the actual length of input tokens. + /// [CLS] and [SEP] tokens will be added. + /// + /// + /// This method is only for input with one segment. + /// + /// A list of tokenized input tokens. + /// A triplet of (token, position, token type). + public (int[] tokens, int[] positions, int[] segments) Build(IList tokens) + { + var tokenizedId = Tokenizer.TokenizeToId(tokens).ToList(); + tokenizedId.Insert(0, Tokenizer.BosIndex); + tokenizedId.Add(Tokenizer.EosIndex); + return ( + tokenizedId.ToArray(), + Positions[0..tokenizedId.Count], + Zeros[0..tokenizedId.Count]); + } + + /// + /// Build Tensor input for the model according to the actual length of input tokens. + /// [CLS] and [SEP] tokens will be added. + /// + /// + /// This method is only for input with two segments. + /// + /// A list of tokenized input tokens. + /// A triplet of (token, position, token type). + public (int[] tokens, int[] positions, int[] segments) Build(IList tokens0, IList tokens1) + { + var tokenizedId0 = Tokenizer.TokenizeToId(tokens0); + var tokenizedId1 = Tokenizer.TokenizeToId(tokens1); + var tokenizedId = tokenizedId0.ToList(); + tokenizedId.Insert(0, Tokenizer.BosIndex); + tokenizedId.Add(Tokenizer.EosIndex); + tokenizedId.AddRange(tokenizedId1); + tokenizedId = tokenizedId.Take(MaxPositions - 1).ToList(); + tokenizedId.Add(Tokenizer.EosIndex); + var segments = Zeros[0..(tokenizedId0.Count + 2)] + .Concat(Ones[0..(tokenizedId.Count - tokenizedId0.Count - 3)]).ToArray(); + return ( + tokenizedId.ToArray(), + Positions[0..tokenizedId.Count], + segments); + } + + /// + /// Build Tensor input for the model according to the actual length of input tokens. + /// [CLS] and [SEP] tokens will be added. + /// + /// + /// This method is only for input with one segment. + /// + /// A list of tokenized input token IDs. + /// A triplet of (token, position, token type). + public (int[] tokens, int[] positions, int[] segments) Build(IList tokensId) + { + tokensId.Insert(0, Tokenizer.BosIndex); + tokensId.Add(Tokenizer.EosIndex); + return ( + tokensId.ToArray(), + Positions[0..tokensId.Count], + Zeros[0..tokensId.Count]); + } + + /// + /// Build Tensor input for the model according to the actual length of input tokens. + /// [CLS] and [SEP] tokens will be added. + /// + /// + /// This method is only for input with two segments. + /// + /// A list of tokenized input token IDs. + /// A triplet of (token, position, token type). + public (int[] tokens, int[] positions, int[] segments, int questionSegmentLength) + Build(IList tokensId0, IList tokensId1) + { + var tokensId = tokensId0.ToList(); + tokensId.Insert(0, Tokenizer.BosIndex); + tokensId.Add(Tokenizer.EosIndex); + tokensId.AddRange(tokensId1); + tokensId = tokensId.Take(MaxPositions - 1).ToList(); + tokensId.Add(Tokenizer.EosIndex); + var segments = Zeros[0..(tokensId0.Count + 2)] + .Concat(Ones[0..(tokensId.Count - tokensId0.Count - 2)]).ToArray(); + return ( + tokensId.ToArray(), + Positions[0..tokensId.Count], + segments, + tokensId0.Count + 2); + } + + /// + /// Build Tensor input for the model according to the actual length of input tokens. + /// [CLS] and [SEP] tokens will be added. + /// + /// + /// This method is only for input with two segments. + /// + /// A list of tokenized input token IDs. + /// A tuple of (token, position, token type, startPosition, endPosition). + public (int[] tokens, int[] positions, int[] segments, int[] startPositions, int[] endPositions, int questionSegmentLength) + Build(IList tokensId0, IList tokensId1, IList<(int, int)> answerPositions) + { + var (tokens, positions, segments, questionSegmentLength) = Build(tokensId0, tokensId1); + + // There are only one answer in training set. + // Filter out samples with the answer in truncated part. + var groundTruths = new HashSet<(int, int)>(); + foreach (var answer in answerPositions) + { + var groundStart = answer.Item1 + questionSegmentLength; + var groundEnd = answer.Item2 + questionSegmentLength; + if (groundEnd < MaxPositions - 1) + { + groundTruths.Add((groundStart, groundEnd)); + } + } + if (groundTruths.Count == 0) + { + groundTruths.Add((0, 0)); + } + var startPositions = groundTruths.Select(p => p.Item1).ToArray(); + var endPositions = groundTruths.Select(p => p.Item2).ToArray(); + + return (tokens, positions, segments, startPositions, endPositions, questionSegmentLength); + } + } +} \ No newline at end of file diff --git a/src/Utils/RobertaTokenizer.cs b/src/Utils/RobertaTokenizer.cs new file mode 100644 index 0000000..120aab7 --- /dev/null +++ b/src/Utils/RobertaTokenizer.cs @@ -0,0 +1,461 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Net; +using System.Text.Json; +using System.Text.RegularExpressions; + +namespace Examples.Utils +{ + public class RobertaTokenizer + { + private const string EncoderJsonName = "encoder.json"; + private const string MergeName = "vocab.bpe"; + private const string DictName = "dict.txt"; + + private static readonly Uri EncoderJsonUrl = new("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json"); + private static readonly Uri MergeUrl = new("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe"); + private static readonly Uri DictUrl = new("https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt"); + + private readonly string _path; + private readonly RobertaVocabulary Vocabulary; + private readonly IReadOnlyDictionary Encoder; + private readonly IReadOnlyDictionary Decoder; + private readonly (string, string)[] Merges; + private readonly DefaultDictionary<(string, string), int> MergeRanks; + public readonly IReadOnlyDictionary ByteToUnicode; + private readonly IReadOnlyDictionary UnicodeToByte; + //private const string Pattern = @"'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+"; + public readonly char StartChar; + + public int VocabSize => Vocabulary.Count; + public int PadIndex => Vocabulary.PadIndex; + public int UnkIndex => Vocabulary.UnkIndex; + public int BosIndex => Vocabulary.BosIndex; + public int EosIndex => Vocabulary.EosIndex; + public int MaskIndex => Vocabulary.MaskIndex; + + public string PadToken => Vocabulary.PadWord; + public string UnkToken => Vocabulary.UnkWord; + public string BosToken => Vocabulary.BosWord; + public string EosToken => Vocabulary.EosWord; + public string MaskToken => Vocabulary.MaskWord; + + public RobertaTokenizer(string path, char startChar = '\u0120') + { + _path = path; + Directory.CreateDirectory(_path); + + StartChar = startChar; + Vocabulary = GetVocabulary(); + Encoder = GetEncoder(); + Decoder = Encoder.Reverse(); + Merges = GetMerges(); + MergeRanks = GetMergeRanks(); + ByteToUnicode = GetByteToUnicode(); + UnicodeToByte = ByteToUnicode.Reverse(); + } + + #region Public API + /// + /// Tokenize a sequence of words into a sequence of subtokens. + /// + public IList Tokenize(IEnumerable words) + { + return words + .Select(word => Bpe(word)) + .Aggregate((all, tokens) => all.Concat(tokens)) + .ToArray(); + } + + /// + /// Tokenize a sentence into a sequence of subtokens. + /// The sentence will first be split into a sequence of words according to the given Regex pattern + /// before applying the tokenization algorithm. + /// + public IList Tokenize(string sentence, string splitPattern = @"\s+") + { + sentence = Normalize(sentence); + var words = Regex.Split(sentence, splitPattern).Select(word => StartChar + word).ToArray(); + //words[0] = words[0][1..]; // Do not add the StartChar to the first word + return Tokenize(words); + } + + /// + /// Tokenize a sequence of words into a sequence of subtoken IDs. + /// + public IList TokenizeToId(IEnumerable words) + { + return words + .Select(word => TokensToIds(Bpe(word)) as IEnumerable) + .Aggregate((all, tokens) => all.Concat(tokens)) + .ToArray(); + } + + /// + /// Tokenize a sentence into a sequence of subtoken IDs. + /// The sentence will first be split into a sequence of words according to the given Regex pattern + /// before applying the tokenization algorithm. + /// + public IList TokenizeToId(string sentence, string splitPattern = @"\s+") + { + sentence = Normalize(sentence); + var words = Regex.Split(sentence, splitPattern).Select(word => StartChar + word).ToArray(); + //words[0] = words[0][1..]; // Do not add the StartChar to the first word + return TokenizeToId(words); + } + + /// + /// Convert a sequence of subtokens into their IDs. + /// + public IList TokensToIds(IEnumerable tokens) + { + return tokens.Select(token => + { + var id = Encoder[token]; + return id <= 0 ? -id : Vocabulary.IndexOf($"{id}"); + }) + .ToArray(); + } + + /// + /// Untokenize a sequence of subtoken IDs into the corresponding sentence text. + /// + public string Untokenize(IEnumerable tokenIds) + { + return DecodeFromConverted(tokenIds); + } + + /// + /// Untokenize a sequence of subtoken IDs into the corresponding tokens text. + /// + public IList UntokenizeToTokens(IEnumerable tokenIds) + { + return DecodeConvertedToTokens(tokenIds).ToArray(); + } + #endregion + + #region Normalization + /// + /// Modify some symbols to normalize a string text. + /// + private static string Normalize(string input) + { + var processed = input + .Trim() + .Replace("``", "\"") + .Replace("''", "\"") + .Replace("\\\"", "\""); + return processed; + } + + #endregion + + #region Initialization + private RobertaVocabulary GetVocabulary() + { + _ = LoadFromFileOrDownloadFromWeb(_path, DictName, DictUrl); + return new RobertaVocabulary(Path.Join(_path, DictName)); + } + + private Dictionary GetEncoder() + { + var contents = LoadFromFileOrDownloadFromWeb(_path, EncoderJsonName, EncoderJsonUrl); + + // Parse JSON + try + { + var jsonResult = (Dictionary)JsonSerializer.Deserialize( + contents, typeof(Dictionary)); + + jsonResult[Vocabulary.BosWord] = -Vocabulary.BosIndex; + jsonResult[Vocabulary.EosWord] = -Vocabulary.EosIndex; + jsonResult[Vocabulary.UnkWord] = -Vocabulary.UnkIndex; + jsonResult[Vocabulary.PadWord] = -Vocabulary.PadIndex; + + return jsonResult; + } + catch (JsonException e) + { + throw new JsonException($"Problems met when parsing JSON object in {EncoderJsonName}.\n" + + $"Error message: {e.Message}"); + } + } + + private (string, string)[] GetMerges() + { + var contents = LoadFromFileOrDownloadFromWeb(_path, MergeName, MergeUrl); + + // Parse merge info + try + { + var merges = contents.Split('\n')[1..^1].Select(line => + { + var split = line.Split(' '); + if (split[0] == "" || split[1] == "") + { + throw new Exception("Invalid format of merge file: \"{line}\""); + } + return (split[0], split[1]); + }).ToArray(); + return merges; + } + catch (Exception e) + { + throw new Exception($"Problems met when parsing records in {MergeName}.\n" + + $"Error message: {e.Message}"); + } + } + + private DefaultDictionary<(string, string), int> GetMergeRanks() + { + var mergeRanks = new DefaultDictionary<(string, string), int>(() => int.MaxValue); + for (var i = 0; i < Merges.Length; ++i) + { + mergeRanks.Add(Merges[i], i); + } + + return mergeRanks; + } + + /// + /// Returns list of utf-8 bytes and a corresponding list of unicode chars. + /// This mapping is to make unseen characters (such as control characters) displayable. + /// + private static Dictionary GetByteToUnicode() + { + var byteToUnicode = Enumerable.Range('!', '~' - '!' + 1) + .Concat(Enumerable.Range('¡', '¬' - '¡' + 1)) + .Concat(Enumerable.Range('®', 'ÿ' - '®' + 1)) + .ToDictionary(b => (char)b, b => (char)b); + + const int numChars = 256; + var n = 0; + foreach (var b in Enumerable.Range(0, numChars)) + { + if (byteToUnicode.ContainsKey((char)b)) continue; + byteToUnicode.Add((char)b, (char)(numChars + n)); + ++n; + } + + byteToUnicode.Add('Ġ', 'Ġ'); // Space char + + return byteToUnicode; + } + + public static string LoadFromFileOrDownloadFromWeb(string path, string fileName, Uri url) + { + var contents = string.Empty; + var filePath = Path.Join(path, fileName); + if (!File.Exists(filePath)) + { + try + { + using var webClient = new WebClient(); + contents = webClient.DownloadString(url); + } + catch (WebException e) + { + throw new WebException($"File {fileName} not found and cannot be downloaded from {url}.\n" + + $"Error message: {e.Message}"); + } + + try + { + File.WriteAllText(filePath, contents); + Console.WriteLine($"File {fileName} successfully downloaded from {url} and saved to {path}."); + } + catch (Exception e) + { + Console.WriteLine($"{DateTime.Now} - WARNING: File {fileName} successfully downloaded from {url}, " + + $"but error occurs when saving file {fileName} into {path}.\n" + + $"Error message: {e.Message}"); + } + + } + else + { + try + { + contents = File.ReadAllText(filePath); + } + catch (Exception e) + { + throw new IOException($"Problems met when reading {filePath}.\n" + + $"Error message: {e.Message}"); + } + } + + return contents; + } + #endregion + + #region Wrapper for BPE algorithm + /// + /// Decode converted token IDs and return the corresponding string. + /// Origin token ID is the token ID after BPE processing, and defines a mapping + /// between origin token IDs and converted token IDs. + /// + private string DecodeFromConverted(IEnumerable tokenIds) + { + var tokens = DecodeConvertedToTokens(tokenIds); + var text = string.Join(string.Empty, tokens).Replace(StartChar, ' ').Trim(); + return text; + } + + /// + /// Decode converted token IDs and return the corresponding tokens. + /// Origin token ID is the token ID after BPE processing, and defines a mapping + /// between origin token IDs and converted token IDs. + /// + private IEnumerable DecodeConvertedToTokens(IEnumerable tokenIds) + { + // 1. not to decode padding tokens + // 2. special tokens (BOS, EOS, PAD, UNK) in vocabulary are presented as strings rather than integers, + // so they will cause parsing failure. We treat their IDs as negative integers to avoid conflict with + // normal tokens. Other unrecognized tokens will be treated as token#13, which is ".". + var tokenArray = tokenIds + .Where(token => token != Vocabulary.PadIndex) + .Select(token => int.TryParse(Vocabulary[token], out var result) ? result : + token < RobertaVocabulary.NumSpecialSymbols ? -token : + 13) + .ToArray(); + var tokens = tokenArray.Select(id => + new string(Decoder[id] + .Where(c => UnicodeToByte.ContainsKey(c)) + .Select(c => UnicodeToByte[c]).ToArray())); + return tokens; + } + #endregion + + #region BPE algorithm + /// + /// Encode text with several tokens into BPE-ed sub-tokens. + /// + private IEnumerable Bpe(string word) + { + var convertedToken = string.Join("", word + .Where(ByteToUnicode.ContainsKey) + .Select(b => ByteToUnicode[b])); + if (convertedToken.Length == 0) return Array.Empty(); + return BpeToken(convertedToken); + } + + /// + /// Encode a token into BPE-ed sub-tokens. E.g., "playing" into ["play", "ing"]. + /// + private IEnumerable BpeToken(string token) + { + var word = token.Select(c => c.ToString()).ToList(); + var pairs = WordToPairs(word); + + if (pairs.Count == 0) + { + return new List { token }; + } + + while (true) + { + /* while conditions */ + // if only one element left, merge is finished (with the whole word merged) + if (word.Count == 1) + { + break; + } + + // get the most frequent bi-gram pair + var (first, second) = pairs.ArgMin(pair => MergeRanks[pair]); + if (!MergeRanks.ContainsKey((first, second))) + { + break; + } + /* end while conditions */ + + // search and merge all (first, second) pairs in {word} + var newWord = new List(); + var i = 0; + while (i < word.Count) + { + // find the next occurrence of {first} and add the elements before into {newWord} + var j = word.IndexOf(first, i); + if (j == -1) + { + newWord.AddRange(word.Skip(i)); + break; + } + else + { + newWord.AddRange(word.Skip(i).Take(j - i)); + i = j; + } + + // check the next element is {second} or not + if (i < word.Count - 1 && word[i + 1] == second) + { + newWord.Add(first + second); + i += 2; + } + else + { + newWord.Add(word[i]); + i += 1; + } + } + + word = newWord; + + // otherwise, continue merging + pairs = WordToPairs(word); + } + + return word; + } + + /// + /// Extract element pairs in an aggregating word. E.g. [p, l, ay] into [(p,l), (l,ay)]. + /// If word contains 0 or 1 element, an empty HashSet will be returned. + /// + private static HashSet<(string, string)> WordToPairs(IReadOnlyList word) + { + var pairs = new HashSet<(string, string)>(); + if (word.Count <= 1) return pairs; + + var prevElem = word[0]; + foreach (var elem in word.Skip(1)) + { + pairs.Add((prevElem, elem)); + prevElem = elem; + } + + return pairs; + } + #endregion + } + + internal static class IEnumerableExtension + { + public static T ArgMin(this IEnumerable source, Func getValue) + { + var keys = source.ToList(); // avoid enumerate twice + var values = keys.Select(getValue); + var (minSource, minValue) = keys.Zip(values).Aggregate((min, x) => min.Second <= x.Second ? min : x); + return minValue < int.MaxValue ? minSource : default; + } + } + + internal static class IReadOnlyDictionaryExtension + { + public static IReadOnlyDictionary Reverse(this IReadOnlyDictionary source) + { + var dictionary = new Dictionary(); + foreach (var (key, value) in source) + { + dictionary[value] = key; + } + return dictionary; + } + } + +} \ No newline at end of file diff --git a/src/Utils/RobertaVocabulary.cs b/src/Utils/RobertaVocabulary.cs new file mode 100644 index 0000000..7abd4b4 --- /dev/null +++ b/src/Utils/RobertaVocabulary.cs @@ -0,0 +1,168 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.IO; +using System.Text; + +namespace Examples.Utils +{ + /// + /// A mapping from symbols to consecutive integers. + /// + public class RobertaVocabulary + { + public const int NumSpecialSymbols = 4; + + public string PadWord { get; } + public string EosWord { get; } + public string UnkWord { get; } + public string BosWord { get; } + + public int PadIndex { get; } + public int EosIndex { get; } + public int UnkIndex { get; } + public int BosIndex { get; } + + public string MaskWord { get; private set; } + public int MaskIndex { get; private set; } + + private readonly List _symbols; + private readonly Dictionary _indices; + private readonly List _counter; + + public IReadOnlyDictionary Indices => _indices; + + /// + /// Loads the vocabulary from a text file with the format: + /// + /// + /// ... + /// + /// Any of `pad`, `eos`, `unk` and `bos` is `null`. + public RobertaVocabulary(string fileName, string pad = "", string eos = "", string unk = "", string bos = "", + string[] extraSpecialSymbols = null) + { + _indices = new Dictionary(); + _counter = new List(); + _symbols = new List(); + + PadWord = pad; + EosWord = eos; + UnkWord = unk; + BosWord = bos; + BosIndex = AddSymbol(bos); + PadIndex = AddSymbol(pad); + EosIndex = AddSymbol(eos); + UnkIndex = AddSymbol(unk); + + if (extraSpecialSymbols != null) + { + foreach (var symbol in extraSpecialSymbols) + { + AddSymbol(symbol); + } + } + + AddFromFile(fileName); + } + + /// + /// Add a word to the vocabulary. + /// + /// `word` is `null`. + private int AddSymbol(string word, int n = 1) + { + if (word == null) + { + throw new ArgumentNullException(nameof(word), $"argument {nameof(word)} should not be null."); + } + + int idx; + if (_indices.ContainsKey(word)) + { + idx = _indices[word]; + _counter[idx] += n; + } + else + { + idx = _symbols.Count; + _indices[word] = idx; + _symbols.Add(word); + _counter.Add(n); + } + + return idx; + } + + public int AddMaskSymbol(string mask = "") + { + MaskWord = mask; + MaskIndex = AddSymbol(mask); + return MaskIndex; + } + + /// `idx` is negative. + public string this[int idx] + { + get + { + if (idx < 0) + { + throw new ArgumentOutOfRangeException(nameof(idx), $"Index should be non-negative, got {idx}."); + } + + return idx < _symbols.Count ? _symbols[idx] : UnkWord; + } + } + + public int Count => _symbols.Count; + + public bool Contains(string symbol) => symbol != null && _indices.ContainsKey(symbol); + + /// `symbol` is `null`. + public int IndexOf(string symbol) => _indices.ContainsKey(symbol) ? _indices[symbol] : UnkIndex; + + /// + /// Loads a pre-existing vocabulary from a text file and adds its symbols to this instance. + /// + private void AddFromFile(string fileName) + { + var lines = File.ReadAllLines(fileName, Encoding.UTF8); + + foreach (var line in lines) + { + var splitLine = line.Trim().Split(' '); + if (splitLine.Length != 2) + { + throw new FileFormatException("Incorrect vocabulary format, expected \" \""); + } + + var word = splitLine[0]; + if (int.TryParse(splitLine[1], out var count)) + { + AddSymbol(word, count); + } + else + { + throw new FileFormatException($"Cannot parse {splitLine[1]} as an integer. File line: \"{line}\"."); + } + } + } + } + + public class FileFormatException : Exception + { + public FileFormatException() + { + } + + public FileFormatException(string message) : base(message) + { + } + + public FileFormatException(string message, Exception innerException) : base(message, innerException) + { + } + } + +} \ No newline at end of file diff --git a/src/Utils/SquadDatasets.cs b/src/Utils/SquadDatasets.cs new file mode 100644 index 0000000..3ae2219 --- /dev/null +++ b/src/Utils/SquadDatasets.cs @@ -0,0 +1,527 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Text.Json; +using TorchSharp; + +namespace Examples.Utils +{ + /// + /// Store the extracted samples of the SQuAD v2.0 data set. + /// The training and development data can be downloaded from the page at: https://rajpurkar.github.io/SQuAD-explorer/ + /// + public class SquadDataset + { + private static readonly Logger _logger = new(); + + private readonly int[] ZeroPad; + private readonly int[] NegBillionPad; + private readonly int[] TokenPad; + + private readonly RobertaTokenizer robertaTokenizer; + private readonly RobertaInputBuilder RobertaInputBuilder; + public string FilePath { get; } + public List Samples { get; } + + public int Count => Samples.Count; + + public SquadDataset(string filePath, RobertaTokenizer robertaTokenizer, RobertaInputBuilder inputBuilder) + { + const int negBillion = (int)-1e9; + ZeroPad = Enumerable.Repeat(0, inputBuilder.MaxPositions).ToArray(); + NegBillionPad = Enumerable.Repeat(negBillion, inputBuilder.MaxPositions).ToArray(); + TokenPad = Enumerable.Repeat(robertaTokenizer.PadIndex, inputBuilder.MaxPositions).ToArray(); + + FilePath = filePath; + this.robertaTokenizer = robertaTokenizer; + RobertaInputBuilder = inputBuilder; + Samples = LoadDataset(filePath); + } + + private List LoadDataset(string filePath) + { + var samples = new List(); + var dataset = JsonSerializer.Deserialize(File.ReadAllText(filePath)); + var start = DateTime.Now; + var count = 1; + foreach (var document in dataset.data) + { + var end = DateTime.Now; + _logger.LogLoop($"Processing document {count}/{dataset.data.Count}, time consumed {end - start}"); + + foreach (var paragraph in document.paragraphs) + { + var contextTokens = robertaTokenizer.Tokenize(paragraph.context); + var contextTokenIds = robertaTokenizer.TokensToIds(contextTokens); + var mapping = AlignAnswerPosition(contextTokens, paragraph.context); + if (mapping == null) + { + continue; + } + + foreach (var qapair in paragraph.qas) + { + var questionTokenIds = robertaTokenizer.TokenizeToId(qapair.question); + + // Treat plausible answers as normal answers, which is not a normal behavior. + //var answerList = qapair.is_impossible ? qapair.plausible_answers : qapair.answers; + //var answerPositions = new List<(int, int)>(); + //foreach (var answer in answerList) + //{ + // // inclusive [start, end] + // var startPosition = mapping[answer.answer_start]; + // var endPosition = mapping[answer.answer_start + answer.text.Length - 1]; + // answerPositions.Add((startPosition, endPosition)); + //} + + // Treat plausible answers as "no answers". + var answerPositions = new List<(int, int)>(); + if (!qapair.is_impossible) + { + foreach (var answer in qapair.answers) + { + // inclusive [start, end] + var startPosition = mapping[answer.answer_start]; + var endPosition = mapping[answer.answer_start + answer.text.Length - 1]; + answerPositions.Add((startPosition, endPosition)); + } + } + + samples.Add(new SquadSample( + title: document.title, + url: document.url, + context: paragraph.context, + contextTokens: contextTokenIds, + id: qapair.id, + question: qapair.question, + questionTokens: questionTokenIds, + answers: answerPositions)); + } + } + + count++; + } + _logger.LogAppend(); + + return samples; + } + + private Dictionary AlignAnswerPosition(IList tokens, string text) + { + var mapping = new Dictionary(); + int surrogateDeduce = 0; + for (var (i, j, tid) = (0, 0, 0); i < text.Length && tid < tokens.Count;) + { + // Move to a new token + if (j >= tokens[tid].Length) + { + ++tid; + j = 0; + } + // There are a few UTF-32 chars in corpus, which is considered one char in position + else if (i + 1 < text.Length && char.IsSurrogatePair(text[i], text[i + 1])) + { + i += 2; + ++surrogateDeduce; + } + // White spaces are not included in tokens + else if (char.IsWhiteSpace(text[i])) + { + ++i; + } + // Chars not included in tokenizer will not appear in tokens + else if (!robertaTokenizer.ByteToUnicode.ContainsKey(text[i])) + { + mapping[i - surrogateDeduce] = tid; + ++i; + } + // "\\\"", "``" and "''" converted to "\"" in normalizer + else if (i + 1 < text.Length && tokens[tid][j] == '"' + && ((text[i] == '`' && text[i + 1] == '`') + || (text[i] == '\'' && text[i + 1] == '\'') + || (text[i] == '\\' && text[i + 1] == '"'))) + { + mapping[i - surrogateDeduce] = mapping[i + 1 - surrogateDeduce] = tid; + i += 2; + j += 1; + } + // Normal match + else if (text[i] == tokens[tid][j]) + { + mapping[i - surrogateDeduce] = tid; + ++i; + ++j; + } + // There are a few real \u0120 chars in the corpus, so this rule has to be later than text[i] == tokens[tid][j]. + else if (tokens[tid][j] == '\u0120' && j == 0) + { + ++j; + } + else + { + throw new Exception("unmatched!"); + } + } + + return mapping; + } + + public IEnumerable GetBatches(QuestionAnsweringConfig config, bool shuffle) + { + var buffer = new List<(int[], int[], int[], int[], int[], int)>(); + var samples = Samples; + if (shuffle) + { + var random = new Random(); + samples = samples.OrderBy(_ => random.Next()).ToList(); + } + foreach (var sample in samples) + { + (int[] tokens, int[] positions, int[] segments, int[] startPositions, int[] endPositions, int questionSegmentLength) ioArrays + = RobertaInputBuilder.Build(sample.QuestionTokens, sample.ContextTokens, sample.Answers); + + buffer.Add(ioArrays); + if (buffer.Count == config.BatchSize) + { + yield return BufferToBatch(buffer, config.Cuda); + buffer.Clear(); + } + } + if (buffer.Count > 0) + { + yield return BufferToBatch(buffer, config.Cuda); + buffer.Clear(); + } + } + + public static SquadSampleBatch GetMaxDummyBatch(QuestionAnsweringConfig config) + { + var batchSize = config.BatchSize; + var maxLength = config.MaxSequence; + var maxAnswer = 1; + var device = config.Cuda ? torch.CUDA : torch.CPU; + + return new SquadSampleBatch + { + Tokens = torch.zeros(batchSize, maxLength, dtype: torch.int64, device: device), + Positions = torch.zeros(batchSize, maxLength, dtype: torch.int64, device: device), + Segments = torch.zeros(batchSize, maxLength, dtype: torch.int64, device: device), + Starts = torch.zeros(batchSize, maxAnswer, dtype: torch.int64, device: device), + Ends = torch.zeros(batchSize, maxAnswer, dtype: torch.int64, device: device), + AttentionMasks = torch.zeros(batchSize, 1, 1, maxLength, dtype: torch.float32, device: device), + PredictMasks = null, + }; + } + + private SquadSampleBatch BufferToBatch(List<(int[], int[], int[], int[], int[], int)> buffer, bool cuda) + { + using var disposeScope = torch.NewDisposeScope(); + var device = cuda ? torch.CUDA : torch.CPU; + var maxLength = buffer.Max(tensors => tensors.Item1.Length); + var maxAnswer = buffer.Max(tensors => tensors.Item4.Length); + var tokens = new torch.Tensor[buffer.Count]; + var positions = new torch.Tensor[buffer.Count]; + var segments = new torch.Tensor[buffer.Count]; + var starts = new torch.Tensor[buffer.Count]; + var ends = new torch.Tensor[buffer.Count]; + var attentionMasks = new torch.Tensor[buffer.Count]; + var predictMasks = new torch.Tensor[buffer.Count]; + for (var i = 0; i < buffer.Count; ++i) + { + var arrays = buffer[i]; + var length = arrays.Item1.Length; + var questionSegmentLength = arrays.Item6; + var token = torch.tensor(arrays.Item1.Concat(TokenPad[..(maxLength - length)]).ToArray(), + 1, maxLength, dtype: torch.int64, device: device); + var position = torch.tensor(arrays.Item2.Concat(ZeroPad[..(maxLength - length)]).ToArray(), + 1, maxLength, dtype: torch.int64, device: device); + var segment = torch.tensor(arrays.Item3.Concat(ZeroPad[..(maxLength - length)]).ToArray(), + 1, maxLength, dtype: torch.int64, device: device); + var attentionMask = torch.tensor(ZeroPad[..length].Concat(NegBillionPad[..(maxLength - length)]).ToArray(), + new long[] { 1, 1, 1, maxLength }, dtype: torch.float32, device: device); + var predictMask = torch.tensor( + NegBillionPad[..questionSegmentLength].Concat(ZeroPad[..(length - questionSegmentLength - 1)]) + .Concat(NegBillionPad[..(maxLength - length + 1)]).ToArray(), + 1, maxLength, dtype: torch.float32, device: device); + + var answer = arrays.Item4.Length; + var start = torch.tensor(arrays.Item4.Concat(Enumerable.Repeat(arrays.Item4[^1], maxAnswer - answer)).ToArray(), + 1, maxAnswer, dtype: torch.int64, device: device); + var end = torch.tensor(arrays.Item5.Concat(Enumerable.Repeat(arrays.Item5[^1], maxAnswer - answer)).ToArray(), + 1, maxAnswer, dtype: torch.int64, device: device); + + tokens[i] = token; + positions[i] = position; + segments[i] = segment; + starts[i] = start; + ends[i] = end; + attentionMasks[i] = attentionMask; + predictMasks[i] = predictMask; + } + + return new SquadSampleBatch + { + Tokens = torch.cat(tokens, dimension: 0).MoveToOuterDisposeScope(), + Positions = torch.cat(positions, dimension: 0).MoveToOuterDisposeScope(), + Segments = torch.cat(segments, dimension: 0).MoveToOuterDisposeScope(), + Starts = torch.cat(starts, dimension: 0).MoveToOuterDisposeScope(), + Ends = torch.cat(ends, dimension: 0).MoveToOuterDisposeScope(), + AttentionMasks = torch.cat(attentionMasks, dimension: 0).MoveToOuterDisposeScope(), + PredictMasks = torch.cat(predictMasks, dimension: 0).MoveToOuterDisposeScope(), + }; + } + } + + /// + /// Store the extracted document corpus from a data set file with the same format as SQuAD v2.0. + /// The SQuAD v2.0 data can be downloaded from the page at: https://rajpurkar.github.io/SQuAD-explorer/ + /// + public class SquadCorpus + { + private static readonly Logger _logger = new(); + + private readonly int[] ZeroPad; + private readonly int[] NegBillionPad; + private readonly int[] TokenPad; + + private readonly RobertaTokenizer robertaTokenizer; + private readonly RobertaInputBuilder RobertaInputBuilder; + public readonly string FilePath; + public readonly IReadOnlyList Documents; + + public int Count => Documents.Count; + + public SquadCorpus(string filePath, RobertaTokenizer robertaTokenizer, RobertaInputBuilder inputBuilder) + { + int negBillion = (int)-1e9; + ZeroPad = Enumerable.Repeat(0, inputBuilder.MaxPositions).ToArray(); + NegBillionPad = Enumerable.Repeat(negBillion, inputBuilder.MaxPositions).ToArray(); + TokenPad = Enumerable.Repeat(robertaTokenizer.PadIndex, inputBuilder.MaxPositions).ToArray(); + + FilePath = filePath; + this.robertaTokenizer = robertaTokenizer; + RobertaInputBuilder = inputBuilder; + Documents = LoadDocuments(filePath); + } + + private List LoadDocuments(string filePath) + { + var documents = new List(); + var dataset = JsonSerializer.Deserialize(File.ReadAllText(filePath)); + var start = DateTime.Now; + var count = 1; + foreach (var document in dataset.data) + { + var end = DateTime.Now; + _logger.LogLoop($"Processing document {count}/{dataset.data.Count}, time consumed {end - start}"); + + foreach (var paragraph in document.paragraphs) + { + var contextTokenIds = robertaTokenizer.TokenizeToId(paragraph.context, @"\s+"); + + documents.Add(new SquadDocument( + title: document.title, + url: document.url, + context: paragraph.context, + contextTokens: contextTokenIds)); + } + + count++; + } + _logger.LogAppend(); + + return documents; + } + + public static SquadSampleBatch GetMaxDummyBatch(QuestionAnsweringConfig config) + { + using var disposeScope = torch.NewDisposeScope(); + var batchSize = config.BatchSize; + var maxLength = config.MaxSequence; + var device = config.Cuda ? torch.CUDA : torch.CPU; + + var token = torch.zeros(batchSize, maxLength, dtype: torch.int64, device: device); + var position = torch.zeros(batchSize, maxLength, dtype: torch.int64, device: device); + var segment = torch.zeros(batchSize, maxLength, dtype: torch.int64, device: device); + var attentionMask = torch.zeros(batchSize, 1, 1, maxLength, dtype: torch.float32, device: device); + + return new SquadSampleBatch + { + Tokens = token.MoveToOuterDisposeScope(), + Positions = position.MoveToOuterDisposeScope(), + Segments = segment.MoveToOuterDisposeScope(), + Starts = null, + Ends = null, + AttentionMasks = attentionMask.MoveToOuterDisposeScope(), + PredictMasks = null + }; + } + + public IEnumerable GetBatches(QuestionAnsweringConfig config, string questionText) + { + // parse question + var questionTokenIds = robertaTokenizer.TokenizeToId(questionText); + + foreach (var batch in GetBatches(config, questionTokenIds, Documents)) + { + yield return batch; + } + } + + public IEnumerable GetBatches(QuestionAnsweringConfig config, IList questionTokens, IReadOnlyList documents) + { + var buffer = new List<(int[], int[], int[], int)>(); + foreach (var document in documents) + { + var ioArrays = RobertaInputBuilder.Build(questionTokens, document.ContextTokens); + buffer.Add(ioArrays); + if (buffer.Count == config.BatchSize) + { + yield return BufferToBatch(buffer, config.Cuda); + buffer.Clear(); + } + } + if (buffer.Count > 0) + { + yield return BufferToBatch(buffer, config.Cuda); + buffer.Clear(); + } + } + + private SquadSampleBatch BufferToBatch(List<(int[], int[], int[], int)> buffer, bool cuda) + { + using var disposeScope = torch.NewDisposeScope(); + var device = cuda ? torch.CUDA : torch.CPU; + var maxLength = buffer.Max(tensors => tensors.Item1.Length); + var tokens = new torch.Tensor[buffer.Count]; + var positions = new torch.Tensor[buffer.Count]; + var segments = new torch.Tensor[buffer.Count]; + var attentionMasks = new torch.Tensor[buffer.Count]; + var predictMasks = new torch.Tensor[buffer.Count]; + for (var i = 0; i < buffer.Count; ++i) + { + var arrays = buffer[i]; + var length = arrays.Item1.Length; + var questionSegmentLength = arrays.Item4; + var token = torch.tensor(arrays.Item1.Concat(TokenPad[..(maxLength - length)]).ToArray(), + 1, maxLength, dtype: torch.int64, device: device); + var position = torch.tensor(arrays.Item2.Concat(ZeroPad[..(maxLength - length)]).ToArray(), + 1, maxLength, dtype: torch.int64, device: device); + var segment = torch.tensor(arrays.Item3.Concat(ZeroPad[..(maxLength - length)]).ToArray(), + 1, maxLength, dtype: torch.int64, device: device); + var attentionMask = torch.tensor(ZeroPad[..length].Concat(NegBillionPad[..(maxLength - length)]).ToArray(), + new long[] { 1, 1, 1, maxLength }, dtype: torch.float32, device: device); + var predictMask = torch.tensor( + NegBillionPad[..questionSegmentLength].Concat(ZeroPad[..(length - questionSegmentLength - 1)]) + .Concat(NegBillionPad[..(maxLength - length + 1)]).ToArray(), + 1, maxLength, dtype: torch.float32, device: device); + + tokens[i] = token; + positions[i] = position; + segments[i] = segment; + attentionMasks[i] = attentionMask; + predictMasks[i] = predictMask; + } + + return new SquadSampleBatch + { + Tokens = torch.cat(tokens, dimension: 0).MoveToOuterDisposeScope(), + Positions = torch.cat(positions, dimension: 0).MoveToOuterDisposeScope(), + Segments = torch.cat(segments, dimension: 0).MoveToOuterDisposeScope(), + Starts = null, + Ends = null, + AttentionMasks = torch.cat(attentionMasks, dimension: 0).MoveToOuterDisposeScope(), + PredictMasks = torch.cat(predictMasks, dimension: 0).MoveToOuterDisposeScope(), + }; + } + } + + public class SquadDocument + { + public string Title { get; } + public string Url { get; } + public string Context { get; } + public IList ContextTokens { get; } + + public SquadDocument(string title, string url, string context, IList contextTokens) + { + Title = title; + Url = url; + Context = context; + ContextTokens = contextTokens; + } + } + + public class SquadSample + { + public string Title { get; } + public string Url { get; } + public string Context { get; } + public IList ContextTokens { get; } + public string Id { get; } + public string Question { get; } + public IList QuestionTokens { get; } + public IList<(int, int)> Answers { get; } + + public SquadSample(string title, string url, string context, IList contextTokens, string id, + string question, IList questionTokens, IList<(int, int)> answers) + { + Title = title; + Url = url; + Context = context; + ContextTokens = contextTokens; + Id = id; + Question = question; + QuestionTokens = questionTokens; + Answers = answers; + } + } + + public struct SquadSampleBatch + { + public torch.Tensor Tokens; + public torch.Tensor Positions; + public torch.Tensor Segments; + public torch.Tensor Starts; + public torch.Tensor Ends; + public torch.Tensor AttentionMasks; + public torch.Tensor PredictMasks; + } + + internal struct SquadDataFromFile + { + public string version { get; set; } + public IList data { get; set; } + } + + internal struct SquadData + { + public string title { get; set; } + public string url { get; set; } + public IList paragraphs { get; set; } + } + + internal struct SquadParagraph + { + public string context { get; set; } + public IList qas { get; set; } + } + + internal struct SquadQAPair + { + public string question { get; set; } + public string id { get; set; } + public IList answers { get; set; } + public IList plausible_answers { get; set; } + public bool is_impossible { get; set; } + } + + internal struct SquadAnswer + { + public string text { get; set; } + public int answer_start { get; set; } + } +} diff --git a/src/Utils/SquadMetric.cs b/src/Utils/SquadMetric.cs new file mode 100644 index 0000000..ec105e6 --- /dev/null +++ b/src/Utils/SquadMetric.cs @@ -0,0 +1,87 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.Linq; +using TorchSharp; + +namespace Examples.Utils +{ + public static class SquadMetric + { + public static double CosineSimilarity(IList vector1, IList vector2) + { + if (vector1.Count != vector2.Count) + { + throw new ArgumentException( + $"Vectors to compute cosine similarity must have the same length, got {vector1.Count} and {vector2.Count}"); + } + + var xMax = vector1.Max(); + var yMax = vector2.Max(); + var xEnum = vector1.Select(x => x / xMax); + var yEnum = vector2.Select(y => y / yMax); + var nominator = Enumerable.Zip(xEnum, yEnum).Sum(pair => pair.First * pair.Second); + var denominator = Math.Sqrt(xEnum.Sum(x => x * x) * yEnum.Sum(y => y * y)); + return nominator / denominator; + } + + public static double ComputeF1(long predictionStart, long predictionEnd, long truthStart, long truthEnd) + { + // No answer + if ((truthStart | truthEnd) == 0) + { + if ((predictionStart | predictionEnd) == 0) return 1.0; + else return 0.0; + } + // Has answer: prediction invalid + if (predictionStart > predictionEnd) + { + return 0.0; + } + // Has answer: prediction valid + var prediction = predictionEnd - predictionStart + 1; + var truth = truthEnd - truthStart + 1; + var overlap = ComputeOverlap(predictionStart, predictionEnd, truthStart, truthEnd); + if (overlap == 0) + { + return 0.0; + } + var precision = 1.0 * overlap / prediction; + var recall = 1.0 * overlap / truth; + var f1 = 2 * precision * recall / (precision + recall); + return f1; + } + + private static long ComputeOverlap(long predictionStart, long predictionEnd, long truthStart, long truthEnd) + { + var arr = new long[] { predictionStart, predictionEnd, truthStart, truthEnd }; + var max = arr.Max(); + var min = arr.Min(); + //var overlap = (truthEnd - truthStart + 1) + (predictionEnd - predictionStart + 1) - (max - min + 1); + var overlap = (truthEnd - truthStart) + (predictionEnd - predictionStart) - (max - min) + 1; + return Math.Max(overlap, 0); + } + + public static IList<(int start, int end, double score)> ComputeTopKSpansWithScore( + torch.Tensor predictStartScores, torch.Tensor predictStarts, torch.Tensor predictEndScores, torch.Tensor predictEnds, int k) + { + var startScores = predictStartScores.ToArray(); + var endScores = predictEndScores.ToArray(); + var starts = predictStarts.ToArray(); + var ends = predictEnds.ToArray(); + var topK = new List<(int, int, double)>(); + for (var i = 0; i < starts.Length; ++i) + { + for (var j = 0; j < ends.Length; ++j) + { + if (starts[i] <= ends[j]) + { + topK.Add(((int)starts[i], (int)ends[j], startScores[i] * endScores[j])); + } + } + } + topK = topK.OrderByDescending(tuple => tuple.Item3).Take(k).ToList(); + return topK; + } + } +} \ No newline at end of file diff --git a/src/Utils/TensorExtension.cs b/src/Utils/TensorExtension.cs new file mode 100644 index 0000000..8b04691 --- /dev/null +++ b/src/Utils/TensorExtension.cs @@ -0,0 +1,31 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System.Diagnostics.CodeAnalysis; +using TorchSharp; + +namespace Examples.Utils +{ + public static class TensorExtension + { +#nullable enable + public static bool IsNull([NotNullWhen(false)] this torch.Tensor? tensor) + { + return tensor is null || tensor.IsInvalid; + } + + public static bool IsNotNull([NotNullWhen(true)] this torch.Tensor? tensor) + { + return !tensor.IsNull(); + } +#nullable disable + + public static T[] ToArray(this torch.Tensor tensor) where T : unmanaged + { + return tensor.cpu().data().ToArray(); + } + + public static T ToItem(this torch.Tensor tensor) where T : unmanaged + { + return tensor.cpu().item(); + } + } +} \ No newline at end of file diff --git a/src/Utils/TfIdfDocumentSelector.cs b/src/Utils/TfIdfDocumentSelector.cs new file mode 100644 index 0000000..3c48bc7 --- /dev/null +++ b/src/Utils/TfIdfDocumentSelector.cs @@ -0,0 +1,197 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.Linq; + +namespace Examples.Utils +{ + /// + /// A class for selecting the best matched documents given a query using a variant of TF-IDF algorithm. + /// Visit https://en.wikipedia.org/wiki/Tf%E2%80%93idf for more information about the algorithm. + /// + public class TfIdfDocumentSelector + { + /// + /// Smoothing factor for max-normalization of tf. + /// + private const double MaxNormSmoothing = 0.01; + + /// + /// Weight of title. + /// + private const double TitleWeight = 0.1; + + /// + /// Consider some tokens as stop words and do not count them in cosine similarity: + /// the first N tokens of the highest frequency in the vocab. + /// + private const int StopWordFilterIndex = 500; + + /// + /// Consider some tokens as stop words and do not count them in cosine similarity: + /// the tokens with document frequency >= RATIO * total # documents. + /// + private const double StopWordFilterRatio = 0.25; + + /// + /// SquadDocument frequency of each token. + /// + private readonly int[] DocumentFrequency; + + private readonly SquadDocument[] Documents; + private readonly int DocumentCount; + private readonly int VocabularySize; + + public IReadOnlyDictionary DocumentVectors { get; } + + private readonly RobertaTokenizer Tokenizer; + private readonly HashSet StopWordsTokenIds; + + public TfIdfDocumentSelector(IEnumerable documents, RobertaTokenizer tokenizer, bool useTitle = true) + { + Documents = documents.ToArray(); + Tokenizer = tokenizer; + VocabularySize = tokenizer.VocabSize; + + DocumentCount = Documents.Length; + DocumentFrequency = new int[VocabularySize]; + foreach (var doc in Documents) + { + var contextTokens = Tokenizer.TokenizeToId(doc.Context.ToLower()); + foreach (var token in contextTokens.ToHashSet()) + { + ++DocumentFrequency[token]; + } + } + + var stopWordsTokenIds = Enumerable.Range(0, StopWordFilterIndex); // high frequency + for (var i = 0; i < VocabularySize; ++i) + { + if (DocumentFrequency[i] >= StopWordFilterRatio * Documents.Length) + { + stopWordsTokenIds = stopWordsTokenIds.Append(i); + } + } + StopWordsTokenIds = stopWordsTokenIds.ToHashSet(); + + if (useTitle) + { + var documentVectors = new Dictionary(); + foreach (var doc in Documents) + { + var titleTokens = Tokenizer.TokenizeToId(doc.Title.ToLower()); + var titleVector = TfIdf(titleTokens); + var contextTokens = Tokenizer.TokenizeToId(doc.Context.ToLower()); + var contextVector = TfIdf(contextTokens, maxNormTf: true); + var combined = Enumerable.Zip(titleVector, contextVector) + .Select(pair => TitleWeight * pair.First + (1 - TitleWeight) * pair.Second) + .ToArray(); + documentVectors[doc] = combined; + } + DocumentVectors = documentVectors; + } + else + { + DocumentVectors = Documents.ToDictionary( + doc => doc, + doc => TfIdf(Tokenizer.TokenizeToId(doc.Context.ToLower()), maxNormTf: true)); + } + } + + public SquadDocument Top1(string question) + { + return TopK(question, 1)[0]; + } + + public IList TopK(string question, int k) + { + var questionTokens = Tokenizer.TokenizeToId(question.ToLower()); + var weight = TfIdf(questionTokens); + var matchedDocuments = DocumentVectors + .Select(pair => (Doc: pair.Key, Sim: CosineSimilarityOfQuestion(weight, pair.Value))) + .OrderByDescending(pair => pair.Sim) + .Take(k) + .Select(pair => pair.Doc) + .ToArray(); + return matchedDocuments; + ; + } + + private double Tf(double termFrequency, bool sublinear) + { + return sublinear ? Math.Log(termFrequency) + 1 : termFrequency; + } + + private double Idf(int token, bool smooth) + { + var ifreq = smooth + ? 1.0 * (1 + DocumentCount) / (1 + DocumentFrequency[token]) + : 1.0 * DocumentCount / DocumentFrequency[token]; + return 1 + Math.Log(ifreq); + } + + /// + /// TF-IDF value of one token given the term frequency. + /// + private double TfIdf(double termFrequency, int token, bool sublinearTf, bool smoothIdf, bool noIdf) + { + return noIdf + ? Tf(termFrequency, sublinearTf) + : Tf(termFrequency, sublinearTf) * Idf(token, smoothIdf); + } + + /// + /// TF-IDF value of all tokens in a document. + /// + /// The tf-idf vector over the whole vocabulary. + private double[] TfIdf(IList documentTokens, bool sublinearTf = false, bool smoothIdf = true, bool noIdf = false, + bool maxNormTf = false) + { + // Remove + documentTokens = documentTokens.SkipWhile(token => token < 4).ToArray(); + + // Compute term frequency + var frequency = new int[VocabularySize]; + foreach (var token in documentTokens) + { + ++frequency[token]; + } + + var weight = new double[VocabularySize]; + var maxFrequency = frequency.Max(); + foreach (var token in documentTokens.ToHashSet()) + { + var freq = maxNormTf ? MaxNormSmoothing + (1 - MaxNormSmoothing) * frequency[token] / maxFrequency : frequency[token]; + weight[token] = TfIdf(freq, token, sublinearTf, smoothIdf, noIdf); + } + return weight; + } + + private double CosineSimilarityOfQuestion(IList question, IList document) + { + if (question.Count != document.Count) + { + throw new ArgumentException( + $"Vectors to compute cosine similarity must have the same length, got {question.Count} and {document.Count}"); + } + + var qMax = question.Max(); + var dMax = document.Max(); + var nominator = 0.0; + var qSum = 0.0; + var dSum = 0.0; + for (var i = 0; i < question.Count; ++i) + { + var q = question[i] / qMax; + var d = document[i] / dMax; + if (!StopWordsTokenIds.Contains(i)) + { + nominator += q * d; + qSum += q * q; + dSum += d * d; + } + } + return nominator / Math.Sqrt(qSum * dSum); + } + } +} \ No newline at end of file