Skip to content

Create wrapper for square-root Kalman filter #97

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions GeneralisedFilters/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,12 @@ CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
KalmanFilters = "272a6111-cf0e-4c1b-a056-8d658cb314ee"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
PDMats = "90014a1f-27ba-587c-ab20-58faa44d9150"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Expand Down
56 changes: 55 additions & 1 deletion GeneralisedFilters/src/algorithms/kalman.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
export KalmanFilter, filter, BatchKalmanFilter
using GaussianDistributions
using CUDA: i32
using PDMats
using KalmanFilters
import LinearAlgebra: hermitianpart

export KalmanFilter, KF, KalmanSmoother, KS
export KalmanFilter, KF, KalmanSmoother, KS, SRKF

struct KalmanFilter <: AbstractFilter end

Expand Down Expand Up @@ -129,6 +131,58 @@ function update(
return BatchGaussianDistribution(μ_filt, Σ_filt), dropdims(log_likes; dims=1)
end

## SQUARE-ROOT KALMAN FILTER ###############################################################

"""
SRKF()

A square-root Kalman filter.

Implemented by wrapping KalmanFilters.jl.
"""
struct SRKF <: AbstractFilter end

function initialise(
rng::AbstractRNG, model::LinearGaussianStateSpaceModel, filter::SRKF; kwargs...
)
μ0, Σ0 = calc_initial(model.dyn; kwargs...)
return Gaussian(μ0, Σ0)
end

function predict(
rng::AbstractRNG,
model::LinearGaussianStateSpaceModel,
::SRKF,
iter::Integer,
state::Gaussian,
observation=nothing;
kwargs...,
)
μ, Σ = GaussianDistributions.pair(state)
A, b, Q = calc_params(model.dyn, iter; kwargs...)
!all(b .== 0) && error("SKRF doesn't current support non-zero b")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
!all(b .== 0) && error("SKRF doesn't current support non-zero b")
b != zero) && error("SKRF doesn't current support non-zero b")

same as before


tu = KalmanFilters.time_update(μ, cholesky(Σ), A, cholesky(Q))
return Gaussian(tu.state, PDMat(tu.covariance))
end

function update(
model::LinearGaussianStateSpaceModel,
algo::SRKF,
iter::Integer,
state::Gaussian,
observation::AbstractVector;
kwargs...,
)
μ, Σ = GaussianDistributions.pair(state)
H, c, R = calc_params(model.obs, iter; kwargs...)
!all(c .== 0) && error("SKRF doesn't current support non-zero c")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
!all(c .== 0) && error("SKRF doesn't current support non-zero c")
c != zero) && error("SKRF doesn't current support non-zero c")

I doubt this gets us any speed up, but should guarantee type stability in 99% of cases.


mu = KalmanFilters.measurement_update(μ, cholesky(Σ), observation, H, cholesky(R))
ll = logpdf(MvNormal(mu.innovation, PDMat(mu.innovation_covariance)), zero(observation))
return Gaussian(mu.state, PDMat(mu.covariance)), ll
end

## KALMAN SMOOTHER #########################################################################

struct KalmanSmoother <: AbstractSmoother end
Expand Down
22 changes: 13 additions & 9 deletions GeneralisedFilters/src/models/hierarchical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,24 +29,28 @@ function AbstractMCMC.sample(
)
outer_dyn, inner_model = model.outer_dyn, model.inner_model

zs = Vector{eltype(inner_model.dyn)}(undef, T)
xs = Vector{eltype(outer_dyn)}(undef, T)
zs = OffsetVector(Vector{eltype(inner_model.dyn)}(undef, T + 1), -1)
xs = OffsetVector(Vector{eltype(outer_dyn)}(undef, T + 1), -1)
ys = Vector{eltype(inner_model.obs)}(undef, T)

