Skip to content

Commit e88e478

Browse files
authored
Merge pull request #147 from theabhirath/master
2 parents 8ba79d5 + c6e26ac commit e88e478

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1717
BSON = "0.3.2"
1818
Flux = "0.13"
1919
Functors = "0.2"
20-
MLUtils = "0.1.2, 0.2"
20+
MLUtils = "0.2"
2121
NNlib = "0.7.34, 0.8"
22-
julia = "1.4"
22+
julia = "1.6"
2323
NeuralAttentionlib = "0.0"
2424

2525
[extras]

src/layers/attention.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ end
4040
@functor MHAttention
4141

4242
function (m::MHAttention)(x::AbstractArray{T, 3}) where T
43-
B, C, N = size(x)
44-
q, k, v = chunk(reshape(m.qkv_layer(x), B ÷ m.nheads, m.nheads, C, 3 * N), 3; dims = 4)
43+
features, len_seq, batch_size = size(x)
44+
q, k, v = chunk(reshape(m.qkv_layer(x), features ÷ m.nheads, m.nheads, len_seq, 3 * batch_size), 3; dims = 4)
4545
scale = convert(T, sqrt(size(q, 1) / m.nheads))
4646
attn = m.attn_drop(softmax(NeuralAttentionlib.matmul(q, permutedims(k, (2, 1, 3, 4))) * scale))
47-
x = m.projection(reshape(NeuralAttentionlib.matmul(attn, v), (B, C, N)))
47+
x = m.projection(reshape(NeuralAttentionlib.matmul(attn, v), (features, len_seq, batch_size)))
4848
end

0 commit comments

Comments
 (0)