Skip to content

Disentangle Priors from Dynamics #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
87 changes: 48 additions & 39 deletions GeneralisedFilters/src/GFTest/models/dummy_linear_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,60 +22,64 @@ export InnerDynamics, create_dummy_linear_gaussian_model
Linear Gaussian dynamics conditonal on the (previous) outer state (u_t), defined by:
x_{t+1} = A x_t + b + C u_t + w_t
"""
struct InnerDynamics{
T,TMat<:AbstractMatrix{T},TVec<:AbstractVector{T},TCov<:AbstractMatrix{T}
} <: LinearGaussianLatentDynamics{T}
μ0::TVec
Σ0::TCov
struct InnerDynamics{TMat<:AbstractMatrix,TVec<:AbstractVector,TCov<:AbstractMatrix} <:
LinearGaussianLatentDynamics
A::TMat
b::TVec
C::TMat
Q::TCov
end

struct InnerPrior{TVec<:AbstractVector,TCov<:AbstractMatrix} <: GaussianPrior
μ0::TVec
Σ0::TCov
end

# CPU methods
GeneralisedFilters.calc_μ0(dyn::InnerDynamics; kwargs...) = dyn.μ0
GeneralisedFilters.calc_Σ0(dyn::InnerDynamics; kwargs...) = dyn.Σ0
GeneralisedFilters.calc_μ0(prior::InnerPrior; kwargs...) = prior.μ0
GeneralisedFilters.calc_Σ0(prior::InnerPrior; kwargs...) = prior.Σ0
GeneralisedFilters.calc_A(dyn::InnerDynamics, ::Integer; kwargs...) = dyn.A
function GeneralisedFilters.calc_b(dyn::InnerDynamics, ::Integer; prev_outer, kwargs...)
return dyn.b + dyn.C * prev_outer
end
GeneralisedFilters.calc_Q(dyn::InnerDynamics, ::Integer; kwargs...) = dyn.Q

# GPU methods
function GeneralisedFilters.batch_calc_μ0s(dyn::InnerDynamics{T}, N; kwargs...) where {T}
μ0s = CuArray{T}(undef, length(dyn.μ0), N)
return μ0s[:, :] .= cu(dyn.μ0)
function GeneralisedFilters.batch_calc_μ0s(prior::InnerPrior, N; kwargs...)
# μ0s = CuArray{T}(undef, length(prior.μ0), N)
# return μ0s[:, :] .= cu(prior.μ0)
return repeat(cu(prior.μ0), 1, N)
end

function GeneralisedFilters.batch_calc_Σ0s(
dyn::InnerDynamics{T}, N::Integer; kwargs...
) where {T}
Σ0s = CuArray{T}(undef, size(dyn.Σ0)..., N)
return Σ0s[:, :, :] .= cu(dyn.Σ0)
function GeneralisedFilters.batch_calc_Σ0s(prior::InnerPrior, N::Integer; kwargs...)
# Σ0s = CuArray{T}(undef, size(dyn.Σ0)..., N)
# return Σ0s[:, :, :] .= cu(dyn.Σ0)
return repeat(cu(prior.Σ0), 1, N)
end

function GeneralisedFilters.batch_calc_As(
dyn::InnerDynamics{T}, ::Integer, N::Integer; kwargs...
) where {T}
As = CuArray{T}(undef, size(dyn.A)..., N)
As[:, :, :] .= cu(dyn.A)
return As
dyn::InnerDynamics, ::Integer, N::Integer; kwargs...
)
# As = CuArray{T}(undef, size(dyn.A)..., N)
# As[:, :, :] .= cu(dyn.A)
Comment on lines +63 to +64
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may well be on your todo list already, but would be good to remove commented out snippets like this one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left those in for when @THargreaves eventually fixes the GPU elements of this PR, but there are definitely some left over which need to be removed.

return repeat(cu(dyn.A), 1, N)
end

function GeneralisedFilters.batch_calc_bs(
dyn::InnerDynamics{T}, ::Integer, N::Integer; prev_outer, kwargs...
) where {T}
Cs = CuArray{T}(undef, size(dyn.C)..., N)
Cs[:, :, :] .= cu(dyn.C)
dyn::InnerDynamics, ::Integer, N::Integer; prev_outer, kwargs...
)
# Cs = CuArray{T}(undef, size(dyn.C)..., N)
# Cs[:, :, :] .= cu(dyn.C)
Cs = repeat(cu(dyn.C), 1, N)
return NNlib.batched_vec(Cs, prev_outer) .+ cu(dyn.b)
end

function GeneralisedFilters.batch_calc_Qs(
dyn::InnerDynamics{T}, ::Integer, N::Integer; kwargs...
) where {T}
Q = CuArray{T}(undef, size(dyn.Q)..., N)
return Q[:, :, :] .= cu(dyn.Q)
dyn::InnerDynamics, ::Integer, N::Integer; kwargs...
)
# Q = CuArray{T}(undef, size(dyn.Q)..., N)
# return Q[:, :, :] .= cu(dyn.Q)
return repeat(cu(dyn.Q), 1, N)
end

function create_dummy_linear_gaussian_model(
Expand Down Expand Up @@ -113,36 +117,41 @@ function create_dummy_linear_gaussian_model(
full_model = create_homogeneous_linear_gaussian_model(μ0, Σ0s, A, b, Q, H, c, R)

# Create hierarchical model
outer_prior = GeneralisedFilters.HomogeneousGaussianPrior(
μ0[1:D_outer], Σ0s[1:D_outer, 1:D_outer]
)

outer_dyn = GeneralisedFilters.HomogeneousLinearGaussianLatentDynamics(
μ0[1:D_outer],
Σ0s[1:D_outer, 1:D_outer],
A[1:D_outer, 1:D_outer],
b[1:D_outer],
Q[1:D_outer, 1:D_outer],
A[1:D_outer, 1:D_outer], b[1:D_outer], Q[1:D_outer, 1:D_outer]
)
inner_dyn = if static_arrays
InnerDynamics(

inner_prior, inner_dyn = if static_arrays
prior = InnerPrior(
SVector{D_inner,T}(μ0[(D_outer + 1):end]),
SMatrix{D_inner,D_inner,T}(Σ0s[(D_outer + 1):end, (D_outer + 1):end]),
)
dyn = InnerDynamics(
SMatrix{D_inner,D_outer,T}(A[(D_outer + 1):end, (D_outer + 1):end]),
SVector{D_inner,T}(b[(D_outer + 1):end]),
SMatrix{D_inner,D_outer,T}(A[(D_outer + 1):end, 1:D_outer]),
SMatrix{D_inner,D_inner,T}(Q[(D_outer + 1):end, (D_outer + 1):end]),
)
prior, dyn
else
InnerDynamics(
μ0[(D_outer + 1):end],
Σ0s[(D_outer + 1):end, (D_outer + 1):end],
prior = InnerPrior(μ0[(D_outer + 1):end], Σ0s[(D_outer + 1):end, (D_outer + 1):end])
dyn = InnerDynamics(
A[(D_outer + 1):end, (D_outer + 1):end],
b[(D_outer + 1):end],
A[(D_outer + 1):end, 1:D_outer],
Q[(D_outer + 1):end, (D_outer + 1):end],
)
prior, dyn
end

obs = GeneralisedFilters.HomogeneousLinearGaussianObservationProcess(
H[:, (D_outer + 1):end], c, R
)
hier_model = HierarchicalSSM(outer_dyn, inner_dyn, obs)
hier_model = HierarchicalSSM(outer_prior, outer_dyn, inner_prior, inner_dyn, obs)

return full_model, hier_model
end
3 changes: 2 additions & 1 deletion GeneralisedFilters/src/GFTest/models/linear_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ function create_linear_gaussian_model(
end

function _compute_joint(model, T::Integer)
(; μ0, Σ0, A, b, Q) = model.dyn
(; μ0, Σ0) = model.prior
(; A, b, Q) = model.dyn
(; H, c, R) = model.obs
Dy, Dx = size(H)

Expand Down
48 changes: 24 additions & 24 deletions GeneralisedFilters/src/GeneralisedFilters.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,57 +23,57 @@ abstract type AbstractFilter <: AbstractSampler end
abstract type AbstractBatchFilter <: AbstractFilter end

"""
initialise([rng,] model, alg; kwargs...)
initialise([rng,] model, algo; kwargs...)

