diff --git a/wilds/datasets/wilds_dataset.py b/wilds/datasets/wilds_dataset.py index d4f301e3..266bf122 100644 --- a/wilds/datasets/wilds_dataset.py +++ b/wilds/datasets/wilds_dataset.py @@ -495,13 +495,15 @@ def __init__(self, dataset, indices, transform, do_transform_y=False): self.do_transform_y = do_transform_y def __getitem__(self, idx): - x, y, metadata = self.dataset[self.indices[idx]] + dataset_item = self.dataset[self.indices[idx]] + + x, y, *_ = dataset_item # Unpacks the first two items; expects dataset items to have at least 2 elements if self.transform is not None: if self.do_transform_y: x, y = self.transform(x, y) else: x = self.transform(x) - return x, y, metadata + return x, y, *dataset_item[2:] # Returns additional elements beyond x and y, if any def __len__(self): return len(self.indices)