149 lines
4.9 KiB
Python
149 lines
4.9 KiB
Python
|
import requests
|
|||
|
import uvicorn
|
|||
|
from fastapi import FastAPI, WebSocket
|
|||
|
import base64
|
|||
|
import datetime
|
|||
|
import hashlib
|
|||
|
import hmac
|
|||
|
import json
|
|||
|
from urllib.parse import urlparse
|
|||
|
from datetime import datetime
|
|||
|
from time import mktime
|
|||
|
from urllib.parse import urlencode
|
|||
|
from wsgiref.handlers import format_date_time
|
|||
|
|
|||
|
import websockets
|
|||
|
|
|||
|
app = FastAPI()
|
|||
|
Spark_url = {
|
|||
|
"v1.1": "wss://spark-api.xf-yun.com/v1.1/chat",
|
|||
|
"v2.1": "wss://spark-api.xf-yun.com/v2.1/chat",
|
|||
|
"v3.1": "wss://spark-api.xf-yun.com/v3.1/chat"
|
|||
|
}
|
|||
|
text = []
|
|||
|
|
|||
|
|
|||
|
class Ws_Param(object):
|
|||
|
# 初始化
|
|||
|
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
|||
|
self.APPID = APPID
|
|||
|
self.APIKey = APIKey
|
|||
|
self.APISecret = APISecret
|
|||
|
self.host = urlparse(Spark_url).netloc
|
|||
|
self.path = urlparse(Spark_url).path
|
|||
|
self.Spark_url = Spark_url
|
|||
|
|
|||
|
# 生成url
|
|||
|
def create_url(self):
|
|||
|
# 生成RFC1123格式的时间戳
|
|||
|
now = datetime.now()
|
|||
|
date = format_date_time(mktime(now.timetuple()))
|
|||
|
|
|||
|
# 拼接字符串
|
|||
|
signature_origin = "host: " + self.host + "\n"
|
|||
|
signature_origin += "date: " + date + "\n"
|
|||
|
signature_origin += "GET " + self.path + " HTTP/1.1"
|
|||
|
|
|||
|
# 进行hmac-sha256进行加密
|
|||
|
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
|||
|
digestmod=hashlib.sha256).digest()
|
|||
|
|
|||
|
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
|||
|
|
|||
|
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
|||
|
|
|||
|
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
|||
|
|
|||
|
# 将请求的鉴权参数组合为字典
|
|||
|
v = {
|
|||
|
"authorization": authorization,
|
|||
|
"date": date,
|
|||
|
"host": self.host
|
|||
|
}
|
|||
|
# 拼接鉴权参数,生成url
|
|||
|
url = self.Spark_url + '?' + urlencode(v)
|
|||
|
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
|||
|
return url
|
|||
|
|
|||
|
|
|||
|
def getText(role, content):
|
|||
|
jsoncon = {}
|
|||
|
jsoncon["role"] = role
|
|||
|
jsoncon["content"] = content
|
|||
|
text.append(jsoncon)
|
|||
|
return text
|
|||
|
|
|||
|
|
|||
|
def getlength(text):
|
|||
|
length = 0
|
|||
|
for content in text:
|
|||
|
temp = content["content"]
|
|||
|
leng = len(temp)
|
|||
|
length += leng
|
|||
|
return length
|
|||
|
|
|||
|
|
|||
|
def checklen(text):
|
|||
|
while (getlength(text) >= 8192):
|
|||
|
del text[0]
|
|||
|
return text
|
|||
|
|
|||
|
|
|||
|
def v1wsUrl(version, appid, api_secret, api_key):
|
|||
|
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url[version])
|
|||
|
wsUrl = wsParam.create_url()
|
|||
|
return wsUrl
|
|||
|
|
|||
|
|
|||
|
def get_query(data, knowledge_base_name, score_threshold, top_k):
|
|||
|
content = data["payload"]["message"]["text"][-1] # 数组最后一个
|
|||
|
query = content["content"]
|
|||
|
q = {
|
|||
|
"knowledge_base_name": knowledge_base_name,
|
|||
|
"query": query,
|
|||
|
"score_threshold": score_threshold,
|
|||
|
"top_k": top_k
|
|||
|
}
|
|||
|
docs = requests.post("http://127.0.0.1:7861/knowledge_base/search_docs", json=q).json()
|
|||
|
if len(docs) <= 0:
|
|||
|
return query
|
|||
|
contexts = [doc['page_content'] for doc in docs]
|
|||
|
wiki_content = "\n".join(contexts)
|
|||
|
prompt = "请将以下内容作为已知信息:\n" + wiki_content + (
|
|||
|
"\n请根据以上内容回答用户的问题。\n问题:\n") + query + "\n回答: "
|
|||
|
return prompt
|
|||
|
|
|||
|
|
|||
|
@app.websocket("/{version}/chat")
|
|||
|
async def wsk(ws: WebSocket, version: str):
|
|||
|
try:
|
|||
|
await ws.accept()
|
|||
|
appid = ws.query_params.get("appid")
|
|||
|
api_secret = ws.query_params.get("api_secret")
|
|||
|
api_key = ws.query_params.get("api_key")
|
|||
|
knowledge_base_name = ws.query_params.get("knowledge_base_name") or "threshold"
|
|||
|
score_threshold = ws.query_params.get("score_threshold") or 0.5
|
|||
|
top_k = ws.query_params.get("top_k") or 2
|
|||
|
wsUrl = v1wsUrl(version, appid, api_secret, api_key)
|
|||
|
r = await ws.receive_json()
|
|||
|
r["payload"]["message"]["text"][-1]["content"] = get_query(r, knowledge_base_name, score_threshold, top_k)
|
|||
|
r["payload"]["message"]["text"] = checklen(r["payload"]["message"]["text"])
|
|||
|
try:
|
|||
|
async with websockets.connect(wsUrl) as websocket:
|
|||
|
await websocket.send(json.dumps(r))
|
|||
|
async for message in websocket:
|
|||
|
data = json.loads(message)
|
|||
|
await ws.send_text(json.dumps(data, ensure_ascii=False))
|
|||
|
status = data["payload"]["choices"]["status"]
|
|||
|
if status == 2:
|
|||
|
await ws.close()
|
|||
|
except Exception as e:
|
|||
|
await ws.send_text(str(e))
|
|||
|
except Exception as e:
|
|||
|
print(e)
|
|||
|
await ws.close()
|
|||
|
|
|||
|
|
|||
|
if __name__ == '__main__':
|
|||
|
uvicorn.run(app, host="0.0.0.0", port=8005)
|