@@ -2,6 +2,8 @@ module IncrInfrDiffEqFactorExt
2
2
3
3
@info " IncrementalInference.jl is loading extensions related to DifferentialEquations.jl"
4
4
5
+ import Base: show
6
+
5
7
using DifferentialEquations
6
8
import DifferentialEquations: solve
7
9
@@ -15,10 +17,30 @@ using DocStringExtensions
15
17
16
18
export DERelative
17
19
20
+ import Manifolds: allocate, compose, hat, Identity, vee, log
18
21
19
22
20
23
getManifold (de:: DERelative{T} ) where {T} = getManifold (de. domain)
21
24
25
+
26
+ function Base. show (
27
+ io:: IO ,
28
+ :: Union{<:DERelative{T,O},Type{<:DERelative{T,O}}}
29
+ ) where {T,O}
30
+ println (io, " DERelative{" )
31
+ println (io, " " , T)
32
+ println (io, " " , O. name. name)
33
+ println (io, " }" )
34
+ nothing
35
+ end
36
+
37
+ Base. show (
38
+ io:: IO ,
39
+ :: MIME"text/plain" ,
40
+ der:: DERelative
41
+ ) = show (io, der)
42
+
43
+
22
44
"""
23
45
$SIGNATURES
24
46
@@ -28,7 +50,9 @@ DevNotes
28
50
- TODO does not yet incorporate Xi.nanosecond field.
29
51
- TODO does not handle timezone crossing properly yet.
30
52
"""
31
- function _calcTimespan (Xi:: AbstractVector{<:DFGVariable} )
53
+ function _calcTimespan (
54
+ Xi:: AbstractVector{<:DFGVariable}
55
+ )
32
56
#
33
57
tsmps = getTimestamp .(Xi[1 : 2 ]) .| > DateTime .| > datetime2unix
34
58
# toffs = (tsmps .- tsmps[1]) .|> x-> elemType(x.value*1e-3)
@@ -47,10 +71,10 @@ function DERelative(
47
71
f:: Function ,
48
72
data = () -> ();
49
73
dt:: Real = 1 ,
50
- state0:: AbstractVector{<:Real} = zeros (getDimension (domain)),
51
- state1:: AbstractVector{<:Real} = zeros (getDimension (domain)),
74
+ state0:: AbstractVector{<:Real} = allocate ( getPointIdentity (domain)), # zeros(getDimension(domain)),
75
+ state1:: AbstractVector{<:Real} = allocate ( getPointIdentity (domain)), # zeros(getDimension(domain)),
52
76
tspan:: Tuple{<:Real, <:Real} = _calcTimespan (Xi),
53
- problemType = DiscreteProblem,
77
+ problemType = ODEProblem, # DiscreteProblem,
54
78
)
55
79
#
56
80
datatuple = if 2 < length (Xi)
@@ -60,11 +84,11 @@ function DERelative(
60
84
data
61
85
end
62
86
# forward time problem
63
- fproblem = problemType (f, state0, tspan, datatuple; dt = dt )
87
+ fproblem = problemType (f, state0, tspan, datatuple; dt)
64
88
# backward time problem
65
89
bproblem = problemType (f, state1, (tspan[2 ], tspan[1 ]), datatuple; dt = - dt)
66
90
# build the IIF recognizable object
67
- return DERelative (domain, fproblem, bproblem, datatuple, getSample)
91
+ return DERelative (domain, fproblem, bproblem, datatuple) # , getSample)
68
92
end
69
93
70
94
function DERelative (
@@ -75,8 +99,8 @@ function DERelative(
75
99
data = () -> ();
76
100
Xi:: AbstractArray{<:DFGVariable} = getVariable .(dfg, labels),
77
101
dt:: Real = 1 ,
78
- state0 :: AbstractVector{<:Real} = zeros (getDimension (domain)),
79
- state1 :: AbstractVector{<:Real} = zeros (getDimension (domain)),
102
+ state1 :: AbstractVector{<:Real} = allocate ( getPointIdentity (domain)), # zeros(getDimension(domain)),
103
+ state0 :: AbstractVector{<:Real} = allocate ( getPointIdentity (domain)), # zeros(getDimension(domain)),
80
104
tspan:: Tuple{<:Real, <:Real} = _calcTimespan (Xi),
81
105
problemType = DiscreteProblem,
82
106
)
@@ -85,26 +109,32 @@ function DERelative(
85
109
domain,
86
110
f,
87
111
data;
88
- dt = dt ,
89
- state0 = state0 ,
90
- state1 = state1 ,
91
- tspan = tspan ,
92
- problemType = problemType ,
112
+ dt,
113
+ state0,
114
+ state1,
115
+ tspan,
116
+ problemType,
93
117
)
94
118
end
95
119
#
96
120
#
97
121
98
122
# n-ary factor: Xtra splat are variable points (X3::Matrix, X4::Matrix,...)
99
- function _solveFactorODE! (measArr, prob, u0pts, Xtra... )
123
+ function _solveFactorODE! (
124
+ measArr,
125
+ prob,
126
+ u0pts,
127
+ Xtra...
128
+ )
100
129
# happens when more variables (n-ary) must be included in DE solve
101
130
for (xid, xtra) in enumerate (Xtra)
102
131
# update the data register before ODE solver calls the function
103
- prob. p[xid + 1 ][:] = xtra[:]
132
+ prob. p[xid + 1 ][:] = xtra[:] # FIXME , unlikely to work with ArrayPartition, maybe use MArray and `.=`
104
133
end
105
134
106
135
# set the initial condition
107
- prob. u0[:] = u0pts[:]
136
+ prob. u0 .= u0pts
137
+
108
138
sol = DifferentialEquations. solve (prob)
109
139
110
140
# extract solution from solved ode
@@ -155,21 +185,21 @@ end
155
185
156
186
157
187
# NOTE see #1025, CalcFactor should fix `multihypo=` in `cf.__` fields; OBSOLETE
158
- function (cf:: CalcFactor{<:DERelative} )(measurement, X... )
188
+ function (cf:: CalcFactor{<:DERelative} )(
189
+ measurement,
190
+ X...
191
+ )
159
192
#
193
+ # numerical measurement values
160
194
meas1 = measurement[1 ]
161
- diffOp = measurement[2 ]
162
-
195
+ # work on-manifold via sampleFactor piggy back of particular manifold definition
196
+ M = measurement[2 ]
197
+ # lazy factor pointer
163
198
oderel = cf. factor
164
-
165
- # work on-manifold
166
- # diffOp = meas[2]
167
- # if backwardSolve else forward
168
-
169
199
# check direction
170
-
171
200
solveforIdx = cf. solvefor
172
-
201
+
202
+ # if backwardSolve else forward
173
203
if solveforIdx > 2
174
204
# need to recalculate new ODE (forward) for change in parameters (solving for 3rd or higher variable)
175
205
solveforIdx = 2
@@ -185,16 +215,10 @@ function (cf::CalcFactor{<:DERelative})(measurement, X...)
185
215
end
186
216
187
217
# find the difference between measured and predicted.
188
- # # assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
189
- # # FIXME , obviously this is not going to work for more compilcated groups/manifolds -- must fix this soon!
190
- # @show cf._sampleIdx, solveforIdx, meas1
191
-
192
- # FIXME
193
- res = zeros (size (X[2 ], 1 ))
194
- for i = 1 : size (X[2 ], 1 )
195
- # diffop( reference?, test? ) <===> ΔX = test \ reference
196
- res[i] = diffOp[i](X[solveforIdx][i], meas1[i])
197
- end
218
+ # assuming the ODE integrated from current X1 through to predicted X2 (ie `meas1[:,idx]`)
219
+ res_ = compose (M, inv (M, X[solveforIdx]), meas1)
220
+ res = vee (M, Identity (M), log (M, Identity (M), res_))
221
+
198
222
return res
199
223
end
200
224
@@ -249,28 +273,32 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
249
273
oder = cf. factor
250
274
251
275
# how many trajectories to propagate?
252
- # @show getLabel(cf.fullvariables[2]), getDimension(cf.fullvariables[2])
253
- meas = [zeros (getDimension (cf. fullvariables[2 ])) for _ = 1 : N]
276
+ #
277
+ v2T = getVariableType (cf. fullvariables[2 ])
278
+ meas = [allocate (getPointIdentity (v2T)) for _ = 1 : N]
279
+ # meas = [zeros(getDimension(cf.fullvariables[2])) for _ = 1:N]
254
280
255
281
# pick forward or backward direction
256
282
# set boundary condition
257
- u0pts = if cf. solvefor == 1
283
+ u0pts, M = if cf. solvefor == 1
258
284
# backward direction
259
285
prob = oder. backwardProblem
286
+ M_ = getManifold (getVariableType (cf. fullvariables[1 ]))
260
287
addOp, diffOp, _, _ = AMP. buildHybridManifoldCallbacks (
261
- convert (Tuple, getManifold ( getVariableType (cf . fullvariables[ 1 ])) ),
288
+ convert (Tuple, M_ ),
262
289
)
263
290
# getBelief(cf.fullvariables[2]) |> getPoints
264
- cf. _legacyParams[2 ]
291
+ cf. _legacyParams[2 ], M_
265
292
else
266
293
# forward backward
267
294
prob = oder. forwardProblem
295
+ M_ = getManifold (getVariableType (cf. fullvariables[2 ]))
268
296
# buffer manifold operations for use during factor evaluation
269
297
addOp, diffOp, _, _ = AMP. buildHybridManifoldCallbacks (
270
- convert (Tuple, getManifold ( getVariableType (cf . fullvariables[ 2 ])) ),
298
+ convert (Tuple, M_ ),
271
299
)
272
300
# getBelief(cf.fullvariables[1]) |> getPoints
273
- cf. _legacyParams[1 ]
301
+ cf. _legacyParams[1 ], M_
274
302
end
275
303
276
304
# solve likely elements
@@ -281,17 +309,11 @@ function IncrementalInference.sampleFactor(cf::CalcFactor{<:DERelative}, N::Int
281
309
# _solveFactorODE!(meas, prob, u0pts, i, _maketuplebeyond2args(cf._legacyParams...)...)
282
310
end
283
311
284
- return map (x -> (x, diffOp), meas)
312
+ # return meas, M
313
+ return map (x -> (x, M), meas)
285
314
end
286
315
# getDimension(oderel.domain)
287
316
288
317
289
318
290
-
291
-
292
- # # the function
293
- # ode.problem.f.f
294
-
295
- #
296
-
297
319
end # module
0 commit comments