# Simulate outer dynamics
x0 = simulate(rng, outer_dyn; kwargs...)
z0 = simulate(rng, inner_model.dyn; new_outer=x0, kwargs...)
xs[0] = simulate(rng, outer_dyn; kwargs...)
zs[0] = simulate(rng, inner_model.dyn; new_outer=xs[0], kwargs...)
for t in 1:T
prev_x = t == 1 ? x0 : xs[t - 1]
prev_z = t == 1 ? z0 : zs[t - 1]
xs[t] = simulate(rng, model.outer_dyn, t, prev_x; kwargs...)
xs[t] = simulate(rng, model.outer_dyn, t, xs[t - 1]; kwargs...)
zs[t] = simulate(
rng, inner_model.dyn, t, prev_z; prev_outer=prev_x, new_outer=xs[t], kwargs...
rng,
inner_model.dyn,
t,
zs[t - 1];
prev_outer=xs[t - 1],
new_outer=xs[t],
kwargs...,
)
ys[t] = simulate(rng, inner_model.obs, t, zs[t]; new_outer=xs[t], kwargs...)
end

return x0, z0, xs, zs, ys
return xs, zs, ys
end

## Methods to make HierarchicalSSM compatible with the bootstrap filter
Expand Down
25 changes: 15 additions & 10 deletions GeneralisedFilters/src/models/linear_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,18 @@ end
###########################################

struct HomogeneousLinearGaussianLatentDynamics{
T<:Real,ΣT<:AbstractMatrix{T},AT<:AbstractMatrix{T},QT<:AbstractMatrix{T}
T<:Real,
T_μ0<:AbstractVector{T},
T_Σ0<:AbstractMatrix{T},
T_A<:AbstractMatrix{T},
T_b<:AbstractVector{T},
T_Q<:AbstractMatrix{T},
} <: LinearGaussianLatentDynamics{T}
μ0::Vector{T}
Σ0::ΣT
A::AT
b::Vector{T}
Q::QT
μ0::T_μ0
Σ0::T_Σ0
A::T_A
b::T_b
Q::T_Q
end
calc_μ0(dyn::HomogeneousLinearGaussianLatentDynamics; kwargs...) = dyn.μ0
calc_Σ0(dyn::HomogeneousLinearGaussianLatentDynamics; kwargs...) = dyn.Σ0
Expand All @@ -91,11 +96,11 @@ calc_b(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn
calc_Q(dyn::HomogeneousLinearGaussianLatentDynamics, ::Integer; kwargs...) = dyn.Q

struct HomogeneousLinearGaussianObservationProcess{
T<:Real,HT<:AbstractMatrix{T},RT<:AbstractMatrix{T}
T<:Real,T_H<:AbstractMatrix{T},T_c<:AbstractVector{T},T_R<:AbstractMatrix{T}
} <: LinearGaussianObservationProcess{T}
H::HT
c::Vector{T}
R::RT
H::T_H
c::T_c
R::T_R
end
calc_H(obs::HomogeneousLinearGaussianObservationProcess, ::Integer; kwargs...) = obs.H
calc_c(obs::HomogeneousLinearGaussianObservationProcess, ::Integer; kwargs...) = obs.c
Expand Down
84 changes: 70 additions & 14 deletions GeneralisedFilters/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ include("resamplers.jl")
for Dy in Dys
rng = StableRNG(1234)
model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy)
_, _, ys = sample(rng, model, 1)
_, ys = sample(rng, model, 1)

filtered, ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), ys)

Expand All @@ -48,6 +48,62 @@ include("resamplers.jl")
end
end

@testitem "Square root Kalman filter test" begin
using GeneralisedFilters
using LinearAlgebra
using PDMats
using StableRNGs
using StaticArrays

