Skip to content
This repository was archived by the owner on Sep 28, 2024. It is now read-only.

extend SpectralConv to n-dim #10

Merged
merged 14 commits into from
Aug 15, 2021
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ Fetch = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
MAT = "23992714-dd62-5051-b70f-ba57cb901cac"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
Tullio = "bc48ee85-29a4-5162-ae0b-a64e1601d4bc"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

Expand All @@ -20,9 +21,11 @@ CUDA = "3.3"
CUDAKernels = "0.3"
DataDeps = "0.7"
FFTW = "1.4"
Fetch = "0.1"
Flux = "0.12"
KernelAbstractions = "0.7"
MAT = "0.10"
StatsBase = "0.33"
Tullio = "0.3"
Zygote = "0.6"
julia = "1.6"
Expand Down
28 changes: 14 additions & 14 deletions docs/Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -73,10 +73,10 @@ uuid = "fa961155-64e5-5f13-b03f-caf6b980ea82"
version = "0.4.1"

[[CUDA]]
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "DataStructures", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
git-tree-sha1 = "889889f1c13467406a126cd2789b4844487ddfc1"
deps = ["AbstractFFTs", "Adapt", "BFloat16s", "CEnum", "CompilerSupportLibraries_jll", "ExprTools", "GPUArrays", "GPUCompiler", "LLVM", "LazyArtifacts", "Libdl", "LinearAlgebra", "Logging", "Printf", "Random", "Random123", "RandomNumbers", "Reexport", "Requires", "SparseArrays", "SpecialFunctions", "TimerOutputs"]
git-tree-sha1 = "9303b20dfa74e4bcb4da425d351d551fbb5850be"
uuid = "052768ef-5323-5732-b1bb-66c8b64840ba"
version = "3.3.5"
version = "3.4.0"

[[CUDAKernels]]
deps = ["Adapt", "CUDA", "Cassette", "KernelAbstractions", "SpecialFunctions", "StaticArrays"]
Expand All @@ -91,9 +91,9 @@ version = "0.3.7"

[[ChainRules]]
deps = ["ChainRulesCore", "Compat", "LinearAlgebra", "Random", "Statistics"]
git-tree-sha1 = "11567f2471013449c2fcf119f674c681484a130e"
git-tree-sha1 = "6615deb51db68c3fa0cc7b34e5399c15f63fa97e"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.5.1"
version = "1.7.0"

[[ChainRulesCore]]
deps = ["Compat", "LinearAlgebra", "SparseArrays"]
Expand Down Expand Up @@ -168,9 +168,9 @@ version = "1.0.3"

[[DiffRules]]
deps = ["NaNMath", "Random", "SpecialFunctions"]
git-tree-sha1 = "85d2d9e2524da988bffaf2a381864e20d2dae08d"
git-tree-sha1 = "3ed8fa7178a10d1cd0f1ca524f249ba6937490c0"
uuid = "b552c78f-8df3-52c6-915a-8e097449b14b"
version = "1.2.1"
version = "1.3.0"

[[Distributed]]
deps = ["Random", "Serialization", "Sockets"]
Expand Down Expand Up @@ -211,9 +211,9 @@ version = "3.3.9+8"

[[Fetch]]
deps = ["Base64", "HTTP", "JSON3", "Random", "StructTypes"]
git-tree-sha1 = "805a7f0edd71138f053b572613e918ef147625f0"
git-tree-sha1 = "84ba4219db49572bc3020589e77db293707aad51"
uuid = "bb354801-46f6-40b6-9c3d-d42d7a74c775"
version = "0.1.0"
version = "0.1.1"

[[FillArrays]]
deps = ["LinearAlgebra", "Random", "SparseArrays", "Statistics"]
Expand Down Expand Up @@ -247,15 +247,15 @@ version = "0.2.3"

[[GPUArrays]]
deps = ["Adapt", "LinearAlgebra", "Printf", "Random", "Serialization", "Statistics"]
git-tree-sha1 = "8034b1a19f7a19743c53cda450fcc65d1b8f7ab5"
git-tree-sha1 = "8fac1cf7d6ce0f2249c7acaf25d22e1e85c4a07f"
uuid = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
version = "8.0.1"
version = "8.0.2"

