From 04ecb4eec657c62edb957be1b1608bd98af2d9cc Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 8 Jan 2025 13:04:25 -0500 Subject: [PATCH 01/33] fixed type stability of linear filter --- src/algorithms/kalman.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/algorithms/kalman.jl b/src/algorithms/kalman.jl index 2702b1f..825aedd 100644 --- a/src/algorithms/kalman.jl +++ b/src/algorithms/kalman.jl @@ -11,7 +11,7 @@ function initialise( rng::AbstractRNG, model::LinearGaussianStateSpaceModel, filter::KalmanFilter; kwargs... ) μ0, Σ0 = calc_initial(model.dyn; kwargs...) - return Gaussian(μ0, Σ0) + return Gaussian(μ0, Matrix(Σ0)) end function predict( From 152917b6a02c8f61238ffae22912493d02243ff9 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 8 Jan 2025 13:05:10 -0500 Subject: [PATCH 02/33] added MLE demonstration --- research/maximum_likelihood/Project.toml | 7 +++ research/maximum_likelihood/mle_demo.jl | 72 ++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 research/maximum_likelihood/Project.toml create mode 100644 research/maximum_likelihood/mle_demo.jl diff --git a/research/maximum_likelihood/Project.toml b/research/maximum_likelihood/Project.toml new file mode 100644 index 0000000..abf67da --- /dev/null +++ b/research/maximum_likelihood/Project.toml @@ -0,0 +1,7 @@ +[deps] +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" diff --git a/research/maximum_likelihood/mle_demo.jl b/research/maximum_likelihood/mle_demo.jl new file mode 100644 index 0000000..c438da6 --- /dev/null +++ b/research/maximum_likelihood/mle_demo.jl @@ -0,0 +1,72 @@ +using GeneralisedFilters +using SSMProblems +using LinearAlgebra +using Random + +## TOY MODEL ############################################################################### + +# this is taken from an example in Kalman.jl +function toy_model(θ::T) where {T<:Real} + μ0 = T[1.0, 0.0] + Σ0 = Diagonal(ones(T, 2)) + + A = T[0.8 θ/2; -0.1 0.8] + Q = Diagonal(T[0.2, 1.0]) + b = zeros(T, 2) + + H = Matrix{T}(I, 1, 2) + R = Diagonal(T[0.2]) + c = zeros(T, 1) + + return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) +end + +# data generation process +rng = MersenneTwister(1234) +true_model = toy_model(1.0) +_, _, ys = sample(rng, true_model, 10000) + +# evaluate and return the log evidence +function logℓ(θ, data) + rng = MersenneTwister(1234) + _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), KF(), data) + return ll +end + +# check type stability (important for use with Enzyme) +@code_warntype logℓ([1.0], ys) + +## MLE ##################################################################################### + +using DifferentiationInterface +using ForwardDiff +using Optimisers + +# initial value +θ = [0.7] + +# setup optimiser (feel free to use other backends) +state = Optimisers.setup(Optimisers.Descent(0.5), θ) +backend = AutoForwardDiff() +num_epochs = 1000 + +# prepare gradients for faster AD +grad_prep = prepare_gradient(logℓ, backend, θ, Constant(ys)) +hess_prep = prepare_hessian(logℓ, backend, θ, Constant(ys)) + +for epoch in 1:num_epochs + # calculate gradients + val, ∇logℓ = DifferentiationInterface.value_and_gradient( + logℓ, grad_prep, backend, θ, Constant(ys) + ) + + # adjust the learning rate for a hacky Newton's method + H = DifferentiationInterface.hessian(logℓ, hess_prep, backend, θ, Constant(ys)) + Optimisers.update!(state, θ, inv(H)*∇logℓ) + + # stopping condition and printer + (epoch % 5) == 1 && println("$(epoch-1):\t $(θ[])") + if (∇logℓ'*∇logℓ) < 1e-12 + break + end +end From efe1573c95be91dd9bd76ed327ce21c66951b1e9 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 8 Jan 2025 16:11:02 -0500 Subject: [PATCH 03/33] flipped sign of objective function --- research/maximum_likelihood/mle_demo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/research/maximum_likelihood/mle_demo.jl b/research/maximum_likelihood/mle_demo.jl index c438da6..c7f86fa 100644 --- a/research/maximum_likelihood/mle_demo.jl +++ b/research/maximum_likelihood/mle_demo.jl @@ -30,7 +30,7 @@ _, _, ys = sample(rng, true_model, 10000) function logℓ(θ, data) rng = MersenneTwister(1234) _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), KF(), data) - return ll + return -ll end # check type stability (important for use with Enzyme) From acbc4e11e793ff9bb7400eb6046ed4850e99a0a3 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Fri, 7 Mar 2025 15:34:04 -0500 Subject: [PATCH 04/33] reorganized and added Mooncake MWE --- research/maximum_likelihood/Project.toml | 4 ++ research/maximum_likelihood/mle_demo.jl | 4 +- research/maximum_likelihood/mooncake_test.jl | 63 ++++++++++++++++++++ 3 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 research/maximum_likelihood/mooncake_test.jl diff --git a/research/maximum_likelihood/Project.toml b/research/maximum_likelihood/Project.toml index abf67da..d2b1a8a 100644 --- a/research/maximum_likelihood/Project.toml +++ b/research/maximum_likelihood/Project.toml @@ -1,7 +1,11 @@ [deps] DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GeneralisedFilters = "3ef92589-7ab8-43f9-b5b9-a3a0c86ecbb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" diff --git a/research/maximum_likelihood/mle_demo.jl b/research/maximum_likelihood/mle_demo.jl index c7f86fa..7f5e595 100644 --- a/research/maximum_likelihood/mle_demo.jl +++ b/research/maximum_likelihood/mle_demo.jl @@ -36,10 +36,10 @@ end # check type stability (important for use with Enzyme) @code_warntype logℓ([1.0], ys) -## MLE ##################################################################################### +## NEWTONS METHOD ########################################################################## using DifferentiationInterface -using ForwardDiff +import ForwardDiff using Optimisers # initial value diff --git a/research/maximum_likelihood/mooncake_test.jl b/research/maximum_likelihood/mooncake_test.jl new file mode 100644 index 0000000..f2ae5af --- /dev/null +++ b/research/maximum_likelihood/mooncake_test.jl @@ -0,0 +1,63 @@ +using GeneralisedFilters +using SSMProblems +using LinearAlgebra +using Random + +## TOY MODEL ############################################################################### + +# this is taken from an example in Kalman.jl +function toy_model(θ::T) where {T<:Real} + μ0 = T[1.0, 0.0] + Σ0 = Diagonal(ones(T, 2)) + + A = T[0.8 θ/2; -0.1 0.8] + Q = Diagonal(T[0.2, 1.0]) + b = zeros(T, 2) + + H = Matrix{T}(I, 1, 2) + R = Diagonal(T[0.2]) + c = zeros(T, 1) + + return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) +end + +# data generation process with small sample +rng = MersenneTwister(1234) +true_model = toy_model(1.0) +_, _, ys = sample(rng, true_model, 20) + +## RUN MOONCKAE TESTS ###################################################################### + +using DifferentiationInterface +import Mooncake +using DistributionsAD + +function build_objective(rng, θ, algo, data) + _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), algo, data) + return -ll +end + +# kalman filter likelihood testing (works, but is slow) +logℓ1 = θ -> build_objective(rng, θ, KF(), ys) +Mooncake.TestUtils.test_rule(rng, logℓ1, [0.7]; is_primitive=false, debug_mode=true) + +# bootstrap filter likelihood testing (shouldn't work) +logℓ2 = θ -> build_objective(rng, θ, BF(512), ys) +Mooncake.TestUtils.test_rule(rng, logℓ2, [0.7]; is_primitive=false, debug_mode=true) + +## FOR USE WITH DIFFERENTIATION INTERFACE ################################################## + +# data should be part of the objective, but be held constant by DifferentiationInterface +logℓ3 = (θ, data) -> build_objective(rng, θ, KF(), data) + +# set the backend with default configuration +backend = AutoMooncake(; config=nothing) + +# prepare the gradient for faster subsequent iteration +grad_prep = prepare_gradient(logℓ3, backend, [0.7], Constant(ys)) + +# evaluate gradients and iterate to show proof of concept +DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [0.7], Constant(ys)) +DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [0.8], Constant(ys)) +DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [0.9], Constant(ys)) +DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [1.0], Constant(ys)) From 25eb4d63a5b080d05eea224ce68dfe557b622a07 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Mon, 10 Mar 2025 14:30:20 -0400 Subject: [PATCH 05/33] fixed KF type stability in Enzyme --- GeneralisedFilters/src/algorithms/kalman.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index 711669d..571f037 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -1,6 +1,7 @@ export KalmanFilter, filter, BatchKalmanFilter using GaussianDistributions using CUDA: i32 +import LinearAlgebra: Symmetric export KalmanFilter, KF, KalmanSmoother, KS @@ -42,12 +43,9 @@ function update( # Update state m = H * μ + c y = obs - m - S = H * Σ * H' + R + S = Symmetric(H * Σ * H' + R) K = Σ * H' / S - # HACK: force the covariance to be positive definite - S = (S + S') / 2 - filtered = Gaussian(μ + K * y, Σ - K * H * Σ) # Compute log-likelihood From 6d061b0f5308f0ab9cbc91bb09f43c37f7708f58 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Tue, 11 Mar 2025 10:34:25 -0400 Subject: [PATCH 06/33] add MWE for Kalman filtering --- .../minimum_working_example.jl | 181 ++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 research/maximum_likelihood/minimum_working_example.jl diff --git a/research/maximum_likelihood/minimum_working_example.jl b/research/maximum_likelihood/minimum_working_example.jl new file mode 100644 index 0000000..5f017b9 --- /dev/null +++ b/research/maximum_likelihood/minimum_working_example.jl @@ -0,0 +1,181 @@ +using LinearAlgebra +using GaussianDistributions +using Random + +using DistributionsAD +using Distributions + +using Enzyme + +## MODEL DEFINITION ######################################################################## + +struct LinearGaussianProcess{ + T<:Real, + ΦT<:AbstractMatrix{T}, + ΣT<:AbstractMatrix{T}, + μT<:AbstractVector{T} + } + ϕ::ΦT + Σ::ΣT + μ::μT + function LinearGaussianProcess(ϕ::ΦT, Σ::ΣT, μ::μT) where { + T<:Real, + ΦT<:AbstractMatrix{T}, + ΣT<:AbstractMatrix{T}, + μT<:AbstractVector{T} + } + @assert size(ϕ,1) == size(Σ,1) == size(Σ,2) == size(μ,1) + return new{T, ΦT, ΣT, μT}(ϕ, Σ, μ) + end +end + +# a rather simplified version of GeneralisedFilters.LinearGaussianStateSpaceModel +struct LinearGaussianModel{ + ΘT<:Real, + TT<:LinearGaussianProcess{ΘT}, + OT<:LinearGaussianProcess{ΘT} + } + transition::TT + observation::OT +end + +## KALMAN FILTER ########################################################################### + +# this is based on the algorithm of GeneralisedFilters.jl +function kalman_filter( + model::LinearGaussianModel, + init_state::Gaussian, + observations::Vector{T} + ) where {T<:Real} + log_evidence = zero(T) + filtered = init_state + + # calc_params(model.dyn) + A = model.transition.ϕ + Q = model.transition.Σ + b = model.transition.μ + + # calc_params(model.obs) + H = model.observation.ϕ + R = model.observation.Σ + c = model.observation.μ + + for obs in observations + # predict step + μ, Σ = GaussianDistributions.pair(filtered) + proposed = Gaussian(A*μ + b, A*Σ*A' + Q) + + # update step + μ, Σ = GaussianDistributions.pair(proposed) + m = H*μ + c + residual = [obs] - m + + S = Symmetric(H*Σ*H' + R) + gain = Σ*H' / S + + filtered = Gaussian(μ + gain*residual, Σ - gain*H*Σ) + log_evidence += logpdf(MvNormal(m, S), [obs]) + end + + return log_evidence +end + +## DEMONSTRATION ########################################################################### + +# model constructor +function build_model(θ::T) where {T<:Real} + trans = LinearGaussianProcess( + T[0.8 θ/2; -0.1 0.8], + Diagonal(T[0.2, 1.0]), + zeros(T, 2) + ) + + obs = LinearGaussianProcess( + Matrix{T}(I, 1, 2), + Diagonal(T[0.2]), + zeros(T, 1) + ) + + return LinearGaussianModel(trans, obs) +end + +# log likelihood function +function logℓ(θ::Vector{T}, data) where {T<:Real} + model = build_model(θ[]) + init_state = Gaussian(T[1.0, 0.0], diagm(ones(T, 2))) + return kalman_filter(model, init_state, data) +end + +# refer to data globally (not preferred) +function logℓ_nodata(θ) + return logℓ(θ, data) +end + +# data generation (with unit covariance) +rng = MersenneTwister(1234) +data = cumsum(randn(rng, 100)) .+ randn(rng, 100) + +# ensure that log likelihood looks stable +logℓ([1.0], data) + +## SYNTACTICAL SUGAR ####################################################################### + +# this has no issue behaving well +grad_test, _ = Enzyme.gradient(Enzyme.Reverse, logℓ, [1.0], Const(data)) + +# this error is unlegible (at least to my untrained eye) +Enzyme.hvp(logℓ_nodata, [1.0], [1.0]) + +## FROM SCRATCH ############################################################################ + +function generate_perturbations(::Type{T}, n::Int) where {T<:Real} + perturbation_mat = Matrix{T}(I, n, n) + return tuple(collect.(eachslice(perturbation_mat, dims=1))...) +end + +generate_perturbations(n::Int) = generate_perturbations(Float64, n) +generate_perturbations(x::Vector{T}) where {T<:Real} = generate_perturbations(T, length(x)) + +function make_zeros(::Type{T}, n::Int) where {T<:Real} + return tuple(collect.(zeros(T, n) for _ in 1:n)...) +end + +make_zeros(n::Int) = make_zeros(Float64, n) +make_zeros(x::Vector{T}) where {T<:Real} = make_zeros(T, length(x)) + +function ∇logℓ(θ, args...) + ∂θ = Enzyme.make_zero(θ) + ∇logℓ!(θ, ∂θ, args...) + return ∂θ +end + +function ∇logℓ!(θ, ∂θ, args...) + Enzyme.autodiff(Enzyme.Reverse, logℓ, Active, Duplicated(θ, ∂θ), args...) + return nothing +end + +# ensure I'm doing the right thing +@assert grad_test == ∇logℓ([1.0], Const(data)) + +# see https://enzyme.mit.edu/julia/stable/generated/autodiff/#Vector-forward-over-reverse +function hessian(θ::Vector{T}) where {T<:Real} + # generate impulse and record second order responses + dθ = Enzyme.make_zero(θ) + vθ = generate_perturbations(θ) + H = make_zeros(θ) + + # take derivatives + Enzyme.autodiff( + Enzyme.Forward, + ∇logℓ!, + Enzyme.BatchDuplicated(θ, vθ), + Enzyme.BatchDuplicated(dθ, H), + Const(data), + ) + + # stack appropriately + return vcat(H...) +end + +# errors and I don't know Enzyme well enough to figure out why +hessian([1.0]) From f8f0e59bff8c54fe08a0f219cf40167bc63b2326 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Tue, 11 Mar 2025 12:36:23 -0400 Subject: [PATCH 07/33] replaced second order optimizer and added backend testing --- research/maximum_likelihood/Project.toml | 1 + research/maximum_likelihood/mle_demo.jl | 69 ++++++++++++++---------- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/research/maximum_likelihood/Project.toml b/research/maximum_likelihood/Project.toml index d2b1a8a..f4304f6 100644 --- a/research/maximum_likelihood/Project.toml +++ b/research/maximum_likelihood/Project.toml @@ -9,3 +9,4 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/research/maximum_likelihood/mle_demo.jl b/research/maximum_likelihood/mle_demo.jl index 7f5e595..765c00a 100644 --- a/research/maximum_likelihood/mle_demo.jl +++ b/research/maximum_likelihood/mle_demo.jl @@ -2,13 +2,14 @@ using GeneralisedFilters using SSMProblems using LinearAlgebra using Random +using DistributionsAD ## TOY MODEL ############################################################################### # this is taken from an example in Kalman.jl function toy_model(θ::T) where {T<:Real} μ0 = T[1.0, 0.0] - Σ0 = Diagonal(ones(T, 2)) + Σ0 = diagm(ones(T, 2)) A = T[0.8 θ/2; -0.1 0.8] Q = Diagonal(T[0.2, 1.0]) @@ -24,12 +25,11 @@ end # data generation process rng = MersenneTwister(1234) true_model = toy_model(1.0) -_, _, ys = sample(rng, true_model, 10000) +_, _, ys = sample(rng, true_model, 1000) # evaluate and return the log evidence function logℓ(θ, data) - rng = MersenneTwister(1234) - _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), KF(), data) + _, ll = GeneralisedFilters.filter(toy_model(θ[]), KF(), data) return -ll end @@ -39,34 +39,45 @@ end ## NEWTONS METHOD ########################################################################## using DifferentiationInterface -import ForwardDiff +import ForwardDiff, Zygote, Mooncake, Enzyme using Optimisers -# initial value -θ = [0.7] - -# setup optimiser (feel free to use other backends) -state = Optimisers.setup(Optimisers.Descent(0.5), θ) -backend = AutoForwardDiff() -num_epochs = 1000 - -# prepare gradients for faster AD -grad_prep = prepare_gradient(logℓ, backend, θ, Constant(ys)) -hess_prep = prepare_hessian(logℓ, backend, θ, Constant(ys)) - -for epoch in 1:num_epochs - # calculate gradients - val, ∇logℓ = DifferentiationInterface.value_and_gradient( - logℓ, grad_prep, backend, θ, Constant(ys) - ) +# Zygote will fail due to the model constructor, not because of the filtering algorithm +backends = [ + AutoZygote(), AutoForwardDiff(), AutoMooncake(;config=nothing), AutoEnzyme() +] + +function gradient_descent(backend, θ_init, num_epochs=1000) + θ = deepcopy(θ_init) + state = Optimisers.setup(Optimisers.Descent(1/length(ys)), θ) + grad_prep = prepare_gradient(logℓ, backend, θ, Constant(ys)) + + for epoch in 1:num_epochs + val, ∇logℓ = DifferentiationInterface.value_and_gradient( + logℓ, grad_prep, backend, θ, Constant(ys) + ) + Optimisers.update!(state, θ, ∇logℓ) + + (epoch % 5) == 1 && println("$(epoch-1):\t -$(val)") + if (∇logℓ'*∇logℓ) < 1e-12 + break + end + end - # adjust the learning rate for a hacky Newton's method - H = DifferentiationInterface.hessian(logℓ, hess_prep, backend, θ, Constant(ys)) - Optimisers.update!(state, θ, inv(H)*∇logℓ) + return θ +end - # stopping condition and printer - (epoch % 5) == 1 && println("$(epoch-1):\t $(θ[])") - if (∇logℓ'*∇logℓ) < 1e-12 - break +θ_init = rand(rng, 1) +for backend in backends + println("\n",backend) + local θ_mle + try + θ_mle = gradient_descent(backend, θ_init) + catch err + # TODO: more sophistocated exception handling + @warn "automatic differentiation failed!" exception = (err) + else + # check that the solution converged to the correct value + @assert isapprox(θ_mle, [1.0]; rtol=1e-1) end end From da6a54b6fab14a63ef1d3a58f883b86b5f2b9784 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 8 Jan 2025 13:04:25 -0500 Subject: [PATCH 08/33] fixed type stability of linear filter --- GeneralisedFilters/src/algorithms/kalman.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index 51c5aca..bd722b6 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -12,7 +12,7 @@ function initialise( rng::AbstractRNG, model::LinearGaussianStateSpaceModel, filter::KalmanFilter; kwargs... ) μ0, Σ0 = calc_initial(model.dyn; kwargs...) - return Gaussian(μ0, Σ0) + return Gaussian(μ0, Matrix(Σ0)) end function predict( From 39d9f32e2a6a96630d9844ee1562bad31289cc70 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 8 Jan 2025 13:05:10 -0500 Subject: [PATCH 09/33] added MLE demonstration --- research/maximum_likelihood/Project.toml | 7 +++ research/maximum_likelihood/mle_demo.jl | 72 ++++++++++++++++++++++++ 2 files changed, 79 insertions(+) create mode 100644 research/maximum_likelihood/Project.toml create mode 100644 research/maximum_likelihood/mle_demo.jl diff --git a/research/maximum_likelihood/Project.toml b/research/maximum_likelihood/Project.toml new file mode 100644 index 0000000..abf67da --- /dev/null +++ b/research/maximum_likelihood/Project.toml @@ -0,0 +1,7 @@ +[deps] +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" diff --git a/research/maximum_likelihood/mle_demo.jl b/research/maximum_likelihood/mle_demo.jl new file mode 100644 index 0000000..c438da6 --- /dev/null +++ b/research/maximum_likelihood/mle_demo.jl @@ -0,0 +1,72 @@ +using GeneralisedFilters +using SSMProblems +using LinearAlgebra +using Random + +## TOY MODEL ############################################################################### + +# this is taken from an example in Kalman.jl +function toy_model(θ::T) where {T<:Real} + μ0 = T[1.0, 0.0] + Σ0 = Diagonal(ones(T, 2)) + + A = T[0.8 θ/2; -0.1 0.8] + Q = Diagonal(T[0.2, 1.0]) + b = zeros(T, 2) + + H = Matrix{T}(I, 1, 2) + R = Diagonal(T[0.2]) + c = zeros(T, 1) + + return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) +end + +# data generation process +rng = MersenneTwister(1234) +true_model = toy_model(1.0) +_, _, ys = sample(rng, true_model, 10000) + +# evaluate and return the log evidence +function logℓ(θ, data) + rng = MersenneTwister(1234) + _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), KF(), data) + return ll +end + +# check type stability (important for use with Enzyme) +@code_warntype logℓ([1.0], ys) + +## MLE ##################################################################################### + +using DifferentiationInterface +using ForwardDiff +using Optimisers + +# initial value +θ = [0.7] + +# setup optimiser (feel free to use other backends) +state = Optimisers.setup(Optimisers.Descent(0.5), θ) +backend = AutoForwardDiff() +num_epochs = 1000 + +# prepare gradients for faster AD +grad_prep = prepare_gradient(logℓ, backend, θ, Constant(ys)) +hess_prep = prepare_hessian(logℓ, backend, θ, Constant(ys)) + +for epoch in 1:num_epochs + # calculate gradients + val, ∇logℓ = DifferentiationInterface.value_and_gradient( + logℓ, grad_prep, backend, θ, Constant(ys) + ) + + # adjust the learning rate for a hacky Newton's method + H = DifferentiationInterface.hessian(logℓ, hess_prep, backend, θ, Constant(ys)) + Optimisers.update!(state, θ, inv(H)*∇logℓ) + + # stopping condition and printer + (epoch % 5) == 1 && println("$(epoch-1):\t $(θ[])") + if (∇logℓ'*∇logℓ) < 1e-12 + break + end +end From 30a4be355cc88d2dc582dc2b6cab7498e5925533 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 8 Jan 2025 16:11:02 -0500 Subject: [PATCH 10/33] flipped sign of objective function --- research/maximum_likelihood/mle_demo.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/research/maximum_likelihood/mle_demo.jl b/research/maximum_likelihood/mle_demo.jl index c438da6..c7f86fa 100644 --- a/research/maximum_likelihood/mle_demo.jl +++ b/research/maximum_likelihood/mle_demo.jl @@ -30,7 +30,7 @@ _, _, ys = sample(rng, true_model, 10000) function logℓ(θ, data) rng = MersenneTwister(1234) _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), KF(), data) - return ll + return -ll end # check type stability (important for use with Enzyme) From 935df10335790b5bdd713e3dd5b4a587ca04e0ee Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Fri, 7 Mar 2025 15:34:04 -0500 Subject: [PATCH 11/33] reorganized and added Mooncake MWE --- research/maximum_likelihood/Project.toml | 4 ++ research/maximum_likelihood/mle_demo.jl | 4 +- research/maximum_likelihood/mooncake_test.jl | 63 ++++++++++++++++++++ 3 files changed, 69 insertions(+), 2 deletions(-) create mode 100644 research/maximum_likelihood/mooncake_test.jl diff --git a/research/maximum_likelihood/Project.toml b/research/maximum_likelihood/Project.toml index abf67da..d2b1a8a 100644 --- a/research/maximum_likelihood/Project.toml +++ b/research/maximum_likelihood/Project.toml @@ -1,7 +1,11 @@ [deps] DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GeneralisedFilters = "3ef92589-7ab8-43f9-b5b9-a3a0c86ecbb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" diff --git a/research/maximum_likelihood/mle_demo.jl b/research/maximum_likelihood/mle_demo.jl index c7f86fa..7f5e595 100644 --- a/research/maximum_likelihood/mle_demo.jl +++ b/research/maximum_likelihood/mle_demo.jl @@ -36,10 +36,10 @@ end # check type stability (important for use with Enzyme) @code_warntype logℓ([1.0], ys) -## MLE ##################################################################################### +## NEWTONS METHOD ########################################################################## using DifferentiationInterface -using ForwardDiff +import ForwardDiff using Optimisers # initial value diff --git a/research/maximum_likelihood/mooncake_test.jl b/research/maximum_likelihood/mooncake_test.jl new file mode 100644 index 0000000..f2ae5af --- /dev/null +++ b/research/maximum_likelihood/mooncake_test.jl @@ -0,0 +1,63 @@ +using GeneralisedFilters +using SSMProblems +using LinearAlgebra +using Random + +## TOY MODEL ############################################################################### + +# this is taken from an example in Kalman.jl +function toy_model(θ::T) where {T<:Real} + μ0 = T[1.0, 0.0] + Σ0 = Diagonal(ones(T, 2)) + + A = T[0.8 θ/2; -0.1 0.8] + Q = Diagonal(T[0.2, 1.0]) + b = zeros(T, 2) + + H = Matrix{T}(I, 1, 2) + R = Diagonal(T[0.2]) + c = zeros(T, 1) + + return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) +end + +# data generation process with small sample +rng = MersenneTwister(1234) +true_model = toy_model(1.0) +_, _, ys = sample(rng, true_model, 20) + +## RUN MOONCKAE TESTS ###################################################################### + +using DifferentiationInterface +import Mooncake +using DistributionsAD + +function build_objective(rng, θ, algo, data) + _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), algo, data) + return -ll +end + +# kalman filter likelihood testing (works, but is slow) +logℓ1 = θ -> build_objective(rng, θ, KF(), ys) +Mooncake.TestUtils.test_rule(rng, logℓ1, [0.7]; is_primitive=false, debug_mode=true) + +# bootstrap filter likelihood testing (shouldn't work) +logℓ2 = θ -> build_objective(rng, θ, BF(512), ys) +Mooncake.TestUtils.test_rule(rng, logℓ2, [0.7]; is_primitive=false, debug_mode=true) + +## FOR USE WITH DIFFERENTIATION INTERFACE ################################################## + +# data should be part of the objective, but be held constant by DifferentiationInterface +logℓ3 = (θ, data) -> build_objective(rng, θ, KF(), data) + +# set the backend with default configuration +backend = AutoMooncake(; config=nothing) + +# prepare the gradient for faster subsequent iteration +grad_prep = prepare_gradient(logℓ3, backend, [0.7], Constant(ys)) + +# evaluate gradients and iterate to show proof of concept +DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [0.7], Constant(ys)) +DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [0.8], Constant(ys)) +DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [0.9], Constant(ys)) +DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [1.0], Constant(ys)) From 2cc7d636f9dbce7bf0954b00b3cf3a8c427844cd Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Mon, 10 Mar 2025 14:30:20 -0400 Subject: [PATCH 12/33] fixed KF type stability in Enzyme --- GeneralisedFilters/src/algorithms/kalman.jl | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index bd722b6..ddc279e 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -1,6 +1,7 @@ export KalmanFilter, filter, BatchKalmanFilter using GaussianDistributions using CUDA: i32 +import LinearAlgebra: Symmetric export KalmanFilter, KF, KalmanSmoother, KS @@ -42,12 +43,9 @@ function update( # Update state m = H * μ + c y = obs - m - S = H * Σ * H' + R + S = Symmetric(H * Σ * H' + R) K = Σ * H' / S - # HACK: force the covariance to be positive definite - S = (S + S') / 2 - filtered = Gaussian(μ + K * y, Σ - K * H * Σ) # Compute log-likelihood From 5c9149e44b079de3252cb723de846fe967b87bb2 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Tue, 11 Mar 2025 10:34:25 -0400 Subject: [PATCH 13/33] add MWE for Kalman filtering --- .../minimum_working_example.jl | 181 ++++++++++++++++++ 1 file changed, 181 insertions(+) create mode 100644 research/maximum_likelihood/minimum_working_example.jl diff --git a/research/maximum_likelihood/minimum_working_example.jl b/research/maximum_likelihood/minimum_working_example.jl new file mode 100644 index 0000000..5f017b9 --- /dev/null +++ b/research/maximum_likelihood/minimum_working_example.jl @@ -0,0 +1,181 @@ +using LinearAlgebra +using GaussianDistributions +using Random + +using DistributionsAD +using Distributions + +using Enzyme + +## MODEL DEFINITION ######################################################################## + +struct LinearGaussianProcess{ + T<:Real, + ΦT<:AbstractMatrix{T}, + ΣT<:AbstractMatrix{T}, + μT<:AbstractVector{T} + } + ϕ::ΦT + Σ::ΣT + μ::μT + function LinearGaussianProcess(ϕ::ΦT, Σ::ΣT, μ::μT) where { + T<:Real, + ΦT<:AbstractMatrix{T}, + ΣT<:AbstractMatrix{T}, + μT<:AbstractVector{T} + } + @assert size(ϕ,1) == size(Σ,1) == size(Σ,2) == size(μ,1) + return new{T, ΦT, ΣT, μT}(ϕ, Σ, μ) + end +end + +# a rather simplified version of GeneralisedFilters.LinearGaussianStateSpaceModel +struct LinearGaussianModel{ + ΘT<:Real, + TT<:LinearGaussianProcess{ΘT}, + OT<:LinearGaussianProcess{ΘT} + } + transition::TT + observation::OT +end + +## KALMAN FILTER ########################################################################### + +# this is based on the algorithm of GeneralisedFilters.jl +function kalman_filter( + model::LinearGaussianModel, + init_state::Gaussian, + observations::Vector{T} + ) where {T<:Real} + log_evidence = zero(T) + filtered = init_state + + # calc_params(model.dyn) + A = model.transition.ϕ + Q = model.transition.Σ + b = model.transition.μ + + # calc_params(model.obs) + H = model.observation.ϕ + R = model.observation.Σ + c = model.observation.μ + + for obs in observations + # predict step + μ, Σ = GaussianDistributions.pair(filtered) + proposed = Gaussian(A*μ + b, A*Σ*A' + Q) + + # update step + μ, Σ = GaussianDistributions.pair(proposed) + m = H*μ + c + residual = [obs] - m + + S = Symmetric(H*Σ*H' + R) + gain = Σ*H' / S + + filtered = Gaussian(μ + gain*residual, Σ - gain*H*Σ) + log_evidence += logpdf(MvNormal(m, S), [obs]) + end + + return log_evidence +end + +## DEMONSTRATION ########################################################################### + +# model constructor +function build_model(θ::T) where {T<:Real} + trans = LinearGaussianProcess( + T[0.8 θ/2; -0.1 0.8], + Diagonal(T[0.2, 1.0]), + zeros(T, 2) + ) + + obs = LinearGaussianProcess( + Matrix{T}(I, 1, 2), + Diagonal(T[0.2]), + zeros(T, 1) + ) + + return LinearGaussianModel(trans, obs) +end + +# log likelihood function +function logℓ(θ::Vector{T}, data) where {T<:Real} + model = build_model(θ[]) + init_state = Gaussian(T[1.0, 0.0], diagm(ones(T, 2))) + return kalman_filter(model, init_state, data) +end + +# refer to data globally (not preferred) +function logℓ_nodata(θ) + return logℓ(θ, data) +end + +# data generation (with unit covariance) +rng = MersenneTwister(1234) +data = cumsum(randn(rng, 100)) .+ randn(rng, 100) + +# ensure that log likelihood looks stable +logℓ([1.0], data) + +## SYNTACTICAL SUGAR ####################################################################### + +# this has no issue behaving well +grad_test, _ = Enzyme.gradient(Enzyme.Reverse, logℓ, [1.0], Const(data)) + +# this error is unlegible (at least to my untrained eye) +Enzyme.hvp(logℓ_nodata, [1.0], [1.0]) + +## FROM SCRATCH ############################################################################ + +function generate_perturbations(::Type{T}, n::Int) where {T<:Real} + perturbation_mat = Matrix{T}(I, n, n) + return tuple(collect.(eachslice(perturbation_mat, dims=1))...) +end + +generate_perturbations(n::Int) = generate_perturbations(Float64, n) +generate_perturbations(x::Vector{T}) where {T<:Real} = generate_perturbations(T, length(x)) + +function make_zeros(::Type{T}, n::Int) where {T<:Real} + return tuple(collect.(zeros(T, n) for _ in 1:n)...) +end + +make_zeros(n::Int) = make_zeros(Float64, n) +make_zeros(x::Vector{T}) where {T<:Real} = make_zeros(T, length(x)) + +function ∇logℓ(θ, args...) + ∂θ = Enzyme.make_zero(θ) + ∇logℓ!(θ, ∂θ, args...) + return ∂θ +end + +function ∇logℓ!(θ, ∂θ, args...) + Enzyme.autodiff(Enzyme.Reverse, logℓ, Active, Duplicated(θ, ∂θ), args...) + return nothing +end + +# ensure I'm doing the right thing +@assert grad_test == ∇logℓ([1.0], Const(data)) + +# see https://enzyme.mit.edu/julia/stable/generated/autodiff/#Vector-forward-over-reverse +function hessian(θ::Vector{T}) where {T<:Real} + # generate impulse and record second order responses + dθ = Enzyme.make_zero(θ) + vθ = generate_perturbations(θ) + H = make_zeros(θ) + + # take derivatives + Enzyme.autodiff( + Enzyme.Forward, + ∇logℓ!, + Enzyme.BatchDuplicated(θ, vθ), + Enzyme.BatchDuplicated(dθ, H), + Const(data), + ) + + # stack appropriately + return vcat(H...) +end + +# errors and I don't know Enzyme well enough to figure out why +hessian([1.0]) From 4936c521216b5652484cba9cd9ac840b60b597af Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Tue, 11 Mar 2025 12:36:23 -0400 Subject: [PATCH 14/33] replaced second order optimizer and added backend testing --- research/maximum_likelihood/Project.toml | 1 + research/maximum_likelihood/mle_demo.jl | 69 ++++++++++++++---------- 2 files changed, 41 insertions(+), 29 deletions(-) diff --git a/research/maximum_likelihood/Project.toml b/research/maximum_likelihood/Project.toml index d2b1a8a..f4304f6 100644 --- a/research/maximum_likelihood/Project.toml +++ b/research/maximum_likelihood/Project.toml @@ -9,3 +9,4 @@ Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/research/maximum_likelihood/mle_demo.jl b/research/maximum_likelihood/mle_demo.jl index 7f5e595..765c00a 100644 --- a/research/maximum_likelihood/mle_demo.jl +++ b/research/maximum_likelihood/mle_demo.jl @@ -2,13 +2,14 @@ using GeneralisedFilters using SSMProblems using LinearAlgebra using Random +using DistributionsAD ## TOY MODEL ############################################################################### # this is taken from an example in Kalman.jl function toy_model(θ::T) where {T<:Real} μ0 = T[1.0, 0.0] - Σ0 = Diagonal(ones(T, 2)) + Σ0 = diagm(ones(T, 2)) A = T[0.8 θ/2; -0.1 0.8] Q = Diagonal(T[0.2, 1.0]) @@ -24,12 +25,11 @@ end # data generation process rng = MersenneTwister(1234) true_model = toy_model(1.0) -_, _, ys = sample(rng, true_model, 10000) +_, _, ys = sample(rng, true_model, 1000) # evaluate and return the log evidence function logℓ(θ, data) - rng = MersenneTwister(1234) - _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), KF(), data) + _, ll = GeneralisedFilters.filter(toy_model(θ[]), KF(), data) return -ll end @@ -39,34 +39,45 @@ end ## NEWTONS METHOD ########################################################################## using DifferentiationInterface -import ForwardDiff +import ForwardDiff, Zygote, Mooncake, Enzyme using Optimisers -# initial value -θ = [0.7] - -# setup optimiser (feel free to use other backends) -state = Optimisers.setup(Optimisers.Descent(0.5), θ) -backend = AutoForwardDiff() -num_epochs = 1000 - -# prepare gradients for faster AD -grad_prep = prepare_gradient(logℓ, backend, θ, Constant(ys)) -hess_prep = prepare_hessian(logℓ, backend, θ, Constant(ys)) - -for epoch in 1:num_epochs - # calculate gradients - val, ∇logℓ = DifferentiationInterface.value_and_gradient( - logℓ, grad_prep, backend, θ, Constant(ys) - ) +# Zygote will fail due to the model constructor, not because of the filtering algorithm +backends = [ + AutoZygote(), AutoForwardDiff(), AutoMooncake(;config=nothing), AutoEnzyme() +] + +function gradient_descent(backend, θ_init, num_epochs=1000) + θ = deepcopy(θ_init) + state = Optimisers.setup(Optimisers.Descent(1/length(ys)), θ) + grad_prep = prepare_gradient(logℓ, backend, θ, Constant(ys)) + + for epoch in 1:num_epochs + val, ∇logℓ = DifferentiationInterface.value_and_gradient( + logℓ, grad_prep, backend, θ, Constant(ys) + ) + Optimisers.update!(state, θ, ∇logℓ) + + (epoch % 5) == 1 && println("$(epoch-1):\t -$(val)") + if (∇logℓ'*∇logℓ) < 1e-12 + break + end + end - # adjust the learning rate for a hacky Newton's method - H = DifferentiationInterface.hessian(logℓ, hess_prep, backend, θ, Constant(ys)) - Optimisers.update!(state, θ, inv(H)*∇logℓ) + return θ +end - # stopping condition and printer - (epoch % 5) == 1 && println("$(epoch-1):\t $(θ[])") - if (∇logℓ'*∇logℓ) < 1e-12 - break +θ_init = rand(rng, 1) +for backend in backends + println("\n",backend) + local θ_mle + try + θ_mle = gradient_descent(backend, θ_init) + catch err + # TODO: more sophistocated exception handling + @warn "automatic differentiation failed!" exception = (err) + else + # check that the solution converged to the correct value + @assert isapprox(θ_mle, [1.0]; rtol=1e-1) end end From 56c6dc0817cf60192caa270be8cfa2c29e72b392 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Thu, 13 Mar 2025 10:49:01 -0400 Subject: [PATCH 15/33] fixed Mooncake errors for Bootstrap filter --- GeneralisedFilters/src/resamplers.jl | 1 - research/maximum_likelihood/mooncake_test.jl | 30 ++++++-------------- 2 files changed, 8 insertions(+), 23 deletions(-) diff --git a/GeneralisedFilters/src/resamplers.jl b/GeneralisedFilters/src/resamplers.jl index 524c807..d87d545 100644 --- a/GeneralisedFilters/src/resamplers.jl +++ b/GeneralisedFilters/src/resamplers.jl @@ -47,7 +47,6 @@ function resample(rng::AbstractRNG, cond_resampler::ESSResampler, state) # TODO: computing weights twice. Should create a wrapper to avoid this weights = StatsBase.weights(state) ess = inv(sum(abs2, weights)) - @debug "ESS: $ess" if cond_resampler.threshold * n ≥ ess return resample(rng, cond_resampler.resampler, state) diff --git a/research/maximum_likelihood/mooncake_test.jl b/research/maximum_likelihood/mooncake_test.jl index f2ae5af..8590793 100644 --- a/research/maximum_likelihood/mooncake_test.jl +++ b/research/maximum_likelihood/mooncake_test.jl @@ -32,32 +32,18 @@ using DifferentiationInterface import Mooncake using DistributionsAD -function build_objective(rng, θ, algo, data) +function build_objective(θ, algo, data) + rng = Xoshiro(1234) _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), algo, data) return -ll end -# kalman filter likelihood testing (works, but is slow) -logℓ1 = θ -> build_objective(rng, θ, KF(), ys) +# kalman filter likelihood testing (is slow) +logℓ1 = θ -> build_objective(θ, KF(), ys) Mooncake.TestUtils.test_rule(rng, logℓ1, [0.7]; is_primitive=false, debug_mode=true) +DifferentiationInterface.gradient(logℓ1, AutoMooncake(; config=nothing), [0.7]) -# bootstrap filter likelihood testing (shouldn't work) -logℓ2 = θ -> build_objective(rng, θ, BF(512), ys) +# bootstrap filter likelihood testing (is even slower) +logℓ2 = θ -> build_objective(θ, BF(512; threshold=0.1), ys) Mooncake.TestUtils.test_rule(rng, logℓ2, [0.7]; is_primitive=false, debug_mode=true) - -## FOR USE WITH DIFFERENTIATION INTERFACE ################################################## - -# data should be part of the objective, but be held constant by DifferentiationInterface -logℓ3 = (θ, data) -> build_objective(rng, θ, KF(), data) - -# set the backend with default configuration -backend = AutoMooncake(; config=nothing) - -# prepare the gradient for faster subsequent iteration -grad_prep = prepare_gradient(logℓ3, backend, [0.7], Constant(ys)) - -# evaluate gradients and iterate to show proof of concept -DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [0.7], Constant(ys)) -DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [0.8], Constant(ys)) -DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [0.9], Constant(ys)) -DifferentiationInterface.gradient(logℓ3, grad_prep, backend, [1.0], Constant(ys)) +DifferentiationInterface.gradient(logℓ2, AutoMooncake(; config=nothing), [0.7]) From 7da17327b7a9abe601c2133ace1b711bbf1dafba Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Mon, 17 Mar 2025 18:51:10 -0400 Subject: [PATCH 16/33] Add guided filter draft --- GeneralisedFilters/src/GeneralisedFilters.jl | 1 + .../src/algorithms/bootstrap.jl | 15 +-- GeneralisedFilters/src/algorithms/guided.jl | 126 ++++++++++++++++++ GeneralisedFilters/src/algorithms/rbpf.jl | 11 +- 4 files changed, 136 insertions(+), 17 deletions(-) create mode 100644 GeneralisedFilters/src/algorithms/guided.jl diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index fe47e37..eeb8222 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -136,6 +136,7 @@ include("algorithms/bootstrap.jl") include("algorithms/kalman.jl") include("algorithms/forward.jl") include("algorithms/rbpf.jl") +include("algorithms/guided.jl") # Unit-testing helper module include("GFTest/GFTest.jl") diff --git a/GeneralisedFilters/src/algorithms/bootstrap.jl b/GeneralisedFilters/src/algorithms/bootstrap.jl index c5314be..21463b7 100644 --- a/GeneralisedFilters/src/algorithms/bootstrap.jl +++ b/GeneralisedFilters/src/algorithms/bootstrap.jl @@ -13,11 +13,13 @@ function step( callback::Union{AbstractCallback,Nothing}=nothing, kwargs..., ) + # capture the marginalized log-likelihood state = resample(rng, alg.resampler, state) + marginalization_term = logsumexp(state.log_weights) isnothing(callback) || callback(model, alg, iter, state, observation, PostResample; kwargs...) - state = predict(rng, model, alg, iter, state; ref_state=ref_state, kwargs...) + state = predict(rng, model, alg, iter, state, observation; ref_state=ref_state, kwargs...) # TODO: this is quite inelegant and should be refactored. It also might introduce bugs # with callbacks that track the ancestry (and use PostResample) @@ -31,7 +33,7 @@ function step( isnothing(callback) || callback(model, alg, iter, state, observation, PostUpdate; kwargs...) - return state, ll_increment + return state, (ll_increment - marginalization_term) end struct BootstrapFilter{RS<:AbstractResampler} <: AbstractParticleFilter @@ -67,7 +69,8 @@ function predict( model::StateSpaceModel, filter::BootstrapFilter, step::Integer, - state::ParticleDistribution; + state::ParticleDistribution, + observation; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) @@ -86,8 +89,6 @@ function update( observation; kwargs..., ) where {T} - old_ll = logsumexp(state.log_weights) - log_increments = map( x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), collect(state), @@ -95,9 +96,7 @@ function update( state.log_weights += log_increments - ll_increment = logsumexp(state.log_weights) - old_ll - - return state, ll_increment + return state, logsumexp(state.log_weights) end # Application of bootstrap filter to hierarchical models diff --git a/GeneralisedFilters/src/algorithms/guided.jl b/GeneralisedFilters/src/algorithms/guided.jl new file mode 100644 index 0000000..a14f1d5 --- /dev/null +++ b/GeneralisedFilters/src/algorithms/guided.jl @@ -0,0 +1,126 @@ +export GuidedFilter, GPF, AbstractProposal +# import SSMProblems: distribution, simulate, logdensity + + +""" + AbstractProposal +""" +abstract type AbstractProposal end + +# TODO: improve this and ensure that there are no conflicts with SSMProblems +function distribution( + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + state, + observation; + kwargs..., +) + return throw( + MethodError(distribution, (model, prop, step, state, observation, kwargs...)) + ) +end + +function simulate( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + state, + observation; + kwargs..., +) + return rand(rng, distribution(model, prop, step, state, observation; kwargs...)) +end + +function logdensity( + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + prev_state, + new_state, + observation; + kwargs..., +) + return logpdf( + distribution(model, prop, step, prev_state, observation; kwargs...), new_state + ) +end + +struct GuidedFilter{RS<:AbstractResampler,P<:AbstractProposal} <: AbstractParticleFilter + N::Integer + resampler::RS + proposal::P +end + +function GuidedFilter( + N::Integer, proposal::PT; threshold::Real=1.0, resampler::AbstractResampler=Systematic() +) where {PT<:AbstractProposal} + conditional_resampler = ESSResampler(threshold, resampler) + return GuidedFilter{ESSResampler,PT}(N, conditional_resampler, proposal) +end + +"""Shorthand for `GuidedFilter`""" +const GPF = GuidedFilter + +function initialise( + rng::AbstractRNG, + model::StateSpaceModel{T}, + filter::GuidedFilter; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) where {T} + particles = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N)) + weights = zeros(T, filter.N) + + return update_ref!(ParticleDistribution(particles, weights), ref_state) +end + +function predict( + rng::AbstractRNG, + model::StateSpaceModel, + filter::GuidedFilter, + step::Integer, + state::ParticleDistribution, + observation; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) + proposed_particles = map( + x -> simulate(rng, model, filter.proposal, step, x, observation; kwargs...), + collect(state), + ) + + log_increments = map(zip(proposed_particles, state.particles)) do (new_state, prev_state) + log_f = SSMProblems.logdensity(model.dyn, step, prev_state, new_state; kwargs...) + log_q = logdensity( + model, filter.proposal, step, prev_state, new_state, observation; kwargs... + ) + + (log_f - log_q) + end + + proposed_state = ParticleDistribution( + proposed_particles, state.log_weights + log_increments + ) + + return update_ref!(proposed_state, ref_state, step) +end + +function update( + model::StateSpaceModel{T}, + filter::GuidedFilter, + step::Integer, + state::ParticleDistribution, + observation; + kwargs..., +) where {T} + log_increments = map( + x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), + collect(state), + ) + + state.log_weights += log_increments + + return state, logsumexp(state.log_weights) +end diff --git a/GeneralisedFilters/src/algorithms/rbpf.jl b/GeneralisedFilters/src/algorithms/rbpf.jl index 585f8f5..72cc6cd 100644 --- a/GeneralisedFilters/src/algorithms/rbpf.jl +++ b/GeneralisedFilters/src/algorithms/rbpf.jl @@ -78,8 +78,6 @@ end function update( model::HierarchicalSSM{T}, algo::RBPF, t::Integer, state, obs; kwargs... ) where {T} - old_ll = logsumexp(state.log_weights) - for i in 1:(algo.N) state.particles[i].z, log_increments = update( model.inner_model, @@ -93,9 +91,7 @@ function update( state.log_weights[i] += log_increments end - ll_increment = logsumexp(state.log_weights) - old_ll - - return state, ll_increment + return state, logsumexp(state.log_weights) end function marginal_update( @@ -199,8 +195,6 @@ function update( obs; kwargs..., ) - old_ll = logsumexp(state.log_weights) - new_zs, inner_lls = update( model.inner_model, filter.inner_algo, @@ -214,6 +208,5 @@ function update( state.log_weights += inner_lls state.particles.zs = new_zs - step_ll = logsumexp(state.log_weights) - old_ll - return state, step_ll + return state, logsumexp(state.log_weights) end From 262650d5ea1852be7a88b42b04fae8ecd3386741 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Mon, 17 Mar 2025 18:51:37 -0400 Subject: [PATCH 17/33] add VSMC replication --- research/variational_filter/Project.toml | 16 +++ research/variational_filter/script.jl | 141 +++++++++++++++++++++++ 2 files changed, 157 insertions(+) create mode 100644 research/variational_filter/Project.toml create mode 100644 research/variational_filter/script.jl diff --git a/research/variational_filter/Project.toml b/research/variational_filter/Project.toml new file mode 100644 index 0000000..5c9c0be --- /dev/null +++ b/research/variational_filter/Project.toml @@ -0,0 +1,16 @@ +[deps] +DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" +Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" +Fluxperimental = "3102ee7a-c841-4564-8f7f-ec69bd4fd658" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" +GeneralisedFilters = "3ef92589-7ab8-43f9-b5b9-a3a0c86ecbb7" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" +Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" +PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" +Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/research/variational_filter/script.jl b/research/variational_filter/script.jl new file mode 100644 index 0000000..b179ff6 --- /dev/null +++ b/research/variational_filter/script.jl @@ -0,0 +1,141 @@ +using GeneralisedFilters +using SSMProblems +using PDMats +using LinearAlgebra +using Random +using Distributions + +## LINEAR GAUSSIAN PROPOSAL ################################################################ + +# this is a pseudo optimal proposal kernel for linear Gaussian models +struct LinearGaussianProposal{T<:Real} <: GeneralisedFilters.AbstractProposal + φ::Vector{T} +end + +# a lot of computations done at each step +function GeneralisedFilters.distribution( + model::AbstractStateSpaceModel, + kernel::LinearGaussianProposal, + step::Integer, + state, + observation; + kwargs..., +) + # get model dimensions + dx = length(state) + dy = length(observation) + + # see (Corenflos et al, 2021) for details + A = GeneralisedFilters.calc_A(model.dyn, step; kwargs...) + Γ = diagm(dx, dy, kernel.φ[(dx + 1):end]) + Σ = PDiagMat(kernel.φ[1:dx]) + + return MvNormal(inv(Σ) * A * state + inv(Σ) * Γ * observation, Σ) +end + +## DEEP GAUSSIAN PROPOSAL ################################################################## + +using Flux, Fluxperimental + +struct DeepGaussianProposal{T1,T2} <: GeneralisedFilters.AbstractProposal + μ_net::T1 + Σ_net::T2 +end + +function DeepGaussianProposal(model_dims::NTuple{2, Int}, depths::NTuple{2, Int}) + input_dim = sum(model_dims) + + μnet = Chain( + Dense(input_dim => depths[1], relu), + Dense(depths[1] => model_dims[1]) + ) + + Σnet = Chain( + Dense(input_dim => depths[2], relu), + Dense(depths[2] => model_dims[1], softplus) + ) + + return DeepGaussianProposal(μnet, Σnet) +end + +Flux.@layer DeepGaussianProposal + +function (kernel::DeepGaussianProposal)(x) + kernel.μ_net(x), kernel.Σ_net(x) +end + +function GeneralisedFilters.distribution( + model::AbstractStateSpaceModel, + kernel::DeepGaussianProposal, + step::Integer, + state, + observation; + kwargs..., +) + input = cat(state, observation; dims=1) + μ, σ = kernel(input) + return MvNormal(μ, σ) +end + +## VSMC #################################################################################### + +using DifferentiationInterface +using Optimisers +using DistributionsAD +import ForwardDiff, Mooncake + +function toy_model(::Type{T}, dx, dy) where {T<:Real} + A = begin + a = collect(1:dx) + @. convert(T, 0.42)^(abs(a - a') + 1) + end + b = zeros(T, dx) + Q = PDiagMat(ones(T, dx)) + + H = diagm(dy, dx, ones(T, dy)) + c = zeros(T, dy) + R = PDiagMat(convert(T, 0.5)*ones(T, dy)) + + μ0 = zeros(T, dx) + Σ0 = PDiagMat(ones(T, dx)) + + return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) +end + +rng = MersenneTwister(1234) +true_model = toy_model(Float32, 10, 10) +_, _, ys = sample(rng, true_model, 100) + +function logℓ(θ, data) + # algo = GPF(4, LinearGaussianProposal(θ); threshold=1.0) + algo = GPF(4, θ; threshold=1.0) + _, ll = GeneralisedFilters.filter(true_model, algo, data) + return -ll +end + +num_epochs = 500 +# θ = rand(rng, Float64, 20) .+ 1.0 +θ = DeepGaussianProposal((10,10), (16,16)) +opt = Optimisers.setup(Adam(0.01), θ) + +backend = AutoMooncake(;config=nothing) +grad_prep = prepare_gradient( + logℓ, backend, θ, Constant(ys) +) + +DifferentiationInterface.value_and_gradient( + logℓ, AutoMooncake(;config=nothing), θ, Constant(ys) +) + +@time for epoch in 1:num_epochs + val, ∇logℓ = DifferentiationInterface.value_and_gradient( + logℓ, grad_prep, backend, θ, Constant(ys) + ) + + Optimisers.update!(opt, θ, ∇logℓ) + if (epoch % 25 == 0) + println("\r$(epoch):\t -$(val)") + else + print("\r$(epoch):\t -$(val)") + end +end From ebc7843b6155a64a020dea8c15e88276fe94cb48 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Mon, 17 Mar 2025 19:07:35 -0400 Subject: [PATCH 18/33] switch proposals for demonstration --- research/variational_filter/script.jl | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/research/variational_filter/script.jl b/research/variational_filter/script.jl index b179ff6..fa6ed30 100644 --- a/research/variational_filter/script.jl +++ b/research/variational_filter/script.jl @@ -107,26 +107,22 @@ true_model = toy_model(Float32, 10, 10) _, _, ys = sample(rng, true_model, 100) function logℓ(θ, data) - # algo = GPF(4, LinearGaussianProposal(θ); threshold=1.0) - algo = GPF(4, θ; threshold=1.0) + algo = GPF(4, LinearGaussianProposal(θ); threshold=1.0) + # algo = GPF(4, θ; threshold=1.0) _, ll = GeneralisedFilters.filter(true_model, algo, data) return -ll end num_epochs = 500 -# θ = rand(rng, Float64, 20) .+ 1.0 -θ = DeepGaussianProposal((10,10), (16,16)) +θ = rand(rng, Float64, 20) .+ 1.0 +# θ = DeepGaussianProposal((10,10), (16,16)) opt = Optimisers.setup(Adam(0.01), θ) -backend = AutoMooncake(;config=nothing) +backend = AutoForwardDiff() grad_prep = prepare_gradient( logℓ, backend, θ, Constant(ys) ) -DifferentiationInterface.value_and_gradient( - logℓ, AutoMooncake(;config=nothing), θ, Constant(ys) -) - @time for epoch in 1:num_epochs val, ∇logℓ = DifferentiationInterface.value_and_gradient( logℓ, grad_prep, backend, θ, Constant(ys) From 016a336082c2f838cb630ff1d71b49baa4d9b2ff Mon Sep 17 00:00:00 2001 From: Charles Knipp <32943413+charlesknipp@users.noreply.github.com> Date: Mon, 17 Mar 2025 19:09:59 -0400 Subject: [PATCH 19/33] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- .../src/algorithms/bootstrap.jl | 4 +++- GeneralisedFilters/src/algorithms/guided.jl | 20 ++++++++++--------- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/bootstrap.jl b/GeneralisedFilters/src/algorithms/bootstrap.jl index 21463b7..83f8a7f 100644 --- a/GeneralisedFilters/src/algorithms/bootstrap.jl +++ b/GeneralisedFilters/src/algorithms/bootstrap.jl @@ -19,7 +19,9 @@ function step( isnothing(callback) || callback(model, alg, iter, state, observation, PostResample; kwargs...) - state = predict(rng, model, alg, iter, state, observation; ref_state=ref_state, kwargs...) + state = predict( + rng, model, alg, iter, state, observation; ref_state=ref_state, kwargs... + ) # TODO: this is quite inelegant and should be refactored. It also might introduce bugs # with callbacks that track the ancestry (and use PostResample) diff --git a/GeneralisedFilters/src/algorithms/guided.jl b/GeneralisedFilters/src/algorithms/guided.jl index a14f1d5..0441159 100644 --- a/GeneralisedFilters/src/algorithms/guided.jl +++ b/GeneralisedFilters/src/algorithms/guided.jl @@ -1,7 +1,6 @@ export GuidedFilter, GPF, AbstractProposal # import SSMProblems: distribution, simulate, logdensity - """ AbstractProposal """ @@ -91,14 +90,17 @@ function predict( collect(state), ) - log_increments = map(zip(proposed_particles, state.particles)) do (new_state, prev_state) - log_f = SSMProblems.logdensity(model.dyn, step, prev_state, new_state; kwargs...) - log_q = logdensity( - model, filter.proposal, step, prev_state, new_state, observation; kwargs... - ) - - (log_f - log_q) - end + log_increments = + map(zip(proposed_particles, state.particles)) do (new_state, prev_state) + log_f = SSMProblems.logdensity( + model.dyn, step, prev_state, new_state; kwargs... + ) + log_q = logdensity( + model, filter.proposal, step, prev_state, new_state, observation; kwargs... + ) + + (log_f - log_q) + end proposed_state = ParticleDistribution( proposed_particles, state.log_weights + log_increments From 6715cb68893d585449ff12ee395ec99e370733ac Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Tue, 18 Mar 2025 10:02:06 -0400 Subject: [PATCH 20/33] add fix for Flux and Mooncake --- research/variational_filter/script.jl | 109 ++++++++++---------------- 1 file changed, 42 insertions(+), 67 deletions(-) diff --git a/research/variational_filter/script.jl b/research/variational_filter/script.jl index fa6ed30..74fbfd8 100644 --- a/research/variational_filter/script.jl +++ b/research/variational_filter/script.jl @@ -1,41 +1,37 @@ -using GeneralisedFilters -using SSMProblems -using PDMats -using LinearAlgebra -using Random -using Distributions - -## LINEAR GAUSSIAN PROPOSAL ################################################################ - -# this is a pseudo optimal proposal kernel for linear Gaussian models -struct LinearGaussianProposal{T<:Real} <: GeneralisedFilters.AbstractProposal - φ::Vector{T} -end +using GeneralisedFilters, SSMProblems +using PDMats, LinearAlgebra +using Random, Distributions -# a lot of computations done at each step -function GeneralisedFilters.distribution( - model::AbstractStateSpaceModel, - kernel::LinearGaussianProposal, - step::Integer, - state, - observation; - kwargs..., -) - # get model dimensions - dx = length(state) - dy = length(observation) +using Flux, Fluxperimental +using DifferentiationInterface, Optimisers +import Mooncake - # see (Corenflos et al, 2021) for details - A = GeneralisedFilters.calc_A(model.dyn, step; kwargs...) - Γ = diagm(dx, dy, kernel.φ[(dx + 1):end]) - Σ = PDiagMat(kernel.φ[1:dx]) +## TOY MODEL ############################################################################### - return MvNormal(inv(Σ) * A * state + inv(Σ) * Γ * observation, Σ) +# adapted from (Naesseth, 2016) +function toy_model(::Type{T}, dx, dy) where {T<:Real} + A = begin + a = collect(1:dx) + @. convert(T, 0.42)^(abs(a - a') + 1) + end + b = zeros(T, dx) + Q = PDiagMat(ones(T, dx)) + + H = diagm(dy, dx, ones(T, dy)) + c = zeros(T, dy) + R = PDiagMat(convert(T, 0.5)*ones(T, dy)) + + μ0 = zeros(T, dx) + Σ0 = PDiagMat(ones(T, dx)) + + return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) end -## DEEP GAUSSIAN PROPOSAL ################################################################## +rng = MersenneTwister(1234) +true_model = toy_model(Float32, 10, 10) +_, _, ys = sample(rng, true_model, 100) -using Flux, Fluxperimental +## DEEP GAUSSIAN PROPOSAL ################################################################## struct DeepGaussianProposal{T1,T2} <: GeneralisedFilters.AbstractProposal μ_net::T1 @@ -79,56 +75,35 @@ end ## VSMC #################################################################################### -using DifferentiationInterface -using Optimisers -using DistributionsAD -import ForwardDiff, Mooncake - -function toy_model(::Type{T}, dx, dy) where {T<:Real} - A = begin - a = collect(1:dx) - @. convert(T, 0.42)^(abs(a - a') + 1) - end - b = zeros(T, dx) - Q = PDiagMat(ones(T, dx)) - - H = diagm(dy, dx, ones(T, dy)) - c = zeros(T, dy) - R = PDiagMat(convert(T, 0.5)*ones(T, dy)) - - μ0 = zeros(T, dx) - Σ0 = PDiagMat(ones(T, dx)) - - return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) +# fix for Optimisers.update! with Flux support +function Optimisers.update!(opt_state, params, grad::Mooncake.Tangent) + return Optimisers.update!(opt_state, params, Fluxperimental._moonstrip(grad)) end -rng = MersenneTwister(1234) -true_model = toy_model(Float32, 10, 10) -_, _, ys = sample(rng, true_model, 100) - -function logℓ(θ, data) - algo = GPF(4, LinearGaussianProposal(θ); threshold=1.0) - # algo = GPF(4, θ; threshold=1.0) +function logℓ(φ, data) + algo = GPF(4, φ; threshold=0.8) _, ll = GeneralisedFilters.filter(true_model, algo, data) return -ll end num_epochs = 500 -θ = rand(rng, Float64, 20) .+ 1.0 -# θ = DeepGaussianProposal((10,10), (16,16)) -opt = Optimisers.setup(Adam(0.01), θ) +φ = DeepGaussianProposal((10,10), (16,16)) +opt = Optimisers.setup(Adam(0.01), φ) +vsmc_ll = zeros(Float32, num_epochs) -backend = AutoForwardDiff() +backend = AutoMooncake(;config=nothing) grad_prep = prepare_gradient( - logℓ, backend, θ, Constant(ys) + logℓ, backend, φ, Constant(ys) ) @time for epoch in 1:num_epochs val, ∇logℓ = DifferentiationInterface.value_and_gradient( - logℓ, grad_prep, backend, θ, Constant(ys) + logℓ, grad_prep, backend, φ, Constant(ys) ) - Optimisers.update!(opt, θ, ∇logℓ) + Optimisers.update!(opt, φ, ∇logℓ) + vsmc_ll[epoch] = val + if (epoch % 25 == 0) println("\r$(epoch):\t -$(val)") else From 2c3be921191655071068cffefd26bc0199884acd Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Tue, 18 Mar 2025 11:01:32 -0400 Subject: [PATCH 21/33] add plots and fix formatting --- research/variational_filter/Project.toml | 5 +- research/variational_filter/script.jl | 59 ++++++++++++++---------- 2 files changed, 36 insertions(+), 28 deletions(-) diff --git a/research/variational_filter/Project.toml b/research/variational_filter/Project.toml index 5c9c0be..eea83d5 100644 --- a/research/variational_filter/Project.toml +++ b/research/variational_filter/Project.toml @@ -1,11 +1,9 @@ [deps] +CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0" DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c" Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" Fluxperimental = "3102ee7a-c841-4564-8f7f-ec69bd4fd658" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196" GeneralisedFilters = "3ef92589-7ab8-43f9-b5b9-a3a0c86ecbb7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6" @@ -13,4 +11,3 @@ Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2" PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" -Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" diff --git a/research/variational_filter/script.jl b/research/variational_filter/script.jl index 74fbfd8..04d0f35 100644 --- a/research/variational_filter/script.jl +++ b/research/variational_filter/script.jl @@ -1,14 +1,22 @@ +# # Variational Sequential Monte Carlo + +#= +This example demonstrates the extensibility of GeneralisedFilters with an adaptation of VSMC +with a tunable proposal ([Naesseth et al, 2016](https://arxiv.org/pdf/1705.11140)). +=# + using GeneralisedFilters, SSMProblems using PDMats, LinearAlgebra using Random, Distributions using Flux, Fluxperimental using DifferentiationInterface, Optimisers -import Mooncake +using Mooncake: Mooncake -## TOY MODEL ############################################################################### +using CairoMakie + +# ## Model Definition -# adapted from (Naesseth, 2016) function toy_model(::Type{T}, dx, dy) where {T<:Real} A = begin a = collect(1:dx) @@ -19,7 +27,7 @@ function toy_model(::Type{T}, dx, dy) where {T<:Real} H = diagm(dy, dx, ones(T, dy)) c = zeros(T, dy) - R = PDiagMat(convert(T, 0.5)*ones(T, dy)) + R = PDiagMat(convert(T, 0.5) * ones(T, dy)) μ0 = zeros(T, dx) Σ0 = PDiagMat(ones(T, dx)) @@ -31,24 +39,20 @@ rng = MersenneTwister(1234) true_model = toy_model(Float32, 10, 10) _, _, ys = sample(rng, true_model, 100) -## DEEP GAUSSIAN PROPOSAL ################################################################## +# ## Deep Gaussian Proposal struct DeepGaussianProposal{T1,T2} <: GeneralisedFilters.AbstractProposal μ_net::T1 Σ_net::T2 end -function DeepGaussianProposal(model_dims::NTuple{2, Int}, depths::NTuple{2, Int}) +function DeepGaussianProposal(model_dims::NTuple{2,Int}, depths::NTuple{2,Int}) input_dim = sum(model_dims) - μnet = Chain( - Dense(input_dim => depths[1], relu), - Dense(depths[1] => model_dims[1]) - ) + μnet = Chain(Dense(input_dim => depths[1], relu), Dense(depths[1] => model_dims[1])) Σnet = Chain( - Dense(input_dim => depths[2], relu), - Dense(depths[2] => model_dims[1], softplus) + Dense(input_dim => depths[2], relu), Dense(depths[2] => model_dims[1], softplus) ) return DeepGaussianProposal(μnet, Σnet) @@ -57,7 +61,7 @@ end Flux.@layer DeepGaussianProposal function (kernel::DeepGaussianProposal)(x) - kernel.μ_net(x), kernel.Σ_net(x) + return kernel.μ_net(x), kernel.Σ_net(x) end function GeneralisedFilters.distribution( @@ -73,7 +77,7 @@ function GeneralisedFilters.distribution( return MvNormal(μ, σ) end -## VSMC #################################################################################### +# ## Designing the VSMC Algorithm # fix for Optimisers.update! with Flux support function Optimisers.update!(opt_state, params, grad::Mooncake.Tangent) @@ -87,26 +91,33 @@ function logℓ(φ, data) end num_epochs = 500 -φ = DeepGaussianProposal((10,10), (16,16)) +φ = DeepGaussianProposal((10, 10), (16, 16)) opt = Optimisers.setup(Adam(0.01), φ) vsmc_ll = zeros(Float32, num_epochs) -backend = AutoMooncake(;config=nothing) -grad_prep = prepare_gradient( - logℓ, backend, φ, Constant(ys) -) +backend = AutoMooncake(; config=nothing) +grad_prep = prepare_gradient(logℓ, backend, φ, Constant(ys)) @time for epoch in 1:num_epochs - val, ∇logℓ = DifferentiationInterface.value_and_gradient( - logℓ, grad_prep, backend, φ, Constant(ys) - ) + ∇logℓ = DifferentiationInterface.gradient(logℓ, grad_prep, backend, φ, Constant(ys)) Optimisers.update!(opt, φ, ∇logℓ) + _, val = GeneralisedFilters.filter(true_model, GPF(25, φ; threshold=0.8), ys) vsmc_ll[epoch] = val if (epoch % 25 == 0) - println("\r$(epoch):\t -$(val)") + println("\r$(epoch):\t $(val)") else - print("\r$(epoch):\t -$(val)") + print("\r$(epoch):\t $(val)") end end + +begin + fig = Figure(; size=(500, 400), fontsize=16) + ax = Axis(fig[1, 1]; limits=((0, 500), nothing), ylabel="ELBO", xlabel="Epochs") + _, kf_ll = GeneralisedFilters.filter(true_model, KF(), ys) + hlines!(ax, kf_ll; linewidth=3, color=:black, label="KF") + lines!(ax, vsmc_ll; linewidth=3, color=:red, label="VSMC") + axislegend(ax; position=:rb) + fig +end From ada830600a7ed111761d4fec12fe196a5a6af4a8 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Tue, 25 Mar 2025 15:39:29 -0400 Subject: [PATCH 22/33] restructure particle filters --- GeneralisedFilters/src/GeneralisedFilters.jl | 9 +- GeneralisedFilters/src/algorithms/kalman.jl | 16 +- .../src/algorithms/particles.jl | 235 ++++++++++++++++++ GeneralisedFilters/src/algorithms/rbpf.jl | 59 +++-- 4 files changed, 287 insertions(+), 32 deletions(-) create mode 100644 GeneralisedFilters/src/algorithms/particles.jl diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index eeb8222..2a2b953 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -54,7 +54,7 @@ function initialise(model, alg; kwargs...) return initialise(default_rng(), model, alg; kwargs...) end -function predict(model, alg, step, filtered; kwargs...) +function predict(model, alg, step, filtered, observation; kwargs...) return predict(default_rng(), model, alg, step, filtered; kwargs...) end @@ -108,7 +108,7 @@ function step( callback::Union{AbstractCallback,Nothing}=nothing, kwargs..., ) - state = predict(rng, model, alg, iter, state; kwargs...) + state = predict(rng, model, alg, iter, state, observation; kwargs...) isnothing(callback) || callback(model, alg, iter, state, observation, PostPredict; kwargs...) @@ -132,11 +132,12 @@ include("models/discrete.jl") include("models/hierarchical.jl") # Filtering/smoothing algorithms -include("algorithms/bootstrap.jl") +include("algorithms/particles.jl") +# include("algorithms/bootstrap.jl") include("algorithms/kalman.jl") include("algorithms/forward.jl") include("algorithms/rbpf.jl") -include("algorithms/guided.jl") +# include("algorithms/guided.jl") # Unit-testing helper module include("GFTest/GFTest.jl") diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index ddc279e..670da79 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -19,9 +19,10 @@ end function predict( rng::AbstractRNG, model::LinearGaussianStateSpaceModel, - filter::KalmanFilter, + algo::KalmanFilter, step::Integer, - filtered::Gaussian; + filtered::Gaussian, + observation=nothing; kwargs..., ) μ, Σ = GaussianDistributions.pair(filtered) @@ -31,7 +32,7 @@ end function update( model::LinearGaussianStateSpaceModel, - filter::KalmanFilter, + algo::KalmanFilter, step::Integer, proposed::Gaussian, obs::AbstractVector; @@ -73,7 +74,8 @@ function predict( model::LinearGaussianStateSpaceModel{T}, algo::BatchKalmanFilter, step::Integer, - state::BatchGaussianDistribution; + state::BatchGaussianDistribution, + observation; kwargs..., ) where {T} μs, Σs = state.μs, state.Σs @@ -173,7 +175,7 @@ end function smooth( rng::AbstractRNG, model::LinearGaussianStateSpaceModel{T}, - alg::KalmanSmoother, + algo::KalmanSmoother, observations::AbstractVector; t_smooth=1, callback=nothing, @@ -188,7 +190,7 @@ function smooth( back_state = filtered for t in (length(observations) - 1):-1:t_smooth back_state = backward( - rng, model, alg, t, back_state, observations[t]; states_cache=cache, kwargs... + rng, model, algo, t, back_state, observations[t]; states_cache=cache, kwargs... ) end @@ -198,7 +200,7 @@ end function backward( rng::AbstractRNG, model::LinearGaussianStateSpaceModel{T}, - alg::KalmanSmoother, + algo::KalmanSmoother, iter::Integer, back_state, obs; diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl new file mode 100644 index 0000000..92d3dcb --- /dev/null +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -0,0 +1,235 @@ +export BootstrapFilter, BF +export ParticleFilter, PF, AbstractProposal + +import SSMProblems: distribution, simulate, logdensity + +abstract type AbstractProposal end + +function SSMProblems.distribution( + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + state, + observation; + kwargs..., +) + return throw( + MethodError( + SSMProblems.distribution, (model, prop, step, state, observation, kwargs...) + ), + ) +end + +function SSMProblems.simulate( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + state, + observation; + kwargs..., +) + return rand( + rng, SSMProblems.distribution(model, prop, step, state, observation; kwargs...) + ) +end + +function SSMProblems.logdensity( + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + prev_state, + new_state, + observation; + kwargs..., +) + return logpdf( + SSMProblems.distribution(model, prop, step, prev_state, observation; kwargs...), + new_state, + ) +end + +abstract type AbstractParticleFilter <: AbstractFilter end + +struct ParticleFilter{RS,PT} <: AbstractParticleFilter + N::Int + resampler::RS + proposal::PT +end + +const PF = ParticleFilter + +function ParticleFilter( + N::Integer, proposal::PT; threshold::Real=1.0, resampler::AbstractResampler=Systematic() +) where {PT<:AbstractProposal} + conditional_resampler = ESSResampler(threshold, resampler) + return ParticleFilter{ESSResampler,PT}(N, conditional_resampler, proposal) +end + +function step( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + alg::AbstractParticleFilter, + iter::Integer, + state, + observation; + ref_state::Union{Nothing,AbstractVector}=nothing, + callback::Union{AbstractCallback,Nothing}=nothing, + kwargs..., +) + # capture the marginalized log-likelihood + state = resample(rng, alg.resampler, state) + marginalization_term = logsumexp(state.log_weights) + isnothing(callback) || + callback(model, alg, iter, state, observation, PostResample; kwargs...) + + state = predict( + rng, model, alg, iter, state, observation; ref_state=ref_state, kwargs... + ) + + # TODO: this is quite inelegant and should be refactored. It also might introduce bugs + # with callbacks that track the ancestry (and use PostResample) + if !isnothing(ref_state) + CUDA.@allowscalar state.ancestors[1] = 1 + end + isnothing(callback) || + callback(model, alg, iter, state, observation, PostPredict; kwargs...) + + state, ll_increment = update(model, alg, iter, state, observation; kwargs...) + isnothing(callback) || + callback(model, alg, iter, state, observation, PostUpdate; kwargs...) + + return state, (ll_increment - marginalization_term) +end + +function initialise( + rng::AbstractRNG, + model::StateSpaceModel{T}, + filter::ParticleFilter; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) where {T} + particles = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N)) + weights = zeros(T, filter.N) + + return update_ref!(ParticleDistribution(particles, weights), ref_state) +end + +function predict( + rng::AbstractRNG, + model::StateSpaceModel, + filter::ParticleFilter, + step::Integer, + state::ParticleDistribution, + observation; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) + proposed_particles = map( + x -> SSMProblems.simulate( + rng, model, filter.proposal, step, x, observation; kwargs... + ), + collect(state), + ) + + log_increments = + map(zip(proposed_particles, state.particles)) do (new_state, prev_state) + log_f = SSMProblems.logdensity( + model.dyn, step, prev_state, new_state; kwargs... + ) + log_q = SSMProblems.logdensity( + model, filter.proposal, step, prev_state, new_state, observation; kwargs... + ) + + (log_f - log_q) + end + + proposed_state = ParticleDistribution( + proposed_particles, state.log_weights + log_increments + ) + + return update_ref!(proposed_state, ref_state, step) +end + +function update( + model::StateSpaceModel{T}, + filter::ParticleFilter, + step::Integer, + state::ParticleDistribution, + observation; + kwargs..., +) where {T} + log_increments = map( + x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), + collect(state), + ) + + state.log_weights += log_increments + + return state, logsumexp(state.log_weights) +end + +# Default to latent dynamics +struct LatentProposal <: AbstractProposal end + +const BootstrapFilter{RS} = ParticleFilter{RS,LatentProposal} +const BF = BootstrapFilter +BootstrapFilter(N::Integer; kwargs...) = ParticleFilter(N, LatentProposal(); kwargs...) + +function simulate( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + state, + observation; + kwargs..., +) + return SSMProblems.simulate(rng, model.dyn, step, state; kwargs...) +end + +function logdensity( + model::AbstractStateSpaceModel, + prop::AbstractProposal, + step::Integer, + prev_state, + new_state, + observation; + kwargs..., +) + return SSMProblems.logdensity(model.dyn, step, prev_state, new_state; kwargs...) +end + +# overwrite predict for the bootstrap filter to remove redundant computation +function predict( + rng::AbstractRNG, + model::StateSpaceModel, + filter::BootstrapFilter, + step::Integer, + state::ParticleDistribution, + observation; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) + state.particles = map( + x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), collect(state) + ) + + return update_ref!(state, ref_state, step) +end + +# Application of bootstrap filter to hierarchical models +function filter( + rng::AbstractRNG, + model::HierarchicalSSM, + alg::BootstrapFilter, + observations::AbstractVector; + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) + ssm = StateSpaceModel( + HierarchicalDynamics(model.outer_dyn, model.inner_model.dyn), + HierarchicalObservations(model.inner_model.obs), + ) + return filter(rng, ssm, alg, observations; ref_state=ref_state, kwargs...) +end diff --git a/GeneralisedFilters/src/algorithms/rbpf.jl b/GeneralisedFilters/src/algorithms/rbpf.jl index 72cc6cd..78766b0 100644 --- a/GeneralisedFilters/src/algorithms/rbpf.jl +++ b/GeneralisedFilters/src/algorithms/rbpf.jl @@ -29,7 +29,7 @@ function initialise( ) where {T} particles = map( x -> RaoBlackwellisedParticle( - simulate(rng, model.outer_dyn; kwargs...), + SSMProblems.simulate(rng, model.outer_dyn; kwargs...), initialise(model.inner_model, algo.inner_algo; new_outer=x, kwargs...), ), 1:(algo.N), @@ -43,30 +43,39 @@ function predict( rng::AbstractRNG, model::HierarchicalSSM, algo::RBPF, - t::Integer, - filtered; + step::Integer, + state, + observation; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) new_particles = map( - x -> marginal_predict(rng, model, algo, t, x; kwargs...), filtered.particles + x -> marginal_predict(rng, model, algo, step, x, observation; kwargs...), + collect(state), ) # Don't need to deep copy weights as filtered will be overwritten in the update step - proposed = ParticleDistribution(new_particles, filtered.log_weights) + proposed = ParticleDistribution(new_particles, state.log_weights) - return update_ref!(proposed, ref_state, t) + return update_ref!(proposed, ref_state, step) end function marginal_predict( - rng::AbstractRNG, model::HierarchicalSSM, algo::RBPF, t::Integer, state; kwargs... + rng::AbstractRNG, + model::HierarchicalSSM, + algo::RBPF, + step::Integer, + state, + observation; + kwargs..., ) - proposed_x = simulate(rng, model.outer_dyn, t, state.x; kwargs...) + proposed_x = SSMProblems.simulate(rng, model.outer_dyn, step, state.x; kwargs...) proposed_z = predict( rng, model.inner_model, algo.inner_algo, - t, - state.z; + step, + state.z, + observation; prev_outer=state.x, new_outer=proposed_x, kwargs..., @@ -76,15 +85,15 @@ function marginal_predict( end function update( - model::HierarchicalSSM{T}, algo::RBPF, t::Integer, state, obs; kwargs... + model::HierarchicalSSM{T}, algo::RBPF, step::Integer, state, observation; kwargs... ) where {T} for i in 1:(algo.N) state.particles[i].z, log_increments = update( model.inner_model, algo.inner_algo, - t, + step, state.particles[i].z, - obs; + observation; new_outer=state.particles[i].x, kwargs..., ) @@ -95,10 +104,16 @@ function update( end function marginal_update( - model::HierarchicalSSM, algo::RBPF, t::Integer, state, obs; kwargs... + model::HierarchicalSSM, algo::RBPF, step::Integer, state, observation; kwargs... ) filtered_z, log_increment = update( - model.inner_model, algo.inner_algo, t, state.z, obs; new_outer=state.x, kwargs... + model.inner_model, + algo.inner_algo, + step, + state.z, + observation; + new_outer=state.x, + kwargs..., ) return RaoBlackwellisedParticle(state.x, filtered_z), log_increment @@ -164,9 +179,10 @@ end function predict( rng::AbstractRNG, model::HierarchicalSSM, - filter::BatchRBPF, + algo::BatchRBPF, step::Integer, - state::RaoBlackwellisedParticleDistribution; + state::RaoBlackwellisedParticleDistribution, + observation; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) @@ -175,9 +191,10 @@ function predict( new_xs = SSMProblems.batch_simulate(rng, outer_dyn, step, state.particles.xs; kwargs...) new_zs = predict( inner_model, - filter.inner_algo, + algo.inner_algo, step, - state.particles.zs; + state.particles.zs, + observation; prev_outer=state.particles.xs, new_outer=new_xs, kwargs..., @@ -189,7 +206,7 @@ end function update( model::HierarchicalSSM, - filter::BatchRBPF, + algo::BatchRBPF, step::Integer, state::RaoBlackwellisedParticleDistribution, obs; @@ -197,7 +214,7 @@ function update( ) new_zs, inner_lls = update( model.inner_model, - filter.inner_algo, + algo.inner_algo, step, state.particles.zs, obs; From afb3021e502591e42a0e2d65acf86b0ce0c594bb Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Tue, 25 Mar 2025 15:39:46 -0400 Subject: [PATCH 23/33] update example --- research/variational_filter/script.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/research/variational_filter/script.jl b/research/variational_filter/script.jl index 04d0f35..4f37501 100644 --- a/research/variational_filter/script.jl +++ b/research/variational_filter/script.jl @@ -85,7 +85,7 @@ function Optimisers.update!(opt_state, params, grad::Mooncake.Tangent) end function logℓ(φ, data) - algo = GPF(4, φ; threshold=0.8) + algo = PF(4, φ; threshold=0.8) _, ll = GeneralisedFilters.filter(true_model, algo, data) return -ll end @@ -102,7 +102,7 @@ grad_prep = prepare_gradient(logℓ, backend, φ, Constant(ys)) ∇logℓ = DifferentiationInterface.gradient(logℓ, grad_prep, backend, φ, Constant(ys)) Optimisers.update!(opt, φ, ∇logℓ) - _, val = GeneralisedFilters.filter(true_model, GPF(25, φ; threshold=0.8), ys) + _, val = GeneralisedFilters.filter(true_model, PF(25, φ; threshold=0.8), ys) vsmc_ll[epoch] = val if (epoch % 25 == 0) From cc92ad8f144b7d989ce835085b9e52c1b5cf62a8 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Thu, 27 Mar 2025 17:23:48 -0400 Subject: [PATCH 24/33] fixed type signatures for bootstrap filter --- .../src/algorithms/particles.jl | 22 +++++++++---------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 92d3dcb..3b5c500 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -1,11 +1,11 @@ export BootstrapFilter, BF export ParticleFilter, PF, AbstractProposal -import SSMProblems: distribution, simulate, logdensity +# import SSMProblems: distribution, simulate, logdensity abstract type AbstractProposal end -function SSMProblems.distribution( +function distribution( model::AbstractStateSpaceModel, prop::AbstractProposal, step::Integer, @@ -15,12 +15,12 @@ function SSMProblems.distribution( ) return throw( MethodError( - SSMProblems.distribution, (model, prop, step, state, observation, kwargs...) + distribution, (model, prop, step, state, observation, kwargs...) ), ) end -function SSMProblems.simulate( +function simulate( rng::AbstractRNG, model::AbstractStateSpaceModel, prop::AbstractProposal, @@ -30,11 +30,11 @@ function SSMProblems.simulate( kwargs..., ) return rand( - rng, SSMProblems.distribution(model, prop, step, state, observation; kwargs...) + rng, distribution(model, prop, step, state, observation; kwargs...) ) end -function SSMProblems.logdensity( +function logdensity( model::AbstractStateSpaceModel, prop::AbstractProposal, step::Integer, @@ -44,7 +44,7 @@ function SSMProblems.logdensity( kwargs..., ) return logpdf( - SSMProblems.distribution(model, prop, step, prev_state, observation; kwargs...), + distribution(model, prop, step, prev_state, observation; kwargs...), new_state, ) end @@ -126,7 +126,7 @@ function predict( kwargs..., ) proposed_particles = map( - x -> SSMProblems.simulate( + x -> simulate( rng, model, filter.proposal, step, x, observation; kwargs... ), collect(state), @@ -137,7 +137,7 @@ function predict( log_f = SSMProblems.logdensity( model.dyn, step, prev_state, new_state; kwargs... ) - log_q = SSMProblems.logdensity( + log_q = logdensity( model, filter.proposal, step, prev_state, new_state, observation; kwargs... ) @@ -179,7 +179,7 @@ BootstrapFilter(N::Integer; kwargs...) = ParticleFilter(N, LatentProposal(); kwa function simulate( rng::AbstractRNG, model::AbstractStateSpaceModel, - prop::AbstractProposal, + prop::LatentProposal, step::Integer, state, observation; @@ -190,7 +190,7 @@ end function logdensity( model::AbstractStateSpaceModel, - prop::AbstractProposal, + prop::LatentProposal, step::Integer, prev_state, new_state, From 55aaa30775d6b5d8b6167a82a4a9c75f1a33ec42 Mon Sep 17 00:00:00 2001 From: Charles Knipp <32943413+charlesknipp@users.noreply.github.com> Date: Fri, 28 Mar 2025 10:23:32 -0400 Subject: [PATCH 25/33] guess who forgot to run the formatter?? it was me Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- GeneralisedFilters/src/algorithms/particles.jl | 15 ++++----------- 1 file changed, 4 insertions(+), 11 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 3b5c500..14bf3f8 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -14,9 +14,7 @@ function distribution( kwargs..., ) return throw( - MethodError( - distribution, (model, prop, step, state, observation, kwargs...) - ), + MethodError(distribution, (model, prop, step, state, observation, kwargs...)) ) end @@ -29,9 +27,7 @@ function simulate( observation; kwargs..., ) - return rand( - rng, distribution(model, prop, step, state, observation; kwargs...) - ) + return rand(rng, distribution(model, prop, step, state, observation; kwargs...)) end function logdensity( @@ -44,8 +40,7 @@ function logdensity( kwargs..., ) return logpdf( - distribution(model, prop, step, prev_state, observation; kwargs...), - new_state, + distribution(model, prop, step, prev_state, observation; kwargs...), new_state ) end @@ -126,9 +121,7 @@ function predict( kwargs..., ) proposed_particles = map( - x -> simulate( - rng, model, filter.proposal, step, x, observation; kwargs... - ), + x -> simulate(rng, model, filter.proposal, step, x, observation; kwargs...), collect(state), ) From a92f45569237218df507564241bf62e91a5e9dd9 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Thu, 3 Apr 2025 11:06:11 -0400 Subject: [PATCH 26/33] update forward algorithm --- GeneralisedFilters/src/algorithms/forward.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/GeneralisedFilters/src/algorithms/forward.jl b/GeneralisedFilters/src/algorithms/forward.jl index b8182b3..81a5f6f 100644 --- a/GeneralisedFilters/src/algorithms/forward.jl +++ b/GeneralisedFilters/src/algorithms/forward.jl @@ -14,7 +14,8 @@ function predict( model::DiscreteStateSpaceModel{T}, filter::ForwardAlgorithm, step::Integer, - states::AbstractVector; + states::AbstractVector, + observation; kwargs..., ) where {T} P = calc_P(model.dyn, step; kwargs...) From 265ed380d4a5ef9e9a794b87b18c79fca124ff7e Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Thu, 3 Apr 2025 14:49:42 -0400 Subject: [PATCH 27/33] additional merge fixes --- .../src/algorithms/bootstrap.jl | 12 +--- GeneralisedFilters/src/algorithms/kalman.jl | 2 +- .../src/algorithms/particles.jl | 56 ++++++++++--------- 3 files changed, 33 insertions(+), 37 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/bootstrap.jl b/GeneralisedFilters/src/algorithms/bootstrap.jl index 8137ff1..0b48871 100644 --- a/GeneralisedFilters/src/algorithms/bootstrap.jl +++ b/GeneralisedFilters/src/algorithms/bootstrap.jl @@ -14,20 +14,12 @@ function step( kwargs..., ) # capture the marginalized log-likelihood - state = resample(rng, alg.resampler, state) + state = resample(rng, alg.resampler, state; ref_state) marginalization_term = logsumexp(state.log_weights) isnothing(callback) || callback(model, alg, iter, state, observation, PostResample; kwargs...) - state = predict( - rng, model, alg, iter, state, observation; ref_state=ref_state, kwargs... - ) - - # TODO: this is quite inelegant and should be refactored. It also might introduce bugs - # with callbacks that track the ancestry (and use PostResample) - if !isnothing(ref_state) - CUDA.@allowscalar state.ancestors[1] = 1 - end + state = predict(rng, model, alg, iter, state, observation; ref_state, kwargs...) isnothing(callback) || callback(model, alg, iter, state, observation, PostPredict; kwargs...) diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index 670da79..8527cd1 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -13,7 +13,7 @@ function initialise( rng::AbstractRNG, model::LinearGaussianStateSpaceModel, filter::KalmanFilter; kwargs... ) μ0, Σ0 = calc_initial(model.dyn; kwargs...) - return Gaussian(μ0, Matrix(Σ0)) + return Gaussian(μ0, Σ0) end function predict( diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 14bf3f8..b1e2830 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -73,20 +73,12 @@ function step( kwargs..., ) # capture the marginalized log-likelihood - state = resample(rng, alg.resampler, state) + state = resample(rng, alg.resampler, state; ref_state) marginalization_term = logsumexp(state.log_weights) isnothing(callback) || callback(model, alg, iter, state, observation, PostResample; kwargs...) - state = predict( - rng, model, alg, iter, state, observation; ref_state=ref_state, kwargs... - ) - - # TODO: this is quite inelegant and should be refactored. It also might introduce bugs - # with callbacks that track the ancestry (and use PostResample) - if !isnothing(ref_state) - CUDA.@allowscalar state.ancestors[1] = 1 - end + state = predict(rng, model, alg, iter, state, observation; ref_state, kwargs...) isnothing(callback) || callback(model, alg, iter, state, observation, PostPredict; kwargs...) @@ -104,10 +96,16 @@ function initialise( ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) where {T} - particles = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N)) - weights = zeros(T, filter.N) + particles = map(1:(filter.N)) do i + if !isnothing(ref_state) && i == 1 + ref_state[0] + else + SSMProblems.simulate(rng, model.dyn; kwargs...) + end + end + log_ws = zeros(T, filter.N) - return update_ref!(ParticleDistribution(particles, weights), ref_state) + return ParticleDistribution(particles, log_ws) end function predict( @@ -120,16 +118,20 @@ function predict( ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) - proposed_particles = map( - x -> simulate(rng, model, filter.proposal, step, x, observation; kwargs...), - collect(state), - ) + proposed_particles = map(enumerate(state.particles)) do (i, particle) + if !isnothing(ref_state) && i == 1 + ref_state[step] + else + simulate(rng, model, filter.proposal, step, particle, observation; kwargs...) + end + end - log_increments = + state.log_weights += map(zip(proposed_particles, state.particles)) do (new_state, prev_state) log_f = SSMProblems.logdensity( model.dyn, step, prev_state, new_state; kwargs... ) + log_q = logdensity( model, filter.proposal, step, prev_state, new_state, observation; kwargs... ) @@ -137,11 +139,9 @@ function predict( (log_f - log_q) end - proposed_state = ParticleDistribution( - proposed_particles, state.log_weights + log_increments - ) + state.particles = proposed_particles - return update_ref!(proposed_state, ref_state, step) + return state end function update( @@ -204,11 +204,15 @@ function predict( ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) - state.particles = map( - x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), collect(state) - ) + state.particles = map(enumerate(state.particles)) do (i, particle) + if !isnothing(ref_state) && i == 1 + ref_state[step] + else + SSMProblems.simulate(rng, model.dyn, step, particle; kwargs...) + end + end - return update_ref!(state, ref_state, step) + return state end # Application of bootstrap filter to hierarchical models From c90d2394300285b060a309ffeb6d22242d777a0e Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Thu, 3 Apr 2025 14:55:52 -0400 Subject: [PATCH 28/33] remove redundant files --- .../src/algorithms/bootstrap.jl | 120 ---------------- GeneralisedFilters/src/algorithms/guided.jl | 128 ------------------ 2 files changed, 248 deletions(-) delete mode 100644 GeneralisedFilters/src/algorithms/bootstrap.jl delete mode 100644 GeneralisedFilters/src/algorithms/guided.jl diff --git a/GeneralisedFilters/src/algorithms/bootstrap.jl b/GeneralisedFilters/src/algorithms/bootstrap.jl deleted file mode 100644 index 0b48871..0000000 --- a/GeneralisedFilters/src/algorithms/bootstrap.jl +++ /dev/null @@ -1,120 +0,0 @@ -export BootstrapFilter, BF - -abstract type AbstractParticleFilter <: AbstractFilter end - -function step( - rng::AbstractRNG, - model::AbstractStateSpaceModel, - alg::AbstractParticleFilter, - iter::Integer, - state, - observation; - ref_state::Union{Nothing,AbstractVector}=nothing, - callback::Union{AbstractCallback,Nothing}=nothing, - kwargs..., -) - # capture the marginalized log-likelihood - state = resample(rng, alg.resampler, state; ref_state) - marginalization_term = logsumexp(state.log_weights) - isnothing(callback) || - callback(model, alg, iter, state, observation, PostResample; kwargs...) - - state = predict(rng, model, alg, iter, state, observation; ref_state, kwargs...) - isnothing(callback) || - callback(model, alg, iter, state, observation, PostPredict; kwargs...) - - state, ll_increment = update(model, alg, iter, state, observation; kwargs...) - isnothing(callback) || - callback(model, alg, iter, state, observation, PostUpdate; kwargs...) - - return state, (ll_increment - marginalization_term) -end - -struct BootstrapFilter{RS<:AbstractResampler} <: AbstractParticleFilter - N::Int - resampler::RS -end - -"""Shorthand for `BootstrapFilter`""" -const BF = BootstrapFilter - -function BootstrapFilter( - N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic() -) - conditional_resampler = ESSResampler(threshold, resampler) - return BootstrapFilter{ESSResampler}(N, conditional_resampler) -end - -function initialise( - rng::AbstractRNG, - model::StateSpaceModel{T}, - filter::BootstrapFilter; - ref_state::Union{Nothing,AbstractVector}=nothing, - kwargs..., -) where {T} - particles = map(1:(filter.N)) do i - if !isnothing(ref_state) && i == 1 - ref_state[0] - else - SSMProblems.simulate(rng, model.dyn; kwargs...) - end - end - log_ws = zeros(T, filter.N) - - return ParticleDistribution(particles, log_ws) -end - -function predict( - rng::AbstractRNG, - model::StateSpaceModel, - filter::BootstrapFilter, - step::Integer, - state::ParticleDistribution, - observation; - ref_state::Union{Nothing,AbstractVector}=nothing, - kwargs..., -) - state.particles = map(enumerate(state.particles)) do (i, particle) - if !isnothing(ref_state) && i == 1 - ref_state[step] - else - SSMProblems.simulate(rng, model.dyn, step, particle; kwargs...) - end - end - - return state -end - -function update( - model::StateSpaceModel{T}, - filter::BootstrapFilter, - step::Integer, - state::ParticleDistribution, - observation; - kwargs..., -) where {T} - log_increments = map( - x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), - collect(state), - ) - - state.log_weights += log_increments - - return state, logsumexp(state.log_weights) -end - -# Application of bootstrap filter to hierarchical models -function filter( - rng::AbstractRNG, - model::HierarchicalSSM, - alg::BootstrapFilter, - observations::AbstractVector; - ref_state::Union{Nothing,AbstractVector}=nothing, - kwargs..., -) - ssm = StateSpaceModel( - HierarchicalDynamics(model.outer_dyn, model.inner_model.dyn), - HierarchicalObservations(model.inner_model.obs), - ) - return filter(rng, ssm, alg, observations; ref_state=ref_state, kwargs...) -end diff --git a/GeneralisedFilters/src/algorithms/guided.jl b/GeneralisedFilters/src/algorithms/guided.jl deleted file mode 100644 index 0441159..0000000 --- a/GeneralisedFilters/src/algorithms/guided.jl +++ /dev/null @@ -1,128 +0,0 @@ -export GuidedFilter, GPF, AbstractProposal -# import SSMProblems: distribution, simulate, logdensity - -""" - AbstractProposal -""" -abstract type AbstractProposal end - -# TODO: improve this and ensure that there are no conflicts with SSMProblems -function distribution( - model::AbstractStateSpaceModel, - prop::AbstractProposal, - step::Integer, - state, - observation; - kwargs..., -) - return throw( - MethodError(distribution, (model, prop, step, state, observation, kwargs...)) - ) -end - -function simulate( - rng::AbstractRNG, - model::AbstractStateSpaceModel, - prop::AbstractProposal, - step::Integer, - state, - observation; - kwargs..., -) - return rand(rng, distribution(model, prop, step, state, observation; kwargs...)) -end - -function logdensity( - model::AbstractStateSpaceModel, - prop::AbstractProposal, - step::Integer, - prev_state, - new_state, - observation; - kwargs..., -) - return logpdf( - distribution(model, prop, step, prev_state, observation; kwargs...), new_state - ) -end - -struct GuidedFilter{RS<:AbstractResampler,P<:AbstractProposal} <: AbstractParticleFilter - N::Integer - resampler::RS - proposal::P -end - -function GuidedFilter( - N::Integer, proposal::PT; threshold::Real=1.0, resampler::AbstractResampler=Systematic() -) where {PT<:AbstractProposal} - conditional_resampler = ESSResampler(threshold, resampler) - return GuidedFilter{ESSResampler,PT}(N, conditional_resampler, proposal) -end - -"""Shorthand for `GuidedFilter`""" -const GPF = GuidedFilter - -function initialise( - rng::AbstractRNG, - model::StateSpaceModel{T}, - filter::GuidedFilter; - ref_state::Union{Nothing,AbstractVector}=nothing, - kwargs..., -) where {T} - particles = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N)) - weights = zeros(T, filter.N) - - return update_ref!(ParticleDistribution(particles, weights), ref_state) -end - -function predict( - rng::AbstractRNG, - model::StateSpaceModel, - filter::GuidedFilter, - step::Integer, - state::ParticleDistribution, - observation; - ref_state::Union{Nothing,AbstractVector}=nothing, - kwargs..., -) - proposed_particles = map( - x -> simulate(rng, model, filter.proposal, step, x, observation; kwargs...), - collect(state), - ) - - log_increments = - map(zip(proposed_particles, state.particles)) do (new_state, prev_state) - log_f = SSMProblems.logdensity( - model.dyn, step, prev_state, new_state; kwargs... - ) - log_q = logdensity( - model, filter.proposal, step, prev_state, new_state, observation; kwargs... - ) - - (log_f - log_q) - end - - proposed_state = ParticleDistribution( - proposed_particles, state.log_weights + log_increments - ) - - return update_ref!(proposed_state, ref_state, step) -end - -function update( - model::StateSpaceModel{T}, - filter::GuidedFilter, - step::Integer, - state::ParticleDistribution, - observation; - kwargs..., -) where {T} - log_increments = map( - x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), - collect(state), - ) - - state.log_weights += log_increments - - return state, logsumexp(state.log_weights) -end From 57622956149cdc64d6cc4073a202a68f23c09126 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Thu, 3 Apr 2025 14:58:38 -0400 Subject: [PATCH 29/33] consolidate MLE examples --- .../minimum_working_example.jl | 181 ------------------ research/maximum_likelihood/mooncake_test.jl | 49 ----- .../{mle_demo.jl => script.jl} | 0 3 files changed, 230 deletions(-) delete mode 100644 research/maximum_likelihood/minimum_working_example.jl delete mode 100644 research/maximum_likelihood/mooncake_test.jl rename research/maximum_likelihood/{mle_demo.jl => script.jl} (100%) diff --git a/research/maximum_likelihood/minimum_working_example.jl b/research/maximum_likelihood/minimum_working_example.jl deleted file mode 100644 index 5f017b9..0000000 --- a/research/maximum_likelihood/minimum_working_example.jl +++ /dev/null @@ -1,181 +0,0 @@ -using LinearAlgebra -using GaussianDistributions -using Random - -using DistributionsAD -using Distributions - -using Enzyme - -## MODEL DEFINITION ######################################################################## - -struct LinearGaussianProcess{ - T<:Real, - ΦT<:AbstractMatrix{T}, - ΣT<:AbstractMatrix{T}, - μT<:AbstractVector{T} - } - ϕ::ΦT - Σ::ΣT - μ::μT - function LinearGaussianProcess(ϕ::ΦT, Σ::ΣT, μ::μT) where { - T<:Real, - ΦT<:AbstractMatrix{T}, - ΣT<:AbstractMatrix{T}, - μT<:AbstractVector{T} - } - @assert size(ϕ,1) == size(Σ,1) == size(Σ,2) == size(μ,1) - return new{T, ΦT, ΣT, μT}(ϕ, Σ, μ) - end -end - -# a rather simplified version of GeneralisedFilters.LinearGaussianStateSpaceModel -struct LinearGaussianModel{ - ΘT<:Real, - TT<:LinearGaussianProcess{ΘT}, - OT<:LinearGaussianProcess{ΘT} - } - transition::TT - observation::OT -end - -## KALMAN FILTER ########################################################################### - -# this is based on the algorithm of GeneralisedFilters.jl -function kalman_filter( - model::LinearGaussianModel, - init_state::Gaussian, - observations::Vector{T} - ) where {T<:Real} - log_evidence = zero(T) - filtered = init_state - - # calc_params(model.dyn) - A = model.transition.ϕ - Q = model.transition.Σ - b = model.transition.μ - - # calc_params(model.obs) - H = model.observation.ϕ - R = model.observation.Σ - c = model.observation.μ - - for obs in observations - # predict step - μ, Σ = GaussianDistributions.pair(filtered) - proposed = Gaussian(A*μ + b, A*Σ*A' + Q) - - # update step - μ, Σ = GaussianDistributions.pair(proposed) - m = H*μ + c - residual = [obs] - m - - S = Symmetric(H*Σ*H' + R) - gain = Σ*H' / S - - filtered = Gaussian(μ + gain*residual, Σ - gain*H*Σ) - log_evidence += logpdf(MvNormal(m, S), [obs]) - end - - return log_evidence -end - -## DEMONSTRATION ########################################################################### - -# model constructor -function build_model(θ::T) where {T<:Real} - trans = LinearGaussianProcess( - T[0.8 θ/2; -0.1 0.8], - Diagonal(T[0.2, 1.0]), - zeros(T, 2) - ) - - obs = LinearGaussianProcess( - Matrix{T}(I, 1, 2), - Diagonal(T[0.2]), - zeros(T, 1) - ) - - return LinearGaussianModel(trans, obs) -end - -# log likelihood function -function logℓ(θ::Vector{T}, data) where {T<:Real} - model = build_model(θ[]) - init_state = Gaussian(T[1.0, 0.0], diagm(ones(T, 2))) - return kalman_filter(model, init_state, data) -end - -# refer to data globally (not preferred) -function logℓ_nodata(θ) - return logℓ(θ, data) -end - -# data generation (with unit covariance) -rng = MersenneTwister(1234) -data = cumsum(randn(rng, 100)) .+ randn(rng, 100) - -# ensure that log likelihood looks stable -logℓ([1.0], data) - -## SYNTACTICAL SUGAR ####################################################################### - -# this has no issue behaving well -grad_test, _ = Enzyme.gradient(Enzyme.Reverse, logℓ, [1.0], Const(data)) - -# this error is unlegible (at least to my untrained eye) -Enzyme.hvp(logℓ_nodata, [1.0], [1.0]) - -## FROM SCRATCH ############################################################################ - -function generate_perturbations(::Type{T}, n::Int) where {T<:Real} - perturbation_mat = Matrix{T}(I, n, n) - return tuple(collect.(eachslice(perturbation_mat, dims=1))...) -end - -generate_perturbations(n::Int) = generate_perturbations(Float64, n) -generate_perturbations(x::Vector{T}) where {T<:Real} = generate_perturbations(T, length(x)) - -function make_zeros(::Type{T}, n::Int) where {T<:Real} - return tuple(collect.(zeros(T, n) for _ in 1:n)...) -end - -make_zeros(n::Int) = make_zeros(Float64, n) -make_zeros(x::Vector{T}) where {T<:Real} = make_zeros(T, length(x)) - -function ∇logℓ(θ, args...) - ∂θ = Enzyme.make_zero(θ) - ∇logℓ!(θ, ∂θ, args...) - return ∂θ -end - -function ∇logℓ!(θ, ∂θ, args...) - Enzyme.autodiff(Enzyme.Reverse, logℓ, Active, Duplicated(θ, ∂θ), args...) - return nothing -end - -# ensure I'm doing the right thing -@assert grad_test == ∇logℓ([1.0], Const(data)) - -# see https://enzyme.mit.edu/julia/stable/generated/autodiff/#Vector-forward-over-reverse -function hessian(θ::Vector{T}) where {T<:Real} - # generate impulse and record second order responses - dθ = Enzyme.make_zero(θ) - vθ = generate_perturbations(θ) - H = make_zeros(θ) - - # take derivatives - Enzyme.autodiff( - Enzyme.Forward, - ∇logℓ!, - Enzyme.BatchDuplicated(θ, vθ), - Enzyme.BatchDuplicated(dθ, H), - Const(data), - ) - - # stack appropriately - return vcat(H...) -end - -# errors and I don't know Enzyme well enough to figure out why -hessian([1.0]) diff --git a/research/maximum_likelihood/mooncake_test.jl b/research/maximum_likelihood/mooncake_test.jl deleted file mode 100644 index 8590793..0000000 --- a/research/maximum_likelihood/mooncake_test.jl +++ /dev/null @@ -1,49 +0,0 @@ -using GeneralisedFilters -using SSMProblems -using LinearAlgebra -using Random - -## TOY MODEL ############################################################################### - -# this is taken from an example in Kalman.jl -function toy_model(θ::T) where {T<:Real} - μ0 = T[1.0, 0.0] - Σ0 = Diagonal(ones(T, 2)) - - A = T[0.8 θ/2; -0.1 0.8] - Q = Diagonal(T[0.2, 1.0]) - b = zeros(T, 2) - - H = Matrix{T}(I, 1, 2) - R = Diagonal(T[0.2]) - c = zeros(T, 1) - - return create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) -end - -# data generation process with small sample -rng = MersenneTwister(1234) -true_model = toy_model(1.0) -_, _, ys = sample(rng, true_model, 20) - -## RUN MOONCKAE TESTS ###################################################################### - -using DifferentiationInterface -import Mooncake -using DistributionsAD - -function build_objective(θ, algo, data) - rng = Xoshiro(1234) - _, ll = GeneralisedFilters.filter(rng, toy_model(θ[]), algo, data) - return -ll -end - -# kalman filter likelihood testing (is slow) -logℓ1 = θ -> build_objective(θ, KF(), ys) -Mooncake.TestUtils.test_rule(rng, logℓ1, [0.7]; is_primitive=false, debug_mode=true) -DifferentiationInterface.gradient(logℓ1, AutoMooncake(; config=nothing), [0.7]) - -# bootstrap filter likelihood testing (is even slower) -logℓ2 = θ -> build_objective(θ, BF(512; threshold=0.1), ys) -Mooncake.TestUtils.test_rule(rng, logℓ2, [0.7]; is_primitive=false, debug_mode=true) -DifferentiationInterface.gradient(logℓ2, AutoMooncake(; config=nothing), [0.7]) diff --git a/research/maximum_likelihood/mle_demo.jl b/research/maximum_likelihood/script.jl similarity index 100% rename from research/maximum_likelihood/mle_demo.jl rename to research/maximum_likelihood/script.jl From 1e95b3308bc7cd149762f8fc95156d0caf7ec93b Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Fri, 4 Apr 2025 11:43:44 -0400 Subject: [PATCH 30/33] suggested changes and house cleaning --- GeneralisedFilters/src/GeneralisedFilters.jl | 2 - GeneralisedFilters/src/algorithms/kalman.jl | 48 +++++------ .../src/algorithms/particles.jl | 79 ++++++++++--------- GeneralisedFilters/src/algorithms/rbpf.jl | 24 +++--- research/maximum_likelihood/script.jl | 15 ++-- 5 files changed, 85 insertions(+), 83 deletions(-) diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index 2a2b953..6dc97c6 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -133,11 +133,9 @@ include("models/hierarchical.jl") # Filtering/smoothing algorithms include("algorithms/particles.jl") -# include("algorithms/bootstrap.jl") include("algorithms/kalman.jl") include("algorithms/forward.jl") include("algorithms/rbpf.jl") -# include("algorithms/guided.jl") # Unit-testing helper module include("GFTest/GFTest.jl") diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index 8527cd1..878ea84 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -1,7 +1,7 @@ export KalmanFilter, filter, BatchKalmanFilter using GaussianDistributions using CUDA: i32 -import LinearAlgebra: Symmetric +import LinearAlgebra: hermitianpart export KalmanFilter, KF, KalmanSmoother, KS @@ -20,39 +20,39 @@ function predict( rng::AbstractRNG, model::LinearGaussianStateSpaceModel, algo::KalmanFilter, - step::Integer, - filtered::Gaussian, + iter::Integer, + state::Gaussian, observation=nothing; kwargs..., ) - μ, Σ = GaussianDistributions.pair(filtered) - A, b, Q = calc_params(model.dyn, step; kwargs...) + μ, Σ = GaussianDistributions.pair(state) + A, b, Q = calc_params(model.dyn, iter; kwargs...) return Gaussian(A * μ + b, A * Σ * A' + Q) end function update( model::LinearGaussianStateSpaceModel, algo::KalmanFilter, - step::Integer, - proposed::Gaussian, - obs::AbstractVector; + iter::Integer, + state::Gaussian, + observation::AbstractVector; kwargs..., ) - μ, Σ = GaussianDistributions.pair(proposed) - H, c, R = calc_params(model.obs, step; kwargs...) + μ, Σ = GaussianDistributions.pair(state) + H, c, R = calc_params(model.obs, iter; kwargs...) # Update state m = H * μ + c - y = obs - m - S = Symmetric(H * Σ * H' + R) + y = observation - m + S = hermitianpart(H * Σ * H' + R) K = Σ * H' / S - filtered = Gaussian(μ + K * y, Σ - K * H * Σ) + state = Gaussian(μ + K * y, Σ - K * H * Σ) # Compute log-likelihood - ll = logpdf(MvNormal(m, S), obs) + ll = logpdf(MvNormal(m, S), observation) - return filtered, ll + return state, ll end struct BatchKalmanFilter <: AbstractBatchFilter @@ -73,13 +73,13 @@ function predict( rng::AbstractRNG, model::LinearGaussianStateSpaceModel{T}, algo::BatchKalmanFilter, - step::Integer, + iter::Integer, state::BatchGaussianDistribution, observation; kwargs..., ) where {T} μs, Σs = state.μs, state.Σs - As, bs, Qs = batch_calc_params(model.dyn, step, algo.batch_size; kwargs...) + As, bs, Qs = batch_calc_params(model.dyn, iter, algo.batch_size; kwargs...) μ̂s = NNlib.batched_vec(As, μs) .+ bs Σ̂s = NNlib.batched_mul(NNlib.batched_mul(As, Σs), NNlib.batched_transpose(As)) .+ Qs return BatchGaussianDistribution(μ̂s, Σ̂s) @@ -88,17 +88,17 @@ end function update( model::LinearGaussianStateSpaceModel{T}, algo::BatchKalmanFilter, - step::Integer, + iter::Integer, state::BatchGaussianDistribution, - obs; + observation; kwargs..., ) where {T} μs, Σs = state.μs, state.Σs - Hs, cs, Rs = batch_calc_params(model.obs, step, algo.batch_size; kwargs...) - D = size(obs, 1) + Hs, cs, Rs = batch_calc_params(model.obs, iter, algo.batch_size; kwargs...) + D = size(observation, 1) m = NNlib.batched_vec(Hs, μs) .+ cs - y_res = cu(obs) .- m + y_res = cu(observation) .- m S = NNlib.batched_mul(Hs, NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs))) .+ Rs ΣH_T = NNlib.batched_mul(Σs, NNlib.batched_transpose(Hs)) @@ -151,7 +151,7 @@ function (callback::StateCallback)( algo::KalmanFilter, iter::Integer, state, - obs, + observation, ::PostPredictCallback; kwargs..., ) @@ -164,7 +164,7 @@ function (callback::StateCallback)( algo::KalmanFilter, iter::Integer, state, - obs, + observation, ::PostUpdateCallback; kwargs..., ) diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index b1e2830..879e157 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -1,46 +1,49 @@ export BootstrapFilter, BF export ParticleFilter, PF, AbstractProposal -# import SSMProblems: distribution, simulate, logdensity +import SSMProblems: distribution, simulate, logdensity abstract type AbstractProposal end -function distribution( +function SSMProblems.distribution( model::AbstractStateSpaceModel, prop::AbstractProposal, - step::Integer, + iter::Integer, state, observation; kwargs..., ) return throw( - MethodError(distribution, (model, prop, step, state, observation, kwargs...)) + MethodError(distribution, (model, prop, iter, state, observation, kwargs...)) ) end -function simulate( +function SSMProblems.simulate( rng::AbstractRNG, model::AbstractStateSpaceModel, prop::AbstractProposal, - step::Integer, + iter::Integer, state, observation; kwargs..., ) - return rand(rng, distribution(model, prop, step, state, observation; kwargs...)) + return rand( + rng, SSMProblems.distribution(model, prop, iter, state, observation; kwargs...) + ) end -function logdensity( +function SSMProblems.logdensity( model::AbstractStateSpaceModel, prop::AbstractProposal, - step::Integer, + iter::Integer, prev_state, new_state, observation; kwargs..., ) return logpdf( - distribution(model, prop, step, prev_state, observation; kwargs...), new_state + SSMProblems.distribution(model, prop, iter, prev_state, observation; kwargs...), + new_state, ) end @@ -64,7 +67,7 @@ end function step( rng::AbstractRNG, model::AbstractStateSpaceModel, - alg::AbstractParticleFilter, + algo::AbstractParticleFilter, iter::Integer, state, observation; @@ -73,18 +76,19 @@ function step( kwargs..., ) # capture the marginalized log-likelihood - state = resample(rng, alg.resampler, state; ref_state) + state = resample(rng, algo.resampler, state; ref_state) marginalization_term = logsumexp(state.log_weights) isnothing(callback) || - callback(model, alg, iter, state, observation, PostResample; kwargs...) + callback(model, algo, iter, state, observation, PostResample; kwargs...) - state = predict(rng, model, alg, iter, state, observation; ref_state, kwargs...) + state = predict(rng, model, algo, iter, state, observation; ref_state, kwargs...) isnothing(callback) || - callback(model, alg, iter, state, observation, PostPredict; kwargs...) + callback(model, algo, iter, state, observation, PostPredict; kwargs...) - state, ll_increment = update(model, alg, iter, state, observation; kwargs...) + # TODO: ll_increment is no longer consistent with the Kalman filter + state, ll_increment = update(model, algo, iter, state, observation; kwargs...) isnothing(callback) || - callback(model, alg, iter, state, observation, PostUpdate; kwargs...) + callback(model, algo, iter, state, observation, PostUpdate; kwargs...) return state, (ll_increment - marginalization_term) end @@ -112,7 +116,7 @@ function predict( rng::AbstractRNG, model::StateSpaceModel, filter::ParticleFilter, - step::Integer, + iter::Integer, state::ParticleDistribution, observation; ref_state::Union{Nothing,AbstractVector}=nothing, @@ -120,20 +124,20 @@ function predict( ) proposed_particles = map(enumerate(state.particles)) do (i, particle) if !isnothing(ref_state) && i == 1 - ref_state[step] + ref_state[iter] else - simulate(rng, model, filter.proposal, step, particle, observation; kwargs...) + simulate(rng, model, filter.proposal, iter, particle, observation; kwargs...) end end state.log_weights += map(zip(proposed_particles, state.particles)) do (new_state, prev_state) log_f = SSMProblems.logdensity( - model.dyn, step, prev_state, new_state; kwargs... + model.dyn, iter, prev_state, new_state; kwargs... ) - log_q = logdensity( - model, filter.proposal, step, prev_state, new_state, observation; kwargs... + log_q = SSMProblems.logdensity( + model, filter.proposal, iter, prev_state, new_state, observation; kwargs... ) (log_f - log_q) @@ -147,14 +151,14 @@ end function update( model::StateSpaceModel{T}, filter::ParticleFilter, - step::Integer, + iter::Integer, state::ParticleDistribution, observation; kwargs..., ) where {T} log_increments = map( - x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), - collect(state), + x -> SSMProblems.logdensity(model.obs, iter, x, observation; kwargs...), + state.particles, ) state.log_weights += log_increments @@ -162,7 +166,6 @@ function update( return state, logsumexp(state.log_weights) end -# Default to latent dynamics struct LatentProposal <: AbstractProposal end const BootstrapFilter{RS} = ParticleFilter{RS,LatentProposal} @@ -173,24 +176,24 @@ function simulate( rng::AbstractRNG, model::AbstractStateSpaceModel, prop::LatentProposal, - step::Integer, + iter::Integer, state, observation; kwargs..., ) - return SSMProblems.simulate(rng, model.dyn, step, state; kwargs...) + return SSMProblems.simulate(rng, model.dyn, iter, state; kwargs...) end function logdensity( model::AbstractStateSpaceModel, prop::LatentProposal, - step::Integer, + iter::Integer, prev_state, new_state, observation; kwargs..., ) - return SSMProblems.logdensity(model.dyn, step, prev_state, new_state; kwargs...) + return SSMProblems.logdensity(model.dyn, iter, prev_state, new_state; kwargs...) end # overwrite predict for the bootstrap filter to remove redundant computation @@ -198,28 +201,28 @@ function predict( rng::AbstractRNG, model::StateSpaceModel, filter::BootstrapFilter, - step::Integer, + iter::Integer, state::ParticleDistribution, - observation; + observation=nothing; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) state.particles = map(enumerate(state.particles)) do (i, particle) if !isnothing(ref_state) && i == 1 - ref_state[step] + ref_state[iter] else - SSMProblems.simulate(rng, model.dyn, step, particle; kwargs...) + SSMProblems.simulate(rng, model.dyn, iter, particle; kwargs...) end end return state end -# Application of bootstrap filter to hierarchical models +# Application of particle filter to hierarchical models function filter( rng::AbstractRNG, model::HierarchicalSSM, - alg::BootstrapFilter, + algo::ParticleFilter, observations::AbstractVector; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., @@ -228,5 +231,5 @@ function filter( HierarchicalDynamics(model.outer_dyn, model.inner_model.dyn), HierarchicalObservations(model.inner_model.obs), ) - return filter(rng, ssm, alg, observations; ref_state=ref_state, kwargs...) + return filter(rng, ssm, algo, observations; ref_state=ref_state, kwargs...) end diff --git a/GeneralisedFilters/src/algorithms/rbpf.jl b/GeneralisedFilters/src/algorithms/rbpf.jl index 70f9d56..e73bc72 100644 --- a/GeneralisedFilters/src/algorithms/rbpf.jl +++ b/GeneralisedFilters/src/algorithms/rbpf.jl @@ -46,7 +46,7 @@ function predict( rng::AbstractRNG, model::HierarchicalSSM, algo::RBPF, - t::Integer, + iter::Integer, state, observation; ref_state::Union{Nothing,AbstractVector}=nothing, @@ -54,15 +54,15 @@ function predict( ) state.particles = map(enumerate(state.particles)) do (i, particle) new_x = if !isnothing(ref_state) && i == 1 - ref_state[t] + ref_state[iter] else - SSMProblems.simulate(rng, model.outer_dyn, t, particle.x; kwargs...) + SSMProblems.simulate(rng, model.outer_dyn, iter, particle.x; kwargs...) end new_z = predict( rng, model.inner_model, algo.inner_algo, - t, + iter, particle.z, observation; prev_outer=particle.x, @@ -77,13 +77,13 @@ function predict( end function update( - model::HierarchicalSSM{T}, algo::RBPF, step::Integer, state, observation; kwargs... + model::HierarchicalSSM{T}, algo::RBPF, iter::Integer, state, observation; kwargs... ) where {T} for i in 1:(algo.N) state.particles[i].z, log_increments = update( model.inner_model, algo.inner_algo, - step, + iter, state.particles[i].z, observation; new_outer=state.particles[i].x, @@ -158,7 +158,7 @@ function predict( rng::AbstractRNG, model::HierarchicalSSM, algo::BatchRBPF, - step::Integer, + iter::Integer, state::RaoBlackwellisedParticleDistribution, observation; ref_state::Union{Nothing,AbstractVector}=nothing, @@ -167,19 +167,19 @@ function predict( outer_dyn, inner_model = model.outer_dyn, model.inner_model new_xs = SSMProblems.batch_simulate( - rng, outer_dyn, step, state.particles.xs; ref_state, kwargs... + rng, outer_dyn, iter, state.particles.xs; ref_state, kwargs... ) # Set reference trajectory if ref_state !== nothing - new_xs[:, [1]] = ref_state[step] + new_xs[:, [1]] = ref_state[iter] end new_zs = predict( rng, inner_model, algo.inner_algo, - step, + iter, state.particles.zs, observation; prev_outer=state.particles.xs, @@ -194,7 +194,7 @@ end function update( model::HierarchicalSSM, algo::BatchRBPF, - step::Integer, + iter::Integer, state::RaoBlackwellisedParticleDistribution, obs; kwargs..., @@ -202,7 +202,7 @@ function update( new_zs, inner_lls = update( model.inner_model, algo.inner_algo, - step, + iter, state.particles.zs, obs; new_outer=state.particles.xs, diff --git a/research/maximum_likelihood/script.jl b/research/maximum_likelihood/script.jl index 765c00a..41404a6 100644 --- a/research/maximum_likelihood/script.jl +++ b/research/maximum_likelihood/script.jl @@ -39,17 +39,18 @@ end ## NEWTONS METHOD ########################################################################## using DifferentiationInterface -import ForwardDiff, Zygote, Mooncake, Enzyme +using ForwardDiff: ForwardDiff +using Zygote: Zygote +using Mooncake: Mooncake +using Enzyme: Enzyme using Optimisers # Zygote will fail due to the model constructor, not because of the filtering algorithm -backends = [ - AutoZygote(), AutoForwardDiff(), AutoMooncake(;config=nothing), AutoEnzyme() -] +backends = [AutoZygote(), AutoForwardDiff(), AutoMooncake(; config=nothing), AutoEnzyme()] function gradient_descent(backend, θ_init, num_epochs=1000) θ = deepcopy(θ_init) - state = Optimisers.setup(Optimisers.Descent(1/length(ys)), θ) + state = Optimisers.setup(Optimisers.Descent(1 / length(ys)), θ) grad_prep = prepare_gradient(logℓ, backend, θ, Constant(ys)) for epoch in 1:num_epochs @@ -59,7 +60,7 @@ function gradient_descent(backend, θ_init, num_epochs=1000) Optimisers.update!(state, θ, ∇logℓ) (epoch % 5) == 1 && println("$(epoch-1):\t -$(val)") - if (∇logℓ'*∇logℓ) < 1e-12 + if (∇logℓ' * ∇logℓ) < 1e-12 break end end @@ -69,7 +70,7 @@ end θ_init = rand(rng, 1) for backend in backends - println("\n",backend) + println("\n", backend) local θ_mle try θ_mle = gradient_descent(backend, θ_init) From 93c86e2c15f393dbdbf4af3a0675e6d6bd5153e4 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Fri, 4 Apr 2025 11:58:53 -0400 Subject: [PATCH 31/33] update proposal definition --- research/variational_filter/script.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/research/variational_filter/script.jl b/research/variational_filter/script.jl index 4f37501..b53c134 100644 --- a/research/variational_filter/script.jl +++ b/research/variational_filter/script.jl @@ -64,7 +64,7 @@ function (kernel::DeepGaussianProposal)(x) return kernel.μ_net(x), kernel.Σ_net(x) end -function GeneralisedFilters.distribution( +function SSMProblems.distribution( model::AbstractStateSpaceModel, kernel::DeepGaussianProposal, step::Integer, From 243f86a95abb733e457094dfb210371eed1f0c80 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Fri, 4 Apr 2025 14:29:29 -0400 Subject: [PATCH 32/33] add unit testing --- GeneralisedFilters/test/runtests.jl | 82 +++++++++++++++++------------ 1 file changed, 48 insertions(+), 34 deletions(-) diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index 777b721..c9fd1ad 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -80,52 +80,66 @@ end end @testitem "Bootstrap filter test" begin - using GeneralisedFilters using SSMProblems using StableRNGs - using PDMats - using LinearAlgebra using LogExpFunctions: softmax - using Random: randexp - T = Float64 rng = StableRNG(1234) - σx², σy² = randexp(rng, T, 2) - - # initial state distribution - μ0 = zeros(T, 2) - Σ0 = PDMat(T[1 0; 0 1]) - - # state transition equation - A = T[1 1; 0 1] - b = T[0; 0] - Q = PDiagMat([σx²; 0]) - - # observation equation - H = T[1 0] - c = T[0;] - R = [σy²;;] - - # when working with PDMats, the Kalman filter doesn't play nicely without this - function Base.convert(::Type{PDMat{T,MT}}, mat::MT) where {MT<:AbstractMatrix,T<:Real} - return PDMat(Symmetric(mat)) - end - - model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) - _, _, data = sample(rng, model, 10) + model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) + _, _, ys = sample(rng, model, 10) - bf = BF(2^16; threshold=0.8) - bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, data) - kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), data) + bf = BF(2^12; threshold=0.8) + bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, ys) + kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), ys) xs = bf_state.particles ws = softmax(bf_state.log_weights) - # Compare filtered states + # Compare log-likelihood and states @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-2 + @test llkf ≈ llbf atol = 1e-1 +end - # since this is log valued, we can up the tolerance - @test llkf ≈ llbf atol = 0.1 +@testitem "Guided filter test" begin + using SSMProblems + using LogExpFunctions: softmax + using StableRNGs + using Distributions + using GaussianDistributions + using LinearAlgebra + + struct LinearGaussianProposal <: GeneralisedFilters.AbstractProposal end + + function SSMProblems.distribution( + model::AbstractStateSpaceModel, + kernel::LinearGaussianProposal, + iter::Integer, + state, + observation; + kwargs..., + ) + A, b, Q = GeneralisedFilters.calc_params(model.dyn, iter; kwargs...) + pred = Gaussian(A * state + b, Q) + prop, _ = GeneralisedFilters.update( + model, KF(), iter, pred, observation; kwargs... + ) + return MvNormal(prop.μ, hermitianpart(prop.Σ)) + end + + rng = StableRNG(1234) + model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) + _, _, ys = sample(rng, model, 10) + + algo = PF(2^10, LinearGaussianProposal(); threshold=0.6) + kf_states, kf_ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), ys) + pf_states, pf_ll = GeneralisedFilters.filter(rng, model, algo, ys) + + xs = pf_states.particles + ws = softmax(pf_states.log_weights) + + # Compare log-likelihood and states + @test first(kf_states.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-2 + @test kf_ll ≈ pf_ll rtol = 1e-1 end @testitem "Forward algorithm test" begin From 0dcb20f7a501f7f8d75604847385fbb0ae46a3e5 Mon Sep 17 00:00:00 2001 From: Charles Knipp <32943413+charlesknipp@users.noreply.github.com> Date: Fri, 4 Apr 2025 14:34:42 -0400 Subject: [PATCH 33/33] formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- GeneralisedFilters/test/runtests.jl | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index c9fd1ad..c1ace04 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -120,9 +120,7 @@ end ) A, b, Q = GeneralisedFilters.calc_params(model.dyn, iter; kwargs...) pred = Gaussian(A * state + b, Q) - prop, _ = GeneralisedFilters.update( - model, KF(), iter, pred, observation; kwargs... - ) + prop, _ = GeneralisedFilters.update(model, KF(), iter, pred, observation; kwargs...) return MvNormal(prop.μ, hermitianpart(prop.Σ)) end @@ -133,7 +131,6 @@ end algo = PF(2^10, LinearGaussianProposal(); threshold=0.6) kf_states, kf_ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), ys) pf_states, pf_ll = GeneralisedFilters.filter(rng, model, algo, ys) - xs = pf_states.particles ws = softmax(pf_states.log_weights)