Skip to content

Commit c41c7d8

Browse files
committed
Update flat tensor ndm to account for named delegate data
Currently flat_tensor ndm only accounts for tensors in get_data, get_num_keys, get_key functions. Add support to return named_data values as well. TODO: consolidate tensors and named_data into one structure in the flatbuffer. This will simplify all the serialization and runtime code. Currently, we assume that a PTD file has either tensors or named_data, not both. After the consolidation, this won't be an issue. Differential Revision: [D73380805](https://our.internmc.facebook.com/intern/diff/D73380805/) ghstack-source-id: 279349252 Pull Request resolved: #10330
1 parent 3413ce5 commit c41c7d8

File tree

5 files changed

+151
-18
lines changed

5 files changed

+151
-18
lines changed

extension/flat_tensor/flat_tensor_data_map.cpp

+67-3
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,25 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
6565
return Error::NotFound;
6666
}
6767

68+
Result<const flat_tensor_flatbuffer::NamedData*> get_named_data(
69+
const char* key,
70+
const flatbuffers::Vector<
71+
flatbuffers::Offset<flat_tensor_flatbuffer::NamedData>>* named_data) {
72+
// Linear search by name.
73+
for (int i = 0; i < named_data->size(); i++) {
74+
if (std::strcmp(named_data->Get(i)->key()->c_str(), key) == 0) {
75+
const auto* metadata = named_data->Get(i);
76+
ET_CHECK_OR_RETURN_ERROR(
77+
metadata->segment_index() >= 0,
78+
InvalidExternalData,
79+
"Invalid segment_index %d; malformed PTD file.",
80+
metadata->segment_index());
81+
return metadata;
82+
}
83+
}
84+
return Error::NotFound;
85+
}
86+
6887
Result<const TensorLayout> create_tensor_layout(
6988
const flat_tensor_flatbuffer::TensorMetadata* tensor_metadata) {
7089
ScalarType scalar_type =
@@ -109,6 +128,39 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
109128

110129
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
111130
const char* key) const {
131+
// TODO(lfq): consolidate named_data and tensors.
132+
// Check named data.
133+
Result<const flat_tensor_flatbuffer::NamedData*> named_data =
134+
get_named_data(key, flat_tensor_->named_data());
135+
if (named_data.ok()) {
136+
size_t segment_index = named_data.get()->segment_index();
137+
ET_CHECK_OR_RETURN_ERROR(
138+
segment_index < flat_tensor_->segments()->size(),
139+
InvalidExternalData,
140+
"Invalid segment_index %zu; malformed PTD file.",
141+
segment_index);
142+
143+
size_t segment_offset =
144+
flat_tensor_->segments()->Get(segment_index)->offset();
145+
size_t segment_size = flat_tensor_->segments()->Get(segment_index)->size();
146+
ET_CHECK_OR_RETURN_ERROR(
147+
segment_offset <
148+
header_.segment_base_offset + header_.segment_data_size,
149+
InvalidExternalData,
150+
"Invalid segment offset %zu is larger than the segment_base_offset + segment_data_size %" PRIu64
151+
"; malformed PTD file.",
152+
segment_offset,
153+
header_.segment_base_offset + header_.segment_data_size);
154+
return loader_->load(
155+
/*offset=*/header_.segment_base_offset + segment_offset,
156+
segment_size,
157+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
158+
}
159+
if (named_data.error() != Error::NotFound) {
160+
return named_data.error();
161+
}
162+
163+
// Check tensors, if named data is not found.
112164
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
113165
get_flat_tensor_metadata(key, flat_tensor_->tensors());
114166
if (!metadata.ok()) {
@@ -179,16 +231,28 @@ ET_NODISCARD Error FlatTensorDataMap::load_data_into(
179231
}
180232

181233
ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
182-
return flat_tensor_->tensors()->size();
234+
// TODO(lfq): consolidate named_data and tensors.
235+
return flat_tensor_->tensors()->size() + flat_tensor_->named_data()->size();
183236
}
184237

185238
ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
186239
size_t index) const {
187-
if (index < 0 || index >= flat_tensor_->tensors()->size()) {
240+
// TODO(lfq): consolidate named_data and tensors.
241+
// Currently, this assumes we either have tensors or named_data, but not both.
242+
if (flat_tensor_->tensors()->size() > 0 && flat_tensor_->named_data()->size() > 0) {
243+
return Error::NotImplemented;
244+
}
245+
if (index < 0) {
188246
return Error::InvalidArgument;
189247
}
248+
if (index < flat_tensor_->tensors()->size()) {
249+
return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
250+
}
251+
if (index < flat_tensor_->named_data()->size()) {
252+
return flat_tensor_->named_data()->Get(index)->key()->c_str();
253+
}
190254

191-
return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
255+
return Error::InvalidArgument;
192256
}
193257

194258
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(

extension/flat_tensor/test/flat_tensor_data_map_test.cpp

+77-14
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,36 @@ using torch::executor::util::FileDataLoader;
2828

2929
class FlatTensorDataMapTest : public ::testing::Test {
3030
protected:
31+
void create_loader(const char* path, const char* module_name) {
32+
// Create a loader for the serialized data map.
33+
Result<FileDataLoader> loader = FileDataLoader::from(path);
34+
ASSERT_EQ(loader.error(), Error::Ok);
35+
loaders_.insert(
36+
{module_name,
37+
std::make_unique<FileDataLoader>(std::move(loader.get()))});
38+
}
3139
void SetUp() override {
3240
// Since these tests cause ET_LOG to be called, the PAL must be initialized
3341
// first.
3442
executorch::runtime::runtime_init();
3543

36-
// Load data map. The eager linear model is defined at:
37-
// //executorch/test/models/linear_model.py
38-
const char* path = std::getenv("ET_MODULE_LINEAR_DATA_PATH");
39-
Result<FileDataLoader> loader = FileDataLoader::from(path);
40-
ASSERT_EQ(loader.error(), Error::Ok);
41-
42-
data_map_loader_ =
43-
std::make_unique<FileDataLoader>(std::move(loader.get()));
44+
// Model defined in //executorch/test/models/linear_model.py
45+
create_loader(std::getenv("ET_MODULE_LINEAR_DATA_PATH"), "linear");
46+
// Model defined in //executorch/test/models/export_delegated_program.py
47+
create_loader(std::getenv("ET_MODULE_LINEAR_XNN_DATA_PATH"), "linear_xnn");
4448
}
45-
std::unique_ptr<FileDataLoader> data_map_loader_;
49+
std::unordered_map<std::string, std::unique_ptr<FileDataLoader>> loaders_;
4650
};
4751

4852
TEST_F(FlatTensorDataMapTest, LoadFlatTensorDataMap) {
4953
Result<FlatTensorDataMap> data_map =
50-
FlatTensorDataMap::load(data_map_loader_.get());
54+
FlatTensorDataMap::load(loaders_["linear"].get());
5155
EXPECT_EQ(data_map.error(), Error::Ok);
5256
}
5357

5458
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
5559
Result<FlatTensorDataMap> data_map =
56-
FlatTensorDataMap::load(data_map_loader_.get());
60+
FlatTensorDataMap::load(loaders_["linear"].get());
5761
EXPECT_EQ(data_map.error(), Error::Ok);
5862

5963
// Check tensor layouts are correct.
@@ -95,7 +99,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
9599

96100
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
97101
Result<FlatTensorDataMap> data_map =
98-
FlatTensorDataMap::load(data_map_loader_.get());
102+
FlatTensorDataMap::load(loaders_["linear"].get());
99103
EXPECT_EQ(data_map.error(), Error::Ok);
100104

101105
// Check tensor data sizes are correct.
@@ -116,7 +120,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
116120

117121
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
118122
Result<FlatTensorDataMap> data_map =
119-
FlatTensorDataMap::load(data_map_loader_.get());
123+
FlatTensorDataMap::load(loaders_["linear"].get());
120124
EXPECT_EQ(data_map.error(), Error::Ok);
121125

122126
// Check num tensors is 2.
@@ -140,7 +144,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
140144

141145
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
142146
Result<FlatTensorDataMap> data_map =
143-
FlatTensorDataMap::load(data_map_loader_.get());
147+
FlatTensorDataMap::load(loaders_["linear"].get());
144148
EXPECT_EQ(data_map.error(), Error::Ok);
145149

146150
// get the metadata
@@ -160,3 +164,62 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
160164
}
161165
free(data);
162166
}
167+
168+
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData_Xnnpack) {
169+
Result<FlatTensorDataMap> data_map =
170+
FlatTensorDataMap::load(loaders_["linear_xnn"].get());
171+
EXPECT_EQ(data_map.error(), Error::Ok);
172+
173+
// Check tensor data sizes are correct.
174+
// 64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885 is the
175+
// hash of the 3*3 identity matrix
176+
Result<FreeableBuffer> data_weight_res = data_map->get_data(
177+
"64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885");
178+
ASSERT_EQ(Error::Ok, data_weight_res.error());
179+
FreeableBuffer data_a = std::move(data_weight_res.get());
180+
EXPECT_EQ(data_a.size(), 36); // 3*3*4 (3*3 matrix, 4 bytes per float)
181+
182+
// 15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b is the
183+
// hash of the 3*1 vector [1, 1, 1]
184+
Result<FreeableBuffer> data_bias_res = data_map->get_data(
185+
"15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b");
186+
ASSERT_EQ(Error::Ok, data_bias_res.error());
187+
FreeableBuffer data_b = std::move(data_bias_res.get());
188+
EXPECT_EQ(data_b.size(), 12); // 3*4 (3*1 vector, 4 bytes per float)
189+
190+
// Check get_data fails when key is not found.
191+
Result<FreeableBuffer> data_c_res = data_map->get_data("c");
192+
EXPECT_EQ(data_c_res.error(), Error::NotFound);
193+
}
194+
195+
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys_Xnnpack) {
196+
Result<FlatTensorDataMap> data_map =
197+
FlatTensorDataMap::load(loaders_["linear_xnn"].get());
198+
EXPECT_EQ(data_map.error(), Error::Ok);
199+
200+
// Check num tensors is 2.
201+
Result<size_t> num_tensors_res = data_map->get_num_keys();
202+
ASSERT_EQ(Error::Ok, num_tensors_res.error());
203+
EXPECT_EQ(num_tensors_res.get(), 2);
204+
205+
// Check get_key returns the correct keys.
206+
Result<const char*> key0_res = data_map->get_key(0);
207+
ASSERT_EQ(Error::Ok, key0_res.error());
208+
EXPECT_EQ(
209+
strcmp(
210+
key0_res.get(),
211+
"64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885"),
212+
0);
213+
214+
Result<const char*> key1_res = data_map->get_key(1);
215+
ASSERT_EQ(Error::Ok, key1_res.error());
216+
EXPECT_EQ(
217+
strcmp(
218+
key1_res.get(),
219+
"15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b"),
220+
0);
221+
222+
// Check get_key fails when out of bounds.
223+
Result<const char*> key2_res = data_map->get_key(2);
224+
EXPECT_EQ(key2_res.error(), Error::InvalidArgument);
225+
}