rng = StableRNG(1234)
μ0 = rand(rng, 2)
Σ0 = rand(rng, 2, 2)
Σ0 = Σ0 * Σ0' # make Σ0 positive definite
Σ0 = PDMat(Σ0)
A = rand(rng, 2, 2)
b = zeros(2)
Q = rand(rng, 2, 2)
Q = Q * Q' # make Q positive definite
Q = PDMat(Q)
H = rand(rng, 2, 2)
c = zeros(2)
R = rand(rng, 2, 2)
R = R * R' # make R positive definite
R = PDMat(R)

model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R)

T = 3
observations = [rand(rng, 2) for _ in 1:T]

# kf_state, kf_ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), observations)
# srkf_state, srkf_ll = GeneralisedFilters.filter(rng, model, SRKF(), observations)

# @test isa(srkf_state.Σ, PDMat)
# @test kf_state.μ ≈ srkf_state.μ
# @test kf_state.Σ ≈ srkf_state.Σ
# @test kf_ll ≈ srkf_ll

# Test with StaticArrays.jl
μ0_static = SVector{2}(μ0)
Σ0_static = PDMat(SMatrix{2,2}(Σ0.mat))
A_static = SMatrix{2,2}(A)
b_static = SVector{2}(b)
Q_static = PDMat(SMatrix{2,2}(Q.mat))
H_static = SMatrix{2,2}(H)
c_static = SVector{2}(c)
R_static = PDMat(SMatrix{2,2}(R.mat))
model_static = create_homogeneous_linear_gaussian_model(
μ0_static, Σ0_static, A_static, b_static, Q_static, H_static, c_static, R_static
)
observations_static = [SVector{2}(rand(rng, 2)) for _ in 1:T]
srkf_state, _ = GeneralisedFilters.filter(
rng, model_static, SRKF(), observations_static
)
# @test isa(srkf_state.μ, SVector)
# @test isa(srkf_state.Σ.mat, SMatrix)
Comment on lines +103 to +104
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's up with the unit testing?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to fix the upstream issue with StaticArrays getting converted to regular arrays. Then I can comment these back out. They'll fail currently.

end

@testitem "Kalman smoother test" begin
using GeneralisedFilters
using Distributions
Expand All @@ -61,7 +117,7 @@ end
for Dy in Dys
rng = StableRNG(1234)
model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy)
_, _, ys = sample(rng, model, 2)
_, ys = sample(rng, model, 2)

states, ll = GeneralisedFilters.smooth(rng, model, KalmanSmoother(), ys)

Expand All @@ -88,7 +144,7 @@ end

rng = StableRNG(1234)
model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1)
_, _, ys = sample(rng, model, 10)
_, ys = sample(rng, model, 10)

bf = BF(2^12; threshold=0.8)
bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, ys)
Expand Down Expand Up @@ -128,7 +184,7 @@ end

rng = StableRNG(1234)
model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1)
_, _, ys = sample(rng, model, 10)
_, ys = sample(rng, model, 10)

algo = PF(2^10, LinearGaussianProposal(); threshold=0.6)
kf_states, kf_ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), ys)
Expand Down Expand Up @@ -212,7 +268,7 @@ end
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model(
rng, D_outer, D_inner, D_obs
)
_, _, ys = sample(rng, full_model, T)
_, ys = sample(rng, full_model, T)

# Ground truth Kalman filtering
kf_states, kf_ll = GeneralisedFilters.filter(rng, full_model, KalmanFilter(), ys)
Expand Down Expand Up @@ -255,7 +311,7 @@ end
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model(
rng, D_outer, D_inner, D_obs, ET
)
_, _, ys = sample(rng, full_model, T)
_, ys = sample(rng, full_model, T)

# Ground truth Kalman filtering
kf_state, kf_ll = GeneralisedFilters.filter(full_model, KalmanFilter(), ys)
Expand Down Expand Up @@ -290,7 +346,7 @@ end
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model(
rng, 1, 1, 1
)
_, _, ys = sample(rng, full_model, T)
_, ys = sample(rng, full_model, T)

