@@ -28,32 +28,36 @@ using torch::executor::util::FileDataLoader;
28
28
29
29
class FlatTensorDataMapTest : public ::testing::Test {
30
30
protected:
31
+ void create_loader (const char * path, const char * module_name) {
32
+ // Create a loader for the serialized data map.
33
+ Result<FileDataLoader> loader = FileDataLoader::from (path);
34
+ ASSERT_EQ (loader.error (), Error::Ok);
35
+ loaders_.insert (
36
+ {module_name,
37
+ std::make_unique<FileDataLoader>(std::move (loader.get ()))});
38
+ }
31
39
void SetUp () override {
32
40
// Since these tests cause ET_LOG to be called, the PAL must be initialized
33
41
// first.
34
42
executorch::runtime::runtime_init ();
35
43
36
- // Load data map. The eager linear model is defined at:
37
- // //executorch/test/models/linear_model.py
38
- const char * path = std::getenv (" ET_MODULE_LINEAR_DATA_PATH" );
39
- Result<FileDataLoader> loader = FileDataLoader::from (path);
40
- ASSERT_EQ (loader.error (), Error::Ok);
41
-
42
- data_map_loader_ =
43
- std::make_unique<FileDataLoader>(std::move (loader.get ()));
44
+ // Model defined in //executorch/test/models/linear_model.py
45
+ create_loader (std::getenv (" ET_MODULE_LINEAR_DATA_PATH" ), " linear" );
46
+ // Model defined in //executorch/test/models/export_delegated_program.py
47
+ create_loader (std::getenv (" ET_MODULE_LINEAR_XNN_DATA_PATH" ), " linear_xnn" );
44
48
}
45
- std::unique_ptr<FileDataLoader> data_map_loader_ ;
49
+ std::unordered_map<std::string, std:: unique_ptr<FileDataLoader>> loaders_ ;
46
50
};
47
51
48
52
TEST_F (FlatTensorDataMapTest, LoadFlatTensorDataMap) {
49
53
Result<FlatTensorDataMap> data_map =
50
- FlatTensorDataMap::load (data_map_loader_ .get ());
54
+ FlatTensorDataMap::load (loaders_[ " linear " ] .get ());
51
55
EXPECT_EQ (data_map.error (), Error::Ok);
52
56
}
53
57
54
58
TEST_F (FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
55
59
Result<FlatTensorDataMap> data_map =
56
- FlatTensorDataMap::load (data_map_loader_ .get ());
60
+ FlatTensorDataMap::load (loaders_[ " linear " ] .get ());
57
61
EXPECT_EQ (data_map.error (), Error::Ok);
58
62
59
63
// Check tensor layouts are correct.
@@ -95,7 +99,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
95
99
96
100
TEST_F (FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
97
101
Result<FlatTensorDataMap> data_map =
98
- FlatTensorDataMap::load (data_map_loader_ .get ());
102
+ FlatTensorDataMap::load (loaders_[ " linear " ] .get ());
99
103
EXPECT_EQ (data_map.error (), Error::Ok);
100
104
101
105
// Check tensor data sizes are correct.
@@ -116,7 +120,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
116
120
117
121
TEST_F (FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
118
122
Result<FlatTensorDataMap> data_map =
119
- FlatTensorDataMap::load (data_map_loader_ .get ());
123
+ FlatTensorDataMap::load (loaders_[ " linear " ] .get ());
120
124
EXPECT_EQ (data_map.error (), Error::Ok);
121
125
122
126
// Check num tensors is 2.
@@ -140,7 +144,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
140
144
141
145
TEST_F (FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
142
146
Result<FlatTensorDataMap> data_map =
143
- FlatTensorDataMap::load (data_map_loader_ .get ());
147
+ FlatTensorDataMap::load (loaders_[ " linear " ] .get ());
144
148
EXPECT_EQ (data_map.error (), Error::Ok);
145
149
146
150
// get the metadata
@@ -160,3 +164,62 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
160
164
}
161
165
free (data);
162
166
}
167
+
168
+ TEST_F (FlatTensorDataMapTest, FlatTensorDataMap_GetData_Xnnpack) {
169
+ Result<FlatTensorDataMap> data_map =
170
+ FlatTensorDataMap::load (loaders_[" linear_xnn" ].get ());
171
+ EXPECT_EQ (data_map.error (), Error::Ok);
172
+
173
+ // Check tensor data sizes are correct.
174
+ // 64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885 is the
175
+ // hash of the 3*3 identity matrix
176
+ Result<FreeableBuffer> data_weight_res = data_map->get_data (
177
+ " 64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885" );
178
+ ASSERT_EQ (Error::Ok, data_weight_res.error ());
179
+ FreeableBuffer data_a = std::move (data_weight_res.get ());
180
+ EXPECT_EQ (data_a.size (), 36 ); // 3*3*4 (3*3 matrix, 4 bytes per float)
181
+
182
+ // 15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b is the
183
+ // hash of the 3*1 vector [1, 1, 1]
184
+ Result<FreeableBuffer> data_bias_res = data_map->get_data (
185
+ " 15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b" );
186
+ ASSERT_EQ (Error::Ok, data_bias_res.error ());
187
+ FreeableBuffer data_b = std::move (data_bias_res.get ());
188
+ EXPECT_EQ (data_b.size (), 12 ); // 3*4 (3*1 vector, 4 bytes per float)
189
+
190
+ // Check get_data fails when key is not found.
191
+ Result<FreeableBuffer> data_c_res = data_map->get_data (" c" );
192
+ EXPECT_EQ (data_c_res.error (), Error::NotFound);
193
+ }
194
+
195
+ TEST_F (FlatTensorDataMapTest, FlatTensorDataMap_Keys_Xnnpack) {
196
+ Result<FlatTensorDataMap> data_map =
197
+ FlatTensorDataMap::load (loaders_[" linear_xnn" ].get ());
198
+ EXPECT_EQ (data_map.error (), Error::Ok);
199
+
200
+ // Check num tensors is 2.
201
+ Result<size_t > num_tensors_res = data_map->get_num_keys ();
202
+ ASSERT_EQ (Error::Ok, num_tensors_res.error ());
203
+ EXPECT_EQ (num_tensors_res.get (), 2 );
204
+
205
+ // Check get_key returns the correct keys.
206
+ Result<const char *> key0_res = data_map->get_key (0 );
207
+ ASSERT_EQ (Error::Ok, key0_res.error ());
208
+ EXPECT_EQ (
209
+ strcmp (
210
+ key0_res.get (),
211
+ " 64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885" ),
212
+ 0 );
213
+
214
+ Result<const char *> key1_res = data_map->get_key (1 );
215
+ ASSERT_EQ (Error::Ok, key1_res.error ());
216
+ EXPECT_EQ (
217
+ strcmp (
218
+ key1_res.get (),
219
+ " 15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b" ),
220
+ 0 );
221
+
222
+ // Check get_key fails when out of bounds.
223
+ Result<const char *> key2_res = data_map->get_key (2 );
224
+ EXPECT_EQ (key2_res.error (), Error::InvalidArgument);
225
+ }
0 commit comments