test_ai/knownledge_api.py

175 lines
5.7 KiB
Python
Executable File
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import *
import nltk
import sys
import os
import pydantic
from pydantic import BaseModel
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
from configs import VERSION
from configs.model_config import NLTK_DATA_PATH
from configs.server_config import OPEN_CROSS_DOMAIN
import argparse
import uvicorn
from fastapi.middleware.cors import CORSMiddleware
from starlette.responses import RedirectResponse
from fastapi import FastAPI
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
class BaseResponse(BaseModel):
code: int = pydantic.Field(200, description="API status code")
msg: str = pydantic.Field("success", description="API status message")
data: Any = pydantic.Field(None, description="API data")
class Config:
schema_extra = {
"example": {
"code": 200,
"msg": "success",
}
}
class ListResponse(BaseResponse):
data: List[str] = pydantic.Field(..., description="List of names")
class Config:
schema_extra = {
"example": {
"code": 200,
"msg": "success",
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
}
}
async def document():
return RedirectResponse(url="/docs")
def create_app(run_mode: str = None):
app = FastAPI(
title="Langchain-Chatchat API Server",
version=VERSION
)
# Add CORS middleware to allow all origins
# 在config.py中设置OPEN_DOMAIN=True允许跨域
# set OPEN_DOMAIN=True in config.py to allow cross-domain
if OPEN_CROSS_DOMAIN:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
return app
def mount_knowledge_routes(app: FastAPI):
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
update_docs, download_doc, recreate_vector_store,
search_docs, DocumentWithScore, update_info)
# Tag: Knowledge Base Management
app.get("/knowledge_base/list_knowledge_bases",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库列表")(list_kbs)
app.post("/knowledge_base/create_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="创建知识库"
)(create_kb)
app.post("/knowledge_base/delete_knowledge_base",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库"
)(delete_kb)
app.get("/knowledge_base/list_files",
tags=["Knowledge Base Management"],
response_model=ListResponse,
summary="获取知识库内的文件列表"
)(list_files)
app.post("/knowledge_base/search_docs",
tags=["Knowledge Base Management"],
response_model=List[DocumentWithScore],
summary="搜索知识库"
)(search_docs)
app.post("/knowledge_base/upload_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="上传文件到知识库,并/或进行向量化"
)(upload_docs)
app.post("/knowledge_base/delete_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="删除知识库内指定文件"
)(delete_docs)
app.post("/knowledge_base/update_info",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="更新知识库介绍"
)(update_info)
app.post("/knowledge_base/update_docs",
tags=["Knowledge Base Management"],
response_model=BaseResponse,
summary="更新现有文件到知识库"
)(update_docs)
app.get("/knowledge_base/download_doc",
tags=["Knowledge Base Management"],
summary="下载对应的知识文件")(download_doc)
app.post("/knowledge_base/recreate_vector_store",
tags=["Knowledge Base Management"],
summary="根据content中文档重建向量库流式输出处理进度。"
)(recreate_vector_store)
def run_api(host, port, **kwargs):
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
uvicorn.run(app,
host=host,
port=port,
ssl_keyfile=kwargs.get("ssl_keyfile"),
ssl_certfile=kwargs.get("ssl_certfile"),
)
else:
uvicorn.run(app, host=host, port=port)
if __name__ == "__main__":
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
' 基于本地知识库的 ChatGLM 问答')
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7861)
parser.add_argument("--ssl_keyfile", type=str)
parser.add_argument("--ssl_certfile", type=str)
# 初始化消息
args = parser.parse_args()
args_dict = vars(args)
app = create_app()
mount_knowledge_routes(app)
run_api(host=args.host,
port=args.port,
ssl_keyfile=args.ssl_keyfile,
ssl_certfile=args.ssl_certfile,
)