Skip to content

Commit 52daca1

Browse files
fix: fix namespacing of AffectSystem
1 parent 77f85ea commit 52daca1

File tree

2 files changed

+40
-24
lines changed

2 files changed

+40
-24
lines changed

src/systems/callbacks.jl

Lines changed: 24 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -13,15 +13,12 @@ struct AffectSystem
1313
parameters::Vector
1414
"""Parameters of the parent ODESystem whose values are modified by the affect."""
1515
discretes::Vector
16-
"""Maps the symbols of unknowns/observed in the ImplicitDiscreteSystem to its corresponding unknown/parameter in the parent system."""
17-
aff_to_sys::Dict
1816
end
1917

2018
system(a::AffectSystem) = a.system
2119
discretes(a::AffectSystem) = a.discretes
2220
unknowns(a::AffectSystem) = a.unknowns
2321
parameters(a::AffectSystem) = a.parameters
24-
aff_to_sys(a::AffectSystem) = a.aff_to_sys
2522
all_equations(a::AffectSystem) = vcat(equations(system(a)), observed(system(a)))
2623

2724
function Base.show(iio::IO, aff::AffectSystem)
@@ -34,16 +31,14 @@ function Base.:(==)(a1::AffectSystem, a2::AffectSystem)
3431
isequal(system(a1), system(a2)) &&
3532
isequal(discretes(a1), discretes(a2)) &&
3633
isequal(unknowns(a1), unknowns(a2)) &&
37-
isequal(parameters(a1), parameters(a2)) &&
38-
isequal(aff_to_sys(a1), aff_to_sys(a2))
34+
isequal(parameters(a1), parameters(a2))
3935
end
4036

4137
function Base.hash(a::AffectSystem, s::UInt)
4238
s = hash(system(a), s)
4339
s = hash(unknowns(a), s)
4440
s = hash(parameters(a), s)
45-
s = hash(discretes(a), s)
46-
hash(aff_to_sys(a), s)
41+
hash(discretes(a), s)
4742
end
4843

4944
function vars!(vars, aff::AffectSystem; op = Differential)
@@ -251,14 +246,12 @@ function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
251246
for eq in alg_eqs
252247
collect_vars!(dvs, params, eq, iv)
253248
end
254-
255249
pre_params = filter(haspre value, params)
256250
sys_params = collect(setdiff(params, union(discrete_parameters, pre_params)))
257251
discretes = map(tovar, discrete_parameters)
258252
dvs = collect(dvs)
259253
_dvs = map(default_toterm, dvs)
260254

261-
aff_map = Dict(zip(discretes, discrete_parameters))
262255
rev_map = Dict(zip(discrete_parameters, discretes))
263256
subs = merge(rev_map, Dict(zip(dvs, _dvs)))
264257
affect = Symbolics.fast_substitute(affect, subs)
@@ -269,17 +262,14 @@ function make_affect(affect::Vector{Equation}; discrete_parameters = Any[],
269262
collect(union(pre_params, sys_params)))
270263
affectsys = mtkcompile(affectsys; fully_determined = nothing)
271264
# get accessed parameters p from Pre(p) in the callback parameters
272-
accessed_params = filter(isparameter, map(unPre, collect(pre_params)))
265+
accessed_params = Vector{Any}(filter(isparameter, map(unPre, collect(pre_params))))
273266
union!(accessed_params, sys_params)
274267

275268
# add scalarized unknowns to the map.
276269
_dvs = reduce(vcat, map(scalarize, _dvs), init = Any[])
277-
for u in _dvs
278-
aff_map[u] = u
279-
end
280270

281271
AffectSystem(affectsys, collect(_dvs), collect(accessed_params),
282-
collect(discrete_parameters), aff_map)
272+
collect(discrete_parameters))
283273
end
284274

