- 实现了基于LangChain的MCP Agent,支持连接MCP服务器调用工具 - 添加了环境配置文件(.env),包含LLM模型和API配置信息 - 创建了完整的工具系统,包括BaseTool基类和Bash、Terminate、Add等工具 - 集成了天气查询工具,支持通过中国气象局API获取天气预报信息 - 实现了交互式对话功能,支持多轮工具调用和结果处理 - 添加了详细的CLAUDE.md开发指导文档
404 lines
14 KiB
Python
404 lines
14 KiB
Python
"""
|
||
使用 LangChain create_agent 创建 MCP Agent 客户端
|
||
|
||
本模块通过 LangChain 的 create_agent 函数创建智能体,
|
||
并连接 MCP 服务器调用其注册的工具(bash, terminate, add, weather 等)。
|
||
"""
|
||
import asyncio
|
||
import os
|
||
import json
|
||
import logging
|
||
from typing import Any, Dict, List, Optional
|
||
|
||
from dotenv import load_dotenv, find_dotenv
|
||
|
||
from langchain.agents import create_agent
|
||
from langchain_core.tools import BaseTool, ToolException
|
||
from langchain_openai import ChatOpenAI
|
||
|
||
from mcp import ClientSession
|
||
from mcp.client.sse import sse_client
|
||
from mcp.types import TextContent
|
||
from pydantic.v1 import BaseModel, Field
|
||
|
||
# 配置日志
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||
datefmt="%Y-%m-%d %H:%M:%S"
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_ = load_dotenv(find_dotenv())
|
||
|
||
|
||
|
||
from typing import Type
|
||
|
||
class MCPTool(BaseTool):
|
||
"""将 MCP 工具转换为 LangChain 工具的适配器类"""
|
||
|
||
name: str = ""
|
||
description: str = ""
|
||
session: Optional[ClientSession] = None
|
||
args_schema: Optional[Type[BaseModel]] = None
|
||
|
||
def __init__(
|
||
self,
|
||
name: str,
|
||
description: str,
|
||
session: ClientSession,
|
||
args_schema: Optional[Type[BaseModel]] = None,
|
||
**kwargs: Any
|
||
):
|
||
super().__init__(
|
||
name=name,
|
||
description=description,
|
||
session=session,
|
||
args_schema=args_schema,
|
||
**kwargs
|
||
)
|
||
|
||
def _run(self, **kwargs: Any) -> Any:
|
||
"""同步执行工具(通过 asyncio 运行)"""
|
||
return asyncio.run(self._async_run(**kwargs))
|
||
|
||
async def _arun(self, **kwargs: Any) -> Any:
|
||
"""异步执行工具"""
|
||
return await self._async_run(**kwargs)
|
||
|
||
async def _async_run(self, **kwargs: Any) -> Any:
|
||
"""执行 MCP 工具调用"""
|
||
if self.session is None:
|
||
raise ToolException(f"MCP session is not initialized for tool: {self.name}")
|
||
|
||
try:
|
||
logger.info(f"Calling MCP tool: {self.name} with args: {kwargs}")
|
||
response = await self.session.call_tool(self.name, kwargs)
|
||
|
||
# 提取结果
|
||
if response.structuredContent is not None:
|
||
result = json.dumps(response.structuredContent, ensure_ascii=False)
|
||
elif response.content:
|
||
result = "\n".join(
|
||
c.text for c in response.content if isinstance(c, TextContent)
|
||
)
|
||
else:
|
||
result = "工具执行成功,无返回内容"
|
||
|
||
# 限制日志长度
|
||
log_result = result[:500] if len(result) > 500 else result
|
||
logger.info(f"MCP tool {self.name} returned: {log_result}")
|
||
return result
|
||
|
||
except Exception as e:
|
||
error_msg = f"工具 {self.name} 执行失败:{str(e)}"
|
||
logger.error(error_msg)
|
||
raise ToolException(error_msg)
|
||
|
||
|
||
class MCPLangChainAgent:
|
||
"""基于 LangChain 的 MCP Agent 封装类"""
|
||
|
||
def __init__(
|
||
self,
|
||
sse_server_url: str = "http://127.0.0.1:8000/sse",
|
||
llm: Optional[Any] = None,
|
||
):
|
||
"""
|
||
初始化 MCP Agent
|
||
|
||
Args:
|
||
sse_server_url: MCP 服务器的 SSE 地址
|
||
llm: LangChain 兼容的 LLM 实例
|
||
"""
|
||
self.sse_server_url = sse_server_url
|
||
self.llm = llm
|
||
self.session: Optional[ClientSession] = None
|
||
self.tools: List[BaseTool] = []
|
||
self.agent = None
|
||
self._read_stream = None
|
||
self._write_stream = None
|
||
|
||
def _generate_args_schema(self, tool_name: str, input_schema: dict) -> Optional[Type[BaseModel]]:
|
||
"""为 MCP 工具生成 LangChain 兼容的参数 schema"""
|
||
if not input_schema or not input_schema.get("properties"):
|
||
return None
|
||
|
||
properties = input_schema.get("properties", {})
|
||
required_fields = input_schema.get("required", [])
|
||
|
||
# 动态创建字段定义
|
||
fields = {}
|
||
for prop_name, prop_details in properties.items():
|
||
field_type = self._map_json_type_to_python(prop_details.get("type", "string"))
|
||
field_description = prop_details.get("description", "")
|
||
|
||
# 根据字段是否必需来设置默认值
|
||
if prop_name in required_fields:
|
||
fields[prop_name] = (field_type, Field(..., description=field_description))
|
||
else:
|
||
fields[prop_name] = (field_type, Field(None, description=field_description))
|
||
|
||
# 动态创建 Pydantic 模型类
|
||
schema_class = type(
|
||
f"{tool_name.title()}Schema",
|
||
(BaseModel,),
|
||
{"__annotations__": {k: v[0] for k, v in fields.items()}, **{k: v[1] for k, v in fields.items()}}
|
||
)
|
||
return schema_class
|
||
|
||
def _map_json_type_to_python(self, json_type: str):
|
||
"""将 JSON Schema 类型映射到 Python 类型"""
|
||
type_mapping = {
|
||
"string": str,
|
||
"integer": int,
|
||
"number": float,
|
||
"boolean": bool,
|
||
"array": list,
|
||
"object": dict
|
||
}
|
||
return type_mapping.get(json_type, str)
|
||
|
||
async def connect(self) -> None:
|
||
"""连接到 MCP 服务器并加载工具"""
|
||
logger.info(f"Connecting to MCP server: {self.sse_server_url}")
|
||
|
||
# 创建 SSE 连接 - 使用同步方式进入上下文
|
||
self._sse_cm = sse_client(self.sse_server_url)
|
||
self._read_stream, self._write_stream = await self._sse_cm.__aenter__()
|
||
|
||
# 创建并初始化会话
|
||
self._session_cm = ClientSession(self._read_stream, self._write_stream)
|
||
self.session = await self._session_cm.__aenter__()
|
||
await self.session.initialize()
|
||
|
||
# 获取工具列表
|
||
tools_response = await self.session.list_tools()
|
||
logger.info(f"Retrieved {len(tools_response.tools)} tools from MCP server")
|
||
|
||
# 转换 MCP 工具为 LangChain 工具
|
||
self.tools = []
|
||
for tool in tools_response.tools:
|
||
# 生成参数 schema 类
|
||
args_schema_class = self._generate_args_schema(tool.name, tool.inputSchema)
|
||
|
||
mcp_tool = MCPTool(
|
||
name=tool.name,
|
||
description=tool.description or "No description provided",
|
||
session=self.session,
|
||
args_schema=args_schema_class
|
||
)
|
||
self.tools.append(mcp_tool)
|
||
logger.info(f"Loaded tool: {tool.name} - {tool.description[:50]}...")
|
||
|
||
logger.info(f"Successfully connected and loaded {len(self.tools)} tools")
|
||
|
||
async def disconnect(self) -> None:
|
||
"""断开 MCP 服务器连接"""
|
||
try:
|
||
if self.session:
|
||
await self._session_cm.__aexit__(None, None, None)
|
||
self.session = None
|
||
if hasattr(self, '_sse_cm') and self._sse_cm:
|
||
await self._sse_cm.__aexit__(None, None, None)
|
||
self._read_stream = None
|
||
self._write_stream = None
|
||
logger.info("Disconnected from MCP server")
|
||
except Exception as e:
|
||
logger.warning(f"Error during disconnect: {e}")
|
||
|
||
def create_agent(self, system_message: Optional[str] = None) -> Any:
|
||
"""
|
||
创建 LangChain Agent
|
||
|
||
Args:
|
||
system_message: 系统提示词
|
||
|
||
Returns:
|
||
Agent 实例
|
||
"""
|
||
if self.llm is None:
|
||
raise ValueError(
|
||
"LLM is not set. Please provide a LangChain compatible LLM instance."
|
||
)
|
||
|
||
if not self.tools:
|
||
raise ValueError("No tools available. Please call connect() first.")
|
||
|
||
# 系统提示词
|
||
system_prompt = system_message or (
|
||
"你是一个智能助手,可以使用提供的工具来帮助用户解决问题。"
|
||
"请根据用户的问题,合理选择并使用工具。"
|
||
"如果不需要使用工具,直接回答用户的问题。"
|
||
)
|
||
|
||
# 使用 LangChain 1.2+ 的 create_agent API
|
||
self.agent = create_agent(
|
||
model=self.llm,
|
||
tools=self.tools,
|
||
system_prompt=system_prompt,
|
||
)
|
||
|
||
logger.info("Agent created successfully")
|
||
return self.agent
|
||
|
||
async def invoke(self, input_text: str) -> Dict[str, Any]:
|
||
"""
|
||
执行 Agent
|
||
|
||
Args:
|
||
input_text: 用户输入
|
||
|
||
Returns:
|
||
Agent 执行结果
|
||
"""
|
||
if self.agent is None:
|
||
raise ValueError("Agent is not created. Please call create_agent() first.")
|
||
|
||
logger.info(f"Invoking agent with input: {input_text}")
|
||
result = await self.agent.ainvoke({"input": input_text})
|
||
logger.info(f"Agent execution completed.")
|
||
logger.info(f"Raw agent result type: {type(result)}, value: {result}")
|
||
return result
|
||
|
||
async def __aenter__(self):
|
||
"""异步上下文管理器入口"""
|
||
await self.connect()
|
||
return self
|
||
|
||
async def __aexit__(self, exc_type, exc_val, exc_tb):
|
||
"""异步上下文管理器出口"""
|
||
await self.disconnect()
|
||
return False # 不抑制异常
|
||
|
||
|
||
def create_langchain_llm(
|
||
model_name: Optional[str] = None,
|
||
api_key: Optional[str] = None,
|
||
base_url: Optional[str] = None,
|
||
temperature: float = 0.7,
|
||
**kwargs: Any
|
||
) -> ChatOpenAI:
|
||
"""
|
||
创建 LangChain 兼容的 ChatOpenAI 实例
|
||
|
||
Args:
|
||
model_name: 模型名称
|
||
api_key: API 密钥
|
||
base_url: API 基础 URL
|
||
temperature: 温度参数
|
||
**kwargs: 其他参数
|
||
|
||
Returns:
|
||
LangChain ChatOpenAI 实例
|
||
"""
|
||
llm_config = {
|
||
"model": model_name or os.getenv("LLM_MODEL", "deepseek-v3"),
|
||
"api_key": api_key or os.getenv("LLM_API_KEY", "your-api-key"),
|
||
"base_url": base_url or os.getenv("LLM_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1"),
|
||
"temperature": temperature,
|
||
"timeout": 300,
|
||
}
|
||
|
||
# 合并额外参数
|
||
llm_config.update(kwargs)
|
||
|
||
logger.info(f"Creating LLM with model: {llm_config['model']}, base_url: {llm_config['base_url']}")
|
||
return ChatOpenAI(**llm_config)
|
||
|
||
|
||
async def start_interactive_chat():
|
||
"""启动交互式对话"""
|
||
print("\n" + "=" * 60)
|
||
print("MCP LangChain Agent - 交互式对话")
|
||
print("=" * 60)
|
||
print("可用工具:add, bash, terminate, get_weather_by_location")
|
||
print("输入 'quit' 或 'exit' 退出")
|
||
print("=" * 60 + "\n")
|
||
|
||
# 创建 LLM 实例
|
||
llm = create_langchain_llm()
|
||
|
||
# 创建并运行 Agent
|
||
async with MCPLangChainAgent(
|
||
sse_server_url="http://127.0.0.1:8000/sse",
|
||
llm=llm
|
||
) as agent_wrapper:
|
||
|
||
# 创建 Agent
|
||
agent = agent_wrapper.create_agent(
|
||
system_message=(
|
||
"你是一个智能助手,可以使用以下工具:\n"
|
||
"- add(a, b): 计算两个整数的和\n"
|
||
"- bash(command): 执行 bash 命令\n"
|
||
"- terminate(status): 终止交互(status: 'success' 或 'failure')\n"
|
||
"- get_weather_by_location(user_id, user_key, province, place): 获取天气预报\n"
|
||
"请根据用户需求合理选择工具。"
|
||
)
|
||
)
|
||
|
||
while True:
|
||
try:
|
||
user_input = input("你:").strip()
|
||
|
||
if user_input.lower() in ["quit", "exit"]:
|
||
logger.info("用户退出")
|
||
break
|
||
|
||
if not user_input:
|
||
continue
|
||
|
||
print("\n助手:思考中...", end="\r")
|
||
|
||
# 执行 Agent
|
||
result = await agent_wrapper.invoke(user_input)
|
||
|
||
# 提取最终输出
|
||
logger.info(f"Processing result: {result} (type: {type(result)})")
|
||
|
||
# 检查结果结构并相应地提取输出
|
||
if isinstance(result, dict):
|
||
# 如果结果包含 messages 键(LangChain agent 的典型返回格式)
|
||
if "messages" in result and isinstance(result["messages"], list):
|
||
# 从 messages 列表中提取最后一个消息的内容
|
||
messages = result["messages"]
|
||
if messages:
|
||
# 获取最后一个消息对象
|
||
last_message = messages[-1]
|
||
# 检查是否是 AIMessage 对象并尝试提取 content
|
||
if hasattr(last_message, 'content'):
|
||
output = str(last_message.content) if last_message.content is not None else "无响应内容"
|
||
else:
|
||
output = str(last_message) if last_message is not None else "无响应内容"
|
||
else:
|
||
output = "无响应内容"
|
||
else:
|
||
# 尝试多种可能的键
|
||
output = (result.get("output") or
|
||
result.get("result") or
|
||
result.get("response") or
|
||
result.get("content") or
|
||
result.get("answer") or
|
||
"无响应内容")
|
||
else:
|
||
# 如果结果不是字典,直接转换为字符串
|
||
output = str(result) if result is not None else "无响应内容"
|
||
|
||
# 显示结果
|
||
print(f"助手:{output}\n")
|
||
|
||
except KeyboardInterrupt:
|
||
logger.info("\n用户强制中断")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"执行错误:{str(e)}", exc_info=True)
|
||
print(f"错误:{str(e)}\n")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
try:
|
||
asyncio.run(start_interactive_chat())
|
||
except KeyboardInterrupt:
|
||
print("\n再见!")
|