Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

Commit c0e1a75

Browse files
authored
Merge pull request #10 from foldfelis/nD
extend SpectralConv to n-dim
2 parents db8a1b6 + 01167c5 commit c0e1a75

File tree

8 files changed

+173
-51
lines changed

8 files changed

+173
-51
lines changed

Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
1212
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
1313
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
1414
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
15+
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1516
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
1617
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1718

@@ -20,9 +21,11 @@ CUDA = "3.3"
2021
CUDAKernels = "0.3"
2122
DataDeps = "0.7"
2223
FFTW = "1.4"
24+
Fetch = "0.1"
2325
Flux = "0.12"
2426
KernelAbstractions = "0.7"
2527
MAT = "0.10"
28+
StatsBase = "0.33"
2629
Tullio = "0.3"
2730
Zygote = "0.6"
2831
julia = "1.6"

docs/Manifest.toml

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,10 +73,10 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
7373
version = "0.4.1"
7474

7575
[[CUDA]]
76-
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
77-
git-tree-sha1 = "889889f1c13467406a126cd2789b4844487ddfc1"
76+
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
77+
git-tree-sha1 = "9303b20dfa74e4bcb4da425d351d551fbb5850be"
7878
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
79-
version = "3.3.5"
79+
version = "3.4.0"
8080

8181
[[CUDAKernels]]
8282
deps = ["Adapt", "CUDA", "Cassette", "KernelAbstractions", "SpecialFunctions", "StaticArrays"]
@@ -91,9 +91,9 @@ version = "0.3.7"
9191

9292
[[ChainRules]]
9393
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"]
94-
git-tree-sha1 = "11567f2471013449c2fcf119f674c681484a130e"
94+
git-tree-sha1 = "6615deb51db68c3fa0cc7b34e5399c15f63fa97e"
9595
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
96-
version = "1.5.1"
96+
version = "1.7.0"
9797

9898
[[ChainRulesCore]]
9999
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
@@ -168,9 +168,9 @@ version = "1.0.3"
168168

169169
[[DiffRules]]
170170
deps = ["NaNMath", "Random", "SpecialFunctions"]
171-
git-tree-sha1 = "85d2d9e2524da988bffaf2a381864e20d2dae08d"
171+
git-tree-sha1 = "3ed8fa7178a10d1cd0f1ca524f249ba6937490c0"
172172
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
173-
version = "1.2.1"
173+
version = "1.3.0"
174174

175175
[[Distributed]]
176176
deps = ["Random", "Serialization", "Sockets"]
@@ -211,9 +211,9 @@ version = "3.3.9+8"
211211

212212
[[Fetch]]
213213
deps = ["Base64", "HTTP", "JSON3", "Random", "StructTypes"]
214-
git-tree-sha1 = "805a7f0edd71138f053b572613e918ef147625f0"
214+
git-tree-sha1 = "84ba4219db49572bc3020589e77db293707aad51"
215215
uuid = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
216-
version = "0.1.0"
216+
version = "0.1.1"
217217

218218
[[FillArrays]]
219219
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
@@ -247,15 +247,15 @@ version = "0.2.3"
247247

