@@ -68,8 +68,7 @@ def __init__(
68
68
filter_class : Union [type [DataFilter ], type [ColumnFilter ], type [MultiGPUDataFilter ]],
69
69
filter_kwargs : dict [str , Any ],
70
70
processor_apply_kwargs : Optional [dict [str , Any ]] = None ,
71
- skip_if_columns_exist : bool = True ,
72
- constant_gpu : bool = False
71
+ skip_if_columns_exist : bool = True
73
72
):
74
73
self .filter_type = filter_type
75
74
self .filter_class = filter_class
@@ -81,20 +80,12 @@ def __init__(
81
80
82
81
self .skip_if_columns_exist = skip_if_columns_exist
83
82
84
- self .constant_gpu = constant_gpu
85
- if constant_gpu :
86
- self .filter_obj = self .filter_class (** self .filter_kwargs )
87
-
88
-
89
83
@property
90
84
def stage_name (self ) -> str :
91
85
return f"FilterPipelineStage(filter_class={ self .filter_class } , filter_kwargs={ self .filter_kwargs } )"
92
86
93
87
def run (self , processor : DatasetProcessor , logger : logging .Logger ) -> None :
94
- if self .constant_gpu :
95
- filter_obj = self .filter_obj
96
- else :
97
- filter_obj = self .filter_class (** self .filter_kwargs )
88
+ filter_obj = self .filter_class (** self .filter_kwargs )
98
89
99
90
columns_to_be_added = filter_obj .result_columns
100
91
columns_intersection = set (processor .columns ).intersection (set (columns_to_be_added ))
@@ -141,4 +132,4 @@ def stage_name(self) -> str:
141
132
def run (self , processor : DatasetProcessor , logger : logging .Logger ) -> None :
142
133
transforms = self .transforms_class (** self .transforms_kwargs )
143
134
144
- processor .apply_transform (transforms , ** self .processor_apply_kwargs ) # type: ignore
135
+ processor .apply_transform (transforms , ** self .processor_apply_kwargs ) # type: ignore
0 commit comments