Skip to content

Commit b987a58

Browse files
committed
restructure imports + add decoding submodule + implement greedy
1 parent e899d71 commit b987a58

27 files changed

+1217
-870
lines changed

README.md

Lines changed: 56 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,11 @@
55
<p align="center">
66
<sup>
77
<b>Contents</b>:&nbsp;
8-
<a href="#restrictions">Restrictions</a> ·
8+
<a href="#features">Features</a> ·
9+
<a href="#example">Example</a> ·
910
<a href="#details">Details</a> ·
1011
<a href="#datasets">Datasets</a> ·
11-
<a href="#models-and-examples">Models and examples</a> ·
12+
<a href="#models-and-notebooks">Models and notebooks</a> ·
1213
<a href="#repository-structure">Repository structure</a> ·
1314
<a href="#installation">Installation</a> ·
1415
<a href="#running">Running</a> ·
@@ -21,12 +22,19 @@ The repository contains a modular Python implementation of transformer architect
2122
- The seminal paper _Attention Is All You Need_ by Vaswani et al.<sup><a href="#references">[1]</a></sup> that details the novel attention-based transformer architecture and its application to sequence-to-sequence tasks, demonstrating its effectiveness by achieving state-of-the-art performance in machine translation, surpassing previous LSTM and CNN based neural machine translation architectures.
2223
- The chapter on _Transformers and Large Language Models_ from _Speech and Language Processing_ by Jurafsky & Martin<sup><a href="#references">[2]</a></sup> which provides a more comprehensive and illustrative look into some of the high-level details discussed in _Attention Is All You Need_.
2324

24-
## Restrictions
25+
## Features
2526

26-
This project is implemented using [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
27+
- Generic encoder-only, decoder-only and encoder-decoder transformer architectures.
28+
- Wrappers for causal language modelling, sequence-to-sequence generation and classification/regression tasks.
29+
- Various decoding methods for causal/sequence-to-sequence generation:
30+
- Search-based (greedy and beam search)
31+
- Sampling-based (nucleus, temperature and top-k sampling)
32+
- Example applications to real-world datasets.
2733

2834
### PyTorch restrictions
2935

36+
This project is implemented using [PyTorch](https://pytorch.org/) and [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/).
37+
3038
As PyTorch provides a number of transformer and attention related layers in its [`torch.nn`](https://pytorch.org/docs/stable/nn.html) submodule, this project explicitly avoids the use of:
3139

3240
- [`torch.nn.Transformer`](https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html#torch.nn.Transformer)
@@ -47,6 +55,47 @@ All other layers provided by `torch.nn` are allowed, including:
4755
- No existing _"x from scratch"_ resources were used, such as the famous _Let's build GPT: from scratch, in code, spelled out._ by Andrej Karpathy<sup><a href="#references">[3]</a></sup>.
4856
- No other online resources were used, apart from official documentation for packages such as [PyTorch](https://pytorch.org/docs/stable/index.html), [PyTorch Lightning](https://lightning.ai/docs/pytorch/stable/) and [Huggingface Tokenizers](https://huggingface.co/docs/transformers/en/main_classes/tokenizer).
4957

58+
## Example
59+
60+
Training a causal language model to generate "Florida man"-style news headlines.
61+
62+
```python
63+
from transformers import LlamaTokenizer
64+
65+
from transformer.params import TransformerParams, TemperatureSamplingParams
66+
from transformer.models import CausalLM
67+
from transformer.decoding import TemperatureSamplingDecoder
68+
69+
# initialize HuggingFace tokenizer
70+
tokenizer = LlamaTokenizer.from_pretrained(
71+
"huggyllama/llama-7b", add_eos_token=True, legacy=False
72+
)
73+
tokenizer.add_special_tokens({"pad_token": "<pad>"})
74+
75+
# initialize the causal language model
76+
model = CausalLM(
77+
params=TransformerParams(context_length=64),
78+
tokenizer=tokenizer,
79+
)
80+
81+
# train the language model
82+
model.train(...)
83+
84+
# initialize decoder for sequence generation
85+
decoder = TemperatureSamplingDecoder(
86+
params=TemperatureSamplingParams(max_length=100, temperature=0.5),
87+
model=model,
88+
)
89+
90+
# generation without context
91+
decoder.generate()
92+
'Florida man arrested after baby alligator, guns, drugs found inside truck'
93+
94+
# generation with context
95+
decoder.generate("Florida man shot")
96+
'Florida man shot and killed while attempting to steal pizza and Pokemon cards from Target'
97+
```
98+
5099
## Details
51100

52101
While the original architecture described in _Attention Is All You Need_ is an encoder-decoder based architecture using transformers for neural machine translation which is a sequence-to-sequence learning task, this project was designed to be more general, allowing for a variety of natural language tasks by implementing encoder-only, decoder-only and encoder-decoder architectures.
@@ -104,7 +153,7 @@ The following datasets were used to test the above transformer implementations o
104153
- [Reddit r/FloridaMan](https://www.kaggle.com/datasets/bcruise/reddit-rfloridaman): News headlines about various (often funny and irrational) actions performed by Florida men and women.
105154
- [Europarl](https://www.kaggle.com/datasets/nltkdata/europarl): Transcriptions of European Parliament proceedings between 1996-2006, collected in 11 languages.
106155

107-
## Models and examples
156+
## Models and notebooks
108157

109158
### Encoder-only models
110159

@@ -129,14 +178,15 @@ The following datasets were used to test the above transformer implementations o
129178
- [**`notebooks/`**](notebooks/): Notebooks applying the models in [`transformer.models`](transformer/models/) to various datasets.
130179
- [**`transformer/`**](transformer/): Core package containing the transformer implementations.
131180
- [**`dataloaders/`**](transformer/dataloaders/): [`LightningDataModule`](https://lightning.ai/docs/pytorch/stable/data/datamodule.html)s for each model in [`transformer.models`](transformer/models/).
181+
- [**`decoding/`**](transformers/decoding/): Decoding method implementations for causal and sequence-to-sequence LMs.
132182
- [**`models/`**](transformer/models/): Task-specific transformers implemented using [`transformer.modules.transformers`](transformer/modules/transformers/).
133183
- [**`modules/`**](transformer/modules/): [`LightningModule`](https://lightning.ai/docs/pytorch/stable/common/lightning_module.html)s used within the transformers in [`transformer.models`](transformer/models/).
134184
- [**`transformers/`**](transformer/modules/transformers/): Encoder-only, decoder-only and encoder-decoder transformer definitions.
135185
- [`attention.py`](transformer/modules/attention.py): Masked/unmasked multi-head self attention definition.
136186
- [`block.py`](transformer/modules/block.py): Transformer block definition.
137187
- [`embedding.py`](transformer/modules/embedding.py): Positional encoding and input embedding definition.
188+
- [**`params/`**](transformer/params/): Pydantic hyper-parameter classes.
138189
- [**`utils/`**](transformer/utils/): Supporting custom layers, functions and constants.
139-
- [`params.py`](transformer/params.py): Pydantic hyper-parameter classes for modules in [`transformer.modules`](transformer/modules/).
140190

141191
## Installation
142192

0 commit comments

Comments
 (0)