import asyncio import logging import json from typing import List, Dict, Any import httpx # 用于添加httpx依赖用于调用Ollama API from mcp.client.sse import sse_client from mcp.client.session import ClientSession from mcp.types import TextContent, AnyUrl # 配置日志 logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class AutoToolChatSession: def __init__(self, server_url: str, ollama_base_url: str = "http://localhost:11434", llm_model: str = "qwen2.5:7b", api_key: str = "ollama"): self.server_url = server_url self.ollama_base_url = ollama_base_url # Ollama本地服务地址 self.llm_model = llm_model # 模型名称 self.api_key = api_key # Ollama的API密钥(默认为'ollama') self.session: ClientSession | None = None self.tools_description = "" async def _get_llm_response(self, messages: List[Dict[str, str]]) -> str: """调用本地Ollama模型获取响应""" # 构建Ollama API请求URL url = f"{self.ollama_base_url}/api/chat" # 转换消息格式以适配Ollama API(Ollama使用"role"和"content"字段,格式兼容) payload = { "model": self.llm_model, "messages": messages, "stream": False, # 非流式响应 "temperature": 0.7 } # 发送请求 async with httpx.AsyncClient() as client: try: response = await client.post( url, json=payload, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {self.api_key}" # 传递API密钥 } ) response.raise_for_status() data = response.json() return data["message"]["content"] except httpx.RequestError as e: logging.error(f"Ollama API调用失败: {str(e)}") return f"无法连接到本地模型服务,请检查Ollama是否运行。错误: {str(e)}" async def _initialize_tools(self): """获取服务器工具列表并生成描述""" assert self.session is not None, "会话未初始化" tools = await self.session.list_tools() tool_descriptions = [] for tool in tools.tools: tool_info = ( f"- 工具名称: {tool.name}\n" f" 描述: {tool.description}\n" f" 参数: {json.dumps(tool.inputSchema.get('properties', {}), indent=2)}" ) tool_descriptions.append(tool_info) self.tools_description = "\n".join(tool_descriptions) logger.info("已加载工具列表:\n%s", self.tools_description) async def _process_tool_call(self, tool_call: Dict[str, Any]) -> str: """执行工具调用并返回结果""" assert self.session is not None, "会话未初始化" tool_name = tool_call["tool"] arguments = tool_call["arguments"] try: response = await self.session.call_tool(tool_name, arguments) # 提取结构化结果 if response.structuredContent is not None: return json.dumps(response.structuredContent, ensure_ascii=False) # 提取文本结果 for content in response.content: if isinstance(content, TextContent): return content.text return "工具调用成功,但未返回内容" except Exception as e: return f"工具调用失败: {str(e)}" async def start_chat(self): """启动聊天会话""" async with sse_client(self.server_url) as (read_stream, write_stream): async with ClientSession(read_stream, write_stream) as session: self.session = session await session.initialize() await self._initialize_tools() # 系统提示词:指导LLM如何选择工具 system_prompt = ( "你是一个智能助手,可以使用以下工具解决用户问题:\n" f"{self.tools_description}\n\n" "使用规则:\n" "1. 如果需要使用工具,必须返回JSON格式:\n" ' {"tool": "工具名称", "arguments": {"参数名": "值"}}\n' "2. 如果不需要工具,直接用自然语言回答\n" "3. 工具返回结果后,你需要将结果整理为自然语言回答用户" ) messages = [{"role": "system", "content": system_prompt}] logger.info("聊天会话已启动,输入 'quit' 退出") while True: try: user_input = input("你: ").strip() if user_input.lower() in ["quit", "exit"]: logger.info("退出会话...") break messages.append({"role": "user", "content": user_input}) llm_response = await self._get_llm_response(messages) logger.info(f"LLM原始响应: {llm_response}") # 尝试解析工具调用 try: tool_call = json.loads(llm_response) if "tool" in tool_call and "arguments" in tool_call: logger.info(f"执行工具调用: {tool_call['tool']}") tool_result = await self._process_tool_call(tool_call) # 将工具结果送回LLM整理 messages.append({"role": "assistant", "content": llm_response}) messages.append({"role": "system", "content": f"工具返回: {tool_result}"}) final_response = await self._get_llm_response(messages) print(f"助手: {final_response}") messages.append({"role": "assistant", "content": final_response}) continue except json.JSONDecodeError: # 不是工具调用,直接返回 pass # 直接回答 print(f"助手: {llm_response}") messages.append({"role": "assistant", "content": llm_response}) except KeyboardInterrupt: logger.info("\n退出会话...") break except Exception as e: logger.error(f"发生错误: {str(e)}") continue if __name__ == "__main__": # 配置参数(适配Ollama) SSE_SERVER_URL = "http://127.0.0.1:8000/sse" # MCP服务器地址 OLLAMA_BASE_URL = "http://localhost:11434" # Ollama默认本地地址 LLM_MODEL = "qwen2.5:7b" # 模型名称 OLLAMA_API_KEY = "ollama" # Ollama的API密钥 chat_session = AutoToolChatSession( server_url=SSE_SERVER_URL, ollama_base_url=OLLAMA_BASE_URL, llm_model=LLM_MODEL, api_key=OLLAMA_API_KEY ) asyncio.run(chat_session.start_chat())