Skip to content

Disentangle Priors from Dynamics #101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft

Disentangle Priors from Dynamics #101

wants to merge 11 commits into from

Conversation

charlesknipp
Copy link
Member

See #100 for details.

Summary

Upon separating the initial state prior from the latent dynamics, we can determine the particle type when drawing from the prior. Assuming ones code is type stable, then there's no need to cast the type signatures from both LatentDynamics and ObservationProcess.

The only issue is with preallocating the weights. I propose a solution (see the general-filters example in SSMProblems), which defines ParticleDistribution as an abstract type. This has its clear downsides (especially with respect to callbacks), but it keeps things type agnostic enough to where we no longer need to determine the weight type prior to evaluating update().

This self contained exploration of the new interface really serves as a feasibility check for GeneralisedFilters. It's not feature complete, and should only serve as a means of experimentation. Feel free to change whatever you want in /SSMProblems/examples/general-filters.

Copy link
Contributor

SSMProblems.jl/GeneralisedFilters documentation for PR #101 is available at:
https://TuringLang.github.io/SSMProblems.jl/GeneralisedFilters/previews/PR101/

@charlesknipp
Copy link
Member Author

Summary of Changes

  • Update the GeneralisedFilters module to account for the inclusion of the initial state prior.

  • Define the ParticleDistribution abstractly such that we introduce a weighted and unweighted variant, which encourage a type stable (in terms of the lowered code) alternative to managing weighted samples.

  • Update the batching to remove the reliance on the arithmetic type. This may be more expensive computationally, but I think @THargreaves may have a solution for this.

Why Remove the Arithmetic Type?

When defining a model, the type signature of a state space model included the arithmetic type. Before this was a necessary element to preallocating the weights and log evidence of the model. While this seemed like a necessary element of estimation, it was merely a hacky way for the user to tell the compiler what type to expect from logdensity.

The arithmetic type would inform the algorithm what to expect, but not necessarily enforce it. Therefore, a model could be defined with an arithmetic type different than the return type of logdensity. At best, this would yield an error, but more dangerously this could induce an expensive type conversion unbeknownst to the user.

Caveats

Due to the removal of the type signatures (both eltype and arithmetic_type), callbacks are in rough shape. Instead of preallocating the object before passing it to the inference algorithm, callbacks must be mutable and untyped such that the PostInit trigger preallocates the necessary information.

This is something that I hope to resolve a little nicer. Currently, this is less than ideal. If anyone has an idea on how to resolve this, it would greatly improve my confidence in this PR.

Note to Reviewers

This passes all the unit tests on the CPU, but the GPU implementations are less than stellar. I also haven't benchmarked the new changes so I can't honestly say whether speed has changed.

@THargreaves: I need some help on the GPU side of things, since it is clearly broken.

@mhauru: Since you have the software background, I would love for you to take a look at the ParticleDistribution side of things.

In Conclusion

This is still very much a work in progress, but I figured pushing some of my changes would make it a little easier to see the vision.

As always, I welcome any and all changes, suggestions, and counterpoints. Especially since this is a huge, but necessary amendment to the interface.

@THargreaves
Copy link
Collaborator

I'll jot down a few thoughts now though I will have a proper think through this and do a more comprehensive write-up considering the GPU and batched cases when I have a moment. I think this is a crucial design decision so very keen to get it right (allow general batching/GPU, speed + type stability, autodiff).

When defining a model, the type signature of a state space model included the arithmetic type. Before this was a necessary element to preallocating the weights and log evidence of the model. While this seemed like a necessary element of estimation, it was merely a hacky way for the user to tell the compiler what type to expect from logdensity.

I agree that the arithmetic type likely should be removed, though I'm approaching it from a slightly different perspective. I didn't like how the filtering was always initialised with a T(0.0). In the batch GPU case this should be a CuVector{T} or a Vector{T} in the CPU case.

I like that this new approach seems to work more abstractly in the sense that the filter steps seem to care less about what their inputs are and just operate on what they've been given through dispatch.

Due to the removal of the type signatures (both eltype and arithmetic_type), callbacks are in rough shape. Instead of preallocating the object before passing it to the inference algorithm, callbacks must be mutable and untyped such that the PostInit trigger preallocates the necessary information.

Or the user has to figure out by themselves what type the future states are going to be? (which would be a hassle but doable in most cases) Am I understanding that right?

@charlesknipp
Copy link
Member Author

Or the user has to figure out by themselves what type the future states are going to be? (which would be a hassle but doable in most cases) Am I understanding that right?

This is correct, but it seems a little hostile to force the user to preallocate types. This is especially true when they want to run the same callback/model combination with different filtering algorithms.

While my current execution is lackluster and unstable, it is technically proof that this could work. Granted, this is not a hill I'm willing to die on. If callbacks must be preallocated by hand, I'll persevere.

@charlesknipp
Copy link
Member Author

Type Stability

I finally have the module in a fully type stable format by borrowing from the iterator philosophy. The first iteration initializes the state, and subsequent passes operate on those already instantiated objects.

The idea is to use the first iteration given the uninitialized state (unweighted particles drawn from the prior) to begin iteration, this allows us to allocate state with the proper type information, and initialize log_evidence with the return type of update.

Minor Cleanup

I was bothered by some recycled code in particles, which redefines the step for particle filters. Instead, I consider the particle filter as a resample-move process where step includes resample and move (predict + update).

I also removed the GaussianDistributions dependency since we only used Gaussian as a container. As an alternative I defined GaussianDistribution, which is almost no different than before.

