import json from langchain.llms.base import LLM from transformers import AutoTokenizer, AutoModel, AutoConfig from typing import List, Optional from utils import tool_config_from_file class ChatGLM3(LLM): max_token: int = 8192 do_sample: bool = False temperature: float = 0.8 top_p = 0.8 tokenizer: object = None model: object = None history: List = [] tool_names: List = [] has_search: bool = False def __init__(self): super().__init__() @property def _llm_type(self) -> str: return "ChatGLM3" def load_model(self, model_name_or_path=None): model_config = AutoConfig.from_pretrained( model_name_or_path, trust_remote_code=True ) self.tokenizer = AutoTokenizer.from_pretrained( model_name_or_path, trust_remote_code=True ) self.model = AutoModel.from_pretrained( model_name_or_path, config=model_config, trust_remote_code=True ).half().cuda() def _tool_history(self, prompt: str): ans = [] tool_prompts = prompt.split( "You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n") tool_names = [tool.split(":")[0] for tool in tool_prompts] self.tool_names = tool_names tools_json = [] for i, tool in enumerate(tool_names): tool_config = tool_config_from_file(tool) if tool_config: tools_json.append(tool_config) else: ValueError( f"Tool {tool} config not found! It's description is {tool_prompts[i]}" ) ans.append({ "role": "system", "content": "Answer the following questions as best as you can. You have access to the following tools:", "tools": tools_json }) query = f"""{prompt.split("Human: ")[-1].strip()}""" return ans, query def _extract_observation(self, prompt: str): return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0] self.history.append({ "role": "observation", "content": return_json }) return def _extract_tool(self): if len(self.history[-1]["metadata"]) > 0: metadata = self.history[-1]["metadata"] content = self.history[-1]["content"] if "tool_call" in content: for tool in self.tool_names: if tool in metadata: input_para = content.split("='")[-1].split("'")[0] action_json = { "action": tool, "action_input": input_para } self.has_search = True return f""" Action: ``` {json.dumps(action_json, ensure_ascii=False)} ```""" final_answer_json = { "action": "Final Answer", "action_input": self.history[-1]["content"] } self.has_search = False return f""" Action: ``` {json.dumps(final_answer_json, ensure_ascii=False)} ```""" def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]): print("======") print(prompt) print("======") if not self.has_search: self.history, query = self._tool_history(prompt) else: self._extract_observation(prompt) query = "" # print("======") # print(history) # print("======") _, self.history = self.model.chat( self.tokenizer, query, history=self.history, do_sample=self.do_sample, max_length=self.max_token, temperature=self.temperature, ) response = self._extract_tool() history.append((prompt, response)) return response