ai/openai_api.bak.py

340 lines
12 KiB
Python
Raw Permalink Normal View History

2023-12-14 14:26:13 +08:00
# coding=utf-8
import time
import uuid
from contextlib import asynccontextmanager
from copy import deepcopy
from typing import List, Union
from typing import Literal, Optional
import torch
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger
from pydantic import BaseModel, Field
from sse_starlette.sse import EventSourceResponse
from transformers import AutoTokenizer, AutoModel
from api import get_complete_docs
from utils import process_response, generate_chatglm3, generate_stream_chatglm3
system_prompt = "你是一个农业领域的AI助手你需要通过已知信息和历史对话来回答人类的问题。特别注意已知信息不一定是正确答案还要根据对话的历史问题和答案进行思考回答如果找不到答案请用自己的知识进行回答不可以乱回答答案使用中文。\n"
@asynccontextmanager
async def lifespan(app: FastAPI): # collects GPU memory
yield
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI(lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class ClearHistory(BaseModel):
code: int
message: str
class ModelCard(BaseModel):
id: str
object: str = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "owner"
root: Optional[str] = None
parent: Optional[str] = None
permission: Optional[list] = None
class ModelList(BaseModel):
object: str = "list"
data: List[ModelCard] = []
class FunctionCallResponse(BaseModel):
name: Optional[str] = None
arguments: Optional[str] = None
class ChatMessage(BaseModel):
role: Literal["user", "assistant", "system", "function"]
content: str = None
name: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
class DeltaMessage(BaseModel):
role: Optional[Literal["user", "assistant", "system"]] = None
content: Optional[str] = None
function_call: Optional[FunctionCallResponse] = None
class ChatCompletionRequest(BaseModel):
model: str
messages: List[ChatMessage]
stream: Optional[bool] = False
functions: Optional[Union[dict, List[dict]]] = None
# Additional parameters
temperature: Optional[float] = 0.8
top_p: Optional[float] = 0.8
max_tokens: Optional[int] = None
max_length: Optional[int] = None
repetition_penalty: Optional[float] = 1.1
top_k: int = 3 # 召回的知识库文档数量
score_threshold: float = 1.0 # 知识库严谨度
system_prompt: str = system_prompt
class ChatCompletionResponseChoice(BaseModel):
index: int
message: ChatMessage
finish_reason: Literal["stop", "length", "function_call", "finish"]
class ChatCompletionResponseStreamChoice(BaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[Literal["stop", "length", "function_call", "finish"]]
class UsageInfo(BaseModel):
prompt_tokens: int = 0
total_tokens: int = 0
completion_tokens: Optional[int] = 0
class History(BaseModel):
id: str
question: str
messages: List[ChatMessage]
query: str = None
wiki_content: List[str] = None
class ChatCompletionResponse(BaseModel):
model: str
object: Literal["chat.completion", "chat.completion.chunk"]
choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]]
created: Optional[int] = Field(default_factory=lambda: int(time.time()))
usage: Optional[UsageInfo] = None
history: History = None
def get_contexts(query, top_k, score_threshold):
kb_name = "agriculture"
return get_complete_docs(query, top_k, score_threshold, kb_name)
from fastapi import Depends, HTTPException, Request
# 验证
async def verify_token(request: Request):
auth_header = request.headers.get('Authorization')
if auth_header:
token_type, _, token = auth_header.partition(' ')
if (
token_type.lower() == "bearer"
and token == "sk-chatglm3-6b-32k"
): # 这里配置你的token
return True
raise HTTPException(
status_code=401,
detail="Invalid authorization credentials",
)
@app.get("/v1/clear/history")
async def clear_history():
try:
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
except:
raise HTTPException(status_code=500, detail="Internal 500 Error")
return ClearHistory(code=200, message="success")
@app.get("/v1/models", response_model=ModelList)
async def list_models():
model_card = ModelCard(id="chatglm3-6b-32k")
return ModelList(data=[model_card])
def init(request):
# TODO 如果用户上传的有SYSTEM则重新处理
r = deepcopy(request)
# 向量库查询语句
query = ",".join([m.content for m in r.messages])
# 用户问题
question = r.messages[-1].content
# 向量库查询内容
# 召回的文档数量
top_k = r.top_k
# 分数匹配
score_threshold = r.score_threshold
# 知识库名称
contexts = get_contexts(query, top_k, score_threshold)
# 添加系统提示词
system_prompt = r.system_prompt
r.messages.insert(0, ChatMessage(role='system', content=system_prompt))
# 重构消息
new_messages = r.messages
# 知识库模板
wiki_content = "\n".join(contexts)
prompt = f"根据以下信息:\n{wiki_content}\n问答问题:\n{question}\n"
new_messages[-1].content = prompt
return query, question, contexts, prompt, new_messages
@app.post("/v1/chat/completions", response_model=ChatCompletionResponse)
async def create_chat_completion(request: ChatCompletionRequest, token: bool = Depends(verify_token)):
print("token:", token)
global model, tokenizer
if len(request.messages) < 1 or request.messages[-1].role != "user":
raise HTTPException(status_code=400, detail="Invalid request")
# TODO 添加系统提示词
query, question, contexts, prompt, new_messages = init(request)
###########################################################
gen_params = dict(
messages=new_messages,
temperature=request.temperature,
top_p=request.top_p,
max_tokens=request.max_tokens or 8192,
max_length=request.max_length or 32768,
echo=False,
stream=request.stream,
repetition_penalty=request.repetition_penalty,
functions=request.functions,
system_prompt=request.system_prompt or system_prompt,
top_k=request.top_k or 3,
score_threshold=request.score_threshold or 1.0,
)
logger.debug(f"==== request ====\n{gen_params}")
# 流式输出
if request.stream:
generate = predict(request.model, gen_params, request.messages, question, query, contexts)
return EventSourceResponse(generate, media_type="text/event-stream")
#############################################################################################
# 非流式输出
response = generate_chatglm3(model, tokenizer, gen_params)
usage = UsageInfo()
function_call, finish_reason = None, "stop"
if request.functions:
try:
function_call = process_response(response["text"], use_tool=True)
except:
logger.warning("Failed to parse tool call")
if isinstance(function_call, dict):
finish_reason = "function_call"
function_call = FunctionCallResponse(**function_call)
message = ChatMessage(
role="assistant",
content=response["text"],
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
)
# 添加历史记录
request.messages.append(message)
h = History(id=str(uuid.uuid4()), messages=request.messages, question=question, query=query, wiki_content=contexts)
choice_data = ChatCompletionResponseChoice(
index=0,
message=message,
finish_reason=finish_reason,
)
task_usage = UsageInfo.parse_obj(response["usage"])
for usage_key, usage_value in task_usage.dict().items():
setattr(usage, usage_key, getattr(usage, usage_key) + usage_value)
return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage,
history=h)
async def predict(model_id: str, params: dict, messages: List[ChatMessage], question, query, contexts):
global model, tokenizer
all_text = ""
# 第一条消息没有content
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(role="assistant"),
finish_reason=None
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
previous_text = ""
f = None
for new_response in generate_stream_chatglm3(model, tokenizer, params):
decoded_unicode = new_response["text"]
delta_text = decoded_unicode[len(previous_text):]
previous_text = decoded_unicode
finish_reason = new_response["finish_reason"]
if len(delta_text) == 0 and finish_reason != "function_call":
continue
function_call = None
if finish_reason == "function_call":
try:
function_call = process_response(decoded_unicode, use_tool=True)
except:
print("Failed to parse tool call")
if isinstance(function_call, dict):
function_call = FunctionCallResponse(**function_call)
f = function_call
delta = DeltaMessage(
content=delta_text,
role="assistant",
function_call=function_call if isinstance(function_call, FunctionCallResponse) else None,
)
all_text += delta_text
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=delta,
finish_reason=finish_reason
)
# 这里添加历史记录
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
messages.append(ChatMessage(role='assistant', content=all_text, name='', function_call=f))
choice_data = ChatCompletionResponseStreamChoice(
index=0,
delta=DeltaMessage(),
finish_reason="stop"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk")
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
h = History(id=str(uuid.uuid4()), question=question, messages=messages, query=query, wiki_content=contexts)
yield '[DONE]'
choice_data = ChatCompletionResponseChoice(
index=0,
message=ChatMessage(role='assistant', content=all_text, function_call=f),
finish_reason="finish"
)
chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk", history=h)
yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False))
if __name__ == "__main__":
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b-32k", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm3-6b-32k", trust_remote_code=True).cuda()
# 多显卡支持使用下面两行代替上面一行将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
model = model.eval()
uvicorn.run(app, host='localhost', port=8000, workers=1)