ai/web_demo/web_demo2.py

109 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import streamlit as st
import torch
from transformers import AutoModel, AutoTokenizer
from api import get_docs
from memory import MyConversationBufferWindowMemory
# 设置页面标题、图标和布局
st.set_page_config(
page_title="ChatGLM3-6B 演示",
page_icon=":robot:",
layout="wide"
)
# 设置为模型ID或本地文件夹路径
model_path = "../THUDM/chatglm3-6b"
@st.cache_resource
def get_model():
tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda()
# 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量
# from utils import load_model_on_gpus
# model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2)
model = model.eval()
return tokenizer, model
# 加载Chatglm3的model和tokenizer
tokenizer, model = get_model()
# 初始化历史记录和past key values
if "history" not in st.session_state:
st.session_state.history = []
if "past_key_values" not in st.session_state:
st.session_state.past_key_values = None
# 设置max_length、top_p和temperature
max_length = st.sidebar.slider("max_length", 0, 32768, 8192, step=1)
top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01)
temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.6, step=0.01)
# 清理会话历史
buttonClean = st.sidebar.button("清理会话历史", key="clean")
if buttonClean:
st.session_state.history = []
st.session_state.past_key_values = None
if torch.cuda.is_available():
torch.cuda.empty_cache()
st.rerun()
# 渲染聊天历史记录
for i, message in enumerate(st.session_state.history):
if message["role"] == "user":
with st.chat_message(name="user", avatar="user"):
st.markdown(message["content"])
else:
with st.chat_message(name="assistant", avatar="assistant"):
st.markdown(message["content"])
# 输入框和输出框
with st.chat_message(name="user", avatar="user"):
input_placeholder = st.empty()
with st.chat_message(name="assistant", avatar="assistant"):
message_placeholder = st.empty()
# system_prompt = "请根据上下文回答我的问题。答案必须是中文。"
system_prompt = ""
memory = MyConversationBufferWindowMemory(k=2)
def build_prompt(prompt_text):
h = memory.load_memory_variables({})['history']
prompt = """你是一个聪明的AI助手你需要通过已知信息和人类与AI助手之间的友好对话来回答人类的问题。\n"""
prompt += f"<已知信息>{get_docs(prompt_text)}</已知信息>\n"
prompt += """下面是人类和AI助手之间的友好对话。AI助手很会健谈并从其上下文中提供了许多具体细节。如果AI助手不知道问题的答案它会如实地说它不知道。
当前对话:\n"""
prompt += h
prompt += f"\n人类:{prompt_text}\n"
prompt += f"\nAI助手"
return prompt
# 获取用户输入
prompt_text = st.chat_input("请输入您的问题")
flag = True
# 如果用户输入了内容,则生成回复
if prompt_text:
input_placeholder.markdown(prompt_text)
history = st.session_state.history
past_key_values = st.session_state.past_key_values
# 这里进行测试将context输入进去
prompt = build_prompt(prompt_text)
print(prompt)
for response, history, past_key_values in model.stream_chat(
tokenizer,
prompt_text,
history,
past_key_values=past_key_values,
max_length=max_length,
top_p=top_p,
temperature=temperature,
return_past_key_values=True,
):
message_placeholder.markdown(response)
# 更新历史记录和past key values
st.session_state.history = history
st.session_state.past_key_values = past_key_values