285275
function make_affect(affect; kwargs...)
@@ -448,11 +438,23 @@ end
448438
########## Namespacing Utilities ###########
449439
############################################
450440
function namespace_affects(affect::AffectSystem, s)
451-
AffectSystem(renamespace(s, system(affect)),
441+
affsys = system(affect)
442+
old_ts = get_tearing_state(affsys)
443+
# if we just `renamespace` the system, it updates the name. However, this doesn't
444+
# namespace the returned values from `equations(affsys)`, etc. which we need. So we
445+
# need to manually namespace everything. This is done by renaming the system to the
446+
# namespace, putting it as a subsystem of an empty system called `affectsys`, and then
447+
# flatten the system. The resultant system has everything namespaced, and is still
448+
# called `affectsys` for further namespacing
449+
affsys = rename(affsys, nameof(s))
450+
affsys = toggle_namespacing(affsys, true)
451+
affsys = System(Equation[], get_iv(affsys); systems = [affsys], name = :affectsys)
452+
affsys = complete(affsys)
453+
@set! affsys.tearing_state = old_ts
454+
AffectSystem(affsys,
452455
renamespace.((s,), unknowns(affect)),
453456
renamespace.((s,), parameters(affect)),
454-
renamespace.((s,), discretes(affect)),
455-
Dict([k => renamespace(s, v) for (k, v) in aff_to_sys(affect)]))
457+
renamespace.((s,), discretes(affect)))
456458
end
457459
namespace_affects(af::Nothing, s) = nothing
458460

@@ -808,15 +810,13 @@ function compile_equational_affect(
808810
affsys = system(aff)
809811
ps_to_update = discretes(aff)
810812
dvs_to_update = setdiff(unknowns(aff), getfield.(observed(sys), :lhs))
811-
aff_map = aff_to_sys(aff)
812-
sys_map = Dict([v => k for (k, v) in aff_map])
813813

814814
obseqs, eqs = unhack_observed(observed(affsys), equations(affsys))
815815
if isempty(equations(affsys))
816816
update_eqs = Symbolics.fast_substitute(
817817
obseqs, Dict([p => unPre(p) for p in parameters(affsys)]))
818818
rhss = map(x -> x.rhs, update_eqs)
819-
lhss = map(x -> aff_map[x.lhs], update_eqs)
819+
lhss = map(x -> x.lhs, update_eqs)
820820
is_p = [lhs Set(ps_to_update) for lhs in lhss]
821821
is_u = [lhs Set(dvs_to_update) for lhs in lhss]
822822
dvs = unknowns(sys)
@@ -854,11 +854,11 @@ function compile_equational_affect(
854854
end
855855
end
856856
else
857-
return let dvs_to_update = dvs_to_update, aff_map = aff_map, sys_map = sys_map,
857+
return let dvs_to_update = dvs_to_update,
858858
affsys = affsys, ps_to_update = ps_to_update, aff = aff, sys = sys,
859859
reset_jumps = reset_jumps
860860

861-
dvs_to_access = [aff_map[u] for u in unknowns(affsys)]
861+
dvs_to_access = unknowns(affsys)
862862
ps_to_access = [unPre(p) for p in parameters(affsys)]
863863

864864
affu_getter = getsym(sys, dvs_to_access)
@@ -867,8 +867,8 @@ function compile_equational_affect(
867867
affp_setter! = setsym(affsys, parameters(affsys))
868868
u_setter! = setsym(sys, dvs_to_update)
869869
p_setter! = setsym(sys, ps_to_update)
870-
u_getter = getsym(affsys, [sys_map[u] for u in dvs_to_update])
871-
p_getter = getsym(affsys, [sys_map[p] for p in ps_to_update])
870+
u_getter = getsym(affsys, dvs_to_update)
871+
p_getter = getsym(affsys, ps_to_update)
872872

873873
affprob = ImplicitDiscreteProblem(
874874
affsys, Pair[unknowns(affsys) .=> 0; parameters(affsys) .=> 0],

test/symbolic_events.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1331,3 +1331,19 @@ end
13311331
sys = mtkcompile(sys)
13321332
sol = solve(ODEProblem(sys, [], (0.0, 1.0)), Tsit5())
13331333
end
1334+
1335+
@testset "non-floating-point discretes and namespaced affects" begin
1336+
function Inner(; name)
1337+
@parameters p(t)::Int
1338+
@variables x(t)
1339+
cevs = ModelingToolkit.SymbolicContinuousCallback(
1340+
[x ~ 1.0], [p ~ Pre(p) + 1]; iv = t, discrete_parameters = [p])
1341+
System([D(x) ~ 1], t, [x], [p]; continuous_events = [cevs], name)
1342+
end
1343+
@named inner = Inner()
1344+
@mtkcompile sys = System(Equation[], t; systems = [inner])
1345+
prob = ODEProblem(sys, [inner.x => 0.0, inner.p => 0], (0.0, 5.0))
1346+
sol = solve(prob, Tsit5())
1347+
@test SciMLBase.successful_retcode(sol)
1348+
@test sol[inner.p][end] 1.0
1349+
end

0 commit comments

Comments
 (0)