Skip to content

Commit 3277095

Browse files
committed
Merge branch 'master' into 23Q4/exp/par_tree
2 parents 7bbd5f4 + a7a024c commit 3277095

File tree

5 files changed

+84
-59
lines changed

5 files changed

+84
-59
lines changed

Project.toml

+1
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ RecursiveArrayTools = "2.31.1"
9191
Reexport = "1"
9292
SparseDiffTools = "2"
9393
StaticArrays = "1"
94+
Statistics = "1"
9495
StatsBase = "0.32, 0.33, 0.34"
9596
StructTypes = "1"
9697
TensorCast = "0.3.3, 0.4"

ext/IncrInfrDiffEqFactorExt.jl

+73-51
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ module IncrInfrDiffEqFactorExt
22

33
@info "IncrementalInference.jl is loading extensions related to DifferentialEquations.jl"
44

5+
import Base: show
6+
57
using DifferentialEquations
68
import DifferentialEquations: solve
79

@@ -15,10 +17,30 @@ using DocStringExtensions
1517

1618
export DERelative
1719

20+
import Manifolds: allocate, compose, hat, Identity, vee, log
1821

1922

2023
getManifold(de::DERelative{T}) where {T} = getManifold(de.domain)
2124

25+
26+
function Base.show(
27+
io::IO,
28+
::Union{<:DERelative{T,O},Type{<:DERelative{T,O}}}
29+
) where {T,O}
30+
println(io, " DERelative{")
31+
println(io, " ", T)
32+
println(io, " ", O.name.name)
33+
println(io, " }")
34+
nothing
35+
end
36+
37+
Base.show(
38+
io::IO,
39+
::MIME"text/plain",
40+
der::DERelative
41+
) = show(io, der)
42+
43+
2244
"""
2345
$SIGNATURES
2446
@@ -28,7 +50,9 @@ DevNotes
2850
- TODO does not yet incorporate Xi.nanosecond field.
2951
- TODO does not handle timezone crossing properly yet.
3052
"""
31-
function _calcTimespan(Xi::AbstractVector{<:DFGVariable})
53+
function _calcTimespan(
54+
Xi::AbstractVector{<:DFGVariable}
55+
)
3256
#
3357
tsmps = getTimestamp.(Xi[1:2]) .|> DateTime .|> datetime2unix
3458
# toffs = (tsmps .- tsmps[1]) .|> x-> elemType(x.value*1e-3)
@@ -47,10 +71,10 @@ function DERelative(
4771
f::Function,
4872
data = () -> ();
4973
dt::Real = 1,
50-
state0::AbstractVector{<:Real} = zeros(getDimension(domain)),
51-
state1::AbstractVector{<:Real} = zeros(getDimension(domain)),
74+
state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)),
75+
state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), # zeros(getDimension(domain)),
5276
tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi),
53-
problemType = DiscreteProblem,
77+
problemType = ODEProblem, # DiscreteProblem,
5478
)
5579
#
5680
datatuple = if 2 < length(Xi)
@@ -60,11 +84,11 @@ function DERelative(
6084
data
6185
end
6286
# forward time problem
63-
fproblem = problemType(f, state0, tspan, datatuple; dt = dt)
87+
fproblem = problemType(f, state0, tspan, datatuple; dt)
6488
# backward time problem
6589
bproblem = problemType(f, state1, (tspan[2], tspan[1]), datatuple; dt = -dt)
6690
# build the IIF recognizable object
67-
return DERelative(domain, fproblem, bproblem, datatuple, getSample)
91+
return DERelative(domain, fproblem, bproblem, datatuple) #, getSample)
6892
end
6993

7094
function DERelative(
@@ -75,8 +99,8 @@ function DERelative(
7599
data = () -> ();
76100
Xi::AbstractArray{<:DFGVariable} = getVariable.(dfg, labels),
77101
dt::Real = 1,
78-
state0::AbstractVector{<:Real} = zeros(getDimension(domain)),
79-
state1::AbstractVector{<:Real} = zeros(getDimension(domain)),
102+
state1::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)),
103+
state0::AbstractVector{<:Real} = allocate(getPointIdentity(domain)), #zeros(getDimension(domain)),
80104
tspan::Tuple{<:Real, <:Real} = _calcTimespan(Xi),
81105
problemType = DiscreteProblem,
82106
)
@@ -85,26 +109,32 @@ function DERelative(
85109
domain,
86110
f,
87111
data;
88-
dt = dt,
89-
state0 = state0,
90-
state1 = state1,
91-
tspan = tspan,
92-
problemType = problemType,
112+
dt,
113+
state0,
114+
state1,
115+
tspan,
116+
problemType,
93117
)
94118
end
95119
#
96120
#
97121

98122
# n-ary factor: Xtra splat are variable points (X3::Matrix, X4::Matrix,...)
99-
function _solveFactorODE!(measArr, prob, u0pts, Xtra...)
123+
function _solveFactorODE!(
124+
measArr,
125+
prob,
126+
u0pts,
127+
Xtra...
128+
)
100129
# happens when more variables (n-ary) must be included in DE solve
101130
for (xid, xtra) in enumerate(Xtra)
102131
# update the data register before ODE solver calls the function
103-
prob.p[xid + 1][:] = xtra[:]
132+
prob.p[xid + 1][:] = xtra[:] # FIXME, unlikely to work with ArrayPartition, maybe use MArray and `.=`
104133
end
105134

106135
# set the initial condition
107-
prob.u0[:] = u0pts[:]
136+
prob.u0 .= u0pts
137+
108138
sol = DifferentialEquations.solve(prob)
109139

