|
| 1 | +# Pytorch-Transformer Implementation |
| 2 | +This repository contains an implementation of the [Transformer model (Attention is All You Need)](https://arxiv.org/abs/1706.03762) in PyTorch. The model is trained and tested on a dummy dataset consisting of tokens `<sos>=0`, `<eos>=1`, `<pad>=2`, and additional tokens 3 and 4, representing sequences. The core architecture is located in the `model/` directory. |
| 3 | + |
| 4 | +<img src="./assets/transformer.png" alt="transformer" width="50%"> |
| 5 | + |
| 6 | +## Guide |
| 7 | +Before you run the commands, modify the configurations in `data/config.yaml` as per your requirements. |
| 8 | +### Training |
| 9 | +Run the following command to start training the model: |
| 10 | +``` |
| 11 | +python main.py --output ${OUTPUT_PATH} --log ${LOG_PATH} --cfg ${CFG_PATH} |
| 12 | +``` |
| 13 | + |
| 14 | +### Testing |
| 15 | +Run the following command to start testing the model: |
| 16 | +``` |
| 17 | +python test.py --model ${MODEL_PATH} --cfg ${CFG_PATH} |
| 18 | +``` |
| 19 | +## Implementation |
| 20 | +The main Transformer architecture is defined as follows. Other components like the Encoder and Decoder are implemented separately in the `model/` directory: |
| 21 | +``` |
| 22 | +class Transformer(nn.Module): |
| 23 | + def __init__(self, enc_vsize, dec_vsize, d_model, max_len, dropout_p=0.1, n_heads=8, n_layers=6, d_ff=2048, device=None, |
| 24 | + src_pad_idx=0, tgt_pad_idx=0): |
| 25 | + super(Transformer, self).__init__() |
| 26 | + self.device = device |
| 27 | +
|
| 28 | + self.encoder = Encoder(vocab_size=enc_vsize, |
| 29 | + d_model=d_model, |
| 30 | + max_len=max_len, |
| 31 | + dropout_p=dropout_p, |
| 32 | + n_heads=n_heads, |
| 33 | + n_layers=n_layers, |
| 34 | + d_ff=d_ff, |
| 35 | + device=device) |
| 36 | + |
| 37 | + self.decoder = Decoder(vocab_size=dec_vsize, |
| 38 | + d_model=d_model, |
| 39 | + max_len = max_len, |
| 40 | + dropout_p=dropout_p, |
| 41 | + n_heads=n_heads, |
| 42 | + n_layers=n_layers, |
| 43 | + d_ff=d_ff, |
| 44 | + device=device) |
| 45 | + self.src_pad_idx = src_pad_idx |
| 46 | + self.tgt_pad_idx = tgt_pad_idx |
| 47 | +
|
| 48 | + def make_src_mask(self, source) -> torch.Tensor: |
| 49 | + """Padding mask""" |
| 50 | + src_mask = (source != self.src_pad_idx).unsqueeze(1).unsqueeze(2) # batch_size x seq_len -> batch_size x 1 x 1 x seq_len |
| 51 | + return src_mask |
| 52 | + |
| 53 | + def make_target_mask(self, target) -> torch.Tensor: |
| 54 | + """ |
| 55 | + 1) padding mask - finds padding token and assigns False |
| 56 | + 2) attention mask (target mask) - limits access available parts |
| 57 | + """ |
| 58 | + padding_mask = (target != self.tgt_pad_idx).unsqueeze(1).unsqueeze(3) |
| 59 | + target_seq_len = target.size(1) |
| 60 | + nopeak_mask = (1 - torch.triu(torch.ones(1, target_seq_len, target_seq_len), diagonal=1)).bool().to(self.device) |
| 61 | + target_mask = nopeak_mask & padding_mask |
| 62 | + |
| 63 | + return target_mask |
| 64 | + |
| 65 | + def forward(self, src, tgt): |
| 66 | + src_mask = self.make_src_mask(src) # batch_size x 1 x 1 x src_seq_len |
| 67 | + tgt_mask = self.make_target_mask(tgt) # batch_size x 1 x 1 x tgt_seq_len |
| 68 | +
|
| 69 | + enc_emb = self.encoder(src, src_mask) # batch_size x src_seq_len x d_model |
| 70 | + tgt_emb = self.decoder(enc_emb, tgt, src_mask, tgt_mask) # batch_size x tgt_seq_len x tgt_vocab_size |
| 71 | + return tgt_emb # No softmax as applied in CrossEntroyLoss |
| 72 | +
|
| 73 | +``` |
| 74 | + |
| 75 | +## Dataset |
| 76 | +**Tokens**: |
| 77 | +- `SOS` token: `0` |
| 78 | +- `EOS` token: `1` |
| 79 | +- `PAD` token: `2` (not used in this function) |
| 80 | +- `WORDS`: `3`, `4` (used to generate patterns) |
| 81 | + |
| 82 | +**Patterns**: |
| 83 | +- Sequence of all 3s: `[0, 3, 3, 3, 3, 3, 3, 3, 3, 1]` |
| 84 | +- Sequence of all 4s: `[0, 4, 4, 4, 4, 4, 4, 4, 4, 1]` |
| 85 | +- Alternating 3s and 4s starting with 3: `[0, 3, 4, 3, 4, 3, 4, 3, 4, 1]` |
| 86 | +- Alternating 3s and 4s starting with 4: `[0, 4, 3, 4, 3, 4, 3, 4, 3, 1]` |
| 87 | + |
| 88 | +## Results |
| 89 | +### Training |
| 90 | +The below graph is about the model traind until 20 epochs and 5 warmup steps. You can download the trained model [here](https://drive.google.com/file/d/1R-JXH_cFMXFKrfejEqrBj36gUigDgIyX/view?usp=sharing). |
| 91 | + |
| 92 | +<img src="./logs/best_model.png" alt="log" width="80%"/> |
| 93 | + |
| 94 | +### Inference |
| 95 | +``` |
| 96 | +Example 0 |
| 97 | +Input: [3, 3, 3, 3, 3, 3, 3, 3] |
| 98 | +Continuation: [3, 3, 3, 3, 3, 3, 3, 3] |
| 99 | +
|
| 100 | +Example 1 |
| 101 | +Input: [4, 4, 4, 4, 4, 4, 4, 4] |
| 102 | +Continuation: [4, 4, 4, 4, 4, 4, 4, 4] |
| 103 | +
|
| 104 | +Example 2 |
| 105 | +Input: [3, 4, 3, 4, 3, 4, 3, 4] |
| 106 | +Continuation: [3, 4, 3, 4, 3, 4, 3, 4] |
| 107 | +
|
| 108 | +Example 3 |
| 109 | +Input: [4, 3, 4, 3, 4, 3, 4, 3] |
| 110 | +Continuation: [3, 4, 3, 4, 3, 4, 3, 4] |
| 111 | +
|
| 112 | +Example 4 |
| 113 | +Input: [3, 4, 3] |
| 114 | +Continuation: [3, 4, 3, 4, 3, 4, 3, 4] |
| 115 | +``` |
| 116 | + |
| 117 | +## Configurations Structure |
| 118 | + |
| 119 | +``` |
| 120 | +train: |
| 121 | + batch_size: |
| 122 | + epochs: |
| 123 | + learning_rate: |
| 124 | + d_model: |
| 125 | + n_heads: |
| 126 | + n_layers: |
| 127 | + d_ff: |
| 128 | + dropout_p: |
| 129 | + max_len: |
| 130 | + warmup_steps: |
| 131 | +
|
| 132 | +test: |
| 133 | + d_model: |
| 134 | + n_heads: |
| 135 | + n_layers: |
| 136 | + d_ff: |
| 137 | + dropout_p: |
| 138 | + max_len: |
| 139 | +
|
| 140 | +``` |
| 141 | + |
| 142 | +## References |
| 143 | +- [Attention Is All You Need](https://arxiv.org/abs/1706.03762) |
| 144 | +- [Transformer: PyTorch Implementation of "Attention Is All You Need"](https://github.com/hyunwoongko/transformer/tree/master) |
| 145 | +- [A detailed guide to PyTorch’s nn.Transformer() module](https://towardsdatascience.com/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1) |
| 146 | + |
| 147 | +## TO-DO |
| 148 | +- [x] Add Encoder, Decoder |
| 149 | +- [x] Training/Validation logic with dataset |
| 150 | +- [x] Refactoring |
| 151 | +- [x] Add other parts |
| 152 | + - [] label smoothing |
| 153 | + - [] Add BLEU & PPL (https://brunch.co.kr/@leadbreak/11) |
0 commit comments