Skip to content

Commit 509b5cf

Browse files
authored
Merge pull request #39 from TuringLang/transducer
Add a `Sample` transducer
2 parents 74bac80 + 14831b8 commit 509b5cf

File tree

5 files changed

+96
-0
lines changed

5 files changed

+96
-0
lines changed

Project.toml

+2
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1515
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1616
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1717
TerminalLoggers = "5d786b92-1e48-4d6f-9151-6b4477ca9bed"
18+
Transducers = "28d57a85-8fef-5791-bfe6-a80928e7c999"
1819

1920
[compat]
2021
BangBang = "0.3.19"
@@ -23,6 +24,7 @@ LoggingExtras = "0.4"
2324
ProgressLogging = "0.1"
2425
StatsBase = "0.32, 0.33"
2526
TerminalLoggers = "0.1"
27+
Transducers = "0.4.30"
2628
julia = "1"
2729

2830
[extras]

src/AbstractMCMC.jl

+2
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import LoggingExtras
66
import ProgressLogging
77
import StatsBase
88
import TerminalLoggers
9+
import Transducers
910

1011
import Distributed
1112
import Logging
@@ -74,5 +75,6 @@ include("logging.jl")
7475
include("interface.jl")
7576
include("sample.jl")
7677
include("stepper.jl")
78+
include("transducer.jl")
7779

7880
end # module AbstractMCMC

src/transducer.jl

+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
struct Sample{A<:Random.AbstractRNG,M<:AbstractModel,S<:AbstractSampler,K} <: Transducers.Transducer
2+
rng::A
3+
model::M
4+
sampler::S
5+
kwargs::K
6+
end
7+
8+
function Sample(model::AbstractModel, sampler::AbstractSampler; kwargs...)
9+
return Sample(Random.GLOBAL_RNG, model, sampler; kwargs...)
10+
end
11+
12+
function Sample(
13+
rng::Random.AbstractRNG,
14+
model::AbstractModel,
15+
sampler::AbstractSampler;
16+
kwargs...
17+
)
18+
sample_init!(rng, model, sampler, 0)
19+
return Sample(rng, model, sampler, kwargs)
20+
end
21+
22+
function Transducers.start(rf::Transducers.R_{<:Sample}, result)
23+
return Transducers.wrap(rf, nothing, Transducers.start(Transducers.inner(rf), result))
24+
end
25+
26+
function Transducers.next(rf::Transducers.R_{<:Sample}, result, input)
27+
t = Transducers.xform(rf)
28+
Transducers.wrapping(rf, result) do state, iresult
29+
transition = step!(t.rng, t.model, t.sampler, 1, state; t.kwargs...)
30+
iinput = transition
31+
iresult = Transducers.next(Transducers.inner(rf), iresult, transition)
32+
return transition, iresult
33+
end
34+
end
35+
36+
function Transducers.complete(rf::Transducers.R_{Sample}, result)
37+
_private_state, inner_result = Transducers.unwrap(rf, result)
38+
return Transducers.complete(Transducers.inner(rf), inner_result)
39+
end

test/runtests.jl

+3
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ using ConsoleProgressMonitor: ProgressLogger
55
using IJulia
66
using LoggingExtras: TeeLogger, EarlyFilteredLogger
77
using TerminalLoggers: TerminalLogger
8+
using Transducers
89

910
using Distributed
1011
import Logging
@@ -276,4 +277,6 @@ include("interface.jl")
276277
MySampler(), 10, 10;
277278
chain_type = MyChain)
278279
end
280+
281+
include("transducer.jl")
279282
end

test/transducer.jl

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
@testset "transducer.jl" begin
2+
Random.seed!(1234)
3+
4+
@testset "Basic sampling" begin
5+
N = 1_000
6+
local chain
7+
Logging.with_logger(TerminalLogger()) do
8+
xf = AbstractMCMC.Sample(MyModel(), MySampler();
9+
sleepy = true, logger = true)
10+
chain = collect(xf, withprogress(1:N; interval=1e-3))
11+
end
12+
13+
# test output type and size
14+
@test chain isa Vector{<:MyTransition}
15+
@test length(chain) == N
16+
17+
# test some statistical properties
18+
tail_chain = @view chain[2:end]
19+
@test mean(x.a for x in tail_chain) 0.5 atol=6e-2
20+
@test var(x.a for x in tail_chain) 1 / 12 atol=5e-3
21+
@test mean(x.b for x in tail_chain) 0.0 atol=5e-2
22+
@test var(x.b for x in tail_chain) 1 atol=6e-2
23+
end
24+
25+
@testset "drop" begin
26+
xf = AbstractMCMC.Sample(MyModel(), MySampler())
27+
chain = collect(xf |> Drop(1), 1:10)
28+
@test chain isa Vector{MyTransition{Float64,Float64}}
29+
@test length(chain) == 9
30+
end
31+
32+
# Reproduce iterator example
33+
@testset "iterator example" begin
34+
# filter missing values and split transitions
35+
xf = AbstractMCMC.Sample(MyModel(), MySampler()) |>
36+
OfType(MyTransition{Float64,Float64}) |> Map(x -> (x.a, x.b))
37+
as, bs = foldl(xf, 1:999; init = (Float64[], Float64[])) do (as, bs), (a, b)
38+
push!(as, a)
39+
push!(bs, b)
40+
as, bs
41+
end
42+
43+
@test length(as) == length(bs) == 998
44+
45+
@test mean(as) 0.5 atol=1e-2
46+
@test var(as) 1 / 12 atol=5e-3
47+
@test mean(bs) 0.0 atol=5e-2
48+
@test var(bs) 1 atol=5e-2
49+
end
50+
end

0 commit comments

Comments
 (0)