Skip to content

Disentangle Priors from Dynamics #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions GeneralisedFilters/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ AcceleratedKernels = "6a4ca0a5-0e36-4168-a932-d9be78d558f1"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
GaussianDistributions = "43dcc890-d446-5863-8d1a-14597580bb8d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Expand All @@ -28,7 +27,6 @@ Aqua = "0.8"
CUDA = "5"
DataStructures = "0.18.20"
Distributions = "0.25"
GaussianDistributions = "0.5.2"
LogExpFunctions = "0.3"
NNlib = "0.9"
OffsetArrays = "1.14.1"
Expand Down
87 changes: 48 additions & 39 deletions GeneralisedFilters/src/GFTest/models/dummy_linear_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,60 +22,64 @@ export InnerDynamics, create_dummy_linear_gaussian_model
Linear Gaussian dynamics conditonal on the (previous) outer state (u_t), defined by:
x_{t+1} = A x_t + b + C u_t + w_t
"""
struct InnerDynamics{
T,TMat<:AbstractMatrix{T},TVec<:AbstractVector{T},TCov<:AbstractMatrix{T}
} <: LinearGaussianLatentDynamics{T}
μ0::TVec
Σ0::TCov
struct InnerDynamics{TMat<:AbstractMatrix,TVec<:AbstractVector,TCov<:AbstractMatrix} <:
LinearGaussianLatentDynamics
A::TMat
b::TVec
C::TMat
Q::TCov
end

struct InnerPrior{TVec<:AbstractVector,TCov<:AbstractMatrix} <: GaussianPrior
μ0::TVec
Σ0::TCov
end

# CPU methods
GeneralisedFilters.calc_μ0(dyn::InnerDynamics; kwargs...) = dyn.μ0
GeneralisedFilters.calc_Σ0(dyn::InnerDynamics; kwargs...) = dyn.Σ0
GeneralisedFilters.calc_μ0(prior::InnerPrior; kwargs...) = prior.μ0
GeneralisedFilters.calc_Σ0(prior::InnerPrior; kwargs...) = prior.Σ0
GeneralisedFilters.calc_A(dyn::InnerDynamics, ::Integer; kwargs...) = dyn.A
function GeneralisedFilters.calc_b(dyn::InnerDynamics, ::Integer; prev_outer, kwargs...)
return dyn.b + dyn.C * prev_outer
end
GeneralisedFilters.calc_Q(dyn::InnerDynamics, ::Integer; kwargs...) = dyn.Q

# GPU methods
function GeneralisedFilters.batch_calc_μ0s(dyn::InnerDynamics{T}, N; kwargs...) where {T}
μ0s = CuArray{T}(undef, length(dyn.μ0), N)
return μ0s[:, :] .= cu(dyn.μ0)
function GeneralisedFilters.batch_calc_μ0s(prior::InnerPrior, N; kwargs...)
# μ0s = CuArray{T}(undef, length(prior.μ0), N)
# return μ0s[:, :] .= cu(prior.μ0)
return repeat(cu(prior.μ0), 1, N)
end

function GeneralisedFilters.batch_calc_Σ0s(
dyn::InnerDynamics{T}, N::Integer; kwargs...
) where {T}
Σ0s = CuArray{T}(undef, size(dyn.Σ0)..., N)
return Σ0s[:, :, :] .= cu(dyn.Σ0)
function GeneralisedFilters.batch_calc_Σ0s(prior::InnerPrior, N::Integer; kwargs...)
# Σ0s = CuArray{T}(undef, size(dyn.Σ0)..., N)
# return Σ0s[:, :, :] .= cu(dyn.Σ0)
return repeat(cu(prior.Σ0), 1, N)
end

function GeneralisedFilters.batch_calc_As(
dyn::InnerDynamics{T}, ::Integer, N::Integer; kwargs...
) where {T}
As = CuArray{T}(undef, size(dyn.A)..., N)
As[:, :, :] .= cu(dyn.A)
return As
dyn::InnerDynamics, ::Integer, N::Integer; kwargs...
)
# As = CuArray{T}(undef, size(dyn.A)..., N)
# As[:, :, :] .= cu(dyn.A)
Comment on lines +63 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may well be on your todo list already, but would be good to remove commented out snippets like this one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left those in for when @THargreaves eventually fixes the GPU elements of this PR, but there are definitely some left over which need to be removed.

return repeat(cu(dyn.A), 1, N)
end

function GeneralisedFilters.batch_calc_bs(
dyn::InnerDynamics{T}, ::Integer, N::Integer; prev_outer, kwargs...
) where {T}
Cs = CuArray{T}(undef, size(dyn.C)..., N)
Cs[:, :, :] .= cu(dyn.C)
dyn::InnerDynamics, ::Integer, N::Integer; prev_outer, kwargs...
)
# Cs = CuArray{T}(undef, size(dyn.C)..., N)
# Cs[:, :, :] .= cu(dyn.C)
Cs = repeat(cu(dyn.C), 1, N)
return NNlib.batched_vec(Cs, prev_outer) .+ cu(dyn.b)
end

function GeneralisedFilters.batch_calc_Qs(
dyn::InnerDynamics{T}, ::Integer, N::Integer; kwargs...
) where {T}
Q = CuArray{T}(undef, size(dyn.Q)..., N)
return Q[:, :, :] .= cu(dyn.Q)
dyn::InnerDynamics, ::Integer, N::Integer; kwargs...
)
# Q = CuArray{T}(undef, size(dyn.Q)..., N)
# return Q[:, :, :] .= cu(dyn.Q)
return repeat(cu(dyn.Q), 1, N)
end

function create_dummy_linear_gaussian_model(
Expand Down Expand Up @@ -113,36 +117,41 @@ function create_dummy_linear_gaussian_model(
full_model = create_homogeneous_linear_gaussian_model(μ0, Σ0s, A, b, Q, H, c, R)

# Create hierarchical model
outer_prior = GeneralisedFilters.HomogeneousGaussianPrior(
μ0[1:D_outer], Σ0s[1:D_outer, 1:D_outer]
)

outer_dyn = GeneralisedFilters.HomogeneousLinearGaussianLatentDynamics(
μ0[1:D_outer],
Σ0s[1:D_outer, 1:D_outer],
A[1:D_outer, 1:D_outer],
b[1:D_outer],
Q[1:D_outer, 1:D_outer],
A[1:D_outer, 1:D_outer], b[1:D_outer], Q[1:D_outer, 1:D_outer]
)
inner_dyn = if static_arrays
InnerDynamics(

inner_prior, inner_dyn = if static_arrays
prior = InnerPrior(
SVector{D_inner,T}(μ0[(D_outer + 1):end]),
SMatrix{D_inner,D_inner,T}(Σ0s[(D_outer + 1):end, (D_outer + 1):end]),
)
dyn = InnerDynamics(
SMatrix{D_inner,D_outer,T}(A[(D_outer + 1):end, (D_outer + 1):end]),
SVector{D_inner,T}(b[(D_outer + 1):end]),
SMatrix{D_inner,D_outer,T}(A[(D_outer + 1):end, 1:D_outer]),
SMatrix{D_inner,D_inner,T}(Q[(D_outer + 1):end, (D_outer + 1):end]),
)
prior, dyn
else
InnerDynamics(
μ0[(D_outer + 1):end],
Σ0s[(D_outer + 1):end, (D_outer + 1):end],
prior = InnerPrior(μ0[(D_outer + 1):end], Σ0s[(D_outer + 1):end, (D_outer + 1):end])
dyn = InnerDynamics(
A[(D_outer + 1):end, (D_outer + 1):end],
b[(D_outer + 1):end],
A[(D_outer + 1):end, 1:D_outer],
Q[(D_outer + 1):end, (D_outer + 1):end],
)
prior, dyn
end

obs = GeneralisedFilters.HomogeneousLinearGaussianObservationProcess(
H[:, (D_outer + 1):end], c, R
)
hier_model = HierarchicalSSM(outer_dyn, inner_dyn, obs)
hier_model = HierarchicalSSM(outer_prior, outer_dyn, inner_prior, inner_dyn, obs)

return full_model, hier_model
end
3 changes: 2 additions & 1 deletion GeneralisedFilters/src/GFTest/models/linear_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ function create_linear_gaussian_model(
end

function _compute_joint(model, T::Integer)
(; μ0, Σ0, A, b, Q) = model.dyn
(; μ0, Σ0) = model.prior
(; A, b, Q) = model.dyn
(; H, c, R) = model.obs
Dy, Dx = size(H)

Expand Down
75 changes: 41 additions & 34 deletions GeneralisedFilters/src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ module GeneralisedFilters
using AbstractMCMC: AbstractMCMC, AbstractSampler
import Distributions: MvNormal
import Random: AbstractRNG, default_rng, rand
using GaussianDistributions: pairs, Gaussian
using OffsetArrays
using SSMProblems
using StatsBase
Expand All @@ -23,98 +22,106 @@ abstract type AbstractFilter <: AbstractSampler end
abstract type AbstractBatchFilter <: AbstractFilter end

"""
initialise([rng,] model, alg; kwargs...)
initialise([rng,] model, algo; kwargs...)

