250 lines
8.3 KiB
Python
250 lines
8.3 KiB
Python
|
import gc
|
|||
|
import json
|
|||
|
import os
|
|||
|
from typing import Dict, Union, Optional, Tuple
|
|||
|
|
|||
|
import torch
|
|||
|
from torch.nn import Module
|
|||
|
from transformers import AutoModel, PreTrainedModel, PreTrainedTokenizer
|
|||
|
from transformers.generation.logits_process import LogitsProcessor
|
|||
|
|
|||
|
|
|||
|
def auto_configure_device_map(num_gpus: int) -> Dict[str, int]:
|
|||
|
# transformer.word_embeddings 占用1层
|
|||
|
# transformer.final_layernorm 和 lm_head 占用1层
|
|||
|
# transformer.layers 占用 28 层
|
|||
|
# 总共30层分配到num_gpus张卡上
|
|||
|
num_trans_layers = 28
|
|||
|
per_gpu_layers = 30 / num_gpus
|
|||
|
|
|||
|
# bugfix: 在linux中调用torch.embedding传入的weight,input不在同一device上,导致RuntimeError
|
|||
|
# windows下 model.device 会被设置成 transformer.word_embeddings.device
|
|||
|
# linux下 model.device 会被设置成 lm_head.device
|
|||
|
# 在调用chat或者stream_chat时,input_ids会被放到model.device上
|
|||
|
# 如果transformer.word_embeddings.device和model.device不同,则会导致RuntimeError
|
|||
|
# 因此这里将transformer.word_embeddings,transformer.final_layernorm,lm_head都放到第一张卡上
|
|||
|
# 本文件来源于https://github.com/THUDM/ChatGLM-6B/blob/main/utils.py
|
|||
|
# 仅此处做少许修改以支持ChatGLM3
|
|||
|
device_map = {
|
|||
|
'transformer.embedding.word_embeddings': 0,
|
|||
|
'transformer.encoder.final_layernorm': 0,
|
|||
|
'transformer.output_layer': 0,
|
|||
|
'transformer.rotary_pos_emb': 0,
|
|||
|
'lm_head': 0
|
|||
|
}
|
|||
|
|
|||
|
used = 2
|
|||
|
gpu_target = 0
|
|||
|
for i in range(num_trans_layers):
|
|||
|
if used >= per_gpu_layers:
|
|||
|
gpu_target += 1
|
|||
|
used = 0
|
|||
|
assert gpu_target < num_gpus
|
|||
|
device_map[f'transformer.encoder.layers.{i}'] = gpu_target
|
|||
|
used += 1
|
|||
|
|
|||
|
return device_map
|
|||
|
|
|||
|
|
|||
|
def load_model_on_gpus(checkpoint_path: Union[str, os.PathLike], num_gpus: int = 2,
|
|||
|
device_map: Optional[Dict[str, int]] = None, **kwargs) -> Module:
|
|||
|
if num_gpus < 2 and device_map is None:
|
|||
|
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half().cuda()
|
|||
|
else:
|
|||
|
from accelerate import dispatch_model
|
|||
|
|
|||
|
model = AutoModel.from_pretrained(checkpoint_path, trust_remote_code=True, **kwargs).half()
|
|||
|
|
|||
|
if device_map is None:
|
|||
|
device_map = auto_configure_device_map(num_gpus)
|
|||
|
|
|||
|
model = dispatch_model(model, device_map=device_map)
|
|||
|
|
|||
|
return model
|
|||
|
|
|||
|
|
|||
|
class InvalidScoreLogitsProcessor(LogitsProcessor):
|
|||
|
def __call__(
|
|||
|
self, input_ids: torch.LongTensor, scores: torch.FloatTensor
|
|||
|
) -> torch.FloatTensor:
|
|||
|
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
|||
|
scores.zero_()
|
|||
|
scores[..., 5] = 5e4
|
|||
|
return scores
|
|||
|
|
|||
|
|
|||
|
def process_response(output: str, use_tool: bool = False) -> Union[str, dict]:
|
|||
|
content = ""
|
|||
|
for response in output.split("<|assistant|>"):
|
|||
|
metadata, content = response.split("\n", maxsplit=1)
|
|||
|
if not metadata.strip():
|
|||
|
content = content.strip()
|
|||
|
content = content.replace("[[训练时间]]", "2023年")
|
|||
|
else:
|
|||
|
if use_tool:
|
|||
|
content = "\n".join(content.split("\n")[1:-1])
|
|||
|
|
|||
|
def tool_call(**kwargs):
|
|||
|
return kwargs
|
|||
|
|
|||
|
parameters = eval(content)
|
|||
|
content = {
|
|||
|
"name": metadata.strip(),
|
|||
|
"arguments": json.dumps(parameters, ensure_ascii=False)
|
|||
|
}
|
|||
|
else:
|
|||
|
content = {
|
|||
|
"name": metadata.strip(),
|
|||
|
"content": content
|
|||
|
}
|
|||
|
return content
|
|||
|
|
|||
|
|
|||
|
@torch.inference_mode()
|
|||
|
def generate_stream_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
|
|||
|
messages = params["messages"]
|
|||
|
functions = params["functions"]
|
|||
|
temperature = float(params.get("temperature", 1.0))
|
|||
|
repetition_penalty = float(params.get("repetition_penalty", 1.0))
|
|||
|
top_p = float(params.get("top_p", 1.0))
|
|||
|
max_new_tokens = int(params.get("max_tokens", 256))
|
|||
|
max_length = params.get("max_length", None)
|
|||
|
echo = params.get("echo", True)
|
|||
|
|
|||
|
messages = process_chatglm_messages(messages, functions=functions)
|
|||
|
query, role = messages[-1]["content"], messages[-1]["role"]
|
|||
|
|
|||
|
inputs = tokenizer.build_chat_input(query, history=messages[:-1], role=role)
|
|||
|
inputs = inputs.to(model.device)
|
|||
|
input_echo_len = len(inputs["input_ids"][0])
|
|||
|
|
|||
|
if input_echo_len >= model.config.seq_length:
|
|||
|
print(f"Input length larger than {model.config.seq_length}")
|
|||
|
|
|||
|
if max_length is None:
|
|||
|
max_length = min(max_new_tokens + input_echo_len, model.config.seq_length)
|
|||
|
|
|||
|
eos_token_id = [
|
|||
|
tokenizer.eos_token_id,
|
|||
|
tokenizer.get_command("<|user|>"),
|
|||
|
]
|
|||
|
|
|||
|
gen_kwargs = {
|
|||
|
"max_length": max_length,
|
|||
|
"do_sample": True if temperature > 1e-5 else False,
|
|||
|
"top_p": top_p,
|
|||
|
"repetition_penalty": repetition_penalty,
|
|||
|
"logits_processor": [InvalidScoreLogitsProcessor()],
|
|||
|
}
|
|||
|
if temperature > 1e-5:
|
|||
|
gen_kwargs["temperature"] = temperature
|
|||
|
|
|||
|
total_len = 0
|
|||
|
for total_ids in model.stream_generate(**inputs, eos_token_id=eos_token_id, **gen_kwargs):
|
|||
|
total_ids = total_ids.tolist()[0]
|
|||
|
total_len = len(total_ids)
|
|||
|
if echo:
|
|||
|
output_ids = total_ids[:-1]
|
|||
|
else:
|
|||
|
output_ids = total_ids[input_echo_len:-1]
|
|||
|
|
|||
|
response = tokenizer.decode(output_ids)
|
|||
|
if response and response[-1] != "<EFBFBD>":
|
|||
|
response, stop_found = apply_stopping_strings(response, ["<|observation|>"])
|
|||
|
|
|||
|
yield {
|
|||
|
"text": response,
|
|||
|
"usage": {
|
|||
|
"prompt_tokens": input_echo_len,
|
|||
|
"completion_tokens": total_len - input_echo_len,
|
|||
|
"total_tokens": total_len,
|
|||
|
},
|
|||
|
"finish_reason": "function_call" if stop_found else None,
|
|||
|
}
|
|||
|
|
|||
|
if stop_found:
|
|||
|
break
|
|||
|
|
|||
|
# Only last stream result contains finish_reason, we set finish_reason as stop
|
|||
|
ret = {
|
|||
|
"text": response,
|
|||
|
"usage": {
|
|||
|
"prompt_tokens": input_echo_len,
|
|||
|
"completion_tokens": total_len - input_echo_len,
|
|||
|
"total_tokens": total_len,
|
|||
|
},
|
|||
|
"finish_reason": "stop",
|
|||
|
}
|
|||
|
yield ret
|
|||
|
|
|||
|
gc.collect()
|
|||
|
torch.cuda.empty_cache()
|
|||
|
|
|||
|
|
|||
|
def process_chatglm_messages(messages, functions=None):
|
|||
|
_messages = messages
|
|||
|
messages = []
|
|||
|
|
|||
|
if functions:
|
|||
|
messages.append(
|
|||
|
{
|
|||
|
"role": "system",
|
|||
|
"content": "Answer the following questions as best as you can. You have access to the following tools:",
|
|||
|
"tools": functions
|
|||
|
}
|
|||
|
)
|
|||
|
|
|||
|
for m in _messages:
|
|||
|
role, content, func_call = m.role, m.content, m.function_call
|
|||
|
if role == "function":
|
|||
|
messages.append(
|
|||
|
{
|
|||
|
"role": "observation",
|
|||
|
"content": content
|
|||
|
}
|
|||
|
)
|
|||
|
|
|||
|
elif role == "assistant" and func_call is not None:
|
|||
|
for response in content.split("<|assistant|>"):
|
|||
|
metadata, sub_content = response.split("\n", maxsplit=1)
|
|||
|
messages.append(
|
|||
|
{
|
|||
|
"role": role,
|
|||
|
"metadata": metadata,
|
|||
|
"content": sub_content.strip()
|
|||
|
}
|
|||
|
)
|
|||
|
else:
|
|||
|
messages.append({"role": role, "content": content})
|
|||
|
return messages
|
|||
|
|
|||
|
|
|||
|
def generate_chatglm3(model: PreTrainedModel, tokenizer: PreTrainedTokenizer, params: dict):
|
|||
|
for response in generate_stream_chatglm3(model, tokenizer, params):
|
|||
|
pass
|
|||
|
return response
|
|||
|
|
|||
|
|
|||
|
def apply_stopping_strings(reply, stop_strings) -> Tuple[str, bool]:
|
|||
|
stop_found = False
|
|||
|
for string in stop_strings:
|
|||
|
idx = reply.find(string)
|
|||
|
if idx != -1:
|
|||
|
reply = reply[:idx]
|
|||
|
stop_found = True
|
|||
|
break
|
|||
|
|
|||
|
if not stop_found:
|
|||
|
# If something like "\nYo" is generated just before "\nYou: is completed, trim it
|
|||
|
for string in stop_strings:
|
|||
|
for j in range(len(string) - 1, 0, -1):
|
|||
|
if reply[-j:] == string[:j]:
|
|||
|
reply = reply[:-j]
|
|||
|
break
|
|||
|
else:
|
|||
|
continue
|
|||
|
|
|||
|
break
|
|||
|
|
|||
|
return reply, stop_found
|