# coding=utf-8 import json import time import uuid from contextlib import asynccontextmanager from copy import deepcopy from typing import List, Union from typing import Literal, Optional from colorama import init, Fore import requests import torch import uvicorn from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from loguru import logger from pydantic import BaseModel, Field from sse_starlette.sse import EventSourceResponse from transformers import AutoTokenizer, AutoModel from api import get_complete_docs from tool_using.tool_register import get_tools, dispatch_tool from utils import process_response, generate_chatglm3, generate_stream_chatglm3 SYSTEM_PROMPT = """你是一个农业领域的AI助手,你需要通过已知信息和历史对话来回答人类的问题。 特别注意:已知信息不一定是正确答案,还要根据对话的历史问题和答案进行思考回答,如果找不到答案,请用自己的知识进行回答,不可以乱回答!!! 直接给出答案即可,答案使用中文。\n""" @asynccontextmanager async def lifespan(app: FastAPI): # collects GPU memory yield if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() app = FastAPI(lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class ClearHistory(BaseModel): code: int message: str class ModelCard(BaseModel): id: str object: str = "model" created: int = Field(default_factory=lambda: int(time.time())) owned_by: str = "owner" root: Optional[str] = None parent: Optional[str] = None permission: Optional[list] = None class ModelList(BaseModel): object: str = "list" data: List[ModelCard] = [] class FunctionCallResponse(BaseModel): name: Optional[str] = None arguments: Optional[str] = None class ChatMessage(BaseModel): role: Literal["user", "assistant", "system", "function"] content: str = None name: Optional[str] = None function_call: Optional[FunctionCallResponse] = None class DeltaMessage(BaseModel): role: Optional[Literal["user", "assistant", "system"]] = None content: Optional[str] = None function_call: Optional[FunctionCallResponse] = None class ChatCompletionRequest(BaseModel): model: str messages: List[ChatMessage] kb_name: str stream: Optional[bool] = False functions: Optional[Union[dict, List[dict]]] = None # Additional parameters temperature: Optional[float] = 0.8 top_p: Optional[float] = 0.8 max_tokens: Optional[int] = None max_length: Optional[int] = None repetition_penalty: Optional[float] = 1.1 top_k: int = 3 # 召回的知识库文档数量 score_threshold: float = 1.0 # 知识库严谨度 system_prompt: str = SYSTEM_PROMPT class ChatCompletionResponseChoice(BaseModel): index: int message: ChatMessage finish_reason: Literal["stop", "length", "function_call", "finish"] class ChatCompletionResponseStreamChoice(BaseModel): index: int delta: DeltaMessage finish_reason: Optional[Literal["stop", "length", "function_call", "finish"]] class UsageInfo(BaseModel): prompt_tokens: int = 0 total_tokens: int = 0 completion_tokens: Optional[int] = 0 class History(BaseModel): id: str question: str messages: List[ChatMessage] query: str = None wiki_content: List[str] = None class ChatCompletionResponse(BaseModel): model: str object: Literal["chat.completion", "chat.completion.chunk"] choices: List[Union[ChatCompletionResponseChoice, ChatCompletionResponseStreamChoice]] created: Optional[int] = Field(default_factory=lambda: int(time.time())) usage: Optional[UsageInfo] = None history: History = None def get_contexts(query, top_k, score_threshold, kb_name): return get_complete_docs(query, top_k, score_threshold, kb_name) from fastapi import Depends, HTTPException, Request # 验证 async def verify_token(request: Request): auth_header = request.headers.get('Authorization') if auth_header: token_type, _, token = auth_header.partition(' ') if ( token_type.lower() == "bearer" and token == "sk-chatglm3-6b" ): # 这里配置你的token return True raise HTTPException( status_code=401, detail="Invalid authorization credentials", ) @app.get("/v1/clear/history") async def clear_history(): try: if torch.cuda.is_available(): torch.cuda.empty_cache() torch.cuda.ipc_collect() except: raise HTTPException(status_code=500, detail="Internal 500 Error") return ClearHistory(code=200, message="success") @app.get("/v1/models", response_model=ModelList) async def list_models(): model_card = ModelCard(id="chatglm3-6b-32k") return ModelList(data=[model_card]) @app.get("/v1/tool", response_model=ChatCompletionResponse) def run_conversation(query: str, stream=False, functions=None): params = dict(model="chatglm3-6b", messages=[{"role": "user", "content": query}], kb_name="ceshi", stream=stream) if functions: params["functions"] = functions else: params["functions"] = get_tools() response = requests.post(f"{base_url}/v1/chat/completions", json=params, stream=False) max_retry = 1 response = ChatCompletionResponse(**response.json()) print(response) for _ in range(max_retry): if not stream: if response.choices[0].message.function_call: function_call = response.choices[0].message.function_call logger.info(f"Function Call Response: {function_call.json(ensure_ascii=False)}") function_args = json.loads(function_call.arguments) tool_response = dispatch_tool(function_call.name, function_args) logger.info(f"Tool Call Response: {tool_response}") msg = response.choices[0].message params["messages"].append({ 'role': msg.role, 'content': msg.content }) params["messages"].append({ "role": "function", "name": function_call.name, "content": tool_response, # 调用函数返回结果 }) else: reply = response.choices[0].message.content logger.info(f"Final Reply: \n{reply}") return else: output = "" for chunk in response: content = chunk.choices[0].delta.content or "" print(Fore.BLUE + content, end="", flush=True) output += content if chunk.choices[0].finish_reason == "stop": return elif chunk.choices[0].finish_reason == "function_call": print("\n") function_call = chunk.choices[0].delta.function_call logger.info(f"Function Call Response: {function_call.model_dump()}") function_args = json.loads(function_call.arguments) tool_response = dispatch_tool(function_call.name, function_args) logger.info(f"Tool Call Response: {tool_response}") params["messages"].append( { "role": "assistant", "content": output } ) params["messages"].append( { "role": "function", "name": function_call.name, "content": tool_response, # 调用函数返回结果 } ) break print(params) res = requests.post(f"{base_url}/v1/chat/completions", json=params) return ChatCompletionResponse(**res.json()) def init(request): # TODO 如果用户上传的有SYSTEM,则重新处理 r = deepcopy(request) # 向量库查询语句 query = "\n".join([m.content for m in r.messages]) # 用户问题 question = r.messages[-1].content # 召回的文档数量 top_k = r.top_k # 分数匹配 score_threshold = r.score_threshold # 知识库名称 kb_name = r.kb_name # 向量库查询内容 contexts = get_contexts(query, top_k, score_threshold, kb_name) # 添加系统提示词 system_prompt = r.system_prompt or SYSTEM_PROMPT r.messages.insert(0, ChatMessage(role='system', content=system_prompt)) # 重构消息 new_messages = r.messages # 知识库模板 wiki_content = "\n".join(contexts) prompt = f"已知信息:\n{wiki_content}\n问题:\n{question}\n" new_messages[-1].content = prompt return query, question, contexts, prompt, new_messages @app.post("/v1/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): # TODO agent调用 需要判断function是否存在 global model, tokenizer if len(request.messages) < 1 or request.messages[-1].role == "assistant": raise HTTPException(status_code=400, detail="Invalid request") # TODO 添加wiki模板提示词 query, question, contexts, prompt, new_messages = init(request) ########################################################### gen_params = dict( messages=request.messages if request.functions is not None else new_messages, temperature=request.temperature, top_p=request.top_p, max_tokens=request.max_tokens or 1024, max_length=request.max_length or 8192, echo=False, stream=request.stream, repetition_penalty=request.repetition_penalty, functions=request.functions, system_prompt=request.system_prompt or SYSTEM_PROMPT, top_k=request.top_k or 3, score_threshold=request.score_threshold or 1.0, kb_name=request.kb_name, ) logger.debug(f"==== request ====\n{gen_params}") # 流式输出 if request.stream: generate = predict(request.model, gen_params, request.messages, question, query, contexts) return EventSourceResponse(generate, media_type="text/event-stream") ############################################################################################# # 非流式输出 response = generate_chatglm3(model, tokenizer, gen_params) usage = UsageInfo() function_call, finish_reason = None, "stop" if request.functions: try: function_call = process_response(response["text"], use_tool=True) except: logger.warning("Failed to parse tool call") if isinstance(function_call, dict): finish_reason = "function_call" function_call = FunctionCallResponse(**function_call) message = ChatMessage( role="assistant", content=response["text"], function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, ) # 添加历史记录 request.messages.append(message) h = History(id=str(uuid.uuid4()), messages=request.messages, question=question, query=query, wiki_content=contexts) choice_data = ChatCompletionResponseChoice( index=0, message=message, finish_reason=finish_reason, ) task_usage = UsageInfo.parse_obj(response["usage"]) for usage_key, usage_value in task_usage.dict().items(): setattr(usage, usage_key, getattr(usage, usage_key) + usage_value) return ChatCompletionResponse(model=request.model, choices=[choice_data], object="chat.completion", usage=usage, history=h) async def predict(model_id: str, params: dict, messages: List[ChatMessage], question, query, contexts): global model, tokenizer all_text = "" # 第一条消息没有content choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(role="assistant"), finish_reason=None ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) previous_text = "" f = None for new_response in generate_stream_chatglm3(model, tokenizer, params): decoded_unicode = new_response["text"] delta_text = decoded_unicode[len(previous_text):] previous_text = decoded_unicode finish_reason = new_response["finish_reason"] if len(delta_text) == 0 and finish_reason != "function_call": continue function_call = None if finish_reason == "function_call": try: function_call = process_response(decoded_unicode, use_tool=True) except: print("Failed to parse tool call") if isinstance(function_call, dict): function_call = FunctionCallResponse(**function_call) f = function_call delta = DeltaMessage( content=delta_text, role="assistant", function_call=function_call if isinstance(function_call, FunctionCallResponse) else None, ) all_text += delta_text choice_data = ChatCompletionResponseStreamChoice( index=0, delta=delta, finish_reason=finish_reason ) # 这里添加历史记录 chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) messages.append(ChatMessage(role='assistant', content=all_text, name='', function_call=f)) choice_data = ChatCompletionResponseStreamChoice( index=0, delta=DeltaMessage(), finish_reason="stop" ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) h = History(id=str(uuid.uuid4()), question=question, messages=messages, query=query, wiki_content=contexts) yield '[DONE]' choice_data = ChatCompletionResponseChoice( index=0, message=ChatMessage(role='assistant', content=all_text, function_call=f), finish_reason="finish" ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk", history=h) yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) if __name__ == "__main__": tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True) model = AutoModel.from_pretrained("THUDM/chatglm3-6b", trust_remote_code=True).cuda() # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 # from utils import load_model_on_gpus # model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2) model = model.eval() host = "localhost" port = 8000 base_url = f"http://{host}:{port}" uvicorn.run(app, host=host, port=port, workers=1)