Skip to content

Commit 2ab9c31

Browse files
Merge pull request #3692 from AayushSabharwal/as/concrete-getu
[v9] fix: fix major compile time regression due to `concrete_getu`
2 parents 62b0b66 + c595416 commit 2ab9c31

File tree

1 file changed

+30
-20
lines changed

1 file changed

+30
-20
lines changed

src/systems/problem_utils.jl

Lines changed: 30 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -617,30 +617,40 @@ struct ReconstructInitializeprob{GP, GU}
617617
ugetter::GU
618618
end
619619

620+
"""
621+
$(TYPEDEF)
622+
623+
A wrapper over an observed function which allows calling it on a problem-like object.
624+
`TD` determines whether the getter function is `(u, p, t)` (if `true`) or `(u, p)` (if
625+
`false`).
626+
"""
627+
struct ObservedWrapper{TD, F}
628+
f::F
629+
end
630+
631+
ObservedWrapper{TD}(f::F) where {TD, F} = ObservedWrapper{TD, F}(f)
632+
633+
function (ow::ObservedWrapper{true})(prob)
634+
ow.f(state_values(prob), parameter_values(prob), current_time(prob))
635+
end
636+
637+
function (ow::ObservedWrapper{false})(prob)
638+
ow.f(state_values(prob), parameter_values(prob))
639+
end
640+
620641
"""
621642
$(TYPEDSIGNATURES)
622643
623644
Given an index provider `indp` and a vector of symbols `syms` return a type-stable getter
624-
function by splitting `syms` into contiguous buffers where the getter of each buffer
625-
is type-stable and constructing a function that calls and concatenates the results.
626-
"""
627-
function concrete_getu(indp, syms::AbstractVector)
628-
# a list of contiguous buffer
629-
split_syms = [Any[syms[1]]]
630-
# the type of the getter of the last buffer
631-
current = typeof(getu(indp, syms[1]))
632-
for sym in syms[2:end]
633-
getter = getu(indp, sym)
634-
if typeof(getter) != current
635-
# if types don't match, build a new buffer
636-
push!(split_syms, [])
637-
current = typeof(getter)
638-
end
639-
push!(split_syms[end], sym)
640-
end
641-
split_syms = Tuple(split_syms)
642-
# the getter is now type-stable, and we can vcat it to get the full buffer
643-
return Base.Fix1(reduce, vcat) getu(indp, split_syms)
645+
function.
646+
647+
Note that the getter ONLY works for problem-like objects, since it generates an observed
648+
function. It does NOT work for solutions.
649+
"""
650+
Base.@nospecializeinfer function concrete_getu(indp, syms::AbstractVector)
651+
@nospecialize
652+
obsfn = SymbolicIndexingInterface.observed(indp, syms)
653+
return ObservedWrapper{is_time_dependent(indp)}(obsfn)
644654
end
645655

646656
"""

0 commit comments

Comments
 (0)