From ed6946c4c8ffcec0494b74ce6978bb4f13883c53 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:09:15 -0400 Subject: [PATCH 01/23] update to match the AdvancedVI@0.3 interface --- Project.toml | 4 +- src/variational/VariationalInference.jl | 185 ++++++++++++++++++++---- src/variational/advi.jl | 140 ------------------ src/variational/bijectors.jl | 70 +++++++++ 4 files changed, 227 insertions(+), 172 deletions(-) delete mode 100644 src/variational/advi.jl create mode 100644 src/variational/bijectors.jl diff --git a/Project.toml b/Project.toml index 459f11dcbd..6fa6daa2a1 100644 --- a/Project.toml +++ b/Project.toml @@ -38,6 +38,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" +UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" [weakdeps] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" @@ -54,7 +55,7 @@ Accessors = "0.1" AdvancedHMC = "0.3.0, 0.4.0, 0.5.2, 0.6" AdvancedMH = "0.8" AdvancedPS = "0.6.0" -AdvancedVI = "0.2" +AdvancedVI = "0.3.1" BangBang = "0.4.2" Bijectors = "0.14, 0.15" Compat = "4.15.0" @@ -85,6 +86,7 @@ Statistics = "1.6" StatsAPI = "1.6" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" +UnicodePlots = "3" julia = "1.10" [extras] diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 189d3f7001..db95093508 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -1,50 +1,173 @@ + module Variational -using DistributionsAD: DistributionsAD -using DynamicPPL: DynamicPPL -using StatsBase: StatsBase -using StatsFuns: StatsFuns -using LogDensityProblems: LogDensityProblems +using DynamicPPL +using ADTypes using Distributions +using LinearAlgebra +using LogDensityProblems +using Random +using UnicodePlots -using Random: Random +import ..Turing: DEFAULT_ADTYPE, PROGRESS import AdvancedVI import Bijectors # Reexports -using AdvancedVI: vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad -export vi, ADVI, ELBO, elbo, TruncatedADAGrad, DecayedADAGrad - -""" - make_logjoint(model::Model; weight = 1.0) -Constructs the logjoint as a function of latent variables, i.e. the map z → p(x ∣ z) p(z). -The weight used to scale the likelihood, e.g. when doing stochastic gradient descent one needs to -use `DynamicPPL.MiniBatch` context to run the `Model` with a weight `num_total_obs / batch_size`. -## Notes -- For sake of efficiency, the returned function is closes over an instance of `VarInfo`. This means that you *might* run into some weird behaviour if you call this method sequentially using different types; if that's the case, just generate a new one for each type using `make_logjoint`. -""" -function make_logjoint(model::DynamicPPL.Model; weight=1.0) - # setup +using AdvancedVI: RepGradELBO, ScoreGradELBO, DoG, DoWG +export vi, RepGradELBO, ScoreGradELBO, DoG, DoWG + +export meanfield_gaussian, fullrank_gaussian + +include("bijectors.jl") + +function make_logdensity(model::DynamicPPL.Model) + weight = 1.0 ctx = DynamicPPL.MiniBatchContext(DynamicPPL.DefaultContext(), weight) - f = DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) - return Base.Fix1(LogDensityProblems.logdensity, f) + return DynamicPPL.LogDensityFunction(model, DynamicPPL.VarInfo(model), ctx) +end + +function initialize_gaussian_scale( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + location::AbstractVector, + scale::AbstractMatrix; + num_samples::Int = 10, + num_max_trials::Int = 10, + reduce_factor = one(eltype(scale))/2 +) + prob = make_logdensity(model) + ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) + varinfo = DynamicPPL.VarInfo(model) + + n_trial = 0 + while true + q = AdvancedVI.MvLocationScale(location, scale, Normal()) + b = Bijectors.bijector(model; varinfo=varinfo) + q_trans = Bijectors.transformed(q, Bijectors.inverse(b)) + energy = mean(ℓπ, eachcol(rand(rng, q_trans, num_samples))) + + if isfinite(energy) + return scale + elseif n_trial == num_max_trials + error("Could not find an initial") + end + + scale = reduce_factor*scale + n_trial += 1 + end +end + +function meanfield_gaussian( + rng::Random.AbstractRNG, + model::DynamicPPL.Model, + location::Union{Nothing, <:AbstractVector} = nothing, + scale::Union{Nothing, <:Diagonal} = nothing; + kwargs... +) + varinfo = DynamicPPL.VarInfo(model) + # Use linked `varinfo` to determine the correct number of parameters. + # TODO: Replace with `length` once this is implemented for `VarInfo`. + varinfo_linked = DynamicPPL.link(varinfo, model) + num_params = length(varinfo_linked[:]) + + μ = if isnothing(location) + zeros(num_params) + else + @assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)." + location + end + + L = if isnothing(scale) + initialize_gaussian_scale(rng, model, μ, Diagonal(ones(num_params)); kwargs...) + else + @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." + L = scale + end + + q = AdvancedVI.MeanFieldGaussian(μ, L) + b = Bijectors.bijector(model; varinfo=varinfo) + return Bijectors.transformed(q, Bijectors.inverse(b)) end -# objectives -function (elbo::ELBO)( +function meanfield_gaussian( + model::DynamicPPL.Model, + location::Union{Nothing, <:AbstractVector} = nothing, + scale::Union{Nothing, <:Diagonal} = nothing; + kwargs... +) + meanfield_gaussian(Random.default_rng(), model, location, scale; kwargs...) +end + +function fullrank_gaussian( rng::Random.AbstractRNG, - alg::AdvancedVI.VariationalInference, - q, model::DynamicPPL.Model, - num_samples; - weight=1.0, - kwargs..., + location::Union{Nothing, <:AbstractVector} = nothing, + scale::Union{Nothing, <:LowerTriangular} = nothing; + kwargs... ) - return elbo(rng, alg, q, make_logjoint(model; weight=weight), num_samples; kwargs...) + varinfo = DynamicPPL.VarInfo(model) + # Use linked `varinfo` to determine the correct number of parameters. + # TODO: Replace with `length` once this is implemented for `VarInfo`. + varinfo_linked = DynamicPPL.link(varinfo, model) + num_params = length(varinfo_linked[:]) + + μ = if isnothing(location) + zeros(num_params) + else + @assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)." + location + end + + L = if isnothing(scale) + L0 = LowerTriangular(Matrix{Float64}(I, num_params, num_params)) + initialize_gaussian_scale(rng, model, μ, L0; kwargs...) + else + @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." + scale + end + + q = AdvancedVI.FullRankGaussian(μ, L) + b = Bijectors.bijector(model; varinfo=varinfo) + return Bijectors.transformed(q, Bijectors.inverse(b)) +end + +function fullrank_gaussian( + model::DynamicPPL.Model, + location::Union{Nothing, <:AbstractVector} = nothing, + scale::Union{Nothing, <:Diagonal} = nothing; + kwargs... +) + fullrank_gaussian(Random.default_rng(), model, location, scale; kwargs...) end -# VI algorithms -include("advi.jl") +function vi( + model::DynamicPPL.Model, + q::Bijectors.TransformedDistribution, + n_iterations::Int; + objective=RepGradELBO(10, entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), + show_progress::Bool=PROGRESS[], + optimizer=AdvancedVI.DoWG(), + averager=AdvancedVI.PolynomialAveraging(), + operator=AdvancedVI.ProximalLocationScaleEntropy(), + adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, +) + q_avg_trans, _, stats, _ = AdvancedVI.optimize( + make_logdensity(model), + objective, + q, + n_iterations; + show_progress=show_progress, + adtype, + optimizer, + averager, + operator, + ) + if show_progress + lineplot([stat.elbo for stat in stats], ylabel="Objective", xlabel="Iteration") |> display + end + return q_avg_trans +end end diff --git a/src/variational/advi.jl b/src/variational/advi.jl deleted file mode 100644 index ec3e6552e3..0000000000 --- a/src/variational/advi.jl +++ /dev/null @@ -1,140 +0,0 @@ -# TODO: Move to Bijectors.jl if we find further use for this. -""" - wrap_in_vec_reshape(f, in_size) - -Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces -a vector of length `prod(Bijectors.output(f, in_size))`. -""" -function wrap_in_vec_reshape(f, in_size) - vec_in_length = prod(in_size) - reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) - out_size = Bijectors.output_size(f, in_size) - vec_out_length = prod(out_size) - reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) - return reshape_outer ∘ f ∘ reshape_inner -end - -""" - bijector(model::Model[, sym2ranges = Val(false)]) - -Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` -denoting the dimensionality of the latent variables. -""" -function Bijectors.bijector( - model::DynamicPPL.Model, ::Val{sym2ranges}=Val(false); varinfo=DynamicPPL.VarInfo(model) -) where {sym2ranges} - num_params = sum([ - size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata) - ]) - - dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) - - num_ranges = sum([ - length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) - ]) - ranges = Vector{UnitRange{Int}}(undef, num_ranges) - idx = 0 - range_idx = 1 - - # ranges might be discontinuous => values are vectors of ranges rather than just ranges - sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() - for sym in keys(varinfo.metadata) - sym_lookup[sym] = Vector{UnitRange{Int}}() - for r in varinfo.metadata[sym].ranges - ranges[range_idx] = idx .+ r - push!(sym_lookup[sym], ranges[range_idx]) - range_idx += 1 - end - - idx += varinfo.metadata[sym].ranges[end][end] - end - - bs = map(tuple(dists...)) do d - b = Bijectors.bijector(d) - if d isa Distributions.UnivariateDistribution - b - else - wrap_in_vec_reshape(b, size(d)) - end - end - - if sym2ranges - return ( - Bijectors.Stacked(bs, ranges), - (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), - ) - else - return Bijectors.Stacked(bs, ranges) - end -end - -""" - meanfield([rng, ]model::Model) - -Creates a mean-field approximation with multivariate normal as underlying distribution. -""" -meanfield(model::DynamicPPL.Model) = meanfield(Random.default_rng(), model) -function meanfield(rng::Random.AbstractRNG, model::DynamicPPL.Model) - # Setup. - varinfo = DynamicPPL.VarInfo(model) - # Use linked `varinfo` to determine the correct number of parameters. - # TODO: Replace with `length` once this is implemented for `VarInfo`. - varinfo_linked = DynamicPPL.link(varinfo, model) - num_params = length(varinfo_linked[:]) - - # initial params - μ = randn(rng, num_params) - σ = StatsFuns.softplus.(randn(rng, num_params)) - - # Construct the base family. - d = DistributionsAD.TuringDiagMvNormal(μ, σ) - - # Construct the bijector constrained → unconstrained. - b = Bijectors.bijector(model; varinfo=varinfo) - - # We want to transform from unconstrained space to constrained, - # hence we need the inverse of `b`. - return Bijectors.transformed(d, Bijectors.inverse(b)) -end - -# Overloading stuff from `AdvancedVI` to specialize for Turing -function AdvancedVI.update(d::DistributionsAD.TuringDiagMvNormal, μ, σ) - return DistributionsAD.TuringDiagMvNormal(μ, σ) -end -function AdvancedVI.update(td::Bijectors.TransformedDistribution, θ...) - return Bijectors.transformed(AdvancedVI.update(td.dist, θ...), td.transform) -end -function AdvancedVI.update( - td::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}, - θ::AbstractArray, -) - # `length(td.dist) != length(td)` if `td.transform` changes the dimensionality, - # so we need to use the length of the underlying distribution `td.dist` here. - # TODO: Check if we can get away with `view` instead of `getindex` for all AD backends. - μ, ω = θ[begin:(begin + length(td.dist) - 1)], θ[(begin + length(td.dist)):end] - return AdvancedVI.update(td, μ, StatsFuns.softplus.(ω)) -end - -function AdvancedVI.vi( - model::DynamicPPL.Model, alg::AdvancedVI.ADVI; optimizer=AdvancedVI.TruncatedADAGrad() -) - q = meanfield(model) - return AdvancedVI.vi(model, alg, q; optimizer=optimizer) -end - -function AdvancedVI.vi( - model::DynamicPPL.Model, - alg::AdvancedVI.ADVI, - q::Bijectors.TransformedDistribution{<:DistributionsAD.TuringDiagMvNormal}; - optimizer=AdvancedVI.TruncatedADAGrad(), -) - # Initial parameters for mean-field approx - μ, σs = StatsBase.params(q) - θ = vcat(μ, StatsFuns.invsoftplus.(σs)) - - # Optimize - AdvancedVI.optimize!(elbo, alg, q, make_logjoint(model), θ; optimizer=optimizer) - - # Return updated `Distribution` - return AdvancedVI.update(q, θ) -end diff --git a/src/variational/bijectors.jl b/src/variational/bijectors.jl new file mode 100644 index 0000000000..e0633493f6 --- /dev/null +++ b/src/variational/bijectors.jl @@ -0,0 +1,70 @@ + +# TODO: Move to Bijectors.jl if we find further use for this. +""" + wrap_in_vec_reshape(f, in_size) + +Wraps a bijector `f` such that it operates on vectors of length `prod(in_size)` and produces +a vector of length `prod(Bijectors.output(f, in_size))`. +""" +function wrap_in_vec_reshape(f, in_size) + vec_in_length = prod(in_size) + reshape_inner = Bijectors.Reshape((vec_in_length,), in_size) + out_size = Bijectors.output_size(f, in_size) + vec_out_length = prod(out_size) + reshape_outer = Bijectors.Reshape(out_size, (vec_out_length,)) + return reshape_outer ∘ f ∘ reshape_inner +end + +""" + bijector(model::Model[, sym2ranges = Val(false)]) + +Returns a `Stacked <: Bijector` which maps from the support of the posterior to ℝᵈ with `d` +denoting the dimensionality of the latent variables. +""" +function Bijectors.bijector( + model::DynamicPPL.Model, ::Val{sym2ranges}=Val(false); varinfo=DynamicPPL.VarInfo(model) +) where {sym2ranges} + num_params = sum([ + size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata) + ]) + + dists = vcat([varinfo.metadata[sym].dists for sym in keys(varinfo.metadata)]...) + + num_ranges = sum([ + length(varinfo.metadata[sym].ranges) for sym in keys(varinfo.metadata) + ]) + ranges = Vector{UnitRange{Int}}(undef, num_ranges) + idx = 0 + range_idx = 1 + + # ranges might be discontinuous => values are vectors of ranges rather than just ranges + sym_lookup = Dict{Symbol,Vector{UnitRange{Int}}}() + for sym in keys(varinfo.metadata) + sym_lookup[sym] = Vector{UnitRange{Int}}() + for r in varinfo.metadata[sym].ranges + ranges[range_idx] = idx .+ r + push!(sym_lookup[sym], ranges[range_idx]) + range_idx += 1 + end + + idx += varinfo.metadata[sym].ranges[end][end] + end + + bs = map(tuple(dists...)) do d + b = Bijectors.bijector(d) + if d isa Distributions.UnivariateDistribution + b + else + wrap_in_vec_reshape(b, size(d)) + end + end + + if sym2ranges + return ( + Bijectors.Stacked(bs, ranges), + (; collect(zip(keys(sym_lookup), values(sym_lookup)))...), + ) + else + return Bijectors.Stacked(bs, ranges) + end +end From a94269d5ec189826ebc89bb025520e35b1329198 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:21:33 -0400 Subject: [PATCH 02/23] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index db95093508..0d773a8ffc 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -33,9 +33,9 @@ function initialize_gaussian_scale( model::DynamicPPL.Model, location::AbstractVector, scale::AbstractMatrix; - num_samples::Int = 10, - num_max_trials::Int = 10, - reduce_factor = one(eltype(scale))/2 + num_samples::Int=10, + num_max_trials::Int=10, + reduce_factor=one(eltype(scale)) / 2, ) prob = make_logdensity(model) ℓπ = Base.Fix1(LogDensityProblems.logdensity, prob) From a4711a9a1e4493ee10be811a3ca85f48bcbcb58e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:21:44 -0400 Subject: [PATCH 03/23] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 0d773a8ffc..8d5084626a 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -93,9 +93,9 @@ end function meanfield_gaussian( model::DynamicPPL.Model, - location::Union{Nothing, <:AbstractVector} = nothing, - scale::Union{Nothing, <:Diagonal} = nothing; - kwargs... + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:Diagonal}=nothing; + kwargs..., ) meanfield_gaussian(Random.default_rng(), model, location, scale; kwargs...) end From 3f8068be6dfd1c0c782dba363c8c2898d152a083 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:21:50 -0400 Subject: [PATCH 04/23] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 8d5084626a..b9c112b3b5 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -54,7 +54,7 @@ function initialize_gaussian_scale( error("Could not find an initial") end - scale = reduce_factor*scale + scale = reduce_factor * scale n_trial += 1 end end From 222a638d44ff639895d348817c15cecc56fcc176 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:21:55 -0400 Subject: [PATCH 05/23] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index b9c112b3b5..992ca09607 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -97,7 +97,7 @@ function meanfield_gaussian( scale::Union{Nothing,<:Diagonal}=nothing; kwargs..., ) - meanfield_gaussian(Random.default_rng(), model, location, scale; kwargs...) + return meanfield_gaussian(Random.default_rng(), model, location, scale; kwargs...) end function fullrank_gaussian( From 57097f5c6c372b6dc7ada063f2a4ad0d6e27750d Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:22:00 -0400 Subject: [PATCH 06/23] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 992ca09607..9aad42e438 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -62,9 +62,9 @@ end function meanfield_gaussian( rng::Random.AbstractRNG, model::DynamicPPL.Model, - location::Union{Nothing, <:AbstractVector} = nothing, - scale::Union{Nothing, <:Diagonal} = nothing; - kwargs... + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:Diagonal}=nothing; + kwargs..., ) varinfo = DynamicPPL.VarInfo(model) # Use linked `varinfo` to determine the correct number of parameters. From a42eea8e4f6f197265061430799c60ef8c6f04ec Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:22:06 -0400 Subject: [PATCH 07/23] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 9aad42e438..cea01500cd 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -135,9 +135,9 @@ end function fullrank_gaussian( model::DynamicPPL.Model, - location::Union{Nothing, <:AbstractVector} = nothing, - scale::Union{Nothing, <:Diagonal} = nothing; - kwargs... + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:Diagonal}=nothing; + kwargs..., ) fullrank_gaussian(Random.default_rng(), model, location, scale; kwargs...) end From 798f3198f681f23e3bd7605233ddf64508571e75 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:22:12 -0400 Subject: [PATCH 08/23] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index cea01500cd..0a676cf3ff 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -146,7 +146,7 @@ function vi( model::DynamicPPL.Model, q::Bijectors.TransformedDistribution, n_iterations::Int; - objective=RepGradELBO(10, entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), + objective=RepGradELBO(10; entropy=AdvancedVI.ClosedFormEntropyZeroGradient()), show_progress::Bool=PROGRESS[], optimizer=AdvancedVI.DoWG(), averager=AdvancedVI.PolynomialAveraging(), From 69a49720d2c9b4f37504cd458e298bd6cd09be14 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:22:20 -0400 Subject: [PATCH 09/23] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 0a676cf3ff..63ab1540bc 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -139,7 +139,7 @@ function fullrank_gaussian( scale::Union{Nothing,<:Diagonal}=nothing; kwargs..., ) - fullrank_gaussian(Random.default_rng(), model, location, scale; kwargs...) + return fullrank_gaussian(Random.default_rng(), model, location, scale; kwargs...) end function vi( From cbcb8b5744604aad363c6d6ad8cf2d86f3a3bb2e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:22:26 -0400 Subject: [PATCH 10/23] run formatter Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 63ab1540bc..cbf4aab779 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -151,7 +151,7 @@ function vi( optimizer=AdvancedVI.DoWG(), averager=AdvancedVI.PolynomialAveraging(), operator=AdvancedVI.ProximalLocationScaleEntropy(), - adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, + adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, ) q_avg_trans, _, stats, _ = AdvancedVI.optimize( make_logdensity(model), From 081d6ff497daf6044fb93818a8c369d8dea519cb Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:22:48 -0400 Subject: [PATCH 11/23] remove plotting --- Project.toml | 2 -- src/variational/VariationalInference.jl | 8 +++----- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/Project.toml b/Project.toml index 6fa6daa2a1..ae940ca541 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,6 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" StatsAPI = "82ae8749-77ed-4fe6-ae5f-f523153014b0" StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91" StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c" -UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228" [weakdeps] DynamicHMC = "bbc10e6e-7c05-544b-b16e-64fede858acb" @@ -86,7 +85,6 @@ Statistics = "1.6" StatsAPI = "1.6" StatsBase = "0.32, 0.33, 0.34" StatsFuns = "0.8, 0.9, 1" -UnicodePlots = "3" julia = "1.10" [extras] diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index db95093508..d532444256 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -152,8 +152,9 @@ function vi( averager=AdvancedVI.PolynomialAveraging(), operator=AdvancedVI.ProximalLocationScaleEntropy(), adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, + kwargs... ) - q_avg_trans, _, stats, _ = AdvancedVI.optimize( + return AdvancedVI.optimize( make_logdensity(model), objective, q, @@ -163,11 +164,8 @@ function vi( optimizer, averager, operator, + kwargs... ) - if show_progress - lineplot([stat.elbo for stat in stats], ylabel="Objective", xlabel="Iteration") |> display - end - return q_avg_trans end end From 1bcec3e18e15d0f8ff45fc72008b32060b1e7435 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:29:52 -0400 Subject: [PATCH 12/23] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index f0b98bb314..4c1a0c289c 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -103,9 +103,9 @@ end function fullrank_gaussian( rng::Random.AbstractRNG, model::DynamicPPL.Model, - location::Union{Nothing, <:AbstractVector} = nothing, - scale::Union{Nothing, <:LowerTriangular} = nothing; - kwargs... + location::Union{Nothing,<:AbstractVector}=nothing, + scale::Union{Nothing,<:LowerTriangular}=nothing; + kwargs..., ) varinfo = DynamicPPL.VarInfo(model) # Use linked `varinfo` to determine the correct number of parameters. From b142832c4d3cc0762bf2277d1d1905c0ac2a1c1a Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:30:04 -0400 Subject: [PATCH 13/23] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 4c1a0c289c..c4d7fe4b04 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -152,7 +152,7 @@ function vi( averager=AdvancedVI.PolynomialAveraging(), operator=AdvancedVI.ProximalLocationScaleEntropy(), adtype::ADTypes.AbstractADType=DEFAULT_ADTYPE, - kwargs... + kwargs..., ) return AdvancedVI.optimize( make_logdensity(model), From 061ec35b66b8fed680bd67e17a3960b7f71166f9 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:30:10 -0400 Subject: [PATCH 14/23] fix formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/VariationalInference.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index c4d7fe4b04..5810239378 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -164,7 +164,7 @@ function vi( optimizer, averager, operator, - kwargs... + kwargs..., ) end From 736bd3e4bef5e2b08d09727293896fedff690b8e Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Fri, 14 Mar 2025 19:32:14 -0400 Subject: [PATCH 15/23] remove unused dependency --- src/variational/VariationalInference.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index f0b98bb314..022622e6f8 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -7,7 +7,6 @@ using Distributions using LinearAlgebra using LogDensityProblems using Random -using UnicodePlots import ..Turing: DEFAULT_ADTYPE, PROGRESS From 297c32a97bfbd118a5148d31769a365a9289046f Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Thu, 20 Mar 2025 21:24:52 +0000 Subject: [PATCH 16/23] Update Project.toml --- test/Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/Project.toml b/test/Project.toml index 36b7ebdec9..f1d829ae1f 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -43,7 +43,7 @@ AbstractMCMC = "5" AbstractPPL = "0.9, 0.10" AdvancedMH = "0.6, 0.7, 0.8" AdvancedPS = "=0.6.0" -AdvancedVI = "0.2" +AdvancedVI = "0.3" Aqua = "0.8" BangBang = "0.4" Bijectors = "0.14, 0.15" From 0c0443402853163e7ae1512c88111af09029fb61 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 25 Mar 2025 16:08:59 -0400 Subject: [PATCH 17/23] fix make some arugments of vi initializer to be optional kwargs --- src/variational/VariationalInference.jl | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index 5847f74afd..e449237212 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -60,9 +60,9 @@ end function meanfield_gaussian( rng::Random.AbstractRNG, - model::DynamicPPL.Model, + model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, - scale::Union{Nothing,<:Diagonal}=nothing; + scale::Union{Nothing,<:Diagonal}=nothing, kwargs..., ) varinfo = DynamicPPL.VarInfo(model) @@ -91,19 +91,19 @@ function meanfield_gaussian( end function meanfield_gaussian( - model::DynamicPPL.Model, + model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, - scale::Union{Nothing,<:Diagonal}=nothing; + scale::Union{Nothing,<:Diagonal}=nothing, kwargs..., ) - return meanfield_gaussian(Random.default_rng(), model, location, scale; kwargs...) + return meanfield_gaussian(Random.default_rng(), model; location, scale, kwargs...) end function fullrank_gaussian( rng::Random.AbstractRNG, - model::DynamicPPL.Model, + model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, - scale::Union{Nothing,<:LowerTriangular}=nothing; + scale::Union{Nothing,<:LowerTriangular}=nothing, kwargs..., ) varinfo = DynamicPPL.VarInfo(model) @@ -133,12 +133,12 @@ function fullrank_gaussian( end function fullrank_gaussian( - model::DynamicPPL.Model, + model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, - scale::Union{Nothing,<:Diagonal}=nothing; + scale::Union{Nothing,<:LowerTriangular}=nothing, kwargs..., ) - return fullrank_gaussian(Random.default_rng(), model, location, scale; kwargs...) + return fullrank_gaussian(Random.default_rng(), model; location, scale, kwargs...) end function vi( From 626c5b5f0ae2927b1dc29f64af50b8e967e8cf9c Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 25 Mar 2025 16:26:41 -0400 Subject: [PATCH 18/23] remove tests for custom optimizers --- test/runtests.jl | 4 ---- 1 file changed, 4 deletions(-) diff --git a/test/runtests.jl b/test/runtests.jl index 47b714188e..75ad71d90b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -71,10 +71,6 @@ end end end - @testset "variational optimisers" begin - @timeit_include("variational/optimisers.jl") - end - @testset "stdlib" verbose = true begin @timeit_include("stdlib/distributions.jl") @timeit_include("stdlib/RandomMeasures.jl") From cb2c6181ada758fbd3bc364ab21f6262d3765212 Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 25 Mar 2025 16:27:36 -0400 Subject: [PATCH 19/23] remove unused file --- test/variational/optimisers.jl | 29 ----------------------------- 1 file changed, 29 deletions(-) delete mode 100644 test/variational/optimisers.jl diff --git a/test/variational/optimisers.jl b/test/variational/optimisers.jl deleted file mode 100644 index 6f64d5fb1f..0000000000 --- a/test/variational/optimisers.jl +++ /dev/null @@ -1,29 +0,0 @@ -module VariationalOptimisersTests - -using AdvancedVI: DecayedADAGrad, TruncatedADAGrad, apply! -import ForwardDiff -import ReverseDiff -using Test: @test, @testset -using Turing - -function test_opt(ADPack, opt) - θ = randn(10, 10) - θ_fit = randn(10, 10) - loss(x, θ_) = mean(sum(abs2, θ * x - θ_ * x; dims=1)) - for t in 1:(10^4) - x = rand(10) - Δ = ADPack.gradient(θ_ -> loss(x, θ_), θ_fit) - Δ = apply!(opt, θ_fit, Δ) - @. θ_fit = θ_fit - Δ - end - @test loss(rand(10, 100), θ_fit) < 0.01 - @test length(opt.acc) == 1 -end -for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)] - test_opt(ForwardDiff, opt) -end -for opt in [TruncatedADAGrad(), DecayedADAGrad(1e-2)] - test_opt(ReverseDiff, opt) -end - -end From c1533a863db45c251abc281a8fe53dc922973839 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Fri, 18 Apr 2025 21:10:02 +0100 Subject: [PATCH 20/23] Update src/variational/bijectors.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/variational/bijectors.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/variational/bijectors.jl b/src/variational/bijectors.jl index e0633493f6..86078efaa4 100644 --- a/src/variational/bijectors.jl +++ b/src/variational/bijectors.jl @@ -22,7 +22,9 @@ Returns a `Stacked <: Bijector` which maps from the support of the posterior to denoting the dimensionality of the latent variables. """ function Bijectors.bijector( - model::DynamicPPL.Model, ::Val{sym2ranges}=Val(false); varinfo=DynamicPPL.VarInfo(model) + model::DynamicPPL.Model, + (::Val{sym2ranges})=Val(false); + varinfo=DynamicPPL.VarInfo(model), ) where {sym2ranges} num_params = sum([ size(varinfo.metadata[sym].vals, 1) for sym in keys(varinfo.metadata) From 231d6e2d72aef59cf51da7810a89ddc2e3588b37 Mon Sep 17 00:00:00 2001 From: Hong Ge <3279477+yebai@users.noreply.github.com> Date: Mon, 21 Apr 2025 11:57:44 +0100 Subject: [PATCH 21/23] Update Turing.jl --- src/Turing.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index aa5fbe8500..4bd3058906 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -39,8 +39,6 @@ function setprogress!(progress::Bool) @info "[Turing]: progress logging is $(progress ? "enabled" : "disabled") globally" PROGRESS[] = progress AbstractMCMC.setprogress!(progress; silent=true) - # TODO: `AdvancedVI.turnprogress` is removed in AdvancedVI v0.3 - AdvancedVI.turnprogress(progress) return progress end From 69639ec4e3f12333cfecffc28ceea3369174df1b Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Apr 2025 11:55:22 -0400 Subject: [PATCH 22/23] fix remove call to `AdvancedVI.turnprogress`, which has been removed --- src/Turing.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/Turing.jl b/src/Turing.jl index aa5fbe8500..4bd3058906 100644 --- a/src/Turing.jl +++ b/src/Turing.jl @@ -39,8 +39,6 @@ function setprogress!(progress::Bool) @info "[Turing]: progress logging is $(progress ? "enabled" : "disabled") globally" PROGRESS[] = progress AbstractMCMC.setprogress!(progress; silent=true) - # TODO: `AdvancedVI.turnprogress` is removed in AdvancedVI v0.3 - AdvancedVI.turnprogress(progress) return progress end From ef9aeb1cc59396092a68e8e8082a8b5c60203f8f Mon Sep 17 00:00:00 2001 From: Kyurae Kim Date: Tue, 29 Apr 2025 12:14:30 -0400 Subject: [PATCH 23/23] apply comments from @yebai --- src/variational/VariationalInference.jl | 67 +++++++++---------------- 1 file changed, 23 insertions(+), 44 deletions(-) diff --git a/src/variational/VariationalInference.jl b/src/variational/VariationalInference.jl index e449237212..010b34a456 100644 --- a/src/variational/VariationalInference.jl +++ b/src/variational/VariationalInference.jl @@ -15,9 +15,9 @@ import Bijectors # Reexports using AdvancedVI: RepGradELBO, ScoreGradELBO, DoG, DoWG -export vi, RepGradELBO, ScoreGradELBO, DoG, DoWG +export RepGradELBO, ScoreGradELBO, DoG, DoWG -export meanfield_gaussian, fullrank_gaussian +export vi, q_init, q_meanfield_gaussian, q_fullrank_gaussian include("bijectors.jl") @@ -58,11 +58,13 @@ function initialize_gaussian_scale( end end -function meanfield_gaussian( +function q_init( rng::Random.AbstractRNG, model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, - scale::Union{Nothing,<:Diagonal}=nothing, + scale::Union{Nothing,<:Diagonal,<:LowerTriangular}=nothing, + meanfield::Bool=true, + basedist::Distributions.UnivariateDistribution=Normal(), kwargs..., ) varinfo = DynamicPPL.VarInfo(model) @@ -79,66 +81,43 @@ function meanfield_gaussian( end L = if isnothing(scale) - initialize_gaussian_scale(rng, model, μ, Diagonal(ones(num_params)); kwargs...) + if meanfield + initialize_gaussian_scale(rng, model, μ, Diagonal(ones(num_params)); kwargs...) + else + L0 = LowerTriangular(Matrix{Float64}(I, num_params, num_params)) + initialize_gaussian_scale(rng, model, μ, L0; kwargs...) + end else @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." - L = scale + if meanfield + Diagonal(diag(scale)) + else + scale + end end - - q = AdvancedVI.MeanFieldGaussian(μ, L) + q = AdvancedVI.MvLocationScale(μ, L, basedist) b = Bijectors.bijector(model; varinfo=varinfo) return Bijectors.transformed(q, Bijectors.inverse(b)) end -function meanfield_gaussian( +function q_meanfield_gaussian( + rng::Random.AbstractRNG, model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, scale::Union{Nothing,<:Diagonal}=nothing, kwargs..., ) - return meanfield_gaussian(Random.default_rng(), model; location, scale, kwargs...) + return q_init(rng, model; location, scale, meanfield=true, basedist=Normal(), kwargs...) end -function fullrank_gaussian( +function q_fullrank_gaussian( rng::Random.AbstractRNG, model::DynamicPPL.Model; location::Union{Nothing,<:AbstractVector}=nothing, scale::Union{Nothing,<:LowerTriangular}=nothing, kwargs..., ) - varinfo = DynamicPPL.VarInfo(model) - # Use linked `varinfo` to determine the correct number of parameters. - # TODO: Replace with `length` once this is implemented for `VarInfo`. - varinfo_linked = DynamicPPL.link(varinfo, model) - num_params = length(varinfo_linked[:]) - - μ = if isnothing(location) - zeros(num_params) - else - @assert length(location) == num_params "Length of the provided location vector, $(length(location)), does not match dimension of the target distribution, $(num_params)." - location - end - - L = if isnothing(scale) - L0 = LowerTriangular(Matrix{Float64}(I, num_params, num_params)) - initialize_gaussian_scale(rng, model, μ, L0; kwargs...) - else - @assert size(scale) == (num_params, num_params) "Dimensions of the provided scale matrix, $(size(scale)), does not match the dimension of the target distribution, $(num_params)." - scale - end - - q = AdvancedVI.FullRankGaussian(μ, L) - b = Bijectors.bijector(model; varinfo=varinfo) - return Bijectors.transformed(q, Bijectors.inverse(b)) -end - -function fullrank_gaussian( - model::DynamicPPL.Model; - location::Union{Nothing,<:AbstractVector}=nothing, - scale::Union{Nothing,<:LowerTriangular}=nothing, - kwargs..., -) - return fullrank_gaussian(Random.default_rng(), model; location, scale, kwargs...) + return q_init(rng, model; location, scale, meanfield=false, basedist=Normal(), kwargs...) end function vi(