-
Notifications
You must be signed in to change notification settings - Fork 3
Create wrapper for square-root Kalman filter #97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
c69814e
68fe012
e26d838
f50717e
d92f00f
5291ea6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
@@ -1,9 +1,11 @@ | ||||||
export KalmanFilter, filter, BatchKalmanFilter | ||||||
using GaussianDistributions | ||||||
using CUDA: i32 | ||||||
using PDMats | ||||||
using KalmanFilters | ||||||
import LinearAlgebra: hermitianpart | ||||||
|
||||||
export KalmanFilter, KF, KalmanSmoother, KS | ||||||
export KalmanFilter, KF, KalmanSmoother, KS, SRKF | ||||||
|
||||||
struct KalmanFilter <: AbstractFilter end | ||||||
|
||||||
|
@@ -129,6 +131,58 @@ function update( | |||||
return BatchGaussianDistribution(μ_filt, Σ_filt), dropdims(log_likes; dims=1) | ||||||
end | ||||||
|
||||||
## SQUARE-ROOT KALMAN FILTER ############################################################### | ||||||
|
||||||
""" | ||||||
SRKF() | ||||||
|
||||||
A square-root Kalman filter. | ||||||
|
||||||
Implemented by wrapping KalmanFilters.jl. | ||||||
""" | ||||||
struct SRKF <: AbstractFilter end | ||||||
|
||||||
function initialise( | ||||||
rng::AbstractRNG, model::LinearGaussianStateSpaceModel, filter::SRKF; kwargs... | ||||||
) | ||||||
μ0, Σ0 = calc_initial(model.dyn; kwargs...) | ||||||
return Gaussian(μ0, Σ0) | ||||||
end | ||||||
|
||||||
function predict( | ||||||
rng::AbstractRNG, | ||||||
model::LinearGaussianStateSpaceModel, | ||||||
::SRKF, | ||||||
iter::Integer, | ||||||
state::Gaussian, | ||||||
observation=nothing; | ||||||
kwargs..., | ||||||
) | ||||||
μ, Σ = GaussianDistributions.pair(state) | ||||||
A, b, Q = calc_params(model.dyn, iter; kwargs...) | ||||||
!all(b .== 0) && error("SKRF doesn't current support non-zero b") | ||||||
|
||||||
tu = KalmanFilters.time_update(μ, cholesky(Σ), A, cholesky(Q)) | ||||||
return Gaussian(tu.state, PDMat(tu.covariance)) | ||||||
end | ||||||
|
||||||
function update( | ||||||
model::LinearGaussianStateSpaceModel, | ||||||
algo::SRKF, | ||||||
iter::Integer, | ||||||
state::Gaussian, | ||||||
observation::AbstractVector; | ||||||
kwargs..., | ||||||
) | ||||||
μ, Σ = GaussianDistributions.pair(state) | ||||||
H, c, R = calc_params(model.obs, iter; kwargs...) | ||||||
!all(c .== 0) && error("SKRF doesn't current support non-zero c") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
I doubt this gets us any speed up, but should guarantee type stability in 99% of cases. |
||||||
|
||||||
mu = KalmanFilters.measurement_update(μ, cholesky(Σ), observation, H, cholesky(R)) | ||||||
ll = logpdf(MvNormal(mu.innovation, PDMat(mu.innovation_covariance)), zero(observation)) | ||||||
return Gaussian(mu.state, PDMat(mu.covariance)), ll | ||||||
end | ||||||
|
||||||
## KALMAN SMOOTHER ######################################################################### | ||||||
|
||||||
struct KalmanSmoother <: AbstractSmoother end | ||||||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,7 +21,7 @@ include("resamplers.jl") | |
for Dy in Dys | ||
rng = StableRNG(1234) | ||
model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy) | ||
_, _, ys = sample(rng, model, 1) | ||
_, ys = sample(rng, model, 1) | ||
|
||
filtered, ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), ys) | ||
|
||
|
@@ -48,6 +48,62 @@ include("resamplers.jl") | |
end | ||
end | ||
|
||
@testitem "Square root Kalman filter test" begin | ||
using GeneralisedFilters | ||
using LinearAlgebra | ||
using PDMats | ||
using StableRNGs | ||
using StaticArrays | ||
|
||
rng = StableRNG(1234) | ||
μ0 = rand(rng, 2) | ||
Σ0 = rand(rng, 2, 2) | ||
Σ0 = Σ0 * Σ0' # make Σ0 positive definite | ||
Σ0 = PDMat(Σ0) | ||
A = rand(rng, 2, 2) | ||
b = zeros(2) | ||
Q = rand(rng, 2, 2) | ||
Q = Q * Q' # make Q positive definite | ||
Q = PDMat(Q) | ||
H = rand(rng, 2, 2) | ||
c = zeros(2) | ||
R = rand(rng, 2, 2) | ||
R = R * R' # make R positive definite | ||
R = PDMat(R) | ||
|
||
model = create_homogeneous_linear_gaussian_model(μ0, Σ0, A, b, Q, H, c, R) | ||
|
||
T = 3 | ||
observations = [rand(rng, 2) for _ in 1:T] | ||
|
||
# kf_state, kf_ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), observations) | ||
# srkf_state, srkf_ll = GeneralisedFilters.filter(rng, model, SRKF(), observations) | ||
|
||
# @test isa(srkf_state.Σ, PDMat) | ||
# @test kf_state.μ ≈ srkf_state.μ | ||
# @test kf_state.Σ ≈ srkf_state.Σ | ||
# @test kf_ll ≈ srkf_ll | ||
|
||
# Test with StaticArrays.jl | ||
μ0_static = SVector{2}(μ0) | ||
Σ0_static = PDMat(SMatrix{2,2}(Σ0.mat)) | ||
A_static = SMatrix{2,2}(A) | ||
b_static = SVector{2}(b) | ||
Q_static = PDMat(SMatrix{2,2}(Q.mat)) | ||
H_static = SMatrix{2,2}(H) | ||
c_static = SVector{2}(c) | ||
R_static = PDMat(SMatrix{2,2}(R.mat)) | ||
model_static = create_homogeneous_linear_gaussian_model( | ||
μ0_static, Σ0_static, A_static, b_static, Q_static, H_static, c_static, R_static | ||
) | ||
observations_static = [SVector{2}(rand(rng, 2)) for _ in 1:T] | ||
srkf_state, _ = GeneralisedFilters.filter( | ||
rng, model_static, SRKF(), observations_static | ||
) | ||
# @test isa(srkf_state.μ, SVector) | ||
# @test isa(srkf_state.Σ.mat, SMatrix) | ||
Comment on lines
+103
to
+104
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what's up with the unit testing? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I need to fix the upstream issue with StaticArrays getting converted to regular arrays. Then I can comment these back out. They'll fail currently. |
||
end | ||
|
||
@testitem "Kalman smoother test" begin | ||
using GeneralisedFilters | ||
using Distributions | ||
|
@@ -61,7 +117,7 @@ end | |
for Dy in Dys | ||
rng = StableRNG(1234) | ||
model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy) | ||
_, _, ys = sample(rng, model, 2) | ||
_, ys = sample(rng, model, 2) | ||
|
||
states, ll = GeneralisedFilters.smooth(rng, model, KalmanSmoother(), ys) | ||
|
||
|
@@ -88,7 +144,7 @@ end | |
|
||
rng = StableRNG(1234) | ||
model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) | ||
_, _, ys = sample(rng, model, 10) | ||
_, ys = sample(rng, model, 10) | ||
|
||
bf = BF(2^12; threshold=0.8) | ||
bf_state, llbf = GeneralisedFilters.filter(rng, model, bf, ys) | ||
|
@@ -128,7 +184,7 @@ end | |
|
||
rng = StableRNG(1234) | ||
model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) | ||
_, _, ys = sample(rng, model, 10) | ||
_, ys = sample(rng, model, 10) | ||
|
||
algo = PF(2^10, LinearGaussianProposal(); threshold=0.6) | ||
kf_states, kf_ll = GeneralisedFilters.filter(rng, model, KalmanFilter(), ys) | ||
|
@@ -212,7 +268,7 @@ end | |
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( | ||
rng, D_outer, D_inner, D_obs | ||
) | ||
_, _, ys = sample(rng, full_model, T) | ||
_, ys = sample(rng, full_model, T) | ||
|
||
# Ground truth Kalman filtering | ||
kf_states, kf_ll = GeneralisedFilters.filter(rng, full_model, KalmanFilter(), ys) | ||
|
@@ -255,7 +311,7 @@ end | |
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( | ||
rng, D_outer, D_inner, D_obs, ET | ||
) | ||
_, _, ys = sample(rng, full_model, T) | ||
_, ys = sample(rng, full_model, T) | ||
|
||
# Ground truth Kalman filtering | ||
kf_state, kf_ll = GeneralisedFilters.filter(full_model, KalmanFilter(), ys) | ||
|
@@ -290,7 +346,7 @@ end | |
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( | ||
rng, 1, 1, 1 | ||
) | ||
_, _, ys = sample(rng, full_model, T) | ||
_, ys = sample(rng, full_model, T) | ||
|
||
# Manually create tree to force expansion on second step | ||
particle_type = GeneralisedFilters.RaoBlackwellisedParticle{ | ||
|
@@ -325,7 +381,7 @@ end | |
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( | ||
rng, D_outer, D_inner, D_obs | ||
) | ||
_, _, ys = sample(rng, full_model, T) | ||
_, ys = sample(rng, full_model, T) | ||
|
||
# Ground truth Kalman filtering | ||
kf_states, kf_ll = GeneralisedFilters.filter(rng, full_model, KalmanFilter(), ys) | ||
|
@@ -373,7 +429,7 @@ end | |
|
||
rng = StableRNG(SEED) | ||
model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, 1, 1) | ||
_, _, ys = sample(rng, model, K) | ||
_, ys = sample(rng, model, K) | ||
|
||
ref_traj = OffsetVector([rand(rng, 1) for _ in 0:K], -1) | ||
|
||
|
@@ -413,7 +469,7 @@ end | |
|
||
rng = StableRNG(SEED) | ||
model = GeneralisedFilters.GFTest.create_linear_gaussian_model(rng, Dx, Dy) | ||
_, _, ys = sample(rng, model, K) | ||
_, ys = sample(rng, model, K) | ||
|
||
# Kalman smoother | ||
state, ks_ll = GeneralisedFilters.smooth( | ||
|
@@ -477,7 +533,7 @@ end | |
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( | ||
rng, D_outer, D_inner, D_obs, T; static_arrays=true | ||
) | ||
_, _, ys = sample(rng, full_model, K) | ||
_, ys = sample(rng, full_model, K) | ||
|
||
# Kalman smoother | ||
state, _ = GeneralisedFilters.smooth( | ||
|
@@ -562,7 +618,7 @@ end | |
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( | ||
rng, D_outer, D_inner, D_obs, T | ||
) | ||
_, _, ys = sample(rng, full_model, K) | ||
_, ys = sample(rng, full_model, K) | ||
|
||
# Generate random reference trajectory | ||
ref_trajectory = [CuArray(rand(rng, T, D_outer, 1)) for _ in 0:K] | ||
|
@@ -595,7 +651,7 @@ end | |
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( | ||
rng, D_outer, D_inner, D_obs, T | ||
) | ||
_, _, ys = sample(rng, full_model, K) | ||
_, ys = sample(rng, full_model, K) | ||
|
||
# Manually create tree to force expansion on second step | ||
M = N_particles * 2 - 1 | ||
|
@@ -642,7 +698,7 @@ end | |
full_model, hier_model = GeneralisedFilters.GFTest.create_dummy_linear_gaussian_model( | ||
rng, D_outer, D_inner, D_obs, T | ||
) | ||
_, _, ys = sample(rng, full_model, K) | ||
_, ys = sample(rng, full_model, K) | ||
|
||
# Kalman smoother | ||
state, _ = GeneralisedFilters.smooth( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as before