using System.Text.Json;
using NewLife.Studio.AI.Safety;
using NewLife.Log;
namespace NewLife.Studio.AI.ToolCalling.BuiltInTools;
/// <summary>内置数据库工具的注册</summary>
public static class BuiltInDatabaseTools
{
/// <summary>注册所有内置数据库工具到 ToolRegistry</summary>
public static void RegisterAll(ToolRegistry registry,
Func<Task<string>> listConnections,
Func<string, Task<string>> openDatabase,
Func<string, Task<string>> listTables,
Func<string, Task<string>> describeTable,
Func<string, Task<string>> executeSelect,
Func<string, int, Task<string>> sampleData)
{
// connections.list
registry.Register(
"connections.list",
"列出所有已保存的数据库连接(脱敏显示)",
new { type = "object", properties = new { } },
async _ => await listConnections()
);
// db.open
registry.Register(
"db.open",
"打开指定数据库连接,传入连接名称,返回会话 ID",
new
{
type = "object",
properties = new
{
connection_name = new { type = "string", description = "连接名称" }
},
required = new[] { "connection_name" }
},
async args =>
{
var json = JsonDocument.Parse(args);
var name = json.RootElement.GetProperty("connection_name").GetString() ?? "";
return await openDatabase(name);
}
);
// schema.tables
registry.Register(
"schema.tables",
"列出当前已打开数据库的所有表",
new { type = "object", properties = new { } },
async _ => await listTables("current")
);
// schema.table
registry.Register(
"schema.table",
"查看指定表的列详情",
new
{
type = "object",
properties = new
{
table_name = new { type = "string", description = "表名" }
},
required = new[] { "table_name" }
},
async args =>
{
var json = JsonDocument.Parse(args);
var table = json.RootElement.GetProperty("table_name").GetString() ?? "";
return await describeTable(table);
}
);
// query.select
registry.Register(
"query.select",
"执行 SELECT 查询(经安全过滤,仅允许只读)",
new
{
type = "object",
properties = new
{
sql = new { type = "string", description = "SELECT 语句" }
},
required = new[] { "sql" }
},
async args =>
{
var json = JsonDocument.Parse(args);
var sql = json.RootElement.GetProperty("sql").GetString() ?? "";
var (isSafe, reason) = QuerySafetyFilter.Validate(sql);
if (!isSafe)
return $"查询被拒绝: {reason}";
return await executeSelect(sql);
}
);
// query.sample
registry.Register(
"query.sample",
"取表的前 N 行样本数据",
new
{
type = "object",
properties = new
{
table_name = new { type = "string", description = "表名" },
limit = new { type = "integer", description = "行数限制,默认 5" }
},
required = new[] { "table_name" }
},
async args =>
{
var json = JsonDocument.Parse(args);
var table = json.RootElement.GetProperty("table_name").GetString() ?? "";
var limit = json.RootElement.TryGetProperty("limit", out var l) ? l.GetInt32() : 5;
return await sampleData(table, limit);
}
);
}
}
|