From 73d7d4e60e427b5669d370af39c5b4ffc05d6459 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Fri, 2 May 2025 13:46:50 -0400 Subject: [PATCH 01/11] disentangle priors from dynamics --- SSMProblems/Project.toml | 4 +- .../examples/general-filters/Project.toml | 9 + .../examples/general-filters/script.jl | 211 ++++++++++++++++++ SSMProblems/examples/kalman-filter/script.jl | 87 ++++---- SSMProblems/src/SSMProblems.jl | 142 +++--------- SSMProblems/src/utils/forward_simulation.jl | 24 +- 6 files changed, 316 insertions(+), 161 deletions(-) create mode 100644 SSMProblems/examples/general-filters/Project.toml create mode 100644 SSMProblems/examples/general-filters/script.jl diff --git a/SSMProblems/Project.toml b/SSMProblems/Project.toml index bfbebff..11713fc 100644 --- a/SSMProblems/Project.toml +++ b/SSMProblems/Project.toml @@ -4,19 +4,21 @@ authors = [ "FredericWantiez ", "THargreaves ", "Hong Ge", - "Charles Knipp" + "Charles Knipp " ] version = "0.5.0" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [compat] AbstractMCMC = "5" Aqua = "0.8" Distributions = "0.25" +OffsetArrays = "1.17.0" Random = "1" Test = "1" julia = "1.6" diff --git a/SSMProblems/examples/general-filters/Project.toml b/SSMProblems/examples/general-filters/Project.toml new file mode 100644 index 0000000..9e3f4f8 --- /dev/null +++ b/SSMProblems/examples/general-filters/Project.toml @@ -0,0 +1,9 @@ +[deps] +AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" +Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48" +StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed" diff --git a/SSMProblems/examples/general-filters/script.jl b/SSMProblems/examples/general-filters/script.jl new file mode 100644 index 0000000..dd37333 --- /dev/null +++ b/SSMProblems/examples/general-filters/script.jl @@ -0,0 +1,211 @@ +using SSMProblems +using Distributions +using StatsFuns +using Random +using AbstractMCMC +using LinearAlgebra +using UnPack +using Plots + +## PARTICLE DISTRIBUTIONS ################################################################## + +abstract type ParticleDistribution end + +mutable struct Particles{PT} <: ParticleDistribution + particles::Vector{PT} +end + +mutable struct WeightedParticles{PT,WT<:Real} <: ParticleDistribution + particles::Vector{PT} + log_weights::Vector{WT} +end + +mutable struct GaussianDistribution{XT,ΣT} <: ParticleDistribution + x::XT + Σ::ΣT +end + +function ParticleDistribution(particles::AbstractVector) + return Particles(particles) +end + +function ParticleDistribution(particles::AbstractVector, log_weights::Vector{<:Real}) + return WeightedParticles(particles, log_weights) +end + +weights(state::Particles) = zeros(length(state.particles)) +weights(state::WeightedParticles) = softmax(state.log_weights) + +function update_weights(state::Particles, log_weights) + return WeightedParticles(state.particles, log_weights) +end + +function update_weights(state::WeightedParticles, log_weights) + state.log_weights += log_weights + return state +end + +## BOOTSTRAP FILTER ######################################################################## + +abstract type AbstractFilter end + +struct BootstrapFilter <: AbstractFilter + N::Int +end + +const BF = BootstrapFilter + +function initialize(rng::AbstractRNG, model::StateSpaceModel, algo::BF; kwargs...) + particles = map(1:(algo.N)) do _ + SSMProblems.simulate(rng, model.prior; kwargs...) + end + + return ParticleDistribution(particles) +end + +function resample(rng::AbstractRNG, algo::BF, state::ParticleDistribution) + ws = weights(state) + ess = inv(sum(abs2, ws)) + if ess <= algo.N * 0.5 + idx = rand(rng, Distributions.Categorical(ws), algo.N) + + # creates a new particle distribution in GF anyways + return ParticleDistribution(state.particles[idx]) + else + return state + end +end + +function predict( + rng::AbstractRNG, model::StateSpaceModel, algo::BF, iter::Int, state; kwargs... +) + state.particles = map(state.particles) do particle + SSMProblems.simulate(rng, model.dyn, iter, particle; kwargs...) + end + return state +end + +function update(model::StateSpaceModel, algo::BF, iter::Int, state, observation; kwargs...) + log_increments = map(state.particles) do particle + SSMProblems.logdensity(model.obs, iter, particle, observation; kwargs...) + end + return update_weights(state, log_increments) +end + +function AbstractMCMC.sample( + rng::AbstractRNG, + model::StateSpaceModel, + algo::AbstractFilter, + observations::AbstractVector; + kwargs..., +) + T = length(observations) + state = initialize(rng, model, algo; kwargs...) + filtererd_states = [] + + for t in 1:T + state = resample(rng, algo, state) + state = predict(rng, model, algo, t, state; kwargs...) + state = update(model, algo, t, state, observations[t]; kwargs...) + push!(filtererd_states, deepcopy(state)) + end + + return filtererd_states +end + +## KALMAN FILTER ########################################################################### + +struct KalmanFilter <: AbstractFilter end + +const KF = KalmanFilter + +function initialize(rng::AbstractRNG, model::StateSpaceModel, algo::KF; kwargs...) + @unpack x, Σ = model.prior + return GaussianDistribution(x, Σ) +end + +function resample(::AbstractRNG, ::KF, state::GaussianDistribution) + return state +end + +function predict( + ::AbstractRNG, model::StateSpaceModel, algo::KF, iter::Int, state; kwargs... +) + @unpack Φ, b, Q = model.dyn + x = Φ * state.x + b + Σ = Φ * state.Σ * Φ' + Q + return GaussianDistribution(x, Σ) +end + +function update(model::StateSpaceModel, algo::KF, iter::Int, state, observation; kwargs...) + @unpack H, R = model.obs + K = state.Σ * H' / (H * state.Σ * H' + R) + state.x = state.x + K * (observation - H * state.x) + state.Σ = state.Σ - K * H * state.Σ + return state +end + +## STATE SPACE MODEL ####################################################################### + +struct GaussianPrior{XT<:AbstractVector,ΣT<:AbstractMatrix} <: StatePrior + x::XT + Σ::ΣT +end + +struct LinearGaussianLatentDynamics{ + ΦT<:AbstractMatrix,bT<:AbstractArray,QT<:AbstractMatrix +} <: LatentDynamics + Φ::ΦT + b::bT + Q::QT +end + +struct LinearGaussianObservationProcess{HT<:AbstractMatrix,RT<:AbstractMatrix} <: + ObservationProcess + H::HT + R::RT +end + +function SSMProblems.distribution(prior::GaussianPrior; kwargs...) + return MvNormal(prior.x, prior.Σ) +end + +function SSMProblems.distribution( + model::LinearGaussianLatentDynamics, step::Int, prev_state::AbstractVector; kwargs... +) + return MvNormal(model.Φ * prev_state + model.b, model.Q) +end + +function SSMProblems.distribution( + model::LinearGaussianObservationProcess, step::Int, state::AbstractVector; kwargs... +) + return MvNormal(model.H * state, model.R) +end + +const LinearGaussianSSM = StateSpaceModel{ + <:GaussianPrior,<:LinearGaussianLatentDynamics,<:LinearGaussianObservationProcess +} + +## DEMONSTRATION ########################################################################### + +toy_model = StateSpaceModel( + GaussianPrior([-1.0, 1.0], Matrix(1.0I, 2, 2)), + LinearGaussianLatentDynamics([0.8 0.2; -0.1 0.8], zeros(2), [0.2 0.0; 0.0 0.5]), + LinearGaussianObservationProcess([1.0 0.0;], Matrix(0.3I, 1, 1)), +); + +rng = MersenneTwister(1234); +xs, ys = sample(rng, toy_model, 100); + +kf_states = AbstractMCMC.sample(rng, toy_model, KalmanFilter(), ys); +bf_states = AbstractMCMC.sample(rng, toy_model, BootstrapFilter(1024), ys); + +begin + plt = plot(; title="First Dimension Filtered Estimates", xlabel="Step", ylabel="Value") + scatter!(plt, first.(ys); label="Observations", ms=2) + plot!(plt, first.(getproperty.(kf_states, :x)); label="Kalman Filtered") + plot!( + plt, first.(mean.(getproperty.(bf_states, :particles))); label="Bootstrap Filtered" + ) + plt +end \ No newline at end of file diff --git a/SSMProblems/examples/kalman-filter/script.jl b/SSMProblems/examples/kalman-filter/script.jl index 4cf7114..88d9635 100644 --- a/SSMProblems/examples/kalman-filter/script.jl +++ b/SSMProblems/examples/kalman-filter/script.jl @@ -25,12 +25,17 @@ using SSMProblems # # We store all of these paramaters in a struct. -struct LinearGaussianLatentDynamics{T<:Real} <: LatentDynamics{T,Vector{T}} - z::Vector{T} - P::Matrix{T} - Φ::Matrix{T} - b::Vector{T} - Q::Matrix{T} +struct GaussianPrior{XT<:AbstractVector,ΣT<:AbstractMatrix} <: StatePrior + x::XT + Σ::ΣT +end + +struct LinearGaussianLatentDynamics{ + ΦT<:AbstractMatrix,bT<:AbstractArray,QT<:AbstractMatrix +} <: LatentDynamics + Φ::ΦT + b::bT + Q::QT end # Note, that our specific dynamics should be subtypes of the abstract `LatentDynamics` type. @@ -45,9 +50,10 @@ end # y[k] = Hx[k] + v[k], v[k] ∼ N(0, R) # ``` -struct LinearGaussianObservationProcess{T<:Real} <: ObservationProcess{T,Vector{T}} - H::Matrix{T} - R::Matrix{T} +struct LinearGaussianObservationProcess{HT<:AbstractMatrix,RT<:AbstractMatrix} <: + ObservationProcess + H::HT + R::RT end # We then define general transition and observation distributions to be used in forward @@ -60,17 +66,19 @@ end # be preferred in this linear Gaussian case, it may be of interest to compare the sampling # performance with a general particle filter. -function SSMProblems.distribution(model::LinearGaussianLatentDynamics; kwargs...) - return MvNormal(model.z, model.P) +function SSMProblems.distribution(prior::GaussianPrior; kwargs...) + return MvNormal(prior.x, prior.Σ) end + function SSMProblems.distribution( - model::LinearGaussianLatentDynamics{T}, step::Int, prev_state::Vector{T}; kwargs... -) where {T} + model::LinearGaussianLatentDynamics, step::Int, prev_state::AbstractVector; kwargs... +) return MvNormal(model.Φ * prev_state + model.b, model.Q) end + function SSMProblems.distribution( - model::LinearGaussianObservationProcess{T}, step::Int, state::Vector{T}; kwargs... -) where {T} + model::LinearGaussianObservationProcess, step::Int, state::AbstractVector; kwargs... +) return MvNormal(model.H * state, model.R) end @@ -85,8 +93,8 @@ struct KalmanFilter end # alias for an SSM with linear Gaussian latent dynamics and observation process, which will # be used to dispatch to the correct method. -const LinearGaussianSSM{T} = StateSpaceModel{ - T,<:LinearGaussianLatentDynamics{T},<:LinearGaussianObservationProcess{T} +const LinearGaussianSSM = StateSpaceModel{ + <:GaussianPrior,<:LinearGaussianLatentDynamics,<:LinearGaussianObservationProcess }; # We then define a method for the `sample` function. This is a standardised interface which @@ -100,35 +108,33 @@ const LinearGaussianSSM{T} = StateSpaceModel{ # `calc_A(model, t; kwargs...)` etc. function AbstractMCMC.sample( - model::LinearGaussianSSM{MT}, ::KalmanFilter, observations::AbstractVector; kwargs... -) where {MT} - T = length(observations) - x_filts = Vector{Vector{MT}}(undef, T) - P_filts = Vector{Matrix{MT}}(undef, T) - - @unpack z, P, Φ, b, Q = model.dyn ## Extract parameters + model::LinearGaussianSSM, ::KalmanFilter, observations::AbstractVector; kwargs... +) + ## Extract parameters + @unpack Φ, b, Q = model.dyn @unpack H, R = model.obs + @unpack x, Σ = model.prior - ## Initialise the filter - x = z - P = P + T = length(observations) + x_filts = Vector{typeof(x)}(undef, T) + Σ_filts = Vector{typeof(Σ)}(undef, T) for t in 1:T ## Prediction step x = Φ * x + b - P = Φ * P * Φ' + Q + Σ = Φ * Σ * Φ' + Q ## Update step y = observations[t] - K = P * H' / (H * P * H' + R) + K = Σ * H' / (H * Σ * H' + R) x = x + K * (y - H * x) - P = P - K * H * P + Σ = Σ - K * H * Σ x_filts[t] = x - P_filts[t] = P + Σ_filts[t] = Σ end - return x_filts, P_filts + return x_filts, Σ_filts end # ## Simulation and Filtering @@ -137,8 +143,8 @@ end SEED = 1; T = 100; -z = [-1.0, 1.0]; -P = Matrix(1.0I, 2, 2); +x = [-1.0, 1.0]; +Σ = Matrix(1.0I, 2, 2); Φ = [0.8 0.2; -0.1 0.8]; b = zeros(2); Q = [0.2 0.0; 0.0 0.5]; @@ -148,25 +154,28 @@ R = Matrix(0.3I, 1, 1); # We can then construct the latent dynamics and observation process, before combining these # to form a state space model. -dyn = LinearGaussianLatentDynamics(z, P, Φ, b, Q); +prior = GaussianPrior(x, Σ); +dyn = LinearGaussianLatentDynamics(Φ, b, Q); obs = LinearGaussianObservationProcess(H, R); -model = StateSpaceModel(dyn, obs); +model = StateSpaceModel(prior, dyn, obs); # Synthetic data can be generated by directly sampling from the model. This calls a utility # function from the `SSMProblems` package, which in turn, calls the three distribution # functions we defined above. rng = MersenneTwister(SEED); -x0, xs, ys = sample(rng, model, T); +xs, ys = sample(rng, model, T); +# @code_warntype sample(rng, model, T) # We can then run the Kalman filter and plot the filtering results against the ground truth. x_filts, P_filts = AbstractMCMC.sample(model, KalmanFilter(), ys); +# @code_warntype AbstractMCMC.sample(model, KalmanFilter(), ys); # Plot trajectory for first dimension p = plot(; title="First Dimension Kalman Filter Estimates", xlabel="Step", ylabel="Value") -plot!(p, 1:T, first.(xs); label="Truth") -scatter!(p, 1:T, first.(ys); label="Observations") +plot!(p, first.(xs); label="Truth") +scatter!(p, first.(ys); label="Observations") plot!( p, 1:T, diff --git a/SSMProblems/src/SSMProblems.jl b/SSMProblems/src/SSMProblems.jl index 5544bde..2c7815b 100644 --- a/SSMProblems/src/SSMProblems.jl +++ b/SSMProblems/src/SSMProblems.jl @@ -8,48 +8,34 @@ import Base: eltype import Random: AbstractRNG, default_rng import Distributions: logpdf -export LatentDynamics, ObservationProcess, AbstractStateSpaceModel, StateSpaceModel +export StatePrior, + LatentDynamics, ObservationProcess, AbstractStateSpaceModel, StateSpaceModel """ -Latent dynamics of a state space model. - -Any concrete subtype of `LatentDynamics` should implement the functions `logdensity` and -`simulate`, by defining two methods as documented below, one for initialisation and one -for transitioning. Whether each of these functions need to be implemented depends on the -exact inference algorithm that is intended to be used. +Initial state prior of a state space model. -Alternatively, you may specify methods for the function `distribution` which will be -used to define the above methods. - -All of these methods should accept keyword arguments through `kwargs...` to facilitate -inference-time dependencies of the dynamics as explained in [Control Variables and Keyword Arguments](@ref). +Any concrete subtype of `StatePrior` should implement the functions `logdensity` and +`simulate` as defined below. -The latent states should be of type `ET` which should be a composed from `T`, the -arithmetic type used for the dynamics (e.g. Float32, ForwardDiff.Dual). - -# Parameters -- `T`: The arithmetic type of the latent dynamics. -- `ET`: The element type of the latent dynamics. +Alternatively, you may specify a method for the function `distribution` which will be used +to define the above methods. """ -abstract type LatentDynamics{T<:Real,ET} end +abstract type StatePrior end """ - arithmetic_type(::Type{<:LatentDynamics}) - arithmetic_type(dyn::LatentDynamics) +Latent dynamics of a state space model. -Return the arithmetic type of the latent dynamics. -""" -arithmetic_type(::Type{<:LatentDynamics{T}}) where {T} = T -arithmetic_type(dyn::LatentDynamics) = arithmetic_type(typeof(dyn)) +Any concrete subtype of `LatentDynamics` should implement the functions `logdensity` and +`simulate` for transition dynamics. Whether each of these functions need to be implemented +depends on the exact inference algorithm that is intended to be used. -""" - eltype(::Type{<:LatentDynamics}) - eltype(dyn::LatentDynamics) +Alternatively, you may specify a method for the function `distribution` which will be used +to define the above methods. -Return the type of the state of the latent dynamics. +All of these methods should accept keyword arguments through `kwargs...` to facilitate +inference-time dependencies of the dynamics as explained in [Control Variables and Keyword Arguments](@ref). """ -Base.eltype(::Type{<:LatentDynamics{T,ET}}) where {T,ET} = ET -Base.eltype(dyn::LatentDynamics) = eltype(typeof(dyn)) +abstract type LatentDynamics end """ Observation process of a state space model. @@ -63,50 +49,25 @@ both of the above methods. All of these methods should accept keyword arguments through `kwargs...` to facilitate inference-time dependencies of the observations as explained in [Control Variables and Keyword Arguments](@ref). - -The observations should be of type `ET` which should be a composed from `T`, the -arithmetic type used for the observations (e.g. Float32, ForwardDiff.Dual). - -# Parameters -- `T`: The arithmetic type of the observation process. -- `ET`: The element type of the observation process. """ -abstract type ObservationProcess{T<:Real,ET} end +abstract type ObservationProcess end """ - arithmetic_type(::Type{<:ObservationProcess}) - arithmetic_type(obs::ObservationProcess) + distribution(prior::StatePrior; kwargs...) -Return the arithmetic type of the observation process. -""" -arithmetic_type(::Type{<:ObservationProcess{T}}) where {T} = T -arithmetic_type(obs::ObservationProcess) = arithmetic_type(typeof(obs)) - -""" - eltype(::Type{<:ObservationProcess}) - eltype(obs::ObservationProcess) - -Return the type of the observations of the observation process. -""" -Base.eltype(::Type{<:ObservationProcess{T,ET}}) where {T,ET} = ET -Base.eltype(obs::ObservationProcess) = eltype(typeof(obs)) - -""" - distribution(dyn::LatentDynamics; kwargs...) - -Return the initialisation distribution for the latent dynamics. +Return the transition distribution for the latent dynamics. -The method should return the distribution of the initial state of the latent dynamics. -The returned value should be a `Distributions.Distribution` object that implements -sampling and log-density calculations. +The method should return the distribution of the initial state of the latent dynamics. +The returned value should be a `Distributions.Distribution` object that implements sampling +and log-density calculations. -See also [`LatentDynamics`](@ref). +See also [`StatePrior`](@ref). # Returns - `Distributions.Distribution`: The distribution of the initial state. """ -function distribution(dyn::LatentDynamics; kwargs...) - throw(MethodError(distribution, (dyn))) +function distribution(prior::StatePrior; kwargs...) + throw(MethodError(distribution, (prior))) end """ @@ -146,7 +107,7 @@ function distribution(obs::ObservationProcess, step::Integer, state; kwargs...) end """ - simulate([rng::AbstractRNG], dyn::LatentDynamics; kwargs...) + simulate([rng::AbstractRNG], prior::StatePrior; kwargs...) Simulate an initial state for the latent dynamics. @@ -156,12 +117,12 @@ dynamics. The default behaviour is generate a random sample from distribution returned by the corresponding `distribution()` method. -See also [`LatentDynamics`](@ref). +See also [`StatePrior`](@ref). """ -function simulate(rng::AbstractRNG, dyn::LatentDynamics; kwargs...) - return rand(rng, distribution(dyn; kwargs...)) +function simulate(rng::AbstractRNG, prior::StatePrior; kwargs...) + return rand(rng, distribution(prior; kwargs...)) end -simulate(dyn::LatentDynamics; kwargs...) = simulate(default_rng(), dyn; kwargs...) +simulate(prior::StatePrior; kwargs...) = simulate(default_rng(), prior; kwargs...) """ simulate([rng::AbstractRNG], dyn::LatentDynamics, step::Integer, prev_state; kwargs...) @@ -207,23 +168,6 @@ function simulate(obs::ObservationProcess, step::Integer, state; kwargs...) return simulate(default_rng(), obs, step, state; kwargs...) end -""" - logdensity(dyn::LatentDynamics, new_state; kwargs...) - -Compute the log-density of an initial state for the latent dynamics. - -The method should return the log-density of the initial state `new_state` for the -initial time step of the latent dynamics. - -The default behaviour is to compute the log-density of the distribution return by the -corresponding `distribution()` method. - -See also [`LatentDynamics`](@ref). -""" -function logdensity(dyn::LatentDynamics, new_state; kwargs...) - return logpdf(distribution(dyn; kwargs...), new_state) -end - """ logdensity(dyn::LatentDynamics, step::Integer, prev_state, new_state; kwargs...) @@ -273,37 +217,25 @@ abstract type AbstractStateSpaceModel <: AbstractMCMC.AbstractModel end """ A state space model. -A vanilla implementation of a state space model, composed of a latent dynamics and an -observation process. +A vanilla implementation of a state space model, composed of an intiail state prior, latent +dynamics and an observation process. # Fields +- `prior::PT`: The initial state prior fo the state space model. - `dyn::LD`: The latent dynamics of the state space model. - `obs::OP`: The observation process of the state space model. # Parameters -- `T`: The arithmetic type of the state space model, which the latent dynamics and - observation process should be consistent with. +- `PT`: The type of the initial state prior. - `LD`: The type of the latent dynamics. - `OP`: The type of the observation process. """ -struct StateSpaceModel{T<:Real,LD<:LatentDynamics{T},OP<:ObservationProcess{T}} <: - AbstractStateSpaceModel +struct StateSpaceModel{PT,LD,OP} <: AbstractStateSpaceModel + prior::PT dyn::LD obs::OP - function StateSpaceModel(dyn::LatentDynamics{T}, obs::ObservationProcess{T}) where {T} - return new{T,typeof(dyn),typeof(obs)}(dyn, obs) - end end -""" - arithmetic_type(::Type{<:StateSpaceModel}) - arithmetic_type(model::StateSpaceModel) - -Return the arithmetic type of the state space model. -""" -arithmetic_type(model::StateSpaceModel) = arithmetic_type(typeof(model)) -arithmetic_type(::Type{<:StateSpaceModel{T}}) where {T} = T - include("batch_methods.jl") include("utils/forward_simulation.jl") diff --git a/SSMProblems/src/utils/forward_simulation.jl b/SSMProblems/src/utils/forward_simulation.jl index cdafc10..bfcf0eb 100644 --- a/SSMProblems/src/utils/forward_simulation.jl +++ b/SSMProblems/src/utils/forward_simulation.jl @@ -1,5 +1,7 @@ """Forward simulation of state space models.""" +using OffsetArrays: OffsetVector + import AbstractMCMC: sample export sample @@ -8,25 +10,15 @@ export sample Simulate a trajectory of length `T` from the state space model. -Returns a tuple `(x0, xs, ys)` where `x0` is the initial state, `xs` is a vector of latent states, -and `ys` is a vector of observations. +Returns a tuple `(xs, ys)` where `xs` is a vector of latent states (including the initial +state) and `ys` is a vector of observations. """ -function sample( - rng::AbstractRNG, model::StateSpaceModel{<:Real,LD,OP}, T::Integer; kwargs... -) where {LD,OP} - T_dyn = eltype(LD) - T_obs = eltype(OP) - - xs = Vector{T_dyn}(undef, T) - ys = Vector{T_obs}(undef, T) - - x0 = simulate(rng, model.dyn; kwargs...) +function sample(rng::AbstractRNG, model::StateSpaceModel, T::Integer; kwargs...) + xs = OffsetVector(fill(simulate(rng, model.prior), T + 1), -1) for t in 1:T - xs[t] = simulate(rng, model.dyn, t, t == 1 ? x0 : xs[t - 1]; kwargs...) - ys[t] = simulate(rng, model.obs, t, xs[t]; kwargs...) + xs[t] = simulate(rng, model.dyn, t, xs[t - 1]; kwargs...) end - - return x0, xs, ys + return xs, map(t -> simulate(rng, model.obs, t, xs[t]; kwargs...), 1:T) end """ From c7b51041d9c0a84530e6f1f0b3b614ce9e3f0672 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Mon, 12 May 2025 11:20:36 -0400 Subject: [PATCH 02/11] update SSMProblems demo --- .../examples/general-filters/script.jl | 33 +++++++++---------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/SSMProblems/examples/general-filters/script.jl b/SSMProblems/examples/general-filters/script.jl index dd37333..6a0da21 100644 --- a/SSMProblems/examples/general-filters/script.jl +++ b/SSMProblems/examples/general-filters/script.jl @@ -9,19 +9,19 @@ using Plots ## PARTICLE DISTRIBUTIONS ################################################################## -abstract type ParticleDistribution end +abstract type ParticleDistribution{PT} end -mutable struct Particles{PT} <: ParticleDistribution +mutable struct Particles{PT} <: ParticleDistribution{PT} particles::Vector{PT} end -mutable struct WeightedParticles{PT,WT<:Real} <: ParticleDistribution +mutable struct WeightedParticles{PT,WT<:Real} <: ParticleDistribution{PT} particles::Vector{PT} log_weights::Vector{WT} end -mutable struct GaussianDistribution{XT,ΣT} <: ParticleDistribution - x::XT +struct GaussianDistribution{PT,ΣT} <: ParticleDistribution{PT} + μ::PT Σ::ΣT end @@ -51,6 +51,7 @@ abstract type AbstractFilter end struct BootstrapFilter <: AbstractFilter N::Int + threshold::Float64 end const BF = BootstrapFilter @@ -66,7 +67,7 @@ end function resample(rng::AbstractRNG, algo::BF, state::ParticleDistribution) ws = weights(state) ess = inv(sum(abs2, ws)) - if ess <= algo.N * 0.5 + if ess <= algo.N * algo.threshold idx = rand(rng, Distributions.Categorical(ws), algo.N) # creates a new particle distribution in GF anyways @@ -120,8 +121,8 @@ struct KalmanFilter <: AbstractFilter end const KF = KalmanFilter function initialize(rng::AbstractRNG, model::StateSpaceModel, algo::KF; kwargs...) - @unpack x, Σ = model.prior - return GaussianDistribution(x, Σ) + @unpack μ, Σ = model.prior + return GaussianDistribution(μ, Σ) end function resample(::AbstractRNG, ::KF, state::GaussianDistribution) @@ -132,28 +133,26 @@ function predict( ::AbstractRNG, model::StateSpaceModel, algo::KF, iter::Int, state; kwargs... ) @unpack Φ, b, Q = model.dyn - x = Φ * state.x + b - Σ = Φ * state.Σ * Φ' + Q - return GaussianDistribution(x, Σ) + return GaussianDistribution(Φ * state.μ + b, Φ * state.Σ * Φ' + Q) end function update(model::StateSpaceModel, algo::KF, iter::Int, state, observation; kwargs...) @unpack H, R = model.obs K = state.Σ * H' / (H * state.Σ * H' + R) - state.x = state.x + K * (observation - H * state.x) - state.Σ = state.Σ - K * H * state.Σ - return state + return GaussianDistribution( + state.μ + K * (observation - H * state.μ), state.Σ - K * H * state.Σ + ) end ## STATE SPACE MODEL ####################################################################### struct GaussianPrior{XT<:AbstractVector,ΣT<:AbstractMatrix} <: StatePrior - x::XT + μ::XT Σ::ΣT end struct LinearGaussianLatentDynamics{ - ΦT<:AbstractMatrix,bT<:AbstractArray,QT<:AbstractMatrix + ΦT<:AbstractMatrix,bT<:AbstractVector,QT<:AbstractMatrix } <: LatentDynamics Φ::ΦT b::bT @@ -167,7 +166,7 @@ struct LinearGaussianObservationProcess{HT<:AbstractMatrix,RT<:AbstractMatrix} < end function SSMProblems.distribution(prior::GaussianPrior; kwargs...) - return MvNormal(prior.x, prior.Σ) + return MvNormal(prior.μ, prior.Σ) end function SSMProblems.distribution( From 1bc17ef86ef9586e99bf383075f6525740289bde Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Mon, 12 May 2025 11:21:01 -0400 Subject: [PATCH 03/11] add kwargs to prior --- SSMProblems/src/SSMProblems.jl | 2 +- SSMProblems/src/utils/forward_simulation.jl | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/SSMProblems/src/SSMProblems.jl b/SSMProblems/src/SSMProblems.jl index 2c7815b..fd58778 100644 --- a/SSMProblems/src/SSMProblems.jl +++ b/SSMProblems/src/SSMProblems.jl @@ -67,7 +67,7 @@ See also [`StatePrior`](@ref). - `Distributions.Distribution`: The distribution of the initial state. """ function distribution(prior::StatePrior; kwargs...) - throw(MethodError(distribution, (prior))) + throw(MethodError(distribution, (prior, kwargs...))) end """ diff --git a/SSMProblems/src/utils/forward_simulation.jl b/SSMProblems/src/utils/forward_simulation.jl index bfcf0eb..64b34a7 100644 --- a/SSMProblems/src/utils/forward_simulation.jl +++ b/SSMProblems/src/utils/forward_simulation.jl @@ -14,7 +14,7 @@ Returns a tuple `(xs, ys)` where `xs` is a vector of latent states (including th state) and `ys` is a vector of observations. """ function sample(rng::AbstractRNG, model::StateSpaceModel, T::Integer; kwargs...) - xs = OffsetVector(fill(simulate(rng, model.prior), T + 1), -1) + xs = OffsetVector(fill(simulate(rng, model.prior; kwargs...), T + 1), -1) for t in 1:T xs[t] = simulate(rng, model.dyn, t, xs[t - 1]; kwargs...) end From 322458ee05d24d313ac640703e55801044defeab Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Mon, 12 May 2025 11:22:22 -0400 Subject: [PATCH 04/11] update GF to account for interface changes --- .../GFTest/models/dummy_linear_gaussian.jl | 87 +++--- .../src/GFTest/models/linear_gaussian.jl | 3 +- GeneralisedFilters/src/GeneralisedFilters.jl | 48 ++-- GeneralisedFilters/src/algorithms/forward.jl | 6 +- GeneralisedFilters/src/algorithms/kalman.jl | 54 ++-- .../src/algorithms/particles.jl | 53 ++-- GeneralisedFilters/src/algorithms/rbpf.jl | 249 +++++++++--------- GeneralisedFilters/src/callbacks.jl | 45 ++-- GeneralisedFilters/src/containers.jl | 154 ++++++----- GeneralisedFilters/src/models/discrete.jl | 36 ++- GeneralisedFilters/src/models/hierarchical.jl | 110 ++++---- .../src/models/linear_gaussian.jl | 151 +++++------ GeneralisedFilters/src/resamplers.jl | 26 +- GeneralisedFilters/test/runtests.jl | 40 ++- 14 files changed, 534 insertions(+), 528 deletions(-) diff --git a/GeneralisedFilters/src/GFTest/models/dummy_linear_gaussian.jl b/GeneralisedFilters/src/GFTest/models/dummy_linear_gaussian.jl index fc6a35e..4968afc 100644 --- a/GeneralisedFilters/src/GFTest/models/dummy_linear_gaussian.jl +++ b/GeneralisedFilters/src/GFTest/models/dummy_linear_gaussian.jl @@ -22,20 +22,22 @@ export InnerDynamics, create_dummy_linear_gaussian_model Linear Gaussian dynamics conditonal on the (previous) outer state (u_t), defined by: x_{t+1} = A x_t + b + C u_t + w_t """ -struct InnerDynamics{ - T,TMat<:AbstractMatrix{T},TVec<:AbstractVector{T},TCov<:AbstractMatrix{T} -} <: LinearGaussianLatentDynamics{T} - μ0::TVec - Σ0::TCov +struct InnerDynamics{TMat<:AbstractMatrix,TVec<:AbstractVector,TCov<:AbstractMatrix} <: + LinearGaussianLatentDynamics A::TMat b::TVec C::TMat Q::TCov end +struct InnerPrior{TVec<:AbstractVector,TCov<:AbstractMatrix} <: GaussianPrior + μ0::TVec + Σ0::TCov +end + # CPU methods -GeneralisedFilters.calc_μ0(dyn::InnerDynamics; kwargs...) = dyn.μ0 -GeneralisedFilters.calc_Σ0(dyn::InnerDynamics; kwargs...) = dyn.Σ0 +GeneralisedFilters.calc_μ0(prior::InnerPrior; kwargs...) = prior.μ0 +GeneralisedFilters.calc_Σ0(prior::InnerPrior; kwargs...) = prior.Σ0 GeneralisedFilters.calc_A(dyn::InnerDynamics, ::Integer; kwargs...) = dyn.A function GeneralisedFilters.calc_b(dyn::InnerDynamics, ::Integer; prev_outer, kwargs...) return dyn.b + dyn.C * prev_outer @@ -43,39 +45,41 @@ end GeneralisedFilters.calc_Q(dyn::InnerDynamics, ::Integer; kwargs...) = dyn.Q # GPU methods -function GeneralisedFilters.batch_calc_μ0s(dyn::InnerDynamics{T}, N; kwargs...) where {T} - μ0s = CuArray{T}(undef, length(dyn.μ0), N) - return μ0s[:, :] .= cu(dyn.μ0) +function GeneralisedFilters.batch_calc_μ0s(prior::InnerPrior, N; kwargs...) + # μ0s = CuArray{T}(undef, length(prior.μ0), N) + # return μ0s[:, :] .= cu(prior.μ0) + return repeat(cu(prior.μ0), 1, N) end -function GeneralisedFilters.batch_calc_Σ0s( - dyn::InnerDynamics{T}, N::Integer; kwargs... -) where {T} - Σ0s = CuArray{T}(undef, size(dyn.Σ0)..., N) - return Σ0s[:, :, :] .= cu(dyn.Σ0) +function GeneralisedFilters.batch_calc_Σ0s(prior::InnerPrior, N::Integer; kwargs...) + # Σ0s = CuArray{T}(undef, size(dyn.Σ0)..., N) + # return Σ0s[:, :, :] .= cu(dyn.Σ0) + return repeat(cu(prior.Σ0), 1, N) end function GeneralisedFilters.batch_calc_As( - dyn::InnerDynamics{T}, ::Integer, N::Integer; kwargs... -) where {T} - As = CuArray{T}(undef, size(dyn.A)..., N) - As[:, :, :] .= cu(dyn.A) - return As + dyn::InnerDynamics, ::Integer, N::Integer; kwargs... +) + # As = CuArray{T}(undef, size(dyn.A)..., N) + # As[:, :, :] .= cu(dyn.A) + return repeat(cu(dyn.A), 1, N) end function GeneralisedFilters.batch_calc_bs( - dyn::InnerDynamics{T}, ::Integer, N::Integer; prev_outer, kwargs... -) where {T} - Cs = CuArray{T}(undef, size(dyn.C)..., N) - Cs[:, :, :] .= cu(dyn.C) + dyn::InnerDynamics, ::Integer, N::Integer; prev_outer, kwargs... +) + # Cs = CuArray{T}(undef, size(dyn.C)..., N) + # Cs[:, :, :] .= cu(dyn.C) + Cs = repeat(cu(dyn.C), 1, N) return NNlib.batched_vec(Cs, prev_outer) .+ cu(dyn.b) end function GeneralisedFilters.batch_calc_Qs( - dyn::InnerDynamics{T}, ::Integer, N::Integer; kwargs... -) where {T} - Q = CuArray{T}(undef, size(dyn.Q)..., N) - return Q[:, :, :] .= cu(dyn.Q) + dyn::InnerDynamics, ::Integer, N::Integer; kwargs... +) + # Q = CuArray{T}(undef, size(dyn.Q)..., N) + # return Q[:, :, :] .= cu(dyn.Q) + return repeat(cu(dyn.Q), 1, N) end function create_dummy_linear_gaussian_model( @@ -113,36 +117,41 @@ function create_dummy_linear_gaussian_model( full_model = create_homogeneous_linear_gaussian_model(μ0, Σ0s, A, b, Q, H, c, R) # Create hierarchical model + outer_prior = GeneralisedFilters.HomogeneousGaussianPrior( + μ0[1:D_outer], Σ0s[1:D_outer, 1:D_outer] + ) + outer_dyn = GeneralisedFilters.HomogeneousLinearGaussianLatentDynamics( - μ0[1:D_outer], - Σ0s[1:D_outer, 1:D_outer], - A[1:D_outer, 1:D_outer], - b[1:D_outer], - Q[1:D_outer, 1:D_outer], + A[1:D_outer, 1:D_outer], b[1:D_outer], Q[1:D_outer, 1:D_outer] ) - inner_dyn = if static_arrays - InnerDynamics( + + inner_prior, inner_dyn = if static_arrays + prior = InnerPrior( SVector{D_inner,T}(μ0[(D_outer + 1):end]), SMatrix{D_inner,D_inner,T}(Σ0s[(D_outer + 1):end, (D_outer + 1):end]), + ) + dyn = InnerDynamics( SMatrix{D_inner,D_outer,T}(A[(D_outer + 1):end, (D_outer + 1):end]), SVector{D_inner,T}(b[(D_outer + 1):end]), SMatrix{D_inner,D_outer,T}(A[(D_outer + 1):end, 1:D_outer]), SMatrix{D_inner,D_inner,T}(Q[(D_outer + 1):end, (D_outer + 1):end]), ) + prior, dyn else - InnerDynamics( - μ0[(D_outer + 1):end], - Σ0s[(D_outer + 1):end, (D_outer + 1):end], + prior = InnerPrior(μ0[(D_outer + 1):end], Σ0s[(D_outer + 1):end, (D_outer + 1):end]) + dyn = InnerDynamics( A[(D_outer + 1):end, (D_outer + 1):end], b[(D_outer + 1):end], A[(D_outer + 1):end, 1:D_outer], Q[(D_outer + 1):end, (D_outer + 1):end], ) + prior, dyn end + obs = GeneralisedFilters.HomogeneousLinearGaussianObservationProcess( H[:, (D_outer + 1):end], c, R ) - hier_model = HierarchicalSSM(outer_dyn, inner_dyn, obs) + hier_model = HierarchicalSSM(outer_prior, outer_dyn, inner_prior, inner_dyn, obs) return full_model, hier_model end diff --git a/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl b/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl index 8c4fff6..bb483fb 100644 --- a/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl +++ b/GeneralisedFilters/src/GFTest/models/linear_gaussian.jl @@ -19,7 +19,8 @@ function create_linear_gaussian_model( end function _compute_joint(model, T::Integer) - (; μ0, Σ0, A, b, Q) = model.dyn + (; μ0, Σ0) = model.prior + (; A, b, Q) = model.dyn (; H, c, R) = model.obs Dy, Dx = size(H) diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index 6dc97c6..937acb6 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -23,57 +23,57 @@ abstract type AbstractFilter <: AbstractSampler end abstract type AbstractBatchFilter <: AbstractFilter end """ - initialise([rng,] model, alg; kwargs...) + initialise([rng,] model, algo; kwargs...) Propose an initial state distribution. """ function initialise end """ - step([rng,] model, alg, iter, state, observation; kwargs...) + step([rng,] model, algo, iter, state, observation; kwargs...) Perform a combined predict and update call of the filtering on the state. """ function step end """ - predict([rng,] model, alg, iter, filtered; kwargs...) + predict([rng,] model, algo, iter, filtered; kwargs...) Propagate the filtered distribution forward in time. """ function predict end """ - update(model, alg, iter, proposed, observation; kwargs...) + update(model, algo, iter, proposed, observation; kwargs...) Update beliefs on the propagated distribution given an observation. """ function update end -function initialise(model, alg; kwargs...) - return initialise(default_rng(), model, alg; kwargs...) +function initialise(model, algo; kwargs...) + return initialise(default_rng(), model, algo; kwargs...) end -function predict(model, alg, step, filtered, observation; kwargs...) - return predict(default_rng(), model, alg, step, filtered; kwargs...) +function predict(model, algo, step, filtered, observation; kwargs...) + return predict(default_rng(), model, algo, step, filtered; kwargs...) end function filter( rng::AbstractRNG, model::AbstractStateSpaceModel, - alg::AbstractFilter, + algo::AbstractFilter, observations::AbstractVector; callback::Union{AbstractCallback,Nothing}=nothing, kwargs..., ) - state = initialise(rng, model, alg; kwargs...) - isnothing(callback) || callback(model, alg, state, observations, PostInit; kwargs...) + state = initialise(rng, model, algo; kwargs...) + isnothing(callback) || callback(model, algo, state, observations, PostInit; kwargs...) - log_evidence = initialise_log_evidence(alg, model) + log_evidence = initialise_log_evidence(algo, model) for t in eachindex(observations) state, ll_increment = step( - rng, model, alg, t, state, observations[t]; callback, kwargs... + rng, model, algo, t, state, observations[t]; callback, kwargs... ) log_evidence += ll_increment end @@ -82,39 +82,39 @@ function filter( end function initialise_log_evidence(::AbstractFilter, model::AbstractStateSpaceModel) - return zero(SSMProblems.arithmetic_type(model)) + return 0 end -function initialise_log_evidence(alg::AbstractBatchFilter, model::AbstractStateSpaceModel) - return CUDA.zeros(SSMProblems.arithmetic_type(model), alg.batch_size) -end +# function initialise_log_evidence(alg::AbstractBatchFilter, model::AbstractStateSpaceModel) +# return CUDA.zeros(SSMProblems.arithmetic_type(model), alg.batch_size) +# end function filter( model::AbstractStateSpaceModel, - alg::AbstractFilter, + algo::AbstractFilter, observations::AbstractVector; kwargs..., ) - return filter(default_rng(), model, alg, observations; kwargs...) + return filter(default_rng(), model, algo, observations; kwargs...) end function step( rng::AbstractRNG, model::AbstractStateSpaceModel, - alg::AbstractFilter, + algo::AbstractFilter, iter::Integer, state, observation; callback::Union{AbstractCallback,Nothing}=nothing, kwargs..., ) - state = predict(rng, model, alg, iter, state, observation; kwargs...) + state = predict(rng, model, algo, iter, state, observation; kwargs...) isnothing(callback) || - callback(model, alg, iter, state, observation, PostPredict; kwargs...) + callback(model, algo, iter, state, observation, PostPredict; kwargs...) - state, ll_increment = update(model, alg, iter, state, observation; kwargs...) + state, ll_increment = update(model, algo, iter, state, observation; kwargs...) isnothing(callback) || - callback(model, alg, iter, state, observation, PostUpdate; kwargs...) + callback(model, algo, iter, state, observation, PostUpdate; kwargs...) return state, ll_increment end diff --git a/GeneralisedFilters/src/algorithms/forward.jl b/GeneralisedFilters/src/algorithms/forward.jl index 81a5f6f..5da6bc1 100644 --- a/GeneralisedFilters/src/algorithms/forward.jl +++ b/GeneralisedFilters/src/algorithms/forward.jl @@ -6,18 +6,18 @@ const FW = ForwardAlgorithm function initialise( rng::AbstractRNG, model::DiscreteStateSpaceModel, ::ForwardAlgorithm; kwargs... ) - return calc_α0(model.dyn; kwargs...) + return calc_α0(model.prior; kwargs...) end function predict( rng::AbstractRNG, - model::DiscreteStateSpaceModel{T}, + model::DiscreteStateSpaceModel, filter::ForwardAlgorithm, step::Integer, states::AbstractVector, observation; kwargs..., -) where {T} +) P = calc_P(model.dyn, step; kwargs...) return (states' * P)' end diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index 878ea84..2f4e8d8 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -12,7 +12,7 @@ KF() = KalmanFilter() function initialise( rng::AbstractRNG, model::LinearGaussianStateSpaceModel, filter::KalmanFilter; kwargs... ) - μ0, Σ0 = calc_initial(model.dyn; kwargs...) + μ0, Σ0 = calc_initial(model.prior; kwargs...) return Gaussian(μ0, Σ0) end @@ -61,23 +61,23 @@ end function initialise( rng::AbstractRNG, - model::LinearGaussianStateSpaceModel{T}, + model::LinearGaussianStateSpaceModel, algo::BatchKalmanFilter; kwargs..., -) where {T} +) μ0s, Σ0s = batch_calc_initial(model.dyn, algo.batch_size; kwargs...) return BatchGaussianDistribution(μ0s, Σ0s) end function predict( rng::AbstractRNG, - model::LinearGaussianStateSpaceModel{T}, + model::LinearGaussianStateSpaceModel, algo::BatchKalmanFilter, iter::Integer, state::BatchGaussianDistribution, observation; kwargs..., -) where {T} +) μs, Σs = state.μs, state.Σs As, bs, Qs = batch_calc_params(model.dyn, iter, algo.batch_size; kwargs...) μ̂s = NNlib.batched_vec(As, μs) .+ bs @@ -86,13 +86,14 @@ function predict( end function update( - model::LinearGaussianStateSpaceModel{T}, + model::LinearGaussianStateSpaceModel, algo::BatchKalmanFilter, iter::Integer, state::BatchGaussianDistribution, observation; kwargs..., -) where {T} +) + # T = Float32 # temporary fix!!! μs, Σs = state.μs, state.Σs Hs, cs, Rs = batch_calc_params(model.obs, iter, algo.batch_size; kwargs...) D = size(observation, 1) @@ -120,8 +121,8 @@ function update( Σ_filt = Σs .- NNlib.batched_mul(K, NNlib.batched_mul(Hs, Σs)) inv_term = NNlib.batched_vec(S_inv, y_res) - log_likes = -T(0.5) * NNlib.batched_vec(reshape(y_res, 1, D, size(S, 3)), inv_term) - log_likes = log_likes .- T(0.5) * (log_dets .+ D * log(T(2π))) + log_likes = -NNlib.batched_vec(reshape(y_res, 1, D, size(S, 3)), inv_term) + log_likes = (log_likes .- (log_dets .+ D * convert(eltype(log_likes), log(2π)))) ./ 2 # HACK: only errors seems to be from numerical stability so will just overwrite log_likes[isnan.(log_likes)] .= -Inf @@ -135,15 +136,23 @@ struct KalmanSmoother <: AbstractSmoother end const KS = KalmanSmoother() -struct StateCallback{T} <: AbstractCallback - proposed_states::Vector{Gaussian{Vector{T},Matrix{T}}} - filtered_states::Vector{Gaussian{Vector{T},Matrix{T}}} +mutable struct StateCallback <: AbstractCallback + proposed_states + filtered_states end -function StateCallback(N::Integer, T::Type) - return StateCallback{T}( - Vector{Gaussian{Vector{T},Matrix{T}}}(undef, N), - Vector{Gaussian{Vector{T},Matrix{T}}}(undef, N), - ) + +function (callback::StateCallback)( + model::LinearGaussianStateSpaceModel, + algo::KalmanFilter, + state::T, + observations, + ::PostInitCallback; + kwargs..., +) where {T} + N = length(observations) + callback.proposed_states = Vector{T}(undef, N) + callback.filtered_states = Vector{T}(undef, N) + return nothing end function (callback::StateCallback)( @@ -174,15 +183,14 @@ end function smooth( rng::AbstractRNG, - model::LinearGaussianStateSpaceModel{T}, + model::LinearGaussianStateSpaceModel, algo::KalmanSmoother, observations::AbstractVector; t_smooth=1, callback=nothing, kwargs..., -) where {T} - cache = StateCallback(length(observations), T) - +) + cache = StateCallback(nothing, nothing) filtered, ll = filter( rng, model, KalmanFilter(), observations; callback=cache, kwargs... ) @@ -199,14 +207,14 @@ end function backward( rng::AbstractRNG, - model::LinearGaussianStateSpaceModel{T}, + model::LinearGaussianStateSpaceModel, algo::KalmanSmoother, iter::Integer, back_state, obs; states_cache, kwargs..., -) where {T} +) μ, Σ = GaussianDistributions.pair(back_state) μ_pred, Σ_pred = GaussianDistributions.pair(states_cache.proposed_states[iter + 1]) μ_filt, Σ_filt = GaussianDistributions.pair(states_cache.filtered_states[iter]) diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index 879e157..b46b983 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -77,7 +77,7 @@ function step( ) # capture the marginalized log-likelihood state = resample(rng, algo.resampler, state; ref_state) - marginalization_term = logsumexp(state.log_weights) + marginalization_term = log_marginal_likelihood(state) isnothing(callback) || callback(model, algo, iter, state, observation, PostResample; kwargs...) @@ -95,27 +95,26 @@ end function initialise( rng::AbstractRNG, - model::StateSpaceModel{T}, - filter::ParticleFilter; + model::StateSpaceModel, + algo::ParticleFilter; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., -) where {T} - particles = map(1:(filter.N)) do i +) + particles = map(1:(algo.N)) do i if !isnothing(ref_state) && i == 1 ref_state[0] else - SSMProblems.simulate(rng, model.dyn; kwargs...) + SSMProblems.simulate(rng, model.prior; kwargs...) end end - log_ws = zeros(T, filter.N) - return ParticleDistribution(particles, log_ws) + return ParticleDistribution(particles) end function predict( rng::AbstractRNG, model::StateSpaceModel, - filter::ParticleFilter, + algo::ParticleFilter, iter::Integer, state::ParticleDistribution, observation; @@ -126,44 +125,39 @@ function predict( if !isnothing(ref_state) && i == 1 ref_state[iter] else - simulate(rng, model, filter.proposal, iter, particle, observation; kwargs...) + simulate(rng, model, algo.proposal, iter, particle, observation; kwargs...) end end - state.log_weights += - map(zip(proposed_particles, state.particles)) do (new_state, prev_state) - log_f = SSMProblems.logdensity( - model.dyn, iter, prev_state, new_state; kwargs... - ) + log_weights = map(zip(proposed_particles, state.particles)) do (new_state, prev_state) + log_f = SSMProblems.logdensity(model.dyn, iter, prev_state, new_state; kwargs...) - log_q = SSMProblems.logdensity( - model, filter.proposal, iter, prev_state, new_state, observation; kwargs... - ) + log_q = SSMProblems.logdensity( + model, algo.proposal, iter, prev_state, new_state, observation; kwargs... + ) - (log_f - log_q) - end + (log_f - log_q) + end state.particles = proposed_particles - - return state + return update_weights(state, log_weights) end function update( - model::StateSpaceModel{T}, - filter::ParticleFilter, + model::StateSpaceModel, + algo::ParticleFilter, iter::Integer, state::ParticleDistribution, observation; kwargs..., -) where {T} +) log_increments = map( x -> SSMProblems.logdensity(model.obs, iter, x, observation; kwargs...), state.particles, ) - state.log_weights += log_increments - - return state, logsumexp(state.log_weights) + state = update_weights(state, log_increments) + return state, log_marginal_likelihood(state) end struct LatentProposal <: AbstractProposal end @@ -200,7 +194,7 @@ end function predict( rng::AbstractRNG, model::StateSpaceModel, - filter::BootstrapFilter, + algo::BootstrapFilter, iter::Integer, state::ParticleDistribution, observation=nothing; @@ -228,6 +222,7 @@ function filter( kwargs..., ) ssm = StateSpaceModel( + HierarchicalPrior(model.outer_prior, model.inner_model.prior), HierarchicalDynamics(model.outer_dyn, model.inner_model.dyn), HierarchicalObservations(model.inner_model.obs), ) diff --git a/GeneralisedFilters/src/algorithms/rbpf.jl b/GeneralisedFilters/src/algorithms/rbpf.jl index e73bc72..aa29d62 100644 --- a/GeneralisedFilters/src/algorithms/rbpf.jl +++ b/GeneralisedFilters/src/algorithms/rbpf.jl @@ -31,15 +31,13 @@ function initialise( x = if !isnothing(ref_state) && i == 1 ref_state[0] else - SSMProblems.simulate(rng, model.outer_dyn; kwargs...) + SSMProblems.simulate(rng, model.outer_prior; kwargs...) end z = initialise(rng, model.inner_model, algo.inner_algo; new_outer=x, kwargs...) - RaoBlackwellisedParticle(x, z) end - log_ws = zeros(T, algo.N) - return ParticleDistribution(particles, log_ws) + return ParticleDistribution(particles) end function predict( @@ -77,10 +75,10 @@ function predict( end function update( - model::HierarchicalSSM{T}, algo::RBPF, iter::Integer, state, observation; kwargs... -) where {T} - for i in 1:(algo.N) - state.particles[i].z, log_increments = update( + model::HierarchicalSSM, algo::RBPF, iter::Integer, state, observation; kwargs... +) + log_increments = map(1:(algo.N)) do i + state.particles[i].z, log_increment = update( model.inner_model, algo.inner_algo, iter, @@ -89,128 +87,129 @@ function update( new_outer=state.particles[i].x, kwargs..., ) - state.log_weights[i] += log_increments + log_increment end - return state, logsumexp(state.log_weights) + state = update_weights(state, log_increments) + return state, log_marginal_likelihood(state) end ################################# #### GPU-ACCELERATED VERSION #### ################################# -struct BatchRBPF{F<:AbstractFilter,RS<:AbstractResampler} <: AbstractParticleFilter - inner_algo::F - N::Int - resampler::RS -end -function BatchRBPF( - inner_algo, n_particles; threshold::Real=1.0, resampler::AbstractResampler=Systematic() -) - return BatchRBPF(inner_algo, n_particles, ESSResampler(threshold, resampler)) -end - -function searchsorted!(ws_cdf, us, idxs) - index = (blockIdx().x - 1) * blockDim().x + threadIdx().x - stride = blockDim().x * gridDim().x - for i in index:stride:length(us) - # Binary search - left = 1 - right = length(ws_cdf) - while left < right - mid = (left + right) ÷ 2 - if ws_cdf[mid] < us[i] - left = mid + 1 - else - right = mid - end - end - idxs[i] = left - end -end - -function initialise( - rng::AbstractRNG, - model::HierarchicalSSM{T}, - algo::BatchRBPF; - ref_state::Union{Nothing,AbstractVector}=nothing, - kwargs..., -) where {T} - N = algo.N - outer_dyn, inner_model = model.outer_dyn, model.inner_model - - xs = SSMProblems.batch_simulate(rng, outer_dyn, N; ref_state, kwargs...) - - # Set reference trajectory - if ref_state !== nothing - xs[:, 1] = ref_state[0] - end - - zs = initialise(rng, inner_model, algo.inner_algo; new_outer=xs, kwargs...) - log_ws = CUDA.zeros(T, N) - - return RaoBlackwellisedParticleDistribution( - BatchRaoBlackwellisedParticles(xs, zs), log_ws - ) -end - -function predict( - rng::AbstractRNG, - model::HierarchicalSSM, - algo::BatchRBPF, - iter::Integer, - state::RaoBlackwellisedParticleDistribution, - observation; - ref_state::Union{Nothing,AbstractVector}=nothing, - kwargs..., -) - outer_dyn, inner_model = model.outer_dyn, model.inner_model - - new_xs = SSMProblems.batch_simulate( - rng, outer_dyn, iter, state.particles.xs; ref_state, kwargs... - ) - - # Set reference trajectory - if ref_state !== nothing - new_xs[:, [1]] = ref_state[iter] - end - - new_zs = predict( - rng, - inner_model, - algo.inner_algo, - iter, - state.particles.zs, - observation; - prev_outer=state.particles.xs, - new_outer=new_xs, - kwargs..., - ) - state.particles = BatchRaoBlackwellisedParticles(new_xs, new_zs) - - return state -end - -function update( - model::HierarchicalSSM, - algo::BatchRBPF, - iter::Integer, - state::RaoBlackwellisedParticleDistribution, - obs; - kwargs..., -) - new_zs, inner_lls = update( - model.inner_model, - algo.inner_algo, - iter, - state.particles.zs, - obs; - new_outer=state.particles.xs, - kwargs..., - ) - - state.log_weights += inner_lls - state.particles.zs = new_zs - - return state, logsumexp(state.log_weights) -end +# struct BatchRBPF{F<:AbstractFilter,RS<:AbstractResampler} <: AbstractParticleFilter +# inner_algo::F +# N::Int +# resampler::RS +# end +# function BatchRBPF( +# inner_algo, n_particles; threshold::Real=1.0, resampler::AbstractResampler=Systematic() +# ) +# return BatchRBPF(inner_algo, n_particles, ESSResampler(threshold, resampler)) +# end + +# function searchsorted!(ws_cdf, us, idxs) +# index = (blockIdx().x - 1) * blockDim().x + threadIdx().x +# stride = blockDim().x * gridDim().x +# for i in index:stride:length(us) +# # Binary search +# left = 1 +# right = length(ws_cdf) +# while left < right +# mid = (left + right) ÷ 2 +# if ws_cdf[mid] < us[i] +# left = mid + 1 +# else +# right = mid +# end +# end +# idxs[i] = left +# end +# end + +# function initialise( +# rng::AbstractRNG, +# model::HierarchicalSSM, +# algo::BatchRBPF; +# ref_state::Union{Nothing,AbstractVector}=nothing, +# kwargs..., +# ) +# N = algo.N +# outer_dyn, inner_model = model.outer_dyn, model.inner_model + +# xs = SSMProblems.batch_simulate(rng, outer_dyn, N; ref_state, kwargs...) + +# # Set reference trajectory +# if ref_state !== nothing +# xs[:, 1] = ref_state[0] +# end + +# zs = initialise(rng, inner_model, algo.inner_algo; new_outer=xs, kwargs...) +# # log_ws = CUDA.zeros(T, N) + +# return RaoBlackwellisedParticleDistribution( +# BatchRaoBlackwellisedParticles(xs, zs) +# ) +# end + +# function predict( +# rng::AbstractRNG, +# model::HierarchicalSSM, +# algo::BatchRBPF, +# iter::Integer, +# state::RaoBlackwellisedParticleDistribution, +# observation; +# ref_state::Union{Nothing,AbstractVector}=nothing, +# kwargs..., +# ) +# outer_dyn, inner_model = model.outer_dyn, model.inner_model + +# new_xs = SSMProblems.batch_simulate( +# rng, outer_dyn, iter, state.particles.xs; ref_state, kwargs... +# ) + +# # Set reference trajectory +# if ref_state !== nothing +# new_xs[:, [1]] = ref_state[iter] +# end + +# new_zs = predict( +# rng, +# inner_model, +# algo.inner_algo, +# iter, +# state.particles.zs, +# observation; +# prev_outer=state.particles.xs, +# new_outer=new_xs, +# kwargs..., +# ) +# state.particles = BatchRaoBlackwellisedParticles(new_xs, new_zs) + +# return state +# end + +# function update( +# model::HierarchicalSSM, +# algo::BatchRBPF, +# iter::Integer, +# state::RaoBlackwellisedParticleDistribution, +# obs; +# kwargs..., +# ) +# new_zs, inner_lls = update( +# model.inner_model, +# algo.inner_algo, +# iter, +# state.particles.zs, +# obs; +# new_outer=state.particles.xs, +# kwargs..., +# ) + +# state.log_weights += inner_lls +# state.particles.zs = new_zs + +# return state, logsumexp(state.log_weights) +# end diff --git a/GeneralisedFilters/src/callbacks.jl b/GeneralisedFilters/src/callbacks.jl index acfdecf..d540037 100644 --- a/GeneralisedFilters/src/callbacks.jl +++ b/GeneralisedFilters/src/callbacks.jl @@ -70,20 +70,16 @@ end A callback for dense ancestry storage, which fills a `DenseParticleContainer`. """ -struct DenseAncestorCallback{T} <: AbstractCallback - container::DenseParticleContainer{T} - - function DenseAncestorCallback(::Type{T}) where {T} - particles = OffsetVector(Vector{T}[], -1) - ancestors = Vector{Int}[] - return new{T}(DenseParticleContainer(particles, ancestors)) - end +mutable struct DenseAncestorCallback <: AbstractCallback + container end function (c::DenseAncestorCallback)( model, filter, state, data, ::PostInitCallback; kwargs... ) - push!(c.container.particles, deepcopy(state.particles)) + c.container = DenseParticleContainer( + OffsetVector([deepcopy(state.particles)], -1), Vector{Int}[] + ) return nothing end @@ -366,18 +362,12 @@ end A callback for sparse ancestry storage, which preallocates and returns a populated `ParticleTree` object. """ -struct AncestorCallback{T} <: AbstractCallback - tree::ParticleTree{T} -end - -function AncestorCallback(::Type{T}, N::Integer, C::Real=1.0) where {T} - M = floor(Int64, C * N * log(N)) - nodes = Vector{T}(undef, N) - return new{T}(ParticleTree(nodes, M)) +mutable struct AncestorCallback <: AbstractCallback + tree end function (c::AncestorCallback)(model, filter, state, data, ::PostInitCallback; kwargs...) - @inbounds c.tree.states[1:(filter.N)] = deepcopy(state.particles) + c.tree = ParticleTree(state.particles, floor(Int64, filter.N * log(filter.N))) return nothing end @@ -395,22 +385,19 @@ end A callback which follows the resampling indices over the filtering algorithm. This is more of a debug tool and visualizer for various resapmling algorithms. """ -struct ResamplerCallback <: AbstractCallback - tree::ParticleTree +mutable struct ResamplerCallback <: AbstractCallback + tree +end - function ResamplerCallback(N::Integer, C::Real=1.0) - M = floor(Int64, C * N * log(N)) - nodes = collect(1:N) - return new(ParticleTree(nodes, M)) - end +function (c::ResamplerCallback)(model, filter, state, data, ::PostInitCallback; kwargs...) + c.tree = ParticleTree(collect(1:N), floor(Int64, filter.N * log(filter.N))) + return nothing end function (c::ResamplerCallback)( model, filter, step, state, data, ::PostResampleCallback; kwargs... ) - if step != 1 - prune!(c.tree, get_offspring(state.ancestors)) - insert!(c.tree, collect(1:(filter.N)), state.ancestors) - end + prune!(c.tree, get_offspring(state.ancestors)) + insert!(c.tree, collect(1:(filter.N)), state.ancestors) return nothing end diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index 9920b9b..f9dd1db 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -8,41 +8,57 @@ A container for particle filters which composes the weighted sample into a distibution-like object, with the states (or particles) distributed accoring to their log-weights. """ -mutable struct ParticleDistribution{PT,WT<:Real} +abstract type ParticleDistribution{PT} end + +Base.collect(state::ParticleDistribution) = state.particles +Base.length(state::ParticleDistribution) = length(state.particles) +Base.keys(state::ParticleDistribution) = LinearIndices(state.particles) + +# not sure if this is kosher, since it doesn't follow the convention of Base.getindex +Base.@propagate_inbounds Base.getindex(state::ParticleDistribution, i) = state.particles[i] + +mutable struct Particles{PT} <: ParticleDistribution{PT} + particles::Vector{PT} + ancestors::Vector{Int} +end + +mutable struct WeightedParticles{PT,WT<:Real} <: ParticleDistribution{PT} particles::Vector{PT} ancestors::Vector{Int} log_weights::Vector{WT} end -function ParticleDistribution(particles::Vector{PT}, log_weights::Vector{WT}) where {PT,WT} + +# TODO: replace Gaussian with custom internals +# struct GaussianDistribution{PT,ΣT} <: ParticleDistribution{PT} +# μ::PT +# Σ::ΣT +# end + +function ParticleDistribution(particles::AbstractVector) N = length(particles) - return ParticleDistribution(particles, Vector{Int}(1:N), log_weights) + return Particles(particles, Vector{Int}(1:N)) end -StatsBase.weights(state::ParticleDistribution) = softmax(state.log_weights) - -Base.collect(state::ParticleDistribution) = state.particles -Base.length(state::ParticleDistribution) = length(state.particles) -Base.keys(state::ParticleDistribution) = LinearIndices(state.particles) +function ParticleDistribution(particles::AbstractVector, log_weights::Vector{<:Real}) + N = length(particles) + return WeightedParticles(particles, Vector{Int}(1:N), log_weights) +end -# not sure if this is kosher, since it doesn't follow the convention of Base.getindex -Base.@propagate_inbounds Base.getindex(state::ParticleDistribution, i) = state.particles[i] -# Base.@propagate_inbounds Base.getindex(state::ParticleDistribution, i::Vector{Int}) = state.particles[i] +StatsBase.weights(state::Particles) = fill(1 / length(state), length(state)) +StatsBase.weights(state::WeightedParticles) = softmax(state.log_weights) -function reset_weights!(state::ParticleDistribution{T,WT}) where {T,WT<:Real} - fill!(state.log_weights, zero(WT)) - return state.log_weights +function update_weights(state::Particles, log_weights) + return WeightedParticles(state.particles, state.ancestors, log_weights) end -function update_ref!( - state::ParticleDistribution, ref_state::Union{Nothing,AbstractVector}, step::Integer=0 -) - if !isnothing(ref_state) - state.particles[1] = ref_state[step] - state.ancestors[1] = 1 - end - return proposed +function update_weights(state::WeightedParticles, log_weights) + state.log_weights += log_weights + return state end +log_marginal_likelihood(state::Particles) = log(length(state)) +log_marginal_likelihood(state::WeightedParticles) = logsumexp(state.log_weights) + ## RAO-BLACKWELLISED PARTICLE ############################################################## """ @@ -63,26 +79,26 @@ mutable struct BatchRaoBlackwellisedParticles{XT,ZT} zs::ZT end -mutable struct RaoBlackwellisedParticleDistribution{ - T,M<:CUDA.AbstractMemory,PT<:BatchRaoBlackwellisedParticles -} - particles::PT - ancestors::CuVector{Int,M} - log_weights::CuVector{T,M} -end -function RaoBlackwellisedParticleDistribution( - particles::PT, log_weights::CuVector{T,M} -) where {T,M,PT} - N = length(log_weights) - return RaoBlackwellisedParticleDistribution(particles, CuVector{Int}(1:N), log_weights) -end - -function StatsBase.weights(state::RaoBlackwellisedParticleDistribution) - return softmax(state.log_weights) -end -function Base.length(state::RaoBlackwellisedParticleDistribution) - return length(state.log_weights) -end +# mutable struct RaoBlackwellisedParticleDistribution{ +# T,M<:CUDA.AbstractMemory,PT<:BatchRaoBlackwellisedParticles +# } +# particles::PT +# ancestors::CuVector{Int,M} +# log_weights::CuVector{T,M} +# end +# function RaoBlackwellisedParticleDistribution( +# particles::PT, log_weights::CuVector{T,M} +# ) where {T,M,PT} +# N = length(log_weights) +# return RaoBlackwellisedParticleDistribution(particles, CuVector{Int}(1:N), log_weights) +# end + +# function StatsBase.weights(state::RaoBlackwellisedParticleDistribution) +# return softmax(state.log_weights) +# end +# function Base.length(state::RaoBlackwellisedParticleDistribution) +# return length(state.log_weights) +# end # Allow particle to be get and set via tree_states[:, 1:N] = states function Base.getindex(state::BatchRaoBlackwellisedParticles, i) @@ -100,33 +116,33 @@ function Base.setindex!( end Base.length(state::BatchRaoBlackwellisedParticles) = size(state.xs, 2) -function expand(particles::CuArray{T,2,Mem}, M::Integer) where {T,Mem<:CUDA.AbstractMemory} - new_particles = CuArray(zeros(eltype(particles), size(particles, 1), M)) - new_particles[:, 1:size(particles, 2)] = particles - return new_particles -end - -# Method for increasing size of particle container -function expand(p::BatchRaoBlackwellisedParticles, M::Integer) - new_x = expand(p.xs, M) - new_z = expand(p.zs, M) - return BatchRaoBlackwellisedParticles(new_x, new_z) -end - -function update_ref!( - state::RaoBlackwellisedParticleDistribution, - ref_state::Union{Nothing,AbstractVector}, - step::Integer=0, -) - if !isnothing(ref_state) - CUDA.@allowscalar begin - state.particles.xs[:, 1] = ref_state[step].xs - state.particles.zs[1] = ref_state[step].zs - state.ancestors[1] = 1 - end - end - return proposed -end +# function expand(particles::CuArray{T,2,Mem}, M::Integer) where {T,Mem<:CUDA.AbstractMemory} +# new_particles = CuArray(zeros(eltype(particles), size(particles, 1), M)) +# new_particles[:, 1:size(particles, 2)] = particles +# return new_particles +# end + +# # Method for increasing size of particle container +# function expand(p::BatchRaoBlackwellisedParticles, M::Integer) +# new_x = expand(p.xs, M) +# new_z = expand(p.zs, M) +# return BatchRaoBlackwellisedParticles(new_x, new_z) +# end + +# function update_ref!( +# state::RaoBlackwellisedParticleDistribution, +# ref_state::Union{Nothing,AbstractVector}, +# step::Integer=0, +# ) +# if !isnothing(ref_state) +# CUDA.@allowscalar begin +# state.particles.xs[:, 1] = ref_state[step].xs +# state.particles.zs[1] = ref_state[step].zs +# state.ancestors[1] = 1 +# end +# end +# return proposed +# end ## BATCH GAUSSIAN DISTRIBUTION ############################################################# diff --git a/GeneralisedFilters/src/models/discrete.jl b/GeneralisedFilters/src/models/discrete.jl index c5cf5bc..dfc9879 100644 --- a/GeneralisedFilters/src/models/discrete.jl +++ b/GeneralisedFilters/src/models/discrete.jl @@ -1,32 +1,26 @@ export DiscreteLatentDynamics export DiscreteStateSpaceModel -export HomogeneousDiscreteLatentDynamics +export HomogenousDiscretePrior, HomogeneousDiscreteLatentDynamics import SSMProblems: distribution import Distributions: Categorical -abstract type DiscreteLatentDynamics{T_state<:Integer,T_prob<:Real} <: - LatentDynamics{T_prob,T_state} end +abstract type DiscretePrior <: StatePrior end +abstract type DiscreteLatentDynamics <: LatentDynamics end function calc_α0 end function calc_P end -const DiscreteStateSpaceModel{T} = SSMProblems.StateSpaceModel{ - T,LD,OD -} where {T,LD<:DiscreteLatentDynamics{<:Integer,T},OD<:ObservationProcess{T}} - -function rb_eltype( - ::DiscreteStateSpaceModel{LD} -) where {T_state,T_prob,LD<:DiscreteLatentDynamics{T_state,T_prob}} - return Vector{T_prob} -end +const DiscreteStateSpaceModel = SSMProblems.StateSpaceModel{ + <:DiscretePrior,<:DiscreteLatentDynamics,<:ObservationProcess +} ####################### #### DISTRIBUTIONS #### ####################### -function SSMProblems.distribution(dyn::DiscreteLatentDynamics; kwargs...) - α0 = calc_α0(dyn; kwargs...) +function SSMProblems.distribution(prior::DiscretePrior; kwargs...) + α0 = calc_α0(prior; kwargs...) return Categorical(α0) end @@ -41,11 +35,13 @@ end #### HOMOGENEOUS DISCRETE MODEL #### #################################### -# TODO: likewise, where do these types come from? -struct HomogeneousDiscreteLatentDynamics{T_state<:Integer,T_prob<:Real} <: - DiscreteLatentDynamics{T_state,T_prob} - α0::Vector{T_prob} - P::Matrix{T_prob} +struct HomogenousDiscretePrior{AT<:AbstractVector} <: StatePrior + α0::AT end -calc_α0(dyn::HomogeneousDiscreteLatentDynamics; kwargs...) = dyn.α0 + +struct HomogeneousDiscreteLatentDynamics{PT<:AbstractMatrix} <: DiscreteLatentDynamics + P::PT +end + +calc_α0(prior::HomogenousDiscretePrior; kwargs...) = prior.α0 calc_P(dyn::HomogeneousDiscreteLatentDynamics, ::Integer; kwargs...) = dyn.P diff --git a/GeneralisedFilters/src/models/hierarchical.jl b/GeneralisedFilters/src/models/hierarchical.jl index ecdc8a6..625200a 100644 --- a/GeneralisedFilters/src/models/hierarchical.jl +++ b/GeneralisedFilters/src/models/hierarchical.jl @@ -1,83 +1,96 @@ import SSMProblems: LatentDynamics, ObservationProcess, simulate export HierarchicalSSM -struct HierarchicalSSM{T<:Real,OD<:LatentDynamics{T},M<:StateSpaceModel{T}} <: +struct HierarchicalSSM{PT<:StatePrior,LD<:LatentDynamics,MT<:StateSpaceModel} <: AbstractStateSpaceModel - outer_dyn::OD - inner_model::M - function HierarchicalSSM( - outer_dyn::LatentDynamics{T}, inner_model::StateSpaceModel{T} - ) where {T} - return new{T,typeof(outer_dyn),typeof(inner_model)}(outer_dyn, inner_model) - end + outer_prior::PT + outer_dyn::LD + inner_model::MT end function HierarchicalSSM( - outer_dyn::LatentDynamics{T}, inner_dyn::LatentDynamics{T}, obs::ObservationProcess{T} -) where {T} - inner_model = StateSpaceModel(inner_dyn, obs) - return HierarchicalSSM(outer_dyn, inner_model) + outer_prior::StatePrior, + outer_dyn::LatentDynamics, + inner_prior::StatePrior, + inner_dyn::LatentDynamics, + obs::ObservationProcess, +) + inner_model = StateSpaceModel(inner_prior, inner_dyn, obs) + return HierarchicalSSM(outer_prior, outer_dyn, inner_model) end -SSMProblems.arithmetic_type(::Type{<:HierarchicalSSM{T}}) where {T} = T -function SSMProblems.arithmetic_type(model::HierarchicalSSM) - return SSMProblems.arithmetic_type(model.outer_dyn) -end +# SSMProblems.arithmetic_type(::Type{<:HierarchicalSSM{T}}) where {T} = T +# function SSMProblems.arithmetic_type(model::HierarchicalSSM) +# return SSMProblems.arithmetic_type(model.outer_dyn) +# end function AbstractMCMC.sample( rng::AbstractRNG, model::HierarchicalSSM, T::Integer; kwargs... ) outer_dyn, inner_model = model.outer_dyn, model.inner_model - zs = Vector{eltype(inner_model.dyn)}(undef, T) - xs = Vector{eltype(outer_dyn)}(undef, T) - ys = Vector{eltype(inner_model.obs)}(undef, T) + # zs = OffsetVector(Vector{eltype(inner_model.dyn)}(undef, T + 1), -1) + # xs = OffsetVector(Vector{eltype(outer_dyn)}(undef, T + 1), -1) + xs = OffsetVector(fill(simulate(rng, model.outer_prior; kwargs...), T + 1), -1) + zs = OffsetVector( + fill(simulate(rng, inner_model.prior; new_outer=xs[0], kwargs...), T + 1), -1 + ) + # ys = Vector{eltype(inner_model.obs)}(undef, T) # Simulate outer dynamics - x0 = simulate(rng, outer_dyn; kwargs...) - z0 = simulate(rng, inner_model.dyn; new_outer=x0, kwargs...) + xs[0] = simulate(rng, outer_dyn; kwargs...) + zs[0] = simulate(rng, inner_model.dyn; new_outer=xs[0], kwargs...) for t in 1:T - prev_x = t == 1 ? x0 : xs[t - 1] - prev_z = t == 1 ? z0 : zs[t - 1] - xs[t] = simulate(rng, model.outer_dyn, t, prev_x; kwargs...) + xs[t] = simulate(rng, model.outer_dyn, t, xs[t - 1]; kwargs...) zs[t] = simulate( - rng, inner_model.dyn, t, prev_z; prev_outer=prev_x, new_outer=xs[t], kwargs... + rng, + inner_model.dyn, + t, + zs[t - 1]; + prev_outer=xs[t - 1], + new_outer=xs[t], + kwargs..., ) - ys[t] = simulate(rng, inner_model.obs, t, zs[t]; new_outer=xs[t], kwargs...) + # ys[t] = simulate(rng, inner_model.obs, t, zs[t]; new_outer=xs[t], kwargs...) end - return x0, z0, xs, zs, ys + ys = map(t -> simulate(rng, inner_model.obs, t, zs[t]; new_outer=xs[t], kwargs...)) + return xs, zs, ys end ## Methods to make HierarchicalSSM compatible with the bootstrap filter -struct HierarchicalDynamics{T<:Real,ET,D1<:LatentDynamics{T},D2<:LatentDynamics{T}} <: - LatentDynamics{T,ET} +struct HierarchicalDynamics{D1<:LatentDynamics,D2<:LatentDynamics} <: LatentDynamics outer_dyn::D1 inner_dyn::D2 - function HierarchicalDynamics( - outer_dyn::D1, inner_dyn::D2 - ) where {D1<:LatentDynamics,D2<:LatentDynamics} - ET = RaoBlackwellisedParticle{eltype(outer_dyn),eltype(inner_dyn)} - T = SSMProblems.arithmetic_type(outer_dyn) - return new{T,ET,D1,D2}(outer_dyn, inner_dyn) - end + # function HierarchicalDynamics( + # outer_dyn::D1, inner_dyn::D2 + # ) where {D1<:LatentDynamics,D2<:LatentDynamics} + # ET = RaoBlackwellisedParticle{eltype(outer_dyn),eltype(inner_dyn)} + # T = SSMProblems.arithmetic_type(outer_dyn) + # return new{T,ET,D1,D2}(outer_dyn, inner_dyn) + # end +end + +struct HierarchicalPrior{P1<:StatePrior,P2<:StatePrior} <: StatePrior + outer_prior::P1 + inner_prior::P2 end -function SSMProblems.simulate(rng::AbstractRNG, dyn::HierarchicalDynamics; kwargs...) - outer_dyn, inner_dyn = dyn.outer_dyn, dyn.inner_dyn - x0 = simulate(rng, outer_dyn; kwargs...) - z0 = simulate(rng, inner_dyn; new_outer=x0, kwargs...) +function SSMProblems.simulate(rng::AbstractRNG, prior::HierarchicalPrior; kwargs...) + outer_prior, inner_prior = prior.outer_prior, prior.inner_prior + x0 = simulate(rng, outer_prior; kwargs...) + z0 = simulate(rng, inner_prior; new_outer=x0, kwargs...) return RaoBlackwellisedParticle(x0, z0) end function SSMProblems.simulate( rng::AbstractRNG, - dyn::HierarchicalDynamics, + proc::HierarchicalDynamics, step::Integer, prev_state::RaoBlackwellisedParticle; kwargs..., ) - outer_dyn, inner_dyn = dyn.outer_dyn, dyn.inner_dyn + outer_dyn, inner_dyn = proc.outer_dyn, proc.inner_dyn x = simulate(rng, outer_dyn, step, prev_state.x; kwargs...) z = simulate( rng, inner_dyn, step, prev_state.z; prev_outer=prev_state.x, new_outer=x, kwargs... @@ -85,14 +98,13 @@ function SSMProblems.simulate( return RaoBlackwellisedParticle(x, z) end -struct HierarchicalObservations{T<:Real,ET,OP<:ObservationProcess{T}} <: - ObservationProcess{T,ET} +struct HierarchicalObservations{OP<:ObservationProcess} <: ObservationProcess obs::OP - function HierarchicalObservations(obs::OP) where {OP<:ObservationProcess} - T = SSMProblems.arithmetic_type(obs) - ET = eltype(obs) - return new{T,ET,OP}(obs) - end + # function HierarchicalObservations(obs::OP) where {OP<:ObservationProcess} + # T = SSMProblems.arithmetic_type(obs) + # ET = eltype(obs) + # return new{T,ET,OP}(obs) + # end end function SSMProblems.distribution( diff --git a/GeneralisedFilters/src/models/linear_gaussian.jl b/GeneralisedFilters/src/models/linear_gaussian.jl index da4a2ef..eec2e16 100644 --- a/GeneralisedFilters/src/models/linear_gaussian.jl +++ b/GeneralisedFilters/src/models/linear_gaussian.jl @@ -1,3 +1,4 @@ +export GaussianPrior export LinearGaussianLatentDynamics export LinearGaussianObservationProcess export LinearGaussianStateSpaceModel @@ -7,14 +8,16 @@ import SSMProblems: distribution import Distributions: MvNormal import LinearAlgebra: cholesky -abstract type LinearGaussianLatentDynamics{T} <: SSMProblems.LatentDynamics{T,Vector{T}} end +abstract type GaussianPrior <: StatePrior end function calc_μ0 end function calc_Σ0 end -function calc_initial(dyn::LinearGaussianLatentDynamics; kwargs...) - return calc_μ0(dyn; kwargs...), calc_Σ0(dyn; kwargs...) +function calc_initial(prior::GaussianPrior; kwargs...) + return calc_μ0(prior; kwargs...), calc_Σ0(prior; kwargs...) end +abstract type LinearGaussianLatentDynamics <: LatentDynamics end + function calc_A end function calc_b end function calc_Q end @@ -26,8 +29,7 @@ function calc_params(dyn::LinearGaussianLatentDynamics, step::Integer; kwargs... ) end -abstract type LinearGaussianObservationProcess{T} <: - SSMProblems.ObservationProcess{T,Vector{T}} end +abstract type LinearGaussianObservationProcess <: ObservationProcess end function calc_H end function calc_c end @@ -40,20 +42,21 @@ function calc_params(obs::LinearGaussianObservationProcess, step::Integer; kwarg ) end -const LinearGaussianStateSpaceModel{T} = SSMProblems.StateSpaceModel{ - T,D,O -} where {T,D<:LinearGaussianLatentDynamics{T},O<:LinearGaussianObservationProcess{T}} +const LinearGaussianStateSpaceModel = StateSpaceModel{ + <:GaussianPrior,<:LinearGaussianLatentDynamics,<:LinearGaussianObservationProcess +} -function rb_eltype(::LinearGaussianStateSpaceModel{T}) where {T} - return Gaussian{Vector{T},Matrix{T}} +function rb_eltype(model::LinearGaussianStateSpaceModel) + μ0, Σ0 = calc_initial(model.prior) + return Gaussian{typeof(μ0),typeof(Σ0)} end ####################### #### DISTRIBUTIONS #### ####################### -function SSMProblems.distribution(dyn::LinearGaussianLatentDynamics; kwargs...) - μ0, Σ0 = calc_initial(dyn; kwargs...) +function SSMProblems.distribution(prior::GaussianPrior; kwargs...) + μ0, Σ0 = calc_initial(prior; kwargs...) return MvNormal(μ0, Σ0) end @@ -75,26 +78,29 @@ end #### HOMOGENEOUS LINEAR GAUSSIAN MODEL #### ########################################### -struct HomogeneousLinearGaussianLatentDynamics{ - T<:Real,ΣT<:AbstractMatrix{T},AT<:AbstractMatrix{T},QT<:AbstractMatrix{T} -} <: LinearGaussianLatentDynamics{T} - μ0::Vector{T} +struct HomogeneousGaussianPrior{XT<:AbstractVector,ΣT<:AbstractMatrix} <: GaussianPrior + μ0::XT Σ0::ΣT +end +calc_μ0(prior::HomogeneousGaussianPrior; kwargs...) = prior.μ0 +calc_Σ0(prior::HomogeneousGaussianPrior; kwargs...) = prior.Σ0 + +struct HomogeneousLinearGaussianLatentDynamics{ + AT<:AbstractMatrix,bT<:AbstractVector,QT<:AbstractMatrix +} <: LinearGaussianLatentDynamics A::AT - b::Vector{T} + b::bT Q::QT end -calc_μ0(dyn::HomogeneousLinearGaussianLatentDynamics; kwargs...) = dyn.μ0 -calc_Σ0(dyn::HomogeneousLinearGaussianLatentDynamics; kwargs...) = dyn.Σ0 calc_A(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn.A calc_b(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn.b calc_Q(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn.Q struct HomogeneousLinearGaussianObservationProcess{ - T<:Real,HT<:AbstractMatrix{T},RT<:AbstractMatrix{T} -} <: LinearGaussianObservationProcess{T} + HT<:AbstractMatrix,cT<:AbstractVector,RT<:AbstractMatrix +} <: LinearGaussianObservationProcess H::HT - c::Vector{T} + c::cT R::RT end calc_H(obs::HomogeneousLinearGaussianObservationProcess, ::Integer; kwargs...) = obs.H @@ -103,7 +109,8 @@ calc_R(obs::HomogeneousLinearGaussianObservationProcess, ::Integer; kwargs...) = function create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) return SSMProblems.StateSpaceModel( - HomogeneousLinearGaussianLatentDynamics(μ0, Σ0, A, b, Q), + HomogeneousGaussianPrior(μ0, Σ0), + HomogeneousLinearGaussianLatentDynamics(A, b, Q), HomogeneousLinearGaussianObservationProcess(H, c, R), ) end @@ -122,8 +129,8 @@ function batch_calc_cs end function batch_calc_Rs end # TODO: can we remove batch size argument? -function batch_calc_initial(dyn::LinearGaussianLatentDynamics, N::Integer; kwargs...) - return batch_calc_μ0s(dyn, N; kwargs...), batch_calc_Σ0s(dyn, N; kwargs...) +function batch_calc_initial(prior::HomogeneousGaussianPrior, N::Integer; kwargs...) + return batch_calc_μ0s(prior, N; kwargs...), batch_calc_Σ0s(prior, N; kwargs...) end function batch_calc_params( @@ -147,82 +154,64 @@ function batch_calc_params( end function SSMProblems.batch_simulate( - ::AbstractRNG, dyn::HomogeneousLinearGaussianLatentDynamics{T}, N::Integer; kwargs... -) where {T} - μ0, Σ0 = GeneralisedFilters.calc_initial(dyn; kwargs...) - D = length(μ0) - L = cholesky(Σ0).L - Ls = CuArray{T}(undef, size(Σ0)..., N) - Ls[:, :, :] .= cu(L) - return cu(μ0) .+ NNlib.batched_vec(Ls, CUDA.randn(T, D, N)) + ::AbstractRNG, prior::HomogeneousGaussianPrior, N::Integer; kwargs... +) + μ0, Σ0 = GeneralisedFilters.calc_initial(prior; kwargs...) + Ls = repeat(cholesky(Σ0).L, 1, N) + noise = CUDA.randn(T, length(μ0), N) + return cu(μ0) .+ NNlib.batched_vec(Ls, noise) end function SSMProblems.batch_simulate( ::AbstractRNG, - dyn::GeneralisedFilters.HomogeneousLinearGaussianLatentDynamics{T}, + dyn::HomogeneousLinearGaussianLatentDynamics, step::Integer, prev_state; kwargs..., -) where {T} +) N = size(prev_state, 2) - A, b, Q = GeneralisedFilters.calc_params(dyn, step; kwargs...) - D = length(b) - L = cholesky(Q).L - Ls = CuArray{T}(undef, size(Q)..., N) - Ls[:, :, :] .= cu(L) - As = CuArray{T}(undef, size(A)..., N) - As[:, :, :] .= cu(A) - return (NNlib.batched_vec(As, prev_state) .+ cu(b)) + - NNlib.batched_vec(Ls, CUDA.randn(T, D, N)) -end - -function batch_calc_μ0s( - dyn::HomogeneousLinearGaussianLatentDynamics{T}, N::Integer; kwargs... -) where {T} - μ0s = CuArray{T}(undef, length(dyn.μ0), N) - return μ0s[:, :] .= cu(dyn.μ0) -end -function batch_calc_Σ0s( - dyn::HomogeneousLinearGaussianLatentDynamics{T}, N::Integer; kwargs... -) where {T} - Σ0s = CuArray{T}(undef, size(dyn.Σ0)..., N) - return Σ0s[:, :, :] .= cu(dyn.Σ0) + A, b, Q = calc_params(dyn, step; kwargs...) + Ls = repeat(cholesky(Q).L, 1, N) + As = repeat(A, 1, N) + noise = CUDA.randn(T, length(b), N) + return (NNlib.batched_vec(As, prev_state) .+ cu(b)) + NNlib.batched_vec(Ls, noise) +end + +function batch_calc_μ0s(prior::HomogeneousGaussianPrior, N::Integer; kwargs...) + return repeat(cu(prior.μ0), 1, N) +end +function batch_calc_Σ0s(prior::HomogeneousGaussianPrior, N::Integer; kwargs...) + return repeat(cu(prior.Σ0), 1, N) end function batch_calc_As( - dyn::HomogeneousLinearGaussianLatentDynamics{T}, ::Integer, N::Integer; kwargs... -) where {T} - As = CuArray{T}(undef, size(dyn.A)..., N) - return As[:, :, :] .= cu(dyn.A) + dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer, N::Integer; kwargs... +) + return repeat(cu(dyn.A), 1, N) end function batch_calc_bs( - dyn::HomogeneousLinearGaussianLatentDynamics{T}, ::Integer, N::Integer; kwargs... -) where {T} - bs = CuArray{T}(undef, size(dyn.b)..., N) - return bs[:, :] .= cu(dyn.b) + dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer, N::Integer; kwargs... +) + return repeat(cu(dyn.b), 1, N) end function batch_calc_Qs( - dyn::HomogeneousLinearGaussianLatentDynamics{T}, ::Integer, N::Integer; kwargs... -) where {T} - Qs = CuArray{T}(undef, size(dyn.Q)..., N) - return Qs[:, :, :] .= cu(dyn.Q) + dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer, N::Integer; kwargs... +) + return repeat(cu(dyn.Q), 1, N) end function batch_calc_Hs( - obs::HomogeneousLinearGaussianObservationProcess{T}, ::Integer, N::Integer; kwargs... -) where {T} - Hs = CuArray{T}(undef, size(obs.H)..., N) - return Hs[:, :, :] .= cu(obs.H) + obs::HomogeneousLinearGaussianObservationProcess, ::Integer, N::Integer; kwargs... +) + return repeat(cu(obs.H), 1, N) end function batch_calc_cs( - obs::HomogeneousLinearGaussianObservationProcess{T}, ::Integer, N::Integer; kwargs... -) where {T} - cs = CuArray{T}(undef, size(obs.c)..., N) - return cs[:, :] .= cu(obs.c) + obs::HomogeneousLinearGaussianObservationProcess, ::Integer, N::Integer; kwargs... +) + return repeat(cu(obs.c), 1, N) end function batch_calc_Rs( - obs::HomogeneousLinearGaussianObservationProcess{T}, ::Integer, N::Integer; kwargs... -) where {T} - Rs = CuArray{T}(undef, size(obs.R)..., N) - return Rs[:, :, :] .= cu(obs.R) + obs::HomogeneousLinearGaussianObservationProcess, ::Integer, N::Integer; kwargs... +) + return repeat(cu(obs.R), 1, N) end diff --git a/GeneralisedFilters/src/resamplers.jl b/GeneralisedFilters/src/resamplers.jl index 143a2da..3ffca3a 100644 --- a/GeneralisedFilters/src/resamplers.jl +++ b/GeneralisedFilters/src/resamplers.jl @@ -22,21 +22,21 @@ function resample( return construct_new_state(states, idxs) end -function construct_new_state(states::ParticleDistribution{PT,WT}, idxs) where {PT,WT} - return ParticleDistribution(states.particles[idxs], idxs, zeros(WT, length(states))) +function construct_new_state(states::ParticleDistribution{PT}, idxs) where {PT} + return Particles(states.particles[idxs], idxs) end -function construct_new_state( - states::RaoBlackwellisedParticleDistribution{T}, idxs -) where {T} - return RaoBlackwellisedParticleDistribution( - BatchRaoBlackwellisedParticles( - states.particles.xs[:, idxs], states.particles.zs[idxs] - ), - idxs, - CUDA.zeros(T, length(states)), - ) -end +# function construct_new_state( +# states::RaoBlackwellisedParticleDistribution{T}, idxs +# ) where {T} +# return RaoBlackwellisedParticleDistribution( +# BatchRaoBlackwellisedParticles( +# states.particles.xs[:, idxs], states.particles.zs[idxs] +# ), +# idxs, +# CUDA.zeros(T, length(states)), +# ) +# end ## CONDITIONAL RESAMPLING ################################################################## diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index ea8dfa2..3d71d48 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -21,7 +21,7 @@ include("resamplers.jl") for Dy in Dys rng = StableRNG(1234) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy) - _, _, ys = sample(rng, model, 1) + _, ys = sample(rng, model, 1) filtered, ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), ys) @@ -61,7 +61,7 @@ end for Dy in Dys rng = StableRNG(1234) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy) - _, _, ys = sample(rng, model, 2) + _, ys = sample(rng, model, 2) states, ll = GeneralisedFilters.smooth(rng, model, KalmanSmoother(), ys) @@ -88,7 +88,7 @@ end rng = StableRNG(1234) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) - _, _, ys = sample(rng, model, 10) + _, ys = sample(rng, model, 10) bf = BF(2^12; threshold=0.8) bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, ys) @@ -128,7 +128,7 @@ end rng = StableRNG(1234) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) - _, _, ys = sample(rng, model, 10) + _, ys = sample(rng, model, 10) algo = PF(2^10, LinearGaussianProposal(); threshold=0.6) kf_states, kf_ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), ys) @@ -153,8 +153,8 @@ end P = rand(rng, 3, 3) P = P ./ sum(P; dims=2) - struct MixtureModelObservation{T} <: SSMProblems.ObservationProcess{T,T} - μs::Vector{T} + struct MixtureModelObservation{T<:Real,MT<:AbstractVector{T}} <: ObservationProcess + μs::MT end function SSMProblems.logdensity( @@ -169,9 +169,10 @@ end μs = [0.0, 1.0, 2.0] - dyn = HomogeneousDiscreteLatentDynamics{Int,Float64}(α0, P) + prior = HomogenousDiscretePrior(α0) + dyn = HomogeneousDiscreteLatentDynamics(P) obs = MixtureModelObservation(μs) - model = StateSpaceModel(dyn, obs) + model = StateSpaceModel(prior, dyn, obs) observations = [rand(rng)] @@ -212,7 +213,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, D_outer, D_inner, D_obs ) - _, _, ys = sample(rng, full_model, T) + _, ys = sample(rng, full_model, T) # Ground truth Kalman filtering kf_states, kf_ll = GeneralisedFilters.filter(rng, full_model, KalmanFilter(), ys) @@ -290,17 +291,10 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, 1, 1, 1 ) - _, _, ys = sample(rng, full_model, T) - - # Manually create tree to force expansion on second step - particle_type = GeneralisedFilters.RaoBlackwellisedParticle{ - eltype(hier_model.outer_dyn),GeneralisedFilters.rb_eltype(hier_model.inner_model) - } - nodes = Vector{particle_type}(undef, N_particles) - tree = GeneralisedFilters.ParticleTree(nodes, N_particles + 1) + _, ys = sample(rng, full_model, T) + cb = GeneralisedFilters.AncestorCallback(nothing) rbpf = RBPF(KalmanFilter(), N_particles) - cb = GeneralisedFilters.AncestorCallback(tree) GeneralisedFilters.filter(rng, hier_model, rbpf, ys; callback=cb) # TODO: add proper test comparing to dense storage @@ -325,7 +319,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, D_outer, D_inner, D_obs ) - _, _, ys = sample(rng, full_model, T) + _, ys = sample(rng, full_model, T) # Ground truth Kalman filtering kf_states, kf_ll = GeneralisedFilters.filter(rng, full_model, KalmanFilter(), ys) @@ -373,12 +367,12 @@ end rng = StableRNG(SEED) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) - _, _, ys = sample(rng, model, K) + _, ys = sample(rng, model, K) ref_traj = OffsetVector([rand(rng, 1) for _ in 0:K], -1) bf = BF(N_particles; threshold=1.0, resampler=DummyResampler()) - cb = GeneralisedFilters.DenseAncestorCallback(Vector{Float64}) + cb = GeneralisedFilters.DenseAncestorCallback(nothing) bf_state, _ = GeneralisedFilters.filter( rng, model, bf, ys; ref_state=ref_traj, callback=cb ) @@ -413,7 +407,7 @@ end rng = StableRNG(SEED) model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy) - _, _, ys = sample(rng, model, K) + _, ys = sample(rng, model, K) # Kalman smoother state, ks_ll = GeneralisedFilters.smooth( @@ -477,7 +471,7 @@ end full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( rng, D_outer, D_inner, D_obs, T; static_arrays=true ) - _, _, ys = sample(rng, full_model, K) + _, ys = sample(rng, full_model, K) # Kalman smoother state, _ = GeneralisedFilters.smooth( From e199b3b7387c09a4acd0f1021e9dc8af0102108f Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Mon, 12 May 2025 13:48:51 -0400 Subject: [PATCH 05/11] fix discrete model types --- GeneralisedFilters/src/models/discrete.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/GeneralisedFilters/src/models/discrete.jl b/GeneralisedFilters/src/models/discrete.jl index dfc9879..3784c9c 100644 --- a/GeneralisedFilters/src/models/discrete.jl +++ b/GeneralisedFilters/src/models/discrete.jl @@ -35,7 +35,7 @@ end #### HOMOGENEOUS DISCRETE MODEL #### #################################### -struct HomogenousDiscretePrior{AT<:AbstractVector} <: StatePrior +struct HomogenousDiscretePrior{AT<:AbstractVector} <: DiscretePrior α0::AT end From 67e3ef597f5d491c661e3df99d2bb6ebd3615cd5 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Mon, 12 May 2025 14:01:54 -0400 Subject: [PATCH 06/11] update callback unit tests --- GeneralisedFilters/test/runtests.jl | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index 3d71d48..fe69281 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -356,8 +356,8 @@ end struct DummyResampler <: GeneralisedFilters.AbstractResampler end function GeneralisedFilters.sample_ancestors( - rng::AbstractRNG, resampler::DummyResampler, weights::Vector{T} - ) where {T} + rng::AbstractRNG, resampler::DummyResampler, weights::AbstractVector + ) return [mod1(a - 1, length(weights)) for a in 1:length(weights)] end @@ -417,11 +417,11 @@ end N_steps = N_burnin + N_sample bf = BF(N_particles; threshold=0.6) ref_traj = nothing - trajectory_samples = Vector{OffsetVector{Vector{T},Vector{Vector{T}}}}(undef, N_sample) - lls = Vector{T}(undef, N_sample) + trajectory_samples = [] + lls = [] for i in 1:N_steps - cb = GeneralisedFilters.DenseAncestorCallback(Vector{T}) + cb = GeneralisedFilters.DenseAncestorCallback(nothing) bf_state, ll = GeneralisedFilters.filter( rng, model, bf, ys; ref_state=ref_traj, callback=cb ) @@ -429,8 +429,8 @@ end sampled_idx = sample(rng, 1:length(weights), Weights(weights)) global ref_traj = GeneralisedFilters.get_ancestry(cb.container, sampled_idx) if i > N_burnin - trajectory_samples[i - N_burnin] = ref_traj - lls[i - N_burnin] = ll + push!(trajectory_samples, ref_traj) + push!(lls, ll) end end @@ -478,19 +478,13 @@ end rng, full_model, KalmanSmoother(), ys; t_smooth=t_smooth ) - particle_type = GeneralisedFilters.RaoBlackwellisedParticle{ - Vector{T},Gaussian{SVector{D_inner,T},SMatrix{D_inner,D_inner,T,D_inner^2}} - } - N_steps = N_burnin + N_sample rbpf = RBPF(KalmanFilter(), N_particles; threshold=0.6) ref_traj = nothing - trajectory_samples = Vector{OffsetVector{particle_type,Vector{particle_type}}}( - undef, N_sample - ) + trajectory_samples = [] for i in 1:N_steps - cb = GeneralisedFilters.DenseAncestorCallback(particle_type) + cb = GeneralisedFilters.DenseAncestorCallback(nothing) bf_state, _ = GeneralisedFilters.filter( rng, hier_model, rbpf, ys; ref_state=ref_traj, callback=cb ) @@ -499,7 +493,7 @@ end global ref_traj = GeneralisedFilters.get_ancestry(cb.container, sampled_idx) if i > N_burnin - trajectory_samples[i - N_burnin] = deepcopy(ref_traj) + push!(trajectory_samples, deepcopy(ref_traj)) end # Reference trajectory should only be nonlinear state for RBPF ref_traj = getproperty.(ref_traj, :x) From 03aa20a5175ff03119702b96408ea9f3b1ebe2b4 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 14 May 2025 13:56:11 -0400 Subject: [PATCH 07/11] ensure better type stability --- GeneralisedFilters/src/GeneralisedFilters.jl | 44 +++++++++++-------- .../src/algorithms/particles.jl | 25 +++-------- GeneralisedFilters/src/algorithms/rbpf.jl | 2 +- GeneralisedFilters/src/callbacks.jl | 10 +++-- GeneralisedFilters/src/containers.jl | 10 ++++- GeneralisedFilters/src/resamplers.jl | 6 ++- GeneralisedFilters/test/runtests.jl | 4 +- 7 files changed, 54 insertions(+), 47 deletions(-) diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index 937acb6..6028acb 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -63,15 +63,20 @@ function filter( model::AbstractStateSpaceModel, algo::AbstractFilter, observations::AbstractVector; - callback::Union{AbstractCallback,Nothing}=nothing, + callback::CallbackType=nothing, kwargs..., ) - state = initialise(rng, model, algo; kwargs...) - isnothing(callback) || callback(model, algo, state, observations, PostInit; kwargs...) + # draw from the prior + init_state = initialise(rng, model, algo; kwargs...) + callback(model, algo, init_state, observations, PostInit; kwargs...) - log_evidence = initialise_log_evidence(algo, model) + # iterations starts here for type stability + state, log_evidence = step( + rng, model, algo, 1, init_state, observations[1]; callback, kwargs... + ) - for t in eachindex(observations) + # subsequent iteration + for t in 2:length(observations) state, ll_increment = step( rng, model, algo, t, state, observations[t]; callback, kwargs... ) @@ -81,14 +86,6 @@ function filter( return state, log_evidence end -function initialise_log_evidence(::AbstractFilter, model::AbstractStateSpaceModel) - return 0 -end - -# function initialise_log_evidence(alg::AbstractBatchFilter, model::AbstractStateSpaceModel) -# return CUDA.zeros(SSMProblems.arithmetic_type(model), alg.batch_size) -# end - function filter( model::AbstractStateSpaceModel, algo::AbstractFilter, @@ -105,16 +102,27 @@ function step( iter::Integer, state, observation; - callback::Union{AbstractCallback,Nothing}=nothing, + kwargs..., +) + # generalised to fit analytical filters + return move(rng, model, algo, iter, state, observation; kwargs...) +end + +function move( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + algo::AbstractFilter, + iter::Integer, + state, + observation; + callback::CallbackType=nothing, kwargs..., ) state = predict(rng, model, algo, iter, state, observation; kwargs...) - isnothing(callback) || - callback(model, algo, iter, state, observation, PostPredict; kwargs...) + callback(model, algo, iter, state, observation, PostPredict; kwargs...) state, ll_increment = update(model, algo, iter, state, observation; kwargs...) - isnothing(callback) || - callback(model, algo, iter, state, observation, PostUpdate; kwargs...) + callback(model, algo, iter, state, observation, PostUpdate; kwargs...) return state, ll_increment end diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index b46b983..d206798 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -72,25 +72,12 @@ function step( state, observation; ref_state::Union{Nothing,AbstractVector}=nothing, - callback::Union{AbstractCallback,Nothing}=nothing, + callback::CallbackType=nothing, kwargs..., ) - # capture the marginalized log-likelihood state = resample(rng, algo.resampler, state; ref_state) - marginalization_term = log_marginal_likelihood(state) - isnothing(callback) || - callback(model, algo, iter, state, observation, PostResample; kwargs...) - - state = predict(rng, model, algo, iter, state, observation; ref_state, kwargs...) - isnothing(callback) || - callback(model, algo, iter, state, observation, PostPredict; kwargs...) - - # TODO: ll_increment is no longer consistent with the Kalman filter - state, ll_increment = update(model, algo, iter, state, observation; kwargs...) - isnothing(callback) || - callback(model, algo, iter, state, observation, PostUpdate; kwargs...) - - return state, (ll_increment - marginalization_term) + callback(model, algo, iter, state, observation, PostResample; kwargs...) + return move(rng, model, algo, iter, state, observation; ref_state, callback, kwargs...) end function initialise( @@ -157,7 +144,7 @@ function update( ) state = update_weights(state, log_increments) - return state, log_marginal_likelihood(state) + return state, logmeanexp(log_increments) end struct LatentProposal <: AbstractProposal end @@ -166,7 +153,7 @@ const BootstrapFilter{RS} = ParticleFilter{RS,LatentProposal} const BF = BootstrapFilter BootstrapFilter(N::Integer; kwargs...) = ParticleFilter(N, LatentProposal(); kwargs...) -function simulate( +function SSMProblems.simulate( rng::AbstractRNG, model::AbstractStateSpaceModel, prop::LatentProposal, @@ -178,7 +165,7 @@ function simulate( return SSMProblems.simulate(rng, model.dyn, iter, state; kwargs...) end -function logdensity( +function SSMProblems.logdensity( model::AbstractStateSpaceModel, prop::LatentProposal, iter::Integer, diff --git a/GeneralisedFilters/src/algorithms/rbpf.jl b/GeneralisedFilters/src/algorithms/rbpf.jl index aa29d62..3f7de8e 100644 --- a/GeneralisedFilters/src/algorithms/rbpf.jl +++ b/GeneralisedFilters/src/algorithms/rbpf.jl @@ -91,7 +91,7 @@ function update( end state = update_weights(state, log_increments) - return state, log_marginal_likelihood(state) + return state, logmeanexp(log_increments) end ################################# diff --git a/GeneralisedFilters/src/callbacks.jl b/GeneralisedFilters/src/callbacks.jl index d540037..f017ba4 100644 --- a/GeneralisedFilters/src/callbacks.jl +++ b/GeneralisedFilters/src/callbacks.jl @@ -16,15 +16,17 @@ abstract type AbstractCallback end abstract type CallbackTrigger end +const CallbackType = Union{Nothing,<:AbstractCallback} + struct PostInitCallback <: CallbackTrigger end const PostInit = PostInitCallback() -function (c::AbstractCallback)(model, filter, state, data, ::PostInitCallback; kwargs...) +function (c::CallbackType)(model, filter, state, data, ::PostInitCallback; kwargs...) return nothing end struct PostResampleCallback <: CallbackTrigger end const PostResample = PostResampleCallback() -function (c::AbstractCallback)( +function (c::CallbackType)( model, filter, step, state, data, ::PostResampleCallback; kwargs... ) return nothing @@ -32,7 +34,7 @@ end struct PostPredictCallback <: CallbackTrigger end const PostPredict = PostPredictCallback() -function (c::AbstractCallback)( +function (c::CallbackType)( model, filter, step, state, data, ::PostPredictCallback; kwargs... ) return nothing @@ -40,7 +42,7 @@ end struct PostUpdateCallback <: CallbackTrigger end const PostUpdate = PostUpdateCallback() -function (c::AbstractCallback)( +function (c::CallbackType)( model, filter, step, state, data, ::PostUpdateCallback; kwargs... ) return nothing diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index f9dd1db..5c0f862 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -56,8 +56,14 @@ function update_weights(state::WeightedParticles, log_weights) return state end -log_marginal_likelihood(state::Particles) = log(length(state)) -log_marginal_likelihood(state::WeightedParticles) = logsumexp(state.log_weights) +function fast_maximum(x::AbstractArray{T}; dims)::T where {T} + @fastmath reduce(max, x; dims, init = float(T)(-Inf)) +end + +function logmeanexp(x::AbstractArray{T}; dims = :)::T where {T} + max_ = fast_maximum(x; dims) + @fastmath max_ .+ log.(mean(exp.(x .- max_); dims)) +end ## RAO-BLACKWELLISED PARTICLE ############################################################## diff --git a/GeneralisedFilters/src/resamplers.jl b/GeneralisedFilters/src/resamplers.jl index 3ffca3a..1d2ae95 100644 --- a/GeneralisedFilters/src/resamplers.jl +++ b/GeneralisedFilters/src/resamplers.jl @@ -22,10 +22,14 @@ function resample( return construct_new_state(states, idxs) end -function construct_new_state(states::ParticleDistribution{PT}, idxs) where {PT} +function construct_new_state(states::Particles, idxs) return Particles(states.particles[idxs], idxs) end +function construct_new_state(states::WeightedParticles{PT,WT}, idxs) where {PT,WT} + return WeightedParticles(states.particles[idxs], idxs, zeros(WT, length(states))) +end + # function construct_new_state( # states::RaoBlackwellisedParticleDistribution{T}, idxs # ) where {T} diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index fe69281..f946f26 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -99,7 +99,7 @@ end # Compare log-likelihood and states @test first(kf_state.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-2 - @test llkf ≈ llbf atol = 1e-1 + @test llkf ≈ llbf atol = 1e-2 end @testitem "Guided filter test" begin @@ -138,7 +138,7 @@ end # Compare log-likelihood and states @test first(kf_states.μ) ≈ sum(first.(xs) .* ws) rtol = 1e-2 - @test kf_ll ≈ pf_ll rtol = 1e-1 + @test kf_ll ≈ pf_ll rtol = 1e-2 end @testitem "Forward algorithm test" begin From 7bc10ccff89e3d902ba554e4541ab0c1a2d72183 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 14 May 2025 13:57:41 -0400 Subject: [PATCH 08/11] formatter --- GeneralisedFilters/src/containers.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index 5c0f862..de7ad76 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -57,10 +57,10 @@ function update_weights(state::WeightedParticles, log_weights) end function fast_maximum(x::AbstractArray{T}; dims)::T where {T} - @fastmath reduce(max, x; dims, init = float(T)(-Inf)) + @fastmath reduce(max, x; dims, init=float(T)(-Inf)) end -function logmeanexp(x::AbstractArray{T}; dims = :)::T where {T} +function logmeanexp(x::AbstractArray{T}; dims=:)::T where {T} max_ = fast_maximum(x; dims) @fastmath max_ .+ log.(mean(exp.(x .- max_); dims)) end From f27c0dc0554ef11cc5471b20f07c909b431a39eb Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 14 May 2025 16:05:28 -0400 Subject: [PATCH 09/11] fix particle containers --- .../src/algorithms/particles.jl | 16 +++++----- GeneralisedFilters/src/algorithms/rbpf.jl | 2 +- GeneralisedFilters/src/containers.jl | 30 ++++++++++++------- GeneralisedFilters/src/resamplers.jl | 6 ++-- 4 files changed, 31 insertions(+), 23 deletions(-) diff --git a/GeneralisedFilters/src/algorithms/particles.jl b/GeneralisedFilters/src/algorithms/particles.jl index d206798..2cb2e02 100644 --- a/GeneralisedFilters/src/algorithms/particles.jl +++ b/GeneralisedFilters/src/algorithms/particles.jl @@ -95,7 +95,7 @@ function initialise( end end - return ParticleDistribution(particles) + return Particles(particles) end function predict( @@ -103,12 +103,12 @@ function predict( model::StateSpaceModel, algo::ParticleFilter, iter::Integer, - state::ParticleDistribution, + state, observation; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) - proposed_particles = map(enumerate(state.particles)) do (i, particle) + proposed_particles = map(enumerate(state)) do (i, particle) if !isnothing(ref_state) && i == 1 ref_state[iter] else @@ -116,7 +116,7 @@ function predict( end end - log_weights = map(zip(proposed_particles, state.particles)) do (new_state, prev_state) + log_increments = map(zip(proposed_particles, state)) do (new_state, prev_state) log_f = SSMProblems.logdensity(model.dyn, iter, prev_state, new_state; kwargs...) log_q = SSMProblems.logdensity( @@ -127,14 +127,14 @@ function predict( end state.particles = proposed_particles - return update_weights(state, log_weights) + return update_weights(state, log_increments) end function update( model::StateSpaceModel, algo::ParticleFilter, iter::Integer, - state::ParticleDistribution, + state, observation; kwargs..., ) @@ -183,12 +183,12 @@ function predict( model::StateSpaceModel, algo::BootstrapFilter, iter::Integer, - state::ParticleDistribution, + state, observation=nothing; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., ) - state.particles = map(enumerate(state.particles)) do (i, particle) + state.particles = map(enumerate(state)) do (i, particle) if !isnothing(ref_state) && i == 1 ref_state[iter] else diff --git a/GeneralisedFilters/src/algorithms/rbpf.jl b/GeneralisedFilters/src/algorithms/rbpf.jl index 3f7de8e..6ff7d7e 100644 --- a/GeneralisedFilters/src/algorithms/rbpf.jl +++ b/GeneralisedFilters/src/algorithms/rbpf.jl @@ -37,7 +37,7 @@ function initialise( RaoBlackwellisedParticle(x, z) end - return ParticleDistribution(particles) + return Particles(particles) end function predict( diff --git a/GeneralisedFilters/src/containers.jl b/GeneralisedFilters/src/containers.jl index de7ad76..6a8a33e 100644 --- a/GeneralisedFilters/src/containers.jl +++ b/GeneralisedFilters/src/containers.jl @@ -10,9 +10,12 @@ object, with the states (or particles) distributed accoring to their log-weights """ abstract type ParticleDistribution{PT} end -Base.collect(state::ParticleDistribution) = state.particles -Base.length(state::ParticleDistribution) = length(state.particles) -Base.keys(state::ParticleDistribution) = LinearIndices(state.particles) +Base.collect(state::PT) where {PT<:ParticleDistribution} = state.particles +Base.length(state::PT) where {PT<:ParticleDistribution} = length(state.particles) +Base.keys(state::PT) where {PT<:ParticleDistribution} = LinearIndices(state.particles) + +Base.iterate(state::ParticleDistribution, i) = Base.iterate(state.particles, i) +Base.iterate(state::ParticleDistribution) = Base.iterate(state.particles) # not sure if this is kosher, since it doesn't follow the convention of Base.getindex Base.@propagate_inbounds Base.getindex(state::ParticleDistribution, i) = state.particles[i] @@ -28,18 +31,12 @@ mutable struct WeightedParticles{PT,WT<:Real} <: ParticleDistribution{PT} log_weights::Vector{WT} end -# TODO: replace Gaussian with custom internals -# struct GaussianDistribution{PT,ΣT} <: ParticleDistribution{PT} -# μ::PT -# Σ::ΣT -# end - -function ParticleDistribution(particles::AbstractVector) +function Particles(particles::AbstractVector) N = length(particles) return Particles(particles, Vector{Int}(1:N)) end -function ParticleDistribution(particles::AbstractVector, log_weights::Vector{<:Real}) +function WeightedParticles(particles, log_weights) N = length(particles) return WeightedParticles(particles, Vector{Int}(1:N), log_weights) end @@ -65,6 +62,17 @@ function logmeanexp(x::AbstractArray{T}; dims=:)::T where {T} @fastmath max_ .+ log.(mean(exp.(x .- max_); dims)) end +## GAUSSIAN STATES ######################################################################### + +struct GaussianDistribution{PT,ΣT} <: ParticleDistribution{PT} + μ::PT + Σ::ΣT +end + +function mean_cov(state::GaussianDistribution) + return state.μ, state.Σ +end + ## RAO-BLACKWELLISED PARTICLE ############################################################## """ diff --git a/GeneralisedFilters/src/resamplers.jl b/GeneralisedFilters/src/resamplers.jl index 1d2ae95..1f393d6 100644 --- a/GeneralisedFilters/src/resamplers.jl +++ b/GeneralisedFilters/src/resamplers.jl @@ -22,12 +22,12 @@ function resample( return construct_new_state(states, idxs) end -function construct_new_state(states::Particles, idxs) - return Particles(states.particles[idxs], idxs) +function construct_new_state(states::Particles{PT}, idxs) where {PT} + return Particles{PT}(states.particles[idxs], idxs) end function construct_new_state(states::WeightedParticles{PT,WT}, idxs) where {PT,WT} - return WeightedParticles(states.particles[idxs], idxs, zeros(WT, length(states))) + return WeightedParticles{PT,WT}(states.particles[idxs], idxs, zeros(WT, length(states))) end # function construct_new_state( From 4d7e75bfa5bd4b3384ca6ab6687465a26bd10b1e Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 14 May 2025 16:08:14 -0400 Subject: [PATCH 10/11] remove dependence of GaussianDistributions --- GeneralisedFilters/Project.toml | 2 -- GeneralisedFilters/src/GeneralisedFilters.jl | 1 - GeneralisedFilters/src/algorithms/kalman.jl | 23 ++++++++++---------- GeneralisedFilters/test/runtests.jl | 3 +-- 4 files changed, 12 insertions(+), 17 deletions(-) diff --git a/GeneralisedFilters/Project.toml b/GeneralisedFilters/Project.toml index 56c2f0a..440fd2c 100644 --- a/GeneralisedFilters/Project.toml +++ b/GeneralisedFilters/Project.toml @@ -9,7 +9,6 @@ AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" -GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" @@ -28,7 +27,6 @@ Aqua = "0.8" CUDA = "5" DataStructures = "0.18.20" Distributions = "0.25" -GaussianDistributions = "0.5.2" LogExpFunctions = "0.3" NNlib = "0.9" OffsetArrays = "1.14.1" diff --git a/GeneralisedFilters/src/GeneralisedFilters.jl b/GeneralisedFilters/src/GeneralisedFilters.jl index 6028acb..7e712c8 100644 --- a/GeneralisedFilters/src/GeneralisedFilters.jl +++ b/GeneralisedFilters/src/GeneralisedFilters.jl @@ -3,7 +3,6 @@ module GeneralisedFilters using AbstractMCMC: AbstractMCMC, AbstractSampler import Distributions: MvNormal import Random: AbstractRNG, default_rng, rand -using GaussianDistributions: pairs, Gaussian using OffsetArrays using SSMProblems using StatsBase diff --git a/GeneralisedFilters/src/algorithms/kalman.jl b/GeneralisedFilters/src/algorithms/kalman.jl index 2f4e8d8..4f2a188 100644 --- a/GeneralisedFilters/src/algorithms/kalman.jl +++ b/GeneralisedFilters/src/algorithms/kalman.jl @@ -1,5 +1,4 @@ export KalmanFilter, filter, BatchKalmanFilter -using GaussianDistributions using CUDA: i32 import LinearAlgebra: hermitianpart @@ -13,7 +12,7 @@ function initialise( rng::AbstractRNG, model::LinearGaussianStateSpaceModel, filter::KalmanFilter; kwargs... ) μ0, Σ0 = calc_initial(model.prior; kwargs...) - return Gaussian(μ0, Σ0) + return GaussianDistribution(μ0, Σ0) end function predict( @@ -21,24 +20,24 @@ function predict( model::LinearGaussianStateSpaceModel, algo::KalmanFilter, iter::Integer, - state::Gaussian, + state::GaussianDistribution, observation=nothing; kwargs..., ) - μ, Σ = GaussianDistributions.pair(state) + μ, Σ = mean_cov(state) A, b, Q = calc_params(model.dyn, iter; kwargs...) - return Gaussian(A * μ + b, A * Σ * A' + Q) + return GaussianDistribution(A * μ + b, A * Σ * A' + Q) end function update( model::LinearGaussianStateSpaceModel, algo::KalmanFilter, iter::Integer, - state::Gaussian, + state::GaussianDistribution, observation::AbstractVector; kwargs..., ) - μ, Σ = GaussianDistributions.pair(state) + μ, Σ = mean_cov(state) H, c, R = calc_params(model.obs, iter; kwargs...) # Update state @@ -47,7 +46,7 @@ function update( S = hermitianpart(H * Σ * H' + R) K = Σ * H' / S - state = Gaussian(μ + K * y, Σ - K * H * Σ) + state = GaussianDistribution(μ + K * y, Σ - K * H * Σ) # Compute log-likelihood ll = logpdf(MvNormal(m, S), observation) @@ -215,13 +214,13 @@ function backward( states_cache, kwargs..., ) - μ, Σ = GaussianDistributions.pair(back_state) - μ_pred, Σ_pred = GaussianDistributions.pair(states_cache.proposed_states[iter + 1]) - μ_filt, Σ_filt = GaussianDistributions.pair(states_cache.filtered_states[iter]) + μ, Σ = mean_cov(back_state) + μ_pred, Σ_pred = mean_cov(states_cache.proposed_states[iter + 1]) + μ_filt, Σ_filt = mean_cov(states_cache.filtered_states[iter]) G = Σ_filt * model.dyn.A' * inv(Σ_pred) μ = μ_filt .+ G * (μ .- μ_pred) Σ = Σ_filt .+ G * (Σ .- Σ_pred) * G' - return Gaussian(μ, Σ) + return GaussianDistribution(μ, Σ) end diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index f946f26..c44b520 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -107,7 +107,6 @@ end using LogExpFunctions: softmax using StableRNGs using Distributions - using GaussianDistributions using LinearAlgebra struct LinearGaussianProposal <: GeneralisedFilters.AbstractProposal end @@ -121,7 +120,7 @@ end kwargs..., ) A, b, Q = GeneralisedFilters.calc_params(model.dyn, iter; kwargs...) - pred = Gaussian(A * state + b, Q) + pred = GeneralisedFilters.GaussianDistribution(A * state + b, Q) prop, _ = GeneralisedFilters.update(model, KF(), iter, pred, observation; kwargs...) return MvNormal(prop.μ, hermitianpart(prop.Σ)) end From 23ac9498d5ab4c0a6e8e9e8cafd9bde900dda185 Mon Sep 17 00:00:00 2001 From: Charles Knipp Date: Wed, 14 May 2025 16:11:31 -0400 Subject: [PATCH 11/11] fix unit tests --- GeneralisedFilters/test/runtests.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/GeneralisedFilters/test/runtests.jl b/GeneralisedFilters/test/runtests.jl index c44b520..33e12f7 100644 --- a/GeneralisedFilters/test/runtests.jl +++ b/GeneralisedFilters/test/runtests.jl @@ -444,7 +444,6 @@ end @testitem "RBCSMC test" begin using GeneralisedFilters - using GaussianDistributions using StableRNGs using PDMats using LinearAlgebra