248248
[[GPUArrays]]
249249
deps = ["Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
250-
git-tree-sha1 = "8034b1a19f7a19743c53cda450fcc65d1b8f7ab5"
250+
git-tree-sha1 = "8fac1cf7d6ce0f2249c7acaf25d22e1e85c4a07f"
251251
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
252-
version = "8.0.1"
252+
version = "8.0.2"
253253

254254
[[GPUCompiler]]
255255
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
256-
git-tree-sha1 = "f26f15d9c353f7091065390ea826df9e03917e58"
256+
git-tree-sha1 = "4ed2616d5e656c8716736b64da86755467f26cf5"
257257
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
258-
version = "0.12.8"
258+
version = "0.12.9"
259259

260260
[[HDF5]]
261261
deps = ["Blosc", "Compat", "HDF5_jll", "Libdl", "Mmap", "Random", "Requires"]
@@ -468,7 +468,7 @@ version = "0.3.5"
468468
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"
469469

470470
[[NeuralOperators]]
471-
deps = ["CUDA", "CUDAKernels", "DataDeps", "FFTW", "Fetch", "Flux", "KernelAbstractions", "MAT", "Tullio", "Zygote"]
471+
deps = ["CUDA", "CUDAKernels", "DataDeps", "FFTW", "Fetch", "Flux", "KernelAbstractions", "MAT", "StatsBase", "Tullio", "Zygote"]
472472
path = ".."
473473
uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
474474
version = "0.1.0"

src/NeuralOperators.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ module NeuralOperators
22
using DataDeps
33
using Fetch
44
using MAT
5+
using StatsBase
56

67
using Flux
78
using FFTW

src/data.jl

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,61 @@
11
export
2-
get_burgers_data
2+
UnitGaussianNormalizer,
3+
encode,
4+
decode,
5+
get_burgers_data,
6+
get_darcy_flow_data
37

4-
function register_datasets()
8+
struct UnitGaussianNormalizer{T}
9+
mean::Array{T}
10+
std::Array{T}
11+
ϵ::T
12+
end
13+
14+
function UnitGaussianNormalizer(𝐱; ϵ=1f-5)
15+
dims = 1:ndims(𝐱)-1
16+
17+
return UnitGaussianNormalizer(mean(𝐱, dims=dims), StatsBase.std(𝐱, dims=dims), ϵ)
18+
end
19+
20+
encode(n::UnitGaussianNormalizer, 𝐱::AbstractArray) = @. (𝐱-n.mean) / (n.std+n.ϵ)
21+
decode(n::UnitGaussianNormalizer, 𝐱::AbstractArray) = @. 𝐱 * (n.std+n.ϵ) + n.mean
22+
23+
24+
function register_burgers()
525
register(DataDep(
6-
"BurgersR10",
26+
"Burgers",
727
"""
828
Burgers' equation dataset from
929
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
1030
""",
11-
"https://drive.google.com/file/d/16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe/view?usp=sharing",
31+
"https://drive.google.com/file/d/17MYsKzxUQVaLMWodzPbffR8hhDHoadPp/view?usp=sharing",
1232
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd",
1333
fetch_method=gdownload,
1434
post_fetch_method=unpack
1535
))
1636
end
1737

38+
function register_darcy_flow()
39+
register(DataDep(
40+
"DarcyFlow",
41+
"""
42+
Darcy flow dataset from
43+
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
44+
""",
45+
"https://drive.google.com/file/d/1zzVMuGhOG70EnR5L24LWqmX9-Wh_H5Wu/view?usp=sharing",
46+
"802825de9da7398407296c99ca9ceb2371c752f6a3bdd1801172e02ce19edda4",
47+
fetch_method=gdownload,
48+
post_fetch_method=unpack
49+
))
50+
end
51+
52+
function register_datasets()
53+
register_burgers()
54+
register_darcy_flow()
55+
end
56+
1857
function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples), T=Float32)
19-
file = matopen(joinpath(datadep"BurgersR10", "burgers_data_R10.mat"))
58+
file = matopen(joinpath(datadep"Burgers", "burgers_data_R10.mat"))
2059
x_data = T.(collect(read(file, "a")[1:n, 1:Δsamples:end]'))
2160
y_data = T.(collect(read(file, "u")[1:n, 1:Δsamples:end]'))
2261
close(file)
@@ -27,3 +66,20 @@ function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples
2766

2867
return x_loc_data, y_data
2968
end
69+
70+
function get_darcy_flow_data(; n=1024, Δsamples=5, T=Float32, test_data=false)
71+
# size(training_data) == size(testing_data) == (1024, 421, 421)
72+
file = test_data ? "piececonst_r421_N1024_smooth2.mat" : "piececonst_r421_N1024_smooth1.mat"
73+
file = matopen(joinpath(datadep"DarcyFlow", file))
74+
x_data = T.(permutedims(read(file, "coeff")[1:n, 1:Δsamples:end, 1:Δsamples:end], (3, 2, 1)))
75+
y_data = T.(permutedims(read(file, "sol")[1:n, 1:Δsamples:end, 1:Δsamples:end], (3, 2, 1)))
76+
close(file)
77+
78+
x_dims = pushfirst!([size(x_data)...], 1)
79+
y_dims = pushfirst!([size(y_data)...], 1)
80+
x_data, y_data = reshape(x_data, x_dims...), reshape(y_data, y_dims...)
81+
82+
x_normalizer, y_normalizer = UnitGaussianNormalizer(x_data), UnitGaussianNormalizer(y_data)
83+
84+
return encode(x_normalizer, x_data), encode(y_normalizer, y_data), x_normalizer, y_normalizer
85+
end

src/fourier.jl

Lines changed: 29 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,62 @@
11
export
2-
SpectralConv1d,
2+
SpectralConv,
33
FourierOperator
44

5-
struct SpectralConv1d{T, S}
5+
struct SpectralConv{N, T, S}
66
weight::T
77
in_channel::S
88
out_channel::S
9-
modes::S
9+
modes::NTuple{N, S}
10+
ndim::S
1011
σ
1112
end
1213

1314
c_glorot_uniform(dims...) = Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...)*im
1415

15-
function SpectralConv1d(
16-
ch::Pair{<:Integer, <:Integer},
17-
modes::Integer,
16+
function SpectralConv(
17+
ch::Pair{S, S},
18+
modes::NTuple{N, S},
1819
σ=identity;
1920
init=c_glorot_uniform,
2021
T::DataType=ComplexF32
21-
)
22+
) where {S<:Integer, N}
2223
in_chs, out_chs = ch
2324
scale = one(T) / (in_chs * out_chs)
24-
weights = scale * init(out_chs, in_chs, modes)
25+
weights = scale * init(out_chs, in_chs, prod(modes))
2526

26-
return SpectralConv1d(weights, in_chs, out_chs, modes, σ)
27+
return SpectralConv(weights, in_chs, out_chs, modes, N, σ)
2728
end
2829

29-
Flux.@functor SpectralConv1d
30+
Flux.@functor SpectralConv
3031

32+
Base.ndims(::SpectralConv{N}) where {N} = N
33+
34+
# [prod(m.modes), out_chs, batch] <- [prod(m.modes), in_chs, batch] * [out_chs, in_chs, prod(m.modes)]
3135
spectral_conv(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]
3236

33-
function (m::SpectralConv1d)(𝐱::AbstractArray)
34-
𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (2, 1, 3)) # [x, in_chs, batch] <- [in_chs, x, batch]
35-
𝐱_fft = fft(𝐱ᵀ, 1) # [x, in_chs, batch]
37+
function (m::SpectralConv)(𝐱::AbstractArray)
38+
n_dims = ndims(𝐱)
39+
40+
𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (ntuple(i->i+1, ndims(m))..., 1, ndims(m)+2)) # [x, in_chs, batch] <- [in_chs, x, batch]
41+
𝐱_fft = fft(𝐱ᵀ, 1:ndims(m)) # [x, in_chs, batch]
42+
43+
𝐱_flattened = reshape(view(𝐱_fft, map(d->1:d, m.modes)..., :, :), :, size(𝐱_fft, n_dims-1), size(𝐱_fft, n_dims))
44+
𝐱_weighted = spectral_conv(𝐱_flattened, m.weight) # [prod(m.modes), out_chs, batch], only 3-dims
45+
𝐱_shaped = reshape(𝐱_weighted, m.modes..., size(𝐱_weighted, 2), size(𝐱_weighted, 3))
3646

37-
# [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes]
38-
𝐱_weighted = spectral_conv(view(𝐱_fft, 1:m.modes, :, :), m.weight)
3947
# [x, out_chs, batch] <- [modes, out_chs, batch]
40-
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, size(𝐱_fft, 1)-m.modes, Base.tail(size(𝐱_weighted))...), dims=1)
48+
pad = zeros(ComplexF32, ntuple(i->size(𝐱_fft, i)-m.modes[i], ndims(m))..., size(𝐱_shaped, n_dims-1), size(𝐱_shaped, n_dims))
49+
𝐱_padded = cat(𝐱_shaped, pad, dims=1:ndims(m))
4150

