Skip to content

Commit f2a8883

Browse files
committed
updates, rm some Optimisers detail
1 parent cf2e7a9 commit f2a8883

File tree

4 files changed

+29
-83
lines changed

4 files changed

+29
-83
lines changed

docs/src/training/optimisers.md

+8-75
Original file line numberDiff line numberDiff line change
@@ -4,53 +4,24 @@ CurrentModule = Flux
44

55
# [Optimisers](@id man-optimisers)
66

7-
Consider a [simple linear regression](../models/basics.md). We create some dummy data, calculate a loss, and backpropagate to calculate gradients for the parameters `W` and `b`.
7+
Flux builds in many optimisation rules for use with [`train!`](@ref Flux.Optimise.train!) and
8+
other training functions.
89

9-
```julia
10-
using Flux
11-
12-
W = rand(2, 5)
13-
b = rand(2)
14-
15-
predict(x) = (W * x) .+ b
16-
loss(x, y) = sum((predict(x) .- y).^2)
10+
The mechanism by which these work is gradually being replaced as part of the change
11+
from "implicit" dictionary-based to "explicit" tree-like structures.
12+
At present, the same struct (such as `Adam`) can be used with either form,
13+
and will be automatically translated.
1714

18-
x, y = rand(5), rand(2) # Dummy data
19-
l = loss(x, y) # ~ 3
20-
21-
θ = Flux.params(W, b)
22-
grads = gradient(() -> loss(x, y), θ)
23-
```
15+
For full details of how the new "explicit" interface works, see the [Optimisers.jl documentation](https://fluxml.ai/Optimisers.jl/dev/).
2416

25-
We want to update each parameter, using the gradient, in order to improve (reduce) the loss. Here's one way to do that:
26-
27-
```julia
28-
η = 0.1 # Learning Rate
29-
for p in (W, b)
30-
p .-= η * grads[p]
31-
end
32-
```
17+
For full details on how the "implicit" interface worked, see the [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/optimisers/#Optimiser-Interface).
3318

34-
Running this will alter the parameters `W` and `b` and our loss should go down. Flux provides a more general way to do optimiser updates like this.
35-
36-
```julia
37-
using Flux: update!
38-
39-
opt = Descent(0.1) # Gradient descent with learning rate 0.1
40-
41-
for p in (W, b)
42-
update!(opt, p, grads[p])
43-
end
44-
```
45-
46-
An optimiser `update!` accepts a parameter and a gradient, and updates the parameter according to the chosen rule. We can also pass `opt` to our [training loop](training.md), which will update all parameters of the model in a loop. However, we can now easily replace `Descent` with a more advanced optimiser such as `Adam`.
4719

4820
## Optimiser Reference
4921

5022
All optimisers return an object that, when passed to `train!`, will update the parameters passed to it.
5123

5224
```@docs
53-
Flux.Optimise.update!
5425
Descent
5526
Momentum
5627
Nesterov
@@ -67,44 +38,6 @@ OAdam
6738
AdaBelief
6839
```
6940

70-
## Optimiser Interface
71-
72-
Flux's optimisers are built around a `struct` that holds all the optimiser parameters along with a definition of how to apply the update rule associated with it. We do this via the `apply!` function which takes the optimiser as the first argument followed by the parameter and its corresponding gradient.
73-
74-
In this manner Flux also allows one to create custom optimisers to be used seamlessly. Let's work on this with a simple example.
75-
76-
```julia
77-
mutable struct Momentum
78-
eta
79-
rho
80-
velocity
81-
end
82-
83-
Momentum(eta::Real, rho::Real) = Momentum(eta, rho, IdDict())
84-
```
85-
86-
The `Momentum` type will act as our optimiser in this case. Notice that we have added all the parameters as fields, along with the velocity which we will use as our state dictionary. Each parameter in our models will get an entry in there. We can now define the rule applied when this optimiser is invoked.
87-
88-
```julia
89-
function Flux.Optimise.apply!(o::Momentum, x, Δ)
90-
η, ρ = o.eta, o.rho
91-
v = get!(o.velocity, x, zero(x))::typeof(x)
92-
@. v = ρ * v - η * Δ
93-
@. Δ = -v
94-
end
95-
```
96-
97-
This is the basic definition of a Momentum update rule given by:
98-
99-
```math
100-
v = ρ * v - η * Δ
101-
w = w - v
102-
```
103-
104-
The `apply!` defines the update rules for an optimiser `opt`, given the parameters and gradients. It returns the updated gradients. Here, every parameter `x` is retrieved from the running state `v` and subsequently updates the state of the optimiser.
105-
106-
Flux internally calls on this function via the `update!` function. It shares the API with `apply!` but ensures that multiple parameters are handled gracefully.
107-
10841
## Composing Optimisers
10942

11043
Flux defines a special kind of optimiser simply called `Optimiser` which takes in arbitrary optimisers as input. Its behaviour is similar to the usual optimisers, but differs in that it acts by calling the optimisers listed in it sequentially. Each optimiser produces a modified gradient

docs/src/training/train_api.md

+12-7
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,16 @@
11
# Training API
22

3-
43
```@docs
54
Flux.Train.setup
6-
Flux.Train.update!
7-
Flux.Train.train!
5+
Flux.Optimise.train!(loss, model, data, opt; cb)
6+
```
7+
8+
The new version of Flux's training code was written as an independent package, called Optimisers.jl.
9+
However, at present all Flux models contain parameter arrays (such as `Array`s and `CuArray`s)
10+
which can be updated in-place. Thus objects returned by `update!` can be ignored.
11+
12+
```@docs
13+
Optimisers.update!
814
```
915

1016
## Implicit style
@@ -15,14 +21,12 @@ Flux 0.13 is the transitional version which supports both.
1521

1622
For full details on how to use the implicit style, see [Flux 0.13.6 manual](https://fluxml.ai/Flux.jl/v0.13.6/training/training/).
1723

18-
1924
```@docs
2025
Flux.params
21-
Flux.Optimise.update!
22-
Flux.Optimise.train!
26+
Optimisers.update!(opt::Flux.Optimise.AbstractOptimiser, xs::Params, gs)
27+
Flux.Optimise.train!(loss, ps::Flux.Params, data, opt::Flux.Optimise.AbstractOptimiser; cb)
2328
```
2429

25-
2630
Note that, by default, `train!` only loops over the data once (a single "epoch").
2731
A convenient way to run multiple epochs from the REPL is provided by `@epochs`.
2832

@@ -69,3 +73,4 @@ cb = function ()
6973
accuracy() > 0.9 && Flux.stop()
7074
end
7175
```
76+

src/Flux.jl

+3-1
Original file line numberDiff line numberDiff line change
@@ -34,9 +34,11 @@ export Descent, Adam, Momentum, Nesterov, RMSProp,
3434
AdamW, RAdam, AdaBelief, InvDecay, ExpDecay,
3535
WeightDecay, ClipValue, ClipNorm
3636

37+
export ClipGrad, OptimiserChain # these are const defined in deprecations, for ClipValue, Optimiser
38+
3739
include("train.jl")
3840
using .Train
39-
# using .Train: setup, @train_autodiff
41+
using .Train: setup
4042

4143
using CUDA
4244
const use_cuda = Ref{Union{Nothing,Bool}}(nothing)

src/optimise/optimisers.jl

+6
Original file line numberDiff line numberDiff line change
@@ -564,6 +564,9 @@ end
564564
Combine several optimisers into one; each optimiser produces a modified gradient
565565
that will be fed into the next, and this is finally applied to the parameter as
566566
usual.
567+
568+
!!! note
569+
This will be replaced by `Optimisers.OptimiserChain` in Flux 0.14.
567570
"""
568571
mutable struct Optimiser <: AbstractOptimiser
569572
os::Vector{Any}
@@ -699,6 +702,9 @@ end
699702
ClipValue(thresh)
700703
701704
Clip gradients when their absolute value exceeds `thresh`.
705+
706+
!!! note
707+
This will be replaced by `Optimisers.ClipGrad` in Flux 0.14.
702708
"""
703709
mutable struct ClipValue{T} <: AbstractOptimiser
704710
thresh::T

0 commit comments

Comments
 (0)