Skip to content

Commit 5d56902

Browse files
sunxd3github-actions[bot]yebai
authored
Add getparams and setparams!! following AbstractMCMC v5.5 and v5.6 (#378)
* add `getparams` and `setparams!!` * undo formatting * Update test/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add some new tests * update `setparams!!` * Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * add comment * Update abstractmcmc.jl Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com> * format * update implementation * bump AbstractMCMC * update test * Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix method ambiguity * Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/abstractmcmc.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * fix test error * fix more test error --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: Hong Ge <3279477+yebai@users.noreply.github.com>
1 parent 47b212a commit 5d56902

File tree

4 files changed

+41
-3
lines changed

4 files changed

+41
-3
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "AdvancedHMC"
22
uuid = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
3-
version = "0.6.2"
3+
version = "0.6.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -30,7 +30,7 @@ AdvancedHMCMCMCChainsExt = "MCMCChains"
3030
AdvancedHMCOrdinaryDiffEqExt = "OrdinaryDiffEq"
3131

3232
[compat]
33-
AbstractMCMC = "5"
33+
AbstractMCMC = "5.6"
3434
ArgCheck = "1, 2"
3535
CUDA = "3, 4, 5"
3636
DocStringExtensions = "0.8, 0.9"

research/tests/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,6 @@ include("../src/riemannian_hmc.jl")
1111
include("relativistic_hmc.jl")
1212
include("riemannian_hmc.jl")
1313

14-
@main function runtests(patterns...; dry::Bool = false)
14+
Comonicon.@main function runtests(patterns...; dry::Bool = false)
1515
retest(patterns...; dry = dry, verbose = Inf)
1616
end

src/abstractmcmc.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,24 @@ getadaptor(state::HMCState) = state.adaptor
3030
getmetric(state::HMCState) = state.metric
3131
getintegrator(state::HMCState) = state.κ.τ.integrator
3232

33+
function AbstractMCMC.getparams(state::HMCState)
34+
return state.transition.z.θ
35+
end
36+
37+
function AbstractMCMC.setparams!!(
38+
model::AbstractMCMC.LogDensityModel,
39+
state::HMCState,
40+
params,
41+
)
42+
hamiltonian = AdvancedHMC.Hamiltonian(state.metric, model)
43+
return Setfield.@set state.transition.z = AdvancedHMC.phasepoint(
44+
hamiltonian,
45+
params,
46+
state.transition.z.r;
47+
ℓκ = state.transition.z.ℓκ,
48+
)
49+
end
50+
3351
"""
3452
$(TYPEDSIGNATURES)
3553

test/abstractmcmc.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,26 @@ using Statistics: mean
2121
LogDensityProblemsAD.ADgradient(Val(:ForwardDiff), ℓπ_gdemo),
2222
)
2323

24+
@testset "getparams and setparams!!" begin
25+
t, s = AbstractMCMC.step(rng, model, nuts;)
26+
27+
θ = AbstractMCMC.getparams(s)
28+
@test θ == t.z.θ
29+
new_state = AbstractMCMC.setparams!!(model, s, θ)
30+
@test new_state.transition.z.θ == θ
31+
new_state_logπ = new_state.transition.z.ℓπ
32+
@test new_state_logπ.value == s.transition.z.ℓπ.value
33+
@test new_state_logπ.gradient == s.transition.z.ℓπ.gradient
34+
new_state_logκ = new_state.transition.z.ℓκ
35+
@test new_state_logκ.value == s.transition.z.ℓκ.value
36+
@test new_state_logκ.gradient == s.transition.z.ℓκ.gradient
37+
@test new_state.transition.z.r == s.transition.z.r
38+
39+
new_θ = randn(rng, 2)
40+
new_state = AbstractMCMC.setparams!!(model, s, new_θ)
41+
@test AbstractMCMC.getparams(new_state) == new_θ
42+
end
43+
2444
samples_nuts = AbstractMCMC.sample(
2545
rng,
2646
model,

0 commit comments

Comments
 (0)