|
| 1 | +@testset "transducer.jl" begin |
| 2 | + Random.seed!(1234) |
| 3 | + |
| 4 | + @testset "Basic sampling" begin |
| 5 | + N = 1_000 |
| 6 | + local chain |
| 7 | + Logging.with_logger(TerminalLogger()) do |
| 8 | + xf = AbstractMCMC.Sample(MyModel(), MySampler(); |
| 9 | + sleepy = true, logger = true) |
| 10 | + chain = collect(xf, withprogress(1:N; interval=1e-3)) |
| 11 | + end |
| 12 | + |
| 13 | + # test output type and size |
| 14 | + @test chain isa Vector{<:MyTransition} |
| 15 | + @test length(chain) == N |
| 16 | + |
| 17 | + # test some statistical properties |
| 18 | + tail_chain = @view chain[2:end] |
| 19 | + @test mean(x.a for x in tail_chain) ≈ 0.5 atol=6e-2 |
| 20 | + @test var(x.a for x in tail_chain) ≈ 1 / 12 atol=5e-3 |
| 21 | + @test mean(x.b for x in tail_chain) ≈ 0.0 atol=5e-2 |
| 22 | + @test var(x.b for x in tail_chain) ≈ 1 atol=6e-2 |
| 23 | + end |
| 24 | + |
| 25 | + @testset "drop" begin |
| 26 | + xf = AbstractMCMC.Sample(MyModel(), MySampler()) |
| 27 | + chain = collect(xf |> Drop(1), 1:10) |
| 28 | + @test chain isa Vector{MyTransition{Float64,Float64}} |
| 29 | + @test length(chain) == 9 |
| 30 | + end |
| 31 | + |
| 32 | + # Reproduce iterator example |
| 33 | + @testset "iterator example" begin |
| 34 | + # filter missing values and split transitions |
| 35 | + xf = AbstractMCMC.Sample(MyModel(), MySampler()) |> |
| 36 | + OfType(MyTransition{Float64,Float64}) |> Map(x -> (x.a, x.b)) |
| 37 | + as, bs = foldl(xf, 1:999; init = (Float64[], Float64[])) do (as, bs), (a, b) |
| 38 | + push!(as, a) |
| 39 | + push!(bs, b) |
| 40 | + as, bs |
| 41 | + end |
| 42 | + |
| 43 | + @test length(as) == length(bs) == 998 |
| 44 | + |
| 45 | + @test mean(as) ≈ 0.5 atol=1e-2 |
| 46 | + @test var(as) ≈ 1 / 12 atol=5e-3 |
| 47 | + @test mean(bs) ≈ 0.0 atol=5e-2 |
| 48 | + @test var(bs) ≈ 1 atol=5e-2 |
| 49 | + end |
| 50 | +end |
0 commit comments