from __future__ import annotations from uuid import UUID from langchain.callbacks import AsyncIteratorCallbackHandler import json import asyncio from typing import Any, Dict, List, Optional from langchain.schema import AgentFinish, AgentAction from langchain.schema.output import LLMResult def dumps(obj: Dict) -> str: return json.dumps(obj, ensure_ascii=False) class Status: start: int = 1 running: int = 2 complete: int = 3 agent_action: int = 4 agent_finish: int = 5 error: int = 6 tool_finish: int = 7 class CustomAsyncIteratorCallbackHandler(AsyncIteratorCallbackHandler): def __init__(self): super().__init__() self.queue = asyncio.Queue() self.done = asyncio.Event() self.cur_tool = {} self.out = True async def on_tool_start(self, serialized: Dict[str, Any], input_str: str, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, metadata: Dict[str, Any] | None = None, **kwargs: Any) -> None: # 对于截断不能自理的大模型,我来帮他截断 stop_words = ["Observation:", "Thought","\"","(", "\n","\t"] for stop_word in stop_words: index = input_str.find(stop_word) if index != -1: input_str = input_str[:index] break self.cur_tool = { "tool_name": serialized["name"], "input_str": input_str, "output_str": "", "status": Status.agent_action, "run_id": run_id.hex, "llm_token": "", "final_answer": "", "error": "", } # print("\nInput Str:",self.cur_tool["input_str"]) self.queue.put_nowait(dumps(self.cur_tool)) async def on_tool_end(self, output: str, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None: self.out = True ## 重置输出 self.cur_tool.update( status=Status.tool_finish, output_str=output.replace("Answer:", ""), ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_tool_error(self, error: Exception | KeyboardInterrupt, *, run_id: UUID, parent_run_id: UUID | None = None, tags: List[str] | None = None, **kwargs: Any) -> None: self.cur_tool.update( status=Status.error, error=str(error), ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_llm_new_token(self, token: str, **kwargs: Any) -> None: if "Action" in token: ## 减少重复输出 before_action = token.split("Action")[0] self.cur_tool.update( status=Status.running, llm_token=before_action + "\n", ) self.queue.put_nowait(dumps(self.cur_tool)) self.out = False if token and self.out: self.cur_tool.update( status=Status.running, llm_token=token, ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_llm_start(self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any) -> None: self.cur_tool.update( status=Status.start, llm_token="", ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_chat_model_start( self, serialized: Dict[str, Any], messages: List[List], *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> None: self.cur_tool.update( status=Status.start, llm_token="", ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: self.cur_tool.update( status=Status.complete, llm_token="\n", ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_llm_error(self, error: Exception | KeyboardInterrupt, **kwargs: Any) -> None: self.cur_tool.update( status=Status.error, error=str(error), ) self.queue.put_nowait(dumps(self.cur_tool)) async def on_agent_finish( self, finish: AgentFinish, *, run_id: UUID, parent_run_id: Optional[UUID] = None, tags: Optional[List[str]] = None, **kwargs: Any, ) -> None: # 返回最终答案 self.cur_tool.update( status=Status.agent_finish, final_answer=finish.return_values["output"], ) self.queue.put_nowait(dumps(self.cur_tool)) self.cur_tool = {}