699 lines
24 KiB
Python
699 lines
24 KiB
Python
|
import pydantic
|
|||
|
from pydantic import BaseModel
|
|||
|
from typing import List
|
|||
|
from fastapi import FastAPI
|
|||
|
from pathlib import Path
|
|||
|
import asyncio
|
|||
|
from configs import (LLM_MODELS, LLM_DEVICE, EMBEDDING_DEVICE,
|
|||
|
MODEL_PATH, MODEL_ROOT_PATH, ONLINE_LLM_MODEL, logger, log_verbose,
|
|||
|
FSCHAT_MODEL_WORKERS, HTTPX_DEFAULT_TIMEOUT)
|
|||
|
import os
|
|||
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|||
|
from langchain.chat_models import ChatOpenAI, AzureChatOpenAI, ChatAnthropic
|
|||
|
from langchain.llms import OpenAI, AzureOpenAI, Anthropic
|
|||
|
import httpx
|
|||
|
from typing import Literal, Optional, Callable, Generator, Dict, Any, Awaitable, Union
|
|||
|
|
|||
|
|
|||
|
async def wrap_done(fn: Awaitable, event: asyncio.Event):
|
|||
|
"""Wrap an awaitable with a event to signal when it's done or an exception is raised."""
|
|||
|
try:
|
|||
|
await fn
|
|||
|
except Exception as e:
|
|||
|
# TODO: handle exception
|
|||
|
msg = f"Caught exception: {e}"
|
|||
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
|||
|
exc_info=e if log_verbose else None)
|
|||
|
finally:
|
|||
|
# Signal the aiter to stop.
|
|||
|
event.set()
|
|||
|
|
|||
|
|
|||
|
def get_ChatOpenAI(
|
|||
|
model_name: str,
|
|||
|
temperature: float,
|
|||
|
max_tokens: int = None,
|
|||
|
streaming: bool = True,
|
|||
|
callbacks: List[Callable] = [],
|
|||
|
verbose: bool = True,
|
|||
|
**kwargs: Any,
|
|||
|
) -> ChatOpenAI:
|
|||
|
## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装
|
|||
|
config_models = list_config_llm_models()
|
|||
|
|
|||
|
## 非Langchain原生支持的模型,走Fschat封装
|
|||
|
config = get_model_worker_config(model_name)
|
|||
|
model = ChatOpenAI(
|
|||
|
streaming=streaming,
|
|||
|
verbose=verbose,
|
|||
|
callbacks=callbacks,
|
|||
|
openai_api_key=config.get("api_key", "EMPTY"),
|
|||
|
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
|||
|
model_name=model_name,
|
|||
|
temperature=temperature,
|
|||
|
max_tokens=max_tokens,
|
|||
|
openai_proxy=config.get("openai_proxy"),
|
|||
|
**kwargs
|
|||
|
)
|
|||
|
|
|||
|
return model
|
|||
|
|
|||
|
|
|||
|
def get_OpenAI(
|
|||
|
model_name: str,
|
|||
|
temperature: float,
|
|||
|
max_tokens: int = None,
|
|||
|
streaming: bool = True,
|
|||
|
echo: bool = True,
|
|||
|
callbacks: List[Callable] = [],
|
|||
|
verbose: bool = True,
|
|||
|
**kwargs: Any,
|
|||
|
) -> OpenAI:
|
|||
|
## 以下模型是Langchain原生支持的模型,这些模型不会走Fschat封装
|
|||
|
config_models = list_config_llm_models()
|
|||
|
if model_name in config_models.get("langchain", {}):
|
|||
|
config = config_models["langchain"][model_name]
|
|||
|
if model_name == "Azure-OpenAI":
|
|||
|
model = AzureOpenAI(
|
|||
|
streaming=streaming,
|
|||
|
verbose=verbose,
|
|||
|
callbacks=callbacks,
|
|||
|
deployment_name=config.get("deployment_name"),
|
|||
|
model_version=config.get("model_version"),
|
|||
|
openai_api_type=config.get("openai_api_type"),
|
|||
|
openai_api_base=config.get("api_base_url"),
|
|||
|
openai_api_version=config.get("api_version"),
|
|||
|
openai_api_key=config.get("api_key"),
|
|||
|
openai_proxy=config.get("openai_proxy"),
|
|||
|
temperature=temperature,
|
|||
|
max_tokens=max_tokens,
|
|||
|
echo=echo,
|
|||
|
)
|
|||
|
|
|||
|
elif model_name == "OpenAI":
|
|||
|
model = OpenAI(
|
|||
|
streaming=streaming,
|
|||
|
verbose=verbose,
|
|||
|
callbacks=callbacks,
|
|||
|
model_name=config.get("model_name"),
|
|||
|
openai_api_base=config.get("api_base_url"),
|
|||
|
openai_api_key=config.get("api_key"),
|
|||
|
openai_proxy=config.get("openai_proxy"),
|
|||
|
temperature=temperature,
|
|||
|
max_tokens=max_tokens,
|
|||
|
echo=echo,
|
|||
|
)
|
|||
|
elif model_name == "Anthropic":
|
|||
|
model = Anthropic(
|
|||
|
streaming=streaming,
|
|||
|
verbose=verbose,
|
|||
|
callbacks=callbacks,
|
|||
|
model_name=config.get("model_name"),
|
|||
|
anthropic_api_key=config.get("api_key"),
|
|||
|
echo=echo,
|
|||
|
)
|
|||
|
## TODO 支持其他的Langchain原生支持的模型
|
|||
|
else:
|
|||
|
## 非Langchain原生支持的模型,走Fschat封装
|
|||
|
config = get_model_worker_config(model_name)
|
|||
|
model = OpenAI(
|
|||
|
streaming=streaming,
|
|||
|
verbose=verbose,
|
|||
|
callbacks=callbacks,
|
|||
|
openai_api_key=config.get("api_key", "EMPTY"),
|
|||
|
openai_api_base=config.get("api_base_url", fschat_openai_api_address()),
|
|||
|
model_name=model_name,
|
|||
|
temperature=temperature,
|
|||
|
max_tokens=max_tokens,
|
|||
|
openai_proxy=config.get("openai_proxy"),
|
|||
|
echo=echo,
|
|||
|
**kwargs
|
|||
|
)
|
|||
|
|
|||
|
return model
|
|||
|
|
|||
|
|
|||
|
class BaseResponse(BaseModel):
|
|||
|
code: int = pydantic.Field(200, description="API status code")
|
|||
|
msg: str = pydantic.Field("success", description="API status message")
|
|||
|
data: Any = pydantic.Field(None, description="API data")
|
|||
|
|
|||
|
class Config:
|
|||
|
schema_extra = {
|
|||
|
"example": {
|
|||
|
"code": 200,
|
|||
|
"msg": "success",
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
class ListResponse(BaseResponse):
|
|||
|
data: List[str] = pydantic.Field(..., description="List of names")
|
|||
|
|
|||
|
class Config:
|
|||
|
schema_extra = {
|
|||
|
"example": {
|
|||
|
"code": 200,
|
|||
|
"msg": "success",
|
|||
|
"data": ["doc1.docx", "doc2.pdf", "doc3.txt"],
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
class ChatMessage(BaseModel):
|
|||
|
question: str = pydantic.Field(..., description="Question text")
|
|||
|
response: str = pydantic.Field(..., description="Response text")
|
|||
|
history: List[List[str]] = pydantic.Field(..., description="History text")
|
|||
|
source_documents: List[str] = pydantic.Field(
|
|||
|
..., description="List of source documents and their scores"
|
|||
|
)
|
|||
|
|
|||
|
class Config:
|
|||
|
schema_extra = {
|
|||
|
"example": {
|
|||
|
"question": "工伤保险如何办理?",
|
|||
|
"response": "根据已知信息,可以总结如下:\n\n1. 参保单位为员工缴纳工伤保险费,以保障员工在发生工伤时能够获得相应的待遇。\n"
|
|||
|
"2. 不同地区的工伤保险缴费规定可能有所不同,需要向当地社保部门咨询以了解具体的缴费标准和规定。\n"
|
|||
|
"3. 工伤从业人员及其近亲属需要申请工伤认定,确认享受的待遇资格,并按时缴纳工伤保险费。\n"
|
|||
|
"4. 工伤保险待遇包括工伤医疗、康复、辅助器具配置费用、伤残待遇、工亡待遇、一次性工亡补助金等。\n"
|
|||
|
"5. 工伤保险待遇领取资格认证包括长期待遇领取人员认证和一次性待遇领取人员认证。\n"
|
|||
|
"6. 工伤保险基金支付的待遇项目包括工伤医疗待遇、康复待遇、辅助器具配置费用、一次性工亡补助金、丧葬补助金等。",
|
|||
|
"history": [
|
|||
|
[
|
|||
|
"工伤保险是什么?",
|
|||
|
"工伤保险是指用人单位按照国家规定,为本单位的职工和用人单位的其他人员,缴纳工伤保险费,"
|
|||
|
"由保险机构按照国家规定的标准,给予工伤保险待遇的社会保险制度。",
|
|||
|
]
|
|||
|
],
|
|||
|
"source_documents": [
|
|||
|
"出处 [1] 广州市单位从业的特定人员参加工伤保险办事指引.docx:\n\n\t"
|
|||
|
"( 一) 从业单位 (组织) 按“自愿参保”原则, 为未建 立劳动关系的特定从业人员单项参加工伤保险 、缴纳工伤保 险费。",
|
|||
|
"出处 [2] ...",
|
|||
|
"出处 [3] ...",
|
|||
|
],
|
|||
|
}
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
def torch_gc():
|
|||
|
try:
|
|||
|
import torch
|
|||
|
if torch.cuda.is_available():
|
|||
|
# with torch.cuda.device(DEVICE):
|
|||
|
torch.cuda.empty_cache()
|
|||
|
torch.cuda.ipc_collect()
|
|||
|
elif torch.backends.mps.is_available():
|
|||
|
try:
|
|||
|
from torch.mps import empty_cache
|
|||
|
empty_cache()
|
|||
|
except Exception as e:
|
|||
|
msg = ("如果您使用的是 macOS 建议将 pytorch 版本升级至 2.0.0 或更高版本,"
|
|||
|
"以支持及时清理 torch 产生的内存占用。")
|
|||
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
|||
|
exc_info=e if log_verbose else None)
|
|||
|
except Exception:
|
|||
|
...
|
|||
|
|
|||
|
|
|||
|
def run_async(cor):
|
|||
|
'''
|
|||
|
在同步环境中运行异步代码.
|
|||
|
'''
|
|||
|
try:
|
|||
|
loop = asyncio.get_event_loop()
|
|||
|
except:
|
|||
|
loop = asyncio.new_event_loop()
|
|||
|
return loop.run_until_complete(cor)
|
|||
|
|
|||
|
|
|||
|
def iter_over_async(ait, loop=None):
|
|||
|
'''
|
|||
|
将异步生成器封装成同步生成器.
|
|||
|
'''
|
|||
|
ait = ait.__aiter__()
|
|||
|
|
|||
|
async def get_next():
|
|||
|
try:
|
|||
|
obj = await ait.__anext__()
|
|||
|
return False, obj
|
|||
|
except StopAsyncIteration:
|
|||
|
return True, None
|
|||
|
|
|||
|
if loop is None:
|
|||
|
try:
|
|||
|
loop = asyncio.get_event_loop()
|
|||
|
except:
|
|||
|
loop = asyncio.new_event_loop()
|
|||
|
|
|||
|
while True:
|
|||
|
done, obj = loop.run_until_complete(get_next())
|
|||
|
if done:
|
|||
|
break
|
|||
|
yield obj
|
|||
|
|
|||
|
|
|||
|
def MakeFastAPIOffline(
|
|||
|
app: FastAPI,
|
|||
|
static_dir=Path(__file__).parent / "static",
|
|||
|
static_url="/static-offline-docs",
|
|||
|
docs_url: Optional[str] = "/docs",
|
|||
|
redoc_url: Optional[str] = "/redoc",
|
|||
|
) -> None:
|
|||
|
"""patch the FastAPI obj that doesn't rely on CDN for the documentation page"""
|
|||
|
from fastapi import Request
|
|||
|
from fastapi.openapi.docs import (
|
|||
|
get_redoc_html,
|
|||
|
get_swagger_ui_html,
|
|||
|
get_swagger_ui_oauth2_redirect_html,
|
|||
|
)
|
|||
|
from fastapi.staticfiles import StaticFiles
|
|||
|
from starlette.responses import HTMLResponse
|
|||
|
|
|||
|
openapi_url = app.openapi_url
|
|||
|
swagger_ui_oauth2_redirect_url = app.swagger_ui_oauth2_redirect_url
|
|||
|
|
|||
|
def remove_route(url: str) -> None:
|
|||
|
'''
|
|||
|
remove original route from app
|
|||
|
'''
|
|||
|
index = None
|
|||
|
for i, r in enumerate(app.routes):
|
|||
|
if r.path.lower() == url.lower():
|
|||
|
index = i
|
|||
|
break
|
|||
|
if isinstance(index, int):
|
|||
|
app.routes.pop(index)
|
|||
|
|
|||
|
# Set up static file mount
|
|||
|
app.mount(
|
|||
|
static_url,
|
|||
|
StaticFiles(directory=Path(static_dir).as_posix()),
|
|||
|
name="static-offline-docs",
|
|||
|
)
|
|||
|
|
|||
|
if docs_url is not None:
|
|||
|
remove_route(docs_url)
|
|||
|
remove_route(swagger_ui_oauth2_redirect_url)
|
|||
|
|
|||
|
# Define the doc and redoc pages, pointing at the right files
|
|||
|
@app.get(docs_url, include_in_schema=False)
|
|||
|
async def custom_swagger_ui_html(request: Request) -> HTMLResponse:
|
|||
|
root = request.scope.get("root_path")
|
|||
|
favicon = f"{root}{static_url}/favicon.png"
|
|||
|
return get_swagger_ui_html(
|
|||
|
openapi_url=f"{root}{openapi_url}",
|
|||
|
title=app.title + " - Swagger UI",
|
|||
|
oauth2_redirect_url=swagger_ui_oauth2_redirect_url,
|
|||
|
swagger_js_url=f"{root}{static_url}/swagger-ui-bundle.js",
|
|||
|
swagger_css_url=f"{root}{static_url}/swagger-ui.css",
|
|||
|
swagger_favicon_url=favicon,
|
|||
|
)
|
|||
|
|
|||
|
@app.get(swagger_ui_oauth2_redirect_url, include_in_schema=False)
|
|||
|
async def swagger_ui_redirect() -> HTMLResponse:
|
|||
|
return get_swagger_ui_oauth2_redirect_html()
|
|||
|
|
|||
|
if redoc_url is not None:
|
|||
|
remove_route(redoc_url)
|
|||
|
|
|||
|
@app.get(redoc_url, include_in_schema=False)
|
|||
|
async def redoc_html(request: Request) -> HTMLResponse:
|
|||
|
root = request.scope.get("root_path")
|
|||
|
favicon = f"{root}{static_url}/favicon.png"
|
|||
|
|
|||
|
return get_redoc_html(
|
|||
|
openapi_url=f"{root}{openapi_url}",
|
|||
|
title=app.title + " - ReDoc",
|
|||
|
redoc_js_url=f"{root}{static_url}/redoc.standalone.js",
|
|||
|
with_google_fonts=False,
|
|||
|
redoc_favicon_url=favicon,
|
|||
|
)
|
|||
|
|
|||
|
|
|||
|
# 从model_config中获取模型信息
|
|||
|
|
|||
|
def list_embed_models() -> List[str]:
|
|||
|
'''
|
|||
|
get names of configured embedding models
|
|||
|
'''
|
|||
|
return list(MODEL_PATH["embed_model"])
|
|||
|
|
|||
|
|
|||
|
def list_config_llm_models() -> Dict[str, Dict]:
|
|||
|
'''
|
|||
|
get configured llm models with different types.
|
|||
|
return [(model_name, config_type), ...]
|
|||
|
'''
|
|||
|
workers = list(FSCHAT_MODEL_WORKERS)
|
|||
|
|
|||
|
return {
|
|||
|
"local": MODEL_PATH["llm_model"],
|
|||
|
"online": ONLINE_LLM_MODEL,
|
|||
|
"worker": workers,
|
|||
|
}
|
|||
|
|
|||
|
|
|||
|
def get_model_path(model_name: str, type: str = None) -> Optional[str]:
|
|||
|
if type in MODEL_PATH:
|
|||
|
paths = MODEL_PATH[type]
|
|||
|
else:
|
|||
|
paths = {}
|
|||
|
for v in MODEL_PATH.values():
|
|||
|
paths.update(v)
|
|||
|
|
|||
|
if path_str := paths.get(model_name): # 以 "chatglm-6b": "THUDM/chatglm-6b-new" 为例,以下都是支持的路径
|
|||
|
path = Path(path_str)
|
|||
|
if path.is_dir(): # 任意绝对路径
|
|||
|
return str(path)
|
|||
|
|
|||
|
root_path = Path(MODEL_ROOT_PATH)
|
|||
|
if root_path.is_dir():
|
|||
|
path = root_path / model_name
|
|||
|
if path.is_dir(): # use key, {MODEL_ROOT_PATH}/chatglm-6b
|
|||
|
return str(path)
|
|||
|
path = root_path / path_str
|
|||
|
if path.is_dir(): # use value, {MODEL_ROOT_PATH}/THUDM/chatglm-6b-new
|
|||
|
return str(path)
|
|||
|
path = root_path / path_str.split("/")[-1]
|
|||
|
if path.is_dir(): # use value split by "/", {MODEL_ROOT_PATH}/chatglm-6b-new
|
|||
|
return str(path)
|
|||
|
return path_str # THUDM/chatglm06b
|
|||
|
|
|||
|
|
|||
|
# 从server_config中获取服务信息
|
|||
|
|
|||
|
def get_model_worker_config(model_name: str = None) -> dict:
|
|||
|
'''
|
|||
|
加载model worker的配置项。
|
|||
|
优先级:FSCHAT_MODEL_WORKERS[model_name] > ONLINE_LLM_MODEL[model_name] > FSCHAT_MODEL_WORKERS["default"]
|
|||
|
'''
|
|||
|
from configs.model_config import ONLINE_LLM_MODEL, MODEL_PATH
|
|||
|
from configs.server_config import FSCHAT_MODEL_WORKERS
|
|||
|
from server import model_workers
|
|||
|
|
|||
|
config = FSCHAT_MODEL_WORKERS.get("default", {}).copy()
|
|||
|
config.update(ONLINE_LLM_MODEL.get(model_name, {}).copy())
|
|||
|
config.update(FSCHAT_MODEL_WORKERS.get(model_name, {}).copy())
|
|||
|
|
|||
|
if model_name in ONLINE_LLM_MODEL:
|
|||
|
config["online_api"] = True
|
|||
|
if provider := config.get("provider"):
|
|||
|
try:
|
|||
|
config["worker_class"] = getattr(model_workers, provider)
|
|||
|
except Exception as e:
|
|||
|
msg = f"在线模型 ‘{model_name}’ 的provider没有正确配置"
|
|||
|
logger.error(f'{e.__class__.__name__}: {msg}',
|
|||
|
exc_info=e if log_verbose else None)
|
|||
|
# 本地模型
|
|||
|
if model_name in MODEL_PATH["llm_model"]:
|
|||
|
config["model_path"] = get_model_path(model_name)
|
|||
|
print(config["model_path"])
|
|||
|
config["device"] = llm_device(config.get("device"))
|
|||
|
return config
|
|||
|
|
|||
|
|
|||
|
def get_all_model_worker_configs() -> dict:
|
|||
|
result = {}
|
|||
|
model_names = set(FSCHAT_MODEL_WORKERS.keys())
|
|||
|
for name in model_names:
|
|||
|
if name != "default":
|
|||
|
result[name] = get_model_worker_config(name)
|
|||
|
return result
|
|||
|
|
|||
|
|
|||
|
def fschat_controller_address() -> str:
|
|||
|
from configs.server_config import FSCHAT_CONTROLLER
|
|||
|
|
|||
|
host = FSCHAT_CONTROLLER["host"]
|
|||
|
if host == "0.0.0.0":
|
|||
|
host = "127.0.0.1"
|
|||
|
port = FSCHAT_CONTROLLER["port"]
|
|||
|
return f"http://{host}:{port}"
|
|||
|
|
|||
|
|
|||
|
def fschat_model_worker_address(model_name: str = LLM_MODELS[0]) -> str:
|
|||
|
if model := get_model_worker_config(model_name): # TODO: depends fastchat
|
|||
|
host = model["host"]
|
|||
|
if host == "0.0.0.0":
|
|||
|
host = "127.0.0.1"
|
|||
|
port = model["port"]
|
|||
|
return f"http://{host}:{port}"
|
|||
|
return ""
|
|||
|
|
|||
|
|
|||
|
def fschat_openai_api_address() -> str:
|
|||
|
from configs.server_config import FSCHAT_OPENAI_API
|
|||
|
|
|||
|
host = FSCHAT_OPENAI_API["host"]
|
|||
|
if host == "0.0.0.0":
|
|||
|
host = "127.0.0.1"
|
|||
|
port = FSCHAT_OPENAI_API["port"]
|
|||
|
return f"http://{host}:{port}/v1"
|
|||
|
|
|||
|
|
|||
|
def api_address() -> str:
|
|||
|
from configs.server_config import API_SERVER
|
|||
|
|
|||
|
host = API_SERVER["host"]
|
|||
|
if host == "0.0.0.0":
|
|||
|
host = "127.0.0.1"
|
|||
|
port = API_SERVER["port"]
|
|||
|
return f"http://{host}:{port}"
|
|||
|
|
|||
|
|
|||
|
def webui_address() -> str:
|
|||
|
from configs.server_config import WEBUI_SERVER
|
|||
|
|
|||
|
host = WEBUI_SERVER["host"]
|
|||
|
port = WEBUI_SERVER["port"]
|
|||
|
return f"http://{host}:{port}"
|
|||
|
|
|||
|
|
|||
|
def get_prompt_template(type: str, name: str) -> Optional[str]:
|
|||
|
'''
|
|||
|
从prompt_config中加载模板内容
|
|||
|
type: "llm_chat","agent_chat","knowledge_base_chat","search_engine_chat"的其中一种,如果有新功能,应该进行加入。
|
|||
|
'''
|
|||
|
|
|||
|
from configs import prompt_config
|
|||
|
import importlib
|
|||
|
importlib.reload(prompt_config) # TODO: 检查configs/prompt_config.py文件有修改再重新加载
|
|||
|
return prompt_config.PROMPT_TEMPLATES[type].get(name)
|
|||
|
|
|||
|
|
|||
|
def set_httpx_config(
|
|||
|
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
|||
|
proxy: Union[str, Dict] = None,
|
|||
|
):
|
|||
|
'''
|
|||
|
设置httpx默认timeout。httpx默认timeout是5秒,在请求LLM回答时不够用。
|
|||
|
将本项目相关服务加入无代理列表,避免fastchat的服务器请求错误。(windows下无效)
|
|||
|
对于chatgpt等在线API,如要使用代理需要手动配置。搜索引擎的代理如何处置还需考虑。
|
|||
|
'''
|
|||
|
|
|||
|
import httpx
|
|||
|
import os
|
|||
|
|
|||
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.connect = timeout
|
|||
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.read = timeout
|
|||
|
httpx._config.DEFAULT_TIMEOUT_CONFIG.write = timeout
|
|||
|
|
|||
|
# 在进程范围内设置系统级代理
|
|||
|
proxies = {}
|
|||
|
if isinstance(proxy, str):
|
|||
|
for n in ["http", "https", "all"]:
|
|||
|
proxies[n + "_proxy"] = proxy
|
|||
|
elif isinstance(proxy, dict):
|
|||
|
for n in ["http", "https", "all"]:
|
|||
|
if p := proxy.get(n):
|
|||
|
proxies[n + "_proxy"] = p
|
|||
|
elif p := proxy.get(n + "_proxy"):
|
|||
|
proxies[n + "_proxy"] = p
|
|||
|
|
|||
|
for k, v in proxies.items():
|
|||
|
os.environ[k] = v
|
|||
|
|
|||
|
# set host to bypass proxy
|
|||
|
no_proxy = [x.strip() for x in os.environ.get("no_proxy", "").split(",") if x.strip()]
|
|||
|
no_proxy += [
|
|||
|
# do not use proxy for locahost
|
|||
|
"http://127.0.0.1",
|
|||
|
"http://localhost",
|
|||
|
]
|
|||
|
# do not use proxy for user deployed fastchat servers
|
|||
|
for x in [
|
|||
|
fschat_controller_address(),
|
|||
|
fschat_model_worker_address(),
|
|||
|
fschat_openai_api_address(),
|
|||
|
]:
|
|||
|
host = ":".join(x.split(":")[:2])
|
|||
|
if host not in no_proxy:
|
|||
|
no_proxy.append(host)
|
|||
|
os.environ["NO_PROXY"] = ",".join(no_proxy)
|
|||
|
|
|||
|
# TODO: 简单的清除系统代理不是个好的选择,影响太多。似乎修改代理服务器的bypass列表更好。
|
|||
|
# patch requests to use custom proxies instead of system settings
|
|||
|
def _get_proxies():
|
|||
|
return proxies
|
|||
|
|
|||
|
import urllib.request
|
|||
|
urllib.request.getproxies = _get_proxies
|
|||
|
|
|||
|
# 自动检查torch可用的设备。分布式部署时,不运行LLM的机器上可以不装torch
|
|||
|
|
|||
|
|
|||
|
def detect_device() -> Literal["cuda", "mps", "cpu"]:
|
|||
|
try:
|
|||
|
import torch
|
|||
|
if torch.cuda.is_available():
|
|||
|
return "cuda"
|
|||
|
if torch.backends.mps.is_available():
|
|||
|
return "mps"
|
|||
|
except:
|
|||
|
pass
|
|||
|
return "cpu"
|
|||
|
|
|||
|
|
|||
|
def llm_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
|||
|
device = device or LLM_DEVICE
|
|||
|
if device not in ["cuda", "mps", "cpu"]:
|
|||
|
device = detect_device()
|
|||
|
return device
|
|||
|
|
|||
|
|
|||
|
def embedding_device(device: str = None) -> Literal["cuda", "mps", "cpu"]:
|
|||
|
device = device or EMBEDDING_DEVICE
|
|||
|
if device not in ["cuda", "mps", "cpu"]:
|
|||
|
device = detect_device()
|
|||
|
return device
|
|||
|
|
|||
|
|
|||
|
def run_in_thread_pool(
|
|||
|
func: Callable,
|
|||
|
params: List[Dict] = [],
|
|||
|
) -> Generator:
|
|||
|
'''
|
|||
|
在线程池中批量运行任务,并将运行结果以生成器的形式返回。
|
|||
|
请确保任务中的所有操作是线程安全的,任务函数请全部使用关键字参数。
|
|||
|
'''
|
|||
|
tasks = []
|
|||
|
with ThreadPoolExecutor() as pool:
|
|||
|
for kwargs in params:
|
|||
|
thread = pool.submit(func, **kwargs)
|
|||
|
tasks.append(thread)
|
|||
|
|
|||
|
for obj in as_completed(tasks): # TODO: Ctrl+c无法停止
|
|||
|
yield obj.result()
|
|||
|
|
|||
|
|
|||
|
def get_httpx_client(
|
|||
|
use_async: bool = False,
|
|||
|
proxies: Union[str, Dict] = None,
|
|||
|
timeout: float = HTTPX_DEFAULT_TIMEOUT,
|
|||
|
**kwargs,
|
|||
|
) -> Union[httpx.Client, httpx.AsyncClient]:
|
|||
|
'''
|
|||
|
helper to get httpx client with default proxies that bypass local addesses.
|
|||
|
'''
|
|||
|
default_proxies = {
|
|||
|
# do not use proxy for locahost
|
|||
|
"all://127.0.0.1": None,
|
|||
|
"all://localhost": None,
|
|||
|
}
|
|||
|
# do not use proxy for user deployed fastchat servers
|
|||
|
for x in [
|
|||
|
fschat_controller_address(),
|
|||
|
fschat_model_worker_address(),
|
|||
|
fschat_openai_api_address(),
|
|||
|
]:
|
|||
|
host = ":".join(x.split(":")[:2])
|
|||
|
default_proxies.update({host: None})
|
|||
|
|
|||
|
# get proxies from system envionrent
|
|||
|
# proxy not str empty string, None, False, 0, [] or {}
|
|||
|
default_proxies.update({
|
|||
|
"http://": (os.environ.get("http_proxy")
|
|||
|
if os.environ.get("http_proxy") and len(os.environ.get("http_proxy").strip())
|
|||
|
else None),
|
|||
|
"https://": (os.environ.get("https_proxy")
|
|||
|
if os.environ.get("https_proxy") and len(os.environ.get("https_proxy").strip())
|
|||
|
else None),
|
|||
|
"all://": (os.environ.get("all_proxy")
|
|||
|
if os.environ.get("all_proxy") and len(os.environ.get("all_proxy").strip())
|
|||
|
else None),
|
|||
|
})
|
|||
|
for host in os.environ.get("no_proxy", "").split(","):
|
|||
|
if host := host.strip():
|
|||
|
default_proxies.update({host: None})
|
|||
|
|
|||
|
# merge default proxies with user provided proxies
|
|||
|
if isinstance(proxies, str):
|
|||
|
proxies = {"all://": proxies}
|
|||
|
|
|||
|
if isinstance(proxies, dict):
|
|||
|
default_proxies.update(proxies)
|
|||
|
|
|||
|
# construct Client
|
|||
|
kwargs.update(timeout=timeout, proxies=default_proxies)
|
|||
|
print(kwargs)
|
|||
|
if use_async:
|
|||
|
return httpx.AsyncClient(**kwargs)
|
|||
|
else:
|
|||
|
return httpx.Client(**kwargs)
|
|||
|
|
|||
|
|
|||
|
def get_server_configs() -> Dict:
|
|||
|
'''
|
|||
|
获取configs中的原始配置项,供前端使用
|
|||
|
'''
|
|||
|
from configs.kb_config import (
|
|||
|
DEFAULT_KNOWLEDGE_BASE,
|
|||
|
DEFAULT_SEARCH_ENGINE,
|
|||
|
DEFAULT_VS_TYPE,
|
|||
|
CHUNK_SIZE,
|
|||
|
OVERLAP_SIZE,
|
|||
|
SCORE_THRESHOLD,
|
|||
|
VECTOR_SEARCH_TOP_K,
|
|||
|
SEARCH_ENGINE_TOP_K,
|
|||
|
ZH_TITLE_ENHANCE,
|
|||
|
text_splitter_dict,
|
|||
|
TEXT_SPLITTER_NAME,
|
|||
|
)
|
|||
|
from configs.model_config import (
|
|||
|
LLM_MODELS,
|
|||
|
HISTORY_LEN,
|
|||
|
TEMPERATURE,
|
|||
|
)
|
|||
|
from configs.prompt_config import PROMPT_TEMPLATES
|
|||
|
|
|||
|
_custom = {
|
|||
|
"controller_address": fschat_controller_address(),
|
|||
|
"openai_api_address": fschat_openai_api_address(),
|
|||
|
"api_address": api_address(),
|
|||
|
}
|
|||
|
|
|||
|
return {**{k: v for k, v in locals().items() if k[0] != "_"}, **_custom}
|
|||
|
|
|||
|
|
|||
|
def list_online_embed_models() -> List[str]:
|
|||
|
from server import model_workers
|
|||
|
|
|||
|
ret = []
|
|||
|
for k, v in list_config_llm_models()["online"].items():
|
|||
|
if provider := v.get("provider"):
|
|||
|
worker_class = getattr(model_workers, provider, None)
|
|||
|
if worker_class is not None and worker_class.can_embedding():
|
|||
|
ret.append(k)
|
|||
|
return ret
|
|||
|
|
|||
|
|
|||
|
def load_local_embeddings(model: str = None, device: str = embedding_device()):
|
|||
|
'''
|
|||
|
从缓存中加载embeddings,可以避免多线程时竞争加载。
|
|||
|
'''
|
|||
|
from server.knowledge_base.kb_cache.base import embeddings_pool
|
|||
|
from configs import EMBEDDING_MODEL
|
|||
|
|
|||
|
model = model or EMBEDDING_MODEL
|
|||
|
return embeddings_pool.load_embeddings(model=model, device=device)
|