Skip to content

Commit 9ba6ab1

Browse files
committed
Update flat tensor ndm to account for named delegate data
Pull Request resolved: #10330 Currently flat_tensor ndm only accounts for tensors in get_data, get_num_keys, get_key functions. Add support to return named_data values as well. TODO: consolidate tensors and named_data into one structure in the flatbuffer. This will simplify all the serialization and runtime code. Currently, we assume that a PTD file has either tensors or named_data, not both. After the consolidation, this won't be an issue. Differential Revision: [D73380805](https://our.internmc.facebook.com/intern/diff/D73380805/) ghstack-source-id: 279887671
1 parent 19535f8 commit 9ba6ab1

File tree

5 files changed

+155
-18
lines changed

5 files changed

+155
-18
lines changed

extension/flat_tensor/flat_tensor_data_map.cpp

+71-3
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,28 @@ Result<const flat_tensor_flatbuffer::TensorMetadata*> get_flat_tensor_metadata(
6565
return Error::NotFound;
6666
}
6767

68+
Result<const flat_tensor_flatbuffer::NamedData*> get_named_data(
69+
const char* key,
70+
const flatbuffers::Vector<
71+
flatbuffers::Offset<flat_tensor_flatbuffer::NamedData>>* named_data) {
72+
// Linear search by name.
73+
if (named_data == nullptr) {
74+
return Error::NotFound;
75+
}
76+
for (int i = 0; i < named_data->size(); i++) {
77+
if (std::strcmp(named_data->Get(i)->key()->c_str(), key) == 0) {
78+
const auto* metadata = named_data->Get(i);
79+
ET_CHECK_OR_RETURN_ERROR(
80+
metadata->segment_index() >= 0,
81+
InvalidExternalData,
82+
"Invalid segment_index %d; malformed PTD file.",
83+
metadata->segment_index());
84+
return metadata;
85+
}
86+
}
87+
return Error::NotFound;
88+
}
89+
6890
Result<const TensorLayout> create_tensor_layout(
6991
const flat_tensor_flatbuffer::TensorMetadata* tensor_metadata) {
7092
ScalarType scalar_type =
@@ -109,6 +131,39 @@ ET_NODISCARD Result<const TensorLayout> FlatTensorDataMap::get_metadata(
109131

110132
ET_NODISCARD Result<FreeableBuffer> FlatTensorDataMap::get_data(
111133
const char* key) const {
134+
// TODO(lfq): consolidate named_data and tensors.
135+
// Check named data.
136+
Result<const flat_tensor_flatbuffer::NamedData*> named_data =
137+
get_named_data(key, flat_tensor_->named_data());
138+
if (named_data.ok()) {
139+
size_t segment_index = named_data.get()->segment_index();
140+
ET_CHECK_OR_RETURN_ERROR(
141+
segment_index < flat_tensor_->segments()->size(),
142+
InvalidExternalData,
143+
"Invalid segment_index %zu; malformed PTD file.",
144+
segment_index);
145+
146+
size_t segment_offset =
147+
flat_tensor_->segments()->Get(segment_index)->offset();
148+
size_t segment_size = flat_tensor_->segments()->Get(segment_index)->size();
149+
ET_CHECK_OR_RETURN_ERROR(
150+
segment_offset <
151+
header_.segment_base_offset + header_.segment_data_size,
152+
InvalidExternalData,
153+
"Invalid segment offset %zu is larger than the segment_base_offset + segment_data_size %" PRIu64
154+
"; malformed PTD file.",
155+
segment_offset,
156+
header_.segment_base_offset + header_.segment_data_size);
157+
return loader_->load(
158+
/*offset=*/header_.segment_base_offset + segment_offset,
159+
segment_size,
160+
DataLoader::SegmentInfo(DataLoader::SegmentInfo::Type::External));
161+
}
162+
if (named_data.error() != Error::NotFound) {
163+
return named_data.error();
164+
}
165+
166+
// Check tensors, if named data is not found.
112167
Result<const flat_tensor_flatbuffer::TensorMetadata*> metadata =
113168
get_flat_tensor_metadata(key, flat_tensor_->tensors());
114169
if (!metadata.ok()) {
@@ -179,16 +234,29 @@ ET_NODISCARD Error FlatTensorDataMap::load_data_into(
179234
}
180235

181236
ET_NODISCARD Result<size_t> FlatTensorDataMap::get_num_keys() const {
182-
return flat_tensor_->tensors()->size();
237+
// TODO(lfq): consolidate named_data and tensors.
238+
return flat_tensor_->tensors()->size() + flat_tensor_->named_data()->size();
183239
}
184240

185241
ET_NODISCARD Result<const char*> FlatTensorDataMap::get_key(
186242
size_t index) const {
187-
if (index < 0 || index >= flat_tensor_->tensors()->size()) {
243+
// TODO(lfq): consolidate named_data and tensors.
244+
// Currently, this assumes we either have tensors or named_data, but not both.
245+
if (flat_tensor_->tensors()->size() > 0 &&
246+
flat_tensor_->named_data()->size() > 0) {
247+
return Error::NotImplemented;
248+
}
249+
if (index < 0) {
188250
return Error::InvalidArgument;
189251
}
252+
if (index < flat_tensor_->tensors()->size()) {
253+
return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
254+
}
255+
if (index < flat_tensor_->named_data()->size()) {
256+
return flat_tensor_->named_data()->Get(index)->key()->c_str();
257+
}
190258

191-
return flat_tensor_->tensors()->Get(index)->fully_qualified_name()->c_str();
259+
return Error::InvalidArgument;
192260
}
193261

194262
/* static */ Result<FlatTensorDataMap> FlatTensorDataMap::load(

extension/flat_tensor/test/flat_tensor_data_map_test.cpp

+77-14
Original file line numberDiff line numberDiff line change
@@ -28,32 +28,36 @@ using torch::executor::util::FileDataLoader;
2828

2929
class FlatTensorDataMapTest : public ::testing::Test {
3030
protected:
31+
void create_loader(const char* path, const char* module_name) {
32+
// Create a loader for the serialized data map.
33+
Result<FileDataLoader> loader = FileDataLoader::from(path);
34+
ASSERT_EQ(loader.error(), Error::Ok);
35+
loaders_.insert(
36+
{module_name,
37+
std::make_unique<FileDataLoader>(std::move(loader.get()))});
38+
}
3139
void SetUp() override {
3240
// Since these tests cause ET_LOG to be called, the PAL must be initialized
3341
// first.
3442
executorch::runtime::runtime_init();
3543

36-
// Load data map. The eager linear model is defined at:
37-
// //executorch/test/models/linear_model.py
38-
const char* path = std::getenv("ET_MODULE_LINEAR_DATA_PATH");
39-
Result<FileDataLoader> loader = FileDataLoader::from(path);
40-
ASSERT_EQ(loader.error(), Error::Ok);
41-
42-
data_map_loader_ =
43-
std::make_unique<FileDataLoader>(std::move(loader.get()));
44+
// Model defined in //executorch/test/models/linear_model.py
45+
create_loader(std::getenv("ET_MODULE_LINEAR_DATA_PATH"), "linear");
46+
// Model defined in //executorch/test/models/export_delegated_program.py
47+
create_loader(std::getenv("ET_MODULE_LINEAR_XNN_DATA_PATH"), "linear_xnn");
4448
}
45-
std::unique_ptr<FileDataLoader> data_map_loader_;
49+
std::unordered_map<std::string, std::unique_ptr<FileDataLoader>> loaders_;
4650
};
4751

4852
TEST_F(FlatTensorDataMapTest, LoadFlatTensorDataMap) {
4953
Result<FlatTensorDataMap> data_map =
50-
FlatTensorDataMap::load(data_map_loader_.get());
54+
FlatTensorDataMap::load(loaders_["linear"].get());
5155
EXPECT_EQ(data_map.error(), Error::Ok);
5256
}
5357

5458
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
5559
Result<FlatTensorDataMap> data_map =
56-
FlatTensorDataMap::load(data_map_loader_.get());
60+
FlatTensorDataMap::load(loaders_["linear"].get());
5761
EXPECT_EQ(data_map.error(), Error::Ok);
5862

5963
// Check tensor layouts are correct.
@@ -95,7 +99,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetMetadata) {
9599

96100
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
97101
Result<FlatTensorDataMap> data_map =
98-
FlatTensorDataMap::load(data_map_loader_.get());
102+
FlatTensorDataMap::load(loaders_["linear"].get());
99103
EXPECT_EQ(data_map.error(), Error::Ok);
100104

101105
// Check tensor data sizes are correct.
@@ -116,7 +120,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData) {
116120

117121
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
118122
Result<FlatTensorDataMap> data_map =
119-
FlatTensorDataMap::load(data_map_loader_.get());
123+
FlatTensorDataMap::load(loaders_["linear"].get());
120124
EXPECT_EQ(data_map.error(), Error::Ok);
121125

122126
// Check num tensors is 2.
@@ -140,7 +144,7 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys) {
140144

141145
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
142146
Result<FlatTensorDataMap> data_map =
143-
FlatTensorDataMap::load(data_map_loader_.get());
147+
FlatTensorDataMap::load(loaders_["linear"].get());
144148
EXPECT_EQ(data_map.error(), Error::Ok);
145149

146150
// get the metadata
@@ -160,3 +164,62 @@ TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_LoadInto) {
160164
}
161165
free(data);
162166
}
167+
168+
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_GetData_Xnnpack) {
169+
Result<FlatTensorDataMap> data_map =
170+
FlatTensorDataMap::load(loaders_["linear_xnn"].get());
171+
EXPECT_EQ(data_map.error(), Error::Ok);
172+
173+
// Check tensor data sizes are correct.
174+
// 64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885 is the
175+
// hash of the 3*3 identity matrix
176+
Result<FreeableBuffer> data_weight_res = data_map->get_data(
177+
"64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885");
178+
ASSERT_EQ(Error::Ok, data_weight_res.error());
179+
FreeableBuffer data_a = std::move(data_weight_res.get());
180+
EXPECT_EQ(data_a.size(), 36); // 3*3*4 (3*3 matrix, 4 bytes per float)
181+
182+
// 15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b is the
183+
// hash of the 3*1 vector [1, 1, 1]
184+
Result<FreeableBuffer> data_bias_res = data_map->get_data(
185+
"15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b");
186+
ASSERT_EQ(Error::Ok, data_bias_res.error());
187+
FreeableBuffer data_b = std::move(data_bias_res.get());
188+
EXPECT_EQ(data_b.size(), 12); // 3*4 (3*1 vector, 4 bytes per float)
189+
190+
// Check get_data fails when key is not found.
191+
Result<FreeableBuffer> data_c_res = data_map->get_data("c");
192+
EXPECT_EQ(data_c_res.error(), Error::NotFound);
193+
}
194+
195+
TEST_F(FlatTensorDataMapTest, FlatTensorDataMap_Keys_Xnnpack) {
196+
Result<FlatTensorDataMap> data_map =
197+
FlatTensorDataMap::load(loaders_["linear_xnn"].get());
198+
EXPECT_EQ(data_map.error(), Error::Ok);
199+
200+
// Check num tensors is 2.
201+
Result<size_t> num_tensors_res = data_map->get_num_keys();
202+
ASSERT_EQ(Error::Ok, num_tensors_res.error());
203+
EXPECT_EQ(num_tensors_res.get(), 2);
204+
205+
// Check get_key returns the correct keys.
206+
Result<const char*> key0_res = data_map->get_key(0);
207+
ASSERT_EQ(Error::Ok, key0_res.error());
208+
EXPECT_EQ(
209+
strcmp(
210+
key0_res.get(),
211+
"64eec129c8d3f58ee6b7ca145b25e312fa82d3d276db5adaedb59aaebb824885"),
212+
0);
213+
214+
Result<const char*> key1_res = data_map->get_key(1);
215+
ASSERT_EQ(Error::Ok, key1_res.error());
216+
EXPECT_EQ(
217+
strcmp(
218+
key1_res.get(),
219+
"15ec7bf0b50732b49f8228e07d24365338f9e3ab994b00af08e5a3bffe55fd8b"),
220+
0);
221+
222+
// Check get_key fails when out of bounds.
223+
Result<const char*> key2_res = data_map->get_key(2);
224+
EXPECT_EQ(key2_res.error(), Error::InvalidArgument);
225+
}

extension/flat_tensor/test/targets.bzl

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ def define_common_targets(is_fbcode=False):
3535
# The tests use this var to find the program file to load. This uses
3636
# an fbcode target path because the authoring/export tools
3737
# intentionally don't work in xplat (since they're host-only tools).
38-
"ET_MODULE_LINEAR_PROGRAM_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.pte])",
3938
"ET_MODULE_LINEAR_DATA_PATH": "$(location fbcode//executorch/test/models:exported_program_and_data[ModuleLinear.ptd])",
39+
"ET_MODULE_LINEAR_XNN_DATA_PATH": "$(location fbcode//executorch/test/models:exported_xnnpack_program_and_data[ModuleLinear.ptd])",
4040
}
4141

4242
runtime.cxx_test(

test/models/export_delegated_program.py

+5
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,11 @@ class ModuleLinear(torch.nn.Module):
9999
def __init__(self):
100100
super().__init__()
101101
self.linear = torch.nn.Linear(3, 3)
102+
# Make the linear deterministic.
103+
self.linear.weight.data = torch.tensor(
104+
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
105+
) # 3x3 identity matrix
106+
self.linear.bias.data = torch.tensor([0.0, 0.0, 0.0])
102107

103108
def forward(self, x: torch.Tensor):
104109
return self.linear(x)

test/models/targets.bzl

+1
Original file line numberDiff line numberDiff line change
@@ -222,6 +222,7 @@ def define_common_targets():
222222
default_outs = ["."],
223223
visibility = [
224224
"//executorch/runtime/executor/test/...",
225+
"//executorch/extension/flat_tensor/test/...",
225226
"//executorch/test/...",
226227
],
227228
)

0 commit comments

Comments
 (0)