feat(agent): 新增MCP Agent客户端和工具系统
- 实现了基于LangChain的MCP Agent,支持连接MCP服务器调用工具 - 添加了环境配置文件(.env),包含LLM模型和API配置信息 - 创建了完整的工具系统,包括BaseTool基类和Bash、Terminate、Add等工具 - 集成了天气查询工具,支持通过中国气象局API获取天气预报信息 - 实现了交互式对话功能,支持多轮工具调用和结果处理 - 添加了详细的CLAUDE.md开发指导文档
This commit is contained in:
parent
842e8f4f72
commit
feb1a0b280
5
.env
Normal file
5
.env
Normal file
@ -0,0 +1,5 @@
|
||||
LLM_MODEL=qwen3-30b-a3b-instruct-2507
|
||||
LLM_API_KEY=sk-2c4834c388724189903e8011420d6c47
|
||||
LLM_BASE_URL=https://dashscope.aliyuncs.com/compatible-mode/v1
|
||||
APIHZ_ID=10007673
|
||||
APIHZ_KEY=be63d2fb5354f76abb18f62583edffae
|
||||
6
.idea/ai_toolkit.xml
generated
Normal file
6
.idea/ai_toolkit.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="AI Toolkit Settings">
|
||||
<option name="importsOfInterestPresent" value="true" />
|
||||
</component>
|
||||
</project>
|
||||
6
.idea/copilot.data.migration.agent.xml
generated
Normal file
6
.idea/copilot.data.migration.agent.xml
generated
Normal file
@ -0,0 +1,6 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="AgentMigrationStateService">
|
||||
<option name="migrationStatus" value="COMPLETED" />
|
||||
</component>
|
||||
</project>
|
||||
3
.idea/misc.xml
generated
3
.idea/misc.xml
generated
@ -1,4 +1,7 @@
|
||||
<?xml version="1.0" encoding="UTF-8"?>
|
||||
<project version="4">
|
||||
<component name="Black">
|
||||
<option name="sdkName" value="Python 3.12 (TestMCP)" />
|
||||
</component>
|
||||
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12 (TestMCP)" project-jdk-type="Python SDK" />
|
||||
</project>
|
||||
106
CLAUDE.md
Normal file
106
CLAUDE.md
Normal file
@ -0,0 +1,106 @@
|
||||
# CLAUDE.md
|
||||
|
||||
此文件为 Claude Code(claude.ai/code)在处理此存储库中的代码时提供指导。
|
||||
|
||||
# 交互规则(优先级最高)
|
||||
1. 所有与本仓库相关的沟通、文档编辑、代码注释、解释说明均使用**中文**
|
||||
2. 技术术语以及命令保留英文原词(如 Langchain、Deepagents),但解释说明必须用中文
|
||||
3. 代码本身的语法/关键字保持英文(符合编程规范),但注释、文档、交互回复全部使用中文
|
||||
4. 当前项目为技术验证项目,目前使用Langchain、Deepagents等框架,遇到框架相关问题需要首先参考官方技术文档,链接:
|
||||
```text
|
||||
[Langchain](https://docs.langchain.com/oss/python/langchain/overview)、[Deepagents](https://docs.langchain.com/oss/python/deepagents/overview)
|
||||
```
|
||||
|
||||
## 项目概述
|
||||
|
||||
TestMCP 是一个 MCP (Model Context Protocol) 服务器实现,带有 OpenAI 兼容的客户端,用于构建 AI Agent 系统。该项目是 "OpenManus" 的简化/修改版本 - 一个支持工具调用的 AI Agent 框架。
|
||||
|
||||
## 命令
|
||||
|
||||
### 运行 Demo MCP 服务器 (FastMCP 带天气工具)
|
||||
```bash
|
||||
python main.py
|
||||
```
|
||||
运行 SSE 服务器,默认端口,包含 `add()` 和 `get_weather_by_location()` 工具。
|
||||
|
||||
### 运行 OpenManus 风格的 MCP 服务器
|
||||
```bash
|
||||
python -m app.mcp_server.server --transport stdio
|
||||
```
|
||||
|
||||
### 运行客户端
|
||||
```bash
|
||||
python client.py
|
||||
```
|
||||
通过 SSE 连接到 `http://127.0.0.1:8000/sse` 的 MCP 服务器。
|
||||
|
||||
### 启动Agent调用本地工具
|
||||
```bash
|
||||
python tool_call_agent.py
|
||||
```
|
||||
|
||||
### 包管理
|
||||
项目使用 `uv` 作为包管理器:
|
||||
```bash
|
||||
uv sync # 安装依赖
|
||||
uv add <package> # 添加新依赖
|
||||
```
|
||||
|
||||
## 架构
|
||||
|
||||
### 核心组件
|
||||
|
||||
```
|
||||
app/
|
||||
├── mcp_server/
|
||||
│ └── server.py # MCPServer 类,包含 FastMCP 和工具注册系统
|
||||
├── tools/
|
||||
│ ├── base.py # BaseTool (抽象类), ToolResult, CLIResult
|
||||
│ ├── bash.py # Bash 命令执行工具
|
||||
│ └── terminate.py # 终止工具
|
||||
├── utils/
|
||||
│ └── logger.py # 基于 Loguru 的日志系统,支持文件轮转
|
||||
└── exceptions.py # ToolError, OpenManusError, TokenLimitExceeded
|
||||
```
|
||||
|
||||
### 入口点
|
||||
|
||||
1. **`client.py`** - 主客户端入口
|
||||
- `AutoToolChatSession`: 通过 SSE 连接 MCP 服务器,调用 LLM,执行工具
|
||||
- 支持 OpenAI 兼容模型 (DashScope, Ollama 等)
|
||||
|
||||
2. **`main.py`** - 使用 FastMCP 的 Demo 服务器
|
||||
- 工具:`add()`, `get_weather_by_location()`
|
||||
- 资源:`greeting://{name}`
|
||||
- 提示词:`greet_user()`
|
||||
|
||||
3. **`app/mcp_server/server.py`** - OpenManus 风格的 MCP 服务器
|
||||
- `MCPServer` 类,支持动态工具注册
|
||||
- 内置工具:bash, terminate
|
||||
|
||||
### 工具系统
|
||||
|
||||
所有工具继承自 `BaseTool` (位于 `app/tools/base.py`):
|
||||
- 必须实现 `async execute(self, **kwargs) -> ToolResult`
|
||||
- 工具参数通过 `parameters` 字典定义 (JSON Schema 格式)
|
||||
- 结果通过 `ToolResult` 返回,包含 `output`, `error`, `base64_image` 字段
|
||||
|
||||
### 环境变量配置
|
||||
|
||||
通过 `.env` 文件配置必需的环境变量:
|
||||
```bash
|
||||
LLM_API_KEY=<your-api-key>
|
||||
APIHZ_ID=<apihz-user-id> # 用于天气 API
|
||||
APIHZ_KEY=<apihz-user-key> # 用于天气 API
|
||||
```
|
||||
|
||||
客户端配置 (位于 `client.py`):
|
||||
- `LLM_MODEL`: 模型名称 (默认:deepseek-v3)
|
||||
- `LLM_BASE_URL`: OpenAI 兼容的 API 端点
|
||||
|
||||
### 关键设计模式
|
||||
|
||||
- **MCP Protocol**: 使用 Model Context Protocol 暴露工具
|
||||
- **SSE Transport**: 客户端通过 Server-Sent Events 连接
|
||||
- **Async Design**: 使用 `asyncio` 进行并发工具执行
|
||||
- **Pydantic Models**: 用于数据验证 (`ToolResult`, `BaseTool`)
|
||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
403
app/agent.py
Normal file
403
app/agent.py
Normal file
@ -0,0 +1,403 @@
|
||||
"""
|
||||
使用 LangChain create_agent 创建 MCP Agent 客户端
|
||||
|
||||
本模块通过 LangChain 的 create_agent 函数创建智能体,
|
||||
并连接 MCP 服务器调用其注册的工具(bash, terminate, add, weather 等)。
|
||||
"""
|
||||
import asyncio
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
|
||||
from langchain.agents import create_agent
|
||||
from langchain_core.tools import BaseTool, ToolException
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
from mcp.types import TextContent
|
||||
from pydantic.v1 import BaseModel, Field
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ = load_dotenv(find_dotenv())
|
||||
|
||||
|
||||
|
||||
from typing import Type
|
||||
|
||||
class MCPTool(BaseTool):
|
||||
"""将 MCP 工具转换为 LangChain 工具的适配器类"""
|
||||
|
||||
name: str = ""
|
||||
description: str = ""
|
||||
session: Optional[ClientSession] = None
|
||||
args_schema: Optional[Type[BaseModel]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
description: str,
|
||||
session: ClientSession,
|
||||
args_schema: Optional[Type[BaseModel]] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
super().__init__(
|
||||
name=name,
|
||||
description=description,
|
||||
session=session,
|
||||
args_schema=args_schema,
|
||||
**kwargs
|
||||
)
|
||||
|
||||
def _run(self, **kwargs: Any) -> Any:
|
||||
"""同步执行工具(通过 asyncio 运行)"""
|
||||
return asyncio.run(self._async_run(**kwargs))
|
||||
|
||||
async def _arun(self, **kwargs: Any) -> Any:
|
||||
"""异步执行工具"""
|
||||
return await self._async_run(**kwargs)
|
||||
|
||||
async def _async_run(self, **kwargs: Any) -> Any:
|
||||
"""执行 MCP 工具调用"""
|
||||
if self.session is None:
|
||||
raise ToolException(f"MCP session is not initialized for tool: {self.name}")
|
||||
|
||||
try:
|
||||
logger.info(f"Calling MCP tool: {self.name} with args: {kwargs}")
|
||||
response = await self.session.call_tool(self.name, kwargs)
|
||||
|
||||
# 提取结果
|
||||
if response.structuredContent is not None:
|
||||
result = json.dumps(response.structuredContent, ensure_ascii=False)
|
||||
elif response.content:
|
||||
result = "\n".join(
|
||||
c.text for c in response.content if isinstance(c, TextContent)
|
||||
)
|
||||
else:
|
||||
result = "工具执行成功,无返回内容"
|
||||
|
||||
# 限制日志长度
|
||||
log_result = result[:500] if len(result) > 500 else result
|
||||
logger.info(f"MCP tool {self.name} returned: {log_result}")
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"工具 {self.name} 执行失败:{str(e)}"
|
||||
logger.error(error_msg)
|
||||
raise ToolException(error_msg)
|
||||
|
||||
|
||||
class MCPLangChainAgent:
|
||||
"""基于 LangChain 的 MCP Agent 封装类"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
sse_server_url: str = "http://127.0.0.1:8000/sse",
|
||||
llm: Optional[Any] = None,
|
||||
):
|
||||
"""
|
||||
初始化 MCP Agent
|
||||
|
||||
Args:
|
||||
sse_server_url: MCP 服务器的 SSE 地址
|
||||
llm: LangChain 兼容的 LLM 实例
|
||||
"""
|
||||
self.sse_server_url = sse_server_url
|
||||
self.llm = llm
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.tools: List[BaseTool] = []
|
||||
self.agent = None
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
|
||||
def _generate_args_schema(self, tool_name: str, input_schema: dict) -> Optional[Type[BaseModel]]:
|
||||
"""为 MCP 工具生成 LangChain 兼容的参数 schema"""
|
||||
if not input_schema or not input_schema.get("properties"):
|
||||
return None
|
||||
|
||||
properties = input_schema.get("properties", {})
|
||||
required_fields = input_schema.get("required", [])
|
||||
|
||||
# 动态创建字段定义
|
||||
fields = {}
|
||||
for prop_name, prop_details in properties.items():
|
||||
field_type = self._map_json_type_to_python(prop_details.get("type", "string"))
|
||||
field_description = prop_details.get("description", "")
|
||||
|
||||
# 根据字段是否必需来设置默认值
|
||||
if prop_name in required_fields:
|
||||
fields[prop_name] = (field_type, Field(..., description=field_description))
|
||||
else:
|
||||
fields[prop_name] = (field_type, Field(None, description=field_description))
|
||||
|
||||
# 动态创建 Pydantic 模型类
|
||||
schema_class = type(
|
||||
f"{tool_name.title()}Schema",
|
||||
(BaseModel,),
|
||||
{"__annotations__": {k: v[0] for k, v in fields.items()}, **{k: v[1] for k, v in fields.items()}}
|
||||
)
|
||||
return schema_class
|
||||
|
||||
def _map_json_type_to_python(self, json_type: str):
|
||||
"""将 JSON Schema 类型映射到 Python 类型"""
|
||||
type_mapping = {
|
||||
"string": str,
|
||||
"integer": int,
|
||||
"number": float,
|
||||
"boolean": bool,
|
||||
"array": list,
|
||||
"object": dict
|
||||
}
|
||||
return type_mapping.get(json_type, str)
|
||||
|
||||
async def connect(self) -> None:
|
||||
"""连接到 MCP 服务器并加载工具"""
|
||||
logger.info(f"Connecting to MCP server: {self.sse_server_url}")
|
||||
|
||||
# 创建 SSE 连接 - 使用同步方式进入上下文
|
||||
self._sse_cm = sse_client(self.sse_server_url)
|
||||
self._read_stream, self._write_stream = await self._sse_cm.__aenter__()
|
||||
|
||||
# 创建并初始化会话
|
||||
self._session_cm = ClientSession(self._read_stream, self._write_stream)
|
||||
self.session = await self._session_cm.__aenter__()
|
||||
await self.session.initialize()
|
||||
|
||||
# 获取工具列表
|
||||
tools_response = await self.session.list_tools()
|
||||
logger.info(f"Retrieved {len(tools_response.tools)} tools from MCP server")
|
||||
|
||||
# 转换 MCP 工具为 LangChain 工具
|
||||
self.tools = []
|
||||
for tool in tools_response.tools:
|
||||
# 生成参数 schema 类
|
||||
args_schema_class = self._generate_args_schema(tool.name, tool.inputSchema)
|
||||
|
||||
mcp_tool = MCPTool(
|
||||
name=tool.name,
|
||||
description=tool.description or "No description provided",
|
||||
session=self.session,
|
||||
args_schema=args_schema_class
|
||||
)
|
||||
self.tools.append(mcp_tool)
|
||||
logger.info(f"Loaded tool: {tool.name} - {tool.description[:50]}...")
|
||||
|
||||
logger.info(f"Successfully connected and loaded {len(self.tools)} tools")
|
||||
|
||||
async def disconnect(self) -> None:
|
||||
"""断开 MCP 服务器连接"""
|
||||
try:
|
||||
if self.session:
|
||||
await self._session_cm.__aexit__(None, None, None)
|
||||
self.session = None
|
||||
if hasattr(self, '_sse_cm') and self._sse_cm:
|
||||
await self._sse_cm.__aexit__(None, None, None)
|
||||
self._read_stream = None
|
||||
self._write_stream = None
|
||||
logger.info("Disconnected from MCP server")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error during disconnect: {e}")
|
||||
|
||||
def create_agent(self, system_message: Optional[str] = None) -> Any:
|
||||
"""
|
||||
创建 LangChain Agent
|
||||
|
||||
Args:
|
||||
system_message: 系统提示词
|
||||
|
||||
Returns:
|
||||
Agent 实例
|
||||
"""
|
||||
if self.llm is None:
|
||||
raise ValueError(
|
||||
"LLM is not set. Please provide a LangChain compatible LLM instance."
|
||||
)
|
||||
|
||||
if not self.tools:
|
||||
raise ValueError("No tools available. Please call connect() first.")
|
||||
|
||||
# 系统提示词
|
||||
system_prompt = system_message or (
|
||||
"你是一个智能助手,可以使用提供的工具来帮助用户解决问题。"
|
||||
"请根据用户的问题,合理选择并使用工具。"
|
||||
"如果不需要使用工具,直接回答用户的问题。"
|
||||
)
|
||||
|
||||
# 使用 LangChain 1.2+ 的 create_agent API
|
||||
self.agent = create_agent(
|
||||
model=self.llm,
|
||||
tools=self.tools,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
logger.info("Agent created successfully")
|
||||
return self.agent
|
||||
|
||||
async def invoke(self, input_text: str) -> Dict[str, Any]:
|
||||
"""
|
||||
执行 Agent
|
||||
|
||||
Args:
|
||||
input_text: 用户输入
|
||||
|
||||
Returns:
|
||||
Agent 执行结果
|
||||
"""
|
||||
if self.agent is None:
|
||||
raise ValueError("Agent is not created. Please call create_agent() first.")
|
||||
|
||||
logger.info(f"Invoking agent with input: {input_text}")
|
||||
result = await self.agent.ainvoke({"input": input_text})
|
||||
logger.info(f"Agent execution completed.")
|
||||
logger.info(f"Raw agent result type: {type(result)}, value: {result}")
|
||||
return result
|
||||
|
||||
async def __aenter__(self):
|
||||
"""异步上下文管理器入口"""
|
||||
await self.connect()
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||||
"""异步上下文管理器出口"""
|
||||
await self.disconnect()
|
||||
return False # 不抑制异常
|
||||
|
||||
|
||||
def create_langchain_llm(
|
||||
model_name: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
base_url: Optional[str] = None,
|
||||
temperature: float = 0.7,
|
||||
**kwargs: Any
|
||||
) -> ChatOpenAI:
|
||||
"""
|
||||
创建 LangChain 兼容的 ChatOpenAI 实例
|
||||
|
||||
Args:
|
||||
model_name: 模型名称
|
||||
api_key: API 密钥
|
||||
base_url: API 基础 URL
|
||||
temperature: 温度参数
|
||||
**kwargs: 其他参数
|
||||
|
||||
Returns:
|
||||
LangChain ChatOpenAI 实例
|
||||
"""
|
||||
llm_config = {
|
||||
"model": model_name or os.getenv("LLM_MODEL", "deepseek-v3"),
|
||||
"api_key": api_key or os.getenv("LLM_API_KEY", "your-api-key"),
|
||||
"base_url": base_url or os.getenv("LLM_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||||
"temperature": temperature,
|
||||
"timeout": 300,
|
||||
}
|
||||
|
||||
# 合并额外参数
|
||||
llm_config.update(kwargs)
|
||||
|
||||
logger.info(f"Creating LLM with model: {llm_config['model']}, base_url: {llm_config['base_url']}")
|
||||
return ChatOpenAI(**llm_config)
|
||||
|
||||
|
||||
async def start_interactive_chat():
|
||||
"""启动交互式对话"""
|
||||
print("\n" + "=" * 60)
|
||||
print("MCP LangChain Agent - 交互式对话")
|
||||
print("=" * 60)
|
||||
print("可用工具:add, bash, terminate, get_weather_by_location")
|
||||
print("输入 'quit' 或 'exit' 退出")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# 创建 LLM 实例
|
||||
llm = create_langchain_llm()
|
||||
|
||||
# 创建并运行 Agent
|
||||
async with MCPLangChainAgent(
|
||||
sse_server_url="http://127.0.0.1:8000/sse",
|
||||
llm=llm
|
||||
) as agent_wrapper:
|
||||
|
||||
# 创建 Agent
|
||||
agent = agent_wrapper.create_agent(
|
||||
system_message=(
|
||||
"你是一个智能助手,可以使用以下工具:\n"
|
||||
"- add(a, b): 计算两个整数的和\n"
|
||||
"- bash(command): 执行 bash 命令\n"
|
||||
"- terminate(status): 终止交互(status: 'success' 或 'failure')\n"
|
||||
"- get_weather_by_location(user_id, user_key, province, place): 获取天气预报\n"
|
||||
"请根据用户需求合理选择工具。"
|
||||
)
|
||||
)
|
||||
|
||||
while True:
|
||||
try:
|
||||
user_input = input("你:").strip()
|
||||
|
||||
if user_input.lower() in ["quit", "exit"]:
|
||||
logger.info("用户退出")
|
||||
break
|
||||
|
||||
if not user_input:
|
||||
continue
|
||||
|
||||
print("\n助手:思考中...", end="\r")
|
||||
|
||||
# 执行 Agent
|
||||
result = await agent_wrapper.invoke(user_input)
|
||||
|
||||
# 提取最终输出
|
||||
logger.info(f"Processing result: {result} (type: {type(result)})")
|
||||
|
||||
# 检查结果结构并相应地提取输出
|
||||
if isinstance(result, dict):
|
||||
# 如果结果包含 messages 键(LangChain agent 的典型返回格式)
|
||||
if "messages" in result and isinstance(result["messages"], list):
|
||||
# 从 messages 列表中提取最后一个消息的内容
|
||||
messages = result["messages"]
|
||||
if messages:
|
||||
# 获取最后一个消息对象
|
||||
last_message = messages[-1]
|
||||
# 检查是否是 AIMessage 对象并尝试提取 content
|
||||
if hasattr(last_message, 'content'):
|
||||
output = str(last_message.content) if last_message.content is not None else "无响应内容"
|
||||
else:
|
||||
output = str(last_message) if last_message is not None else "无响应内容"
|
||||
else:
|
||||
output = "无响应内容"
|
||||
else:
|
||||
# 尝试多种可能的键
|
||||
output = (result.get("output") or
|
||||
result.get("result") or
|
||||
result.get("response") or
|
||||
result.get("content") or
|
||||
result.get("answer") or
|
||||
"无响应内容")
|
||||
else:
|
||||
# 如果结果不是字典,直接转换为字符串
|
||||
output = str(result) if result is not None else "无响应内容"
|
||||
|
||||
# 显示结果
|
||||
print(f"助手:{output}\n")
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("\n用户强制中断")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"执行错误:{str(e)}", exc_info=True)
|
||||
print(f"错误:{str(e)}\n")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
asyncio.run(start_interactive_chat())
|
||||
except KeyboardInterrupt:
|
||||
print("\n再见!")
|
||||
13
app/exceptions.py
Normal file
13
app/exceptions.py
Normal file
@ -0,0 +1,13 @@
|
||||
class ToolError(Exception):
|
||||
"""当工具遇到错误时引发。"""
|
||||
|
||||
def __init__(self, message):
|
||||
self.message = message
|
||||
|
||||
|
||||
class OpenManusError(Exception):
|
||||
"""所有 OpenManus 错误的基础异常"""
|
||||
|
||||
|
||||
class TokenLimitExceeded(OpenManusError):
|
||||
"""当超过 token 限制时引发的异常"""
|
||||
0
app/mcp_server/__init__.py
Normal file
0
app/mcp_server/__init__.py
Normal file
194
app/mcp_server/server.py
Normal file
194
app/mcp_server/server.py
Normal file
@ -0,0 +1,194 @@
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO, handlers=[logging.StreamHandler(sys.stderr)])
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import atexit
|
||||
import json
|
||||
from inspect import Parameter, Signature
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
|
||||
from app.utils.logger import logger
|
||||
from app.tools.base import BaseTool
|
||||
from app.tools.bash import Bash
|
||||
# from app.tools.browser_use_tool import BrowserUseTool
|
||||
# from app.tools.str_replace_editor import StrReplaceEditor
|
||||
from app.tools.terminate import Terminate
|
||||
from app.tools.add import Add
|
||||
from app.tools.weather import GetWeatherByLocation
|
||||
|
||||
|
||||
class MCPServer:
|
||||
"""具有工具注册和管理功能的 MCP 服务器实现。"""
|
||||
|
||||
def __init__(self, name: str = "openmanus"):
|
||||
self.server = FastMCP(name)
|
||||
self.tools: Dict[str, BaseTool] = {}
|
||||
|
||||
# 初始化标准工具
|
||||
self.tools["bash"] = Bash()
|
||||
# self.tools["browser"] = BrowserUseTool()
|
||||
# self.tools["editor"] = StrReplaceEditor()
|
||||
self.tools["terminate"] = Terminate()
|
||||
self.tools["add"] = Add()
|
||||
self.tools["weather"] = GetWeatherByLocation()
|
||||
|
||||
def register_tool(self, tool: BaseTool, method_name: Optional[str] = None) -> None:
|
||||
"""注册一个工具,包含参数验证和文档。"""
|
||||
tool_name = method_name or tool.name
|
||||
tool_param = tool.to_param()
|
||||
tool_function = tool_param["function"]
|
||||
|
||||
# 定义要注册的异步函数
|
||||
async def tool_method(**kwargs):
|
||||
logger.info(f"Executing {tool_name}: {kwargs}")
|
||||
result = await tool.execute(**kwargs)
|
||||
|
||||
logger.info(f"Result of {tool_name}: {result}")
|
||||
|
||||
# 处理不同类型的结果(匹配原始逻辑)
|
||||
if hasattr(result, "model_dump"):
|
||||
return json.dumps(result.model_dump())
|
||||
elif isinstance(result, dict):
|
||||
return json.dumps(result)
|
||||
return result
|
||||
|
||||
# 设置方法元数据
|
||||
tool_method.__name__ = tool_name
|
||||
tool_method.__doc__ = self._build_docstring(tool_function)
|
||||
tool_method.__signature__ = self._build_signature(tool_function)
|
||||
|
||||
# 存储参数模式(对于以编程方式访问它的工具很重要)
|
||||
param_props = tool_function.get("parameters", {}).get("properties", {})
|
||||
required_params = tool_function.get("parameters", {}).get("required", [])
|
||||
tool_method._parameter_schema = {
|
||||
param_name: {
|
||||
"description": param_details.get("description", ""),
|
||||
"type": param_details.get("type", "any"),
|
||||
"required": param_name in required_params,
|
||||
}
|
||||
for param_name, param_details in param_props.items()
|
||||
}
|
||||
|
||||
# 注册到服务器
|
||||
self.server.tool()(tool_method)
|
||||
logger.info(f"Registered tool: {tool_name}")
|
||||
|
||||
def _build_docstring(self, tool_function: dict) -> str:
|
||||
"""从工具函数元数据构建格式化的文档字符串。"""
|
||||
description = tool_function.get("description", "")
|
||||
param_props = tool_function.get("parameters", {}).get("properties", {})
|
||||
required_params = tool_function.get("parameters", {}).get("required", [])
|
||||
|
||||
# 构建文档字符串(匹配原始格式)
|
||||
docstring = description
|
||||
if param_props:
|
||||
docstring += "\n\nParameters:\n"
|
||||
for param_name, param_details in param_props.items():
|
||||
required_str = (
|
||||
"(required)" if param_name in required_params else "(optional)"
|
||||
)
|
||||
param_type = param_details.get("type", "any")
|
||||
param_desc = param_details.get("description", "")
|
||||
docstring += (
|
||||
f" {param_name} ({param_type}) {required_str}: {param_desc}\n"
|
||||
)
|
||||
|
||||
return docstring
|
||||
|
||||
def _build_signature(self, tool_function: dict) -> Signature:
|
||||
"""从工具函数元数据构建函数签名。"""
|
||||
param_props = tool_function.get("parameters", {}).get("properties", {})
|
||||
required_params = tool_function.get("parameters", {}).get("required", [])
|
||||
|
||||
parameters = []
|
||||
|
||||
# 遵循原始类型映射
|
||||
for param_name, param_details in param_props.items():
|
||||
param_type = param_details.get("type", "")
|
||||
default = Parameter.empty if param_name in required_params else None
|
||||
|
||||
# 将 JSON Schema 类型映射到 Python 类型(与原始相同)
|
||||
annotation = Any
|
||||
if param_type == "string":
|
||||
annotation = str
|
||||
elif param_type == "integer":
|
||||
annotation = int
|
||||
elif param_type == "number":
|
||||
annotation = float
|
||||
elif param_type == "boolean":
|
||||
annotation = bool
|
||||
elif param_type == "object":
|
||||
annotation = dict
|
||||
elif param_type == "array":
|
||||
annotation = list
|
||||
|
||||
# 创建与原始结构相同的参数
|
||||
param = Parameter(
|
||||
name=param_name,
|
||||
kind=Parameter.KEYWORD_ONLY,
|
||||
default=default,
|
||||
annotation=annotation,
|
||||
)
|
||||
parameters.append(param)
|
||||
|
||||
return Signature(parameters=parameters)
|
||||
|
||||
async def cleanup(self) -> None:
|
||||
"""清理服务器资源。"""
|
||||
logger.info("Cleaning up resources")
|
||||
# 遵循原始清理逻辑 - 仅清理浏览器工具
|
||||
if "browser" in self.tools and hasattr(self.tools["browser"], "cleanup"):
|
||||
await self.tools["browser"].cleanup()
|
||||
|
||||
def register_all_tools(self) -> None:
|
||||
"""向服务器注册所有工具。"""
|
||||
for tool in self.tools.values():
|
||||
self.register_tool(tool)
|
||||
|
||||
def run(self, transport: str = "stdio") -> None:
|
||||
"""运行 MCP 服务器。"""
|
||||
# 注册所有工具
|
||||
self.register_all_tools()
|
||||
|
||||
# 注册清理函数(匹配原始行为)
|
||||
atexit.register(lambda: asyncio.run(self.cleanup()))
|
||||
|
||||
# 启动服务器(使用与原始相同的日志记录)
|
||||
logger.info(f"Starting OpenManus server ({transport} mode)")
|
||||
self.server.run(transport=transport)
|
||||
|
||||
|
||||
def parse_args() -> argparse.Namespace:
|
||||
"""解析命令行参数。"""
|
||||
parser = argparse.ArgumentParser(description="OpenManus MCP Server")
|
||||
parser.add_argument(
|
||||
"--transport",
|
||||
choices=["stdio", "sse"],
|
||||
default="stdio",
|
||||
help="通信方法:stdio 或 sse (默认:stdio)",
|
||||
)
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = parse_args()
|
||||
|
||||
# 创建服务器
|
||||
server = MCPServer()
|
||||
server.register_all_tools()
|
||||
|
||||
if args.transport == "sse":
|
||||
# SSE 模式:直接运行 FastMCP 服务器(默认端口 8000)
|
||||
logger.info(f"Starting MCP server with SSE transport")
|
||||
server.server.run(transport="sse")
|
||||
else:
|
||||
# stdio 模式
|
||||
atexit.register(lambda: asyncio.run(server.cleanup()))
|
||||
logger.info(f"Starting MCP server with stdio transport")
|
||||
server.server.run(transport="stdio")
|
||||
96
app/tool_call_agent.py
Normal file
96
app/tool_call_agent.py
Normal file
@ -0,0 +1,96 @@
|
||||
import os
|
||||
import uuid
|
||||
from deepagents import create_deep_agent
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
from langchain.agents import create_agent
|
||||
from langgraph.checkpoint.memory import InMemorySaver
|
||||
from app.tools.agent_tools import add, get_weather_by_location
|
||||
from langchain_openai import ChatOpenAI
|
||||
from app.tools.execute_sql import execute_query
|
||||
|
||||
_ = load_dotenv(find_dotenv())
|
||||
|
||||
checkpointer = InMemorySaver()
|
||||
|
||||
# model = ChatTongyi(
|
||||
# model="qwen3-30b-a3b-thinking-2507",
|
||||
# dashscope_api_key=os.getenv('LLM_API_KEY'),
|
||||
# base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
# )
|
||||
|
||||
model = ChatOpenAI(
|
||||
model="qwen3-30b-a3b-thinking-2507",
|
||||
api_key=os.getenv('LLM_API_KEY'),
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1"
|
||||
)
|
||||
|
||||
# agent = create_agent(
|
||||
# model=model,
|
||||
# tools=[add, get_weather_by_location],
|
||||
# system_prompt="你是一个有帮助的助手。请简洁准确。"
|
||||
# )
|
||||
|
||||
agent = create_deep_agent(
|
||||
model=model,
|
||||
tools=[add, get_weather_by_location, execute_query],
|
||||
checkpointer=checkpointer,
|
||||
system_prompt="你是一个有帮助的助手。请简洁准确,用中文进行回答。"
|
||||
)
|
||||
|
||||
thread_id = str(uuid.uuid4())
|
||||
config = {
|
||||
"configurable": {
|
||||
"thread_id": thread_id
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 启用流式输出
|
||||
print("=== 开始流式输出 ===")
|
||||
stream_result = agent.stream(
|
||||
{"messages": [{"role": "user", "content": "查询车上咖啡机数据,不要修改原来的sql,除非运行报错。"
|
||||
"sql: select counts(*) from coffee_train"}]},
|
||||
config=config
|
||||
)
|
||||
|
||||
# 逐块处理流式输出
|
||||
full_response = ""
|
||||
for chunk in stream_result:
|
||||
# 打印每一块内容(调试用途)
|
||||
print("接收到数据块:", chunk)
|
||||
|
||||
# 解析消息内容
|
||||
if 'messages' in chunk:
|
||||
messages = chunk['messages']
|
||||
for msg in messages:
|
||||
# 打印消息类型和所有属性
|
||||
print(f"--- 消息类型: {type(msg).__name__} ---")
|
||||
if hasattr(msg, 'content') and msg.content:
|
||||
print(f"消息内容: {msg.content}")
|
||||
|
||||
# 处理 AIMessage(带工具调用)
|
||||
if hasattr(msg, 'tool_calls') and msg.tool_calls:
|
||||
print("正在调用工具...")
|
||||
for tool_call in msg.tool_calls:
|
||||
print(f" 工具名: {tool_call.get('name', 'N/A')}")
|
||||
print(f" 参数: {tool_call.get('args', {})}")
|
||||
print(f" 调用ID: {tool_call.get('id', 'N/A')}")
|
||||
|
||||
# 处理 ToolMessage(工具响应)
|
||||
elif hasattr(msg, 'name') and msg.name:
|
||||
print(f"工具名称: {msg.name}")
|
||||
print(f"工具调用ID: {msg.tool_call_id}")
|
||||
print(f"工具响应: {msg.content}")
|
||||
|
||||
# 处理普通消息内容
|
||||
elif hasattr(msg, 'content'):
|
||||
content = msg.content
|
||||
full_response += content
|
||||
|
||||
print("\n=== 流式输出结束 ===")
|
||||
|
||||
# 打印最终完整响应
|
||||
print("\n=== 最终完整响应 ===")
|
||||
print(full_response)
|
||||
|
||||
15
app/tools/__init__.py
Normal file
15
app/tools/__init__.py
Normal file
@ -0,0 +1,15 @@
|
||||
from app.tools.base import BaseTool, ToolResult, CLIResult
|
||||
from app.tools.bash import Bash
|
||||
from app.tools.terminate import Terminate
|
||||
from app.tools.add import Add
|
||||
from app.tools.weather import GetWeatherByLocation
|
||||
|
||||
__all__ = [
|
||||
"BaseTool",
|
||||
"ToolResult",
|
||||
"CLIResult",
|
||||
"Bash",
|
||||
"Terminate",
|
||||
"Add",
|
||||
"GetWeatherByLocation",
|
||||
]
|
||||
30
app/tools/add.py
Normal file
30
app/tools/add.py
Normal file
@ -0,0 +1,30 @@
|
||||
from app.tools.base import BaseTool, ToolResult
|
||||
|
||||
|
||||
_ADD_DESCRIPTION = """计算两个整数的和。"""
|
||||
|
||||
|
||||
class Add(BaseTool):
|
||||
"""加法计算工具"""
|
||||
|
||||
name: str = "add"
|
||||
description: str = _ADD_DESCRIPTION
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {
|
||||
"type": "integer",
|
||||
"description": "第一个加数",
|
||||
},
|
||||
"b": {
|
||||
"type": "integer",
|
||||
"description": "第二个加数",
|
||||
},
|
||||
},
|
||||
"required": ["a", "b"],
|
||||
}
|
||||
|
||||
async def execute(self, a: int, b: int) -> ToolResult:
|
||||
"""执行加法计算"""
|
||||
result = a + b
|
||||
return self.success_response({"result": result})
|
||||
190
app/tools/agent_tools.py
Normal file
190
app/tools/agent_tools.py
Normal file
@ -0,0 +1,190 @@
|
||||
from langchain.agents.middleware import wrap_tool_call
|
||||
from langchain.tools import tool
|
||||
from typing import Dict, Optional, Union
|
||||
import requests
|
||||
import os
|
||||
from dotenv import load_dotenv, find_dotenv
|
||||
from langchain_core.messages import ToolMessage
|
||||
|
||||
_ = load_dotenv(find_dotenv())
|
||||
@tool
|
||||
def add(a: int, b: int) -> int:
|
||||
"""Add two numbers"""
|
||||
return a + b
|
||||
|
||||
@tool
|
||||
def get_weather_by_location(province: str, place: str) -> Dict[
|
||||
str, Union[str, float, int]
|
||||
]:
|
||||
"""
|
||||
调用中国气象局天气预报API,获取指定省份和地点的完整当日天气信息(含基础预报、实时数据、气象预警)
|
||||
|
||||
### 参数说明
|
||||
- province: str - 查询省份/直辖市,建议去除"省""市"后缀(例:"四川"而非"四川省","北京"而非"北京市")
|
||||
- place: str - 查询城市/区/县,建议去除"市""区""县"后缀(例:"绵阳"而非"绵阳市","大兴"而非"大兴区")
|
||||
|
||||
### 返回说明(成功,code=200)
|
||||
基础预报信息:
|
||||
- code: int - 状态码,成功固定为200
|
||||
- guo: str - 国家名称(例:"中国")
|
||||
- sheng: str - 标准化省份/直辖市名称(例:"四川")
|
||||
- shi: str - 标准化城市/地区名称(例:"绵阳")
|
||||
- name: str - 与shi一致,冗余字段(例:"绵阳")
|
||||
- weather1: str - 当日主要天气(例:"阵雨")
|
||||
- weather2: str - 当日次要天气(例:"阵雨")
|
||||
- wd1: str - 当日最高温度(单位:℃,例:"25")
|
||||
- wd2: str - 当日最低温度(单位:℃,例:"18")
|
||||
- winddirection1: str - 当日主要风向(例:"无持续风向")
|
||||
- winddirection2: str - 当日次要风向(例:"无持续风向")
|
||||
- windleve1: str - 当日主要风力等级(例:"微风")
|
||||
- windleve2: str - 当日次要风力等级(例:"微风")
|
||||
- weather1img: str - 主要天气图标URL(例:"https://rescdn.apihz.cn/resimg/tianqi/zhenyu.png")
|
||||
- weather2img: str - 次要天气图标URL(例:"https://rescdn.apihz.cn/resimg/tianqi/zhenyu.png")
|
||||
- lon: str - 地区经度(保留3位小数,例:"104.730")
|
||||
- lat: str - 地区纬度(保留3位小数,例:"31.440")
|
||||
- uptime: str - 预报数据更新时间(格式:YYYY-MM-DD HH:MM:SS,例:"2025-08-29 12:00:00")
|
||||
|
||||
实时天气数据(顶层字段):
|
||||
- now_precipitation: float - 当前降水量(单位:mm,例:0.0)
|
||||
- now_temperature: float - 当前温度(单位:℃,例:19.3)
|
||||
- now_pressure: int - 当前气压(单位:hPa,例:956)
|
||||
- now_humidity: int - 当前湿度(单位:%,例:85)
|
||||
- now_windDirection: str - 当前风向(例:"东北风")
|
||||
- now_windDirectionDegree: int - 当前风向角度(0°=北风,例:28)
|
||||
- now_windSpeed: float - 当前风速(单位:m/s,例:3.2)
|
||||
- now_windScale: str - 当前风力等级(例:"微风")
|
||||
- now_feelst: float - 当前体感温度(单位:℃,例:19.7)
|
||||
- now_uptime: str - 实时数据更新时间(格式:YYYY/MM/DD HH:MM,例:"2025/08/29 10:05")
|
||||
|
||||
气象预警(顶层字段,无预警时为空字符串):
|
||||
- alarm_id: str - 预警唯一ID(例:"51070041600000_20250828215515")
|
||||
- alarm_title: str - 预警标题(例:"绵阳市气象台更新大风蓝色预警信号[IV级/一般]")
|
||||
- alarm_signaltype: str - 预警类型(例:"大风")
|
||||
- alarm_signallevel: str - 预警等级(例:"蓝色")
|
||||
- alarm_effective: str - 预警生效时间(例:"2025/08/28 21:55")
|
||||
- alarm_eventType: str - 预警事件编码(例:"11B06")
|
||||
- alarm_severity: str - 预警等级英文编码(例:"BLUE")
|
||||
- alarm_type: str - 预警类型编码(例:"p0007004")
|
||||
|
||||
### 返回说明(失败,code=400)
|
||||
- code: int - 错误状态码,固定为400
|
||||
- msg: str - 错误详情(例:"通讯秘钥错误。"、"API响应解析失败")
|
||||
"""
|
||||
api_url = "https://cn.apihz.cn/api/tianqi/tqyb.php"
|
||||
clean_province = province.replace("省", "").replace("市", "").strip()
|
||||
clean_place = place.replace("市", "").replace("区", "").replace("县", "").strip()
|
||||
final_user_id = os.getenv("APIHZ_ID")
|
||||
final_user_key = os.getenv("APIHZ_KEY")
|
||||
params = {
|
||||
"id": final_user_id,
|
||||
"key": final_user_key,
|
||||
"sheng": clean_province,
|
||||
"place": clean_place
|
||||
}
|
||||
|
||||
try:
|
||||
# 1. 发送请求(增加超时重试,避免偶发网络波动)
|
||||
response = requests.get(api_url, params=params, timeout=10)
|
||||
response.raise_for_status() # 捕获4xx/5xx HTTP错误
|
||||
|
||||
# 2. 解析响应:先判断响应内容是否为空,再转JSON
|
||||
response_text = response.text.strip()
|
||||
if not response_text:
|
||||
return {"code": 400, "msg": "API响应为空,无法解析天气数据"}
|
||||
|
||||
# 3. 转JSON并防御None(确保weather_data是字典)
|
||||
weather_data = response.json()
|
||||
if not isinstance(weather_data, dict):
|
||||
return {"code": 400, "msg": f"API响应格式错误,不是有效字典(实际类型:{type(weather_data).__name__})"}
|
||||
|
||||
# 4. 处理API返回的错误状态(code=400)
|
||||
if weather_data.get("code") == 400:
|
||||
return {
|
||||
"code": 400,
|
||||
"msg": weather_data.get("msg", "API返回错误,原因未知")
|
||||
}
|
||||
# 5. 确保API返回成功状态(code=200)
|
||||
elif weather_data.get("code") != 200:
|
||||
return {
|
||||
"code": 400,
|
||||
"msg": f"API返回非成功状态码:{weather_data.get('code', '未知')}"
|
||||
}
|
||||
|
||||
# 6. 提取嵌套数据(确保now_info/alarm_info是字典,避免None)
|
||||
now_info = weather_data.get("nowinfo", {})
|
||||
if not isinstance(now_info, dict):
|
||||
now_info = {} # 若now_info不是字典,强制设为空字典
|
||||
|
||||
alarm_info = weather_data.get("alarm", {})
|
||||
if not isinstance(alarm_info, dict):
|
||||
alarm_info = {} # 若alarm不是字典,强制设为空字典
|
||||
|
||||
# 7. 构造最终返回结果(全部顶层字段,无嵌套)
|
||||
return {
|
||||
# 基础预报字段
|
||||
"code": 200,
|
||||
"guo": weather_data.get("guo", ""),
|
||||
"sheng": weather_data.get("sheng", ""),
|
||||
"shi": weather_data.get("shi", ""),
|
||||
"name": weather_data.get("name", ""),
|
||||
"weather1": weather_data.get("weather1", ""),
|
||||
"weather2": weather_data.get("weather2", ""),
|
||||
"wd1": weather_data.get("wd1", ""),
|
||||
"wd2": weather_data.get("wd2", ""),
|
||||
"winddirection1": weather_data.get("winddirection1", ""),
|
||||
"winddirection2": weather_data.get("winddirection2", ""),
|
||||
"windleve1": weather_data.get("windleve1", ""),
|
||||
"windleve2": weather_data.get("windleve2", ""),
|
||||
"weather1img": weather_data.get("weather1img", ""),
|
||||
"weather2img": weather_data.get("weather2img", ""),
|
||||
"lon": weather_data.get("lon", ""),
|
||||
"lat": weather_data.get("lat", ""),
|
||||
"uptime": weather_data.get("uptime", ""),
|
||||
# 实时天气字段
|
||||
"now_precipitation": float(now_info.get("precipitation", 0.0)),
|
||||
"now_temperature": float(now_info.get("temperature", 0.0)),
|
||||
"now_pressure": int(now_info.get("pressure", 0)),
|
||||
"now_humidity": int(now_info.get("humidity", 0)),
|
||||
"now_windDirection": now_info.get("windDirection", ""),
|
||||
"now_windDirectionDegree": int(now_info.get("windDirectionDegree", 0)),
|
||||
"now_windSpeed": float(now_info.get("windSpeed", 0.0)),
|
||||
"now_windScale": now_info.get("windScale", ""),
|
||||
"now_feelst": float(now_info.get("feelst", 0.0)),
|
||||
"now_uptime": now_info.get("uptime", ""),
|
||||
# 预警字段
|
||||
"alarm_id": alarm_info.get("id", ""),
|
||||
"alarm_title": alarm_info.get("title", ""),
|
||||
"alarm_signaltype": alarm_info.get("signaltype", ""),
|
||||
"alarm_signallevel": alarm_info.get("signallevel", ""),
|
||||
"alarm_effective": alarm_info.get("effective", ""),
|
||||
"alarm_eventType": alarm_info.get("eventType", ""),
|
||||
"alarm_severity": alarm_info.get("severity", ""),
|
||||
"alarm_type": alarm_info.get("type", "")
|
||||
}
|
||||
|
||||
# 8. 捕获各类异常(明确错误原因)
|
||||
except requests.exceptions.HTTPError as e:
|
||||
status_code = response.status_code if "response" in locals() else "未知"
|
||||
return {"code": 400, "msg": f"HTTP请求错误(状态码:{status_code}):{str(e)}"}
|
||||
except requests.exceptions.ConnectionError:
|
||||
return {"code": 400, "msg": "网络连接错误:无法连接到天气API服务器"}
|
||||
except requests.exceptions.Timeout:
|
||||
return {"code": 400, "msg": "请求超时:API服务器10秒内未响应"}
|
||||
except ValueError as e:
|
||||
# 捕获JSON解析错误(如响应不是合法JSON)
|
||||
return {"code": 400, "msg": f"API响应解析失败(JSON格式错误):{str(e)}"}
|
||||
except Exception as e:
|
||||
# 捕获其他未知错误(附带具体错误信息,便于调试)
|
||||
return {"code": 400, "msg": f"未知错误:{str(e)}(错误类型:{type(e).__name__})"}
|
||||
|
||||
@wrap_tool_call
|
||||
def handle_tool_errors(request, handler):
|
||||
"""使用自定义消息处理工具执行错误。"""
|
||||
try:
|
||||
return handler(request)
|
||||
except Exception as e:
|
||||
# 向模型返回自定义错误消息
|
||||
return ToolMessage(
|
||||
content=f"工具错误:请检查您的输入并重试。({str(e)})",
|
||||
tool_call_id=request.tool_call["id"]
|
||||
)
|
||||
153
app/tools/base.py
Normal file
153
app/tools/base.py
Normal file
@ -0,0 +1,153 @@
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from app.utils.logger import logger
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
"""表示工具执行的结果。"""
|
||||
|
||||
output: Any = Field(default=None)
|
||||
error: Optional[str] = Field(default=None)
|
||||
base64_image: Optional[str] = Field(default=None)
|
||||
system: Optional[str] = Field(default=None)
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
def __bool__(self):
|
||||
return any(getattr(self, field) for field in self.__fields__)
|
||||
|
||||
def __add__(self, other: "ToolResult"):
|
||||
def combine_fields(
|
||||
field: Optional[str], other_field: Optional[str], concatenate: bool = True
|
||||
):
|
||||
if field and other_field:
|
||||
if concatenate:
|
||||
return field + other_field
|
||||
raise ValueError("Cannot combine tool results")
|
||||
return field or other_field
|
||||
|
||||
return ToolResult(
|
||||
output=combine_fields(self.output, other.output),
|
||||
error=combine_fields(self.error, other.error),
|
||||
base64_image=combine_fields(self.base64_image, other.base64_image, False),
|
||||
system=combine_fields(self.system, other.system),
|
||||
)
|
||||
|
||||
def __str__(self):
|
||||
return f"Error: {self.error}" if self.error else self.output
|
||||
|
||||
def replace(self, **kwargs):
|
||||
"""返回一个替换了给定字段的新 ToolResult。"""
|
||||
# return self.copy(update=kwargs)
|
||||
return type(self)(**{**self.dict(), **kwargs})
|
||||
|
||||
|
||||
class BaseTool(ABC, BaseModel):
|
||||
"""所有工具的整合基类,结合了 BaseModel 和 Tool 功能。
|
||||
|
||||
提供:
|
||||
- Pydantic 模型验证
|
||||
- 模式注册
|
||||
- 标准化结果处理
|
||||
- 抽象执行接口
|
||||
|
||||
属性:
|
||||
name (str): 工具名称
|
||||
description (str): 工具描述
|
||||
parameters (dict): 工具参数模式
|
||||
_schemas (Dict[str, List[ToolSchema]]): 已注册的方法模式
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
parameters: Optional[dict] = None
|
||||
# _schemas: Dict[str, List[ToolSchema]] = {}
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
underscore_attrs_are_private = False
|
||||
|
||||
# def __init__(self, **data):
|
||||
# """Initialize tool with model validation and schema registration."""
|
||||
# super().__init__(**data)
|
||||
# logger.debug(f"Initializing tool class: {self.__class__.__name__}")
|
||||
# self._register_schemas()
|
||||
|
||||
# def _register_schemas(self):
|
||||
# """Register schemas from all decorated methods."""
|
||||
# for name, method in inspect.getmembers(self, predicate=inspect.ismethod):
|
||||
# if hasattr(method, 'tool_schemas'):
|
||||
# self._schemas[name] = method.tool_schemas
|
||||
# logger.debug(f"Registered schemas for method '{name}' in {self.__class__.__name__}")
|
||||
|
||||
async def __call__(self, **kwargs) -> Any:
|
||||
"""使用给定参数执行工具。"""
|
||||
return await self.execute(**kwargs)
|
||||
|
||||
@abstractmethod
|
||||
async def execute(self, **kwargs) -> Any:
|
||||
"""使用给定参数执行工具。"""
|
||||
|
||||
def to_param(self) -> Dict:
|
||||
"""将工具转换为函数调用格式。
|
||||
|
||||
Returns:
|
||||
包含 OpenAI 函数调用格式的工具元数据的字典
|
||||
"""
|
||||
return {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": self.name,
|
||||
"description": self.description,
|
||||
"parameters": self.parameters,
|
||||
},
|
||||
}
|
||||
|
||||
# def get_schemas(self) -> Dict[str, List[ToolSchema]]:
|
||||
# """Get all registered tool schemas.
|
||||
|
||||
# Returns:
|
||||
# Dict mapping method names to their schema definitions
|
||||
# """
|
||||
# return self._schemas
|
||||
|
||||
def success_response(self, data: Union[Dict[str, Any], str]) -> ToolResult:
|
||||
"""创建成功的工具结果。
|
||||
|
||||
Args:
|
||||
data: 结果数据(字典或字符串)
|
||||
|
||||
Returns:
|
||||
带有 success=True 和格式化输出的 ToolResult
|
||||
"""
|
||||
if isinstance(data, str):
|
||||
text = data
|
||||
else:
|
||||
text = json.dumps(data, indent=2)
|
||||
logger.debug(f"Created success response for {self.__class__.__name__}")
|
||||
return ToolResult(output=text)
|
||||
|
||||
def fail_response(self, msg: str) -> ToolResult:
|
||||
"""创建失败的工具结果。
|
||||
|
||||
Args:
|
||||
msg: 描述失败的错误消息
|
||||
|
||||
Returns:
|
||||
带有 success=False 和错误消息的 ToolResult
|
||||
"""
|
||||
logger.debug(f"Tool {self.__class__.__name__} returned failed result: {msg}")
|
||||
return ToolResult(error=msg)
|
||||
|
||||
|
||||
class CLIResult(ToolResult):
|
||||
"""可以渲染为 CLI 输出的 ToolResult。"""
|
||||
|
||||
|
||||
class ToolFailure(ToolResult):
|
||||
"""表示失败的 ToolResult。"""
|
||||
158
app/tools/bash.py
Normal file
158
app/tools/bash.py
Normal file
@ -0,0 +1,158 @@
|
||||
import asyncio
|
||||
import os
|
||||
from typing import Optional
|
||||
|
||||
from app.exceptions import ToolError
|
||||
from app.tools.base import BaseTool, CLIResult
|
||||
|
||||
|
||||
_BASH_DESCRIPTION = """在终端中执行 bash 命令。
|
||||
* 长时间运行的命令:对于可能无限期运行的命令,应该在后台运行并将输出重定向到文件,例如 command = `python3 app.py > server.log 2>&1 &`。
|
||||
* 交互式:如果 bash 命令返回退出代码 `-1`,这意味着进程尚未完成。助手必须向终端发送第二次调用,使用空的 `command`(这将检索任何额外的日志),或者它可以向正在运行的进程的 STDIN 发送附加文本(将 `command` 设置为文本),或者它可以发送 command=`ctrl+c` 来中断进程。
|
||||
* 超时:如果命令执行结果说 "Command timed out. Sending SIGINT to the process",助手应该重试在后台运行该命令。
|
||||
"""
|
||||
|
||||
|
||||
class _BashSession:
|
||||
"""bash shell 的会话。"""
|
||||
|
||||
_started: bool
|
||||
_process: asyncio.subprocess.Process
|
||||
|
||||
command: str = "/bin/bash"
|
||||
_output_delay: float = 0.2 # 秒
|
||||
_timeout: float = 120.0 # 秒
|
||||
_sentinel: str = "<<exit>>"
|
||||
|
||||
def __init__(self):
|
||||
self._started = False
|
||||
self._timed_out = False
|
||||
|
||||
async def start(self):
|
||||
if self._started:
|
||||
return
|
||||
|
||||
self._process = await asyncio.create_subprocess_shell(
|
||||
self.command,
|
||||
preexec_fn=os.setsid,
|
||||
shell=True,
|
||||
bufsize=0,
|
||||
stdin=asyncio.subprocess.PIPE,
|
||||
stdout=asyncio.subprocess.PIPE,
|
||||
stderr=asyncio.subprocess.PIPE,
|
||||
)
|
||||
|
||||
self._started = True
|
||||
|
||||
def stop(self):
|
||||
"""终止 bash shell。"""
|
||||
if not self._started:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process.returncode is not None:
|
||||
return
|
||||
self._process.terminate()
|
||||
|
||||
async def run(self, command: str):
|
||||
"""在 bash shell 中执行命令。"""
|
||||
if not self._started:
|
||||
raise ToolError("Session has not started.")
|
||||
if self._process.returncode is not None:
|
||||
return CLIResult(
|
||||
system="tool must be restarted",
|
||||
error=f"bash has exited with returncode {self._process.returncode}",
|
||||
)
|
||||
if self._timed_out:
|
||||
raise ToolError(
|
||||
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||
)
|
||||
|
||||
# 我们知道这些不是 None,因为我们使用 PIPEs 创建了进程
|
||||
assert self._process.stdin
|
||||
assert self._process.stdout
|
||||
assert self._process.stderr
|
||||
|
||||
# 向进程发送命令
|
||||
self._process.stdin.write(
|
||||
command.encode() + f"; echo '{self._sentinel}'\n".encode()
|
||||
)
|
||||
await self._process.stdin.drain()
|
||||
|
||||
# 从进程读取输出,直到找到标记
|
||||
try:
|
||||
async with asyncio.timeout(self._timeout):
|
||||
while True:
|
||||
await asyncio.sleep(self._output_delay)
|
||||
# 如果我们直接从 stdout/stderr 读取,它将永远等待 EOF。
|
||||
# 改为直接使用 StreamReader 缓冲区。
|
||||
output = (
|
||||
self._process.stdout._buffer.decode()
|
||||
) # pyright: ignore[reportAttributeAccessIssue]
|
||||
if self._sentinel in output:
|
||||
# 去除标记并中断
|
||||
output = output[: output.index(self._sentinel)]
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
self._timed_out = True
|
||||
raise ToolError(
|
||||
f"timed out: bash has not returned in {self._timeout} seconds and must be restarted",
|
||||
) from None
|
||||
|
||||
if output.endswith("\n"):
|
||||
output = output[:-1]
|
||||
|
||||
error = (
|
||||
self._process.stderr._buffer.decode()
|
||||
) # pyright: ignore[reportAttributeAccessIssue]
|
||||
if error.endswith("\n"):
|
||||
error = error[:-1]
|
||||
|
||||
# 清除缓冲区,以便可以正确读取下一个输出
|
||||
self._process.stdout._buffer.clear() # pyright: ignore[reportAttributeAccessIssue]
|
||||
self._process.stderr._buffer.clear() # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return CLIResult(output=output, error=error)
|
||||
|
||||
|
||||
class Bash(BaseTool):
|
||||
"""用于执行 bash 命令的工具"""
|
||||
|
||||
name: str = "bash"
|
||||
description: str = _BASH_DESCRIPTION
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"command": {
|
||||
"type": "string",
|
||||
"description": "要执行的 bash 命令。当先前的退出代码为 `-1` 时可以为空以查看其他日志。可以是 `ctrl+c` 来中断当前正在运行的进程。",
|
||||
},
|
||||
},
|
||||
"required": ["command"],
|
||||
}
|
||||
|
||||
_session: Optional[_BashSession] = None
|
||||
|
||||
async def execute(
|
||||
self, command: str | None = None, restart: bool = False, **kwargs
|
||||
) -> CLIResult:
|
||||
if restart:
|
||||
if self._session:
|
||||
self._session.stop()
|
||||
self._session = _BashSession()
|
||||
await self._session.start()
|
||||
|
||||
return CLIResult(system="tool has been restarted.")
|
||||
|
||||
if self._session is None:
|
||||
self._session = _BashSession()
|
||||
await self._session.start()
|
||||
|
||||
if command is not None:
|
||||
return await self._session.run(command)
|
||||
|
||||
raise ToolError("no command provided.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
bash = Bash()
|
||||
rst = asyncio.run(bash.execute("ls -l"))
|
||||
print(rst)
|
||||
56
app/tools/execute_sql.py
Normal file
56
app/tools/execute_sql.py
Normal file
@ -0,0 +1,56 @@
|
||||
from datetime import datetime, date
|
||||
from decimal import Decimal
|
||||
from typing import Dict, Any, List
|
||||
import pymysql
|
||||
from langchain_core.tools import tool
|
||||
from app.agent import logger
|
||||
|
||||
@tool
|
||||
def execute_query(sql: str) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
执行SQL查询并返回结果。
|
||||
|
||||
注意事项:
|
||||
- 确保SQL语法正确,常见错误包括:counts() 应为 count(),表名拼写错误等
|
||||
- 如果收到错误响应,请根据错误信息修正SQL后重试
|
||||
- 返回格式:成功时返回查询结果列表,失败时返回包含error和message的字典
|
||||
|
||||
参数:
|
||||
sql: 要执行的SQL查询语句
|
||||
|
||||
返回:
|
||||
查询结果列表(成功)或错误信息字典(失败)
|
||||
"""
|
||||
dbconfig = {
|
||||
"host": "192.168.10.91",
|
||||
"port": 33062,
|
||||
"user": "root",
|
||||
"password": "123456",
|
||||
"db": "text2sql",
|
||||
"charset": "utf8mb4",
|
||||
}
|
||||
connection = None
|
||||
result = []
|
||||
try:
|
||||
connection = pymysql.connect(**dbconfig)
|
||||
with connection.cursor(pymysql.cursors.DictCursor) as cursor:
|
||||
cursor.execute(sql)
|
||||
result = cursor.fetchall()
|
||||
|
||||
# 类型转换
|
||||
for row in result:
|
||||
for key, value in row.items():
|
||||
if isinstance(value, (datetime, date)):
|
||||
row[key] = value.strftime("%Y-%m-%d %H:%M:%S") if isinstance(value, datetime) else value.strftime("%Y-%m-%d")
|
||||
elif isinstance(value, Decimal):
|
||||
row[key] = float(value)
|
||||
return result
|
||||
except Exception as e:
|
||||
error_msg = f"SQL执行失败: {str(e)}\n请检查SQL语法是否正确,并根据错误信息修正SQL。"
|
||||
logger.error(error_msg)
|
||||
# 返回包含错误信息的结果,而不是抛出异常
|
||||
# 使用字典格式以便模型理解错误
|
||||
return [{"error": True, "message": error_msg}]
|
||||
finally:
|
||||
if connection:
|
||||
connection.close()
|
||||
25
app/tools/terminate.py
Normal file
25
app/tools/terminate.py
Normal file
@ -0,0 +1,25 @@
|
||||
from app.tools.base import BaseTool
|
||||
|
||||
|
||||
_TERMINATE_DESCRIPTION = """当请求已满足或助手无法继续执行任务时终止交互。
|
||||
当你完成所有任务后,调用此工具来结束工作。"""
|
||||
|
||||
|
||||
class Terminate(BaseTool):
|
||||
name: str = "terminate"
|
||||
description: str = _TERMINATE_DESCRIPTION
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"description": "交互的完成状态。",
|
||||
"enum": ["success", "failure"],
|
||||
}
|
||||
},
|
||||
"required": ["status"],
|
||||
}
|
||||
|
||||
async def execute(self, status: str) -> str:
|
||||
"""完成当前执行"""
|
||||
return f"The interaction has been completed with status: {status}"
|
||||
158
app/tools/weather.py
Normal file
158
app/tools/weather.py
Normal file
@ -0,0 +1,158 @@
|
||||
import os
|
||||
import requests
|
||||
from typing import Dict, Union
|
||||
|
||||
from app.tools.base import BaseTool, ToolResult
|
||||
|
||||
|
||||
_WEATHER_DESCRIPTION = """
|
||||
调用中国气象局天气预报 API,获取指定省份和地点的完整当日天气信息(含基础预报、实时数据、气象预警)。
|
||||
|
||||
### 参数说明
|
||||
- user_id: str - 接口调用身份标识,需从 http://www.apihz.cn 注册获取,不可为空
|
||||
- user_key: str - 接口通讯秘钥,与 user_id 对应,注册后获取,不可为空
|
||||
- province: str - 查询省份/直辖市,建议去除"省""市"后缀
|
||||
- place: str - 查询城市/区/县,建议去除"市""区""县"后缀
|
||||
|
||||
### 返回说明
|
||||
成功时返回包含基础预报、实时天气和气象预警的完整信息;失败时返回错误码和错误消息。
|
||||
"""
|
||||
|
||||
|
||||
class GetWeatherByLocation(BaseTool):
|
||||
"""天气预报查询工具"""
|
||||
|
||||
name: str = "get_weather_by_location"
|
||||
description: str = _WEATHER_DESCRIPTION
|
||||
parameters: dict = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"user_id": {
|
||||
"type": "string",
|
||||
"description": "接口调用身份标识,需从 http://www.apihz.cn 注册获取",
|
||||
},
|
||||
"user_key": {
|
||||
"type": "string",
|
||||
"description": "接口通讯秘钥,与 user_id 对应",
|
||||
},
|
||||
"province": {
|
||||
"type": "string",
|
||||
"description": "查询省份/直辖市,建议去除'省''市'后缀",
|
||||
},
|
||||
"place": {
|
||||
"type": "string",
|
||||
"description": "查询城市/区/县,建议去除'市''区''县'后缀",
|
||||
},
|
||||
},
|
||||
"required": ["user_id", "user_key", "province", "place"],
|
||||
}
|
||||
|
||||
async def execute(
|
||||
self, user_id: str, user_key: str, province: str, place: str
|
||||
) -> ToolResult:
|
||||
"""执行天气查询"""
|
||||
api_url = "https://cn.apihz.cn/api/tianqi/tqyb.php"
|
||||
clean_province = province.replace("省", "").replace("市", "").strip()
|
||||
clean_place = place.replace("市", "").replace("区", "").replace("县", "").strip()
|
||||
final_user_id = os.getenv("APIHZ_ID", user_id)
|
||||
final_user_key = os.getenv("APIHZ_KEY", user_key)
|
||||
params = {
|
||||
"id": final_user_id,
|
||||
"key": final_user_key,
|
||||
"sheng": clean_province,
|
||||
"place": clean_place,
|
||||
}
|
||||
|
||||
try:
|
||||
# 发送请求
|
||||
response = requests.get(api_url, params=params, timeout=10)
|
||||
response.raise_for_status()
|
||||
|
||||
# 解析响应
|
||||
response_text = response.text.strip()
|
||||
if not response_text:
|
||||
return self.fail_response("API 响应为空,无法解析天气数据")
|
||||
|
||||
weather_data = response.json()
|
||||
if not isinstance(weather_data, dict):
|
||||
return self.fail_response(
|
||||
f"API 响应格式错误,不是有效字典(实际类型:{type(weather_data).__name__})"
|
||||
)
|
||||
|
||||
# 处理 API 返回的错误状态
|
||||
if weather_data.get("code") == 400:
|
||||
return self.fail_response(
|
||||
weather_data.get("msg", "API 返回错误,原因未知")
|
||||
)
|
||||
elif weather_data.get("code") != 200:
|
||||
return self.fail_response(
|
||||
f"API 返回非成功状态码:{weather_data.get('code', '未知')}"
|
||||
)
|
||||
|
||||
# 提取嵌套数据
|
||||
now_info = weather_data.get("nowinfo", {})
|
||||
if not isinstance(now_info, dict):
|
||||
now_info = {}
|
||||
|
||||
alarm_info = weather_data.get("alarm", {})
|
||||
if not isinstance(alarm_info, dict):
|
||||
alarm_info = {}
|
||||
|
||||
# 构造最终返回结果
|
||||
result = {
|
||||
# 基础预报字段
|
||||
"code": 200,
|
||||
"guo": weather_data.get("guo", ""),
|
||||
"sheng": weather_data.get("sheng", ""),
|
||||
"shi": weather_data.get("shi", ""),
|
||||
"name": weather_data.get("name", ""),
|
||||
"weather1": weather_data.get("weather1", ""),
|
||||
"weather2": weather_data.get("weather2", ""),
|
||||
"wd1": weather_data.get("wd1", ""),
|
||||
"wd2": weather_data.get("wd2", ""),
|
||||
"winddirection1": weather_data.get("winddirection1", ""),
|
||||
"winddirection2": weather_data.get("winddirection2", ""),
|
||||
"windleve1": weather_data.get("windleve1", ""),
|
||||
"windleve2": weather_data.get("windleve2", ""),
|
||||
"weather1img": weather_data.get("weather1img", ""),
|
||||
"weather2img": weather_data.get("weather2img", ""),
|
||||
"lon": weather_data.get("lon", ""),
|
||||
"lat": weather_data.get("lat", ""),
|
||||
"uptime": weather_data.get("uptime", ""),
|
||||
# 实时天气字段
|
||||
"now_precipitation": float(now_info.get("precipitation", 0.0)),
|
||||
"now_temperature": float(now_info.get("temperature", 0.0)),
|
||||
"now_pressure": int(now_info.get("pressure", 0)),
|
||||
"now_humidity": int(now_info.get("humidity", 0)),
|
||||
"now_windDirection": now_info.get("windDirection", ""),
|
||||
"now_windDirectionDegree": int(now_info.get("windDirectionDegree", 0)),
|
||||
"now_windSpeed": float(now_info.get("windSpeed", 0.0)),
|
||||
"now_windScale": now_info.get("windScale", ""),
|
||||
"now_feelst": float(now_info.get("feelst", 0.0)),
|
||||
"now_uptime": now_info.get("uptime", ""),
|
||||
# 预警字段
|
||||
"alarm_id": alarm_info.get("id", ""),
|
||||
"alarm_title": alarm_info.get("title", ""),
|
||||
"alarm_signaltype": alarm_info.get("signaltype", ""),
|
||||
"alarm_signallevel": alarm_info.get("signallevel", ""),
|
||||
"alarm_effective": alarm_info.get("effective", ""),
|
||||
"alarm_eventType": alarm_info.get("eventType", ""),
|
||||
"alarm_severity": alarm_info.get("severity", ""),
|
||||
"alarm_type": alarm_info.get("type", ""),
|
||||
}
|
||||
|
||||
return self.success_response(result)
|
||||
|
||||
except requests.exceptions.HTTPError as e:
|
||||
status_code = (
|
||||
response.status_code if "response" in locals() else "未知"
|
||||
)
|
||||
return self.fail_response(f"HTTP 请求错误(状态码:{status_code}):{str(e)}")
|
||||
except requests.exceptions.ConnectionError:
|
||||
return self.fail_response("网络连接错误:无法连接到天气 API 服务器")
|
||||
except requests.exceptions.Timeout:
|
||||
return self.fail_response("请求超时:API 服务器 10 秒内未响应")
|
||||
except ValueError as e:
|
||||
return self.fail_response(f"API 响应解析失败(JSON 格式错误):{str(e)}")
|
||||
except Exception as e:
|
||||
return self.fail_response(f"未知错误:{str(e)}(错误类型:{type(e).__name__})")
|
||||
0
app/utils/__init__.py
Normal file
0
app/utils/__init__.py
Normal file
89
app/utils/logger.py
Normal file
89
app/utils/logger.py
Normal file
@ -0,0 +1,89 @@
|
||||
import logging
|
||||
import sys
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class InterceptHandler(logging.Handler):
|
||||
def emit(self, record):
|
||||
# 获取对应的 Loguru 级别
|
||||
try:
|
||||
level = logger.level(record.levelname).name
|
||||
except ValueError:
|
||||
level = record.levelno
|
||||
|
||||
# 从 logging 记录中找到调用位置
|
||||
frame, depth = logging.currentframe(), 2
|
||||
while frame.f_code.co_filename == logging.__file__:
|
||||
frame = frame.f_back
|
||||
depth += 1
|
||||
|
||||
logger.opt(depth=depth, exception=record.exc_info).log(
|
||||
level, record.getMessage()
|
||||
)
|
||||
|
||||
|
||||
def setup_logger():
|
||||
"""
|
||||
配置 Loguru 日志记录器
|
||||
- 移除默认的 handler
|
||||
- 添加一个用于 INFO 级别日志的文件 sink
|
||||
- 添加一个用于 ERROR 级别日志的文件 sink
|
||||
- 保持控制台输出(级别为 DEBUG)
|
||||
"""
|
||||
# 检查是否已经初始化,防止重复初始化
|
||||
if hasattr(setup_logger, '_initialized') and setup_logger._initialized:
|
||||
return
|
||||
|
||||
# 1. 移除默认的控制台 handler
|
||||
logger.remove()
|
||||
|
||||
# 2. 添加一个新的控制台 handler,设置级别为 DEBUG
|
||||
logger.add(
|
||||
sys.stderr,
|
||||
level="INFO",
|
||||
format="<green>{time:YYYY-MM-DD HH:mm:ss.SSS}</green> | "
|
||||
"<level>{level: <8}</level> | "
|
||||
"<cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>",
|
||||
colorize=True, # 开启颜色
|
||||
)
|
||||
|
||||
# 3. 添加 INFO 日志文件 sink
|
||||
# - level="INFO": 只记录 INFO 及以上级别的日志
|
||||
# - filter: 确保 ERROR 级别的日志不会写入这个文件
|
||||
# - rotation: 每天午夜创建一个新文件
|
||||
# - retention: 只保留最近 14 天的日志
|
||||
# - enqueue=True: 异步写入,提升性能
|
||||
# - encoding="utf-8": 确保中文不会乱码
|
||||
logger.add(
|
||||
"logs/app.info.log",
|
||||
level="INFO",
|
||||
filter=lambda record: record["level"].name in ["INFO", "WARNING"],
|
||||
rotation="00:00", # 每天午夜轮转
|
||||
retention="14 days", # 保留 14 天
|
||||
enqueue=True, # 异步写入
|
||||
backtrace=True, # 记录堆栈信息
|
||||
diagnose=True, # 诊断信息
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# 4. 添加 ERROR 日志文件 sink
|
||||
# - level="ERROR": 只记录 ERROR 及以上级别的日志 (ERROR, CRITICAL)
|
||||
# - rotation: 文件大小超过 10 MB 时轮转
|
||||
# - retention: 只保留最近 30 天的日志
|
||||
logger.add(
|
||||
"logs/app.error.log",
|
||||
level="ERROR",
|
||||
rotation="10 MB", # 按文件大小轮转
|
||||
retention="30 days", # 保留 30 天
|
||||
enqueue=True, # 异步写入
|
||||
backtrace=True,
|
||||
diagnose=True,
|
||||
encoding="utf-8",
|
||||
)
|
||||
|
||||
# 拦截 logging,以便将所有日志都转发到 Loguru
|
||||
logging.basicConfig(handlers=[InterceptHandler()], level=0, force=True)
|
||||
|
||||
# 设置已初始化标志
|
||||
setup_logger._initialized = True
|
||||
logger.info("日志系统配置完成。")
|
||||
@ -2,7 +2,7 @@ import asyncio
|
||||
import logging
|
||||
import json
|
||||
import os
|
||||
from typing import List, Dict, Any, Generator
|
||||
from typing import List, Dict, Any
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
@ -215,7 +215,7 @@ class AutoToolChatSession:
|
||||
if __name__ == "__main__":
|
||||
config = {
|
||||
"sse_server_url": "http://127.0.0.1:8000/sse",
|
||||
"llm_model": os.getenv("LLM_MODEL", "qwen3-30b-a3b"),
|
||||
"llm_model": os.getenv("LLM_MODEL", "deepseek-v3"),
|
||||
"llm_api_key": os.getenv("LLM_API_KEY", "your-api-key"),
|
||||
"llm_base_url": os.getenv("LLM_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
}
|
||||
|
||||
@ -5,6 +5,13 @@ description = "Add your description here"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.12"
|
||||
dependencies = [
|
||||
"dashscope>=1.25.12",
|
||||
"deepagents>=0.4.3",
|
||||
"langchain>=1.2.10",
|
||||
"langchain-community>=0.4.1",
|
||||
"langchain-openai>=0.3.0",
|
||||
"loguru>=0.7.3",
|
||||
"mcp[cli]>=1.13.1",
|
||||
"openai>=1.102.0",
|
||||
"pymysql>=1.1.2",
|
||||
]
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user