import json import streamlit as st import torch from transformers import AutoModel, AutoTokenizer # from api import get_docs from server.knowledge_base.kb_doc_api import search_docs # 设置页面标题、图标和布局 st.set_page_config( page_title="ChatGLM3-6B 演示", page_icon=":robot:", layout="wide" ) # 设置为模型ID或本地文件夹路径 model_path = "THUDM/chatglm3-6b-32k" def get_docs(query, top_k, score_threshold, knowledge_base_name): docs = search_docs(query, knowledge_base_name, top_k, score_threshold) context1 = "\n".join([doc.page_content for doc in docs]) return context1 @st.cache_resource def get_model(): tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model = AutoModel.from_pretrained(model_path, trust_remote_code=True).cuda() # 多显卡支持,使用下面两行代替上面一行,将num_gpus改为你实际的显卡数量 # from utils import load_model_on_gpus # model = load_model_on_gpus("THUDM/chatglm3-6b", num_gpus=2) model = model.eval() return tokenizer, model # 加载Chatglm3的model和tokenizer tokenizer, model = get_model() system_prompt = """你是一个农业领域的AI助手,你需要通过已知信息和历史对话来回答人类的问题。 特别注意:已知信息不一定是正确答案,还要根据对话的历史问题和答案进行思考回答,如果找不到答案,请用自己的知识进行回答,不可以乱回答!!! 答案使用中文。\n""" # 初始化历史记录和past key values if "history" not in st.session_state: st.session_state.history = [{'role': 'system', 'content': system_prompt}] if "past_key_values" not in st.session_state: st.session_state.past_key_values = None if "all_prompt_text" not in st.session_state: st.session_state.all_prompt_text = [] # 设置max_length、top_p和temperature max_length = st.sidebar.slider("max_length", 0, 32768, 32768, step=1024) top_p = st.sidebar.slider("top_p", 0.0, 1.0, 0.8, step=0.01) temperature = st.sidebar.slider("temperature", 0.0, 1.0, 0.6, step=0.01) # 清理会话历史 buttonClean = st.sidebar.button("清理会话历史", key="clean") if buttonClean: st.session_state.history = [] st.session_state.past_key_values = None if torch.cuda.is_available(): torch.cuda.empty_cache() st.rerun() # 渲染聊天历史记录 for i, message in enumerate(st.session_state.history): if message["role"] == "user": with st.chat_message(name="user", avatar="user"): st.markdown(message["content"]) else: with st.chat_message(name="assistant", avatar="assistant"): st.markdown(message["content"]) # 输入框和输出框 with st.chat_message(name="user", avatar="user"): input_placeholder = st.empty() with st.chat_message(name="assistant", avatar="assistant"): message_placeholder = st.empty() def build_prompt(prompt_text, new_query): knowledge_base_name = 'agriculture' context = get_docs(new_query, 3, 1, knowledge_base_name) print("query:", new_query) print("context:", context) prompt = f"\n\n<已知信息>:\n{context}\n\n\n" prompt += f"\n\n<问题>:{prompt_text}\n\n" return prompt # 获取用户输入 prompt_text = st.chat_input("请输入您的问题") # 如果用户输入了内容,则生成回复 if prompt_text: input_placeholder.markdown(prompt_text) history = st.session_state.history with open('a.json', 'w') as f: json.dump(history, f, ensure_ascii=False) past_key_values = st.session_state.past_key_values new_query = "" assistants = [] # 历史问题 st.session_state.all_prompt_text.append(prompt_text) # 历史助手 for h in history: # 助手 if h['role'] == 'assistant': assistants.append(h['content']) # 将历史问题和历史助手组合成新问题 for i in range(len(st.session_state.all_prompt_text)): new_query += st.session_state.all_prompt_text[i] if i < len(st.session_state.all_prompt_text) - 1: new_query += "," else: new_query += "?" if i < len(assistants): new_query += assistants[i] history = [] for k, v in enumerate(assistants): a = assistants[k] q = st.session_state.all_prompt_text[k] history.append({'role': 'user', 'content': q}) history.append({'role': 'user', 'content': a}) prompt = build_prompt(prompt_text, new_query) # TODO # 控制历史信息、滑动窗口的长度! for response, history, past_key_values in model.stream_chat( tokenizer, prompt, history, past_key_values=past_key_values, max_length=max_length, top_p=top_p, temperature=temperature, return_past_key_values=True, ): message_placeholder.markdown(response) # 历史assistant # print(history) # 更新历史记录和past key values st.session_state.history = history st.session_state.past_key_values = past_key_values