import requests import uvicorn from fastapi import FastAPI, WebSocket import base64 import datetime import hashlib import hmac import json from urllib.parse import urlparse from datetime import datetime from time import mktime from urllib.parse import urlencode from wsgiref.handlers import format_date_time import websockets app = FastAPI() Spark_url = { "v1.1": "wss://spark-api.xf-yun.com/v1.1/chat", "v2.1": "wss://spark-api.xf-yun.com/v2.1/chat", "v3.1": "wss://spark-api.xf-yun.com/v3.1/chat" } text = [] class Ws_Param(object): # 初始化 def __init__(self, APPID, APIKey, APISecret, Spark_url): self.APPID = APPID self.APIKey = APIKey self.APISecret = APISecret self.host = urlparse(Spark_url).netloc self.path = urlparse(Spark_url).path self.Spark_url = Spark_url # 生成url def create_url(self): # 生成RFC1123格式的时间戳 now = datetime.now() date = format_date_time(mktime(now.timetuple())) # 拼接字符串 signature_origin = "host: " + self.host + "\n" signature_origin += "date: " + date + "\n" signature_origin += "GET " + self.path + " HTTP/1.1" # 进行hmac-sha256进行加密 signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'), digestmod=hashlib.sha256).digest() signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8') authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"' authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8') # 将请求的鉴权参数组合为字典 v = { "authorization": authorization, "date": date, "host": self.host } # 拼接鉴权参数,生成url url = self.Spark_url + '?' + urlencode(v) # 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致 return url def getText(role, content): jsoncon = {} jsoncon["role"] = role jsoncon["content"] = content text.append(jsoncon) return text def getlength(text): length = 0 for content in text: temp = content["content"] leng = len(temp) length += leng return length def checklen(text): while (getlength(text) >= 8192): del text[0] return text def v1wsUrl(version, appid, api_secret, api_key): wsParam = Ws_Param(appid, api_key, api_secret, Spark_url[version]) wsUrl = wsParam.create_url() return wsUrl def get_query(data, knowledge_base_name, score_threshold, top_k): content = data["payload"]["message"]["text"][-1] # 数组最后一个 query = content["content"] q = { "knowledge_base_name": knowledge_base_name, "query": query, "score_threshold": score_threshold, "top_k": top_k } docs = requests.post("http://127.0.0.1:7861/knowledge_base/search_docs", json=q).json() if len(docs) <= 0: return query contexts = [doc['page_content'] for doc in docs] wiki_content = "\n".join(contexts) prompt = "请将以下内容作为已知信息:\n" + wiki_content + ( "\n请根据以上内容回答用户的问题。\n问题:\n") + query + "\n回答: " return prompt @app.websocket("/{version}/chat") async def wsk(ws: WebSocket, version: str): try: await ws.accept() appid = ws.query_params.get("appid") api_secret = ws.query_params.get("api_secret") api_key = ws.query_params.get("api_key") knowledge_base_name = ws.query_params.get("knowledge_base_name") or "threshold" score_threshold = ws.query_params.get("score_threshold") or 0.5 top_k = ws.query_params.get("top_k") or 2 wsUrl = v1wsUrl(version, appid, api_secret, api_key) r = await ws.receive_json() r["payload"]["message"]["text"][-1]["content"] = get_query(r, knowledge_base_name, score_threshold, top_k) r["payload"]["message"]["text"] = checklen(r["payload"]["message"]["text"]) try: async with websockets.connect(wsUrl) as websocket: await websocket.send(json.dumps(r)) async for message in websocket: data = json.loads(message) await ws.send_text(json.dumps(data, ensure_ascii=False)) status = data["payload"]["choices"]["status"] if status == 2: await ws.close() except Exception as e: await ws.send_text(str(e)) except Exception as e: print(e) await ws.close() if __name__ == '__main__': uvicorn.run(app, host="0.0.0.0", port=8005)