聊聊Spring AI的Tool Calling

06-01 1105阅读

本文主要研究一下Spring AI的Tool Calling

ToolCallback

org/springframework/ai/tool/ToolCallback.java

public interface ToolCallback extends FunctionCallback {
	/**
	 * Definition used by the AI model to determine when and how to call the tool.
	 */
	ToolDefinition getToolDefinition();
	/**
	 * Metadata providing additional information on how to handle the tool.
	 */
	default ToolMetadata getToolMetadata() {
		return ToolMetadata.builder().build();
	}
	/**
	 * Execute tool with the given input and return the result to send back to the AI
	 * model.
	 */
	String call(String toolInput);
	/**
	 * Execute tool with the given input and context, and return the result to send back
	 * to the AI model.
	 */
	default String call(String toolInput, @Nullable ToolContext tooContext) {
		if (tooContext != null && !tooContext.getContext().isEmpty()) {
			throw new UnsupportedOperationException("Tool context is not supported!");
		}
		return call(toolInput);
	}
	@Override
	@Deprecated // Call getToolDefinition().name() instead
	default String getName() {
		return getToolDefinition().name();
	}
	@Override
	@Deprecated // Call getToolDefinition().description() instead
	default String getDescription() {
		return getToolDefinition().description();
	}
	@Override
	@Deprecated // Call getToolDefinition().inputTypeSchema() instead
	default String getInputTypeSchema() {
		return getToolDefinition().inputSchema();
	}
}

ToolCallback继承了FunctionCallback接口,不过FunctionCallback接口即将被废弃,它主要定义了getToolDefinition、getToolMetadata、call方法,它两个基本实现,分别是MethodToolCallback、FunctionToolCallback

MethodToolCallback

org/springframework/ai/tool/method/MethodToolCallback.java

public class MethodToolCallback implements ToolCallback {
	private static final Logger logger = LoggerFactory.getLogger(MethodToolCallback.class);
	private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter();
	private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();
	private final ToolDefinition toolDefinition;
	private final ToolMetadata toolMetadata;
	private final Method toolMethod;
	@Nullable
	private final Object toolObject;
	private final ToolCallResultConverter toolCallResultConverter;
	public MethodToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Method toolMethod,
			@Nullable Object toolObject, @Nullable ToolCallResultConverter toolCallResultConverter) {
		Assert.notNull(toolDefinition, "toolDefinition cannot be null");
		Assert.notNull(toolMethod, "toolMethod cannot be null");
		Assert.isTrue(Modifier.isStatic(toolMethod.getModifiers()) || toolObject != null,
				"toolObject cannot be null for non-static methods");
		this.toolDefinition = toolDefinition;
		this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
		this.toolMethod = toolMethod;
		this.toolObject = toolObject;
		this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter
				: DEFAULT_RESULT_CONVERTER;
	}
	@Override
	public ToolDefinition getToolDefinition() {
		return toolDefinition;
	}
	@Override
	public ToolMetadata getToolMetadata() {
		return toolMetadata;
	}
	@Override
	public String call(String toolInput) {
		return call(toolInput, null);
	}
	@Override
	public String call(String toolInput, @Nullable ToolContext toolContext) {
		Assert.hasText(toolInput, "toolInput cannot be null or empty");
		logger.debug("Starting execution of tool: {}", toolDefinition.name());
		validateToolContextSupport(toolContext);
		Map toolArguments = extractToolArguments(toolInput);
		Object[] methodArguments = buildMethodArguments(toolArguments, toolContext);
		Object result = callMethod(methodArguments);
		logger.debug("Successful execution of tool: {}", toolDefinition.name());
		Type returnType = toolMethod.getGenericReturnType();
		return toolCallResultConverter.convert(result, returnType);
	}
	@Nullable
	private Object callMethod(Object[] methodArguments) {
		if (isObjectNotPublic() || isMethodNotPublic()) {
			toolMethod.setAccessible(true);
		}
		Object result;
		try {
			result = toolMethod.invoke(toolObject, methodArguments);
		}
		catch (IllegalAccessException ex) {
			throw new IllegalStateException("Could not access method: " + ex.getMessage(), ex);
		}
		catch (InvocationTargetException ex) {
			throw new ToolExecutionException(toolDefinition, ex.getCause());
		}
		return result;
	}
	//......
}	

