172 lines
7.3 KiB
Python
172 lines
7.3 KiB
Python
import asyncio
|
||
import logging
|
||
import json
|
||
from typing import List, Dict, Any
|
||
import httpx # 用于添加httpx依赖用于调用Ollama API
|
||
from mcp.client.sse import sse_client
|
||
from mcp.client.session import ClientSession
|
||
from mcp.types import TextContent, AnyUrl
|
||
|
||
# 配置日志
|
||
logging.basicConfig(level=logging.INFO)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
|
||
class AutoToolChatSession:
|
||
def __init__(self, server_url: str, ollama_base_url: str = "http://localhost:11434",
|
||
llm_model: str = "qwen2.5:7b", api_key: str = "ollama"):
|
||
self.server_url = server_url
|
||
self.ollama_base_url = ollama_base_url # Ollama本地服务地址
|
||
self.llm_model = llm_model # 模型名称
|
||
self.api_key = api_key # Ollama的API密钥(默认为'ollama')
|
||
self.session: ClientSession | None = None
|
||
self.tools_description = ""
|
||
|
||
async def _get_llm_response(self, messages: List[Dict[str, str]]) -> str:
|
||
"""调用本地Ollama模型获取响应"""
|
||
# 构建Ollama API请求URL
|
||
url = f"{self.ollama_base_url}/api/chat"
|
||
|
||
# 转换消息格式以适配Ollama API(Ollama使用"role"和"content"字段,格式兼容)
|
||
payload = {
|
||
"model": self.llm_model,
|
||
"messages": messages,
|
||
"stream": False, # 非流式响应
|
||
"temperature": 0.7
|
||
}
|
||
|
||
# 发送请求
|
||
async with httpx.AsyncClient() as client:
|
||
try:
|
||
response = await client.post(
|
||
url,
|
||
json=payload,
|
||
headers={
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {self.api_key}" # 传递API密钥
|
||
}
|
||
)
|
||
response.raise_for_status()
|
||
data = response.json()
|
||
return data["message"]["content"]
|
||
except httpx.RequestError as e:
|
||
logging.error(f"Ollama API调用失败: {str(e)}")
|
||
return f"无法连接到本地模型服务,请检查Ollama是否运行。错误: {str(e)}"
|
||
|
||
async def _initialize_tools(self):
|
||
"""获取服务器工具列表并生成描述"""
|
||
assert self.session is not None, "会话未初始化"
|
||
|
||
tools = await self.session.list_tools()
|
||
tool_descriptions = []
|
||
for tool in tools.tools:
|
||
tool_info = (
|
||
f"- 工具名称: {tool.name}\n"
|
||
f" 描述: {tool.description}\n"
|
||
f" 参数: {json.dumps(tool.inputSchema.get('properties', {}), indent=2)}"
|
||
)
|
||
tool_descriptions.append(tool_info)
|
||
|
||
self.tools_description = "\n".join(tool_descriptions)
|
||
logger.info("已加载工具列表:\n%s", self.tools_description)
|
||
|
||
async def _process_tool_call(self, tool_call: Dict[str, Any]) -> str:
|
||
"""执行工具调用并返回结果"""
|
||
assert self.session is not None, "会话未初始化"
|
||
|
||
tool_name = tool_call["tool"]
|
||
arguments = tool_call["arguments"]
|
||
|
||
try:
|
||
response = await self.session.call_tool(tool_name, arguments)
|
||
# 提取结构化结果
|
||
if response.structuredContent is not None:
|
||
return json.dumps(response.structuredContent, ensure_ascii=False)
|
||
# 提取文本结果
|
||
for content in response.content:
|
||
if isinstance(content, TextContent):
|
||
return content.text
|
||
return "工具调用成功,但未返回内容"
|
||
except Exception as e:
|
||
return f"工具调用失败: {str(e)}"
|
||
|
||
async def start_chat(self):
|
||
"""启动聊天会话"""
|
||
async with sse_client(self.server_url) as (read_stream, write_stream):
|
||
async with ClientSession(read_stream, write_stream) as session:
|
||
self.session = session
|
||
await session.initialize()
|
||
await self._initialize_tools()
|
||
|
||
# 系统提示词:指导LLM如何选择工具
|
||
system_prompt = (
|
||
"你是一个智能助手,可以使用以下工具解决用户问题:\n"
|
||
f"{self.tools_description}\n\n"
|
||
"使用规则:\n"
|
||
"1. 如果需要使用工具,必须返回JSON格式:\n"
|
||
' {"tool": "工具名称", "arguments": {"参数名": "值"}}\n'
|
||
"2. 如果不需要工具,直接用自然语言回答\n"
|
||
"3. 工具返回结果后,你需要将结果整理为自然语言回答用户"
|
||
)
|
||
|
||
messages = [{"role": "system", "content": system_prompt}]
|
||
logger.info("聊天会话已启动,输入 'quit' 退出")
|
||
|
||
while True:
|
||
try:
|
||
user_input = input("你: ").strip()
|
||
if user_input.lower() in ["quit", "exit"]:
|
||
logger.info("退出会话...")
|
||
break
|
||
|
||
messages.append({"role": "user", "content": user_input})
|
||
llm_response = await self._get_llm_response(messages)
|
||
logger.info(f"LLM原始响应: {llm_response}")
|
||
|
||
# 尝试解析工具调用
|
||
try:
|
||
tool_call = json.loads(llm_response)
|
||
if "tool" in tool_call and "arguments" in tool_call:
|
||
logger.info(f"执行工具调用: {tool_call['tool']}")
|
||
tool_result = await self._process_tool_call(tool_call)
|
||
|
||
# 将工具结果送回LLM整理
|
||
messages.append({"role": "assistant", "content": llm_response})
|
||
messages.append({"role": "system", "content": f"工具返回: {tool_result}"})
|
||
|
||
final_response = await self._get_llm_response(messages)
|
||
print(f"助手: {final_response}")
|
||
messages.append({"role": "assistant", "content": final_response})
|
||
continue
|
||
except json.JSONDecodeError:
|
||
# 不是工具调用,直接返回
|
||
pass
|
||
|
||
# 直接回答
|
||
print(f"助手: {llm_response}")
|
||
messages.append({"role": "assistant", "content": llm_response})
|
||
|
||
except KeyboardInterrupt:
|
||
logger.info("\n退出会话...")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"发生错误: {str(e)}")
|
||
continue
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 配置参数(适配Ollama)
|
||
SSE_SERVER_URL = "http://127.0.0.1:8000/sse" # MCP服务器地址
|
||
OLLAMA_BASE_URL = "http://localhost:11434" # Ollama默认本地地址
|
||
LLM_MODEL = "qwen2.5:7b" # 模型名称
|
||
OLLAMA_API_KEY = "ollama" # Ollama的API密钥
|
||
|
||
chat_session = AutoToolChatSession(
|
||
server_url=SSE_SERVER_URL,
|
||
ollama_base_url=OLLAMA_BASE_URL,
|
||
llm_model=LLM_MODEL,
|
||
api_key=OLLAMA_API_KEY
|
||
)
|
||
asyncio.run(chat_session.start_chat())
|
||
|