Propose an initial state distribution.
"""
function initialise end

"""
step([rng,] model, alg, iter, state, observation; kwargs...)
step([rng,] model, algo, iter, state, observation; kwargs...)

Perform a combined predict and update call of the filtering on the state.
"""
function step end

"""
predict([rng,] model, alg, iter, filtered; kwargs...)
predict([rng,] model, algo, iter, filtered; kwargs...)

Propagate the filtered distribution forward in time.
"""
function predict end

"""
update(model, alg, iter, proposed, observation; kwargs...)
update(model, algo, iter, proposed, observation; kwargs...)

Update beliefs on the propagated distribution given an observation.
"""
function update end

function initialise(model, alg; kwargs...)
return initialise(default_rng(), model, alg; kwargs...)
function initialise(model, algo; kwargs...)
return initialise(default_rng(), model, algo; kwargs...)
end

function predict(model, alg, step, filtered, observation; kwargs...)
return predict(default_rng(), model, alg, step, filtered; kwargs...)
function predict(model, algo, step, filtered, observation; kwargs...)
return predict(default_rng(), model, algo, step, filtered; kwargs...)
end

function filter(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
alg::AbstractFilter,
algo::AbstractFilter,
observations::AbstractVector;
callback::Union{AbstractCallback,Nothing}=nothing,
kwargs...,
)
state = initialise(rng, model, alg; kwargs...)
isnothing(callback) || callback(model, alg, state, observations, PostInit; kwargs...)
state = initialise(rng, model, algo; kwargs...)
isnothing(callback) || callback(model, algo, state, observations, PostInit; kwargs...)

log_evidence = initialise_log_evidence(alg, model)
log_evidence = initialise_log_evidence(algo, model)

for t in eachindex(observations)
state, ll_increment = step(
rng, model, alg, t, state, observations[t]; callback, kwargs...
rng, model, algo, t, state, observations[t]; callback, kwargs...
)
log_evidence += ll_increment
end
Expand All @@ -82,39 +82,39 @@ function filter(
end

function initialise_log_evidence(::AbstractFilter, model::AbstractStateSpaceModel)
return zero(SSMProblems.arithmetic_type(model))
return 0
end

function initialise_log_evidence(alg::AbstractBatchFilter, model::AbstractStateSpaceModel)
return CUDA.zeros(SSMProblems.arithmetic_type(model), alg.batch_size)
end
# function initialise_log_evidence(alg::AbstractBatchFilter, model::AbstractStateSpaceModel)
# return CUDA.zeros(SSMProblems.arithmetic_type(model), alg.batch_size)
# end

function filter(
model::AbstractStateSpaceModel,
alg::AbstractFilter,
algo::AbstractFilter,
observations::AbstractVector;
kwargs...,
)
return filter(default_rng(), model, alg, observations; kwargs...)
return filter(default_rng(), model, algo, observations; kwargs...)
end

function step(
rng::AbstractRNG,
model::AbstractStateSpaceModel,
alg::AbstractFilter,
algo::AbstractFilter,
iter::Integer,
state,
observation;
callback::Union{AbstractCallback,Nothing}=nothing,
kwargs...,
)
state = predict(rng, model, alg, iter, state, observation; kwargs...)
state = predict(rng, model, algo, iter, state, observation; kwargs...)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostPredict; kwargs...)
callback(model, algo, iter, state, observation, PostPredict; kwargs...)

state, ll_increment = update(model, alg, iter, state, observation; kwargs...)
state, ll_increment = update(model, algo, iter, state, observation; kwargs...)
isnothing(callback) ||
callback(model, alg, iter, state, observation, PostUpdate; kwargs...)
callback(model, algo, iter, state, observation, PostUpdate; kwargs...)

return state, ll_increment
end
Expand Down
6 changes: 3 additions & 3 deletions GeneralisedFilters/src/algorithms/forward.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,18 @@ const FW = ForwardAlgorithm
function initialise(
rng::AbstractRNG, model::DiscreteStateSpaceModel, ::ForwardAlgorithm; kwargs...
)
return calc_α0(model.dyn; kwargs...)
return calc_α0(model.prior; kwargs...)
end

function predict(
rng::AbstractRNG,
model::DiscreteStateSpaceModel{T},
model::DiscreteStateSpaceModel,
filter::ForwardAlgorithm,
step::Integer,
states::AbstractVector,
observation;
kwargs...,
) where {T}
)
P = calc_P(model.dyn, step; kwargs...)
return (states' * P)'
end
Expand Down
Loading
Loading