175 lines
5.7 KiB
Python
175 lines
5.7 KiB
Python
|
from typing import *
|
|||
|
|
|||
|
import nltk
|
|||
|
import sys
|
|||
|
import os
|
|||
|
|
|||
|
import pydantic
|
|||
|
from pydantic import BaseModel
|
|||
|
|
|||
|
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
|||
|
|
|||
|
from configs import VERSION
|
|||
|
from configs.model_config import NLTK_DATA_PATH
|
|||
|
from configs.server_config import OPEN_CROSS_DOMAIN
|
|||
|
import argparse
|
|||
|
import uvicorn
|
|||
|
from fastapi.middleware.cors import CORSMiddleware
|
|||
|
from starlette.responses import RedirectResponse
|
|||
|
from fastapi import FastAPI
|
|||
|
|
|||
|
nltk.data.path = [NLTK_DATA_PATH] + nltk.data.path
|
|||
|
|
|||
|
|
|||
|
class BaseResponse(BaseModel):
|
|||
|
code: int = pydantic.Field(200, description="API status code")
|
|||
|
msg: str = pydantic.Field("success", description="API status message")
|
|||
|
data: Any = pydantic.Field(None, description="API data")
|
|||
|
|
|||
|
class Config:
|
|||
|
schema_extra = {
|
|||
|
"example": {
|
|||
|
"code": 200,
|
|||
|
"msg": "success",
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
class ListResponse(BaseResponse):
|
|||
|
data: List[str] = pydantic.Field(..., description="List of names")
|
|||
|
|
|||
|
class Config:
|
|||
|
schema_extra = {
|
|||
|
"example": {
|
|||
|
"code": 200,
|
|||
|
"msg": "success",
|
|||
|
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
async def document():
|
|||
|
return RedirectResponse(url="/docs")
|
|||
|
|
|||
|
|
|||
|
def create_app(run_mode: str = None):
|
|||
|
app = FastAPI(
|
|||
|
title="Langchain-Chatchat API Server",
|
|||
|
version=VERSION
|
|||
|
)
|
|||
|
# Add CORS middleware to allow all origins
|
|||
|
# 在config.py中设置OPEN_DOMAIN=True,允许跨域
|
|||
|
# set OPEN_DOMAIN=True in config.py to allow cross-domain
|
|||
|
if OPEN_CROSS_DOMAIN:
|
|||
|
app.add_middleware(
|
|||
|
CORSMiddleware,
|
|||
|
allow_origins=["*"],
|
|||
|
allow_credentials=True,
|
|||
|
allow_methods=["*"],
|
|||
|
allow_headers=["*"],
|
|||
|
)
|
|||
|
return app
|
|||
|
|
|||
|
|
|||
|
def mount_knowledge_routes(app: FastAPI):
|
|||
|
from server.knowledge_base.kb_api import list_kbs, create_kb, delete_kb
|
|||
|
from server.knowledge_base.kb_doc_api import (list_files, upload_docs, delete_docs,
|
|||
|
update_docs, download_doc, recreate_vector_store,
|
|||
|
search_docs, DocumentWithScore, update_info)
|
|||
|
|
|||
|
# Tag: Knowledge Base Management
|
|||
|
app.get("/knowledge_base/list_knowledge_bases",
|
|||
|
tags=["Knowledge Base Management"],
|
|||
|
response_model=ListResponse,
|
|||
|
summary="获取知识库列表")(list_kbs)
|
|||
|
|
|||
|
app.post("/knowledge_base/create_knowledge_base",
|
|||
|
tags=["Knowledge Base Management"],
|
|||
|
response_model=BaseResponse,
|
|||
|
summary="创建知识库"
|
|||
|
)(create_kb)
|
|||
|
|
|||
|
app.post("/knowledge_base/delete_knowledge_base",
|
|||
|
tags=["Knowledge Base Management"],
|
|||
|
response_model=BaseResponse,
|
|||
|
summary="删除知识库"
|
|||
|
)(delete_kb)
|
|||
|
|
|||
|
app.get("/knowledge_base/list_files",
|
|||
|
tags=["Knowledge Base Management"],
|
|||
|
response_model=ListResponse,
|
|||
|
summary="获取知识库内的文件列表"
|
|||
|
)(list_files)
|
|||
|
|
|||
|
app.post("/knowledge_base/search_docs",
|
|||
|
tags=["Knowledge Base Management"],
|
|||
|
response_model=List[DocumentWithScore],
|
|||
|
summary="搜索知识库"
|
|||
|
)(search_docs)
|
|||
|
|
|||
|
app.post("/knowledge_base/upload_docs",
|
|||
|
tags=["Knowledge Base Management"],
|
|||
|
response_model=BaseResponse,
|
|||
|
summary="上传文件到知识库,并/或进行向量化"
|
|||
|
)(upload_docs)
|
|||
|
|
|||
|
app.post("/knowledge_base/delete_docs",
|
|||
|
tags=["Knowledge Base Management"],
|
|||
|
response_model=BaseResponse,
|
|||
|
summary="删除知识库内指定文件"
|
|||
|
)(delete_docs)
|
|||
|
|
|||
|
app.post("/knowledge_base/update_info",
|
|||
|
tags=["Knowledge Base Management"],
|
|||
|
response_model=BaseResponse,
|
|||
|
summary="更新知识库介绍"
|
|||
|
)(update_info)
|
|||
|
app.post("/knowledge_base/update_docs",
|
|||
|
tags=["Knowledge Base Management"],
|
|||
|
response_model=BaseResponse,
|
|||
|
summary="更新现有文件到知识库"
|
|||
|
)(update_docs)
|
|||
|
|
|||
|
app.get("/knowledge_base/download_doc",
|
|||
|
tags=["Knowledge Base Management"],
|
|||
|
summary="下载对应的知识文件")(download_doc)
|
|||
|
|
|||
|
app.post("/knowledge_base/recreate_vector_store",
|
|||
|
tags=["Knowledge Base Management"],
|
|||
|
summary="根据content中文档重建向量库,流式输出处理进度。"
|
|||
|
)(recreate_vector_store)
|
|||
|
|
|||
|
|
|||
|
def run_api(host, port, **kwargs):
|
|||
|
if kwargs.get("ssl_keyfile") and kwargs.get("ssl_certfile"):
|
|||
|
uvicorn.run(app,
|
|||
|
host=host,
|
|||
|
port=port,
|
|||
|
ssl_keyfile=kwargs.get("ssl_keyfile"),
|
|||
|
ssl_certfile=kwargs.get("ssl_certfile"),
|
|||
|
)
|
|||
|
else:
|
|||
|
uvicorn.run(app, host=host, port=port)
|
|||
|
|
|||
|
|
|||
|
if __name__ == "__main__":
|
|||
|
parser = argparse.ArgumentParser(prog='langchain-ChatGLM',
|
|||
|
description='About langchain-ChatGLM, local knowledge based ChatGLM with langchain'
|
|||
|
' | 基于本地知识库的 ChatGLM 问答')
|
|||
|
parser.add_argument("--host", type=str, default="0.0.0.0")
|
|||
|
parser.add_argument("--port", type=int, default=7861)
|
|||
|
parser.add_argument("--ssl_keyfile", type=str)
|
|||
|
parser.add_argument("--ssl_certfile", type=str)
|
|||
|
# 初始化消息
|
|||
|
args = parser.parse_args()
|
|||
|
args_dict = vars(args)
|
|||
|
|
|||
|
app = create_app()
|
|||
|
mount_knowledge_routes(app)
|
|||
|
|
|||
|
run_api(host=args.host,
|
|||
|
port=args.port,
|
|||
|
ssl_keyfile=args.ssl_keyfile,
|
|||
|
ssl_certfile=args.ssl_certfile,
|
|||
|
)
|