-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathfasttext_custom_embeddings_with_flair.py
120 lines (92 loc) · 4.54 KB
/
fasttext_custom_embeddings_with_flair.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
"""
Flair framework has recently made an enhancement to load custom embeddings in Flair through 'WordEmbeddings' class.
However, this has a limitation. Internally, Flair uses gensim to load these custom embeddings. So, if we wish to
import our custom embeddings in gensim then we have to convert it to gensim format and then use it.
While doing so we lose the ability of fasttext to approximate vector for out of vocabulary words using the sub-word
information. As a solution to this problem, I present you this script.
The embedding object defined using this script can be used in the same way as any other Flair embedding objects.
We can use the 'embed' function to add embeddings to tokens in the 'Sentence' object. We can use this along with other
embeddings in 'StackedEmbeddings' and 'DocumentEmbeddings'.
Please install packages from requirements.txt to use this script. You can then import this script to your project
where you wish to use custom fasttext embeddings with Flair.
"""
import numpy as np
import gensim
import torch
from typing import List
from flair.embeddings import TokenEmbeddings
from flair.file_utils import cached_path
from flair.data import Sentence
from pathlib import Path
import fasttext as ft
class FastTextEmbeddings(TokenEmbeddings):
"""FastText Embeddings to use with Flair framework"""
def __init__(
self,
embeddings: str,
use_local: bool = True,
use_gensim: bool = False,
extension: str = ".bin",
field: str = None,
):
"""
Initializes fasttext word embeddings. Constructor downloads required embedding file and stores in cache
if use_local is False.
The ".bin" and ".vec" embeddings format are supported. However, ".vec" doesn't support OOV functionality.
In case of ".vec", if word is not in vocabulary then zero vector will be returned by default.
:param embeddings: path to your embeddings '.bin' file
:param use_local: set this to False if you are using embeddings from a remote source
:param use_gensim: set this to true if your fasttext embedding is trained with fasttext version below 0.9.1
:param extension: ".bin" or ".vec" are the supported extensions
"""
cache_dir = Path("embeddings")
if use_local:
if not Path(embeddings).exists():
raise ValueError(
f'The given embeddings "{embeddings}" is not available or is not a valid path.'
)
else:
embeddings = cached_path(f"{embeddings}", cache_dir=cache_dir)
self.embeddings = embeddings
self.name: str = str(embeddings)
self.static_embeddings = True
self.use_gensim = use_gensim
if extension == ".bin":
if use_gensim:
self.precomputed_word_embeddings = gensim.models.FastText.load_fasttext_format(
str(embeddings)
)
self.__embedding_length: int = self.precomputed_word_embeddings.vector_size
else:
self.precomputed_word_embeddings = ft.load_model(str(embeddings))
self.__embedding_length: int = self.precomputed_word_embeddings.get_dimension()
elif extension == ".vec":
self.precomputed_word_embeddings = gensim.models.KeyedVectors.load_word2vec_format(
str(embeddings)
)
self.__embedding_length: int = self.precomputed_word_embeddings.vector_size
else:
raise ValueError("Unsupported extension: " + extension)
self.field = field
super().__init__()
@property
def embedding_length(self) -> int:
return self.__embedding_length
def _add_embeddings_internal(self, sentences: List[Sentence]) -> List[Sentence]:
for i, sentence in enumerate(sentences):
for token, token_idx in zip(sentence.tokens, range(len(sentence.tokens))):
if "field" not in self.__dict__ or self.field is None:
word = token.text
else:
word = token.get_tag(self.field).value
try:
word_embedding = self.precomputed_word_embeddings[word]
except:
word_embedding = np.zeros(self.embedding_length, dtype="float")
word_embedding = torch.FloatTensor(word_embedding)
token.set_embedding(self.name, word_embedding)
return sentences
def __str__(self):
return self.name
def extra_repr(self):
return f"'{self.embeddings}'"