42-
𝐱_out = ifft(𝐱_padded, 1) # [x, out_chs, batch]
43-
𝐱_outᵀ = permutedims(real(𝐱_out), (2, 1, 3)) # [out_chs, x, batch] <- [x, out_chs, batch]
51+
𝐱_out = ifft(𝐱_padded, 1:ndims(m)) # [x, out_chs, batch]
52+
𝐱_outᵀ = permutedims(real(𝐱_out), (ndims(m)+1, 1:ndims(m)..., ndims(m)+2)) # [out_chs, x, batch] <- [x, out_chs, batch]
4453

4554
return m.σ.(𝐱_outᵀ)
4655
end
4756

48-
function FourierOperator(ch::Pair{<:Integer, <:Integer}, modes::Integer, σ=identity)
57+
function FourierOperator(ch::Pair{S, S}, modes::NTuple{N, S}, σ=identity) where {S<:Integer, N}
4958
return Chain(
50-
Parallel(+, Dense(ch.first, ch.second), SpectralConv1d(ch, modes)),
59+
Parallel(+, Dense(ch.first, ch.second), SpectralConv(ch, modes)),
5160
x -> σ.(x)
5261
)
5362
end

src/model.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ export
22
FourierNeuralOperator
33

44
function FourierNeuralOperator()
5-
modes = 16
5+
modes = (16, )
66
ch = 64 => 64
77
σ = relu
88