I added a Base.iterate to the ParticleDistribution which cleans up all the map statements.

The boilerplate required for callbacks is also much cleaner, since I added a default method call for the type union of nothing and <:AbstractCallback

The Sacrifice

In order to get properly type stable log evidence, I had to drop the marginalization term. For most applications, this is no big deal. I defined a custom logmeanexp which is called on the incremental observation weights. Since this neglects the transition and proposal kernels, it will lead to incorrect likelihoods for the guided filter.

Funnily enough, the guided filter unit test works remarkably well. It passes with a higher tolerance than before. However, when running VSMC, the likelihood surpasses the Kalman filter after 50ish iterations. When adding the transition weights back into the computation, this issue disappears.

What's Left?

We need to fix the log-likelihood computation. Since the guided filter is commonly neglected, this is a hill I am willing to die on. I have a couple ideas on how this could work:

  • return an additional value from predict representing the log increments
  • define the weights via a LazyArray somehow caching the incremental weights (only those computed at time $t$)

I gravitate more towards the latter since this intuits similar returns from the Kalman/particle filters. However, I don't know how costly it would be to store additional elements in the particle distributions. There is likely a brutal trade-off unless someone wants to get creative.

Copy link
Member

@mhauru mhauru left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not qualified to really review this PR, but I gave it a superficial skim and threw around a few ideas in case they are helpful.

A couple other thoughts:

  • I love me a docstring, and if I was trying to properly understand what is going on, more of them, even for internal functions, would be very helpful.
  • Since it sounds like type stability is important and non-trivial to achieve, you could consider testing for it using @inferred in your tests.

Comment on lines +63 to +64
# As = CuArray{T}(undef, size(dyn.A)..., N)
# As[:, :, :] .= cu(dyn.A)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may well be on your todo list already, but would be good to remove commented out snippets like this one.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left those in for when @THargreaves eventually fixes the GPU elements of this PR, but there are definitely some left over which need to be removed.

end

function predict(
rng::AbstractRNG,
model::StateSpaceModel,
filter::ParticleFilter,
algo::ParticleFilter,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is a bit of an omnibus PR that changes a bazillion things and there's no avoiding that then this is fine, but otherwise could be helpful in the future if renamings like these would have their own PRs, and then the meaty PRs that change functioning would be lighter to review.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also true. I just decided to change the some argument names for homogeneity sake. Probably could have been left to a separate PR.

mutable struct ParticleDistribution{PT,WT<:Real}
abstract type ParticleDistribution{PT} end

Base.collect(state::PT) where {PT<:ParticleDistribution} = state.particles
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Base.collect(state::PT) where {PT<:ParticleDistribution} = state.particles
Base.collect(state::ParticleDistribution) = state.particles

I think this is equivalent. Same for several below.

Also, since ParticleDistribution is an abstract type, and here we are getting a literal field of it called .particles, would be good to document that all subtypes of ParticleDistribution must have such a field. Or, maybe even better, wrap it in a getparticles function, and then document that all subtypes must implement getparticles.

# not sure if this is kosher, since it doesn't follow the convention of Base.getindex
Base.@propagate_inbounds Base.getindex(state::ParticleDistribution, i) = state.particles[i]

mutable struct Particles{PT} <: ParticleDistribution{PT}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I notice a lot of mutable structs in this PR. There may be a very good reason for them, but could be worth considering whether some of them could be made immutable, and if yes, check how it affects performance. Immutable ones are often faster, and I think most times easier to reason about.

If you end up using a mixture of mutable and immutable types, and e.g. have abstract types that have both immutable and mutable subtypes, consider using the !! convention from https://github.com/juliafolds/bangbang.jl: A function whose name ends in !! may or may not mutate its input, but is always guaranteed to return the necessary output as its return value. I find it very helpful for not having to worry about the mutable/immutable distinction when calling, and letting the underlying implementations worry about when mutation makes sense. (The downside is that sometimes !! forces you to use copy unnecessarily.)

Also, and sorry if I'm telling you something you already know, but if all you want to do is to mutate the particles and ancestors vectors, but not reassign entirely new vectors to an existing Particles object, then you don't need Particles to be mutable.

function reset_weights!(state::ParticleDistribution{T,WT}) where {T,WT<:Real}
fill!(state.log_weights, zero(WT))
return state.log_weights
function update_weights(state::WeightedParticles, log_weights)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
function update_weights(state::WeightedParticles, log_weights)
function update_weights!(state::WeightedParticles, log_weights)

By convention, functions that mutate their inputs should have their name end in !, per https://docs.julialang.org/en/v1/manual/style-guide/#bang-convention.

There may be more instances of this in the PR that I haven't noticed, just spotted this one.

Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SSMProblems = "26aad666-b158-4e64-9d35-0e672562fa48"
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
UnPack = "3a884ed6-31ef-47d7-9d2a-63182c4928ed"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it's just an example, but could consider compat bounds for these.

@@ -0,0 +1,210 @@
using SSMProblems
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this file have a docstring or a README to accompany it?

@@ -273,37 +217,25 @@ abstract type AbstractStateSpaceModel <: AbstractMCMC.AbstractModel end
"""
A state space model.

A vanilla implementation of a state space model, composed of a latent dynamics and an
observation process.
A vanilla implementation of a state space model, composed of an intiail state prior, latent
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
A vanilla implementation of a state space model, composed of an intiail state prior, latent
A vanilla implementation of a state space model, composed of an initial state prior, latent

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants