Skip to content

Commit 4bf0a3d

Browse files
authored
Merge pull request #232 from FluxML/metalhead-update
Bump compat for Metalhead
2 parents 19dc08b + fa7133b commit 4bf0a3d

9 files changed

+70
-50
lines changed

Project.toml

+3-2
Original file line numberDiff line numberDiff line change
@@ -21,12 +21,13 @@ ColorTypes = "0.10.3, 0.11"
2121
ComputationalResources = "0.3.2"
2222
Flux = "0.13, 0.14"
2323
MLJModelInterface = "1.1.1"
24-
Metalhead = "0.7"
24+
Metalhead = "0.8"
2525
ProgressMeter = "1.7.1"
2626
Tables = "1.0"
2727
julia = "1.6"
2828

2929
[extras]
30+
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
3031
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
3132
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
3233
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -36,4 +37,4 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
3637
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
3738

3839
[targets]
39-
test = ["LinearAlgebra", "MLJBase", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"]
40+
test = ["cuDNN", "LinearAlgebra", "MLJBase", "Random", "StableRNGs", "Statistics", "StatsBase", "Test"]

src/metalhead.jl

+22-16
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ TODO: After https://github.com/FluxML/Metalhead.jl/issues/176:
44
55
- Export and externally document `image_builder` method
66
7-
- Delete definition of `ResNetHack` below
7+
- Delete definition of `VGGHack` below
88
99
- Change default builder in ImageClassifier (see /src/types.jl) from
10-
`image_builder(ResNetHack)` to `image_builder(Metalhead.ResNet)`.
10+
`image_builder(VGGHack)` to `image_builder(Metalhead.VGG)`.
1111
1212
=#
1313

@@ -51,7 +51,7 @@ Base.show(io::IO, w::MetalheadBuilder) =
5151
5252
Return an MLJFlux builder object based on the Metalhead.jl constructor/type
5353
`metalhead_constructor` (eg, `Metalhead.ResNet`). Here `args` and `kwargs` are
54-
passed to the `MetalheadType` constructor at "build time", along with
54+
passed as arguments to `metalhead_constructor` at "build time", along with
5555
the extra keyword specifiers `imsize=...`, `inchannels=...` and
5656
`nclasses=...`, with values inferred from the data.
5757
@@ -61,14 +61,14 @@ If in Metalhead.jl you would do
6161
6262
```julia
6363
using Metalhead
64-
model = ResNet(50, pretrain=true, inchannels=1, nclasses=10)
64+
model = ResNet(50, pretrain=false, inchannels=1, nclasses=10)
6565
```
6666
6767
then in MLJFlux, it suffices to do
6868
6969
```julia
7070
using MLJFlux, Metalhead
71-
builder = image_builder(ResNet, 50, pretrain=true)
71+
builder = image_builder(ResNet, 50, pretrain=false)
7272
```
7373
7474
which can be used in `ImageClassifier` as in
@@ -122,25 +122,31 @@ function VGGHack(
122122
pretrain=false,
123123
)
124124

125-
# Adapted from
126-
# https://github.com/FluxML/Metalhead.jl/blob/9edff63222720ff84671b8087dd71eb370a6c35a/src/convnets/vgg.jl#L165
125+
# Adapted from
126+
# https://github.com/FluxML/Metalhead.jl/blob/4e5b8f16964468518eeb6eb8d7e5f85af4ecf959/src/convnets/vgg.jl#L161
127127
# But we do not ignore `imsize`.
128128

129129
@assert(
130-
depth in keys(Metalhead.vgg_config),
131-
"depth must be from one in $(sort(collect(keys(Metalhead.vgg_config))))"
130+
depth in keys(Metalhead.VGG_CONFIGS),
131+
"depth must be from one in $(sort(collect(keys(Metalhead.VGG_CONFIGS))))"
132132
)
133133
model = Metalhead.VGG(imsize;
134-
config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]],
134+
config = Metalhead.VGG_CONV_CONFIGS[Metalhead.VGG_CONFIGS[depth]],
135135
inchannels,
136136
batchnorm,
137137
nclasses,
138-
fcsize = 4096,
139-
dropout = 0.5)
140-
if pretrain && !batchnorm
141-
Metalhead.loadpretrain!(model, string("VGG", depth))
142-
elseif pretrain
143-
Metalhead.loadpretrain!(model, "VGG$(depth)-BN)")
138+
dropout_prob = 0.5)
139+
if pretrain
140+
imsize == (224, 224) || @warn "Using `pretrain=true` may not work unless "*
141+
"image size is `(224, 224)`, which it is not. "
142+
artifact_name = string("vgg", depth)
143+
if batchnorm
144+
artifact_name *= "_bn"
145+
else
146+
artifact_name *= "-IMAGENET1K_V1"
147+
end
148+
loadpretrain!(model, artifact_name)
144149
end
150+
145151
return model
146152
end

src/mlj_model_interface.jl

+5-4
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ end
4040

4141
# # FIT AND UPDATE
4242

43-
const ERR_BUILDER =
43+
const ERR_BUILDER =
4444
"Builder does not appear to build an architecture compatible with supplied data. "
4545

4646
true_rng(model) = model.rng isa Integer ? MersenneTwister(model.rng) : model.rng
@@ -60,17 +60,18 @@ function MLJModelInterface.fit(model::MLJFluxModel,
6060
catch ex
6161
@error ERR_BUILDER
6262
end
63-
63+
6464
penalty = Penalty(model)
6565
data = move.(collate(model, X, y))
6666

67-
x = data |> first |> first
67+
x = data[1][1]
68+
6869
try
6970
chain(x)
7071
catch ex
7172
@error ERR_BUILDER
7273
throw(ex)
73-
end
74+
end
7475

7576
optimiser = deepcopy(model.optimiser)
7677

test/builders.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ end
4747

4848
# reproducibility (without dropout):
4949
chain2 = MLJFlux.build(builder, StableRNGs.StableRNG(123), 5, 3)
50-
x = rand(5)
50+
x = rand(Float32, 5)
5151
@test chain(x) chain2(x)
5252
end
5353

test/core.jl

+5-5
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ rowvec(y::Vector) = reshape(y, 1, length(y))
77
@test MLJFlux.MLJModelInterface.istransparent(Flux.Adam(0.1))
88

99
@testset "nrows" begin
10-
Xmatrix = rand(stable_rng, 10, 3)
10+
Xmatrix = rand(stable_rng, Float32, 10, 3)
1111
X = MLJBase.table(Xmatrix)
1212
@test MLJFlux.nrows(X) == 10
1313
@test MLJFlux.nrows(Tables.columntable(X)) == 10
@@ -19,7 +19,7 @@ end
1919
# convert to a column table:
2020
X = MLJBase.table(Xmatrix)
2121

22-
y = rand(stable_rng, 10)
22+
y = rand(stable_rng, Float32, 10)
2323
model = MLJFlux.NeuralNetworkRegressor()
2424
model.batch_size= 3
2525
@test MLJFlux.collate(model, X, y) ==
@@ -38,7 +38,7 @@ end
3838
reshape([1; 0], (2,1))]))
3939

4040
# MultitargetNeuralNetworRegressor:
41-
ymatrix = rand(stable_rng, 10, 2)
41+
ymatrix = rand(stable_rng, Float32, 10, 2)
4242
y = MLJBase.table(ymatrix) # a rowaccess table
4343
model = MLJFlux.NeuralNetworkRegressor()
4444
model.batch_size= 3
@@ -54,7 +54,7 @@ end
5454
ymatrix'[:,7:9], ymatrix'[:,10:10]]))
5555

5656
# ImageClassifier
57-
Xmatrix = coerce(rand(stable_rng, 6, 6, 1, 10), GrayImage)
57+
Xmatrix = coerce(rand(stable_rng, Float32, 6, 6, 1, 10), GrayImage)
5858
y = categorical(['a', 'b', 'a', 'a', 'b', 'a', 'a', 'a', 'b', 'a'])
5959
model = MLJFlux.ImageClassifier(batch_size=2)
6060

@@ -69,7 +69,7 @@ end
6969

7070
end
7171

72-
Xmatrix = rand(stable_rng, 100, 5)
72+
Xmatrix = rand(stable_rng, Float32, 100, 5)
7373
X = MLJBase.table(Xmatrix)
7474
y = Xmatrix[:, 1] + Xmatrix[:, 2] + Xmatrix[:, 3] +
7575
Xmatrix[:, 4] + Xmatrix[:, 5]

test/image.jl

+23-12
Original file line numberDiff line numberDiff line change
@@ -8,17 +8,29 @@ mutable struct MyNeuralNetwork <: MLJFlux.Builder
88
kernel2
99
end
1010

