import requests import json import sys from pathlib import Path sys.path.append(str(Path(__file__).parent.parent.parent)) from configs import BING_SUBSCRIPTION_KEY from server.utils import api_address from pprint import pprint api_base_url = api_address() def dump_input(d, title): print("\n") print("=" * 30 + title + " input " + "="*30) pprint(d) def dump_output(r, title): print("\n") print("=" * 30 + title + " output" + "="*30) for line in r.iter_content(None, decode_unicode=True): print(line, end="", flush=True) headers = { 'accept': 'application/json', 'Content-Type': 'application/json', } data = { "query": "请用100字左右的文字介绍自己", "history": [ { "role": "user", "content": "你好" }, { "role": "assistant", "content": "你好,我是人工智能大模型" } ], "stream": True, "temperature": 0.7, } def test_chat_fastchat(api="/chat/fastchat"): url = f"{api_base_url}{api}" data2 = { "stream": True, "messages": data["history"] + [{"role": "user", "content": "推荐一部科幻电影"}] } dump_input(data2, api) response = requests.post(url, headers=headers, json=data2, stream=True) dump_output(response, api) assert response.status_code == 200 def test_chat_chat(api="/chat/chat"): url = f"{api_base_url}{api}" dump_input(data, api) response = requests.post(url, headers=headers, json=data, stream=True) dump_output(response, api) assert response.status_code == 200 def test_knowledge_chat(api="/chat/knowledge_base_chat"): url = f"{api_base_url}{api}" data = { "query": "如何提问以获得高质量答案", "knowledge_base_name": "samples", "history": [ { "role": "user", "content": "你好" }, { "role": "assistant", "content": "你好,我是 ChatGLM" } ], "stream": True } dump_input(data, api) response = requests.post(url, headers=headers, json=data, stream=True) print("\n") print("=" * 30 + api + " output" + "="*30) for line in response.iter_content(None, decode_unicode=True): data = json.loads(line) if "answer" in data: print(data["answer"], end="", flush=True) pprint(data) assert "docs" in data and len(data["docs"]) > 0 assert response.status_code == 200 def test_search_engine_chat(api="/chat/search_engine_chat"): global data data["query"] = "室温超导最新进展是什么样?" url = f"{api_base_url}{api}" for se in ["bing", "duckduckgo"]: data["search_engine_name"] = se dump_input(data, api + f" by {se}") response = requests.post(url, json=data, stream=True) if se == "bing" and not BING_SUBSCRIPTION_KEY: data = response.json() assert data["code"] == 404 assert data["msg"] == f"要使用Bing搜索引擎,需要设置 `BING_SUBSCRIPTION_KEY`" print("\n") print("=" * 30 + api + f" by {se} output" + "="*30) for line in response.iter_content(None, decode_unicode=True): data = json.loads(line) if "answer" in data: print(data["answer"], end="", flush=True) assert "docs" in data and len(data["docs"]) > 0 pprint(data["docs"]) assert response.status_code == 200