extension/flat_tensor/test/targets.bzl

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def define_common_targets(is_fbcode=False):
3535
# The tests use this var to find the program file to load. This uses
3636
# an fbcode target path because the authoring/export tools
3737
# intentionally don't work in xplat (since they're host-only tools).
38-
"ET_MODULE_LINEAR_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])",
3938
"ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])",
39+
"ET_MODULE_LINEAR_XNN_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_data[ModuleLinear-e.ptd])",
4040
}
4141

4242
runtime.cxx_test(

test/models/export_delegated_program.py

+5
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ class ModuleLinear(torch.nn.Module):
9999
def __init__(self):
100100
super().__init__()
101101
self.linear = torch.nn.Linear(3, 3)
102+
# Make the linear deterministic.
103+
self.linear.weight.data = torch.tensor(
104+
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
105+
) # 3x3 identity matrix
106+
self.linear.bias.data = torch.tensor([0.0, 0.0, 0.0])
102107

103108
def forward(self, x: torch.Tensor):
104109
return self.linear(x)

test/models/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def define_common_targets():
222222
default_outs = ["."],
223223
visibility = [
224224
"//executorch/runtime/executor/test/...",
225+
"//executorch/extension/flat_tensor/test/...",
225226
"//executorch/test/...",
226227
],
227228
)

0 commit comments

Comments
 (0)