423 lines
15 KiB
Python
423 lines
15 KiB
Python
# coding=utf-8
|
||
import json
|
||
import time
|
||
import uuid
|
||
from contextlib import asynccontextmanager
|
||
from copy import deepcopy
|
||
from typing import List, Union
|
||
from typing import Literal, Optional
|
||
from colorama import init, Fore
|
||
import requests
|
||
import torch
|
||
import uvicorn
|
||
from fastapi import FastAPI
|
||
from fastapi.middleware.cors import CORSMiddleware
|
||
from loguru import logger
|
||
from pydantic import BaseModel, Field
|
||
from sse_starlette.sse import EventSourceResponse
|
||
from transformers import AutoTokenizer, AutoModel
|
||
|
||
from api import get_complete_docs
|
||
from tool_using.tool_register import get_tools, dispatch_tool
|
||
from utils import process_response, generate_chatglm3, generate_stream_chatglm3
|
||
|
||
SYSTEM_PROMPT = """你是一个农业领域的AI助手,你需要通过已知信息和历史对话来回答人类的问题。
|
||
特别注意:已知信息不一定是正确答案,还要根据对话的历史问题和答案进行思考回答,如果找不到答案,请用自己的知识进行回答,不可以乱回答!!!
|
||
直接给出答案即可,答案使用中文。\n"""
|
||
|
||
|
||
@asynccontextmanager
|
||
async def lifespan(app: FastAPI): # collects GPU memory
|
||
yield
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
torch.cuda.ipc_collect()
|
||
|
||
|
||
app = FastAPI(lifespan=lifespan)
|
||
|
||
app.add_middleware(
|
||
CORSMiddleware,
|
||
allow_origins=["*"],
|
||
allow_credentials=True,
|
||
allow_methods=["*"],
|
||
allow_headers=["*"],
|
||
)
|
||
|
||
|
||
class ClearHistory(BaseModel):
|
||
code: int
|
||
message: str
|
||
|
||
|
||
class ModelCard(BaseModel):
|
||
id: str
|
||
object: str = "model"
|
||
created: int = Field(default_factory=lambda: int(time.time()))
|
||
owned_by: str = "owner"
|
||
root: Optional[str] = None
|
||
parent: Optional[str] = None
|
||
permission: Optional[list] = None
|
||
|
||
|
||
class ModelList(BaseModel):
|
||
object: str = "list"
|
||
data: List[ModelCard] = []
|
||
|
||
|
||
class FunctionCallResponse(BaseModel):
|
||
name: Optional[str] = None
|
||
arguments: Optional[str] = None
|
||
|
||
|
||
class ChatMessage(BaseModel):
|
||
role: Literal["user", "assistant", "system", "function"]
|
||
content: str = None
|
||
name: Optional[str] = None
|
||
function_call: Optional[FunctionCallResponse] = None
|
||
|
||
|
||
class DeltaMessage(BaseModel):
|
||
role: Optional[Literal["user", "assistant", "system"]] = None
|
||
content: Optional[str] = None
|
||
function_call: Optional[FunctionCallResponse] = None
|
||
|
||
|
||
class ChatCompletionRequest(BaseModel):
|
||
model: str
|
||
messages: List[ChatMessage]
|
||
kb_name: str
|
||
stream: Optional[bool] = False
|
||
functions: Optional[Union[dict, List[dict]]] = None
|
||
# Additional parameters
|
||
temperature: Optional[float] = 0.8
|
||
top_p: Optional[float] = 0.8
|
||
max_tokens: Optional[int] = None
|
||
max_length: Optional[int] = None
|
||
repetition_penalty: Optional[float] = 1.1
|
||
top_k: int = 3 # 召回的知识库文档数量
|
||
score_threshold: float = 1.0 # 知识库严谨度
|
||
system_prompt: str = SYSTEM_PROMPT
|
||
|
||
|
||
class ChatCompletionResponseChoice(BaseModel):
|
||
index: int
|
||
message: ChatMessage
|
||
finish_reason: Literal["stop", "length", "function_call", "finish"]
|
||
|
||
|
||
class ChatCompletionResponseStreamChoice(BaseModel):
|
||
index: int
|
||
delta: DeltaMessage
|
||
finish_reason: Optional[Literal["stop", "length", "function_call", "finish"]]
|
||
|
||
|
||
class UsageInfo(BaseModel):
|
||
prompt_tokens: int = 0
|
||
total_tokens: int = 0
|
||
completion_tokens: Optional[int] = 0
|
||
|
||
|
||
class History(BaseModel):
|
||
id: str
|
||
question: str
|
||
messages: List[ChatMessage]
|
||
query: str = None
|
||
wiki_content: List[str] = None
|
||
|
||
|
||
class ChatCompletionResponse(BaseModel):
|
||
model: str
|
||
object: Literal["chat.completion", "chat.completion.chunk"]
|
||
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
||
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
||
usage: Optional[UsageInfo] = None
|
||
history: History = None
|
||
|
||
|
||
def get_contexts(query, top_k, score_threshold, kb_name):
|
||
return get_complete_docs(query, top_k, score_threshold, kb_name)
|
||
|
||
|
||
from fastapi import Depends, HTTPException, Request
|
||
|
||
|
||
# 验证
|
||
async def verify_token(request: Request):
|
||
auth_header = request.headers.get('Authorization')
|
||
if auth_header:
|
||
token_type, _, token = auth_header.partition(' ')
|
||
if (
|
||
token_type.lower() == "bearer"
|
||
and token == "sk-chatglm3-6b"
|
||
): # 这里配置你的token
|
||
return True
|
||
raise HTTPException(
|
||
status_code=401,
|
||
detail="Invalid authorization credentials",
|
||
)
|
||
|
||
|
||
@app.get("/v1/clear/history")
|
||
async def clear_history():
|
||
try:
|
||
if torch.cuda.is_available():
|
||
torch.cuda.empty_cache()
|
||
torch.cuda.ipc_collect()
|
||
except:
|
||
raise HTTPException(status_code=500, detail="Internal 500 Error")
|
||
return ClearHistory(code=200, message="success")
|
||
|
||
|
||
@app.get("/v1/models", response_model=ModelList)
|
||
async def list_models():
|
||
model_card = ModelCard(id="chatglm3-6b-32k")
|
||
return ModelList(data=[model_card])
|
||
|
||
|
||
@app.get("/v1/tool", response_model=ChatCompletionResponse)
|
||
def run_conversation(query: str, stream=False, functions=None):
|
||
params = dict(model="chatglm3-6b", messages=[{"role": "user", "content": query}], kb_name="ceshi", stream=stream)
|
||
if functions:
|
||
params["functions"] = functions
|
||
else:
|
||
params["functions"] = get_tools()
|
||
response = requests.post(f"{base_url}/v1/chat/completions", json=params, stream=False)
|
||
max_retry = 1
|
||
response = ChatCompletionResponse(**response.json())
|
||
print(response)
|
||
for _ in range(max_retry):
|
||
if not stream:
|
||
if response.choices[0].message.function_call:
|
||
function_call = response.choices[0].message.function_call
|
||
logger.info(f"Function Call Response: {function_call.json(ensure_ascii=False)}")
|
||
function_args = json.loads(function_call.arguments)
|
||
tool_response = dispatch_tool(function_call.name, function_args)
|
||
logger.info(f"Tool Call Response: {tool_response}")
|
||
msg = response.choices[0].message
|
||
params["messages"].append({
|
||
'role': msg.role,
|
||
'content': msg.content
|
||
})
|
||
params["messages"].append({
|
||
"role": "function",
|
||
"name": function_call.name,
|
||
"content": tool_response, # 调用函数返回结果
|
||
})
|
||
else:
|
||
reply = response.choices[0].message.content
|
||
logger.info(f"Final Reply: \n{reply}")
|
||
return
|
||
|
||
else:
|
||
output = ""
|
||
for chunk in response:
|
||
content = chunk.choices[0].delta.content or ""
|
||
print(Fore.BLUE + content, end="", flush=True)
|
||
output += content
|
||
|
||
if chunk.choices[0].finish_reason == "stop":
|
||
return
|
||
|
||
elif chunk.choices[0].finish_reason == "function_call":
|
||
print("\n")
|
||
|
||
function_call = chunk.choices[0].delta.function_call
|
||
logger.info(f"Function Call Response: {function_call.model_dump()}")
|
||
|
||
function_args = json.loads(function_call.arguments)
|
||
tool_response = dispatch_tool(function_call.name, function_args)
|
||
logger.info(f"Tool Call Response: {tool_response}")
|
||
|
||
params["messages"].append(
|
||
{
|
||
"role": "assistant",
|
||
"content": output
|
||
}
|
||
)
|
||
params["messages"].append(
|
||
{
|
||
"role": "function",
|
||
"name": function_call.name,
|
||
"content": tool_response, # 调用函数返回结果
|
||
}
|
||
)
|
||
|
||
break
|
||
print(params)
|
||
res = requests.post(f"{base_url}/v1/chat/completions", json=params)
|
||
return ChatCompletionResponse(**res.json())
|
||
|
||
|
||
def init(request):
|
||
# TODO 如果用户上传的有SYSTEM,则重新处理
|
||
r = deepcopy(request)
|
||
# 向量库查询语句
|
||
query = "\n".join([m.content for m in r.messages])
|
||
# 用户问题
|
||
question = r.messages[-1].content
|
||
# 召回的文档数量
|
||
top_k = r.top_k
|
||
# 分数匹配
|
||
score_threshold = r.score_threshold
|
||
# 知识库名称
|
||
kb_name = r.kb_name
|
||
# 向量库查询内容
|
||
contexts = get_contexts(query, top_k, score_threshold, kb_name)
|
||
# 添加系统提示词
|
||
system_prompt = r.system_prompt or SYSTEM_PROMPT
|
||
r.messages.insert(0, ChatMessage(role='system', content=system_prompt))
|
||
# 重构消息
|
||
new_messages = r.messages
|
||
# 知识库模板
|
||
wiki_content = "\n".join(contexts)
|
||
prompt = f"已知信息:\n{wiki_content}\n问题:\n{question}\n"
|
||
new_messages[-1].content = prompt
|
||
return query, question, contexts, prompt, new_messages
|
||
|
||
|
||
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
||
async def create_chat_completion(request: ChatCompletionRequest):
|
||
# TODO agent调用 需要判断function是否存在
|
||
global model, tokenizer
|
||
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
||
raise HTTPException(status_code=400, detail="Invalid request")
|
||
# TODO 添加wiki模板提示词
|
||
query, question, contexts, prompt, new_messages = init(request)
|
||
###########################################################
|
||
gen_params = dict(
|
||
messages=request.messages if request.functions is not None else new_messages,
|
||
temperature=request.temperature,
|
||
top_p=request.top_p,
|
||
max_tokens=request.max_tokens or 1024,
|
||
max_length=request.max_length or 8192,
|
||
echo=False,
|
||
stream=request.stream,
|
||
repetition_penalty=request.repetition_penalty,
|
||
functions=request.functions,
|
||
system_prompt=request.system_prompt or SYSTEM_PROMPT,
|
||
top_k=request.top_k or 3,
|
||
score_threshold=request.score_threshold or 1.0,
|
||
kb_name=request.kb_name,
|
||
)
|
||
logger.debug(f"==== request ====\n{gen_params}")
|
||
# 流式输出
|
||
if request.stream:
|
||
generate = predict(request.model, gen_params, request.messages, question, query, contexts)
|
||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||
#############################################################################################
|
||
# 非流式输出
|
||
response = generate_chatglm3(model, tokenizer, gen_params)
|
||
usage = UsageInfo()
|
||
|
||
function_call, finish_reason = None, "stop"
|
||
if request.functions:
|
||
try:
|
||
function_call = process_response(response["text"], use_tool=True)
|
||
except:
|
||
logger.warning("Failed to parse tool call")
|
||
|
||
if isinstance(function_call, dict):
|
||
finish_reason = "function_call"
|
||
function_call = FunctionCallResponse(**function_call)
|
||
|
||
message = ChatMessage(
|
||
role="assistant",
|
||
content=response["text"],
|
||
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
|
||
)
|
||
# 添加历史记录
|
||
request.messages.append(message)
|
||
h = History(id=str(uuid.uuid4()), messages=request.messages, question=question, query=query, wiki_content=contexts)
|
||
choice_data = ChatCompletionResponseChoice(
|
||
index=0,
|
||
message=message,
|
||
finish_reason=finish_reason,
|
||
)
|
||
task_usage = UsageInfo.parse_obj(response["usage"])
|
||
for usage_key, usage_value in task_usage.dict().items():
|
||
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
||
|
||
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage,
|
||
history=h)
|
||
|
||
|
||
async def predict(model_id: str, params: dict, messages: List[ChatMessage], question, query, contexts):
|
||
global model, tokenizer
|
||
all_text = ""
|
||
# 第一条消息没有content
|
||
choice_data = ChatCompletionResponseStreamChoice(
|
||
index=0,
|
||
delta=DeltaMessage(role="assistant"),
|
||
finish_reason=None
|
||
)
|
||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||
|
||
previous_text = ""
|
||
f = None
|
||
for new_response in generate_stream_chatglm3(model, tokenizer, params):
|
||
decoded_unicode = new_response["text"]
|
||
delta_text = decoded_unicode[len(previous_text):]
|
||
previous_text = decoded_unicode
|
||
|
||
finish_reason = new_response["finish_reason"]
|
||
if len(delta_text) == 0 and finish_reason != "function_call":
|
||
continue
|
||
|
||
function_call = None
|
||
if finish_reason == "function_call":
|
||
try:
|
||
function_call = process_response(decoded_unicode, use_tool=True)
|
||
except:
|
||
print("Failed to parse tool call")
|
||
|
||
if isinstance(function_call, dict):
|
||
function_call = FunctionCallResponse(**function_call)
|
||
f = function_call
|
||
|
||
delta = DeltaMessage(
|
||
content=delta_text,
|
||
role="assistant",
|
||
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
|
||
)
|
||
all_text += delta_text
|
||
choice_data = ChatCompletionResponseStreamChoice(
|
||
index=0,
|
||
delta=delta,
|
||
finish_reason=finish_reason
|
||
)
|
||
# 这里添加历史记录
|
||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||
messages.append(ChatMessage(role='assistant', content=all_text, name='', function_call=f))
|
||
choice_data = ChatCompletionResponseStreamChoice(
|
||
index=0,
|
||
delta=DeltaMessage(),
|
||
finish_reason="stop"
|
||
)
|
||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||
h = History(id=str(uuid.uuid4()), question=question, messages=messages, query=query, wiki_content=contexts)
|
||
yield '[DONE]'
|
||
choice_data = ChatCompletionResponseChoice(
|
||
index=0,
|
||
message=ChatMessage(role='assistant', content=all_text, function_call=f),
|
||
finish_reason="finish"
|
||
)
|
||
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk", history=h)
|
||
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
||
|
||
|
||
if __name__ == "__main__":
|
||
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
|
||
model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True).cuda()
|
||
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
|
||
# from utils import load_model_on_gpus
|
||
# model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
|
||
model = model.eval()
|
||
host = "localhost"
|
||
port = 8000
|
||
base_url = f"http://{host}:{port}"
|
||
uvicorn.run(app, host=host, port=port, workers=1)
|