ai/openai_api.py

423 lines
15 KiB
Python
Raw Permalink Normal View History

2023-12-14 14:26:13 +08:00
# coding=utf-8
import json
import time
import uuid
from contextlib import asynccontextmanager
from copy import deepcopy
from typing import List, Union
from typing import Literal, Optional
from colorama import init, Fore
import requests
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, AutoModel
from api import get_complete_docs
from tool_using.tool_register import get_tools, dispatch_tool
from utils import process_response, generate_chatglm3, generate_stream_chatglm3
SYSTEM_PROMPT = """你是一个农业领域的AI助手你需要通过已知信息和历史对话来回答人类的问题。
特别注意已知信息不一定是正确答案还要根据对话的历史问题和答案进行思考回答如果找不到答案请用自己的知识进行回答不可以乱回答
直接给出答案即可答案使用中文\n"""
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ClearHistory(BaseModel):
code: int
message: str
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class FunctionCallResponse(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "function"]
content: str = None
name: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
kb_name: str
stream: Optional[bool] = False
functions: Optional[Union[dict, List[dict]]] = None
# Additional parameters
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
max_length: Optional[int] = None
repetition_penalty: Optional[float] = 1.1
top_k: int = 3 # 召回的知识库文档数量
score_threshold: float = 1.0 # 知识库严谨度
system_prompt: str = SYSTEM_PROMPT
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "function_call", "finish"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length", "function_call", "finish"]]
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class History(BaseModel):
id: str
question: str
messages: List[ChatMessage]
query: str = None
wiki_content: List[str] = None
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
history: History = None
def get_contexts(query, top_k, score_threshold, kb_name):
return get_complete_docs(query, top_k, score_threshold, kb_name)
from fastapi import Depends, HTTPException, Request
# 验证
async def verify_token(request: Request):
auth_header = request.headers.get('Authorization')
if auth_header:
token_type, _, token = auth_header.partition(' ')
if (
token_type.lower() == "bearer"
and token == "sk-chatglm3-6b"
): # 这里配置你的token
return True
raise HTTPException(
status_code=401,
detail="Invalid authorization credentials",
)
@app.get("/v1/clear/history")
async def clear_history():
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except:
raise HTTPException(status_code=500, detail="Internal 500 Error")
return ClearHistory(code=200, message="success")
@app.get("/v1/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(id="chatglm3-6b-32k")
return ModelList(data=[model_card])
@app.get("/v1/tool", response_model=ChatCompletionResponse)
def run_conversation(query: str, stream=False, functions=None):
params = dict(model="chatglm3-6b", messages=[{"role": "user", "content": query}], kb_name="ceshi", stream=stream)
if functions:
params["functions"] = functions
else:
params["functions"] = get_tools()
response = requests.post(f"{base_url}/v1/chat/completions", json=params, stream=False)
max_retry = 1
response = ChatCompletionResponse(**response.json())
print(response)
for _ in range(max_retry):
if not stream:
if response.choices[0].message.function_call:
function_call = response.choices[0].message.function_call
logger.info(f"Function Call Response: {function_call.json(ensure_ascii=False)}")
function_args = json.loads(function_call.arguments)
tool_response = dispatch_tool(function_call.name, function_args)
logger.info(f"Tool Call Response: {tool_response}")
msg = response.choices[0].message
params["messages"].append({
'role': msg.role,
'content': msg.content
})
params["messages"].append({
"role": "function",
"name": function_call.name,
"content": tool_response, # 调用函数返回结果
})
else:
reply = response.choices[0].message.content
logger.info(f"Final Reply: \n{reply}")
return
else:
output = ""
for chunk in response:
content = chunk.choices[0].delta.content or ""
print(Fore.BLUE + content, end="", flush=True)
output += content
if chunk.choices[0].finish_reason == "stop":
return
elif chunk.choices[0].finish_reason == "function_call":
print("\n")
function_call = chunk.choices[0].delta.function_call
logger.info(f"Function Call Response: {function_call.model_dump()}")
function_args = json.loads(function_call.arguments)
tool_response = dispatch_tool(function_call.name, function_args)
logger.info(f"Tool Call Response: {tool_response}")
params["messages"].append(
{
"role": "assistant",
"content": output
}
)
params["messages"].append(
{
"role": "function",
"name": function_call.name,
"content": tool_response, # 调用函数返回结果
}
)
break
print(params)
res = requests.post(f"{base_url}/v1/chat/completions", json=params)
return ChatCompletionResponse(**res.json())
def init(request):
# TODO 如果用户上传的有SYSTEM则重新处理
r = deepcopy(request)
# 向量库查询语句
query = "\n".join([m.content for m in r.messages])
# 用户问题
question = r.messages[-1].content
# 召回的文档数量
top_k = r.top_k
# 分数匹配
score_threshold = r.score_threshold
# 知识库名称
kb_name = r.kb_name
# 向量库查询内容
contexts = get_contexts(query, top_k, score_threshold, kb_name)
# 添加系统提示词
system_prompt = r.system_prompt or SYSTEM_PROMPT
r.messages.insert(0, ChatMessage(role='system', content=system_prompt))
# 重构消息
new_messages = r.messages
# 知识库模板
wiki_content = "\n".join(contexts)
prompt = f"已知信息:\n{wiki_content}\n问题:\n{question}\n"
new_messages[-1].content = prompt
return query, question, contexts, prompt, new_messages
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest):
# TODO agent调用 需要判断function是否存在
global model, tokenizer
if len(request.messages) < 1 or request.messages[-1].role == "assistant":
raise HTTPException(status_code=400, detail="Invalid request")
# TODO 添加wiki模板提示词
query, question, contexts, prompt, new_messages = init(request)
###########################################################
gen_params = dict(
messages=request.messages if request.functions is not None else new_messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 1024,
max_length=request.max_length or 8192,
echo=False,
stream=request.stream,
repetition_penalty=request.repetition_penalty,
functions=request.functions,
system_prompt=request.system_prompt or SYSTEM_PROMPT,
top_k=request.top_k or 3,
score_threshold=request.score_threshold or 1.0,
kb_name=request.kb_name,
)
logger.debug(f"==== request ====\n{gen_params}")
# 流式输出
if request.stream:
generate = predict(request.model, gen_params, request.messages, question, query, contexts)
return EventSourceResponse(generate, media_type="text/event-stream")
#############################################################################################
# 非流式输出
response = generate_chatglm3(model, tokenizer, gen_params)
usage = UsageInfo()
function_call, finish_reason = None, "stop"
if request.functions:
try:
function_call = process_response(response["text"], use_tool=True)
except:
logger.warning("Failed to parse tool call")
if isinstance(function_call, dict):
finish_reason = "function_call"
function_call = FunctionCallResponse(**function_call)
message = ChatMessage(
role="assistant",
content=response["text"],
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
)
# 添加历史记录
request.messages.append(message)
h = History(id=str(uuid.uuid4()), messages=request.messages, question=question, query=query, wiki_content=contexts)
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason=finish_reason,
)
task_usage = UsageInfo.parse_obj(response["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage,
history=h)
async def predict(model_id: str, params: dict, messages: List[ChatMessage], question, query, contexts):
global model, tokenizer
all_text = ""
# 第一条消息没有content
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
previous_text = ""
f = None
for new_response in generate_stream_chatglm3(model, tokenizer, params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(previous_text):]
previous_text = decoded_unicode
finish_reason = new_response["finish_reason"]
if len(delta_text) == 0 and finish_reason != "function_call":
continue
function_call = None
if finish_reason == "function_call":
try:
function_call = process_response(decoded_unicode, use_tool=True)
except:
print("Failed to parse tool call")
if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call)
f = function_call
delta = DeltaMessage(
content=delta_text,
role="assistant",
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
)
all_text += delta_text
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=delta,
finish_reason=finish_reason
)
# 这里添加历史记录
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
messages.append(ChatMessage(role='assistant', content=all_text, name='', function_call=f))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
h = History(id=str(uuid.uuid4()), question=question, messages=messages, query=query, wiki_content=contexts)
yield '[DONE]'
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=all_text, function_call=f),
finish_reason="finish"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk", history=h)
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True).cuda()
# 多显卡支持使用下面两行代替上面一行将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
model = model.eval()
host = "localhost"
port = 8000
base_url = f"http://{host}:{port}"
uvicorn.run(app, host=host, port=port, workers=1)