# Manually create tree to force expansion on second step
particle_type = GeneralisedFilters.RaoBlackwellisedParticle{
Expand Down Expand Up @@ -325,7 +381,7 @@ end
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model(
rng, D_outer, D_inner, D_obs
)
_, _, ys = sample(rng, full_model, T)
_, ys = sample(rng, full_model, T)

# Ground truth Kalman filtering
kf_states, kf_ll = GeneralisedFilters.filter(rng, full_model, KalmanFilter(), ys)
Expand Down Expand Up @@ -373,7 +429,7 @@ end

rng = StableRNG(SEED)
model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1)
_, _, ys = sample(rng, model, K)
_, ys = sample(rng, model, K)

ref_traj = OffsetVector([rand(rng, 1) for _ in 0:K], -1)

Expand Down Expand Up @@ -413,7 +469,7 @@ end

rng = StableRNG(SEED)
model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy)
_, _, ys = sample(rng, model, K)
_, ys = sample(rng, model, K)

# Kalman smoother
state, ks_ll = GeneralisedFilters.smooth(
Expand Down Expand Up @@ -477,7 +533,7 @@ end
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model(
rng, D_outer, D_inner, D_obs, T; static_arrays=true
)
_, _, ys = sample(rng, full_model, K)
_, ys = sample(rng, full_model, K)

# Kalman smoother
state, _ = GeneralisedFilters.smooth(
Expand Down Expand Up @@ -562,7 +618,7 @@ end
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model(
rng, D_outer, D_inner, D_obs, T
)
_, _, ys = sample(rng, full_model, K)
_, ys = sample(rng, full_model, K)

# Generate random reference trajectory
ref_trajectory = [CuArray(rand(rng, T, D_outer, 1)) for _ in 0:K]
Expand Down Expand Up @@ -595,7 +651,7 @@ end
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model(
rng, D_outer, D_inner, D_obs, T
)
_, _, ys = sample(rng, full_model, K)
_, ys = sample(rng, full_model, K)

# Manually create tree to force expansion on second step
M = N_particles * 2 - 1
Expand Down Expand Up @@ -642,7 +698,7 @@ end
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model(
rng, D_outer, D_inner, D_obs, T
)
_, _, ys = sample(rng, full_model, K)
_, ys = sample(rng, full_model, K)

# Kalman smoother
state, _ = GeneralisedFilters.smooth(
Expand Down
8 changes: 3 additions & 5 deletions SSMProblems/Project.toml
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
name = "SSMProblems"
uuid = "26aad666-b158-4e64-9d35-0e672562fa48"
authors = [
"FredericWantiez <frederic.wantiez@gmail.com>",
"THargreaves <tim.hargreaves@icloud.com>"
]
version = "0.5.2"
authors = ["FredericWantiez <frederic.wantiez@gmail.com>", "THargreaves <tim.hargreaves@icloud.com>"]
version = "0.6.0"

[deps]
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"

[extras]
Expand Down
6 changes: 3 additions & 3 deletions SSMProblems/examples/kalman-filter/script.jl
Original file line number Diff line number Diff line change
Expand Up @@ -157,16 +157,16 @@ model = StateSpaceModel(dyn, obs);
# functions we defined above.

rng = MersenneTwister(SEED);
x0, xs, ys = sample(rng, model, T);
xs, ys = sample(rng, model, T);

# We can then run the Kalman filter and plot the filtering results against the ground truth.

x_filts, P_filts = AbstractMCMC.sample(model, KalmanFilter(), ys);

# Plot trajectory for first dimension
p = plot(; title="First Dimension Kalman Filter Estimates", xlabel="Step", ylabel="Value")
plot!(p, 1:T, first.(xs); label="Truth")
scatter!(p, 1:T, first.(ys); label="Observations")
plot!(p, first.(xs); label="Truth")
scatter!(p, first.(ys); label="Observations")
plot!(
p,
1:T,
Expand Down
Loading
Loading