- 新增 OpenAIClientModel 类,用于调用 OpenAI 兼容模型 - 重构 AutoToolChatSession 类,支持 OpenAI 兼容模型- 增加了更多日志输出,以便调试和跟踪程序执行流程 - 优化了工具调用和结果处理的逻辑- 调整了环境变量加载方式,使用 dotenv 库
241 lines
11 KiB
Python
241 lines
11 KiB
Python
import asyncio
|
||
import logging
|
||
import json
|
||
import os
|
||
from typing import List, Dict, Any, Generator
|
||
|
||
from mcp import ClientSession
|
||
from mcp.client.sse import sse_client
|
||
from mcp.types import TextContent
|
||
from openai import OpenAI
|
||
from dotenv import load_dotenv, find_dotenv
|
||
|
||
# ---------------------- 第一步:增强日志配置(显示更详细上下文) ----------------------
|
||
# 调整日志格式,包含时间、模块、日志级别,便于追溯步骤
|
||
logging.basicConfig(
|
||
level=logging.INFO,
|
||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", # 新增时间戳
|
||
datefmt="%Y-%m-%d %H:%M:%S" # 时间格式
|
||
)
|
||
logger = logging.getLogger(__name__)
|
||
|
||
_ = load_dotenv(find_dotenv())
|
||
|
||
|
||
class OpenAIClientModel:
|
||
"""兼容OpenAI格式的模型调用类(不变)"""
|
||
|
||
def __init__(self, model_name: str, api_key: str, base_url: str):
|
||
self.model_name = model_name
|
||
self.api_key = api_key
|
||
self.base_url = base_url
|
||
self.client = OpenAI(
|
||
api_key=self.api_key,
|
||
base_url=self.base_url,
|
||
timeout=300,
|
||
)
|
||
logger.info(f"已初始化OpenAI兼容模型:{model_name}(BaseURL:{base_url})")
|
||
|
||
def get_response(self, messages: List[Dict[str, str]], stream: bool = False) -> Any:
|
||
try:
|
||
valid_messages = [m for m in messages if m.get("content")]
|
||
# ---------------------- 新增日志:打印调用模型的完整提示词(格式化便于阅读) ----------------------
|
||
logger.info(
|
||
f"开始调用模型【{self.model_name}】,完整提示词:\n"
|
||
f"{json.dumps(valid_messages, ensure_ascii=False, indent=2)}"
|
||
)
|
||
|
||
extra_params = {}
|
||
if not stream and "qwen3" in self.model_name.lower():
|
||
extra_params["enable_thinking"] = False
|
||
|
||
response = self.client.chat.completions.create(
|
||
model=self.model_name,
|
||
messages=valid_messages,
|
||
stream=stream,
|
||
temperature=0.7,
|
||
timeout=300,
|
||
extra_body=extra_params
|
||
)
|
||
logger.info(f"模型【{self.model_name}】调用成功,已获取响应")
|
||
return response
|
||
except Exception as e:
|
||
error_msg = f"OpenAI兼容模型调用失败:{str(e)}"
|
||
logger.error(error_msg)
|
||
raise Exception(error_msg)
|
||
|
||
|
||
class AutoToolChatSession:
|
||
def __init__(self, server_url: str,
|
||
llm_model: str,
|
||
llm_api_key: str,
|
||
llm_base_url: str):
|
||
self.server_url = server_url
|
||
self.session: ClientSession | None = None
|
||
self.tools_description = ""
|
||
self.llm_client = OpenAIClientModel(
|
||
model_name=llm_model,
|
||
api_key=llm_api_key,
|
||
base_url=llm_base_url
|
||
)
|
||
|
||
async def _get_llm_response(self, messages: List[Dict[str, str]]) -> str:
|
||
"""调用模型(不变,日志已在OpenAIClientModel中新增)"""
|
||
try:
|
||
response = await asyncio.to_thread(
|
||
self.llm_client.get_response,
|
||
messages=messages,
|
||
stream=False
|
||
)
|
||
llm_content = response.choices[0].message.content.strip()
|
||
logger.info(f"模型返回原始响应:\n{llm_content}") # 保留原有响应日志
|
||
return llm_content
|
||
except Exception as e:
|
||
error_msg = f"模型响应获取失败:{str(e)}"
|
||
logger.error(error_msg)
|
||
return error_msg
|
||
|
||
async def _initialize_tools(self):
|
||
"""获取工具列表(不变)"""
|
||
assert self.session is not None, "会话未初始化"
|
||
tools = await self.session.list_tools()
|
||
tool_descriptions = []
|
||
for tool in tools.tools:
|
||
tool_info = (
|
||
f"- 工具名称: {tool.name}\n"
|
||
f" 描述: {tool.description}\n"
|
||
f" 参数: {json.dumps(tool.inputSchema.get('properties', {}), indent=2, ensure_ascii=False)}"
|
||
)
|
||
tool_descriptions.append(tool_info)
|
||
self.tools_description = "\n".join(tool_descriptions)
|
||
logger.info("已加载MCP工具列表:\n%s", self.tools_description)
|
||
|
||
async def _process_tool_call(self, tool_call: Dict[str, Any]) -> str:
|
||
"""执行工具调用(新增工具返回结果日志)"""
|
||
assert self.session is not None, "会话未初始化"
|
||
tool_name = tool_call["tool"]
|
||
arguments = tool_call["arguments"]
|
||
# ---------------------- 新增日志:打印工具调用的参数 ----------------------
|
||
logger.info(
|
||
f"开始调用工具【{tool_name}】,调用参数:\n"
|
||
f"{json.dumps(arguments, ensure_ascii=False, indent=2)}"
|
||
)
|
||
try:
|
||
response = await self.session.call_tool(tool_name, arguments)
|
||
# 提取工具结果
|
||
if response.structuredContent is not None:
|
||
tool_result = json.dumps(response.structuredContent, ensure_ascii=False, indent=2)
|
||
elif response.content:
|
||
tool_result = next((c.text for c in response.content if isinstance(c, TextContent)), "无文本结果")
|
||
else:
|
||
tool_result = "工具调用成功,但未返回内容"
|
||
|
||
# ---------------------- 新增日志:打印工具返回的完整结果 ----------------------
|
||
logger.info(
|
||
f"工具【{tool_name}】调用成功,返回结果:\n"
|
||
f"{tool_result}"
|
||
)
|
||
return tool_result
|
||
except Exception as e:
|
||
error_msg = f"工具【{tool_name}】调用失败: {str(e)}"
|
||
logger.error(error_msg) # 新增工具调用失败的详细日志
|
||
return error_msg
|
||
|
||
async def start_chat(self):
|
||
"""启动聊天会话(补充步骤日志)"""
|
||
logger.info(f"开始连接MCP服务,SSE地址:{self.server_url}")
|
||
async with sse_client(self.server_url) as (read_stream, write_stream):
|
||
async with ClientSession(read_stream, write_stream) as session:
|
||
self.session = session
|
||
await session.initialize()
|
||
logger.info("MCP会话初始化成功,开始加载工具列表")
|
||
await self._initialize_tools()
|
||
|
||
# 系统提示词(不变,日志已在模型调用时打印)
|
||
system_prompt = (
|
||
"你是一个智能助手,可以使用以下工具解决用户问题:\n"
|
||
f"{self.tools_description}\n\n"
|
||
"使用规则:\n"
|
||
"1. 如果需要使用工具,必须返回JSON格式:\n"
|
||
' {"tool": "工具名称", "arguments": {"参数名": "值"}}\n'
|
||
"2. 如果不需要工具,直接用自然语言回答\n"
|
||
"3. 工具返回结果后,你需要将结果整理为自然语言回答用户"
|
||
)
|
||
messages = [{"role": "system", "content": system_prompt}]
|
||
logger.info("聊天会话已启动,输入 'quit' 退出")
|
||
|
||
while True:
|
||
try:
|
||
user_input = input("你: ").strip()
|
||
if user_input.lower() in ["quit", "exit"]:
|
||
logger.info("用户输入退出指令,终止会话")
|
||
break
|
||
|
||
# 构建用户消息
|
||
messages.append({"role": "user", "content": user_input})
|
||
logger.info(f"接收用户输入:{user_input},准备调用模型")
|
||
|
||
# 第一次调用模型(判断是否需要工具)
|
||
llm_response = await self._get_llm_response(messages)
|
||
|
||
# 尝试解析工具调用
|
||
try:
|
||
tool_call = json.loads(llm_response)
|
||
if "tool" in tool_call and "arguments" in tool_call:
|
||
# 调用工具(结果日志已在_process_tool_call中新增)
|
||
tool_result = await self._process_tool_call(tool_call)
|
||
|
||
# 将工具结果回传给LLM,准备二次调用
|
||
messages.append({"role": "assistant", "content": llm_response})
|
||
messages.append({"role": "system", "content": f"工具返回: {tool_result}"})
|
||
logger.info("工具结果已追加到对话历史,准备二次调用模型整理结果")
|
||
|
||
# 第二次调用模型(整理工具结果)
|
||
final_response = await self._get_llm_response(messages)
|
||
print(f"助手: {final_response}")
|
||
messages.append({"role": "assistant", "content": final_response})
|
||
continue
|
||
except json.JSONDecodeError:
|
||
# 非工具调用,直接返回
|
||
logger.info("模型返回非JSON格式,无需调用工具,直接返回自然语言响应")
|
||
pass
|
||
|
||
# 直接打印响应
|
||
print(f"助手: {llm_response}")
|
||
messages.append({"role": "assistant", "content": llm_response})
|
||
|
||
except KeyboardInterrupt:
|
||
logger.info("\n用户强制中断会话,退出程序")
|
||
break
|
||
except Exception as e:
|
||
logger.error(f"会话执行异常: {str(e)}", exc_info=True) # 新增exc_info=True,打印异常堆栈
|
||
continue
|
||
|
||
|
||
if __name__ == "__main__":
|
||
config = {
|
||
"sse_server_url": "http://127.0.0.1:8000/sse",
|
||
"llm_model": os.getenv("LLM_MODEL", "qwen3-30b-a3b"),
|
||
"llm_api_key": os.getenv("LLM_API_KEY", "your-api-key"),
|
||
"llm_base_url": os.getenv("LLM_BASE_URL", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||
}
|
||
# config = {
|
||
# "sse_server_url": "http://127.0.0.1:8000/sse",
|
||
# "llm_model": os.getenv("LLM_MODEL", "qwen2.5:7b"),
|
||
# "llm_api_key": os.getenv("LLM_API_KEY", "your-api-key"),
|
||
# "llm_base_url": os.getenv("LLM_BASE_URL", "http://localhost:11434/v1")
|
||
# }
|
||
logger.info(
|
||
f"程序启动配置:\n"
|
||
f"MCP SSE地址: {config['sse_server_url']}\n"
|
||
f"LLM模型名: {config['llm_model']}\n"
|
||
)
|
||
|
||
chat_session = AutoToolChatSession(
|
||
server_url=config["sse_server_url"],
|
||
llm_model=config["llm_model"],
|
||
llm_api_key=config["llm_api_key"],
|
||
llm_base_url=config["llm_base_url"]
|
||
)
|
||
|
||
asyncio.run(chat_session.start_chat()) |