diff --git a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java index 2f1126abc..53fe948b4 100644 --- a/core/src/main/java/com/google/adk/flows/llmflows/Functions.java +++ b/core/src/main/java/com/google/adk/flows/llmflows/Functions.java @@ -46,6 +46,7 @@ import io.opentelemetry.context.Scope; import io.reactivex.rxjava3.core.Flowable; import io.reactivex.rxjava3.core.Maybe; +import io.reactivex.rxjava3.core.Observable; import io.reactivex.rxjava3.core.Single; import io.reactivex.rxjava3.disposables.Disposable; import io.reactivex.rxjava3.functions.Function; @@ -152,15 +153,16 @@ public static Maybe handleFunctionCalls( Function> functionCallMapper = getFunctionCallMapper(invocationContext, tools, toolConfirmations, false); - Flowable functionResponseEventsFlowable; + Observable functionResponseEventsObservable; if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { - functionResponseEventsFlowable = - Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper); + functionResponseEventsObservable = + Observable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper); } else { - functionResponseEventsFlowable = - Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper); + functionResponseEventsObservable = + Observable.fromIterable(functionCalls) + .concatMapEager(call -> functionCallMapper.apply(call).toObservable()); } - return functionResponseEventsFlowable + return functionResponseEventsObservable .toList() .flatMapMaybe( events -> { @@ -217,16 +219,17 @@ public static Maybe handleFunctionCallsLive( Function> functionCallMapper = getFunctionCallMapper(invocationContext, tools, toolConfirmations, true); - Flowable responseEventsFlowable; + Observable responseEventsObservable; if (invocationContext.runConfig().toolExecutionMode() == ToolExecutionMode.SEQUENTIAL) { - responseEventsFlowable = - Flowable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper); + responseEventsObservable = + Observable.fromIterable(functionCalls).concatMapMaybe(functionCallMapper); } else { - responseEventsFlowable = - Flowable.fromIterable(functionCalls).flatMapMaybe(functionCallMapper); + responseEventsObservable = + Observable.fromIterable(functionCalls) + .concatMapEager(call -> functionCallMapper.apply(call).toObservable()); } - return responseEventsFlowable + return responseEventsObservable .toList() .flatMapMaybe( events -> { diff --git a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java index 775fab5fc..97092f68c 100644 --- a/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java +++ b/core/src/test/java/com/google/adk/flows/llmflows/FunctionsTest.java @@ -23,6 +23,8 @@ import static org.junit.Assert.assertThrows; import com.google.adk.agents.InvocationContext; +import com.google.adk.agents.RunConfig; +import com.google.adk.agents.RunConfig.ToolExecutionMode; import com.google.adk.events.Event; import com.google.adk.testing.TestUtils; import com.google.common.collect.ImmutableList; @@ -151,8 +153,11 @@ public void handleFunctionCalls_singleFunctionCall() { } @Test - public void handleFunctionCalls_multipleFunctionCalls() { - InvocationContext invocationContext = createInvocationContext(createRootAgent()); + public void handleFunctionCalls_multipleFunctionCalls_parallel() { + InvocationContext invocationContext = + createInvocationContext( + createRootAgent(), + RunConfig.builder().setToolExecutionMode(ToolExecutionMode.PARALLEL).build()); ImmutableMap args1 = ImmutableMap.of("key1", "value2"); ImmutableMap args2 = ImmutableMap.of("key2", "value2"); Event event = @@ -201,7 +206,66 @@ public void handleFunctionCalls_multipleFunctionCalls() { .name("echo_tool") .response(ImmutableMap.of("result", args2)) .build()) - .build()); + .build()) + .inOrder(); + } + + @Test + public void handleFunctionCalls_multipleFunctionCalls_sequential() { + InvocationContext invocationContext = + createInvocationContext( + createRootAgent(), + RunConfig.builder().setToolExecutionMode(ToolExecutionMode.SEQUENTIAL).build()); + ImmutableMap args1 = ImmutableMap.of("key1", "value2"); + ImmutableMap args2 = ImmutableMap.of("key2", "value2"); + Event event = + createEvent("event").toBuilder() + .content( + Content.fromParts( + Part.fromText("..."), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("function_call_id1") + .name("echo_tool") + .args(args1) + .build()) + .build(), + Part.builder() + .functionCall( + FunctionCall.builder() + .id("function_call_id2") + .name("echo_tool") + .args(args2) + .build()) + .build())) + .build(); + + Event functionResponseEvent = + Functions.handleFunctionCalls( + invocationContext, event, ImmutableMap.of("echo_tool", new TestUtils.EchoTool())) + .blockingGet(); + + assertThat(functionResponseEvent).isNotNull(); + assertThat(functionResponseEvent.content().get().parts().get()) + .containsExactly( + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id("function_call_id1") + .name("echo_tool") + .response(ImmutableMap.of("result", args1)) + .build()) + .build(), + Part.builder() + .functionResponse( + FunctionResponse.builder() + .id("function_call_id2") + .name("echo_tool") + .response(ImmutableMap.of("result", args2)) + .build()) + .build()) + .inOrder(); } @Test