Skip to content

Commit 9569aec

Browse files
committed
Initial commit
0 parents  commit 9569aec

19 files changed

+1156
-0
lines changed

README.md

+153
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
# Pytorch-Transformer Implementation
2+
This repository contains an implementation of the [Transformer model (Attention is All You Need)](https://arxiv.org/abs/1706.03762) in PyTorch. The model is trained and tested on a dummy dataset consisting of tokens `<sos>=0`, `<eos>=1`, `<pad>=2`, and additional tokens 3 and 4, representing sequences. The core architecture is located in the `model/` directory.
3+
4+
<img src="./assets/transformer.png" alt="transformer" width="50%">
5+
6+
## Guide
7+
Before you run the commands, modify the configurations in `data/config.yaml` as per your requirements.
8+
### Training
9+
Run the following command to start training the model:
10+
```
11+
python main.py --output ${OUTPUT_PATH} --log ${LOG_PATH} --cfg ${CFG_PATH}
12+
```
13+
14+
### Testing
15+
Run the following command to start testing the model:
16+
```
17+
python test.py --model ${MODEL_PATH} --cfg ${CFG_PATH}
18+
```
19+
## Implementation
20+
The main Transformer architecture is defined as follows. Other components like the Encoder and Decoder are implemented separately in the `model/` directory:
21+
```
22+
class Transformer(nn.Module):
23+
def __init__(self, enc_vsize, dec_vsize, d_model, max_len, dropout_p=0.1, n_heads=8, n_layers=6, d_ff=2048, device=None,
24+
src_pad_idx=0, tgt_pad_idx=0):
25+
super(Transformer, self).__init__()
26+
self.device = device
27+
28+
self.encoder = Encoder(vocab_size=enc_vsize,
29+
d_model=d_model,
30+
max_len=max_len,
31+
dropout_p=dropout_p,
32+
n_heads=n_heads,
33+
n_layers=n_layers,
34+
d_ff=d_ff,
35+
device=device)
36+
37+
self.decoder = Decoder(vocab_size=dec_vsize,
38+
d_model=d_model,
39+
max_len = max_len,
40+
dropout_p=dropout_p,
41+
n_heads=n_heads,
42+
n_layers=n_layers,
43+
d_ff=d_ff,
44+
device=device)
45+
self.src_pad_idx = src_pad_idx
46+
self.tgt_pad_idx = tgt_pad_idx
47+
48+
def make_src_mask(self, source) -> torch.Tensor:
49+
"""Padding mask"""
50+
src_mask = (source != self.src_pad_idx).unsqueeze(1).unsqueeze(2) # batch_size x seq_len -> batch_size x 1 x 1 x seq_len
51+
return src_mask
52+
53+
def make_target_mask(self, target) -> torch.Tensor:
54+
"""
55+
1) padding mask - finds padding token and assigns False
56+
2) attention mask (target mask) - limits access available parts
57+
"""
58+
padding_mask = (target != self.tgt_pad_idx).unsqueeze(1).unsqueeze(3)
59+
target_seq_len = target.size(1)
60+
nopeak_mask = (1 - torch.triu(torch.ones(1, target_seq_len, target_seq_len), diagonal=1)).bool().to(self.device)
61+
target_mask = nopeak_mask & padding_mask
62+
63+
return target_mask
64+
65+
def forward(self, src, tgt):
66+
src_mask = self.make_src_mask(src) # batch_size x 1 x 1 x src_seq_len
67+
tgt_mask = self.make_target_mask(tgt) # batch_size x 1 x 1 x tgt_seq_len
68+
69+
enc_emb = self.encoder(src, src_mask) # batch_size x src_seq_len x d_model
70+
tgt_emb = self.decoder(enc_emb, tgt, src_mask, tgt_mask) # batch_size x tgt_seq_len x tgt_vocab_size
71+
return tgt_emb # No softmax as applied in CrossEntroyLoss
72+
73+
```
74+
75+
## Dataset
76+
**Tokens**:
77+
- `SOS` token: `0`
78+
- `EOS` token: `1`
79+
- `PAD` token: `2` (not used in this function)
80+
- `WORDS`: `3`, `4` (used to generate patterns)
81+
82+
**Patterns**:
83+
- Sequence of all 3s: `[0, 3, 3, 3, 3, 3, 3, 3, 3, 1]`
84+
- Sequence of all 4s: `[0, 4, 4, 4, 4, 4, 4, 4, 4, 1]`
85+
- Alternating 3s and 4s starting with 3: `[0, 3, 4, 3, 4, 3, 4, 3, 4, 1]`
86+
- Alternating 3s and 4s starting with 4: `[0, 4, 3, 4, 3, 4, 3, 4, 3, 1]`
87+
88+
## Results
89+
### Training
90+
The below graph is about the model traind until 20 epochs and 5 warmup steps. You can download the trained model [here](https://drive.google.com/file/d/1R-JXH_cFMXFKrfejEqrBj36gUigDgIyX/view?usp=sharing).
91+
92+
<img src="./logs/best_model.png" alt="log" width="80%"/>
93+
94+
### Inference
95+
```
96+
Example 0
97+
Input: [3, 3, 3, 3, 3, 3, 3, 3]
98+
Continuation: [3, 3, 3, 3, 3, 3, 3, 3]
99+
100+
Example 1
101+
Input: [4, 4, 4, 4, 4, 4, 4, 4]
102+
Continuation: [4, 4, 4, 4, 4, 4, 4, 4]
103+
104+
Example 2
105+
Input: [3, 4, 3, 4, 3, 4, 3, 4]
106+
Continuation: [3, 4, 3, 4, 3, 4, 3, 4]
107+
108+
Example 3
109+
Input: [4, 3, 4, 3, 4, 3, 4, 3]
110+
Continuation: [3, 4, 3, 4, 3, 4, 3, 4]
111+
112+
Example 4
113+
Input: [3, 4, 3]
114+
Continuation: [3, 4, 3, 4, 3, 4, 3, 4]
115+
```
116+
117+
## Configurations Structure
118+
119+
```
120+
train:
121+
batch_size:
122+
epochs:
123+
learning_rate:
124+
d_model:
125+
n_heads:
126+
n_layers:
127+
d_ff:
128+
dropout_p:
129+
max_len:
130+
warmup_steps:
131+
132+
test:
133+
d_model:
134+
n_heads:
135+
n_layers:
136+
d_ff:
137+
dropout_p:
138+
max_len:
139+
140+
```
141+
142+
## References
143+
- [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
144+
- [Transformer: PyTorch Implementation of "Attention Is All You Need"](https://github.com/hyunwoongko/transformer/tree/master)
145+
- [A detailed guide to PyTorch’s nn.Transformer() module](https://towardsdatascience.com/a-detailed-guide-to-pytorchs-nn-transformer-module-c80afbc9ffb1)
146+
147+
## TO-DO
148+
- [x] Add Encoder, Decoder
149+
- [x] Training/Validation logic with dataset
150+
- [x] Refactoring
151+
- [x] Add other parts
152+
- [] label smoothing
153+
- [] Add BLEU & PPL (https://brunch.co.kr/@leadbreak/11)

__init__.py

Whitespace-only changes.

assets/transformer.png

162 KB
Loading

data/config.yaml

+19
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
train:
2+
batch_size: 16 # 16 32
3+
epochs: 20 # 50 100
4+
learning_rate: 0.001 # 0.01 0.0005 0.0001
5+
d_model: 512
6+
n_heads: 8
7+
n_layers: 6
8+
d_ff: 2048
9+
dropout_p: 0.1
10+
max_len: 10 # 100 10
11+
warmup_steps: 5 # 4000 50
12+
13+
test:
14+
d_model: 512
15+
n_heads: 8
16+
n_layers: 6
17+
d_ff: 2048
18+
dropout_p: 0.1
19+
max_len: 10

dataset.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import numpy as np
2+
import random
3+
4+
def generate_random_data(n:int, length:int=8, sos_idx:int=0, eos_idx:int=1, pad_idx:int=2) -> list:
5+
"""
6+
Generate random sequences of data for training/testing.
7+
8+
Each sequence starts with an SOS token (start of sequence) and ends with an EOS token (end of sequence).
9+
The sequence is filled with specific patterns of words (tokens), and no padding is used since the max length is set to 10 for convenience.
10+
11+
Tokens:
12+
SOS token: 0
13+
EOS token: 1
14+
PAD token: 2 (not used in this function)
15+
WORDS: 3, 4 (used to generate patterns)
16+
17+
Patterns:
18+
- Sequence of all 3s: [0, 3, 3, 3, 3, 3, 3, 3, 3, 1]
19+
- Sequence of all 4s: [0, 4, 4, 4, 4, 4, 4, 4, 4, 1]
20+
- Alternating 3s and 4s starting with 3: [0, 3, 4, 3, 4, 3, 4, 3, 4, 1]
21+
- Alternating 3s and 4s starting with 4: [0, 4, 3, 4, 3, 4, 3, 4, 3, 1]
22+
23+
Args:
24+
n (int): Number of sequences to generate. Should be divisible by 3.
25+
length (int, optional): Length of the sequence excluding SOS and EOS tokens. Default is 8.
26+
sos_idx (int, optional): Index for the SOS token. Default is 0.
27+
eos_idx (int, optional): Index for the EOS token. Default is 1.
28+
pad_idx (int, optional): Index for the PAD token (not used in this function). Default is 2.
29+
30+
Returns:
31+
list: A list of tuples, where each tuple contains two numpy arrays representing the input and target sequences.
32+
"""
33+
SOS_token = np.array([sos_idx])
34+
EOS_token = np.array([eos_idx])
35+
data = []
36+
37+
for _ in range(n // 3):
38+
X = np.concatenate((SOS_token, 3* np.ones(length), EOS_token))
39+
y = np.concatenate((SOS_token, 3* np.ones(length), EOS_token))
40+
data.append([X, y])
41+
42+
for _ in range(n // 3):
43+
X = np.concatenate((SOS_token, 4 * np.ones(length), EOS_token))
44+
y = np.concatenate((SOS_token, 4 * np.ones(length), EOS_token))
45+
data.append([X, y])
46+
47+
for _ in range(n // 3):
48+
X = np.ones(length)*3
49+
start = random.randint(0, 1)
50+
X[start::2] = 4
51+
52+
y = np.ones(length)*3
53+
if X[-1] == 0:
54+
y[::2] = 4
55+
else:
56+
y[1::2] = 4
57+
58+
X = np.concatenate((SOS_token, X, EOS_token))
59+
y = np.concatenate((SOS_token, y, EOS_token))
60+
data.append([X, y])
61+
np.random.shuffle(data)
62+
return data
63+
64+

debug.ipynb

+175
Large diffs are not rendered by default.

logs/best_model.png

33.4 KB
Loading

main.py

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
import torch
2+
import torch.nn as nn
3+
import torch.optim as optim
4+
from torch.optim.lr_scheduler import LambdaLR
5+
from model.transformer import Transformer
6+
from train import train
7+
from dataset import generate_random_data
8+
from utils import load_yaml, count_parameters, initialize_weights, batchify_data, save_logs
9+
import os
10+
import argparse
11+
import numpy as np
12+
import warnings
13+
warnings.simplefilter(action='ignore', category=FutureWarning)
14+
15+
def parse_args():
16+
parser = argparse.ArgumentParser("Implementation of Transformer in Pytorch")
17+
parser.add_argument("--output",
18+
required=True,
19+
type=str,
20+
help="output path for the trained model")
21+
parser.add_argument("--log",
22+
required=True,
23+
type=str,
24+
help="output path for saving the logs (including filename)")
25+
parser.add_argument("--cfg",
26+
required=True,
27+
type=str,
28+
help="configuration path")
29+
return parser.parse_args()
30+
31+
32+
def main():
33+
args = parse_args()
34+
log_save_path = args.log
35+
model_save_path = args.output
36+
cfg = load_yaml(args.cfg)['train']
37+
38+
os.makedirs(model_save_path, exist_ok=True)
39+
40+
device = torch.device('cuda:5' if torch.cuda.is_available() else 'cpu')
41+
print(f'[INFO] Using device: {device}')
42+
43+
print(f'[INFO] n_warmup: {cfg["warmup_steps"]} | max length : {cfg["max_len"]} | batch size : {cfg["batch_size"]} | epochs : {cfg["epochs"]} | lr : {cfg["learning_rate"]}')
44+
print(f'[INFO] d_model : {cfg["d_model"]} | n_heads : {cfg["n_heads"]} | n_layers : {cfg["n_layers"]} | d_ff : {cfg["d_ff"]} | dropout_p : {cfg["dropout_p"]}')
45+
46+
print('[INFO] Load dataset ...')
47+
train_data = generate_random_data(20000, length=cfg['max_len'] - 2) # 10000
48+
val_data = generate_random_data(6000, length=cfg['max_len'] - 2) # 3000
49+
50+
train_loader = batchify_data(train_data, batch_size=cfg['batch_size'])
51+
val_loader = batchify_data(val_data, batch_size=cfg['batch_size'])
52+
53+
print('[INFO] Load model ...')
54+
# sos, eos, padding, 3, 4
55+
model = Transformer(
56+
enc_vsize=5,
57+
dec_vsize=5,
58+
d_model=cfg['d_model'],
59+
max_len=cfg['max_len'],
60+
dropout_p=cfg['dropout_p'],
61+
n_heads=cfg['n_heads'],
62+
n_layers=cfg['n_layers'],
63+
d_ff=cfg['d_ff'],
64+
device=device,
65+
src_pad_idx=2,
66+
tgt_pad_idx=2
67+
).to(device)
68+
69+
print(f'[INFO] # of trainable parameters : {count_parameters(model):,}')
70+
model.apply(initialize_weights)
71+
72+
criterion = nn.CrossEntropyLoss(ignore_index=2)
73+
optimizer = optim.Adam(model.parameters(),
74+
betas=(0.9, 0.98),
75+
lr=cfg['learning_rate'], # default 0.001
76+
eps=1e-9)
77+
78+
def lr_scheduler(optimizer, warmup_steps, d_model):
79+
"""equation (3)"""
80+
def lrate(step):
81+
return (d_model ** -0.5) * min((step + 1) ** -0.5, (step + 1) * warmup_steps ** -1.5)
82+
return LambdaLR(optimizer, lr_lambda=lrate)
83+
84+
scheduler = lr_scheduler(optimizer,
85+
warmup_steps=cfg['warmup_steps'],
86+
d_model=cfg['d_model'])
87+
88+
tr_losses, val_losses = train(model, train_loader, val_loader,
89+
criterion, optimizer, scheduler,
90+
cfg['epochs'], device, model_save_path)
91+
92+
save_logs(log_save_path, tr_losses, val_losses)
93+
print('[INFO] Successfully saved model!')
94+
95+
if __name__ == "__main__":
96+
main()
97+

model/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
"""Codebase for implementation of transformer in PyTorch"""

0 commit comments

Comments
 (0)