def __init__(self,
function_list: Optional[List[Union[str, Dict, BaseTool]]] = None,
llm: Optional[Union[dict, BaseChatModel]] = None,
system_message: Optional[str] = DEFAULT_SYSTEM_MESSAGE,
name: Optional[str] = None,
description: Optional[str] = None,
**kwargs):
"""Initialization the agent.
Args:
function_list: One list of tool name, tool configuration or Tool object,
such as 'code_interpreter', {'name': 'code_interpreter', 'timeout': 10}, or CodeInterpreter().
llm: The LLM model configuration or LLM model object.
Set the configuration as {'model': '', 'api_key': '', 'model_server': ''}.
system_message: The specified system message for LLM chat.
name: The name of this agent.
description: The description of this agent, which will be used for multi_agent.
"""
# 1.根据传入的llm参数,获取对应的llm模型
if isinstance(llm, dict):
self.llm = get_chat_model(llm)
else:
self.llm = llm
self.extra_generate_cfg: dict = {}
# 2.根据传入的function_list参数,初始化对应的工具
self.function_map = {}
if function_list:
for tool in function_list:
self._init_tool(tool)
# 3.初始化智能体消息、名称、描述等信息
self.system_message = system_message
self.name = name
self.description = description
2.2 非流式运行接口
(1) 传参描述
参数
参数描述
messages
传入消息列表
(2) 代码解析:调用流式运行接口,获取流式输出的最终响应
def run_nonstream(self, messages: List[Union[Dict, Message]], **kwargs) -> Union[List[Message], List[Dict]]:
"""Same as self.run, but with stream=False,
meaning it returns the complete response directly
instead of streaming the response incrementally."""
# 1.调用流式运行接口,获取流式输出的最终响应
*_, last_responses = self.run(messages, **kwargs)
return last_responses
def run(self, messages: List[Union[Dict, Message]],
**kwargs) -> Union[Iterator[List[Message]], Iterator[List[Dict]]]:
"""Return one response generator based on the received messages.
This method performs a uniform type conversion for the inputted messages,
and calls the _run method to generate a reply.
Args:
messages: A list of messages.
Yields:
The response generator.
"""
# 1.深拷贝输入的消息,避免修改原始消息
messages = copy.deepcopy(messages)
_return_message_type = 'dict'
new_messages = []
# Only return dict when all input messages are dict
# 2.遍历输入的消息,将dict类型的消息转换为Message对象
if not messages:
_return_message_type = 'message'
for msg in messages:
if isinstance(msg, dict):
new_messages.append(Message(**msg))
else:
new_messages.append(msg)
_return_message_type = 'message'
# 3.判断是否包含中文消息,如果包含中文消息,则设置lang为'zh',否则设置为'en'
if 'lang' not in kwargs:
if has_chinese_messages(new_messages):
kwargs['lang'] = 'zh'
else:
kwargs['lang'] = 'en'
# 4.如果设置了提示词,则将提示词添加到消息列表的第一个位置
if self.system_message:
if new_messages[0][ROLE] != SYSTEM:
# Add the system instruction to the agent, default to `DEFAULT_SYSTEM_MESSAGE`
new_messages.insert(0, Message(role=SYSTEM, content=self.system_message))
else:
# When the messages contain system message
if self.system_message != DEFAULT_SYSTEM_MESSAGE:
# If the user has set a special system that does not exist in messages, add
if isinstance(new_messages[0][CONTENT], str):
if not new_messages[0][CONTENT].startswith(self.system_message):
new_messages[0][CONTENT] = self.system_message + '\n\n' + new_messages[0][CONTENT]
else:
assert isinstance(new_messages[0][CONTENT], list)
assert new_messages[0][CONTENT][0].text
if not new_messages[0][CONTENT][0].text.startswith(self.system_message):
new_messages[0][CONTENT] = [ContentItem(text=self.system_message + '\n\n')
] + new_messages[0][CONTENT] # noqa
# 5.调用_run方法,获取响应生成器
for rsp in self._run(messages=new_messages, **kwargs):
for i in range(len(rsp)):
if not rsp[i].name and self.name:
rsp[i].name = self.name
if _return_message_type == 'message':
yield [Message(**x) if isinstance(x, dict) else x for x in rsp]
else:
yield [x.model_dump() if not isinstance(x, dict) else x for x in rsp]
2.4 内部运行方法
这是一个抽象方法,需要子类继承并重写实现。
2.5 内部大语言模型调用方法
(1) 传参描述
参数
参数描述
messages
传入消息列表
functions
提供给大语言模型调用的工具列表
stream
大语言模型是否流式输出标志,默认是True,也即是默认流式出书
extra_generate_cfg
一些额外生成的参数,例如,上面根据消息中是否包含中文产生的语言类型信息
(2) 代码解析:调用大语言模型的chat方法
def _call_llm(
self,
messages: List[Message],
functions: Optional[List[Dict]] = None,
stream: bool = True,
extra_generate_cfg: Optional[dict] = None,
) -> Iterator[List[Message]]:
"""The interface of calling LLM for the agent.
We prepend the system_message of this agent to the messages, and call LLM.
Args:
messages: A list of messages.
functions: The list of functions provided to LLM.
stream: LLM streaming output or non-streaming output.
For consistency, we default to using streaming output across all agents.
Yields:
The response generator of LLM.
"""
# 1.调用self.llm的chat方法
return self.llm.chat(messages=messages,
functions=functions,
stream=stream,
extra_generate_cfg=merge_generate_cfgs(
base_generate_cfg=self.extra_generate_cfg,
new_generate_cfg=extra_generate_cfg,
))
def _call_tool(self, tool_name: str, tool_args: Union[str, dict] = '{}', **kwargs) -> Union[str, List[ContentItem]]:
"""The interface of calling tools for the agent.
Args:
tool_name: The name of one tool.
tool_args: Model generated or user given tool parameters.
Returns:
The output of tools.
"""
# 1.判断工具是否存在,如果不存在,则返回错误信息
if tool_name not in self.function_map:
return f'Tool {tool_name} does not exists.'
tool = self.function_map[tool_name]
# 2.调用工具的call方法,获取工具的输出结果
try:
tool_result = tool.call(tool_args, **kwargs)
except (ToolServiceError, DocParserError) as ex:
raise ex
except Exception as ex:
exception_type = type(ex).__name__
exception_message = str(ex)
traceback_info = ''.join(traceback.format_tb(ex.__traceback__))
error_message = f'An error occurred when calling tool `{tool_name}`:\n' \
f'{exception_type}: {exception_message}\n' \
f'Traceback:\n{traceback_info}'
logger.warning(error_message)
return error_message
# 3.判断工具的输出结果类型,如果是字符串,则直接返回,如果是ContentItem列表,则返回ContentItem列表
if isinstance(tool_result, str):
return tool_result
elif isinstance(tool_result, list) and all(isinstance(item, ContentItem) for item in tool_result):
return tool_result # multimodal tool results
else:
return json.dumps(tool_result, ensure_ascii=False, indent=4)
def _init_tool(self, tool: Union[str, Dict, BaseTool]):
# 1.判断传入的工具类型,如果是BaseTool对象,则直接添加到function_map中
if isinstance(tool, BaseTool):
tool_name = tool.name
if tool_name in self.function_map:
logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
self.function_map[tool_name] = tool
# 2.如果是字典类型,则根据字典中的mcpServers字段,初始化MCPManager对象,获取对应的工具列表
elif isinstance(tool, dict) and 'mcpServers' in tool:
tools = MCPManager().initConfig(tool)
for tool in tools:
tool_name = tool.name
if tool_name in self.function_map:
logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
self.function_map[tool_name] = tool
# 3.如果是字符串类型,则根据字符串名称,从TOOL_REGISTRY中获取对应的工具对象
else:
if isinstance(tool, dict):
tool_name = tool['name']
tool_cfg = tool
else:
tool_name = tool
tool_cfg = None
if tool_name not in TOOL_REGISTRY:
raise ValueError(f'Tool {tool_name} is not registered.')
if tool_name in self.function_map:
logger.warning(f'Repeatedly adding tool {tool_name}, will use the newest tool in function list')
self.function_map[tool_name] = TOOL_REGISTRY[tool_name](tool_cfg)
2.8 内部检查工具
(1) 参数解析
参数
参数描述
message
大语言模型产生的消息
(2) 代码解析
def _detect_tool(self, message: Message) -> Tuple[bool, str, str, str]:
"""A built-in tool call detection for func_call format message.
Args:
message: one message generated by LLM.
Returns:
Need to call tool or not, tool name, tool args, text replies.
"""
func_name = None
func_args = None
# 1.判断消息中是否包含函数调用,如果包含,则获取函数名称和参数
if message.function_call:
func_call = message.function_call
func_name = func_call.name
func_args = func_call.arguments
text = message.content
if not text:
text = ''
# 2.返回是否需要调用工具、工具名称、工具参数、文本回复
return (func_name is not None), func_name, func_args, text