110140
# extract solution from solved ode
@@ -155,21 +185,21 @@ end
155185

156186

157187
# NOTE see #1025, CalcFactor should fix `multihypo=` in `cf.__` fields; OBSOLETE
158-
function (cf::CalcFactor{<:DERelative})(measurement, X...)
188+
function (cf::CalcFactor{<:DERelative})(
189+
measurement,
190+
X...
191+
)
159192
#
193+
# numerical measurement values
160194
meas1 = measurement[1]
161-
diffOp = measurement[2]
162-
195+
# work on-manifold via sampleFactor piggy back of particular manifold definition
196+
M = measurement[2]
197+
# lazy factor pointer
163198
oderel = cf.factor
164-
165-
# work on-manifold
166-
# diffOp = meas[2]
167-
# if backwardSolve else forward
168-
169199
# check direction
170-
171200
solveforIdx = cf.solvefor
172-
201+
202+
# if backwardSolve else forward
173203
if solveforIdx > 2
174204
# need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable)
175205
solveforIdx = 2
@@ -185,16 +215,10 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
185215
end
186216

187217
# find the difference between measured and predicted.
188-
## assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
189-
## FIXME, obviously this is not going to work for more compilcated groups/manifolds -- must fix this soon!
190-
# @show cf._sampleIdx, solveforIdx, meas1
191-
192-
#FIXME
193-
res = zeros(size(X[2], 1))
194-
for i = 1:size(X[2], 1)
195-
# diffop( reference?, test? ) <===> ΔX = test \ reference
196-
res[i] = diffOp[i](X[solveforIdx][i], meas1[i])
197-
end
218+
# assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
219+
res_ = compose(M, inv(M, X[solveforIdx]), meas1)
220+
res = vee(M, Identity(M), log(M, Identity(M), res_))
221+
198222
return res
199223
end
200224

@@ -249,28 +273,32 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
249273
oder = cf.factor
250274

251275
# how many trajectories to propagate?
252-
# @show getLabel(cf.fullvariables[2]), getDimension(cf.fullvariables[2])
253-
meas = [zeros(getDimension(cf.fullvariables[2])) for _ = 1:N]
276+
#
277+
v2T = getVariableType(cf.fullvariables[2])
278+
meas = [allocate(getPointIdentity(v2T)) for _ = 1:N]
279+
# meas = [zeros(getDimension(cf.fullvariables[2])) for _ = 1:N]
254280

255281
# pick forward or backward direction
256282
# set boundary condition
257-
u0pts = if cf.solvefor == 1
283+
u0pts, M = if cf.solvefor == 1
258284
# backward direction
259285
prob = oder.backwardProblem
286+
M_ = getManifold(getVariableType(cf.fullvariables[1]))
260287
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
261-
convert(Tuple, getManifold(getVariableType(cf.fullvariables[1]))),
288+
convert(Tuple, M_),
262289
)
263290
# getBelief(cf.fullvariables[2]) |> getPoints
264-
cf._legacyParams[2]
291+
cf._legacyParams[2], M_
265292
else
266293
# forward backward
267294
prob = oder.forwardProblem
295+
M_ = getManifold(getVariableType(cf.fullvariables[2]))
268296
# buffer manifold operations for use during factor evaluation
269297
addOp, diffOp, _, _ = AMP.buildHybridManifoldCallbacks(
270-
convert(Tuple, getManifold(getVariableType(cf.fullvariables[2]))),
298+
convert(Tuple, M_),
271299
)
272300
# getBelief(cf.fullvariables[1]) |> getPoints
273-
cf._legacyParams[1]
301+
cf._legacyParams[1], M_
274302
end
275303

276304
# solve likely elements
@@ -281,17 +309,11 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
281309
# _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...)
282310
end
283311

284-
return map(x -> (x, diffOp), meas)
312+
# return meas, M
313+
return map(x -> (x, M), meas)
285314
end
286315
# getDimension(oderel.domain)
287316

288317

289318

290-
291-
292-
## the function
293-
# ode.problem.f.f
294-
295-
#
296-
297319
end # module

src/ExportAPI.jl

+9
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,14 @@
11
# the IncrementalInference API
22

3+
4+
# reexport
5+
export ℝ, AbstractManifold
6+
export Identity, hat , vee, ArrayPartition, exp!, exp, log!, log
7+
# common groups -- preferred defaults at this time.
8+
export TranslationGroup, RealCircleGroup
9+
# common non-groups -- TODO still teething problems to sort out in IIF v0.25-v0.26.
10+
export Euclidean, Circle
11+
312
# DFG SpecialDefinitions
413
export AbstractDFG,
514
getSolverParams,

src/IncrementalInference.jl

-7
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,6 @@ using FiniteDifferences
1919

2020
using OrderedCollections: OrderedDict
2121

22-
export ℝ, AbstractManifold
23-
# export ProductRepr
24-
# common groups -- preferred defaults at this time.
25-
export TranslationGroup, RealCircleGroup
26-
# common non-groups -- TODO still teething problems to sort out in IIF v0.25-v0.26.
27-
export Euclidean, Circle
28-
2922
import Optim
3023

3124
using Dates,

src/entities/ExtFactors.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ struct DERelative{T <: InferenceVariable, P, D} <: AbstractManifoldMinimize # Ab
2525
backwardProblem::P
2626
""" second element of this data tuple is additional variables that will be passed down as a parameter """
2727
data::D
28-
specialSampler::Function
28+
# specialSampler::Function
2929
end

0 commit comments

Comments
 (0)