from datetime import datetime, date from decimal import Decimal from typing import Dict, Any, List import pymysql from langchain_core.tools import tool from app.agent import logger @tool def execute_query(sql: str) -> List[Dict[str, Any]]: """ 执行SQL查询并返回结果。 注意事项: - 确保SQL语法正确,常见错误包括:counts() 应为 count(),表名拼写错误等 - 如果收到错误响应,请根据错误信息修正SQL后重试 - 返回格式:成功时返回查询结果列表,失败时返回包含error和message的字典 参数: sql: 要执行的SQL查询语句 返回: 查询结果列表(成功)或错误信息字典(失败) """ dbconfig = { "host": "192.168.10.91", "port": 33062, "user": "root", "password": "123456", "db": "text2sql", "charset": "utf8mb4", } connection = None result = [] try: connection = pymysql.connect(**dbconfig) with connection.cursor(pymysql.cursors.DictCursor) as cursor: cursor.execute(sql) result = cursor.fetchall() # 类型转换 for row in result: for key, value in row.items(): if isinstance(value, (datetime, date)): row[key] = value.strftime("%Y-%m-%d %H:%M:%S") if isinstance(value, datetime) else value.strftime("%Y-%m-%d") elif isinstance(value, Decimal): row[key] = float(value) return result except Exception as e: error_msg = f"SQL执行失败: {str(e)}\n请检查SQL语法是否正确,并根据错误信息修正SQL。" logger.error(error_msg) # 返回包含错误信息的结果,而不是抛出异常 # 使用字典格式以便模型理解错误 return [{"error": True, "message": error_msg}] finally: if connection: connection.close()