124 lines
3.9 KiB
Python
124 lines
3.9 KiB
Python
|
import json
|
||
|
from langchain.llms.base import LLM
|
||
|
from transformers import AutoTokenizer, AutoModel, AutoConfig
|
||
|
from typing import List, Optional
|
||
|
from utils import tool_config_from_file
|
||
|
|
||
|
|
||
|
class ChatGLM3(LLM):
|
||
|
max_token: int = 8192
|
||
|
do_sample: bool = False
|
||
|
temperature: float = 0.8
|
||
|
top_p = 0.8
|
||
|
tokenizer: object = None
|
||
|
model: object = None
|
||
|
history: List = []
|
||
|
tool_names: List = []
|
||
|
has_search: bool = False
|
||
|
|
||
|
def __init__(self):
|
||
|
super().__init__()
|
||
|
|
||
|
@property
|
||
|
def _llm_type(self) -> str:
|
||
|
return "ChatGLM3"
|
||
|
|
||
|
def load_model(self, model_name_or_path=None):
|
||
|
model_config = AutoConfig.from_pretrained(
|
||
|
model_name_or_path,
|
||
|
trust_remote_code=True
|
||
|
)
|
||
|
self.tokenizer = AutoTokenizer.from_pretrained(
|
||
|
model_name_or_path,
|
||
|
trust_remote_code=True
|
||
|
)
|
||
|
self.model = AutoModel.from_pretrained(
|
||
|
model_name_or_path, config=model_config, trust_remote_code=True
|
||
|
).half().cuda()
|
||
|
|
||
|
def _tool_history(self, prompt: str):
|
||
|
ans = []
|
||
|
tool_prompts = prompt.split(
|
||
|
"You have access to the following tools:\n\n")[1].split("\n\nUse a json blob")[0].split("\n")
|
||
|
|
||
|
tool_names = [tool.split(":")[0] for tool in tool_prompts]
|
||
|
self.tool_names = tool_names
|
||
|
tools_json = []
|
||
|
for i, tool in enumerate(tool_names):
|
||
|
tool_config = tool_config_from_file(tool)
|
||
|
if tool_config:
|
||
|
tools_json.append(tool_config)
|
||
|
else:
|
||
|
ValueError(
|
||
|
f"Tool {tool} config not found! It's description is {tool_prompts[i]}"
|
||
|
)
|
||
|
|
||
|
ans.append({
|
||
|
"role": "system",
|
||
|
"content": "Answer the following questions as best as you can. You have access to the following tools:",
|
||
|
"tools": tools_json
|
||
|
})
|
||
|
query = f"""{prompt.split("Human: ")[-1].strip()}"""
|
||
|
return ans, query
|
||
|
|
||
|
def _extract_observation(self, prompt: str):
|
||
|
return_json = prompt.split("Observation: ")[-1].split("\nThought:")[0]
|
||
|
self.history.append({
|
||
|
"role": "observation",
|
||
|
"content": return_json
|
||
|
})
|
||
|
return
|
||
|
|
||
|
def _extract_tool(self):
|
||
|
if len(self.history[-1]["metadata"]) > 0:
|
||
|
metadata = self.history[-1]["metadata"]
|
||
|
content = self.history[-1]["content"]
|
||
|
if "tool_call" in content:
|
||
|
for tool in self.tool_names:
|
||
|
if tool in metadata:
|
||
|
input_para = content.split("='")[-1].split("'")[0]
|
||
|
action_json = {
|
||
|
"action": tool,
|
||
|
"action_input": input_para
|
||
|
}
|
||
|
self.has_search = True
|
||
|
return f"""
|
||
|
Action:
|
||
|
```
|
||
|
{json.dumps(action_json, ensure_ascii=False)}
|
||
|
```"""
|
||
|
final_answer_json = {
|
||
|
"action": "Final Answer",
|
||
|
"action_input": self.history[-1]["content"]
|
||
|
}
|
||
|
self.has_search = False
|
||
|
return f"""
|
||
|
Action:
|
||
|
```
|
||
|
{json.dumps(final_answer_json, ensure_ascii=False)}
|
||
|
```"""
|
||
|
|
||
|
def _call(self, prompt: str, history: List = [], stop: Optional[List[str]] = ["<|user|>"]):
|
||
|
print("======")
|
||
|
print(prompt)
|
||
|
print("======")
|
||
|
if not self.has_search:
|
||
|
self.history, query = self._tool_history(prompt)
|
||
|
else:
|
||
|
self._extract_observation(prompt)
|
||
|
query = ""
|
||
|
# print("======")
|
||
|
# print(history)
|
||
|
# print("======")
|
||
|
_, self.history = self.model.chat(
|
||
|
self.tokenizer,
|
||
|
query,
|
||
|
history=self.history,
|
||
|
do_sample=self.do_sample,
|
||
|
max_length=self.max_token,
|
||
|
temperature=self.temperature,
|
||
|
)
|
||
|
response = self._extract_tool()
|
||
|
history.append((prompt, response))
|
||
|
return response
|