[[GPUCompiler]]
deps = ["ExprTools", "InteractiveUtils", "LLVM", "Libdl", "Logging", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "f26f15d9c353f7091065390ea826df9e03917e58"
git-tree-sha1 = "4ed2616d5e656c8716736b64da86755467f26cf5"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.12.8"
version = "0.12.9"

[[HDF5]]
deps = ["Blosc", "Compat", "HDF5_jll", "Libdl", "Mmap", "Random", "Requires"]
Expand Down Expand Up @@ -468,7 +468,7 @@ version = "0.3.5"
uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908"

[[NeuralOperators]]
deps = ["CUDA", "CUDAKernels", "DataDeps", "FFTW", "Fetch", "Flux", "KernelAbstractions", "MAT", "Tullio", "Zygote"]
deps = ["CUDA", "CUDAKernels", "DataDeps", "FFTW", "Fetch", "Flux", "KernelAbstractions", "MAT", "StatsBase", "Tullio", "Zygote"]
path = ".."
uuid = "ea5c82af-86e5-48da-8ee1-382d6ad7af4b"
version = "0.1.0"
Expand Down
1 change: 1 addition & 0 deletions src/NeuralOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ module NeuralOperators
using DataDeps
using Fetch
using MAT
using StatsBase

using Flux
using FFTW
Expand Down
66 changes: 61 additions & 5 deletions src/data.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,61 @@
export
get_burgers_data
UnitGaussianNormalizer,
encode,
decode,
get_burgers_data,
get_darcy_flow_data

function register_datasets()
struct UnitGaussianNormalizer{T}
mean::Array{T}
std::Array{T}
ϵ::T
end

function UnitGaussianNormalizer(𝐱; ϵ=1f-5)
dims = 1:ndims(𝐱)-1

return UnitGaussianNormalizer(mean(𝐱, dims=dims), StatsBase.std(𝐱, dims=dims), ϵ)
end

encode(n::UnitGaussianNormalizer, 𝐱::AbstractArray) = @. (𝐱-n.mean) / (n.std+n.ϵ)
decode(n::UnitGaussianNormalizer, 𝐱::AbstractArray) = @. 𝐱 * (n.std+n.ϵ) + n.mean


function register_burgers()
register(DataDep(
"BurgersR10",
"Burgers",
"""
Burgers' equation dataset from
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
""",
"https://drive.google.com/file/d/16a8od4vidbiNR3WtaBPCSZ0T3moxjhYe/view?usp=sharing",
"https://drive.google.com/file/d/17MYsKzxUQVaLMWodzPbffR8hhDHoadPp/view?usp=sharing",
"9cbbe5070556c777b1ba3bacd49da5c36ea8ed138ba51b6ee76a24b971066ecd",
fetch_method=gdownload,
post_fetch_method=unpack
))
end

function register_darcy_flow()
register(DataDep(
"DarcyFlow",
"""
Darcy flow dataset from
[fourier_neural_operator](https://github.com/zongyi-li/fourier_neural_operator)
""",
"https://drive.google.com/file/d/1zzVMuGhOG70EnR5L24LWqmX9-Wh_H5Wu/view?usp=sharing",
"802825de9da7398407296c99ca9ceb2371c752f6a3bdd1801172e02ce19edda4",
fetch_method=gdownload,
post_fetch_method=unpack
))
end

function register_datasets()
register_burgers()
register_darcy_flow()
end

function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples), T=Float32)
file = matopen(joinpath(datadep"BurgersR10", "burgers_data_R10.mat"))
file = matopen(joinpath(datadep"Burgers", "burgers_data_R10.mat"))
x_data = T.(collect(read(file, "a")[1:n, 1:Δsamples:end]'))
y_data = T.(collect(read(file, "u")[1:n, 1:Δsamples:end]'))
close(file)
Expand All @@ -27,3 +66,20 @@ function get_burgers_data(; n=2048, Δsamples=2^3, grid_size=div(2^13, Δsamples

return x_loc_data, y_data
end

function get_darcy_flow_data(; n=1024, Δsamples=5, T=Float32, test_data=false)
# size(training_data) == size(testing_data) == (1024, 421, 421)
file = test_data ? "piececonst_r421_N1024_smooth2.mat" : "piececonst_r421_N1024_smooth1.mat"
file = matopen(joinpath(datadep"DarcyFlow", file))
x_data = T.(permutedims(read(file, "coeff")[1:n, 1:Δsamples:end, 1:Δsamples:end], (3, 2, 1)))
y_data = T.(permutedims(read(file, "sol")[1:n, 1:Δsamples:end, 1:Δsamples:end], (3, 2, 1)))
close(file)

x_dims = pushfirst!([size(x_data)...], 1)
y_dims = pushfirst!([size(y_data)...], 1)
x_data, y_data = reshape(x_data, x_dims...), reshape(y_data, y_dims...)

x_normalizer, y_normalizer = UnitGaussianNormalizer(x_data), UnitGaussianNormalizer(y_data)

return encode(x_normalizer, x_data), encode(y_normalizer, y_data), x_normalizer, y_normalizer
end
49 changes: 29 additions & 20 deletions src/fourier.jl
Original file line number Diff line number Diff line change
@@ -1,53 +1,62 @@
export
SpectralConv1d,
SpectralConv,
FourierOperator

struct SpectralConv1d{T, S}
struct SpectralConv{N, T, S}
weight::T
in_channel::S
out_channel::S
modes::S
modes::NTuple{N, S}
ndim::S
σ
end

c_glorot_uniform(dims...) = Flux.glorot_uniform(dims...) + Flux.glorot_uniform(dims...)*im

function SpectralConv1d(
ch::Pair{<:Integer, <:Integer},
modes::Integer,
function SpectralConv(
ch::Pair{S, S},
modes::NTuple{N, S},
σ=identity;
init=c_glorot_uniform,
T::DataType=ComplexF32
)
) where {S<:Integer, N}
in_chs, out_chs = ch
scale = one(T) / (in_chs * out_chs)
weights = scale * init(out_chs, in_chs, modes)
weights = scale * init(out_chs, in_chs, prod(modes))

return SpectralConv1d(weights, in_chs, out_chs, modes, σ)
return SpectralConv(weights, in_chs, out_chs, modes, N, σ)
end

Flux.@functor SpectralConv1d
Flux.@functor SpectralConv

Base.ndims(::SpectralConv{N}) where {N} = N

# [prod(m.modes), out_chs, batch] <- [prod(m.modes), in_chs, batch] * [out_chs, in_chs, prod(m.modes)]
spectral_conv(𝐱₁, 𝐱₂) = @tullio 𝐲[m, o, b] := 𝐱₁[m, i, b] * 𝐱₂[o, i, m]

function (m::SpectralConv1d)(𝐱::AbstractArray)
𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (2, 1, 3)) # [x, in_chs, batch] <- [in_chs, x, batch]
𝐱_fft = fft(𝐱ᵀ, 1) # [x, in_chs, batch]
function (m::SpectralConv)(𝐱::AbstractArray)
n_dims = ndims(𝐱)

𝐱ᵀ = permutedims(Zygote.hook(real, 𝐱), (ntuple(i->i+1, ndims(m))..., 1, ndims(m)+2)) # [x, in_chs, batch] <- [in_chs, x, batch]
𝐱_fft = fft(𝐱ᵀ, 1:ndims(m)) # [x, in_chs, batch]

𝐱_flattened = reshape(view(𝐱_fft, map(d->1:d, m.modes)..., :, :), :, size(𝐱_fft, n_dims-1), size(𝐱_fft, n_dims))
𝐱_weighted = spectral_conv(𝐱_flattened, m.weight) # [prod(m.modes), out_chs, batch], only 3-dims
𝐱_shaped = reshape(𝐱_weighted, m.modes..., size(𝐱_weighted, 2), size(𝐱_weighted, 3))

# [modes, out_chs, batch] <- [modes, in_chs, batch] * [out_chs, in_chs, modes]
𝐱_weighted = spectral_conv(view(𝐱_fft, 1:m.modes, :, :), m.weight)
# [x, out_chs, batch] <- [modes, out_chs, batch]
𝐱_padded = cat(𝐱_weighted, zeros(ComplexF32, size(𝐱_fft, 1)-m.modes, Base.tail(size(𝐱_weighted))...), dims=1)
pad = zeros(ComplexF32, ntuple(i->size(𝐱_fft, i)-m.modes[i], ndims(m))..., size(𝐱_shaped, n_dims-1), size(𝐱_shaped, n_dims))
𝐱_padded = cat(𝐱_shaped, pad, dims=1:ndims(m))

𝐱_out = ifft(𝐱_padded, 1) # [x, out_chs, batch]
𝐱_outᵀ = permutedims(real(𝐱_out), (2, 1, 3)) # [out_chs, x, batch] <- [x, out_chs, batch]
𝐱_out = ifft(𝐱_padded, 1:ndims(m)) # [x, out_chs, batch]
𝐱_outᵀ = permutedims(real(𝐱_out), (ndims(m)+1, 1:ndims(m)..., ndims(m)+2)) # [out_chs, x, batch] <- [x, out_chs, batch]

return m.σ.(𝐱_outᵀ)
end

function FourierOperator(ch::Pair{<:Integer, <:Integer}, modes::Integer, σ=identity)
function FourierOperator(ch::Pair{S, S}, modes::NTuple{N, S}, σ=identity) where {S<:Integer, N}
return Chain(
Parallel(+, Dense(ch.first, ch.second), SpectralConv1d(ch, modes)),
Parallel(+, Dense(ch.first, ch.second), SpectralConv(ch, modes)),
x -> σ.(x)
)
end
2 changes: 1 addition & 1 deletion src/model.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ export
FourierNeuralOperator

function FourierNeuralOperator()
modes = 16
modes = (16, )
ch = 64 => 64
σ = relu

Expand Down
18 changes: 18 additions & 0 deletions test/data.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,21 @@
@test size(xs) == (2, 1024, 1000)
@test size(ys) == (1024, 1000)
end

@testset "unit gaussian normalizer" begin
dims = (3, 3, 5, 6)
𝐱 = rand(Float32, dims)

n = UnitGaussianNormalizer(𝐱)

@test size(n.mean) == size(n.std)
@test size(encode(n, 𝐱)) == dims
@test size(decode(n, encode(n, 𝐱))) == dims
end

@testset "get darcy flow data" begin
xs, ys, _, _ = get_darcy_flow_data()

@test size(xs) == (1, 85, 85, 1024)
@test size(ys) == (1, 85, 85, 1024)
end
57 changes: 46 additions & 11 deletions test/fourier.jl
Original file line number Diff line number Diff line change
@@ -1,34 +1,69 @@
@testset "SpectralConv1d" begin
modes = 16
modes = (16, )
ch = 64 => 64

m = Chain(
Dense(2, 64),
SpectralConv1d(ch, modes)
SpectralConv(ch, modes)
)
@test ndims(SpectralConv(ch, modes)) == 1

𝐱, _ = get_burgers_data(n=1000)
@test size(m(𝐱)) == (64, 1024, 1000)
𝐱, _ = get_burgers_data(n=5)
@test size(m(𝐱)) == (64, 1024, 5)

T = Float32
loss(x, y) = Flux.mse(m(x), y)
data = [(T.(𝐱[:, :, 1:5]), rand(T, 64, 1024, 5))]
data = [(𝐱, rand(Float32, 64, 1024, 5))]
Flux.train!(loss, params(m), data, Flux.ADAM())
end

@testset "FourierOperator" begin
modes = 16
@testset "FourierOperator1d" begin
modes = (16, )
ch = 64 => 64

m = Chain(
Dense(2, 64),
FourierOperator(ch, modes)
)

𝐱, _ = get_burgers_data(n=1000)
@test size(m(𝐱)) == (64, 1024, 1000)
𝐱, _ = get_burgers_data(n=5)
@test size(m(𝐱)) == (64, 1024, 5)

loss(x, y) = Flux.mse(m(x), y)
data = [(Float32.(𝐱[:, :, 1:5]), rand(Float32, 64, 1024, 5))]
data = [(𝐱, rand(Float32, 64, 1024, 5))]
Flux.train!(loss, params(m), data, Flux.ADAM())
end

@testset "SpectralConv2d" begin
modes = (16, 16)
ch = 64 => 64

m = Chain(
Dense(1, 64),
SpectralConv(ch, modes)
)
@test ndims(SpectralConv(ch, modes)) == 2

𝐱, _, _, _ = get_darcy_flow_data(n=5, Δsamples=20)
@test size(m(𝐱)) == (64, 22, 22, 5)

loss(x, y) = Flux.mse(m(x), y)
data = [(𝐱, rand(Float32, 64, 22, 22, 5))]
Flux.train!(loss, params(m), data, Flux.ADAM())
end

@testset "FourierOperator2d" begin
modes = (16, 16)
ch = 64 => 64

m = Chain(
Dense(1, 64),
FourierOperator(ch, modes)
)

𝐱, _, _, _ = get_darcy_flow_data(n=5, Δsamples=20)
@test size(m(𝐱)) == (64, 22, 22, 5)

loss(x, y) = Flux.mse(m(x), y)
data = [(𝐱, rand(Float32, 64, 22, 22, 5))]
Flux.train!(loss, params(m), data, Flux.ADAM())
end