149 lines
4.9 KiB
Python
149 lines
4.9 KiB
Python
import requests
|
||
import uvicorn
|
||
from fastapi import FastAPI, WebSocket
|
||
import base64
|
||
import datetime
|
||
import hashlib
|
||
import hmac
|
||
import json
|
||
from urllib.parse import urlparse
|
||
from datetime import datetime
|
||
from time import mktime
|
||
from urllib.parse import urlencode
|
||
from wsgiref.handlers import format_date_time
|
||
|
||
import websockets
|
||
|
||
app = FastAPI()
|
||
Spark_url = {
|
||
"v1.1": "wss://spark-api.xf-yun.com/v1.1/chat",
|
||
"v2.1": "wss://spark-api.xf-yun.com/v2.1/chat",
|
||
"v3.1": "wss://spark-api.xf-yun.com/v3.1/chat"
|
||
}
|
||
text = []
|
||
|
||
|
||
class Ws_Param(object):
|
||
# 初始化
|
||
def __init__(self, APPID, APIKey, APISecret, Spark_url):
|
||
self.APPID = APPID
|
||
self.APIKey = APIKey
|
||
self.APISecret = APISecret
|
||
self.host = urlparse(Spark_url).netloc
|
||
self.path = urlparse(Spark_url).path
|
||
self.Spark_url = Spark_url
|
||
|
||
# 生成url
|
||
def create_url(self):
|
||
# 生成RFC1123格式的时间戳
|
||
now = datetime.now()
|
||
date = format_date_time(mktime(now.timetuple()))
|
||
|
||
# 拼接字符串
|
||
signature_origin = "host: " + self.host + "\n"
|
||
signature_origin += "date: " + date + "\n"
|
||
signature_origin += "GET " + self.path + " HTTP/1.1"
|
||
|
||
# 进行hmac-sha256进行加密
|
||
signature_sha = hmac.new(self.APISecret.encode('utf-8'), signature_origin.encode('utf-8'),
|
||
digestmod=hashlib.sha256).digest()
|
||
|
||
signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding='utf-8')
|
||
|
||
authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'
|
||
|
||
authorization = base64.b64encode(authorization_origin.encode('utf-8')).decode(encoding='utf-8')
|
||
|
||
# 将请求的鉴权参数组合为字典
|
||
v = {
|
||
"authorization": authorization,
|
||
"date": date,
|
||
"host": self.host
|
||
}
|
||
# 拼接鉴权参数,生成url
|
||
url = self.Spark_url + '?' + urlencode(v)
|
||
# 此处打印出建立连接时候的url,参考本demo的时候可取消上方打印的注释,比对相同参数时生成的url与自己代码生成的url是否一致
|
||
return url
|
||
|
||
|
||
def getText(role, content):
|
||
jsoncon = {}
|
||
jsoncon["role"] = role
|
||
jsoncon["content"] = content
|
||
text.append(jsoncon)
|
||
return text
|
||
|
||
|
||
def getlength(text):
|
||
length = 0
|
||
for content in text:
|
||
temp = content["content"]
|
||
leng = len(temp)
|
||
length += leng
|
||
return length
|
||
|
||
|
||
def checklen(text):
|
||
while (getlength(text) >= 8192):
|
||
del text[0]
|
||
return text
|
||
|
||
|
||
def v1wsUrl(version, appid, api_secret, api_key):
|
||
wsParam = Ws_Param(appid, api_key, api_secret, Spark_url[version])
|
||
wsUrl = wsParam.create_url()
|
||
return wsUrl
|
||
|
||
|
||
def get_query(data, knowledge_base_name, score_threshold, top_k):
|
||
content = data["payload"]["message"]["text"][-1] # 数组最后一个
|
||
query = content["content"]
|
||
q = {
|
||
"knowledge_base_name": knowledge_base_name,
|
||
"query": query,
|
||
"score_threshold": score_threshold,
|
||
"top_k": top_k
|
||
}
|
||
docs = requests.post("http://127.0.0.1:7861/knowledge_base/search_docs", json=q).json()
|
||
if len(docs) <= 0:
|
||
return query
|
||
contexts = [doc['page_content'] for doc in docs]
|
||
wiki_content = "\n".join(contexts)
|
||
prompt = "请将以下内容作为已知信息:\n" + wiki_content + (
|
||
"\n请根据以上内容回答用户的问题。\n问题:\n") + query + "\n回答: "
|
||
return prompt
|
||
|
||
|
||
@app.websocket("/{version}/chat")
|
||
async def wsk(ws: WebSocket, version: str):
|
||
try:
|
||
await ws.accept()
|
||
appid = ws.query_params.get("appid")
|
||
api_secret = ws.query_params.get("api_secret")
|
||
api_key = ws.query_params.get("api_key")
|
||
knowledge_base_name = ws.query_params.get("knowledge_base_name") or "threshold"
|
||
score_threshold = ws.query_params.get("score_threshold") or 0.5
|
||
top_k = ws.query_params.get("top_k") or 2
|
||
wsUrl = v1wsUrl(version, appid, api_secret, api_key)
|
||
r = await ws.receive_json()
|
||
r["payload"]["message"]["text"][-1]["content"] = get_query(r, knowledge_base_name, score_threshold, top_k)
|
||
r["payload"]["message"]["text"] = checklen(r["payload"]["message"]["text"])
|
||
try:
|
||
async with websockets.connect(wsUrl) as websocket:
|
||
await websocket.send(json.dumps(r))
|
||
async for message in websocket:
|
||
data = json.loads(message)
|
||
await ws.send_text(json.dumps(data, ensure_ascii=False))
|
||
status = data["payload"]["choices"]["status"]
|
||
if status == 2:
|
||
await ws.close()
|
||
except Exception as e:
|
||
await ws.send_text(str(e))
|
||
except Exception as e:
|
||
print(e)
|
||
await ws.close()
|
||
|
||
|
||
if __name__ == '__main__':
|
||
uvicorn.run(app, host="0.0.0.0", port=8005)
|