@@ -617,30 +617,40 @@ struct ReconstructInitializeprob{GP, GU}
617
617
ugetter:: GU
618
618
end
619
619
620
+ """
621
+ $(TYPEDEF)
622
+
623
+ A wrapper over an observed function which allows calling it on a problem-like object.
624
+ `TD` determines whether the getter function is `(u, p, t)` (if `true`) or `(u, p)` (if
625
+ `false`).
626
+ """
627
+ struct ObservedWrapper{TD, F}
628
+ f:: F
629
+ end
630
+
631
+ ObservedWrapper {TD} (f:: F ) where {TD, F} = ObservedWrapper {TD, F} (f)
632
+
633
+ function (ow:: ObservedWrapper{true} )(prob)
634
+ ow. f (state_values (prob), parameter_values (prob), current_time (prob))
635
+ end
636
+
637
+ function (ow:: ObservedWrapper{false} )(prob)
638
+ ow. f (state_values (prob), parameter_values (prob))
639
+ end
640
+
620
641
"""
621
642
$(TYPEDSIGNATURES)
622
643
623
644
Given an index provider `indp` and a vector of symbols `syms` return a type-stable getter
624
- function by splitting `syms` into contiguous buffers where the getter of each buffer
625
- is type-stable and constructing a function that calls and concatenates the results.
626
- """
627
- function concrete_getu (indp, syms:: AbstractVector )
628
- # a list of contiguous buffer
629
- split_syms = [Any[syms[1 ]]]
630
- # the type of the getter of the last buffer
631
- current = typeof (getu (indp, syms[1 ]))
632
- for sym in syms[2 : end ]
633
- getter = getu (indp, sym)
634
- if typeof (getter) != current
635
- # if types don't match, build a new buffer
636
- push! (split_syms, [])
637
- current = typeof (getter)
638
- end
639
- push! (split_syms[end ], sym)
640
- end
641
- split_syms = Tuple (split_syms)
642
- # the getter is now type-stable, and we can vcat it to get the full buffer
643
- return Base. Fix1 (reduce, vcat) ∘ getu (indp, split_syms)
645
+ function.
646
+
647
+ Note that the getter ONLY works for problem-like objects, since it generates an observed
648
+ function. It does NOT work for solutions.
649
+ """
650
+ Base. @nospecializeinfer function concrete_getu (indp, syms:: AbstractVector )
651
+ @nospecialize
652
+ obsfn = SymbolicIndexingInterface. observed (indp, syms)
653
+ return ObservedWrapper {is_time_dependent(indp)} (obsfn)
644
654
end
645
655
646
656
"""
0 commit comments