聊聊Spring AI的Tool Calling
序
本文主要研究一下Spring AI的Tool Calling
ToolCallback
org/springframework/ai/tool/ToolCallback.java
public interface ToolCallback extends FunctionCallback { /** * Definition used by the AI model to determine when and how to call the tool. */ ToolDefinition getToolDefinition(); /** * Metadata providing additional information on how to handle the tool. */ default ToolMetadata getToolMetadata() { return ToolMetadata.builder().build(); } /** * Execute tool with the given input and return the result to send back to the AI * model. */ String call(String toolInput); /** * Execute tool with the given input and context, and return the result to send back * to the AI model. */ default String call(String toolInput, @Nullable ToolContext tooContext) { if (tooContext != null && !tooContext.getContext().isEmpty()) { throw new UnsupportedOperationException("Tool context is not supported!"); } return call(toolInput); } @Override @Deprecated // Call getToolDefinition().name() instead default String getName() { return getToolDefinition().name(); } @Override @Deprecated // Call getToolDefinition().description() instead default String getDescription() { return getToolDefinition().description(); } @Override @Deprecated // Call getToolDefinition().inputTypeSchema() instead default String getInputTypeSchema() { return getToolDefinition().inputSchema(); } }
ToolCallback继承了FunctionCallback接口,不过FunctionCallback接口即将被废弃,它主要定义了getToolDefinition、getToolMetadata、call方法,它两个基本实现,分别是MethodToolCallback、FunctionToolCallback
MethodToolCallback
org/springframework/ai/tool/method/MethodToolCallback.java
public class MethodToolCallback implements ToolCallback { private static final Logger logger = LoggerFactory.getLogger(MethodToolCallback.class); private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter(); private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build(); private final ToolDefinition toolDefinition; private final ToolMetadata toolMetadata; private final Method toolMethod; @Nullable private final Object toolObject; private final ToolCallResultConverter toolCallResultConverter; public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Method toolMethod, @Nullable Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) { Assert.notNull(toolDefinition, "toolDefinition cannot be null"); Assert.notNull(toolMethod, "toolMethod cannot be null"); Assert.isTrue(Modifier.isStatic(toolMethod.getModifiers()) || toolObject != null, "toolObject cannot be null for non-static methods"); this.toolDefinition = toolDefinition; this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA; this.toolMethod = toolMethod; this.toolObject = toolObject; this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter : DEFAULT_RESULT_CONVERTER; } @Override public ToolDefinition getToolDefinition() { return toolDefinition; } @Override public ToolMetadata getToolMetadata() { return toolMetadata; } @Override public String call(String toolInput) { return call(toolInput, null); } @Override public String call(String toolInput, @Nullable ToolContext toolContext) { Assert.hasText(toolInput, "toolInput cannot be null or empty"); logger.debug("Starting execution of tool: {}", toolDefinition.name()); validateToolContextSupport(toolContext); Map toolArguments = extractToolArguments(toolInput); Object[] methodArguments = buildMethodArguments(toolArguments, toolContext); Object result = callMethod(methodArguments); logger.debug("Successful execution of tool: {}", toolDefinition.name()); Type returnType = toolMethod.getGenericReturnType(); return toolCallResultConverter.convert(result, returnType); } @Nullable private Object callMethod(Object[] methodArguments) { if (isObjectNotPublic() || isMethodNotPublic()) { toolMethod.setAccessible(true); } Object result; try { result = toolMethod.invoke(toolObject, methodArguments); } catch (IllegalAccessException ex) { throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex); } catch (InvocationTargetException ex) { throw new ToolExecutionException(toolDefinition, ex.getCause()); } return result; } //...... }
MethodToolCallback实现了ToolCallback接口,其call方法通过buildMethodArguments构建参数,再通过callMethod获取返回值,最后通过toolCallResultConverter.convert来转换返回值类型;callMethod主要是通过反射调用执行
目前如下几个类型作为参数或者返回类型不支持
- Optional
- Asynchronous types (e.g. CompletableFuture, Future)
- Reactive types (e.g. Flow, Mono, Flux)
- Functional types (e.g. Function, Supplier, Consumer).
FunctionToolCallback
org/springframework/ai/tool/function/FunctionToolCallback.java
public class FunctionToolCallback implements ToolCallback { private static final Logger logger = LoggerFactory.getLogger(FunctionToolCallback.class); private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter(); private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build(); private final ToolDefinition toolDefinition; private final ToolMetadata toolMetadata; private final Type toolInputType; private final BiFunction toolFunction; private final ToolCallResultConverter toolCallResultConverter; public FunctionToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Type toolInputType, BiFunction toolFunction, @Nullable ToolCallResultConverter toolCallResultConverter) { Assert.notNull(toolDefinition, "toolDefinition cannot be null"); Assert.notNull(toolInputType, "toolInputType cannot be null"); Assert.notNull(toolFunction, "toolFunction cannot be null"); this.toolDefinition = toolDefinition; this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA; this.toolFunction = toolFunction; this.toolInputType = toolInputType; this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter : DEFAULT_RESULT_CONVERTER; } @Override public ToolDefinition getToolDefinition() { return toolDefinition; } @Override public ToolMetadata getToolMetadata() { return toolMetadata; } @Override public String call(String toolInput) { return call(toolInput, null); } @Override public String call(String toolInput, @Nullable ToolContext toolContext) { Assert.hasText(toolInput, "toolInput cannot be null or empty"); logger.debug("Starting execution of tool: {}", toolDefinition.name()); I request = JsonParser.fromJson(toolInput, toolInputType); O response = toolFunction.apply(request, toolContext); logger.debug("Successful execution of tool: {}", toolDefinition.name()); return toolCallResultConverter.convert(response, null); } @Override public String toString() { return "FunctionToolCallback{" + "toolDefinition=" + toolDefinition + ", toolMetadata=" + toolMetadata + '}'; } //...... }
FunctionToolCallback实现了ToolCallback接口,其call方法通过JsonParser.fromJson(toolInput, toolInputType)转换请求参数,再通过toolFunction.apply(request, toolContext)获取返回结果,最后通过toolCallResultConverter.convert(response, null)来转换结果
目前如下类型不支持作为参数或者返回类型
- Primitive types
- Optional
- Collection types (e.g. List, Map, Array, Set)
- Asynchronous types (e.g. CompletableFuture, Future)
- Reactive types (e.g. Flow, Mono, Flux).
示例
class DateTimeTools { String getCurrentDateTime() { return LocalDateTime.now().atZone(LocaleContextHolder.getTimeZone().toZoneId()).toString(); } }
MethodToolCallback
Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime"); ToolCallback toolCallback = MethodToolCallback.builder() .toolDefinition(ToolDefinition.builder(method) .description("Get the current date and time in the user's timezone") .build()) .toolMethod(method) .toolObject(new DateTimeTools()) .build();
亦或是使用@Tool注解
class DateTimeTools { @Tool(description = "Get the current date and time in the user's timezone") String getCurrentDateTime() { return LocalDateTime.now().atZone(LocaleContextHolder.getTimeZone().toZoneId()).toString(); } }
亦或是通过ToolCallbacks.from方法
ToolCallback[] dateTimeTools = ToolCallbacks.from(new DateTimeTools());
FunctionToolCallback
public class WeatherService implements Function { public WeatherResponse apply(WeatherRequest request) { return new WeatherResponse(30.0, Unit.C); } } ToolCallback toolCallback = FunctionToolCallback .builder("currentWeather", new WeatherService()) .description("Get the weather in location") .inputType(WeatherRequest.class) .build(); ChatClient.create(chatModel) .prompt("What's the weather like in Copenhagen?") .tools(toolCallback) .call() .content();
亦或设置到chatOptions
ChatOptions chatOptions = ToolCallingChatOptions.builder() .toolCallbacks(toolCallback) .build(): Prompt prompt = new Prompt("What's the weather like in Copenhagen?", chatOptions); chatModel.call(prompt);
亦或是注册到spring中
@Configuration(proxyBeanMethods = false) class WeatherTools { WeatherService weatherService = new WeatherService(); @Bean @Description("Get the weather in location") Function currentWeather() { return weatherService; } } ChatClient.create(chatModel) .prompt("What's the weather like in Copenhagen?") .tools("currentWeather") .call() .content();
Tool Specification
ToolDefinition
org/springframework/ai/tool/definition/ToolDefinition.java
public interface ToolDefinition { /** * The tool name. Unique within the tool set provided to a model. */ String name(); /** * The tool description, used by the AI model to determine what the tool does. */ String description(); /** * The schema of the parameters used to call the tool. */ String inputSchema(); /** * Create a default {@link ToolDefinition} builder. */ static DefaultToolDefinition.Builder builder() { return DefaultToolDefinition.builder(); } /** * Create a default {@link ToolDefinition} builder from a {@link Method}. */ static DefaultToolDefinition.Builder builder(Method method) { Assert.notNull(method, "method cannot be null"); return DefaultToolDefinition.builder() .name(ToolUtils.getToolName(method)) .description(ToolUtils.getToolDescription(method)) .inputSchema(JsonSchemaGenerator.generateForMethodInput(method)); } /** * Create a default {@link ToolDefinition} instance from a {@link Method}. */ static ToolDefinition from(Method method) { return ToolDefinition.builder(method).build(); } }
ToolDefinition定义了name、description、inputSchema属性,它提供了builder方法可以基于Method来构建DefaultToolDefinition
(图片来源网络,侵删)示例
Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime"); ToolDefinition toolDefinition = ToolDefinition.builder(method) .name("currentDateTime") .description("Get the current date and time in the user's timezone") .inputSchema(JsonSchemaGenerator.generateForMethodInput(method)) .build();
JSON Schema
Spring AI提供了JsonSchemaGenerator用于生成指定method或者function的请求参数的json schema,对于参数描述可以使用如下注解:
(图片来源网络,侵删)@ToolParam(description = "…") from Spring AI @JsonClassDescription(description = "…") from Jackson @JsonPropertyDescription(description = "…") from Jackson @Schema(description = "…") from Swagger.
示例
import java.time.LocalDateTime; import java.time.format.DateTimeFormatter; import org.springframework.ai.tool.annotation.Tool; import org.springframework.ai.tool.annotation.ToolParam; import org.springframework.context.i18n.LocaleContextHolder; class DateTimeTools { @Tool(description = "Set a user alarm for the given time") void setAlarm(@ToolParam(description = "Time in ISO-8601 format") String time) { LocalDateTime alarmTime = LocalDateTime.parse(time, DateTimeFormatter.ISO_DATE_TIME); System.out.println("Alarm set for " + alarmTime); } }
对于是否必填,可以使用如下注解:
(图片来源网络,侵删)@ToolParam(required = false) from Spring AI @JsonProperty(required = false) from Jackson @Schema(required = false) from Swagger @Nullable from Spring Framework.
示例:
class CustomerTools { @Tool(description = "Update customer information") void updateCustomerInfo(Long id, String name, @ToolParam(required = false) String email) { System.out.println("Updated info for customer with id: " + id); } }
Result Conversion
Spring AI提供了ToolCallResultConverter用于将tool calling的返回数据进行转换再发送给AI模型
org/springframework/ai/tool/execution/ToolCallResultConverter.java
@FunctionalInterface public interface ToolCallResultConverter { /** * Given an Object returned by a tool, convert it to a String compatible with the * given class type. */ String convert(@Nullable Object result, @Nullable Type returnType); }
它有一个默认实现DefaultToolCallResultConverter
public final class DefaultToolCallResultConverter implements ToolCallResultConverter { private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallResultConverter.class); @Override public String convert(@Nullable Object result, @Nullable Type returnType) { if (returnType == Void.TYPE) { logger.debug("The tool has no return type. Converting to conventional response."); return "Done"; } else { logger.debug("Converting tool result to JSON."); return JsonParser.toJson(result); } } }
DefaultToolCallResultConverter采用的是JsonParser.toJson(result),将返回类型转换为json字符串
也可以自己指定,比如
class CustomerTools { @Tool(description = "Retrieve customer information", resultConverter = CustomToolCallResultConverter.class) Customer getCustomerInfo(Long id) { return customerRepository.findById(id); } }
Tool Context
Spring AI提供了ToolContext,可以将附加的上下文信息传递给工具。这一功能允许开发者提供额外的、由用户提供的数据,这些数据可以在工具执行过程中与AI模型传递的工具参数一起使用。使用示例如下:
class CustomerTools { @Tool(description = "Retrieve customer information") Customer getCustomerInfo(Long id, ToolContext toolContext) { return customerRepository.findById(id, toolContext.get("tenantId")); } }
对于chatClient:
ChatModel chatModel = ... String response = ChatClient.create(chatModel) .prompt("Tell me more about the customer with ID 42") .tools(new CustomerTools()) .toolContext(Map.of("tenantId", "acme")) .call() .content(); System.out.println(response);
对于chatModel:
ChatModel chatModel = ... ToolCallback[] customerTools = ToolCallbacks.from(new CustomerTools()); ChatOptions chatOptions = ToolCallingChatOptions.builder() .toolCallbacks(customerTools) .toolContext(Map.of("tenantId", "acme")) .build(); Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOptions); chatModel.call(prompt);
Return Direct
Spring AI提供了returnDirect参数,设置为true则会将tool calling的返回直接返回,而不是经过大模型再返回。默认是返回给AI模型,AI模型处理之后再返回给用户。
示例如下:
class CustomerTools { @Tool(description = "Retrieve customer information", returnDirect = true) Customer getCustomerInfo(Long id) { return customerRepository.findById(id); } }
亦或是
ToolMetadata toolMetadata = ToolMetadata.builder() .returnDirect(true) .build();
ToolCallingManager
org/springframework/ai/model/tool/ToolCallingManager.java
public interface ToolCallingManager { /** * Resolve the tool definitions from the model's tool calling options. */ List resolveToolDefinitions(ToolCallingChatOptions chatOptions); /** * Execute the tool calls requested by the model. */ ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse); /** * Create a default {@link ToolCallingManager} builder. */ static DefaultToolCallingManager.Builder builder() { return DefaultToolCallingManager.builder(); } }
ToolCallingManager定义了resolveToolDefinitions、executeToolCalls方法,默认实现是DefaultToolCallingManager
DefaultToolCallingManager
org/springframework/ai/model/tool/DefaultToolCallingManager.java
public class DefaultToolCallingManager implements ToolCallingManager { private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallingManager.class); // @formatter:off private static final ObservationRegistry DEFAULT_OBSERVATION_REGISTRY = ObservationRegistry.NOOP; private static final ToolCallbackResolver DEFAULT_TOOL_CALLBACK_RESOLVER = new DelegatingToolCallbackResolver(List.of()); private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR = DefaultToolExecutionExceptionProcessor.builder().build(); // @formatter:on private final ObservationRegistry observationRegistry; private final ToolCallbackResolver toolCallbackResolver; private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor; public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver, ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) { Assert.notNull(observationRegistry, "observationRegistry cannot be null"); Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null"); Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null"); this.observationRegistry = observationRegistry; this.toolCallbackResolver = toolCallbackResolver; this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor; } @Override public List resolveToolDefinitions(ToolCallingChatOptions chatOptions) { Assert.notNull(chatOptions, "chatOptions cannot be null"); List toolCallbacks = new ArrayList(chatOptions.getToolCallbacks()); for (String toolName : chatOptions.getToolNames()) { // Skip the tool if it is already present in the request toolCallbacks. // That might happen if a tool is defined in the options // both as a ToolCallback and as a tool name. if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) { continue; } FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName); if (toolCallback == null) { throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); } toolCallbacks.add(toolCallback); } return toolCallbacks.stream().map(functionCallback -> { if (functionCallback instanceof ToolCallback toolCallback) { return toolCallback.getToolDefinition(); } else { return ToolDefinition.builder() .name(functionCallback.getName()) .description(functionCallback.getDescription()) .inputSchema(functionCallback.getInputTypeSchema()) .build(); } }).toList(); } @Override public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) { Assert.notNull(prompt, "prompt cannot be null"); Assert.notNull(chatResponse, "chatResponse cannot be null"); Optional toolCallGeneration = chatResponse.getResults() .stream() .filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls())) .findFirst(); if (toolCallGeneration.isEmpty()) { throw new IllegalStateException("No tool call requested by the chat model"); } AssistantMessage assistantMessage = toolCallGeneration.get().getOutput(); ToolContext toolContext = buildToolContext(prompt, assistantMessage); InternalToolExecutionResult internalToolExecutionResult = executeToolCall(prompt, assistantMessage, toolContext); List conversationHistory = buildConversationHistoryAfterToolExecution(prompt.getInstructions(), assistantMessage, internalToolExecutionResult.toolResponseMessage()); return ToolExecutionResult.builder() .conversationHistory(conversationHistory) .returnDirect(internalToolExecutionResult.returnDirect()) .build(); } //...... /** * Execute the tool call and return the response message. To ensure backward * compatibility, both {@link ToolCallback} and {@link FunctionCallback} are * supported. */ private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage, ToolContext toolContext) { List toolCallbacks = List.of(); if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) { toolCallbacks = toolCallingChatOptions.getToolCallbacks(); } else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) { toolCallbacks = functionOptions.getFunctionCallbacks(); } List toolResponses = new ArrayList(); Boolean returnDirect = null; for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) { logger.debug("Executing tool call: {}", toolCall.name()); String toolName = toolCall.name(); String toolInputArguments = toolCall.arguments(); FunctionCallback toolCallback = toolCallbacks.stream() .filter(tool -> toolName.equals(tool.getName())) .findFirst() .orElseGet(() -> toolCallbackResolver.resolve(toolName)); if (toolCallback == null) { throw new IllegalStateException("No ToolCallback found for tool name: " + toolName); } if (returnDirect == null && toolCallback instanceof ToolCallback callback) { returnDirect = callback.getToolMetadata().returnDirect(); } else if (toolCallback instanceof ToolCallback callback) { returnDirect = returnDirect && callback.getToolMetadata().returnDirect(); } else if (returnDirect == null) { // This is a temporary solution to ensure backward compatibility with // FunctionCallback. // TODO: remove this block when FunctionCallback is removed. returnDirect = false; } String toolResult; try { toolResult = toolCallback.call(toolInputArguments, toolContext); } catch (ToolExecutionException ex) { toolResult = toolExecutionExceptionProcessor.process(ex); } toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolResult)); } return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect); } private List buildConversationHistoryAfterToolExecution(List previousMessages, AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) { List messages = new ArrayList(previousMessages); messages.add(assistantMessage); messages.add(toolResponseMessage); return messages; } }
DefaultToolCallingManager的resolveToolDefinitions方法会通过toolCallbackResolver来解析chatOptions.getToolCallbacks(),executeToolCalls方法先筛选出需要toolCall支持的assistantMessage,然后构建toolContext,再执行executeToolCall获取执行结构,再基于此构建conversationHistory。
executeToolCall方法遍历assistantMessage.getToolCalls(),通过toolCallbackResolver.resolve(toolName)解析成toolCallback,最后通过toolCallback.call(toolInputArguments, toolContext)获取结果,如果出现ToolExecutionException,则通过toolExecutionExceptionProcessor.process(ex)去做兜底操作
ToolExecutionExceptionProcessor
org/springframework/ai/tool/execution/ToolExecutionExceptionProcessor.java
@FunctionalInterface public interface ToolExecutionExceptionProcessor { /** * Convert an exception thrown by a tool to a String that can be sent back to the AI * model or throw an exception to be handled by the caller. */ String process(ToolExecutionException exception); }
ToolExecutionExceptionProcessor定义process
DefaultToolExecutionExceptionProcessor
public class DefaultToolExecutionExceptionProcessor implements ToolExecutionExceptionProcessor { private final static Logger logger = LoggerFactory.getLogger(DefaultToolExecutionExceptionProcessor.class); private static final boolean DEFAULT_ALWAYS_THROW = false; private final boolean alwaysThrow; public DefaultToolExecutionExceptionProcessor(boolean alwaysThrow) { this.alwaysThrow = alwaysThrow; } @Override public String process(ToolExecutionException exception) { Assert.notNull(exception, "exception cannot be null"); if (alwaysThrow) { throw exception; } logger.debug("Exception thrown by tool: {}. Message: {}", exception.getToolDefinition().name(), exception.getMessage()); return exception.getMessage(); } //...... }
DefaultToolExecutionExceptionProcessor对于alwaysThrow为true的(默认为false)直接抛出该异常,否则返回异常的信息
User-Controlled Tool Execution
ToolCallingChatOptions提供了internalToolExecutionEnabled属性,设置为false可以自行控制对tool的调用过程(也可以自己实现ToolExecutionEligibilityPredicate去控制),示例如下:
ChatModel chatModel = ... ToolCallingManager toolCallingManager = ToolCallingManager.builder().build(); ChatOptions chatOptions = ToolCallingChatOptions.builder() .toolCallbacks(new CustomerTools()) .internalToolExecutionEnabled(false) .build(); Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOptions); ChatResponse chatResponse = chatModel.call(prompt); while (chatResponse.hasToolCalls()) { ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse); prompt = new Prompt(toolExecutionResult.conversationHistory(), chatOptions); chatResponse = chatModel.call(prompt); } System.out.println(chatResponse.getResult().getOutput().getText());
这里自己通过toolCallingManager.executeToolCalls去执行,再传递给chatModel
ToolCallbackResolver
spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java
public interface ToolCallbackResolver { /** * Resolve the {@link FunctionCallback} for the given tool name. */ @Nullable FunctionCallback resolve(String toolName); }
ToolCallbackResolver定义了resolve方法,用于根据toolName来获取对应的FunctionCallback,它有三种实现,分别是StaticToolCallbackResolver、SpringBeanToolCallbackResolver、DelegatingToolCallbackResolver
StaticToolCallbackResolver
spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java
public class StaticToolCallbackResolver implements ToolCallbackResolver { private static final Logger logger = LoggerFactory.getLogger(StaticToolCallbackResolver.class); private final Map toolCallbacks = new HashMap(); public StaticToolCallbackResolver(List toolCallbacks) { Assert.notNull(toolCallbacks, "toolCallbacks cannot be null"); Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements"); toolCallbacks.forEach(callback -> { if (callback instanceof ToolCallback toolCallback) { this.toolCallbacks.put(toolCallback.getToolDefinition().name(), toolCallback); } this.toolCallbacks.put(callback.getName(), callback); }); } @Override public FunctionCallback resolve(String toolName) { Assert.hasText(toolName, "toolName cannot be null or empty"); logger.debug("ToolCallback resolution attempt from static registry"); return toolCallbacks.get(toolName); } }
StaticToolCallbackResolver依据构造器传入的List来寻找
SpringBeanToolCallbackResolver
spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java
public class SpringBeanToolCallbackResolver implements ToolCallbackResolver { private static final Logger logger = LoggerFactory.getLogger(SpringBeanToolCallbackResolver.class); private static final Map toolCallbacksCache = new HashMap(); private static final SchemaType DEFAULT_SCHEMA_TYPE = SchemaType.JSON_SCHEMA; private final GenericApplicationContext applicationContext; private final SchemaType schemaType; public SpringBeanToolCallbackResolver(GenericApplicationContext applicationContext, @Nullable SchemaType schemaType) { Assert.notNull(applicationContext, "applicationContext cannot be null"); this.applicationContext = applicationContext; this.schemaType = schemaType != null ? schemaType : DEFAULT_SCHEMA_TYPE; } @Override public ToolCallback resolve(String toolName) { Assert.hasText(toolName, "toolName cannot be null or empty"); logger.debug("ToolCallback resolution attempt from Spring application context"); ToolCallback resolvedToolCallback = toolCallbacksCache.get(toolName); if (resolvedToolCallback != null) { return resolvedToolCallback; } ResolvableType toolType = TypeResolverHelper.resolveBeanType(applicationContext, toolName); ResolvableType toolInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(toolType)) ? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(toolType, 0); String toolDescription = resolveToolDescription(toolName, toolInputType.toClass()); Object bean = applicationContext.getBean(toolName); resolvedToolCallback = buildToolCallback(toolName, toolType, toolInputType, toolDescription, bean); toolCallbacksCache.put(toolName, resolvedToolCallback); return resolvedToolCallback; } //...... }
SpringBeanToolCallbackResolver使用GenericApplicationContext根据toolName去spring容器查找,找到的话会放到toolCallbacksCache中
DelegatingToolCallbackResolver
spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java
public class DelegatingToolCallbackResolver implements ToolCallbackResolver { private final List toolCallbackResolvers; public DelegatingToolCallbackResolver(List toolCallbackResolvers) { Assert.notNull(toolCallbackResolvers, "toolCallbackResolvers cannot be null"); Assert.noNullElements(toolCallbackResolvers, "toolCallbackResolvers cannot contain null elements"); this.toolCallbackResolvers = toolCallbackResolvers; } @Override @Nullable public FunctionCallback resolve(String toolName) { Assert.hasText(toolName, "toolName cannot be null or empty"); for (ToolCallbackResolver toolCallbackResolver : toolCallbackResolvers) { FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName); if (toolCallback != null) { return toolCallback; } } return null; } }
DelegatingToolCallbackResolver把resolve方法委托给了构造器传入的其他toolCallbackResolvers
小结
Spring AI提供了ToolCallback来实现Tool Calling,它继承了FunctionCallback接口,不过FunctionCallback接口即将被废弃,它主要定义了getToolDefinition、getToolMetadata、call方法,它两个基本实现,分别是MethodToolCallback、FunctionToolCallback。
整个Tool Specification包含了Tool Callback、Tool Definition、JSON Schema、Result Conversion、Tool Context、Return Direct
整个Tool Execution包含了Framework-Controlled Tool Execution、User-Controlled Tool Execution、Exception Handling
doc
- tools