MethodToolCallback实现了ToolCallback接口,其call方法通过buildMethodArguments构建参数,再通过callMethod获取返回值,最后通过toolCallResultConverter.convert来转换返回值类型;callMethod主要是通过反射调用执行

目前如下几个类型作为参数或者返回类型不支持

  • Optional
  • Asynchronous types (e.g. CompletableFuture, Future)
  • Reactive types (e.g. Flow, Mono, Flux)
  • Functional types (e.g. Function, Supplier, Consumer).

    FunctionToolCallback

    org/springframework/ai/tool/function/FunctionToolCallback.java

    public class FunctionToolCallback implements ToolCallback {
    	private static final Logger logger = LoggerFactory.getLogger(FunctionToolCallback.class);
    	private static final ToolCallResultConverter DEFAULT_RESULT_CONVERTER = new DefaultToolCallResultConverter();
    	private static final ToolMetadata DEFAULT_TOOL_METADATA = ToolMetadata.builder().build();
    	private final ToolDefinition toolDefinition;
    	private final ToolMetadata toolMetadata;
    	private final Type toolInputType;
    	private final BiFunction toolFunction;
    	private final ToolCallResultConverter toolCallResultConverter;
    	public FunctionToolCallback(ToolDefinition toolDefinition, @Nullable ToolMetadata toolMetadata, Type toolInputType,
    			BiFunction toolFunction, @Nullable ToolCallResultConverter toolCallResultConverter) {
    		Assert.notNull(toolDefinition, "toolDefinition cannot be null");
    		Assert.notNull(toolInputType, "toolInputType cannot be null");
    		Assert.notNull(toolFunction, "toolFunction cannot be null");
    		this.toolDefinition = toolDefinition;
    		this.toolMetadata = toolMetadata != null ? toolMetadata : DEFAULT_TOOL_METADATA;
    		this.toolFunction = toolFunction;
    		this.toolInputType = toolInputType;
    		this.toolCallResultConverter = toolCallResultConverter != null ? toolCallResultConverter
    				: DEFAULT_RESULT_CONVERTER;
    	}
    	@Override
    	public ToolDefinition getToolDefinition() {
    		return toolDefinition;
    	}
    	@Override
    	public ToolMetadata getToolMetadata() {
    		return toolMetadata;
    	}
    	@Override
    	public String call(String toolInput) {
    		return call(toolInput, null);
    	}
    	@Override
    	public String call(String toolInput, @Nullable ToolContext toolContext) {
    		Assert.hasText(toolInput, "toolInput cannot be null or empty");
    		logger.debug("Starting execution of tool: {}", toolDefinition.name());
    		I request = JsonParser.fromJson(toolInput, toolInputType);
    		O response = toolFunction.apply(request, toolContext);
    		logger.debug("Successful execution of tool: {}", toolDefinition.name());
    		return toolCallResultConverter.convert(response, null);
    	}
    	@Override
    	public String toString() {
    		return "FunctionToolCallback{" + "toolDefinition=" + toolDefinition + ", toolMetadata=" + toolMetadata + '}';
    	}
    	//......
    }	
    

    FunctionToolCallback实现了ToolCallback接口,其call方法通过JsonParser.fromJson(toolInput, toolInputType)转换请求参数,再通过toolFunction.apply(request, toolContext)获取返回结果,最后通过toolCallResultConverter.convert(response, null)来转换结果

    目前如下类型不支持作为参数或者返回类型

    • Primitive types
    • Optional
    • Collection types (e.g. List, Map, Array, Set)
    • Asynchronous types (e.g. CompletableFuture, Future)
    • Reactive types (e.g. Flow, Mono, Flux).

      示例

      class DateTimeTools {
          String getCurrentDateTime() {
              return LocalDateTime.now().atZone(LocaleContextHolder.getTimeZone().toZoneId()).toString();
          }
      }
      

      MethodToolCallback

      Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime");
      ToolCallback toolCallback = MethodToolCallback.builder()
          .toolDefinition(ToolDefinition.builder(method)
                  .description("Get the current date and time in the user's timezone")
                  .build())
          .toolMethod(method)
          .toolObject(new DateTimeTools())
          .build();
      

      亦或是使用@Tool注解

      class DateTimeTools {
          @Tool(description = "Get the current date and time in the user's timezone")
          String getCurrentDateTime() {
              return LocalDateTime.now().atZone(LocaleContextHolder.getTimeZone().toZoneId()).toString();
          }
      }
      

      亦或是通过ToolCallbacks.from方法

      ToolCallback[] dateTimeTools = ToolCallbacks.from(new DateTimeTools());
      

      FunctionToolCallback

      public class WeatherService implements Function {
          public WeatherResponse apply(WeatherRequest request) {
              return new WeatherResponse(30.0, Unit.C);
          }
      }
      ToolCallback toolCallback = FunctionToolCallback
          .builder("currentWeather", new WeatherService())
          .description("Get the weather in location")
          .inputType(WeatherRequest.class)
          .build();
      ChatClient.create(chatModel)
          .prompt("What's the weather like in Copenhagen?")
          .tools(toolCallback)
          .call()
          .content();    
      

      亦或设置到chatOptions

      ChatOptions chatOptions = ToolCallingChatOptions.builder()
          .toolCallbacks(toolCallback)
          .build():
      Prompt prompt = new Prompt("What's the weather like in Copenhagen?", chatOptions);
      chatModel.call(prompt);
      

      亦或是注册到spring中

      @Configuration(proxyBeanMethods = false)
      class WeatherTools {
          WeatherService weatherService = new WeatherService();
      	@Bean
      	@Description("Get the weather in location")
      	Function currentWeather() {
      		return weatherService;
      	}
      }
      ChatClient.create(chatModel)
          .prompt("What's the weather like in Copenhagen?")
          .tools("currentWeather")
          .call()
          .content();
      

      Tool Specification

      ToolDefinition

      org/springframework/ai/tool/definition/ToolDefinition.java

      public interface ToolDefinition {
      	/**
      	 * The tool name. Unique within the tool set provided to a model.
      	 */
      	String name();
      	/**
      	 * The tool description, used by the AI model to determine what the tool does.
      	 */
      	String description();
      	/**
      	 * The schema of the parameters used to call the tool.
      	 */
      	String inputSchema();
      	/**
      	 * Create a default {@link ToolDefinition} builder.
      	 */
      	static DefaultToolDefinition.Builder builder() {
      		return DefaultToolDefinition.builder();
      	}
      	/**
      	 * Create a default {@link ToolDefinition} builder from a {@link Method}.
      	 */
      	static DefaultToolDefinition.Builder builder(Method method) {
      		Assert.notNull(method, "method cannot be null");
      		return DefaultToolDefinition.builder()
      			.name(ToolUtils.getToolName(method))
      			.description(ToolUtils.getToolDescription(method))
      			.inputSchema(JsonSchemaGenerator.generateForMethodInput(method));
      	}
      	/**
      	 * Create a default {@link ToolDefinition} instance from a {@link Method}.
      	 */
      	static ToolDefinition from(Method method) {
      		return ToolDefinition.builder(method).build();
      	}
      }
      

      ToolDefinition定义了name、description、inputSchema属性,它提供了builder方法可以基于Method来构建DefaultToolDefinition

      聊聊Spring AI的Tool Calling
      (图片来源网络,侵删)

      示例

      Method method = ReflectionUtils.findMethod(DateTimeTools.class, "getCurrentDateTime");
      ToolDefinition toolDefinition = ToolDefinition.builder(method)
          .name("currentDateTime")
          .description("Get the current date and time in the user's timezone")
          .inputSchema(JsonSchemaGenerator.generateForMethodInput(method))
          .build();
      

      JSON Schema

      Spring AI提供了JsonSchemaGenerator用于生成指定method或者function的请求参数的json schema,对于参数描述可以使用如下注解:

      聊聊Spring AI的Tool Calling
      (图片来源网络,侵删)
      @ToolParam(description = "…​") from Spring AI
      @JsonClassDescription(description = "…​") from Jackson
      @JsonPropertyDescription(description = "…​") from Jackson
      @Schema(description = "…​") from Swagger.
      

      示例

      import java.time.LocalDateTime;
      import java.time.format.DateTimeFormatter;
      import org.springframework.ai.tool.annotation.Tool;
      import org.springframework.ai.tool.annotation.ToolParam;
      import org.springframework.context.i18n.LocaleContextHolder;
      class DateTimeTools {
          @Tool(description = "Set a user alarm for the given time")
          void setAlarm(@ToolParam(description = "Time in ISO-8601 format") String time) {
              LocalDateTime alarmTime = LocalDateTime.parse(time, DateTimeFormatter.ISO_DATE_TIME);
              System.out.println("Alarm set for " + alarmTime);
          }
      }
      

      对于是否必填,可以使用如下注解:

      聊聊Spring AI的Tool Calling
      (图片来源网络,侵删)
      @ToolParam(required = false) from Spring AI
      @JsonProperty(required = false) from Jackson
      @Schema(required = false) from Swagger
      @Nullable from Spring Framework.
      

      示例:

      class CustomerTools {
          @Tool(description = "Update customer information")
          void updateCustomerInfo(Long id, String name, @ToolParam(required = false) String email) {
              System.out.println("Updated info for customer with id: " + id);
          }
      }
      

      Result Conversion

      Spring AI提供了ToolCallResultConverter用于将tool calling的返回数据进行转换再发送给AI模型

      org/springframework/ai/tool/execution/ToolCallResultConverter.java

      @FunctionalInterface
      public interface ToolCallResultConverter {
      	/**
      	 * Given an Object returned by a tool, convert it to a String compatible with the
      	 * given class type.
      	 */
      	String convert(@Nullable Object result, @Nullable Type returnType);
      }
      

      它有一个默认实现DefaultToolCallResultConverter

      public final class DefaultToolCallResultConverter implements ToolCallResultConverter {
      	private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallResultConverter.class);
      	@Override
      	public String convert(@Nullable Object result, @Nullable Type returnType) {
      		if (returnType == Void.TYPE) {
      			logger.debug("The tool has no return type. Converting to conventional response.");
      			return "Done";
      		}
      		else {
      			logger.debug("Converting tool result to JSON.");
      			return JsonParser.toJson(result);
      		}
      	}
      }
      

      DefaultToolCallResultConverter采用的是JsonParser.toJson(result),将返回类型转换为json字符串

      也可以自己指定,比如

      class CustomerTools {
          @Tool(description = "Retrieve customer information", resultConverter = CustomToolCallResultConverter.class)
          Customer getCustomerInfo(Long id) {
              return customerRepository.findById(id);
          }
      }
      

      Tool Context

      Spring AI提供了ToolContext,可以将附加的上下文信息传递给工具。这一功能允许开发者提供额外的、由用户提供的数据,这些数据可以在工具执行过程中与AI模型传递的工具参数一起使用。使用示例如下:

      class CustomerTools {
          @Tool(description = "Retrieve customer information")
          Customer getCustomerInfo(Long id, ToolContext toolContext) {
              return customerRepository.findById(id, toolContext.get("tenantId"));
          }
      }
      

      对于chatClient:

      ChatModel chatModel = ...
      String response = ChatClient.create(chatModel)
              .prompt("Tell me more about the customer with ID 42")
              .tools(new CustomerTools())
              .toolContext(Map.of("tenantId", "acme"))
              .call()
              .content();
      System.out.println(response);
      

      对于chatModel:

      ChatModel chatModel = ...
      ToolCallback[] customerTools = ToolCallbacks.from(new CustomerTools());
      ChatOptions chatOptions = ToolCallingChatOptions.builder()
          .toolCallbacks(customerTools)
          .toolContext(Map.of("tenantId", "acme"))
          .build();
      Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOptions);
      chatModel.call(prompt);
      

      Return Direct

      Spring AI提供了returnDirect参数,设置为true则会将tool calling的返回直接返回,而不是经过大模型再返回。默认是返回给AI模型,AI模型处理之后再返回给用户。

      示例如下:

      class CustomerTools {
          @Tool(description = "Retrieve customer information", returnDirect = true)
          Customer getCustomerInfo(Long id) {
              return customerRepository.findById(id);
          }
      }
      

      亦或是

      ToolMetadata toolMetadata = ToolMetadata.builder()
          .returnDirect(true)
          .build();
      

      ToolCallingManager

      org/springframework/ai/model/tool/ToolCallingManager.java

      public interface ToolCallingManager {
      	/**
      	 * Resolve the tool definitions from the model's tool calling options.
      	 */
      	List resolveToolDefinitions(ToolCallingChatOptions chatOptions);
      	/**
      	 * Execute the tool calls requested by the model.
      	 */
      	ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse);
      	/**
      	 * Create a default {@link ToolCallingManager} builder.
      	 */
      	static DefaultToolCallingManager.Builder builder() {
      		return DefaultToolCallingManager.builder();
      	}
      }
      

      ToolCallingManager定义了resolveToolDefinitions、executeToolCalls方法,默认实现是DefaultToolCallingManager

      DefaultToolCallingManager

      org/springframework/ai/model/tool/DefaultToolCallingManager.java

      public class DefaultToolCallingManager implements ToolCallingManager {
      	private static final Logger logger = LoggerFactory.getLogger(DefaultToolCallingManager.class);
      	// @formatter:off
      	private static final ObservationRegistry DEFAULT_OBSERVATION_REGISTRY
      			= ObservationRegistry.NOOP;
      	private static final ToolCallbackResolver DEFAULT_TOOL_CALLBACK_RESOLVER
      			= new DelegatingToolCallbackResolver(List.of());
      	private static final ToolExecutionExceptionProcessor DEFAULT_TOOL_EXECUTION_EXCEPTION_PROCESSOR
      			= DefaultToolExecutionExceptionProcessor.builder().build();
      	// @formatter:on
      	private final ObservationRegistry observationRegistry;
      	private final ToolCallbackResolver toolCallbackResolver;
      	private final ToolExecutionExceptionProcessor toolExecutionExceptionProcessor;
      	public DefaultToolCallingManager(ObservationRegistry observationRegistry, ToolCallbackResolver toolCallbackResolver,
      			ToolExecutionExceptionProcessor toolExecutionExceptionProcessor) {
      		Assert.notNull(observationRegistry, "observationRegistry cannot be null");
      		Assert.notNull(toolCallbackResolver, "toolCallbackResolver cannot be null");
      		Assert.notNull(toolExecutionExceptionProcessor, "toolCallExceptionConverter cannot be null");
      		this.observationRegistry = observationRegistry;
      		this.toolCallbackResolver = toolCallbackResolver;
      		this.toolExecutionExceptionProcessor = toolExecutionExceptionProcessor;
      	}
      	@Override
      	public List resolveToolDefinitions(ToolCallingChatOptions chatOptions) {
      		Assert.notNull(chatOptions, "chatOptions cannot be null");
      		List toolCallbacks = new ArrayList(chatOptions.getToolCallbacks());
      		for (String toolName : chatOptions.getToolNames()) {
      			// Skip the tool if it is already present in the request toolCallbacks.
      			// That might happen if a tool is defined in the options
      			// both as a ToolCallback and as a tool name.
      			if (chatOptions.getToolCallbacks().stream().anyMatch(tool -> tool.getName().equals(toolName))) {
      				continue;
      			}
      			FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName);
      			if (toolCallback == null) {
      				throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
      			}
      			toolCallbacks.add(toolCallback);
      		}
      		return toolCallbacks.stream().map(functionCallback -> {
      			if (functionCallback instanceof ToolCallback toolCallback) {
      				return toolCallback.getToolDefinition();
      			}
      			else {
      				return ToolDefinition.builder()
      					.name(functionCallback.getName())
      					.description(functionCallback.getDescription())
      					.inputSchema(functionCallback.getInputTypeSchema())
      					.build();
      			}
      		}).toList();
      	}
      	@Override
      	public ToolExecutionResult executeToolCalls(Prompt prompt, ChatResponse chatResponse) {
      		Assert.notNull(prompt, "prompt cannot be null");
      		Assert.notNull(chatResponse, "chatResponse cannot be null");
      		Optional toolCallGeneration = chatResponse.getResults()
      			.stream()
      			.filter(g -> !CollectionUtils.isEmpty(g.getOutput().getToolCalls()))
      			.findFirst();
      		if (toolCallGeneration.isEmpty()) {
      			throw new IllegalStateException("No tool call requested by the chat model");
      		}
      		AssistantMessage assistantMessage = toolCallGeneration.get().getOutput();
      		ToolContext toolContext = buildToolContext(prompt, assistantMessage);
      		InternalToolExecutionResult internalToolExecutionResult = executeToolCall(prompt, assistantMessage,
      				toolContext);
      		List conversationHistory = buildConversationHistoryAfterToolExecution(prompt.getInstructions(),
      				assistantMessage, internalToolExecutionResult.toolResponseMessage());
      		return ToolExecutionResult.builder()
      			.conversationHistory(conversationHistory)
      			.returnDirect(internalToolExecutionResult.returnDirect())
      			.build();
      	}
      	//......
      	/**
      	 * Execute the tool call and return the response message. To ensure backward
      	 * compatibility, both {@link ToolCallback} and {@link FunctionCallback} are
      	 * supported.
      	 */
      	private InternalToolExecutionResult executeToolCall(Prompt prompt, AssistantMessage assistantMessage,
      			ToolContext toolContext) {
      		List toolCallbacks = List.of();
      		if (prompt.getOptions() instanceof ToolCallingChatOptions toolCallingChatOptions) {
      			toolCallbacks = toolCallingChatOptions.getToolCallbacks();
      		}
      		else if (prompt.getOptions() instanceof FunctionCallingOptions functionOptions) {
      			toolCallbacks = functionOptions.getFunctionCallbacks();
      		}
      		List toolResponses = new ArrayList();
      		Boolean returnDirect = null;
      		for (AssistantMessage.ToolCall toolCall : assistantMessage.getToolCalls()) {
      			logger.debug("Executing tool call: {}", toolCall.name());
      			String toolName = toolCall.name();
      			String toolInputArguments = toolCall.arguments();
      			FunctionCallback toolCallback = toolCallbacks.stream()
      				.filter(tool -> toolName.equals(tool.getName()))
      				.findFirst()
      				.orElseGet(() -> toolCallbackResolver.resolve(toolName));
      			if (toolCallback == null) {
      				throw new IllegalStateException("No ToolCallback found for tool name: " + toolName);
      			}
      			if (returnDirect == null && toolCallback instanceof ToolCallback callback) {
      				returnDirect = callback.getToolMetadata().returnDirect();
      			}
      			else if (toolCallback instanceof ToolCallback callback) {
      				returnDirect = returnDirect && callback.getToolMetadata().returnDirect();
      			}
      			else if (returnDirect == null) {
      				// This is a temporary solution to ensure backward compatibility with
      				// FunctionCallback.
      				// TODO: remove this block when FunctionCallback is removed.
      				returnDirect = false;
      			}
      			String toolResult;
      			try {
      				toolResult = toolCallback.call(toolInputArguments, toolContext);
      			}
      			catch (ToolExecutionException ex) {
      				toolResult = toolExecutionExceptionProcessor.process(ex);
      			}
      			toolResponses.add(new ToolResponseMessage.ToolResponse(toolCall.id(), toolName, toolResult));
      		}
      		return new InternalToolExecutionResult(new ToolResponseMessage(toolResponses, Map.of()), returnDirect);
      	}
      	private List buildConversationHistoryAfterToolExecution(List previousMessages,
      			AssistantMessage assistantMessage, ToolResponseMessage toolResponseMessage) {
      		List messages = new ArrayList(previousMessages);
      		messages.add(assistantMessage);
      		messages.add(toolResponseMessage);
      		return messages;
      	}	
      }	
      

      DefaultToolCallingManager的resolveToolDefinitions方法会通过toolCallbackResolver来解析chatOptions.getToolCallbacks(),executeToolCalls方法先筛选出需要toolCall支持的assistantMessage,然后构建toolContext,再执行executeToolCall获取执行结构,再基于此构建conversationHistory。

      executeToolCall方法遍历assistantMessage.getToolCalls(),通过toolCallbackResolver.resolve(toolName)解析成toolCallback,最后通过toolCallback.call(toolInputArguments, toolContext)获取结果,如果出现ToolExecutionException,则通过toolExecutionExceptionProcessor.process(ex)去做兜底操作

      ToolExecutionExceptionProcessor

      org/springframework/ai/tool/execution/ToolExecutionExceptionProcessor.java

      @FunctionalInterface
      public interface ToolExecutionExceptionProcessor {
      	/**
      	 * Convert an exception thrown by a tool to a String that can be sent back to the AI
      	 * model or throw an exception to be handled by the caller.
      	 */
      	String process(ToolExecutionException exception);
      }
      

      ToolExecutionExceptionProcessor定义process

      DefaultToolExecutionExceptionProcessor

      public class DefaultToolExecutionExceptionProcessor implements ToolExecutionExceptionProcessor {
      	private final static Logger logger = LoggerFactory.getLogger(DefaultToolExecutionExceptionProcessor.class);
      	private static final boolean DEFAULT_ALWAYS_THROW = false;
      	private final boolean alwaysThrow;
      	public DefaultToolExecutionExceptionProcessor(boolean alwaysThrow) {
      		this.alwaysThrow = alwaysThrow;
      	}
      	@Override
      	public String process(ToolExecutionException exception) {
      		Assert.notNull(exception, "exception cannot be null");
      		if (alwaysThrow) {
      			throw exception;
      		}
      		logger.debug("Exception thrown by tool: {}. Message: {}", exception.getToolDefinition().name(),
      				exception.getMessage());
      		return exception.getMessage();
      	}
      	//......
      }	
      

      DefaultToolExecutionExceptionProcessor对于alwaysThrow为true的(默认为false)直接抛出该异常,否则返回异常的信息

      User-Controlled Tool Execution

      ToolCallingChatOptions提供了internalToolExecutionEnabled属性,设置为false可以自行控制对tool的调用过程(也可以自己实现ToolExecutionEligibilityPredicate去控制),示例如下:

      ChatModel chatModel = ...
      ToolCallingManager toolCallingManager = ToolCallingManager.builder().build();
      ChatOptions chatOptions = ToolCallingChatOptions.builder()
          .toolCallbacks(new CustomerTools())
          .internalToolExecutionEnabled(false)
          .build();
      Prompt prompt = new Prompt("Tell me more about the customer with ID 42", chatOptions);
      ChatResponse chatResponse = chatModel.call(prompt);
      while (chatResponse.hasToolCalls()) {
          ToolExecutionResult toolExecutionResult = toolCallingManager.executeToolCalls(prompt, chatResponse);
          prompt = new Prompt(toolExecutionResult.conversationHistory(), chatOptions);
          chatResponse = chatModel.call(prompt);
      }
      System.out.println(chatResponse.getResult().getOutput().getText());
      

      这里自己通过toolCallingManager.executeToolCalls去执行,再传递给chatModel

      ToolCallbackResolver

      spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/ToolCallbackResolver.java

      public interface ToolCallbackResolver {
      	/**
      	 * Resolve the {@link FunctionCallback} for the given tool name.
      	 */
      	@Nullable
      	FunctionCallback resolve(String toolName);
      }
      

      ToolCallbackResolver定义了resolve方法,用于根据toolName来获取对应的FunctionCallback,它有三种实现,分别是StaticToolCallbackResolver、SpringBeanToolCallbackResolver、DelegatingToolCallbackResolver

      StaticToolCallbackResolver

      spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/StaticToolCallbackResolver.java

      public class StaticToolCallbackResolver implements ToolCallbackResolver {
      	private static final Logger logger = LoggerFactory.getLogger(StaticToolCallbackResolver.class);
      	private final Map toolCallbacks = new HashMap();
      	public StaticToolCallbackResolver(List toolCallbacks) {
      		Assert.notNull(toolCallbacks, "toolCallbacks cannot be null");
      		Assert.noNullElements(toolCallbacks, "toolCallbacks cannot contain null elements");
      		toolCallbacks.forEach(callback -> {
      			if (callback instanceof ToolCallback toolCallback) {
      				this.toolCallbacks.put(toolCallback.getToolDefinition().name(), toolCallback);
      			}
      			this.toolCallbacks.put(callback.getName(), callback);
      		});
      	}
      	@Override
      	public FunctionCallback resolve(String toolName) {
      		Assert.hasText(toolName, "toolName cannot be null or empty");
      		logger.debug("ToolCallback resolution attempt from static registry");
      		return toolCallbacks.get(toolName);
      	}
      }
      

      StaticToolCallbackResolver依据构造器传入的List来寻找

      SpringBeanToolCallbackResolver

      spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/SpringBeanToolCallbackResolver.java

      public class SpringBeanToolCallbackResolver implements ToolCallbackResolver {
      	private static final Logger logger = LoggerFactory.getLogger(SpringBeanToolCallbackResolver.class);
      	private static final Map toolCallbacksCache = new HashMap();
      	private static final SchemaType DEFAULT_SCHEMA_TYPE = SchemaType.JSON_SCHEMA;
      	private final GenericApplicationContext applicationContext;
      	private final SchemaType schemaType;
      	public SpringBeanToolCallbackResolver(GenericApplicationContext applicationContext,
      			@Nullable SchemaType schemaType) {
      		Assert.notNull(applicationContext, "applicationContext cannot be null");
      		this.applicationContext = applicationContext;
      		this.schemaType = schemaType != null ? schemaType : DEFAULT_SCHEMA_TYPE;
      	}
      	@Override
      	public ToolCallback resolve(String toolName) {
      		Assert.hasText(toolName, "toolName cannot be null or empty");
      		logger.debug("ToolCallback resolution attempt from Spring application context");
      		ToolCallback resolvedToolCallback = toolCallbacksCache.get(toolName);
      		if (resolvedToolCallback != null) {
      			return resolvedToolCallback;
      		}
      		ResolvableType toolType = TypeResolverHelper.resolveBeanType(applicationContext, toolName);
      		ResolvableType toolInputType = (ResolvableType.forType(Supplier.class).isAssignableFrom(toolType))
      				? ResolvableType.forType(Void.class) : TypeResolverHelper.getFunctionArgumentType(toolType, 0);
      		String toolDescription = resolveToolDescription(toolName, toolInputType.toClass());
      		Object bean = applicationContext.getBean(toolName);
      		resolvedToolCallback = buildToolCallback(toolName, toolType, toolInputType, toolDescription, bean);
      		toolCallbacksCache.put(toolName, resolvedToolCallback);
      		return resolvedToolCallback;
      	}
      	//......
      }	
      

      SpringBeanToolCallbackResolver使用GenericApplicationContext根据toolName去spring容器查找,找到的话会放到toolCallbacksCache中

      DelegatingToolCallbackResolver

      spring-ai-model/src/main/java/org/springframework/ai/tool/resolution/DelegatingToolCallbackResolver.java

      public class DelegatingToolCallbackResolver implements ToolCallbackResolver {
      	private final List toolCallbackResolvers;
      	public DelegatingToolCallbackResolver(List toolCallbackResolvers) {
      		Assert.notNull(toolCallbackResolvers, "toolCallbackResolvers cannot be null");
      		Assert.noNullElements(toolCallbackResolvers, "toolCallbackResolvers cannot contain null elements");
      		this.toolCallbackResolvers = toolCallbackResolvers;
      	}
      	@Override
      	@Nullable
      	public FunctionCallback resolve(String toolName) {
      		Assert.hasText(toolName, "toolName cannot be null or empty");
      		for (ToolCallbackResolver toolCallbackResolver : toolCallbackResolvers) {
      			FunctionCallback toolCallback = toolCallbackResolver.resolve(toolName);
      			if (toolCallback != null) {
      				return toolCallback;
      			}
      		}
      		return null;
      	}
      }
      

      DelegatingToolCallbackResolver把resolve方法委托给了构造器传入的其他toolCallbackResolvers

      小结

      Spring AI提供了ToolCallback来实现Tool Calling,它继承了FunctionCallback接口,不过FunctionCallback接口即将被废弃,它主要定义了getToolDefinition、getToolMetadata、call方法,它两个基本实现,分别是MethodToolCallback、FunctionToolCallback。

      整个Tool Specification包含了Tool Callback、Tool Definition、JSON Schema、Result Conversion、Tool Context、Return Direct

      整个Tool Execution包含了Framework-Controlled Tool Execution、User-Controlled Tool Execution、Exception Handling

      doc

      • tools
免责声明:我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自自研大数据AI进行生成,内容摘自(百度百科,百度知道,头条百科,中国民法典,刑法,牛津词典,新华词典,汉语词典,国家院校,科普平台)等数据,内容仅供学习参考,不准确地方联系删除处理! 图片声明:本站部分配图来自人工智能系统AI生成,觅知网授权图片,PxHere摄影无版权图库和百度,360,搜狗等多加搜索引擎自动关键词搜索配图,如有侵权的图片,请第一时间联系我们。

目录[+]

取消
微信二维码
微信二维码
支付宝二维码