259 lines
9.6 KiB
Python
259 lines
9.6 KiB
Python
"""
|
||
调用示例: python llm_api_stale.py --model-path-address THUDM/chatglm2-6b@localhost@7650 THUDM/chatglm2-6b-32k@localhost@7651
|
||
其他fastchat.server.controller/worker/openai_api_server参数可按照fastchat文档调用
|
||
但少数非关键参数如--worker-address,--allowed-origins,--allowed-methods,--allowed-headers不支持
|
||
|
||
"""
|
||
import sys
|
||
import os
|
||
|
||
sys.path.append(os.path.dirname(os.path.dirname(__file__)))
|
||
|
||
import subprocess
|
||
import re
|
||
import logging
|
||
import argparse
|
||
|
||
LOG_PATH = "./logs/"
|
||
LOG_FORMAT = "%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s"
|
||
logger = logging.getLogger()
|
||
logger.setLevel(logging.INFO)
|
||
logging.basicConfig(format=LOG_FORMAT)
|
||
|
||
parser = argparse.ArgumentParser()
|
||
# ------multi worker-----------------
|
||
parser.add_argument('--model-path-address',
|
||
default="THUDM/chatglm2-6b@localhost@20002",
|
||
nargs="+",
|
||
type=str,
|
||
help="model path, host, and port, formatted as model-path@host@port")
|
||
# ---------------controller-------------------------
|
||
|
||
parser.add_argument("--controller-host", type=str, default="localhost")
|
||
parser.add_argument("--controller-port", type=int, default=21001)
|
||
parser.add_argument(
|
||
"--dispatch-method",
|
||
type=str,
|
||
choices=["lottery", "shortest_queue"],
|
||
default="shortest_queue",
|
||
)
|
||
controller_args = ["controller-host", "controller-port", "dispatch-method"]
|
||
|
||
# ----------------------worker------------------------------------------
|
||
|
||
parser.add_argument("--worker-host", type=str, default="localhost")
|
||
parser.add_argument("--worker-port", type=int, default=21002)
|
||
# parser.add_argument("--worker-address", type=str, default="http://localhost:21002")
|
||
# parser.add_argument(
|
||
# "--controller-address", type=str, default="http://localhost:21001"
|
||
# )
|
||
parser.add_argument(
|
||
"--model-path",
|
||
type=str,
|
||
default="lmsys/vicuna-7b-v1.3",
|
||
help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
|
||
)
|
||
parser.add_argument(
|
||
"--revision",
|
||
type=str,
|
||
default="main",
|
||
help="Hugging Face Hub model revision identifier",
|
||
)
|
||
parser.add_argument(
|
||
"--device",
|
||
type=str,
|
||
choices=["cpu", "cuda", "mps", "xpu"],
|
||
default="cuda",
|
||
help="The device type",
|
||
)
|
||
parser.add_argument(
|
||
"--gpus",
|
||
type=str,
|
||
default="0",
|
||
help="A single GPU like 1 or multiple GPUs like 0,2",
|
||
)
|
||
parser.add_argument("--num-gpus", type=int, default=1)
|
||
parser.add_argument(
|
||
"--max-gpu-memory",
|
||
type=str,
|
||
default="20GiB",
|
||
help="The maximum memory per gpu. Use a string like '13Gib'",
|
||
)
|
||
parser.add_argument(
|
||
"--load-8bit", action="store_true", help="Use 8-bit quantization"
|
||
)
|
||
parser.add_argument(
|
||
"--cpu-offloading",
|
||
action="store_true",
|
||
help="Only when using 8-bit quantization: Offload excess weights to the CPU that don't fit on the GPU",
|
||
)
|
||
parser.add_argument(
|
||
"--gptq-ckpt",
|
||
type=str,
|
||
default=None,
|
||
help="Load quantized model. The path to the local GPTQ checkpoint.",
|
||
)
|
||
parser.add_argument(
|
||
"--gptq-wbits",
|
||
type=int,
|
||
default=16,
|
||
choices=[2, 3, 4, 8, 16],
|
||
help="#bits to use for quantization",
|
||
)
|
||
parser.add_argument(
|
||
"--gptq-groupsize",
|
||
type=int,
|
||
default=-1,
|
||
help="Groupsize to use for quantization; default uses full row.",
|
||
)
|
||
parser.add_argument(
|
||
"--gptq-act-order",
|
||
action="store_true",
|
||
help="Whether to apply the activation order GPTQ heuristic",
|
||
)
|
||
parser.add_argument(
|
||
"--model-names",
|
||
type=lambda s: s.split(","),
|
||
help="Optional display comma separated names",
|
||
)
|
||
parser.add_argument(
|
||
"--limit-worker-concurrency",
|
||
type=int,
|
||
default=5,
|
||
help="Limit the model concurrency to prevent OOM.",
|
||
)
|
||
parser.add_argument("--stream-interval", type=int, default=2)
|
||
parser.add_argument("--no-register", action="store_true")
|
||
|
||
worker_args = [
|
||
"worker-host", "worker-port",
|
||
"model-path", "revision", "device", "gpus", "num-gpus",
|
||
"max-gpu-memory", "load-8bit", "cpu-offloading",
|
||
"gptq-ckpt", "gptq-wbits", "gptq-groupsize",
|
||
"gptq-act-order", "model-names", "limit-worker-concurrency",
|
||
"stream-interval", "no-register",
|
||
"controller-address", "worker-address"
|
||
]
|
||
# -----------------openai server---------------------------
|
||
|
||
parser.add_argument("--server-host", type=str, default="localhost", help="host name")
|
||
parser.add_argument("--server-port", type=int, default=8888, help="port number")
|
||
parser.add_argument(
|
||
"--allow-credentials", action="store_true", help="allow credentials"
|
||
)
|
||
# parser.add_argument(
|
||
# "--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
|
||
# )
|
||
# parser.add_argument(
|
||
# "--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
|
||
# )
|
||
# parser.add_argument(
|
||
# "--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
|
||
# )
|
||
parser.add_argument(
|
||
"--api-keys",
|
||
type=lambda s: s.split(","),
|
||
help="Optional list of comma separated API keys",
|
||
)
|
||
server_args = ["server-host", "server-port", "allow-credentials", "api-keys",
|
||
"controller-address"
|
||
]
|
||
|
||
# 0,controller, model_worker, openai_api_server
|
||
# 1, 命令行选项
|
||
# 2,LOG_PATH
|
||
# 3, log的文件名
|
||
base_launch_sh = "nohup python3 -m fastchat.serve.{0} {1} >{2}/{3}.log 2>&1 &"
|
||
|
||
# 0 log_path
|
||
# ! 1 log的文件名,必须与bash_launch_sh一致
|
||
# 2 controller, worker, openai_api_server
|
||
base_check_sh = """while [ `grep -c "Uvicorn running on" {0}/{1}.log` -eq '0' ];do
|
||
sleep 5s;
|
||
echo "wait {2} running"
|
||
done
|
||
echo '{2} running' """
|
||
|
||
|
||
def string_args(args, args_list):
|
||
"""将args中的key转化为字符串"""
|
||
args_str = ""
|
||
for key, value in args._get_kwargs():
|
||
# args._get_kwargs中的key以_为分隔符,先转换,再判断是否在指定的args列表中
|
||
key = key.replace("_", "-")
|
||
if key not in args_list:
|
||
continue
|
||
# fastchat中port,host没有前缀,去除前缀
|
||
key = key.split("-")[-1] if re.search("port|host", key) else key
|
||
if not value:
|
||
pass
|
||
# 1==True -> True
|
||
elif isinstance(value, bool) and value == True:
|
||
args_str += f" --{key} "
|
||
elif isinstance(value, list) or isinstance(value, tuple) or isinstance(value, set):
|
||
value = " ".join(value)
|
||
args_str += f" --{key} {value} "
|
||
else:
|
||
args_str += f" --{key} {value} "
|
||
|
||
return args_str
|
||
|
||
|
||
def launch_worker(item, args, worker_args=worker_args):
|
||
log_name = item.split("/")[-1].split("\\")[-1].replace("-", "_").replace("@", "_").replace(".", "_")
|
||
# 先分割model-path-address,在传到string_args中分析参数
|
||
args.model_path, args.worker_host, args.worker_port = item.split("@")
|
||
args.worker_address = f"http://{args.worker_host}:{args.worker_port}"
|
||
print("*" * 80)
|
||
print(f"如长时间未启动,请到{LOG_PATH}{log_name}.log下查看日志")
|
||
worker_str_args = string_args(args, worker_args)
|
||
print(worker_str_args)
|
||
worker_sh = base_launch_sh.format("model_worker", worker_str_args, LOG_PATH, f"worker_{log_name}")
|
||
worker_check_sh = base_check_sh.format(LOG_PATH, f"worker_{log_name}", "model_worker")
|
||
subprocess.run(worker_sh, shell=True, check=True)
|
||
subprocess.run(worker_check_sh, shell=True, check=True)
|
||
|
||
|
||
def launch_all(args,
|
||
controller_args=controller_args,
|
||
worker_args=worker_args,
|
||
server_args=server_args
|
||
):
|
||
print(f"Launching llm service,logs are located in {LOG_PATH}...")
|
||
print(f"开始启动LLM服务,请到{LOG_PATH}下监控各模块日志...")
|
||
controller_str_args = string_args(args, controller_args)
|
||
controller_sh = base_launch_sh.format("controller", controller_str_args, LOG_PATH, "controller")
|
||
controller_check_sh = base_check_sh.format(LOG_PATH, "controller", "controller")
|
||
subprocess.run(controller_sh, shell=True, check=True)
|
||
subprocess.run(controller_check_sh, shell=True, check=True)
|
||
print(f"worker启动时间视设备不同而不同,约需3-10分钟,请耐心等待...")
|
||
if isinstance(args.model_path_address, str):
|
||
launch_worker(args.model_path_address, args=args, worker_args=worker_args)
|
||
else:
|
||
for idx, item in enumerate(args.model_path_address):
|
||
print(f"开始加载第{idx}个模型:{item}")
|
||
launch_worker(item, args=args, worker_args=worker_args)
|
||
|
||
server_str_args = string_args(args, server_args)
|
||
server_sh = base_launch_sh.format("openai_api_server", server_str_args, LOG_PATH, "openai_api_server")
|
||
server_check_sh = base_check_sh.format(LOG_PATH, "openai_api_server", "openai_api_server")
|
||
subprocess.run(server_sh, shell=True, check=True)
|
||
subprocess.run(server_check_sh, shell=True, check=True)
|
||
print("Launching LLM service done!")
|
||
print("LLM服务启动完毕。")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
args = parser.parse_args()
|
||
# 必须要加http//:,否则InvalidSchema: No connection adapters were found
|
||
args = argparse.Namespace(**vars(args),
|
||
**{"controller-address": f"http://{args.controller_host}:{str(args.controller_port)}"})
|
||
|
||
if args.gpus:
|
||
if len(args.gpus.split(",")) < args.num_gpus:
|
||
raise ValueError(
|
||
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
|
||
)
|
||
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
|
||
launch_all(args=args)
|