Skip to content

Commit 20baea4

Browse files
committed
fix: back to original version
1 parent c6281f4 commit 20baea4

File tree

1 file changed

+3
-12
lines changed

1 file changed

+3
-12
lines changed

DPF/pipelines/pipeline_stages.py

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,7 @@ def __init__(
6868
filter_class: Union[type[DataFilter], type[ColumnFilter], type[MultiGPUDataFilter]],
6969
filter_kwargs: dict[str, Any],
7070
processor_apply_kwargs: Optional[dict[str, Any]] = None,
71-
skip_if_columns_exist: bool = True,
72-
constant_gpu: bool = False
71+
skip_if_columns_exist: bool = True
7372
):
7473
self.filter_type = filter_type
7574
self.filter_class = filter_class
@@ -81,20 +80,12 @@ def __init__(
8180

8281
self.skip_if_columns_exist = skip_if_columns_exist
8382

84-
self.constant_gpu = constant_gpu
85-
if constant_gpu:
86-
self.filter_obj = self.filter_class(**self.filter_kwargs)
87-
88-
8983
@property
9084
def stage_name(self) -> str:
9185
return f"FilterPipelineStage(filter_class={self.filter_class}, filter_kwargs={self.filter_kwargs})"
9286

9387
def run(self, processor: DatasetProcessor, logger: logging.Logger) -> None:
94-
if self.constant_gpu:
95-
filter_obj = self.filter_obj
96-
else:
97-
filter_obj = self.filter_class(**self.filter_kwargs)
88+
filter_obj = self.filter_class(**self.filter_kwargs)
9889

9990
columns_to_be_added = filter_obj.result_columns
10091
columns_intersection = set(processor.columns).intersection(set(columns_to_be_added))
@@ -141,4 +132,4 @@ def stage_name(self) -> str:
141132
def run(self, processor: DatasetProcessor, logger: logging.Logger) -> None:
142133
transforms = self.transforms_class(**self.transforms_kwargs)
143134

144-
processor.apply_transform(transforms, **self.processor_apply_kwargs) # type: ignore
135+
processor.apply_transform(transforms, **self.processor_apply_kwargs) # type: ignore

0 commit comments

Comments
 (0)