11-
function MLJFlux.build(model::MyNeuralNetwork, rng, ip, op, n_channels)
11+
# to get a matrix whose last dimension mathces that of the array input (the batch size):
12+
function make2d(x)
13+
l = length(x)
14+
b = size(x)[end]
15+
reshape(x, div(l, b), b)
16+
end
17+
18+
function MLJFlux.build(builder::MyNeuralNetwork, rng, ip, op, n_channels)
1219
init = Flux.glorot_uniform(rng)
13-
Flux.Chain(
14-
Flux.Conv(model.kernel1, n_channels=>2, init=init),
15-
Flux.Conv(model.kernel2, 2=>1, init=init),
16-
x->reshape(x, :, size(x)[end]),
17-
Flux.Dense(16, op, init=init))
20+
front = Flux.Chain(
21+
Flux.Conv(builder.kernel1, n_channels=>2, init=init),
22+
Flux.Conv(builder.kernel2, 2=>1, init=init),
23+
make2d,
24+
)
25+
d = Flux.outputsize(front, (ip..., n_channels, 1))[1]
26+
return Flux.Chain(
27+
front,
28+
Flux.Dense(d, op, init=init)
29+
)
1830
end
1931

2032
builder = MyNeuralNetwork((2,2), (2,2))
21-
images, labels = MLJFlux.make_images(stable_rng)
33+
images, labels = MLJFlux.make_images(stable_rng);
2234
losses = []
2335

2436
@testset_accelerated "ImageClassifier basic tests" accel begin
@@ -69,8 +81,6 @@ reference = losses[1]
6981

7082
# # BASIC IMAGE TESTS COLOR
7183

72-
# In this case we use the default ResNet builder
73-
7484
builder = MyNeuralNetwork((2,2), (2,2))
7585
images, labels = MLJFlux.make_images(stable_rng, color=true)
7686
losses = []
@@ -112,12 +122,13 @@ reference = losses[1]
112122
@test all(x->abs(x - reference)/reference < 1e-5, losses[2:end])
113123

114124

115-
# # SMOKE TEST FOR DEFAULT BUILDER
125+
# # SMOKE TEST FOR DEFAULT BUILDER
116126

117-
images, labels = MLJFlux.make_images(stable_rng, image_size=(32, 32), n_images=12, noise=0.2, color=true);
127+
images, labels = MLJFlux.make_images(stable_rng, image_size=(32, 32), n_images=12,
128+
noise=0.2, color=true);
118129

119130
@testset_accelerated "ImageClassifier basic tests" accel begin
120-
model = MLJFlux.ImageClassifier(epochs=10,
131+
model = MLJFlux.ImageClassifier(epochs=5,
121132
batch_size=4,
122133
acceleration=accel,
123134
rng=stable_rng)

test/metalhead.jl

+9-10
Original file line numberDiff line numberDiff line change
@@ -42,16 +42,15 @@ end
4242
@test builder.metalhead_constructor == Metalhead.VGG
4343
@test builder.args == (depth, )
4444
@test (; builder.kwargs...) == (; batchnorm=true)
45-
ref_chain = Metalhead.VGG(
46-
imsize;
47-
config = Metalhead.vgg_conv_config[Metalhead.vgg_config[depth]],
48-
inchannels,
49-
batchnorm=true,
50-
nclasses,
51-
fcsize = 4096,
52-
dropout = 0.5
53-
)
54-
# needs https://github.com/FluxML/Metalhead.jl/issues/176
45+
46+
## needs https://github.com/FluxML/Metalhead.jl/issues/176:
47+
# ref_chain = Metalhead.VGG(
48+
# imsize;
49+
# config = Metalhead.VGG_CONV_CONFIGS[Metalhead.VGG_CONFIGS[depth]],
50+
# inchannels,
51+
# batchnorm=true,
52+
# nclasses,
53+
# )
5554
# chain =
5655
# MLJFlux.build(builder, StableRNGs.StableRNG(123), imsize, nclasses, inchannels)
5756
# @test length.(MLJFlux.Flux.params(ref_chain)) ==

test/mlj_model_interface.jl

+1
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ end
4444

4545
# integration test:
4646
X, y = MLJBase.make_regression(10)
47+
X = Float32.(MLJBase.Tables.matrix(X)) |> MLJBase.Tables.table
4748
mach = MLJBase.machine(model, X, y)
4849
MLJBase.fit!(mach, verbosity=0)
4950
losses = MLJBase.training_losses(mach)

test/runtests.jl

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import Random.seed!
1010
using Statistics
1111
import StatsBase
1212
using StableRNGs
13+
using cuDNN
1314

1415
using ComputationalResources
1516
using ComputationalResources: CPU1, CUDALibs

0 commit comments

Comments
 (0)