Skip to content

Commit 9f63c25

Browse files
authored
fix as (#348)
* fix `as` * add test * bump version
1 parent 80dfca9 commit 9f63c25

File tree

3 files changed

+20
-11
lines changed

3 files changed

+20
-11
lines changed

Project.toml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Soss"
22
uuid = "8ce77f84-9b61-11e8-39ff-d17a774bf41c"
33
author = ["Chad Scherrer <chad.scherrer@gmail.com>"]
4-
version = "0.21.1"
4+
version = "0.21.2"
55

66
[deps]
77
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
@@ -52,7 +52,7 @@ JuliaVariables = "0.2"
5252
MLStyle = "0.3,0.4"
5353
MacroTools = "0.5"
5454
MappedArrays = "0.3, 0.4"
55-
MeasureBase = "0.9"
55+
MeasureBase = "0.10, 0.11, 0.12"
5656
MeasureTheory = "0.16"
5757
NamedTupleTools = "0.12, 0.13, 0.14"
5858
NestedTuples = "0.3.9"
@@ -61,16 +61,16 @@ Reexport = "1"
6161
Requires = "1"
6262
RuntimeGeneratedFunctions = "0.5"
6363
SampleChains = "0.5"
64-
SimpleGraphs = "0.5, 0.6, 0.7"
64+
SimpleGraphs = "= 0.7.18"
6565
SimplePartitions = "0.2, 0.3"
66-
SimplePosets = "0.1"
66+
SimplePosets = "= 0.1.5"
6767
SpecialFunctions = "1, 2"
6868
Static = "0.5, 0.6"
6969
StatsBase = "0.33"
7070
StatsFuns = "0.9, 1"
7171
SymbolicUtils = "0.17, 0.18, 0.19"
7272
TransformVariables = "0.5, 0.6"
73-
TupleVectors = "0.1"
73+
TupleVectors = "0.1, 0.2"
7474
julia = "1.6"
7575

7676
[extras]

src/primitives/as.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,11 @@ function sourceXform(_data=NamedTuple())
4444
rhs = st.rhs
4545

4646
thecode = @q begin
47-
_t = Soss.as($rhs, get(_data, $xname, NamedTuple()))
48-
if !isnothing(_t)
49-
_result = merge(_result, ($x=_t,))
47+
_d = get(_data, $xname, nothing)
48+
if isnothing(_d) # xname is not defined in _data
49+
_result = merge(_result, ($x = Soss.as($rhs),))
50+
elseif _d isa NamedTuple
51+
_result = merge(_result, ($x = Soss.as($rhs, _d),))
5052
end
5153
end
5254

@@ -90,11 +92,9 @@ function asTransform(supp:: Dists.RealInterval)
9092
return ScaledShiftedLogistic(ub-lb, lb)
9193
end
9294

93-
as(d, _data) = nothing
94-
9595
as::AbstractMeasure, _data::NamedTuple) = as(μ)
9696

97-
as(d::Dists.AbstractMvNormal, _data::NamedTuple=NamedTuple()) = as(Array, size(d))
97+
as(d::Dists.AbstractMvNormal, _data::NamedTuple = NamedTuple()) = TV.as(Array, size(d))
9898

9999
@gg function _as(M::Type{<:TypeLevel}, _m::Model{Asub,B}, _args::A, _data) where {Asub,A,B}
100100
body = type2model(_m) |> sourceXform(_data) |> loadvals(_args, _data)

test/runtests.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Soss
22
using Test
3+
using LinearAlgebra
34
using MeasureTheory
45
import TransformVariables as TV
56
using TransformVariables: transform
@@ -194,4 +195,12 @@ include("examples-list.jl")
194195
base = basemeasure(post)
195196
@test logdensity_def(base, (p=0.2, x=post.obs.x)) isa Real
196197
end
198+
199+
@testset "https://github.com/cscherrer/Soss.jl/issues/342" begin
200+
m = Soss.@model () begin
201+
z ~ Dists.MvNormal(zeros(10), I)
202+
end
203+
t = Soss.as(m())
204+
@test TV.transform(t, zeros(10)) == (z = zeros(10), )
205+
end
197206
end

0 commit comments

Comments
 (0)