-
Notifications
You must be signed in to change notification settings - Fork 240
/
Copy pathgeneration.py
159 lines (140 loc) · 6.02 KB
/
generation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import json
from math import ceil
from typing import List, Optional
from accelerate.utils import set_seed
from torch.utils.data.dataloader import DataLoader
from transformers import StoppingCriteria, StoppingCriteriaList
from bigcode_eval.utils import TokenizedDataset, complete_code
class EndOfFunctionCriteria(StoppingCriteria):
"""Custom `StoppingCriteria` which checks if all generated functions in the batch are completed."""
def __init__(self, start_length, eof_strings, tokenizer, check_fn=None):
self.start_length = start_length
self.eof_strings = eof_strings
self.tokenizer = tokenizer
if check_fn is None:
check_fn = lambda decoded_generation: any(
[stop_string in decoded_generation for stop_string in self.eof_strings]
)
self.check_fn = check_fn
def __call__(self, input_ids, scores, **kwargs):
"""Returns true if all generated sequences contain any of the end-of-function strings."""
decoded_generations = self.tokenizer.batch_decode(input_ids[:, self.start_length :])
return all([self.check_fn(decoded_generation) for decoded_generation in decoded_generations])
class TooLongFunctionCriteria(StoppingCriteria):
"""Custom `StoppingCriteria` which checks if the generated function is too long by a certain multiplier based on input length."""
def __init__(self, input_length, multiplier):
self.input_length = input_length
self.multiplier = multiplier
def __call__(self, input_ids, scores, **kwargs):
"""Returns true if generated sequence is too long."""
return input_ids.shape[1] > int(self.input_length * self.multiplier)
def parallel_generations(
task,
dataset,
accelerator,
model,
tokenizer,
n_tasks,
args,
curr_sample_idx: int = 0,
save_every_k_tasks: int = -1,
intermediate_generations: Optional[List[Optional[List[Optional[str]]]]] = None,
intermediate_save_generations_path: Optional[str] = None,
):
if args.load_generations_path:
# load generated code
with open(args.load_generations_path) as fp:
generations = json.load(fp)
if accelerator.is_main_process:
print(
f"generations loaded, {n_tasks} selected from {len(generations)} with {len(generations[0])} candidates"
)
return generations[:n_tasks]
set_seed(args.seed, device_specific=True)
# Setup generation settings
gen_kwargs = {
"do_sample": args.do_sample,
"temperature": args.temperature,
"top_p": args.top_p,
"top_k": args.top_k,
"max_new_tokens": args.max_length_generation,
}
stopping_criteria = []
# The input_length / start_length set to 0 for now will be adjusted later
# Check if the task has a custom check_fn method for the stopping criteria
if task.stop_words and tokenizer.eos_token:
task.stop_words.append(tokenizer.eos_token)
if hasattr(task, "check_fn"):
stopping_criteria.append(
EndOfFunctionCriteria(0, task.stop_words, tokenizer, task.check_fn)
)
elif task.stop_words:
stopping_criteria.append(
EndOfFunctionCriteria(0, task.stop_words, tokenizer)
)
if hasattr(task, "max_length_multiplier") and task.max_length_multiplier:
stopping_criteria.append(
TooLongFunctionCriteria(0, task.max_length_multiplier)
)
if stopping_criteria:
gen_kwargs["stopping_criteria"] = StoppingCriteriaList(stopping_criteria)
if args.instruction_tokens:
instruction_tokens = args.instruction_tokens.split(",")
if len(instruction_tokens) != 3:
raise ValueError(
"Instruction tokens should contain exactly 3 tokens separated by a comma. If a token is empty, represent it as ''"
)
for token in instruction_tokens:
if token.strip() != "":
task.stop_words.append(token)
else:
instruction_tokens = None
if accelerator.is_main_process:
print(f"number of problems for this task is {n_tasks}")
n_copies = ceil(args.n_samples / args.batch_size)
ds_tokenized = TokenizedDataset(
task,
dataset,
tokenizer,
num_devices=accelerator.state.num_processes,
max_length=args.max_length_generation,
limit_start=args.limit_start + curr_sample_idx,
n_tasks=n_tasks,
n_copies=n_copies,
prefix=args.prefix,
has_encoder=args.modeltype == "seq2seq",
instruction_tokens=instruction_tokens,
)
# do not confuse args.batch_size, which is actually the num_return_sequences
ds_loader = DataLoader(ds_tokenized, batch_size=1)
is_loaded_in_8bit = getattr(model, "is_loaded_in_8bit", False)
is_loaded_in_4bit = getattr(model, "is_loaded_in_4bit", False)
if args.max_memory_per_gpu is not None:
# The model is already sharded across multiple GPUs
ds_loader = accelerator.prepare(ds_loader)
elif not is_loaded_in_8bit and not is_loaded_in_4bit:
# we only wrap data loader to avoid extra memory occupation
model = model.to(accelerator.device)
ds_loader = accelerator.prepare(ds_loader)
else:
# model.to() is not supported for 8bit and 4bit models
model, ds_loader = accelerator.prepare(model, ds_loader)
generations = complete_code(
task,
accelerator,
model,
tokenizer,
ds_loader,
n_tasks=n_tasks,
limit_start=args.limit_start + curr_sample_idx,
batch_size=args.batch_size,
prefix=args.prefix,
instruction_tokens=instruction_tokens,
postprocess=args.postprocess,
is_wrapped=is_loaded_in_8bit or is_loaded_in_4bit,
save_every_k_tasks=save_every_k_tasks,
intermediate_generations=intermediate_generations,
intermediate_save_generations_path=intermediate_save_generations_path,
**gen_kwargs,
)
return generations