diff --git a/Project.toml b/Project.toml index 97bd214..6a105b7 100644 --- a/Project.toml +++ b/Project.toml @@ -21,6 +21,7 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" [compat] DataStructures = "0.18.20" GaussianDistributions = "0.5.2" +SSMProblems = "0.4.0" StatsBase = "0.34.3" [extras] diff --git a/src/GeneralisedFilters.jl b/src/GeneralisedFilters.jl index da9b8d2..d53589f 100644 --- a/src/GeneralisedFilters.jl +++ b/src/GeneralisedFilters.jl @@ -13,6 +13,8 @@ using NNlib abstract type AbstractFilter <: AbstractSampler end +abstract type AbstractParticleFilter{N} <: AbstractFilter end + """ predict([rng,] model, alg, iter, state; kwargs...) @@ -41,6 +43,22 @@ Perform a combined predict and update call on a single iteration of the filter. """ function step end +""" + reset_weights!(log_weights, filter) + +Reset container log-weights after a resampling step +""" +function reset_weights! end + +""" + update_weights! +""" +function update_weights! end + +function log_marginal end + +function update_ref! end + function initialise(model, alg; kwargs...) return initialise(default_rng(), model, alg; kwargs...) end @@ -106,6 +124,7 @@ include("models/hierarchical.jl") # Filtering/smoothing algorithms include("algorithms/bootstrap.jl") +include("algorithms/apf.jl") include("algorithms/kalman.jl") include("algorithms/forward.jl") include("algorithms/rbpf.jl") diff --git a/src/algorithms/apf.jl b/src/algorithms/apf.jl new file mode 100644 index 0000000..f4a9df6 --- /dev/null +++ b/src/algorithms/apf.jl @@ -0,0 +1,116 @@ +export AuxiliaryParticleFilter, APF + +mutable struct AuxiliaryParticleFilter{N,RS<:AbstractConditionalResampler} <: AbstractParticleFilter{N} + resampler::RS + aux::Vector # Auxiliary weights +end + +function AuxiliaryParticleFilter( + N::Integer; threshold::Real=0., resampler::AbstractResampler=Systematic() +) + conditional_resampler = ESSResampler(threshold, resampler) + return AuxiliaryParticleFilter{N,typeof(conditional_resampler)}(conditional_resampler, zeros(N)) +end + +const APF = AuxiliaryParticleFilter + +function initialise( + rng::AbstractRNG, + model::StateSpaceModel{T}, + filter::AuxiliaryParticleFilter{N}, + ref_state::Union{Nothing,AbstractVector}=nothing, + kwargs..., +) where {N,T} + initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N) + initial_weights = zeros(T, N) + + return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state, filter) +end + +function update_weights!( + rng::AbstractRNG, filter, model, step, states, observation; kwargs... +) + simulation_weights = eta(rng, model, step, states, observation) + return states.log_weights += simulation_weights +end + +function predict( + rng::AbstractRNG, + model::StateSpaceModel, + filter::AuxiliaryParticleFilter, + step::Integer, + states::ParticleContainer{T}, + observation; + ref_state::Union{Nothing,AbstractVector{T}}=nothing, + kwargs..., +) where {T} + # states = update_weights!(rng, filter.eta, model, step, states.filtered, observation; kwargs...) + + # Compute auxilary weights + # POC: use the simplest approximation to the predictive likelihood + # Ideally should be something like update_weights!(filter, ...) + predicted = map( + x -> mean(SSMProblems.distribution(model.dyn, step, x; kwargs...)), + states.filtered.particles, + ) + auxiliary_weights = map( + x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), predicted + ) + states.filtered.log_weights .+= auxiliary_weights + filter.aux = auxiliary_weights + + states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter) + states.proposed.particles = map( + x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), + states.proposed.particles, + ) + + return update_ref!(states, ref_state, filter, step) +end + +function update( + model::StateSpaceModel{T}, + filter::AuxiliaryParticleFilter, + step::Integer, + states::ParticleContainer, + observation; + kwargs..., +) where {T} + @debug "step $step" + log_increments = map( + x -> SSMProblems.logdensity(model.obs, step, x, observation; kwargs...), + collect(states.proposed.particles), + ) + + states.filtered.log_weights = states.proposed.log_weights + log_increments + states.filtered.particles = states.proposed.particles + + return states, logmarginal(states, filter) +end + +function step( + rng::AbstractRNG, + model::AbstractStateSpaceModel, + alg::AuxiliaryParticleFilter, + iter::Integer, + state, + observation; + kwargs..., +) + proposed_state = predict(rng, model, alg, iter, state, observation; kwargs...) + filtered_state, ll = update(model, alg, iter, proposed_state, observation; kwargs...) + + return filtered_state, ll +end + +function reset_weights!( + state::ParticleState{T,WT}, idxs, filter::AuxiliaryParticleFilter +) where {T,WT<:Real} + # From Choping: An Introduction to sequential monte carlo, section 10.3.3 + state.log_weights = state.log_weights[idxs] - filter.aux[idxs] + return state +end + +function logmarginal(states::ParticleContainer, ::AuxiliaryParticleFilter) + return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) +end diff --git a/src/algorithms/bootstrap.jl b/src/algorithms/bootstrap.jl index ac8ed1e..a508761 100644 --- a/src/algorithms/bootstrap.jl +++ b/src/algorithms/bootstrap.jl @@ -1,31 +1,32 @@ export BootstrapFilter, BF -struct BootstrapFilter{RS<:AbstractResampler} <: AbstractFilter - N::Integer +struct BootstrapFilter{N,RS<:AbstractResampler} <: AbstractParticleFilter{N} resampler::RS end -"""Shorthand for `BootstrapFilter`""" -const BF = BootstrapFilter - function BootstrapFilter( N::Integer; threshold::Real=1.0, resampler::AbstractResampler=Systematic() ) conditional_resampler = ESSResampler(threshold, resampler) - return BootstrapFilter(N, conditional_resampler) + return BootstrapFilter{N, typeof(conditional_resampler)}(conditional_resampler) end +"""Shorthand for `BootstrapFilter`""" +const BF = BootstrapFilter + function initialise( rng::AbstractRNG, model::StateSpaceModel{T}, - filter::BootstrapFilter; + filter::BootstrapFilter{N}; ref_state::Union{Nothing,AbstractVector}=nothing, kwargs..., -) where {T} - initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:(filter.N)) - initial_weights = zeros(T, filter.N) +) where {N,T} + initial_states = map(x -> SSMProblems.simulate(rng, model.dyn; kwargs...), 1:N) + initial_weights = zeros(T, N) - return update_ref!(ParticleContainer(initial_states, initial_weights), ref_state) + return update_ref!( + ParticleContainer(initial_states, initial_weights), ref_state, filter + ) end function predict( @@ -37,13 +38,13 @@ function predict( ref_state::Union{Nothing,AbstractVector{T}}=nothing, kwargs..., ) where {T} - states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered) + states.proposed, states.ancestors = resample(rng, filter.resampler, states.filtered, filter) states.proposed.particles = map( x -> SSMProblems.simulate(rng, model.dyn, step, x; kwargs...), collect(states.proposed), ) - return update_ref!(states, ref_state, step) + return update_ref!(states, ref_state, filter, step) end function update( @@ -62,5 +63,16 @@ function update( states.filtered.log_weights = states.proposed.log_weights + log_increments states.filtered.particles = states.proposed.particles - return states, logmarginal(states) + return states, logmarginal(states, filter) +end + +function reset_weights!( + state::ParticleState{T,WT}, idxs, filter::BootstrapFilter +) where {T,WT<:Real} + fill!(state.log_weights, -log(WT(length(state.particles)))) + return state +end + +function logmarginal(states::ParticleContainer, ::BootstrapFilter) + return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) end diff --git a/src/algorithms/rbpf.jl b/src/algorithms/rbpf.jl index 4b1b700..11b47d5 100644 --- a/src/algorithms/rbpf.jl +++ b/src/algorithms/rbpf.jl @@ -72,7 +72,7 @@ end function predict( rng::AbstractRNG, model::HierarchicalSSM, algo::RBPF, t::Integer, states; kwargs... ) - states.proposed, states.ancestors = resample(rng, algo.resampler, states.filtered) + states.proposed, states.ancestors = resample(rng, algo.resampler, states.filtered, algo) states.proposed.particles = map( x -> marginal_predict(rng, model, algo, t, x; kwargs...), @@ -108,7 +108,7 @@ function update( states.filtered.log_weights = states.proposed.log_weights + log_increments - return states, logmarginal(states) + return states, logmarginal(states, algo) end ################################# diff --git a/src/containers.jl b/src/containers.jl index e7e652b..e31c7f0 100644 --- a/src/containers.jl +++ b/src/containers.jl @@ -1,5 +1,5 @@ using DataStructures: Stack -using Random: rand +import Random: rand ## GAUSSIAN STATES ######################################################################### @@ -105,13 +105,20 @@ Base.keys(state::ParticleState) = LinearIndices(state.particles) Base.@propagate_inbounds Base.getindex(state::ParticleState, i) = state.particles[i] # Base.@propagate_inbounds Base.getindex(state::ParticleState, i::Vector{Int}) = state.particles[i] -function reset_weights!(state::ParticleState{T,WT}) where {T,WT<:Real} +function reset_weights!(state::ParticleState{T,WT}, idx, ::AbstractFilter) where {T,WT<:Real} fill!(state.log_weights, zero(WT)) return state.log_weights end +function logmarginal(states::ParticleContainer, ::AbstractFilter) + return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) +end + function update_ref!( - pc::ParticleContainer{T}, ref_state::Union{Nothing,AbstractVector{T}}, step::Integer=0 + pc::ParticleContainer{T}, + ref_state::Union{Nothing,AbstractVector{T}}, + ::AbstractFilter, + step::Integer=0, ) where {T} # this comes from Nicolas Chopin's package particles if !isnothing(ref_state) @@ -122,10 +129,6 @@ function update_ref!( return pc end -function logmarginal(states::ParticleContainer) - return logsumexp(states.filtered.log_weights) - logsumexp(states.proposed.log_weights) -end - ## SPARSE PARTICLE STORAGE ################################################################# Base.append!(s::Stack, a::AbstractVector) = map(x -> push!(s, x), a) @@ -180,7 +183,7 @@ function prune!(tree::ParticleTree, offspring::Vector{Int64}) end function insert!( - tree::ParticleTree{T}, states::Vector{T}, ancestors::AbstractVector{Int64} + tree::ParticleTree{T}, states::Vector{T}, ancestors::AbstractVector{<:Integer} ) where {T} # parents of new generation parents = getindex(tree.leaves, ancestors) @@ -213,7 +216,7 @@ function expand!(tree::ParticleTree) return tree end -function get_offspring(a::AbstractVector{Int64}) +function get_offspring(a::AbstractVector{<:Integer}) offspring = zero(a) for i in a offspring[i] += 1 diff --git a/src/resamplers.jl b/src/resamplers.jl index d0e74e5..60a7b36 100644 --- a/src/resamplers.jl +++ b/src/resamplers.jl @@ -8,13 +8,15 @@ export Multinomial, Systematic, Metropolis, Rejection abstract type AbstractResampler end function resample( - rng::AbstractRNG, resampler::AbstractResampler, states::ParticleState{PT,WT} + rng::AbstractRNG, + resampler::AbstractResampler, + states::ParticleState{PT,WT}, + filter::AbstractFilter; + weights::AbstractVector{WT}=StatsBase.weights(states) ) where {PT,WT} - weights = StatsBase.weights(states) idxs = sample_ancestors(rng, resampler, weights) - - new_state = ParticleState(deepcopy(states.particles[idxs]), zeros(WT, length(states))) - + new_state = ParticleState(deepcopy(states.particles[idxs]), zeros(WT, length(states))) + reset_weights!(new_state, idxs, filter) return new_state, idxs end @@ -23,8 +25,9 @@ function resample( rng::AbstractRNG, resampler::AbstractResampler, states::RaoBlackwellisedParticleState{T,M,ZT}, + ::AbstractFilter; + weights=StatsBase.weights(states) ) where {T,M,ZT} - weights = StatsBase.weights(states) idxs = sample_ancestors(rng, resampler, weights) new_state = RaoBlackwellisedParticleState( @@ -49,7 +52,7 @@ struct ESSResampler <: AbstractConditionalResampler end function resample( - rng::AbstractRNG, cond_resampler::ESSResampler, state::ParticleState{PT,WT} + rng::AbstractRNG, cond_resampler::ESSResampler, state::ParticleState{PT,WT}, filter::AbstractFilter ) where {PT,WT} n = length(state) # TODO: computing weights twice. Should create a wrapper to avoid this @@ -58,7 +61,7 @@ function resample( @debug "ESS: $ess" if cond_resampler.threshold * n ≥ ess - return resample(rng, cond_resampler.resampler, state) + return resample(rng, cond_resampler.resampler, state, filter; weights=weights) else return deepcopy(state), collect(1:n) end diff --git a/test/runtests.jl b/test/runtests.jl index 86c0f96..4d85a0c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -109,7 +109,9 @@ end _, _, data = sample(rng, model, 20) bf = BF(2^12; threshold=0.8) + apf = APF(2^10, threshold=1.) bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, data) + _, llapf= GeneralisedFilters.filter(rng, model, apf, data) kf_state, llkf = GeneralisedFilters.filter(rng, model, KF(), data) xs = bf_state.filtered.particles @@ -120,6 +122,7 @@ end # since this is log valued, we can up the tolerance @test llkf ≈ llbf atol = 0.1 + @test llkf ≈ llapf atol = 2 end @testitem "Forward algorithm test" begin