1
+ #define _CRT_SECURE_NO_WARNINGS
2
+ #include < iostream>
3
+ #include < fstream>
4
+ #include < string>
5
+
6
+ #include < opencv2/imgproc.hpp>
7
+ #include < opencv2/highgui.hpp>
8
+ #include < opencv2/opencv.hpp>
9
+ #include < opencv2/features2d.hpp>
10
+
11
+ // #include <cuda_provider_factory.h>
12
+ #include < onnxruntime_cxx_api.h>
13
+
14
+ using namespace cv ;
15
+ using namespace std ;
16
+ using namespace Ort ;
17
+
18
+
19
+ class DeDoDeRunner_end2end
20
+ {
21
+ public:
22
+ DeDoDeRunner_end2end (string model_path);
23
+ void detect (Mat image_a, Mat image_b, vector<cv::KeyPoint>& points_A, vector<cv::KeyPoint>& points_B);
24
+ private:
25
+ const int inpWidth = 256 ;
26
+ const int inpHeight = 256 ;
27
+ const float mean_[3 ] = { 0.485 , 0.456 , 0.406 };
28
+ const float std_[3 ] = { 0.229 , 0.224 , 0.225 };
29
+ vector<float > input_images;
30
+ void preprocess (Mat image_a, Mat image_b);
31
+
32
+ Env env = Env(ORT_LOGGING_LEVEL_ERROR, " cv::KeyPoints detect and match" );
33
+ Ort::Session* ort_session = nullptr ;
34
+ SessionOptions sessionOptions = SessionOptions();
35
+ vector<char *> input_names;
36
+ vector<char *> output_names;
37
+ vector<vector<int64_t >> output_node_dims; // >=1 outputs
38
+ };
39
+
40
+ DeDoDeRunner_end2end::DeDoDeRunner_end2end (string model_path)
41
+ {
42
+ std::wstring widestr = std::wstring (model_path.begin (), model_path.end ());
43
+ // OrtStatus* status = OrtSessionOptionsAppendExecutionProvider_CUDA(sessionOptions, 0);
44
+ sessionOptions.SetGraphOptimizationLevel (ORT_ENABLE_BASIC);
45
+ ort_session = new Session (env, widestr.c_str (), sessionOptions);
46
+ size_t numInputNodes = ort_session->GetInputCount ();
47
+ size_t numOutputNodes = ort_session->GetOutputCount ();
48
+ AllocatorWithDefaultOptions allocator;
49
+ for (int i = 0 ; i < numInputNodes; i++)
50
+ {
51
+ input_names.push_back (ort_session->GetInputName (i, allocator));
52
+ }
53
+ for (int i = 0 ; i < numOutputNodes; i++)
54
+ {
55
+ output_names.push_back (ort_session->GetOutputName (i, allocator));
56
+ Ort::TypeInfo output_type_info = ort_session->GetOutputTypeInfo (i);
57
+ auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo ();
58
+ auto output_dims = output_tensor_info.GetShape ();
59
+ output_node_dims.push_back (output_dims);
60
+ }
61
+ }
62
+
63
+ void DeDoDeRunner_end2end::preprocess (Mat image_a, Mat image_b)
64
+ {
65
+ Mat dstimg;
66
+ cvtColor (image_a, dstimg, COLOR_BGR2RGB);
67
+ Size target_size = Size (this ->inpWidth , this ->inpHeight );
68
+ resize (dstimg, dstimg, target_size, INTER_LINEAR);
69
+ this ->input_images .resize (2 * target_size.area () * 3 );
70
+ for (int c = 0 ; c < 3 ; c++)
71
+ {
72
+ for (int i = 0 ; i < this ->inpHeight ; i++)
73
+ {
74
+ for (int j = 0 ; j < this ->inpWidth ; j++)
75
+ {
76
+ float pix = dstimg.ptr <uchar>(i)[j * 3 + c];
77
+ this ->input_images [c * target_size.area () + i * this ->inpWidth + j] = (pix / 255.0 - this ->mean_ [c]) / this ->std_ [c];
78
+ }
79
+ }
80
+ }
81
+
82
+ cvtColor (image_b, dstimg, COLOR_BGR2RGB);
83
+ resize (dstimg, dstimg, target_size, INTER_LINEAR);
84
+ for (int c = 0 ; c < 3 ; c++)
85
+ {
86
+ for (int i = 0 ; i < this ->inpHeight ; i++)
87
+ {
88
+ for (int j = 0 ; j < this ->inpWidth ; j++)
89
+ {
90
+ float pix = dstimg.ptr <uchar>(i)[j * 3 + c];
91
+ this ->input_images [(3 + c) * target_size.area () + i * this ->inpWidth + j] = (pix / 255.0 - this ->mean_ [c]) / this ->std_ [c];
92
+ }
93
+ }
94
+ }
95
+ }
96
+
97
+
98
+ void DeDoDeRunner_end2end::detect (Mat image_a, Mat image_b, vector<cv::KeyPoint>& points_A, vector<cv::KeyPoint>& points_B)
99
+ {
100
+ this ->preprocess (image_a, image_b);
101
+ array<int64_t , 4 > input_shape_{ 2 , 3 , this ->inpHeight , this ->inpWidth };
102
+
103
+ auto allocator_info = MemoryInfo::CreateCpu (OrtDeviceAllocator, OrtMemTypeCPU);
104
+ Value input_tensor_ = Value::CreateTensor<float >(allocator_info, input_images.data (), input_images.size (), input_shape_.data (), input_shape_.size ());
105
+
106
+ // 开始推理
107
+ vector<Value> ort_outputs = ort_session->Run (RunOptions{ nullptr }, &input_names[0 ], &input_tensor_, 1 , output_names.data (), output_names.size ());
108
+
109
+ // /Postprocessing
110
+ const float * matches_A = ort_outputs[0 ].GetTensorMutableData <float >();
111
+ const float * matches_B = ort_outputs[1 ].GetTensorMutableData <float >();
112
+ int num_points = ort_outputs[0 ].GetTensorTypeAndShapeInfo ().GetShape ()[0 ];
113
+ // /cout << "tensor total element = " << ort_outputs[0].GetTensorTypeAndShapeInfo().GetElementCount() << endl;
114
+ points_A.resize (num_points);
115
+ for (int i = 0 ; i < num_points; i++)
116
+ {
117
+ points_A[i].pt .x = (matches_A[i * 2 ] + 1 ) * 0.5 * image_a.cols ;
118
+ points_A[i].pt .y = (matches_A[i * 2 + 1 ] + 1 ) * 0.5 * image_a.rows ;
119
+ points_A[i].size = 1 .f ;
120
+ }
121
+
122
+ num_points = ort_outputs[1 ].GetTensorTypeAndShapeInfo ().GetShape ()[0 ];
123
+ points_B.resize (num_points);
124
+ for (int i = 0 ; i < num_points; i++)
125
+ {
126
+ points_B[i].pt .x = (matches_B[i * 2 ] + 1 ) * 0.5 * image_b.cols ;
127
+ points_B[i].pt .y = (matches_B[i * 2 + 1 ] + 1 ) * 0.5 * image_b.rows ;
128
+ points_B[i].size = 1 .f ;
129
+ }
130
+ }
131
+
132
+ int main ()
133
+ {
134
+ DeDoDeRunner_end2end mynet (" weights/dedode_end2end_1024.onnx" );
135
+ string imgpath_a = " images/im_A.jpg" ;
136
+ string imgpath_b = " images/im_B.jpg" ;
137
+ Mat image_a = imread (imgpath_a);
138
+ Mat image_b = imread (imgpath_b);
139
+
140
+ vector<cv::KeyPoint> points_A;
141
+ vector<cv::KeyPoint> points_B;
142
+ mynet.detect (image_a, image_b, points_A, points_B);
143
+
144
+ // 匹配结果放在matches里面
145
+ const int num_points = points_A.size ();
146
+ vector<DMatch> matches (num_points);
147
+ for (int i = 0 ; i < num_points; i++)
148
+ {
149
+ matches[i] = DMatch (i, i, 0 .f );
150
+ }
151
+
152
+ // 按照匹配关系将图画出来,背景图为match_img
153
+ Mat match_img;
154
+ drawMatches (image_a, points_A, image_b, points_B, matches, match_img);
155
+
156
+ // -- Show detected matches
157
+ static const string kWinName = " Image Matches in ONNXRuntime" ;
158
+ namedWindow (kWinName , WINDOW_NORMAL);
159
+ imshow (kWinName , match_img);
160
+ waitKey (0 );
161
+ destroyAllWindows ();
162
+ }
0 commit comments