Propose an initial state distribution.
"""
function initialise end

"""
step([rng,] model, alg, iter, state, observation; kwargs...)
step([rng,] model, algo, iter, state, observation; kwargs...)

Perform a combined predict and update call of the filtering on the state.
"""
function step end

"""
predict([rng,] model, alg, iter, filtered; kwargs...)
predict([rng,] model, algo, iter, filtered; kwargs...)

Propagate the filtered distribution forward in time.
"""
function predict end

"""
update(model, alg, iter, proposed, observation; kwargs...)
update(model, algo, iter, proposed, observation; kwargs...)

Update beliefs on the propagated distribution given an observation.
"""
function update end

function initialise(model, alg; kwargs...)
return initialise(default_rng(), model, alg; kwargs...)
function initialise(model, algo; kwargs...)
return initialise(default_rng(), model, algo; kwargs...)
end

function predict(model, alg, step, filtered, observation; kwargs...)
return predict(default_rng(), model, alg, step, filtered; kwargs...)
function predict(model, algo, step, filtered, observation; kwargs...)
return predict(default_rng(), model, algo, step, filtered; kwargs...)
end

function filter(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
alg::AbstractFilter,
algo::AbstractFilter,
observations::AbstractVector;
callback::Union{AbstractCallback,Nothing}=nothing,
callback::CallbackType=nothing,
kwargs...,
)
state = initialise(rng, model, alg; kwargs...)
isnothing(callback) || callback(model, alg, state, observations, PostInit; kwargs...)
# draw from the prior
init_state = initialise(rng, model, algo; kwargs...)
callback(model, algo, init_state, observations, PostInit; kwargs...)

log_evidence = initialise_log_evidence(alg, model)
# iterations starts here for type stability
state, log_evidence = step(
rng, model, algo, 1, init_state, observations[1]; callback, kwargs...
)

for t in eachindex(observations)
# subsequent iteration
for t in 2:length(observations)
state, ll_increment = step(
rng, model, alg, t, state, observations[t]; callback, kwargs...
rng, model, algo, t, state, observations[t]; callback, kwargs...
)
log_evidence += ll_increment
end

return state, log_evidence
end

function initialise_log_evidence(::AbstractFilter, model::AbstractStateSpaceModel)
return zero(SSMProblems.arithmetic_type(model))
end

function initialise_log_evidence(alg::AbstractBatchFilter, model::AbstractStateSpaceModel)
return CUDA.zeros(SSMProblems.arithmetic_type(model), alg.batch_size)
end

function filter(
model::AbstractStateSpaceModel,
alg::AbstractFilter,
algo::AbstractFilter,
observations::AbstractVector;
kwargs...,
)
return filter(default_rng(), model, alg, observations; kwargs...)
return filter(default_rng(), model, algo, observations; kwargs...)
end

function step(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
alg::AbstractFilter,
algo::AbstractFilter,
iter::Integer,
state,
observation;
kwargs...,
)
# generalised to fit analytical filters
return move(rng, model, algo, iter, state, observation; kwargs...)
end

function move(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
algo::AbstractFilter,
iter::Integer,
state,
observation;
callback::Union{AbstractCallback,Nothing}=nothing,
callback::CallbackType=nothing,
kwargs...,
)
state = predict(rng, model, alg, iter, state, observation; kwargs...)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostPredict; kwargs...)
state = predict(rng, model, algo, iter, state, observation; kwargs...)
callback(model, algo, iter, state, observation, PostPredict; kwargs...)

state, ll_increment = update(model, alg, iter, state, observation; kwargs...)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostUpdate; kwargs...)
state, ll_increment = update(model, algo, iter, state, observation; kwargs...)
callback(model, algo, iter, state, observation, PostUpdate; kwargs...)

return state, ll_increment
end
Expand Down
6 changes: 3 additions & 3 deletions GeneralisedFilters/src/algorithms/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ const FW = ForwardAlgorithm
function initialise(
rng::AbstractRNG, model::DiscreteStateSpaceModel, ::ForwardAlgorithm; kwargs...
)
return calc_α0(model.dyn; kwargs...)
return calc_α0(model.prior; kwargs...)
end

function predict(
rng::AbstractRNG,
model::DiscreteStateSpaceModel{T},
model::DiscreteStateSpaceModel,
filter::ForwardAlgorithm,
step::Integer,
states::AbstractVector,
observation;
kwargs...,
) where {T}
)
P = calc_P(model.dyn, step; kwargs...)
return (states' * P)'
end
Expand Down
Loading
Loading