test/data.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,21 @@
44
@test size(xs) == (2, 1024, 1000)
55
@test size(ys) == (1024, 1000)
66
end
7+
8+
@testset "unit gaussian normalizer" begin
9+
dims = (3, 3, 5, 6)
10+
𝐱 = rand(Float32, dims)
11+
12+
n = UnitGaussianNormalizer(𝐱)
13+
14+
@test size(n.mean) == size(n.std)
15+
@test size(encode(n, 𝐱)) == dims
16+
@test size(decode(n, encode(n, 𝐱))) == dims
17+
end
18+
19+
@testset "get darcy flow data" begin
20+
xs, ys, _, _ = get_darcy_flow_data()
21+
22+
@test size(xs) == (1, 85, 85, 1024)
23+
@test size(ys) == (1, 85, 85, 1024)
24+
end

test/fourier.jl

Lines changed: 46 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,69 @@
11
@testset "SpectralConv1d" begin
2-
modes = 16
2+
modes = (16, )
33
ch = 64 => 64
44

55
m = Chain(
66
Dense(2, 64),
7-
SpectralConv1d(ch, modes)
7+
SpectralConv(ch, modes)
88
)
9+
@test ndims(SpectralConv(ch, modes)) == 1
910

10-
𝐱, _ = get_burgers_data(n=1000)
11-
@test size(m(𝐱)) == (64, 1024, 1000)
11+
𝐱, _ = get_burgers_data(n=5)
12+
@test size(m(𝐱)) == (64, 1024, 5)
1213

13-
T = Float32
1414
loss(x, y) = Flux.mse(m(x), y)
15-
data = [(T.(𝐱[:, :, 1:5]), rand(T, 64, 1024, 5))]
15+
data = [(𝐱, rand(Float32, 64, 1024, 5))]
1616
Flux.train!(loss, params(m), data, Flux.ADAM())
1717
end
1818

19-
@testset "FourierOperator" begin
20-
modes = 16
19+
@testset "FourierOperator1d" begin
20+
modes = (16, )
2121
ch = 64 => 64
2222

2323
m = Chain(
2424
Dense(2, 64),
2525
FourierOperator(ch, modes)
2626
)
2727

28-
𝐱, _ = get_burgers_data(n=1000)
29-
@test size(m(𝐱)) == (64, 1024, 1000)
28+
𝐱, _ = get_burgers_data(n=5)
29+
@test size(m(𝐱)) == (64, 1024, 5)
3030

3131
loss(x, y) = Flux.mse(m(x), y)
32-
data = [(Float32.(𝐱[:, :, 1:5]), rand(Float32, 64, 1024, 5))]
32+
data = [(𝐱, rand(Float32, 64, 1024, 5))]
33+
Flux.train!(loss, params(m), data, Flux.ADAM())
34+
end
35+
36+
@testset "SpectralConv2d" begin
37+
modes = (16, 16)
38+
ch = 64 => 64
39+
40+
m = Chain(
41+
Dense(1, 64),
42+
SpectralConv(ch, modes)
43+
)
44+
@test ndims(SpectralConv(ch, modes)) == 2
45+
46+
𝐱, _, _, _ = get_darcy_flow_data(n=5, Δsamples=20)
47+
@test size(m(𝐱)) == (64, 22, 22, 5)
48+
49+
loss(x, y) = Flux.mse(m(x), y)
50+
data = [(𝐱, rand(Float32, 64, 22, 22, 5))]
51+
Flux.train!(loss, params(m), data, Flux.ADAM())
52+
end
53+
54+
@testset "FourierOperator2d" begin
55+
modes = (16, 16)
56+
ch = 64 => 64
57+
58+
m = Chain(
59+
Dense(1, 64),
60+
FourierOperator(ch, modes)
61+
)
62+
63+
𝐱, _, _, _ = get_darcy_flow_data(n=5, Δsamples=20)
64+
@test size(m(𝐱)) == (64, 22, 22, 5)
65+
66+
loss(x, y) = Flux.mse(m(x), y)
67+
data = [(𝐱, rand(Float32, 64, 22, 22, 5))]
3368
Flux.train!(loss, params(m), data, Flux.ADAM())
3469
end

0 commit comments

Comments
 (0)