423 lines
15 KiB
Python
423 lines
15 KiB
Python
|
# coding=utf-8
|
|||
|
import json
|
|||
|
import time
|
|||
|
import uuid
|
|||
|
from contextlib import asynccontextmanager
|
|||
|
from copy import deepcopy
|
|||
|
from typing import List, Union
|
|||
|
from typing import Literal, Optional
|
|||
|
from colorama import init, Fore
|
|||
|
import requests
|
|||
|
import torch
|
|||
|
import uvicorn
|
|||
|
from fastapi import FastAPI
|
|||
|
from fastapi.middleware.cors import CORSMiddleware
|
|||
|
from loguru import logger
|
|||
|
from pydantic import BaseModel, Field
|
|||
|
from sse_starlette.sse import EventSourceResponse
|
|||
|
from transformers import AutoTokenizer, AutoModel
|
|||
|
|
|||
|
from api import get_complete_docs
|
|||
|
from tool_using.tool_register import get_tools, dispatch_tool
|
|||
|
from utils import process_response, generate_chatglm3, generate_stream_chatglm3
|
|||
|
|
|||
|
SYSTEM_PROMPT = """你是一个农业领域的AI助手,你需要通过已知信息和历史对话来回答人类的问题。
|
|||
|
特别注意:已知信息不一定是正确答案,还要根据对话的历史问题和答案进行思考回答,如果找不到答案,请用自己的知识进行回答,不可以乱回答!!!
|
|||
|
直接给出答案即可,答案使用中文。\n"""
|
|||
|
|
|||
|
|
|||
|
@asynccontextmanager
|
|||
|
async def lifespan(app: FastAPI): # collects GPU memory
|
|||
|
yield
|
|||
|
if torch.cuda.is_available():
|
|||
|
torch.cuda.empty_cache()
|
|||
|
torch.cuda.ipc_collect()
|
|||
|
|
|||
|
|
|||
|
app = FastAPI(lifespan=lifespan)
|
|||
|
|
|||
|
app.add_middleware(
|
|||
|
CORSMiddleware,
|
|||
|
allow_origins=["*"],
|
|||
|
allow_credentials=True,
|
|||
|
allow_methods=["*"],
|
|||
|
allow_headers=["*"],
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
class ClearHistory(BaseModel):
|
|||
|
code: int
|
|||
|
message: str
|
|||
|
|
|||
|
|
|||
|
class ModelCard(BaseModel):
|
|||
|
id: str
|
|||
|
object: str = "model"
|
|||
|
created: int = Field(default_factory=lambda: int(time.time()))
|
|||
|
owned_by: str = "owner"
|
|||
|
root: Optional[str] = None
|
|||
|
parent: Optional[str] = None
|
|||
|
permission: Optional[list] = None
|
|||
|
|
|||
|
|
|||
|
class ModelList(BaseModel):
|
|||
|
object: str = "list"
|
|||
|
data: List[ModelCard] = []
|
|||
|
|
|||
|
|
|||
|
class FunctionCallResponse(BaseModel):
|
|||
|
name: Optional[str] = None
|
|||
|
arguments: Optional[str] = None
|
|||
|
|
|||
|
|
|||
|
class ChatMessage(BaseModel):
|
|||
|
role: Literal["user", "assistant", "system", "function"]
|
|||
|
content: str = None
|
|||
|
name: Optional[str] = None
|
|||
|
function_call: Optional[FunctionCallResponse] = None
|
|||
|
|
|||
|
|
|||
|
class DeltaMessage(BaseModel):
|
|||
|
role: Optional[Literal["user", "assistant", "system"]] = None
|
|||
|
content: Optional[str] = None
|
|||
|
function_call: Optional[FunctionCallResponse] = None
|
|||
|
|
|||
|
|
|||
|
class ChatCompletionRequest(BaseModel):
|
|||
|
model: str
|
|||
|
messages: List[ChatMessage]
|
|||
|
kb_name: str
|
|||
|
stream: Optional[bool] = False
|
|||
|
functions: Optional[Union[dict, List[dict]]] = None
|
|||
|
# Additional parameters
|
|||
|
temperature: Optional[float] = 0.8
|
|||
|
top_p: Optional[float] = 0.8
|
|||
|
max_tokens: Optional[int] = None
|
|||
|
max_length: Optional[int] = None
|
|||
|
repetition_penalty: Optional[float] = 1.1
|
|||
|
top_k: int = 3 # 召回的知识库文档数量
|
|||
|
score_threshold: float = 1.0 # 知识库严谨度
|
|||
|
system_prompt: str = SYSTEM_PROMPT
|
|||
|
|
|||
|
|
|||
|
class ChatCompletionResponseChoice(BaseModel):
|
|||
|
index: int
|
|||
|
message: ChatMessage
|
|||
|
finish_reason: Literal["stop", "length", "function_call", "finish"]
|
|||
|
|
|||
|
|
|||
|
class ChatCompletionResponseStreamChoice(BaseModel):
|
|||
|
index: int
|
|||
|
delta: DeltaMessage
|
|||
|
finish_reason: Optional[Literal["stop", "length", "function_call", "finish"]]
|
|||
|
|
|||
|
|
|||
|
class UsageInfo(BaseModel):
|
|||
|
prompt_tokens: int = 0
|
|||
|
total_tokens: int = 0
|
|||
|
completion_tokens: Optional[int] = 0
|
|||
|
|
|||
|
|
|||
|
class History(BaseModel):
|
|||
|
id: str
|
|||
|
question: str
|
|||
|
messages: List[ChatMessage]
|
|||
|
query: str = None
|
|||
|
wiki_content: List[str] = None
|
|||
|
|
|||
|
|
|||
|
class ChatCompletionResponse(BaseModel):
|
|||
|
model: str
|
|||
|
object: Literal["chat.completion", "chat.completion.chunk"]
|
|||
|
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
|
|||
|
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
|
|||
|
usage: Optional[UsageInfo] = None
|
|||
|
history: History = None
|
|||
|
|
|||
|
|
|||
|
def get_contexts(query, top_k, score_threshold, kb_name):
|
|||
|
return get_complete_docs(query, top_k, score_threshold, kb_name)
|
|||
|
|
|||
|
|
|||
|
from fastapi import Depends, HTTPException, Request
|
|||
|
|
|||
|
|
|||
|
# 验证
|
|||
|
async def verify_token(request: Request):
|
|||
|
auth_header = request.headers.get('Authorization')
|
|||
|
if auth_header:
|
|||
|
token_type, _, token = auth_header.partition(' ')
|
|||
|
if (
|
|||
|
token_type.lower() == "bearer"
|
|||
|
and token == "sk-chatglm3-6b"
|
|||
|
): # 这里配置你的token
|
|||
|
return True
|
|||
|
raise HTTPException(
|
|||
|
status_code=401,
|
|||
|
detail="Invalid authorization credentials",
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
@app.get("/v1/clear/history")
|
|||
|
async def clear_history():
|
|||
|
try:
|
|||
|
if torch.cuda.is_available():
|
|||
|
torch.cuda.empty_cache()
|
|||
|
torch.cuda.ipc_collect()
|
|||
|
except:
|
|||
|
raise HTTPException(status_code=500, detail="Internal 500 Error")
|
|||
|
return ClearHistory(code=200, message="success")
|
|||
|
|
|||
|
|
|||
|
@app.get("/v1/models", response_model=ModelList)
|
|||
|
async def list_models():
|
|||
|
model_card = ModelCard(id="chatglm3-6b-32k")
|
|||
|
return ModelList(data=[model_card])
|
|||
|
|
|||
|
|
|||
|
@app.get("/v1/tool", response_model=ChatCompletionResponse)
|
|||
|
def run_conversation(query: str, stream=False, functions=None):
|
|||
|
params = dict(model="chatglm3-6b", messages=[{"role": "user", "content": query}], kb_name="ceshi", stream=stream)
|
|||
|
if functions:
|
|||
|
params["functions"] = functions
|
|||
|
else:
|
|||
|
params["functions"] = get_tools()
|
|||
|
response = requests.post(f"{base_url}/v1/chat/completions", json=params, stream=False)
|
|||
|
max_retry = 1
|
|||
|
response = ChatCompletionResponse(**response.json())
|
|||
|
print(response)
|
|||
|
for _ in range(max_retry):
|
|||
|
if not stream:
|
|||
|
if response.choices[0].message.function_call:
|
|||
|
function_call = response.choices[0].message.function_call
|
|||
|
logger.info(f"Function Call Response: {function_call.json(ensure_ascii=False)}")
|
|||
|
function_args = json.loads(function_call.arguments)
|
|||
|
tool_response = dispatch_tool(function_call.name, function_args)
|
|||
|
logger.info(f"Tool Call Response: {tool_response}")
|
|||
|
msg = response.choices[0].message
|
|||
|
params["messages"].append({
|
|||
|
'role': msg.role,
|
|||
|
'content': msg.content
|
|||
|
})
|
|||
|
params["messages"].append({
|
|||
|
"role": "function",
|
|||
|
"name": function_call.name,
|
|||
|
"content": tool_response, # 调用函数返回结果
|
|||
|
})
|
|||
|
else:
|
|||
|
reply = response.choices[0].message.content
|
|||
|
logger.info(f"Final Reply: \n{reply}")
|
|||
|
return
|
|||
|
|
|||
|
else:
|
|||
|
output = ""
|
|||
|
for chunk in response:
|
|||
|
content = chunk.choices[0].delta.content or ""
|
|||
|
print(Fore.BLUE + content, end="", flush=True)
|
|||
|
output += content
|
|||
|
|
|||
|
if chunk.choices[0].finish_reason == "stop":
|
|||
|
return
|
|||
|
|
|||
|
elif chunk.choices[0].finish_reason == "function_call":
|
|||
|
print("\n")
|
|||
|
|
|||
|
function_call = chunk.choices[0].delta.function_call
|
|||
|
logger.info(f"Function Call Response: {function_call.model_dump()}")
|
|||
|
|
|||
|
function_args = json.loads(function_call.arguments)
|
|||
|
tool_response = dispatch_tool(function_call.name, function_args)
|
|||
|
logger.info(f"Tool Call Response: {tool_response}")
|
|||
|
|
|||
|
params["messages"].append(
|
|||
|
{
|
|||
|
"role": "assistant",
|
|||
|
"content": output
|
|||
|
}
|
|||
|
)
|
|||
|
params["messages"].append(
|
|||
|
{
|
|||
|
"role": "function",
|
|||
|
"name": function_call.name,
|
|||
|
"content": tool_response, # 调用函数返回结果
|
|||
|
}
|
|||
|
)
|
|||
|
|
|||
|
break
|
|||
|
print(params)
|
|||
|
res = requests.post(f"{base_url}/v1/chat/completions", json=params)
|
|||
|
return ChatCompletionResponse(**res.json())
|
|||
|
|
|||
|
|
|||
|
def init(request):
|
|||
|
# TODO 如果用户上传的有SYSTEM,则重新处理
|
|||
|
r = deepcopy(request)
|
|||
|
# 向量库查询语句
|
|||
|
query = "\n".join([m.content for m in r.messages])
|
|||
|
# 用户问题
|
|||
|
question = r.messages[-1].content
|
|||
|
# 召回的文档数量
|
|||
|
top_k = r.top_k
|
|||
|
# 分数匹配
|
|||
|
score_threshold = r.score_threshold
|
|||
|
# 知识库名称
|
|||
|
kb_name = r.kb_name
|
|||
|
# 向量库查询内容
|
|||
|
contexts = get_contexts(query, top_k, score_threshold, kb_name)
|
|||
|
# 添加系统提示词
|
|||
|
system_prompt = r.system_prompt or SYSTEM_PROMPT
|
|||
|
r.messages.insert(0, ChatMessage(role='system', content=system_prompt))
|
|||
|
# 重构消息
|
|||
|
new_messages = r.messages
|
|||
|
# 知识库模板
|
|||
|
wiki_content = "\n".join(contexts)
|
|||
|
prompt = f"已知信息:\n{wiki_content}\n问题:\n{question}\n"
|
|||
|
new_messages[-1].content = prompt
|
|||
|
return query, question, contexts, prompt, new_messages
|
|||
|
|
|||
|
|
|||
|
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
|
|||
|
async def create_chat_completion(request: ChatCompletionRequest):
|
|||
|
# TODO agent调用 需要判断function是否存在
|
|||
|
global model, tokenizer
|
|||
|
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
|
|||
|
raise HTTPException(status_code=400, detail="Invalid request")
|
|||
|
# TODO 添加wiki模板提示词
|
|||
|
query, question, contexts, prompt, new_messages = init(request)
|
|||
|
###########################################################
|
|||
|
gen_params = dict(
|
|||
|
messages=request.messages if request.functions is not None else new_messages,
|
|||
|
temperature=request.temperature,
|
|||
|
top_p=request.top_p,
|
|||
|
max_tokens=request.max_tokens or 1024,
|
|||
|
max_length=request.max_length or 8192,
|
|||
|
echo=False,
|
|||
|
stream=request.stream,
|
|||
|
repetition_penalty=request.repetition_penalty,
|
|||
|
functions=request.functions,
|
|||
|
system_prompt=request.system_prompt or SYSTEM_PROMPT,
|
|||
|
top_k=request.top_k or 3,
|
|||
|
score_threshold=request.score_threshold or 1.0,
|
|||
|
kb_name=request.kb_name,
|
|||
|
)
|
|||
|
logger.debug(f"==== request ====\n{gen_params}")
|
|||
|
# 流式输出
|
|||
|
if request.stream:
|
|||
|
generate = predict(request.model, gen_params, request.messages, question, query, contexts)
|
|||
|
return EventSourceResponse(generate, media_type="text/event-stream")
|
|||
|
#############################################################################################
|
|||
|
# 非流式输出
|
|||
|
response = generate_chatglm3(model, tokenizer, gen_params)
|
|||
|
usage = UsageInfo()
|
|||
|
|
|||
|
function_call, finish_reason = None, "stop"
|
|||
|
if request.functions:
|
|||
|
try:
|
|||
|
function_call = process_response(response["text"], use_tool=True)
|
|||
|
except:
|
|||
|
logger.warning("Failed to parse tool call")
|
|||
|
|
|||
|
if isinstance(function_call, dict):
|
|||
|
finish_reason = "function_call"
|
|||
|
function_call = FunctionCallResponse(**function_call)
|
|||
|
|
|||
|
message = ChatMessage(
|
|||
|
role="assistant",
|
|||
|
content=response["text"],
|
|||
|
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
|
|||
|
)
|
|||
|
# 添加历史记录
|
|||
|
request.messages.append(message)
|
|||
|
h = History(id=str(uuid.uuid4()), messages=request.messages, question=question, query=query, wiki_content=contexts)
|
|||
|
choice_data = ChatCompletionResponseChoice(
|
|||
|
index=0,
|
|||
|
message=message,
|
|||
|
finish_reason=finish_reason,
|
|||
|
)
|
|||
|
task_usage = UsageInfo.parse_obj(response["usage"])
|
|||
|
for usage_key, usage_value in task_usage.dict().items():
|
|||
|
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
|
|||
|
|
|||
|
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage,
|
|||
|
history=h)
|
|||
|
|
|||
|
|
|||
|
async def predict(model_id: str, params: dict, messages: List[ChatMessage], question, query, contexts):
|
|||
|
global model, tokenizer
|
|||
|
all_text = ""
|
|||
|
# 第一条消息没有content
|
|||
|
choice_data = ChatCompletionResponseStreamChoice(
|
|||
|
index=0,
|
|||
|
delta=DeltaMessage(role="assistant"),
|
|||
|
finish_reason=None
|
|||
|
)
|
|||
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
|||
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
|||
|
|
|||
|
previous_text = ""
|
|||
|
f = None
|
|||
|
for new_response in generate_stream_chatglm3(model, tokenizer, params):
|
|||
|
decoded_unicode = new_response["text"]
|
|||
|
delta_text = decoded_unicode[len(previous_text):]
|
|||
|
previous_text = decoded_unicode
|
|||
|
|
|||
|
finish_reason = new_response["finish_reason"]
|
|||
|
if len(delta_text) == 0 and finish_reason != "function_call":
|
|||
|
continue
|
|||
|
|
|||
|
function_call = None
|
|||
|
if finish_reason == "function_call":
|
|||
|
try:
|
|||
|
function_call = process_response(decoded_unicode, use_tool=True)
|
|||
|
except:
|
|||
|
print("Failed to parse tool call")
|
|||
|
|
|||
|
if isinstance(function_call, dict):
|
|||
|
function_call = FunctionCallResponse(**function_call)
|
|||
|
f = function_call
|
|||
|
|
|||
|
delta = DeltaMessage(
|
|||
|
content=delta_text,
|
|||
|
role="assistant",
|
|||
|
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
|
|||
|
)
|
|||
|
all_text += delta_text
|
|||
|
choice_data = ChatCompletionResponseStreamChoice(
|
|||
|
index=0,
|
|||
|
delta=delta,
|
|||
|
finish_reason=finish_reason
|
|||
|
)
|
|||
|
# 这里添加历史记录
|
|||
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
|||
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
|||
|
messages.append(ChatMessage(role='assistant', content=all_text, name='', function_call=f))
|
|||
|
choice_data = ChatCompletionResponseStreamChoice(
|
|||
|
index=0,
|
|||
|
delta=DeltaMessage(),
|
|||
|
finish_reason="stop"
|
|||
|
)
|
|||
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
|
|||
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
|||
|
h = History(id=str(uuid.uuid4()), question=question, messages=messages, query=query, wiki_content=contexts)
|
|||
|
yield '[DONE]'
|
|||
|
choice_data = ChatCompletionResponseChoice(
|
|||
|
index=0,
|
|||
|
message=ChatMessage(role='assistant', content=all_text, function_call=f),
|
|||
|
finish_reason="finish"
|
|||
|
)
|
|||
|
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk", history=h)
|
|||
|
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
|
|||
|
model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True).cuda()
|
|||
|
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
|
|||
|
# from utils import load_model_on_gpus
|
|||
|
# model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
|
|||
|
model = model.eval()
|
|||
|
host = "localhost"
|
|||
|
port = 8000
|
|||
|
base_url = f"http://{host}:{port}"
|
|||
|
uvicorn.run(app, host=host, port=port, workers=1)
|