ai/langchain_demo/main.py

58 lines
1.7 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

from typing import List
from ChatGLM3 import ChatGLM3
from langchain.agents import load_tools
from Tool.Weather import Weather
from Tool.Calculator import Calculator
from langchain.agents import initialize_agent
from langchain.agents import AgentType
def run_tool(tools, llm, prompt_chain: List[str]):
loaded_tolls = []
for tool in tools:
if isinstance(tool, str):
loaded_tolls.append(load_tools([tool], llm=llm)[0])
else:
loaded_tolls.append(tool)
agent = initialize_agent(
loaded_tolls, llm,
agent=AgentType.STRUCTURED_CHAT_ZERO_SHOT_REACT_DESCRIPTION,
verbose=True,
handle_parsing_errors=True
)
for prompt in prompt_chain:
agent.run(prompt)
if __name__ == "__main__":
model_path = "../THUDM/chatglm3-6b"
llm = ChatGLM3()
llm.load_model(model_name_or_path=model_path)
# arxiv: 单个工具调用示例 1
# run_tool(["arxiv"], llm, [
# "帮我查询GLM-130B相关工作"
# ])
# weather: 单个工具调用示例 2
run_tool([Weather()], llm, [
"今天北京天气怎么样?",
"What's the weather like in Shanghai today",
])
# calculator: 单个工具调用示例 3
# run_tool([Calculator()], llm, [
# "12345679乘以54等于多少",
# "3.14的3.14次方等于多少?",
# "根号2加上根号三等于多少",
# ]),
# arxiv + weather + calculator: 多个工具结合调用
# run_tool([Calculator(), "arxiv", Weather()], llm, [
# "帮我检索GLM-130B相关论文",
# "今天北京天气怎么样?",
# "根号3减去根号二再加上4等于多少",
# ])