3
3
import numpy as np
4
4
import pytest
5
5
import torch
6
- from transformers import AutoTokenizer , GenerationConfig , LlamaForCausalLM
6
+ from transformers import AutoTokenizer , LlamaConfig , LlamaForCausalLM
7
7
8
8
import colossalai
9
- from colossalai .inference .config import _DEFAULT_PROMPT_TEMPLATES , InferenceConfig
9
+ from colossalai .inference .config import InferenceConfig
10
10
from colossalai .inference .core .engine import InferenceEngine
11
11
from colossalai .testing import parameterize , rerun_if_address_is_in_use , spawn
12
12
@@ -28,69 +28,37 @@ def generate_inputs(num_sequences, min_length, max_length):
28
28
return sequences
29
29
30
30
31
- @parameterize (
32
- "test_config" ,
33
- [
34
- {
35
- "max_batch_size" : 8 ,
36
- "max_output_len" : 512 ,
37
- "max_input_len" : 64 ,
38
- "do_sample" : False ,
39
- }
40
- ],
41
- )
42
- def check_inference_engine (test_config , use_engine = False , prompt_template = None ):
31
+ @parameterize ("n_multiple" , [10 ])
32
+ @parameterize ("max_batch_size" , [8 ])
33
+ @parameterize ("max_input_len" , [128 ])
34
+ @parameterize ("max_output_len" , [128 ])
35
+ def check_inference_engine (n_multiple , max_batch_size , max_input_len , max_output_len ):
43
36
setup_seed (20 )
44
- max_batch_size = test_config ["max_batch_size" ]
45
- max_input_len = test_config ["max_input_len" ]
46
- max_output_len = test_config ["max_output_len" ]
47
- do_sample = test_config ["do_sample" ]
48
- top_p = 0.5
49
- top_k = 50
50
- tokenizer = AutoTokenizer .from_pretrained ("TinyLlama/TinyLlama-1.1B-Chat-v1.0" )
51
- model = LlamaForCausalLM .from_pretrained ("TinyLlama/TinyLlama-1.1B-Chat-v1.0" ).cuda ().half ()
37
+
38
+ tokenizer = AutoTokenizer .from_pretrained ("hf-internal-testing/llama-tokenizer" )
39
+ model = LlamaForCausalLM (LlamaConfig (num_hidden_layers = 2 )).cuda ()
52
40
model = model .eval ()
53
41
54
- inputs_token_ids = generate_inputs (10 * max_batch_size , min_length = 10 , max_length = max_input_len )
55
-
56
- if use_engine :
57
- inference_config = InferenceConfig (
58
- max_batch_size = max_batch_size , max_output_len = max_output_len , prompt_template = prompt_template
59
- )
60
- inference_engine = InferenceEngine (model , tokenizer , inference_config , verbose = True )
61
- assert inference_engine .generation_config .max_new_tokens == max_output_len
62
- inference_engine .add_request (prompts_token_ids = inputs_token_ids )
63
- assert inference_engine .request_handler ._has_waiting ()
64
- generation_config = GenerationConfig (do_sample = do_sample , top_p = top_p , top_k = top_k )
65
- outputs = inference_engine .generate (generation_config = generation_config )
66
- else :
67
- if prompt_template :
68
- # apply prompt template
69
- inputs = [_DEFAULT_PROMPT_TEMPLATES [prompt_template ].format (input_text = input_text ) for input_text in inputs ]
70
- tokenizer .pad_token = tokenizer .eos_token
71
- tokenizer .pad_token_id = tokenizer .eos_token_id
72
- inputs = tokenizer .batch_encode_plus (inputs , padding = True , return_tensors = "pt" )["input_ids" ]
73
- inputs = inputs .cuda ()
74
- generation_config = GenerationConfig (
75
- do_sample = do_sample ,
76
- top_p = top_p ,
77
- top_k = top_k ,
78
- pad_token_id = tokenizer .pad_token_id ,
79
- max_new_tokens = max_output_len ,
80
- )
81
- outputs = model .generate (inputs , generation_config = generation_config )
82
- outputs = tokenizer .batch_decode (outputs , skip_special_tokens = True )
83
- assert len (outputs ) == 10 * max_batch_size
84
-
85
-
86
- @parameterize ("prompt_template" , [None , "llama" ])
87
- def check_continuous_batching (prompt_template ):
88
- check_inference_engine (use_engine = True , prompt_template = prompt_template )
42
+ inputs_token_ids = generate_inputs (
43
+ n_multiple * max_batch_size , min_length = max_input_len // 2 , max_length = max_input_len
44
+ )
45
+ inference_config = InferenceConfig (
46
+ max_batch_size = max_batch_size , max_input_len = max_input_len , max_output_len = max_output_len
47
+ )
48
+ inference_engine = InferenceEngine (model , tokenizer , inference_config , verbose = True )
49
+ assert inference_engine .generation_config .max_new_tokens == max_output_len
50
+
51
+ inference_engine .add_request (prompts_token_ids = inputs_token_ids )
52
+ assert inference_engine .request_handler ._has_waiting ()
53
+
54
+ outputs = inference_engine .generate ()
55
+ assert not inference_engine .request_handler ._has_waiting ()
56
+ assert len (outputs ) == n_multiple * max_batch_size
89
57
90
58
91
59
def run_dist (rank , world_size , port ):
92
60
colossalai .launch (rank = rank , world_size = world_size , port = port , host = "localhost" )
93
- check_continuous_batching ()
61
+ check_inference_engine ()
94
62
95
63
96
64
@pytest .mark .dist
0 commit comments