Skip to content

Commit b96c639

Browse files
[inference] Fix running time of test_continuous_batching (#5750)
1 parent 5f8c0a0 commit b96c639

File tree

1 file changed

+26
-58
lines changed

1 file changed

+26
-58
lines changed

tests/test_infer/test_continuous_batching.py

Lines changed: 26 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,10 @@
33
import numpy as np
44
import pytest
55
import torch
6-
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
6+
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM
77

88
import colossalai
9-
from colossalai.inference.config import _DEFAULT_PROMPT_TEMPLATES, InferenceConfig
9+
from colossalai.inference.config import InferenceConfig
1010
from colossalai.inference.core.engine import InferenceEngine
1111
from colossalai.testing import parameterize, rerun_if_address_is_in_use, spawn
1212

@@ -28,69 +28,37 @@ def generate_inputs(num_sequences, min_length, max_length):
2828
return sequences
2929

3030

31-
@parameterize(
32-
"test_config",
33-
[
34-
{
35-
"max_batch_size": 8,
36-
"max_output_len": 512,
37-
"max_input_len": 64,
38-
"do_sample": False,
39-
}
40-
],
41-
)
42-
def check_inference_engine(test_config, use_engine=False, prompt_template=None):
31+
@parameterize("n_multiple", [10])
32+
@parameterize("max_batch_size", [8])
33+
@parameterize("max_input_len", [128])
34+
@parameterize("max_output_len", [128])
35+
def check_inference_engine(n_multiple, max_batch_size, max_input_len, max_output_len):
4336
setup_seed(20)
44-
max_batch_size = test_config["max_batch_size"]
45-
max_input_len = test_config["max_input_len"]
46-
max_output_len = test_config["max_output_len"]
47-
do_sample = test_config["do_sample"]
48-
top_p = 0.5
49-
top_k = 50
50-
tokenizer = AutoTokenizer.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
51-
model = LlamaForCausalLM.from_pretrained("TinyLlama/TinyLlama-1.1B-Chat-v1.0").cuda().half()
37+
38+
tokenizer = AutoTokenizer.from_pretrained("hf-internal-testing/llama-tokenizer")
39+
model = LlamaForCausalLM(LlamaConfig(num_hidden_layers=2)).cuda()
5240
model = model.eval()
5341

54-
inputs_token_ids = generate_inputs(10 * max_batch_size, min_length=10, max_length=max_input_len)
55-
56-
if use_engine:
57-
inference_config = InferenceConfig(
58-
max_batch_size=max_batch_size, max_output_len=max_output_len, prompt_template=prompt_template
59-
)
60-
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
61-
assert inference_engine.generation_config.max_new_tokens == max_output_len
62-
inference_engine.add_request(prompts_token_ids=inputs_token_ids)
63-
assert inference_engine.request_handler._has_waiting()
64-
generation_config = GenerationConfig(do_sample=do_sample, top_p=top_p, top_k=top_k)
65-
outputs = inference_engine.generate(generation_config=generation_config)
66-
else:
67-
if prompt_template:
68-
# apply prompt template
69-
inputs = [_DEFAULT_PROMPT_TEMPLATES[prompt_template].format(input_text=input_text) for input_text in inputs]
70-
tokenizer.pad_token = tokenizer.eos_token
71-
tokenizer.pad_token_id = tokenizer.eos_token_id
72-
inputs = tokenizer.batch_encode_plus(inputs, padding=True, return_tensors="pt")["input_ids"]
73-
inputs = inputs.cuda()
74-
generation_config = GenerationConfig(
75-
do_sample=do_sample,
76-
top_p=top_p,
77-
top_k=top_k,
78-
pad_token_id=tokenizer.pad_token_id,
79-
max_new_tokens=max_output_len,
80-
)
81-
outputs = model.generate(inputs, generation_config=generation_config)
82-
outputs = tokenizer.batch_decode(outputs, skip_special_tokens=True)
83-
assert len(outputs) == 10 * max_batch_size
84-
85-
86-
@parameterize("prompt_template", [None, "llama"])
87-
def check_continuous_batching(prompt_template):
88-
check_inference_engine(use_engine=True, prompt_template=prompt_template)
42+
inputs_token_ids = generate_inputs(
43+
n_multiple * max_batch_size, min_length=max_input_len // 2, max_length=max_input_len
44+
)
45+
inference_config = InferenceConfig(
46+
max_batch_size=max_batch_size, max_input_len=max_input_len, max_output_len=max_output_len
47+
)
48+
inference_engine = InferenceEngine(model, tokenizer, inference_config, verbose=True)
49+
assert inference_engine.generation_config.max_new_tokens == max_output_len
50+
51+
inference_engine.add_request(prompts_token_ids=inputs_token_ids)
52+
assert inference_engine.request_handler._has_waiting()
53+
54+
outputs = inference_engine.generate()
55+
assert not inference_engine.request_handler._has_waiting()
56+
assert len(outputs) == n_multiple * max_batch_size
8957

9058

9159
def run_dist(rank, world_size, port):
9260
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost")
93-
check_continuous_batching()
61+
check_inference_engine()
9462

9563

9664
@pytest.mark.dist

0 commit comments

Comments
 (0)