Skip to content

Commit 9826d8c

Browse files
Chia-Jung Changfacebook-github-bot
Chia-Jung Chang
authored andcommitted
Enable InputRecorder to run in offline mode and cuda device (#2067)
Summary: Pull Request resolved: #2067 Enable `InputRecorder` to run in offline mode and cuda device. Specificlaly `InputRecorder` is called during Sandcastle CI tests, which do not allow us to access HuggingFace website. Additionally, there is a bug in BUCK file for missing `lm_eval` dependency, which was called by `torchao`. Differential Revision: D73179650
1 parent b195c57 commit 9826d8c

File tree

1 file changed

+18
-9
lines changed

1 file changed

+18
-9
lines changed

torchao/_models/_eval.py

+18-9
Original file line numberDiff line numberDiff line change
@@ -183,12 +183,20 @@ def __init__(
183183
vocab_size=32000,
184184
pad_token=0,
185185
device="cpu",
186+
online_access=True,
186187
):
187-
try:
188-
super().__init__()
189-
except TypeError:
190-
# lm_eval 0.4.2 removed the default init
191-
super().__init__("gpt2", device="cpu")
188+
if online_access:
189+
try:
190+
super().__init__()
191+
except TypeError:
192+
# lm_eval 0.4.2 removed the default init
193+
super().__init__("gpt2", device=device)
194+
else:
195+
# Create a minimal implementation when run offline
196+
self._device = torch.device(device)
197+
self._max_length = 2048
198+
self._max_gen_toks = 256
199+
self._batch_size = 1
192200

193201
self.tokenizer = tokenizer
194202
self._device = torch.device(device)
@@ -257,10 +265,11 @@ def record_inputs(
257265
calibration_tasks,
258266
calibration_limit,
259267
):
260-
try:
261-
lm_eval.tasks.initialize_tasks()
262-
except:
263-
pass
268+
if self.online_access:
269+
try:
270+
lm_eval.tasks.initialize_tasks()
271+
except:
272+
pass
264273

265274
task_dict = get_task_dict(calibration_tasks)
266275
print("Obtaining GPTQ calibration inputs on: ", calibration_tasks)

0 commit comments

Comments
 (0)