From 96fb4d8991dc9e7a7f6e5fba154b3c770ee72297 Mon Sep 17 00:00:00 2001 From: rdarbinyan Date: Fri, 26 Jan 2024 20:52:40 +0400 Subject: [PATCH] Enhance WILDSSubset to support flexible data returns --- wilds/datasets/wilds_dataset.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/wilds/datasets/wilds_dataset.py b/wilds/datasets/wilds_dataset.py index d4f301e3..266bf122 100644 --- a/wilds/datasets/wilds_dataset.py +++ b/wilds/datasets/wilds_dataset.py @@ -495,13 +495,15 @@ def __init__(self, dataset, indices, transform, do_transform_y=False): self.do_transform_y = do_transform_y def __getitem__(self, idx): - x, y, metadata = self.dataset[self.indices[idx]] + dataset_item = self.dataset[self.indices[idx]] + + x, y, *_ = dataset_item # Unpacks the first two items; expects dataset items to have at least 2 elements if self.transform is not None: if self.do_transform_y: x, y = self.transform(x, y) else: x = self.transform(x) - return x, y, metadata + return x, y, *dataset_item[2:] # Returns additional elements beyond x and y, if any def __len__(self): return len(self.indices)