From a02a2a9df8b2f962d5adae1ef648e3cc1c876e69 Mon Sep 17 00:00:00 2001 From: Saurav Date: Mon, 10 Nov 2025 21:52:52 +0000 Subject: [PATCH 1/7] feat(xds): Add configuration objects for ExtAuthz and GrpcService This commit introduces configuration objects for the external authorization (ExtAuthz) filter and the gRPC service it uses. These classes provide a structured, immutable representation of the configuration defined in the xDS protobuf messages. The main new classes are: - `ExtAuthzConfig`: Represents the configuration for the `ExtAuthz` filter, including settings for the gRPC service, header mutation rules, and other filter behaviors. - `GrpcServiceConfig`: Represents the configuration for a gRPC service, including the target URI, credentials, and other settings. - `HeaderMutationRulesConfig`: Represents the configuration for header mutation rules. This commit also includes parsers to create these configuration objects from the corresponding protobuf messages, as well as unit tests for the new classes. --- .../xds/internal/extauthz/ExtAuthzConfig.java | 250 ++++++++++++++ .../extauthz/ExtAuthzParseException.java | 34 ++ .../grpcservice/GrpcServiceConfig.java | 308 ++++++++++++++++++ .../GrpcServiceConfigChannelFactory.java | 26 ++ .../GrpcServiceParseException.java | 33 ++ .../InsecureGrpcChannelFactory.java | 43 +++ .../HeaderMutationRulesConfig.java | 77 +++++ .../internal/extauthz/ExtAuthzConfigTest.java | 259 +++++++++++++++ .../grpcservice/GrpcServiceConfigTest.java | 243 ++++++++++++++ .../InsecureGrpcChannelFactoryTest.java | 57 ++++ .../HeaderMutationRulesConfigTest.java | 84 +++++ 11 files changed, 1414 insertions(+) create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigChannelFactory.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactory.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzConfigTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactoryTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java new file mode 100644 index 00000000000..e826f501d9c --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzConfig.java @@ -0,0 +1,250 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.grpc.Status; +import io.grpc.internal.GrpcUtil; +import io.grpc.xds.internal.MatcherParser; +import io.grpc.xds.internal.Matchers; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig; +import io.grpc.xds.internal.grpcservice.GrpcServiceParseException; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesConfig; +import java.util.Optional; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; + +/** + * Represents the configuration for the external authorization (ext_authz) filter. This class + * encapsulates the settings defined in the + * {@link io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz} proto, providing a + * structured, immutable representation for use within gRPC. It includes configurations for the gRPC + * service used for authorization, header mutation rules, and other filter behaviors. + */ +@AutoValue +public abstract class ExtAuthzConfig { + + /** Creates a new builder for creating {@link ExtAuthzConfig} instances. */ + public static Builder builder() { + return new AutoValue_ExtAuthzConfig.Builder().allowedHeaders(ImmutableList.of()) + .disallowedHeaders(ImmutableList.of()).statusOnError(Status.PERMISSION_DENIED) + .filterEnabled(Matchers.FractionMatcher.create(100, 100)); + } + + /** + * Parses the {@link io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz} proto to + * create an {@link ExtAuthzConfig} instance. + * + * @param extAuthzProto The ext_authz proto to parse. + * @return An {@link ExtAuthzConfig} instance. + * @throws ExtAuthzParseException if the proto is invalid or contains unsupported features. + */ + public static ExtAuthzConfig fromProto(ExtAuthz extAuthzProto) throws ExtAuthzParseException { + if (!extAuthzProto.hasGrpcService()) { + throw new ExtAuthzParseException( + "unsupported ExtAuthz service type: only grpc_service is " + "supported"); + } + GrpcServiceConfig grpcServiceConfig; + try { + grpcServiceConfig = GrpcServiceConfig.fromProto(extAuthzProto.getGrpcService()); + } catch (GrpcServiceParseException e) { + throw new ExtAuthzParseException("Failed to parse GrpcService config: " + e.getMessage(), e); + } + Builder builder = builder().grpcService(grpcServiceConfig) + .failureModeAllow(extAuthzProto.getFailureModeAllow()) + .failureModeAllowHeaderAdd(extAuthzProto.getFailureModeAllowHeaderAdd()) + .includePeerCertificate(extAuthzProto.getIncludePeerCertificate()) + .denyAtDisable(extAuthzProto.getDenyAtDisable().getDefaultValue().getValue()); + + if (extAuthzProto.hasFilterEnabled()) { + builder.filterEnabled(parsePercent(extAuthzProto.getFilterEnabled().getDefaultValue())); + } + + if (extAuthzProto.hasStatusOnError()) { + builder.statusOnError( + GrpcUtil.httpStatusToGrpcStatus(extAuthzProto.getStatusOnError().getCodeValue())); + } + + if (extAuthzProto.hasAllowedHeaders()) { + builder.allowedHeaders(extAuthzProto.getAllowedHeaders().getPatternsList().stream() + .map(MatcherParser::parseStringMatcher).collect(ImmutableList.toImmutableList())); + } + + if (extAuthzProto.hasDisallowedHeaders()) { + builder.disallowedHeaders(extAuthzProto.getDisallowedHeaders().getPatternsList().stream() + .map(MatcherParser::parseStringMatcher).collect(ImmutableList.toImmutableList())); + } + + if (extAuthzProto.hasDecoderHeaderMutationRules()) { + builder.decoderHeaderMutationRules( + parseHeaderMutationRules(extAuthzProto.getDecoderHeaderMutationRules())); + } + + return builder.build(); + } + + /** + * The gRPC service configuration for the external authorization service. This is a required + * field. + * + * @see ExtAuthz#getGrpcService() + */ + public abstract GrpcServiceConfig grpcService(); + + /** + * Changes the filter's behavior on errors from the authorization service. If {@code true}, the + * filter will accept the request even if the authorization service fails or returns an error. + * + * @see ExtAuthz#getFailureModeAllow() + */ + public abstract boolean failureModeAllow(); + + /** + * Determines if the {@code x-envoy-auth-failure-mode-allowed} header is added to the request when + * {@link #failureModeAllow()} is true. + * + * @see ExtAuthz#getFailureModeAllowHeaderAdd() + */ + public abstract boolean failureModeAllowHeaderAdd(); + + /** + * Specifies if the peer certificate is sent to the external authorization service. + * + * @see ExtAuthz#getIncludePeerCertificate() + */ + public abstract boolean includePeerCertificate(); + + /** + * The gRPC status returned to the client when the authorization server returns an error or is + * unreachable. Defaults to {@code PERMISSION_DENIED}. + * + * @see io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz#getStatusOnError() + */ + public abstract Status statusOnError(); + + /** + * Specifies whether to deny requests when the filter is disabled. Defaults to {@code false}. + * + * @see ExtAuthz#getDenyAtDisable() + */ + public abstract boolean denyAtDisable(); + + /** + * The fraction of requests that will be checked by the authorization service. Defaults to all + * requests. + * + * @see ExtAuthz#getFilterEnabled() + */ + public abstract Matchers.FractionMatcher filterEnabled(); + + /** + * Specifies which request headers are sent to the authorization service. If not set, all headers + * are sent. + * + * @see ExtAuthz#getAllowedHeaders() + */ + public abstract ImmutableList allowedHeaders(); + + /** + * Specifies which request headers are not sent to the authorization service. This overrides + * {@link #allowedHeaders()}. + * + * @see ExtAuthz#getDisallowedHeaders() + */ + public abstract ImmutableList disallowedHeaders(); + + /** + * Rules for what modifications an ext_authz server may make to request headers. + * + * @see ExtAuthz#getDecoderHeaderMutationRules() + */ + public abstract Optional decoderHeaderMutationRules(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder grpcService(GrpcServiceConfig grpcService); + + public abstract Builder failureModeAllow(boolean failureModeAllow); + + public abstract Builder failureModeAllowHeaderAdd(boolean failureModeAllowHeaderAdd); + + public abstract Builder includePeerCertificate(boolean includePeerCertificate); + + public abstract Builder statusOnError(Status statusOnError); + + public abstract Builder denyAtDisable(boolean denyAtDisable); + + public abstract Builder filterEnabled(Matchers.FractionMatcher filterEnabled); + + public abstract Builder allowedHeaders(Iterable allowedHeaders); + + public abstract Builder disallowedHeaders(Iterable disallowedHeaders); + + public abstract Builder decoderHeaderMutationRules(HeaderMutationRulesConfig rules); + + public abstract ExtAuthzConfig build(); + } + + + private static Matchers.FractionMatcher parsePercent( + io.envoyproxy.envoy.type.v3.FractionalPercent proto) throws ExtAuthzParseException { + int denominator; + switch (proto.getDenominator()) { + case HUNDRED: + denominator = 100; + break; + case TEN_THOUSAND: + denominator = 10_000; + break; + case MILLION: + denominator = 1_000_000; + break; + case UNRECOGNIZED: + default: + throw new ExtAuthzParseException("Unknown denominator type: " + proto.getDenominator()); + } + return Matchers.FractionMatcher.create(proto.getNumerator(), denominator); + } + + private static HeaderMutationRulesConfig parseHeaderMutationRules(HeaderMutationRules proto) + throws ExtAuthzParseException { + HeaderMutationRulesConfig.Builder builder = HeaderMutationRulesConfig.builder(); + builder.disallowAll(proto.getDisallowAll().getValue()); + builder.disallowIsError(proto.getDisallowIsError().getValue()); + if (proto.hasAllowExpression()) { + builder.allowExpression( + parseRegex(proto.getAllowExpression().getRegex(), "allow_expression")); + } + if (proto.hasDisallowExpression()) { + builder.disallowExpression( + parseRegex(proto.getDisallowExpression().getRegex(), "disallow_expression")); + } + return builder.build(); + } + + private static Pattern parseRegex(String regex, String fieldName) throws ExtAuthzParseException { + try { + return Pattern.compile(regex); + } catch (PatternSyntaxException e) { + throw new ExtAuthzParseException( + "Invalid regex pattern for " + fieldName + ": " + e.getMessage(), e); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java new file mode 100644 index 00000000000..78edea5c305 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzParseException.java @@ -0,0 +1,34 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +/** + * A custom exception for signaling errors during the parsing of external authorization + * (ext_authz) configurations. + */ +public class ExtAuthzParseException extends Exception { + + private static final long serialVersionUID = 0L; + + public ExtAuthzParseException(String message) { + super(message); + } + + public ExtAuthzParseException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java new file mode 100644 index 00000000000..da9be978f87 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfig.java @@ -0,0 +1,308 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import com.google.auth.oauth2.AccessToken; +import com.google.auth.oauth2.OAuth2Credentials; +import com.google.auto.value.AutoValue; +import com.google.common.io.BaseEncoding; +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.xds.v3.XdsCredentials; +import io.grpc.CallCredentials; +import io.grpc.ChannelCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.Metadata; +import io.grpc.alts.GoogleDefaultChannelCredentials; +import io.grpc.auth.MoreCallCredentials; +import io.grpc.xds.XdsChannelCredentials; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Date; +import java.util.List; +import java.util.Optional; + + +/** + * A Java representation of the {@link io.envoyproxy.envoy.config.core.v3.GrpcService} proto, + * designed for parsing and internal use within gRPC. This class encapsulates the configuration for + * a gRPC service, including target URI, credentials, and other settings. The parsing logic adheres + * to the specifications outlined in + * A102: xDS GrpcService Support. This class is immutable and uses the AutoValue library for its + * implementation. + */ +@AutoValue +public abstract class GrpcServiceConfig { + + public static Builder builder() { + return new AutoValue_GrpcServiceConfig.Builder(); + } + + /** + * Parses the {@link io.envoyproxy.envoy.config.core.v3.GrpcService} proto to create a + * {@link GrpcServiceConfig} instance. This method adheres to gRFC A102, which specifies that only + * the {@code google_grpc} target specifier is supported. Other fields like {@code timeout} and + * {@code initial_metadata} are also parsed as per the gRFC. + * + * @param grpcServiceProto The proto to parse. + * @return A {@link GrpcServiceConfig} instance. + * @throws GrpcServiceParseException if the proto is invalid or uses unsupported features. + */ + public static GrpcServiceConfig fromProto(GrpcService grpcServiceProto) + throws GrpcServiceParseException { + if (!grpcServiceProto.hasGoogleGrpc()) { + throw new GrpcServiceParseException( + "Unsupported: GrpcService must have GoogleGrpc, got: " + grpcServiceProto); + } + GoogleGrpcConfig googleGrpcConfig = + GoogleGrpcConfig.fromProto(grpcServiceProto.getGoogleGrpc()); + + Builder builder = GrpcServiceConfig.builder().googleGrpc(googleGrpcConfig); + + if (!grpcServiceProto.getInitialMetadataList().isEmpty()) { + Metadata initialMetadata = new Metadata(); + for (io.envoyproxy.envoy.config.core.v3.HeaderValue header : grpcServiceProto + .getInitialMetadataList()) { + String key = header.getKey(); + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + initialMetadata.put(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER), + BaseEncoding.base64().decode(header.getValue())); + } else { + initialMetadata.put(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER), + header.getValue()); + } + } + builder.initialMetadata(initialMetadata); + } + + if (grpcServiceProto.hasTimeout()) { + com.google.protobuf.Duration timeout = grpcServiceProto.getTimeout(); + builder.timeout(Duration.ofSeconds(timeout.getSeconds(), timeout.getNanos())); + } + return builder.build(); + } + + public abstract GoogleGrpcConfig googleGrpc(); + + public abstract Optional timeout(); + + public abstract Optional initialMetadata(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder googleGrpc(GoogleGrpcConfig googleGrpc); + + public abstract Builder timeout(Duration timeout); + + public abstract Builder initialMetadata(Metadata initialMetadata); + + public abstract GrpcServiceConfig build(); + } + + /** + * Represents the configuration for a Google gRPC service, as defined in the + * {@link io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc} proto. This class + * encapsulates settings specific to Google's gRPC implementation, such as target URI and + * credentials. The parsing of this configuration is guided by gRFC A102, which specifies how gRPC + * clients should interpret the GrpcService proto. + */ + @AutoValue + public abstract static class GoogleGrpcConfig { + + private static final String TLS_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "tls.v3.TlsCredentials"; + private static final String LOCAL_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "local.v3.LocalCredentials"; + private static final String XDS_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "xds.v3.XdsCredentials"; + private static final String INSECURE_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "insecure.v3.InsecureCredentials"; + private static final String GOOGLE_DEFAULT_CREDENTIALS_TYPE_URL = + "type.googleapis.com/envoy.extensions.grpc_service.channel_credentials." + + "google_default.v3.GoogleDefaultCredentials"; + + public static Builder builder() { + return new AutoValue_GrpcServiceConfig_GoogleGrpcConfig.Builder(); + } + + /** + * Parses the {@link io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc} proto to create + * a {@link GoogleGrpcConfig} instance. + * + * @param googleGrpcProto The proto to parse. + * @return A {@link GoogleGrpcConfig} instance. + * @throws GrpcServiceParseException if the proto is invalid. + */ + public static GoogleGrpcConfig fromProto(GrpcService.GoogleGrpc googleGrpcProto) + throws GrpcServiceParseException { + + HashedChannelCredentials channelCreds = + extractChannelCredentials(googleGrpcProto.getChannelCredentialsPluginList()); + + CallCredentials callCreds = + extractCallCredentials(googleGrpcProto.getCallCredentialsPluginList()); + + return GoogleGrpcConfig.builder().target(googleGrpcProto.getTargetUri()) + .hashedChannelCredentials(channelCreds).callCredentials(callCreds).build(); + } + + public abstract String target(); + + public abstract HashedChannelCredentials hashedChannelCredentials(); + + public abstract CallCredentials callCredentials(); + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder target(String target); + + public abstract Builder hashedChannelCredentials(HashedChannelCredentials channelCredentials); + + public abstract Builder callCredentials(CallCredentials callCredentials); + + public abstract GoogleGrpcConfig build(); + } + + private static T getFirstSupported(List configs, Parser parser, + String configName) throws GrpcServiceParseException { + List errors = new ArrayList<>(); + for (U config : configs) { + try { + return parser.parse(config); + } catch (GrpcServiceParseException e) { + errors.add(e.getMessage()); + } + } + throw new GrpcServiceParseException( + "No valid supported " + configName + " found. Errors: " + errors); + } + + private static HashedChannelCredentials channelCredsFromProto(Any cred) + throws GrpcServiceParseException { + String typeUrl = cred.getTypeUrl(); + try { + switch (typeUrl) { + case GOOGLE_DEFAULT_CREDENTIALS_TYPE_URL: + return HashedChannelCredentials.of(GoogleDefaultChannelCredentials.create(), + cred.hashCode()); + case INSECURE_CREDENTIALS_TYPE_URL: + return HashedChannelCredentials.of(InsecureChannelCredentials.create(), + cred.hashCode()); + case XDS_CREDENTIALS_TYPE_URL: + XdsCredentials xdsConfig = cred.unpack(XdsCredentials.class); + HashedChannelCredentials fallbackCreds = + channelCredsFromProto(xdsConfig.getFallbackCredentials()); + return HashedChannelCredentials.of( + XdsChannelCredentials.create(fallbackCreds.channelCredentials()), cred.hashCode()); + case LOCAL_CREDENTIALS_TYPE_URL: + // TODO(sauravzg) : What's the java alternative to LocalCredentials. + throw new GrpcServiceParseException("LocalCredentials are not yet supported."); + case TLS_CREDENTIALS_TYPE_URL: + // TODO(sauravzg) : How to instantiate a TlsChannelCredentials from TlsCredentials + // proto? + throw new GrpcServiceParseException("TlsCredentials are not yet supported."); + default: + throw new GrpcServiceParseException("Unsupported channel credentials type: " + typeUrl); + } + } catch (InvalidProtocolBufferException e) { + // TODO(sauravzg): Add unit tests when we have a solution for TLS creds. + // This code is as of writing unreachable because all channel credential message + // types except TLS are empty messages. + throw new GrpcServiceParseException( + "Failed to parse channel credentials: " + e.getMessage()); + } + } + + private static CallCredentials callCredsFromProto(Any cred) throws GrpcServiceParseException { + try { + AccessTokenCredentials accessToken = cred.unpack(AccessTokenCredentials.class); + // TODO(sauravzg): Verify if the current behavior is per spec.The `AccessTokenCredentials` + // config doesn't have any timeout/refresh, so set the token to never expire. + return MoreCallCredentials.from(OAuth2Credentials + .create(new AccessToken(accessToken.getToken(), new Date(Long.MAX_VALUE)))); + } catch (InvalidProtocolBufferException e) { + throw new GrpcServiceParseException( + "Unsupported call credentials type: " + cred.getTypeUrl()); + } + } + + private static HashedChannelCredentials extractChannelCredentials( + List channelCredentialPlugins) throws GrpcServiceParseException { + return getFirstSupported(channelCredentialPlugins, GoogleGrpcConfig::channelCredsFromProto, + "channel_credentials"); + } + + private static CallCredentials extractCallCredentials(List callCredentialPlugins) + throws GrpcServiceParseException { + return getFirstSupported(callCredentialPlugins, GoogleGrpcConfig::callCredsFromProto, + "call_credentials"); + } + } + + /** + * A container for {@link ChannelCredentials} and a hash for the purpose of caching. + */ + @AutoValue + public abstract static class HashedChannelCredentials { + /** + * Creates a new {@link HashedChannelCredentials} instance. + * + * @param creds The channel credentials. + * @param hash The hash of the credentials. + * @return A new {@link HashedChannelCredentials} instance. + */ + public static HashedChannelCredentials of(ChannelCredentials creds, int hash) { + return new AutoValue_GrpcServiceConfig_HashedChannelCredentials(creds, hash); + } + + /** + * Returns the channel credentials. + */ + public abstract ChannelCredentials channelCredentials(); + + /** + * Returns the hash of the credentials. + */ + public abstract int hash(); + } + + /** + * Defines a generic interface for parsing a configuration of type {@code U} into a result of type + * {@code T}. This functional interface is used to abstract the parsing logic for different parts + * of the GrpcService configuration. + * + * @param The type of the object that will be returned after parsing. + * @param The type of the configuration object that will be parsed. + */ + private interface Parser { + + /** + * Parses the given configuration. + * + * @param config The configuration object to parse. + * @return The parsed object of type {@code T}. + * @throws GrpcServiceParseException if an error occurs during parsing. + */ + T parse(U config) throws GrpcServiceParseException; + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigChannelFactory.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigChannelFactory.java new file mode 100644 index 00000000000..0d02989eaa3 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigChannelFactory.java @@ -0,0 +1,26 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import io.grpc.ManagedChannel; + +/** + * A factory for creating {@link ManagedChannel}s from a {@link GrpcServiceConfig}. + */ +public interface GrpcServiceConfigChannelFactory { + ManagedChannel createChannel(GrpcServiceConfig config); +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java new file mode 100644 index 00000000000..319ad3d07e3 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/GrpcServiceParseException.java @@ -0,0 +1,33 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +/** + * Exception thrown when there is an error parsing the gRPC service config. + */ +public class GrpcServiceParseException extends Exception { + + private static final long serialVersionUID = 1L; + + public GrpcServiceParseException(String message) { + super(message); + } + + public GrpcServiceParseException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactory.java b/xds/src/main/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactory.java new file mode 100644 index 00000000000..d6325d43be4 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactory.java @@ -0,0 +1,43 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import io.grpc.Grpc; +import io.grpc.ManagedChannel; + +/** + * An insecure implementation of {@link GrpcServiceConfigChannelFactory} that creates a plaintext + * channel. This is a stub implementation for channel creation until the GrpcService trusted server + * implementation is completely implemented. + */ +public final class InsecureGrpcChannelFactory implements GrpcServiceConfigChannelFactory { + + private static final InsecureGrpcChannelFactory INSTANCE = new InsecureGrpcChannelFactory(); + + private InsecureGrpcChannelFactory() {} + + public static InsecureGrpcChannelFactory getInstance() { + return INSTANCE; + } + + @Override + public ManagedChannel createChannel(GrpcServiceConfig config) { + GrpcServiceConfig.GoogleGrpcConfig googleGrpc = config.googleGrpc(); + return Grpc.newChannelBuilder(googleGrpc.target(), + googleGrpc.hashedChannelCredentials().channelCredentials()).build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java new file mode 100644 index 00000000000..fd8048fdbd2 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfig.java @@ -0,0 +1,77 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.auto.value.AutoValue; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; +import java.util.Optional; +import java.util.regex.Pattern; + +/** + * Represents the configuration for header mutation rules, as defined in the + * {@link io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules} proto. + */ +@AutoValue +public abstract class HeaderMutationRulesConfig { + /** Creates a new builder for creating {@link HeaderMutationRulesConfig} instances. */ + public static Builder builder() { + return new AutoValue_HeaderMutationRulesConfig.Builder().disallowAll(false) + .disallowIsError(false); + } + + /** + * If set, allows any header that matches this regular expression. + * + * @see HeaderMutationRules#getAllowExpression() + */ + public abstract Optional allowExpression(); + + /** + * If set, disallows any header that matches this regular expression. + * + * @see HeaderMutationRules#getDisallowExpression() + */ + public abstract Optional disallowExpression(); + + /** + * If true, disallows all header mutations. + * + * @see HeaderMutationRules#getDisallowAll() + */ + public abstract boolean disallowAll(); + + /** + * If true, disallows any header mutation that would result in an invalid header value. + * + * @see HeaderMutationRules#getDisallowIsError() + */ + public abstract boolean disallowIsError(); + + + @AutoValue.Builder + public abstract static class Builder { + public abstract Builder allowExpression(Pattern matcher); + + public abstract Builder disallowExpression(Pattern matcher); + + public abstract Builder disallowAll(boolean disallowAll); + + public abstract Builder disallowIsError(boolean disallowIsError); + + public abstract HeaderMutationRulesConfig build(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzConfigTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzConfigTest.java new file mode 100644 index 00000000000..9b9a55b4079 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzConfigTest.java @@ -0,0 +1,259 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.fail; + +import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; +import io.envoyproxy.envoy.config.common.mutation_rules.v3.HeaderMutationRules; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.RuntimeFeatureFlag; +import io.envoyproxy.envoy.config.core.v3.RuntimeFractionalPercent; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.type.matcher.v3.ListStringMatcher; +import io.envoyproxy.envoy.type.matcher.v3.RegexMatcher; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.envoyproxy.envoy.type.v3.FractionalPercent; +import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; +import io.grpc.Status; +import io.grpc.xds.internal.Matchers; +import io.grpc.xds.internal.headermutations.HeaderMutationRulesConfig; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class ExtAuthzConfigTest { + + private static final Any GOOGLE_DEFAULT_CHANNEL_CREDS = + Any.pack(GoogleDefaultCredentials.newBuilder().build()); + private static final Any FAKE_ACCESS_TOKEN_CALL_CREDS = + Any.pack(AccessTokenCredentials.newBuilder().build()); + + private ExtAuthz.Builder extAuthzBuilder; + + @Before + public void setUp() { + extAuthzBuilder = ExtAuthz.newBuilder() + .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder() + .setGoogleGrpc(io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test-cluster") + .addChannelCredentialsPlugin(GOOGLE_DEFAULT_CHANNEL_CREDS) + .addCallCredentialsPlugin(FAKE_ACCESS_TOKEN_CALL_CREDS).build()) + .build()); + } + + @Test + public void fromProto_missingGrpcService_throws() { + ExtAuthz extAuthz = ExtAuthz.newBuilder().build(); + try { + ExtAuthzConfig.fromProto(extAuthz); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat() + .isEqualTo("unsupported ExtAuthz service type: only grpc_service is supported"); + } + } + + @Test + public void fromProto_invalidGrpcService_throws() { + ExtAuthz extAuthz = ExtAuthz.newBuilder() + .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder().build()) + .build(); + try { + ExtAuthzConfig.fromProto(extAuthz); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().startsWith("Failed to parse GrpcService config:"); + } + } + + @Test + public void fromProto_invalidAllowExpression_throws() { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("[invalid").build()).build()) + .build(); + try { + ExtAuthzConfig.fromProto(extAuthz); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().startsWith("Invalid regex pattern for allow_expression:"); + } + } + + @Test + public void fromProto_invalidDisallowExpression_throws() { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("[invalid").build()).build()) + .build(); + try { + ExtAuthzConfig.fromProto(extAuthz); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().startsWith("Invalid regex pattern for disallow_expression:"); + } + } + + @Test + public void fromProto_success() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setGrpcService(extAuthzBuilder.getGrpcServiceBuilder() + .setTimeout(com.google.protobuf.Duration.newBuilder().setSeconds(5).build()) + .addInitialMetadata(HeaderValue.newBuilder().setKey("key").setValue("value").build()) + .build()) + .setFailureModeAllow(true).setFailureModeAllowHeaderAdd(true) + .setIncludePeerCertificate(true) + .setStatusOnError( + io.envoyproxy.envoy.type.v3.HttpStatus.newBuilder().setCodeValue(403).build()) + .setDenyAtDisable( + RuntimeFeatureFlag.newBuilder().setDefaultValue(BoolValue.of(true)).build()) + .setFilterEnabled(RuntimeFractionalPercent.newBuilder() + .setDefaultValue(FractionalPercent.newBuilder().setNumerator(50) + .setDenominator(DenominatorType.TEN_THOUSAND).build()) + .build()) + .setAllowedHeaders(ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("allowed-header").build()).build()) + .setDisallowedHeaders(ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setPrefix("disallowed-").build()).build()) + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("allow.*").build()) + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("disallow.*").build()) + .setDisallowAll(BoolValue.of(true)).setDisallowIsError(BoolValue.of(true)).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthz); + + assertThat(config.grpcService().googleGrpc().target()).isEqualTo("test-cluster"); + assertThat(config.grpcService().timeout().get().getSeconds()).isEqualTo(5); + assertThat(config.grpcService().initialMetadata().isPresent()).isTrue(); + assertThat(config.failureModeAllow()).isTrue(); + assertThat(config.failureModeAllowHeaderAdd()).isTrue(); + assertThat(config.includePeerCertificate()).isTrue(); + assertThat(config.statusOnError().getCode()).isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(config.statusOnError().getDescription()).isEqualTo("HTTP status code 403"); + assertThat(config.denyAtDisable()).isTrue(); + assertThat(config.filterEnabled()).isEqualTo(Matchers.FractionMatcher.create(50, 10_000)); + assertThat(config.allowedHeaders()).hasSize(1); + assertThat(config.allowedHeaders().get(0).matches("allowed-header")).isTrue(); + assertThat(config.disallowedHeaders()).hasSize(1); + assertThat(config.disallowedHeaders().get(0).matches("disallowed-foo")).isTrue(); + assertThat(config.decoderHeaderMutationRules().isPresent()).isTrue(); + HeaderMutationRulesConfig rules = config.decoderHeaderMutationRules().get(); + assertThat(rules.allowExpression().get().pattern()).isEqualTo("allow.*"); + assertThat(rules.disallowExpression().get().pattern()).isEqualTo("disallow.*"); + assertThat(rules.disallowAll()).isTrue(); + assertThat(rules.disallowIsError()).isTrue(); + } + + @Test + public void fromProto_saneDefaults() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder.build(); + + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthz); + + assertThat(config.failureModeAllow()).isFalse(); + assertThat(config.failureModeAllowHeaderAdd()).isFalse(); + assertThat(config.includePeerCertificate()).isFalse(); + assertThat(config.statusOnError()).isEqualTo(Status.PERMISSION_DENIED); + assertThat(config.denyAtDisable()).isFalse(); + assertThat(config.filterEnabled()).isEqualTo(Matchers.FractionMatcher.create(100, 100)); + assertThat(config.allowedHeaders()).isEmpty(); + assertThat(config.disallowedHeaders()).isEmpty(); + assertThat(config.decoderHeaderMutationRules().isPresent()).isFalse(); + } + + @Test + public void fromProto_headerMutationRules_allowExpressionOnly() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setAllowExpression(RegexMatcher.newBuilder().setRegex("allow.*").build()).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthz); + + assertThat(config.decoderHeaderMutationRules().isPresent()).isTrue(); + HeaderMutationRulesConfig rules = config.decoderHeaderMutationRules().get(); + assertThat(rules.allowExpression().get().pattern()).isEqualTo("allow.*"); + assertThat(rules.disallowExpression().isPresent()).isFalse(); + } + + @Test + public void fromProto_headerMutationRules_disallowExpressionOnly() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setDecoderHeaderMutationRules(HeaderMutationRules.newBuilder() + .setDisallowExpression(RegexMatcher.newBuilder().setRegex("disallow.*").build()) + .build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthz); + + assertThat(config.decoderHeaderMutationRules().isPresent()).isTrue(); + HeaderMutationRulesConfig rules = config.decoderHeaderMutationRules().get(); + assertThat(rules.allowExpression().isPresent()).isFalse(); + assertThat(rules.disallowExpression().get().pattern()).isEqualTo("disallow.*"); + } + + @Test + public void fromProto_filterEnabled_hundred() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setFilterEnabled(RuntimeFractionalPercent.newBuilder().setDefaultValue(FractionalPercent + .newBuilder().setNumerator(25).setDenominator(DenominatorType.HUNDRED).build()).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthz); + + assertThat(config.filterEnabled()).isEqualTo(Matchers.FractionMatcher.create(25, 100)); + } + + @Test + public void fromProto_filterEnabled_million() throws ExtAuthzParseException { + ExtAuthz extAuthz = extAuthzBuilder + .setFilterEnabled( + RuntimeFractionalPercent.newBuilder().setDefaultValue(FractionalPercent.newBuilder() + .setNumerator(123456).setDenominator(DenominatorType.MILLION).build()).build()) + .build(); + + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthz); + + assertThat(config.filterEnabled()) + .isEqualTo(Matchers.FractionMatcher.create(123456, 1_000_000)); + } + + @Test + public void fromProto_filterEnabled_unrecognizedDenominator() { + ExtAuthz extAuthz = extAuthzBuilder + .setFilterEnabled(RuntimeFractionalPercent.newBuilder() + .setDefaultValue( + FractionalPercent.newBuilder().setNumerator(1).setDenominatorValue(4).build()) + .build()) + .build(); + + try { + ExtAuthzConfig.fromProto(extAuthz); + fail("Expected ExtAuthzParseException"); + } catch (ExtAuthzParseException e) { + assertThat(e).hasMessageThat().isEqualTo("Unknown denominator type: UNRECOGNIZED"); + } + } +} \ No newline at end of file diff --git a/xds/src/test/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigTest.java b/xds/src/test/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigTest.java new file mode 100644 index 00000000000..7a506220973 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/grpcservice/GrpcServiceConfigTest.java @@ -0,0 +1,243 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; + +import com.google.common.io.BaseEncoding; +import com.google.protobuf.Any; +import com.google.protobuf.Duration; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.insecure.v3.InsecureCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.local.v3.LocalCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.xds.v3.XdsCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.Metadata; +import java.nio.charset.StandardCharsets; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class GrpcServiceConfigTest { + + @Test + public void fromProto_success() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + HeaderValue asciiHeader = + HeaderValue.newBuilder().setKey("test_key").setValue("test_value").build(); + HeaderValue binaryHeader = HeaderValue.newBuilder().setKey("test_key-bin") + .setValue( + BaseEncoding.base64().encode("test_value_binary".getBytes(StandardCharsets.UTF_8))) + .build(); + Duration timeout = Duration.newBuilder().setSeconds(10).build(); + GrpcService grpcService = + GrpcService.newBuilder().setGoogleGrpc(googleGrpc).addInitialMetadata(asciiHeader) + .addInitialMetadata(binaryHeader).setTimeout(timeout).build(); + + GrpcServiceConfig config = GrpcServiceConfig.fromProto(grpcService); + + // Assert target URI + assertThat(config.googleGrpc().target()).isEqualTo("test_uri"); + + // Assert channel credentials + assertThat(config.googleGrpc().hashedChannelCredentials().channelCredentials()) + .isInstanceOf(InsecureChannelCredentials.class); + assertThat(config.googleGrpc().hashedChannelCredentials().hash()) + .isEqualTo(insecureCreds.hashCode()); + + // Assert call credentials + assertThat(config.googleGrpc().callCredentials().getClass().getName()) + .isEqualTo("io.grpc.auth.GoogleAuthLibraryCallCredentials"); + + // Assert initial metadata + assertThat(config.initialMetadata().isPresent()).isTrue(); + assertThat(config.initialMetadata().get() + .get(Metadata.Key.of("test_key", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("test_value"); + assertThat(config.initialMetadata().get() + .get(Metadata.Key.of("test_key-bin", Metadata.BINARY_BYTE_MARSHALLER))) + .isEqualTo("test_value_binary".getBytes(StandardCharsets.UTF_8)); + + // Assert timeout + assertThat(config.timeout().isPresent()).isTrue(); + assertThat(config.timeout().get()).isEqualTo(java.time.Duration.ofSeconds(10)); + } + + @Test + public void fromProto_minimalSuccess_defaults() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = GrpcServiceConfig.fromProto(grpcService); + + assertThat(config.googleGrpc().target()).isEqualTo("test_uri"); + assertThat(config.initialMetadata().isPresent()).isFalse(); + assertThat(config.timeout().isPresent()).isFalse(); + } + + @Test + public void fromProto_missingGoogleGrpc() { + GrpcService grpcService = GrpcService.newBuilder().build(); + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat() + .startsWith("Unsupported: GrpcService must have GoogleGrpc, got: "); + } + + @Test + public void fromProto_emptyCallCredentials() { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat() + .isEqualTo("No valid supported call_credentials found. Errors: []"); + } + + @Test + public void fromProto_emptyChannelCredentials() { + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addCallCredentialsPlugin(accessTokenCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat() + .isEqualTo("No valid supported channel_credentials found. Errors: []"); + } + + @Test + public void fromProto_googleDefaultCredentials() throws GrpcServiceParseException { + Any googleDefaultCreds = Any.pack(GoogleDefaultCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(googleDefaultCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = GrpcServiceConfig.fromProto(grpcService); + + assertThat(config.googleGrpc().hashedChannelCredentials().channelCredentials()) + .isInstanceOf(io.grpc.CompositeChannelCredentials.class); + assertThat(config.googleGrpc().hashedChannelCredentials().hash()) + .isEqualTo(googleDefaultCreds.hashCode()); + } + + @Test + public void fromProto_localCredentials() throws GrpcServiceParseException { + Any localCreds = Any.pack(LocalCredentials.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(localCreds).addCallCredentialsPlugin(accessTokenCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat().contains("LocalCredentials are not yet supported."); + } + + @Test + public void fromProto_xdsCredentials_withInsecureFallback() throws GrpcServiceParseException { + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + XdsCredentials xdsCreds = + XdsCredentials.newBuilder().setFallbackCredentials(insecureCreds).build(); + Any xdsCredsAny = Any.pack(xdsCreds); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(xdsCredsAny).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceConfig config = GrpcServiceConfig.fromProto(grpcService); + + assertThat(config.googleGrpc().hashedChannelCredentials().channelCredentials()) + .isInstanceOf(io.grpc.ChannelCredentials.class); + assertThat(config.googleGrpc().hashedChannelCredentials().hash()) + .isEqualTo(xdsCredsAny.hashCode()); + } + + @Test + public void fromProto_tlsCredentials_notSupported() { + Any tlsCreds = Any + .pack(io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.tls.v3.TlsCredentials + .getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(tlsCreds).addCallCredentialsPlugin(accessTokenCreds).build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat().contains("TlsCredentials are not yet supported."); + } + + @Test + public void fromProto_invalidChannelCredentialsProto() { + // Pack a Duration proto, but try to unpack it as GoogleDefaultCredentials + Any invalidCreds = Any.pack(com.google.protobuf.Duration.getDefaultInstance()); + Any accessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("test_token").build()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(invalidCreds).addCallCredentialsPlugin(accessTokenCreds) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat() + .contains("No valid supported channel_credentials found. Errors: [Unsupported channel " + + "credentials type: type.googleapis.com/google.protobuf.Duration"); + } + + @Test + public void fromProto_invalidCallCredentialsProto() { + // Pack a Duration proto, but try to unpack it as AccessTokenCredentials + Any insecureCreds = Any.pack(InsecureCredentials.getDefaultInstance()); + Any invalidCallCredentials = Any.pack(Duration.getDefaultInstance()); + GrpcService.GoogleGrpc googleGrpc = GrpcService.GoogleGrpc.newBuilder().setTargetUri("test_uri") + .addChannelCredentialsPlugin(insecureCreds).addCallCredentialsPlugin(invalidCallCredentials) + .build(); + GrpcService grpcService = GrpcService.newBuilder().setGoogleGrpc(googleGrpc).build(); + + GrpcServiceParseException exception = assertThrows(GrpcServiceParseException.class, + () -> GrpcServiceConfig.fromProto(grpcService)); + assertThat(exception).hasMessageThat().contains("Unsupported call credentials type:"); + } +} + diff --git a/xds/src/test/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactoryTest.java b/xds/src/test/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactoryTest.java new file mode 100644 index 00000000000..8d7347f56c6 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/grpcservice/InsecureGrpcChannelFactoryTest.java @@ -0,0 +1,57 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.grpcservice; + +import static org.junit.Assert.assertNotNull; + +import io.grpc.CallCredentials; +import io.grpc.InsecureChannelCredentials; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig.GoogleGrpcConfig; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig.HashedChannelCredentials; +import java.util.concurrent.Executor; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +/** Unit tests for {@link InsecureGrpcChannelFactory}. */ +@RunWith(JUnit4.class) +public class InsecureGrpcChannelFactoryTest { + + private static final class NoOpCallCredentials extends CallCredentials { + @Override + public void applyRequestMetadata(RequestInfo requestInfo, Executor appExecutor, + MetadataApplier applier) { + applier.apply(new Metadata()); + } + } + + @Test + public void testCreateChannel() { + InsecureGrpcChannelFactory factory = InsecureGrpcChannelFactory.getInstance(); + GrpcServiceConfig config = GrpcServiceConfig.builder() + .googleGrpc(GoogleGrpcConfig.builder().target("localhost:8080") + .hashedChannelCredentials( + HashedChannelCredentials.of(InsecureChannelCredentials.create(), 0)) + .callCredentials(new NoOpCallCredentials()).build()) + .build(); + ManagedChannel channel = factory.createChannel(config); + assertNotNull(channel); + channel.shutdownNow(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java new file mode 100644 index 00000000000..e2bda9cb836 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationRulesConfigTest.java @@ -0,0 +1,84 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +import java.util.regex.Pattern; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutationRulesConfigTest { + @Test + public void testBuilderDefaultValues() { + HeaderMutationRulesConfig config = HeaderMutationRulesConfig.builder().build(); + assertFalse(config.disallowAll()); + assertFalse(config.disallowIsError()); + assertThat(config.allowExpression()).isEmpty(); + assertThat(config.disallowExpression()).isEmpty(); + } + + @Test + public void testBuilder_setDisallowAll() { + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().disallowAll(true).build(); + assertTrue(config.disallowAll()); + } + + @Test + public void testBuilder_setDisallowIsError() { + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().disallowIsError(true).build(); + assertTrue(config.disallowIsError()); + } + + @Test + public void testBuilder_setAllowExpression() { + Pattern pattern = Pattern.compile("allow.*"); + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().allowExpression(pattern).build(); + assertThat(config.allowExpression()).hasValue(pattern); + } + + @Test + public void testBuilder_setDisallowExpression() { + Pattern pattern = Pattern.compile("disallow.*"); + HeaderMutationRulesConfig config = + HeaderMutationRulesConfig.builder().disallowExpression(pattern).build(); + assertThat(config.disallowExpression()).hasValue(pattern); + } + + @Test + public void testBuilder_setAll() { + Pattern allowPattern = Pattern.compile("allow.*"); + Pattern disallowPattern = Pattern.compile("disallow.*"); + HeaderMutationRulesConfig config = HeaderMutationRulesConfig.builder() + .disallowAll(true) + .disallowIsError(true) + .allowExpression(allowPattern) + .disallowExpression(disallowPattern) + .build(); + assertTrue(config.disallowAll()); + assertTrue(config.disallowIsError()); + assertThat(config.allowExpression()).hasValue(allowPattern); + assertThat(config.disallowExpression()).hasValue(disallowPattern); + } +} From 2d08a46ad90eabf71cdf6620a4f17e1c93095beb Mon Sep 17 00:00:00 2001 From: Saurav Date: Mon, 10 Nov 2025 21:52:52 +0000 Subject: [PATCH 2/7] feat(xds): Implement request builder for external authorization This commit introduces the `CheckRequestBuilder` library, which is responsible for constructing the `CheckRequest` message sent to the external authorization service. The `CheckRequestBuilder` gathers information from various sources, including: - `ServerCall` attributes (local and remote addresses, SSL session). - `MethodDescriptor` (full method name). - Request headers. It uses this information to populate the `AttributeContext` of the `CheckRequest` message, which provides the authorization service with the necessary context to make an authorization decision. This commit also introduces the `ExtAuthzCertificateProvider`, a helper class for extracting certificate information, such as the principal and PEM-encoded certificate. Unit tests for the new components are also included. --- .../extauthz/CheckRequestBuilder.java | 316 ++++++++++++++++ .../extauthz/ExtAuthzCertificateProvider.java | 132 +++++++ .../extauthz/CheckRequestBuilderTest.java | 350 ++++++++++++++++++ .../ExtAuthzCertificateProviderTest.java | 140 +++++++ 4 files changed, 938 insertions(+) create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProvider.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProviderTest.java diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java new file mode 100644 index 00000000000..55234cd50dc --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckRequestBuilder.java @@ -0,0 +1,316 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.auto.value.AutoValue; +import com.google.common.io.BaseEncoding; +import com.google.protobuf.Timestamp; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.config.core.v3.SocketAddress; +import io.envoyproxy.envoy.service.auth.v3.AttributeContext; +import io.envoyproxy.envoy.service.auth.v3.CheckRequest; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.xds.internal.Matchers; +import java.io.UnsupportedEncodingException; +import java.net.InetSocketAddress; +import java.security.cert.Certificate; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Optional; +import java.util.logging.Level; +import java.util.logging.Logger; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; + +/** + * Interface for building external authorization check requests. + */ +public interface CheckRequestBuilder { + + /** + * A factory for creating {@link CheckRequestBuilder} instances. + */ + @FunctionalInterface + interface Factory { + /** + * Creates a new instance of the CheckRequestBuilder. + * + * @param config The external authorization configuration. + * @param certificateProvider The provider for certificate information. + * @return A new CheckRequestBuilder instance. + */ + CheckRequestBuilder create(ExtAuthzConfig config, + ExtAuthzCertificateProvider certificateProvider); + } + + /** The default factory for creating {@link CheckRequestBuilder} instances. */ + Factory INSTANCE = CheckRequestBuilderImpl::new; + + /** + * Builds a CheckRequest for a server-side call. + * + * @param serverCall The server call. + * @param headers The request headers. + * @param requestTime The time of the request. + * @return A new CheckRequest. + */ + CheckRequest buildRequest(ServerCall serverCall, Metadata headers, Timestamp requestTime); + + /** + * Builds a CheckRequest for a client-side call. + * + * @param methodDescriptor The method descriptor of the call. + * @param headers The request headers. + * @param requestTime The time of the request. + * @return A new CheckRequest. + */ + CheckRequest buildRequest(MethodDescriptor methodDescriptor, Metadata headers, + Timestamp requestTime); + + /** + * Implementation of the CheckRequestBuilder interface. + */ + final class CheckRequestBuilderImpl implements CheckRequestBuilder { + private static final Logger logger = Logger.getLogger(CheckRequestBuilderImpl.class.getName()); + + private static final String METHOD = "POST"; + private static final String PROTOCOL = "HTTP/2"; + private static final long SIZE = -1; + + private final ExtAuthzConfig config; + private final ExtAuthzCertificateProvider certificateProvider; + + CheckRequestBuilderImpl(ExtAuthzConfig config, + ExtAuthzCertificateProvider certificateProvider) { + this.config = config; + this.certificateProvider = certificateProvider; + } + + @Override + public CheckRequest buildRequest(MethodDescriptor methodDescriptor, Metadata headers, + Timestamp requestTime) { + return build(CheckRequestParams.builder().setMethodDescriptor(methodDescriptor) + .setHeaders(headers).setRequestTime(requestTime).build()); + } + + @Override + public CheckRequest buildRequest(ServerCall serverCall, Metadata headers, + Timestamp requestTime) { + CheckRequestParams.Builder paramsBuilder = + CheckRequestParams.builder().setMethodDescriptor(serverCall.getMethodDescriptor()) + .setHeaders(headers).setRequestTime(requestTime); + java.net.SocketAddress localAddress = + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_LOCAL_ADDR); + if (localAddress != null) { + paramsBuilder.setLocalAddress(localAddress); + } + java.net.SocketAddress remoteAddress = + serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_REMOTE_ADDR); + if (remoteAddress != null) { + paramsBuilder.setRemoteAddress(remoteAddress); + } + SSLSession sslSession = serverCall.getAttributes().get(Grpc.TRANSPORT_ATTR_SSL_SESSION); + if (sslSession != null) { + paramsBuilder.setSslSession(sslSession); + } + return build(paramsBuilder.build()); + } + + private CheckRequest build(CheckRequestParams params) { + AttributeContext.Builder attrBuilder = AttributeContext.newBuilder(); + if (params.remoteAddress().isPresent()) { + attrBuilder.setSource(buildSource(params.remoteAddress().get(), params.sslSession())); + } + if (params.localAddress().isPresent()) { + attrBuilder + .setDestination(buildDestination(params.localAddress().get(), params.sslSession())); + } + attrBuilder.setRequest(buildAttributeRequest(params.headers(), + params.methodDescriptor().getFullMethodName(), params.requestTime())); + return CheckRequest.newBuilder().setAttributes(attrBuilder).build(); + } + + private AttributeContext.Peer buildSource(java.net.SocketAddress socketAddress, + Optional sslSession) { + AttributeContext.Peer.Builder peerBuilder = buildPeer(socketAddress).toBuilder(); + if (sslSession.isPresent()) { + try { + Certificate[] certs = sslSession.get().getPeerCertificates(); + if (certs != null && certs.length > 0 && certs[0] instanceof X509Certificate) { + X509Certificate cert = (X509Certificate) certs[0]; + peerBuilder.setPrincipal(certificateProvider.getPrincipal(cert)); + if (config.includePeerCertificate()) { + try { + peerBuilder.setCertificate(certificateProvider.getUrlPemEncodedCertificate(cert)); + } catch (UnsupportedEncodingException | CertificateEncodingException e) { + logger.log(Level.WARNING, + "Error encoding peer certificate. " + + "This is not expected, but if it happens, the certificate should not " + + "be set according to the spec.", + e); + } + } + } + } catch (SSLPeerUnverifiedException e) { + logger.log(Level.FINE, + "Peer is not authenticated. " + + "This is expected, principal and certificate should not be set " + + "according to the spec.", + e); + } + } + return peerBuilder.build(); + } + + private AttributeContext.Peer buildDestination(java.net.SocketAddress socketAddress, + Optional sslSession) { + AttributeContext.Peer.Builder peerBuilder = buildPeer(socketAddress).toBuilder(); + if (sslSession.isPresent()) { + Certificate[] certs = sslSession.get().getLocalCertificates(); + if (certs != null && certs.length > 0 && certs[0] instanceof X509Certificate) { + peerBuilder.setPrincipal(certificateProvider.getPrincipal((X509Certificate) certs[0])); + } + } + return peerBuilder.build(); + } + + private AttributeContext.Peer buildPeer(java.net.SocketAddress socketAddress) { + AttributeContext.Peer.Builder peerBuilder = AttributeContext.Peer.newBuilder(); + if (socketAddress instanceof InetSocketAddress) { + InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; + peerBuilder.setAddress(Address.newBuilder() + .setSocketAddress(SocketAddress.newBuilder() + .setAddress(inetSocketAddress.getAddress().getHostAddress()) + .setPortValue(inetSocketAddress.getPort())) + .build()); + } + return peerBuilder.build(); + } + + private AttributeContext.Request buildAttributeRequest(Metadata headers, String fullMethodName, + Timestamp requestTime) { + AttributeContext.Request.Builder reqBuilder = AttributeContext.Request.newBuilder(); + reqBuilder.setTime(requestTime); + AttributeContext.HttpRequest.Builder httpReqBuilder = + AttributeContext.HttpRequest.newBuilder(); + httpReqBuilder.setPath(fullMethodName); + httpReqBuilder.setMethod(METHOD); + httpReqBuilder.setProtocol(PROTOCOL); + httpReqBuilder.setSize(SIZE); + for (String key : headers.keys()) { + if (!isAllowed(key)) { + continue; + } + Optional value; + if (key.endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + value = getBinaryHeaderValue(headers, key); + } else { + value = getAsciiHeaderValue(headers, key); + } + value.ifPresent( + headerValue -> httpReqBuilder.putHeaders(key.toLowerCase(Locale.ROOT), headerValue)); + } + reqBuilder.setHttp(httpReqBuilder); + return reqBuilder.build(); + } + + private Optional getBinaryHeaderValue(Metadata headers, String key) { + Iterable binaryValues = + headers.getAll(Metadata.Key.of(key, Metadata.BINARY_BYTE_MARSHALLER)); + if (binaryValues == null) { + // Unreachable code, since we iterate over the keys. Exists for defensive programming. + return Optional.empty(); + } + List base64Values = new ArrayList<>(); + for (byte[] value : binaryValues) { + base64Values.add(BaseEncoding.base64().encode(value)); + } + return Optional.of(String.join(",", base64Values)); + } + + private Optional getAsciiHeaderValue(Metadata headers, String key) { + Iterable stringValues = + headers.getAll(Metadata.Key.of(key, Metadata.ASCII_STRING_MARSHALLER)); + if (stringValues == null) { + // Unreachable code, since we iterate over the keys. Exists for defensive programming. + return Optional.empty(); + } + return Optional.of(String.join(",", stringValues)); + } + + private boolean isAllowed(String header) { + for (Matchers.StringMatcher matcher : config.disallowedHeaders()) { + if (matcher.matches(header)) { + return false; + } + } + if (config.allowedHeaders().isEmpty()) { + return true; + } + for (Matchers.StringMatcher matcher : config.allowedHeaders()) { + if (matcher.matches(header)) { + return true; + } + } + return false; + } + + @AutoValue + abstract static class CheckRequestParams { + abstract Metadata headers(); + + abstract MethodDescriptor methodDescriptor(); + + abstract Timestamp requestTime(); + + abstract Optional localAddress(); + + abstract Optional remoteAddress(); + + abstract Optional sslSession(); + + static Builder builder() { + Builder builder = + new AutoValue_CheckRequestBuilder_CheckRequestBuilderImpl_CheckRequestParams.Builder(); + return builder; + } + + @AutoValue.Builder + abstract static class Builder { + abstract Builder setHeaders(Metadata headers); + + abstract Builder setMethodDescriptor(MethodDescriptor method); + + abstract Builder setRequestTime(Timestamp time); + + abstract Builder setLocalAddress(java.net.SocketAddress localAddress); + + abstract Builder setRemoteAddress(java.net.SocketAddress remoteAddress); + + abstract Builder setSslSession(SSLSession sslSession); + + abstract CheckRequestParams build(); + } + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProvider.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProvider.java new file mode 100644 index 00000000000..b4ec8dd8303 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProvider.java @@ -0,0 +1,132 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.common.io.BaseEncoding; +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.List; +import java.util.logging.Level; +import java.util.logging.Logger; + +/** + * An interface for providing certificate-related information. + */ +public interface ExtAuthzCertificateProvider { + /** + * Creates a new instance of the CertificateProvider. + * + * @return A new CertificateProvider instance. + */ + static ExtAuthzCertificateProvider create() { + return new DefaultCertificateProvider(); + } + + /** + * Gets the principal from a certificate. It returns the cert's first IP Address SAN if set, + * otherwise the cert's first DNS SAN if set, otherwise the subject field of the certificate in + * RFC 2253 format. + * + * @param cert The certificate. + * @return The principal. + */ + String getPrincipal(X509Certificate cert); + + /** + * Gets the URL PEM encoded certificate. It Pem encodes first and then urlencodes. + * + * @param cert The certificate. + * @return The URL PEM encoded certificate. + * @throws CertificateEncodingException If an error occurs while encoding the certificate. + * @throws UnsupportedEncodingException If an error occurs while encoding the URL. + */ + String getUrlPemEncodedCertificate(X509Certificate cert) + throws CertificateEncodingException, UnsupportedEncodingException; + + /** + * Default implementation of the CertificateProvider interface. + */ + final class DefaultCertificateProvider implements ExtAuthzCertificateProvider { + private static final Logger logger = + Logger.getLogger(DefaultCertificateProvider.class.getName()); + // From RFC 5280, section 4.2.1.6, Subject Alternative Name + // dNSName (2) + // iPAddress (7) + private static final int SAN_TYPE_DNS_NAME = 2; + private static final int SAN_TYPE_IP_ADDRESS = 7; + + @Override + public String getPrincipal(X509Certificate cert) { + try { + Collection> sans = cert.getSubjectAlternativeNames(); + if (sans != null) { + // Look for IP Address SAN. + for (List san : sans) { + if (san.size() == 2 && san.get(0) instanceof Integer + && (Integer) san.get(0) == SAN_TYPE_IP_ADDRESS) { + return (String) san.get(1); + } + } + // If no IP Address SAN, look for DNS SAN. + for (List san : sans) { + if (san.size() == 2 && san.get(0) instanceof Integer + && (Integer) san.get(0) == SAN_TYPE_DNS_NAME) { + return (String) san.get(1); + } + } + } + } catch (java.security.cert.CertificateParsingException e) { + logger.log(Level.WARNING, "Error parsing certificate SANs. " + "This is not expected," + + "falling back to the subject according to the spec.", e); + } + return cert.getSubjectX500Principal().getName(); + } + + @Override + public String getUrlPemEncodedCertificate(X509Certificate cert) + throws CertificateEncodingException, UnsupportedEncodingException { + String pemCert = CertPemConverter.toPem(cert); + return URLEncoder.encode(pemCert, StandardCharsets.UTF_8.toString()); + } + } + + /** + * A utility class for PEM encoding. + */ + final class CertPemConverter { + + private static final String X509_PEM_HEADER = "-----BEGIN CERTIFICATE-----\n"; + private static final String X509_PEM_FOOTER = "\n-----END CERTIFICATE-----\n"; + + private CertPemConverter() {} + + /** + * Converts a certificate to a PEM string. + * + * @param cert The certificate to convert. + * @return The PEM encoded certificate. + * @throws CertificateEncodingException If an error occurs while encoding the certificate. + */ + public static String toPem(X509Certificate cert) throws CertificateEncodingException { + return X509_PEM_HEADER + BaseEncoding.base64().encode(cert.getEncoded()) + X509_PEM_FOOTER; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java new file mode 100644 index 00000000000..1faa0062a04 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckRequestBuilderTest.java @@ -0,0 +1,350 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.protobuf.Any; +import com.google.protobuf.Timestamp; +import io.envoyproxy.envoy.config.core.v3.Address; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.service.auth.v3.AttributeContext; +import io.envoyproxy.envoy.service.auth.v3.CheckRequest; +import io.envoyproxy.envoy.type.matcher.v3.ListStringMatcher; +import io.envoyproxy.envoy.type.matcher.v3.StringMatcher; +import io.grpc.Attributes; +import io.grpc.Grpc; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.ServerCall; +import io.grpc.testing.TestMethodDescriptors; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class CheckRequestBuilderTest { + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + @Mock + private ServerCall serverCall; + @Mock + private SSLSession sslSession; + @Mock + private ExtAuthzCertificateProvider certificateProvider; + + private CheckRequestBuilder checkRequestBuilder; + private MethodDescriptor methodDescriptor; + private Timestamp requestTime; + + @Before + public void setUp() throws ExtAuthzParseException { + ExtAuthzConfig config = buildExtAuthzConfig(); + checkRequestBuilder = CheckRequestBuilder.INSTANCE.create(config, certificateProvider); + methodDescriptor = TestMethodDescriptors.voidMethod(); + requestTime = Timestamp.newBuilder().setSeconds(12345).setNanos(67890).build(); + } + + @Test + public void buildRequest_forServer_happyPath() throws Exception { + // Setup for addresses + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + + // Setup for SSL and certificates + X509Certificate peerCert = mock(X509Certificate.class); + X509Certificate localCert = mock(X509Certificate.class); + Certificate[] peerCerts = new Certificate[] {peerCert}; + Certificate[] localCerts = new Certificate[] {localCert}; + when(sslSession.getPeerCertificates()).thenReturn(peerCerts); + when(sslSession.getLocalCertificates()).thenReturn(localCerts); + when(certificateProvider.getPrincipal(peerCert)).thenReturn("peer-principal"); + when(certificateProvider.getPrincipal(localCert)).thenReturn("local-principal"); + when(certificateProvider.getUrlPemEncodedCertificate(peerCert)).thenReturn("encoded-peer-cert"); + + // Setup for headers + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("allowed-header", Metadata.ASCII_STRING_MARSHALLER), "v1"); + headers.put(Metadata.Key.of("disallowed-header", Metadata.ASCII_STRING_MARSHALLER), "v2"); + headers.put(Metadata.Key.of("overridden-header", Metadata.ASCII_STRING_MARSHALLER), "v3"); + byte[] binaryValue = new byte[] {1, 2, 3}; + headers.put(Metadata.Key.of("bin-header-bin", Metadata.BINARY_BYTE_MARSHALLER), binaryValue); + + // Configure CheckRequestBuilder to allow specific headers + ListStringMatcher allowedHeaders = ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("allowed-header").build()) + .addPatterns(StringMatcher.newBuilder().setExact("overridden-header").build()).build(); + ListStringMatcher disallowedHeaders = ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("disallowed-header").build()) + .addPatterns(StringMatcher.newBuilder().setExact("overridden-header").build()).build(); + ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, disallowedHeaders, true); + checkRequestBuilder = CheckRequestBuilder.INSTANCE.create(config, certificateProvider); + + // Setup server call attributes + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + // Build and verify the request + CheckRequest request = checkRequestBuilder.buildRequest(serverCall, headers, requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.getSource().getAddress().getSocketAddress().getAddress()) + .isEqualTo("192.168.1.1"); + assertThat(attrContext.getSource().getPrincipal()).isEqualTo("peer-principal"); + assertThat(attrContext.getSource().getCertificate()).isEqualTo("encoded-peer-cert"); + assertThat(attrContext.getDestination().getAddress().getSocketAddress().getAddress()) + .isEqualTo("10.0.0.2"); + assertThat(attrContext.getDestination().getPrincipal()).isEqualTo("local-principal"); + + AttributeContext.HttpRequest http = attrContext.getRequest().getHttp(); + assertThat(http.getHeadersMap()).containsEntry("allowed-header", "v1"); + assertThat(http.getHeadersMap()).doesNotContainKey("bin-header-bin"); + assertThat(http.getHeadersMap()).doesNotContainKey("disallowed-header"); + assertThat(http.getHeadersMap()).doesNotContainKey("overridden-header"); + } + + @Test + public void buildRequest_forServer_noTransportAttrs() { + when(serverCall.getAttributes()).thenReturn(Attributes.EMPTY); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + Metadata headers = new Metadata(); + + CheckRequest request = checkRequestBuilder.buildRequest(serverCall, headers, requestTime); + + assertThat(request.getAttributes().getRequest().getTime()).isEqualTo(requestTime); + assertThat(request.getAttributes().getRequest().getHttp().getPath()) + .isEqualTo(methodDescriptor.getFullMethodName()); + assertThat(request.getAttributes().getRequest().getHttp().getMethod()).isEqualTo("POST"); + assertThat(request.getAttributes().getRequest().getHttp().getProtocol()).isEqualTo("HTTP/2"); + assertThat(request.getAttributes().getRequest().getHttp().getSize()).isEqualTo(-1); + assertThat(request.getAttributes().getRequest().getHttp().getHeadersMap()).isEmpty(); + assertThat(request.getAttributes().hasSource()).isFalse(); + assertThat(request.getAttributes().hasDestination()).isFalse(); + } + + + @Test + public void buildRequest_forClient_happyPath_emptyAllowedHeaders() throws Exception { + // Setup for headers + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("some-header", Metadata.ASCII_STRING_MARSHALLER), "v1"); + headers.put(Metadata.Key.of("disallowed-header", Metadata.ASCII_STRING_MARSHALLER), "v2"); + byte[] binaryValue = new byte[] {1, 2, 3}; + headers.put(Metadata.Key.of("bin-header-bin", Metadata.BINARY_BYTE_MARSHALLER), binaryValue); + + // Configure CheckRequestBuilder with empty allowed headers + ListStringMatcher allowedHeaders = ListStringMatcher.newBuilder().build(); // empty + ListStringMatcher disallowedHeaders = ListStringMatcher.newBuilder() + .addPatterns(StringMatcher.newBuilder().setExact("disallowed-header").build()).build(); + ExtAuthzConfig config = buildExtAuthzConfig(allowedHeaders, disallowedHeaders, true); + checkRequestBuilder = CheckRequestBuilder.INSTANCE.create(config, certificateProvider); + + // Build and verify the request + CheckRequest request = checkRequestBuilder.buildRequest(methodDescriptor, headers, requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.hasSource()).isFalse(); + assertThat(attrContext.hasDestination()).isFalse(); + + AttributeContext.HttpRequest http = attrContext.getRequest().getHttp(); + assertThat(http.getPath()).isEqualTo(methodDescriptor.getFullMethodName()); + assertThat(http.getHeadersMap()).containsEntry("some-header", "v1"); + assertThat(http.getHeadersMap()).containsEntry("bin-header-bin", "AQID"); + assertThat(http.getHeadersMap()).doesNotContainKey("disallowed-header"); + } + + @Test + public void buildRequest_forServer_noSslSession() { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext attrContext = request.getAttributes(); + assertThat(attrContext.hasSource()).isTrue(); + Address sourceAddress = attrContext.getSource().getAddress(); + assertThat(sourceAddress.getSocketAddress().getAddress()).isEqualTo("192.168.1.1"); + assertThat(sourceAddress.getSocketAddress().getPortValue()).isEqualTo(12345); + assertThat(attrContext.getSource().getPrincipal()).isEmpty(); + + assertThat(attrContext.hasDestination()).isTrue(); + Address destAddress = attrContext.getDestination().getAddress(); + assertThat(destAddress.getSocketAddress().getAddress()).isEqualTo("10.0.0.2"); + assertThat(destAddress.getSocketAddress().getPortValue()).isEqualTo(443); + assertThat(attrContext.getDestination().getPrincipal()).isEmpty(); + } + + @Test + public void buildRequest_forServer_sslPeerUnverified() throws Exception { + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + when(sslSession.getPeerCertificates()).thenThrow(new SSLPeerUnverifiedException("unverified")); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + assertThat(source.getCertificate()).isEmpty(); + } + + @Test + public void buildRequest_forServer_includePeerCertFalse() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(ListStringMatcher.newBuilder().build(), + ListStringMatcher.newBuilder().build(), false); + checkRequestBuilder = CheckRequestBuilder.INSTANCE.create(config, certificateProvider); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + X509Certificate peerCert = mock(X509Certificate.class); + Certificate[] peerCerts = new Certificate[] {peerCert}; + + when(sslSession.getPeerCertificates()).thenReturn(peerCerts); + when(certificateProvider.getPrincipal(peerCert)).thenReturn("peer-principal"); + + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEqualTo("peer-principal"); + assertThat(source.getCertificate()).isEmpty(); + } + + @Test + public void buildRequest_forServer_nullOrEmptyCertificates() throws Exception { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + + // Test with null certificates + when(sslSession.getPeerCertificates()).thenReturn(null); + when(sslSession.getLocalCertificates()).thenReturn(null); + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + assertThat(source.getCertificate()).isEmpty(); + AttributeContext.Peer destination = request.getAttributes().getDestination(); + assertThat(destination.getPrincipal()).isEmpty(); + + // Test with empty certificates + when(sslSession.getPeerCertificates()).thenReturn(new Certificate[0]); + when(sslSession.getLocalCertificates()).thenReturn(new Certificate[0]); + request = checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + assertThat(source.getCertificate()).isEmpty(); + destination = request.getAttributes().getDestination(); + assertThat(destination.getPrincipal()).isEmpty(); + } + + @Test + public void buildRequest_forServer_nonX509Certificate() throws Exception { + SocketAddress localAddress = new InetSocketAddress("10.0.0.2", 443); + SocketAddress remoteAddress = new InetSocketAddress("192.168.1.1", 12345); + Attributes attributes = + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_LOCAL_ADDR, localAddress) + .set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress) + .set(Grpc.TRANSPORT_ATTR_SSL_SESSION, sslSession).build(); + when(serverCall.getAttributes()).thenReturn(attributes); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + Certificate nonX509Cert = mock(Certificate.class); + Certificate[] certs = new Certificate[] {nonX509Cert}; + + when(sslSession.getPeerCertificates()).thenReturn(certs); + when(sslSession.getLocalCertificates()).thenReturn(certs); + + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + + AttributeContext.Peer source = request.getAttributes().getSource(); + assertThat(source.getPrincipal()).isEmpty(); + AttributeContext.Peer destination = request.getAttributes().getDestination(); + assertThat(destination.getPrincipal()).isEmpty(); + } + + @Test + public void buildRequest_forServer_nonInetSocketAddress() { + SocketAddress remoteAddress = mock(SocketAddress.class); + when(serverCall.getAttributes()).thenReturn( + Attributes.newBuilder().set(Grpc.TRANSPORT_ATTR_REMOTE_ADDR, remoteAddress).build()); + when(serverCall.getMethodDescriptor()).thenReturn(methodDescriptor); + CheckRequest request = + checkRequestBuilder.buildRequest(serverCall, new Metadata(), requestTime); + assertThat(request.getAttributes().getSource().hasAddress()).isFalse(); + } + + private ExtAuthzConfig buildExtAuthzConfig() throws ExtAuthzParseException { + return buildExtAuthzConfig(ListStringMatcher.newBuilder().build(), + ListStringMatcher.newBuilder().build(), true); + } + + private ExtAuthzConfig buildExtAuthzConfig(ListStringMatcher allowed, + ListStringMatcher disallowed, boolean includePeerCertificate) throws ExtAuthzParseException { + Any googleDefaultChannelCreds = Any.pack(GoogleDefaultCredentials.newBuilder().build()); + Any fakeAccessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); + ExtAuthz.Builder builder = ExtAuthz.newBuilder() + .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder() + .setGoogleGrpc(io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test-cluster").addChannelCredentialsPlugin(googleDefaultChannelCreds) + .addCallCredentialsPlugin(fakeAccessTokenCreds).build()) + .build()) + .setIncludePeerCertificate(includePeerCertificate).setAllowedHeaders(allowed) + .setDisallowedHeaders(disallowed); + return ExtAuthzConfig.fromProto(builder.build()); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProviderTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProviderTest.java new file mode 100644 index 00000000000..fdeff595d56 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzCertificateProviderTest.java @@ -0,0 +1,140 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import java.net.URLEncoder; +import java.nio.charset.StandardCharsets; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateParsingException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import javax.security.auth.x500.X500Principal; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + + + +@RunWith(JUnit4.class) +public class ExtAuthzCertificateProviderTest { + private final ExtAuthzCertificateProvider provider = ExtAuthzCertificateProvider.create(); + + @Test + public void getPrincipal_ipAddressSan() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + List ipSan = Arrays.asList(7, "192.168.1.1"); // SAN_TYPE_IP_ADDRESS + Collection> sans = Arrays.asList(ipSan); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("192.168.1.1"); + } + + @Test + public void getPrincipal_dnsSan() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + List san = Arrays.asList(2, "foo.test.google.fr"); // SAN_TYPE_DNS_NAME + Collection> sans = Collections.singletonList(san); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("foo.test.google.fr"); + } + + @Test + public void getPrincipal_noSan_usesSubject() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + when(mockCert.getSubjectAlternativeNames()).thenReturn(Collections.emptyList()); + X500Principal principal = new X500Principal("CN=testclient, O=gRPC authors"); + when(mockCert.getSubjectX500Principal()).thenReturn(principal); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("CN=testclient,O=gRPC authors"); + } + + @Test + public void getPrincipal_nullSans_usesSubject() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + when(mockCert.getSubjectAlternativeNames()).thenReturn(null); + X500Principal principal = new X500Principal("CN=testclient, O=gRPC authors"); + when(mockCert.getSubjectX500Principal()).thenReturn(principal); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("CN=testclient,O=gRPC authors"); + } + + @Test + public void getPrincipal_ipSanWrongSize_usesDnsSan() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + List ipSan = Collections.singletonList(7); // SAN_TYPE_IP_ADDRESS, wrong size + List dnsSan = Arrays.asList(2, "foo.test.google.fr"); // SAN_TYPE_DNS_NAME + Collection> sans = Arrays.asList(ipSan, dnsSan); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("foo.test.google.fr"); + } + + @Test + public void getPrincipal_ipSanWrongType_usesDnsSan() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + // SAN_TYPE_IP_ADDRESS, wrong type + List ipSan = Arrays.asList("not-an-integer", "192.168.1.1"); + List dnsSan = Arrays.asList(2, "foo.test.google.fr"); // SAN_TYPE_DNS_NAME + Collection> sans = Arrays.asList(ipSan, dnsSan); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("foo.test.google.fr"); + } + + @Test + public void getPrincipal_dnsSanWrongType_usesSubject() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + // Wrong SAN type for DNS check + List otherSan = Arrays.asList(6, "foo.test.google.fr"); // SAN_TYPE_URI + Collection> sans = Collections.singletonList(otherSan); + when(mockCert.getSubjectAlternativeNames()).thenReturn(sans); + when(mockCert.getSubjectX500Principal()).thenReturn(new X500Principal("CN=test")); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("CN=test"); + } + + @Test + public void getPrincipal_sanParsingException_usesSubject() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + when(mockCert.getSubjectAlternativeNames()).thenThrow(new CertificateParsingException()); + X500Principal principal = new X500Principal("CN=testclient, O=gRPC authors"); + when(mockCert.getSubjectX500Principal()).thenReturn(principal); + assertThat(provider.getPrincipal(mockCert)).isEqualTo("CN=testclient,O=gRPC authors"); + } + + @Test + public void getUrlPemEncodedCertificate() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + byte[] certData = "cert-data".getBytes(StandardCharsets.UTF_8); + when(mockCert.getEncoded()).thenReturn(certData); + + String pem = "-----BEGIN CERTIFICATE-----\n" + "Y2VydC1kYXRh" // base64 of "cert-data" + + "\n-----END CERTIFICATE-----\n"; + String urlEncodedPem = URLEncoder.encode(pem, StandardCharsets.UTF_8.toString()); + assertThat(provider.getUrlPemEncodedCertificate(mockCert)).isEqualTo(urlEncodedPem); + } + + @Test + public void getUrlPemEncodedCertificate_encodingException() throws Exception { + X509Certificate mockCert = mock(X509Certificate.class); + when(mockCert.getEncoded()).thenThrow(new CertificateEncodingException("test")); + assertThrows(CertificateEncodingException.class, + () -> provider.getUrlPemEncodedCertificate(mockCert)); + } +} From 86d68ac6a9954a69a5dac1949259142a62de891b Mon Sep 17 00:00:00 2001 From: Saurav Date: Fri, 24 Oct 2025 13:58:34 +0000 Subject: [PATCH 3/7] feat(xds): Add header mutations library This commit introduces a library for handling header mutations as specified by the xDS protocol. This library provides the core functionality for modifying request and response headers based on a set of rules. The main components of this library are: - `HeaderMutator`: Applies header mutations to `Metadata` objects. - `HeaderMutationFilter`: Filters header mutations based on a set of configurable rules, such as disallowing mutations of system headers. - `HeaderMutations`: A value class that represents the set of mutations to be applied to request and response headers. - `HeaderMutationDisallowedException`: An exception that is thrown when a disallowed header mutation is attempted. This commit also includes comprehensive unit tests for the new library. --- .../HeaderMutationDisallowedException.java | 32 ++ .../headermutations/HeaderMutationFilter.java | 172 ++++++++++ .../headermutations/HeaderMutations.java | 58 ++++ .../headermutations/HeaderMutator.java | 143 ++++++++ .../HeaderMutationFilterTest.java | 245 ++++++++++++++ .../headermutations/HeaderMutationsTest.java | 50 +++ .../headermutations/HeaderMutatorTest.java | 311 ++++++++++++++++++ 7 files changed, 1011 insertions(+) create mode 100644 xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java new file mode 100644 index 00000000000..b8d4eb582fb --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationDisallowedException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2024 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import io.grpc.Status; +import io.grpc.StatusException; + +/** + * Exception thrown when a header mutation is disallowed. + */ +public final class HeaderMutationDisallowedException extends StatusException { + + private static final long serialVersionUID = 1L; + + public HeaderMutationDisallowedException(String message) { + super(Status.INTERNAL.withDescription(message)); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java new file mode 100644 index 00000000000..0452354d823 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutationFilter.java @@ -0,0 +1,172 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableSet; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.grpc.xds.internal.headermutations.HeaderMutations.RequestHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import java.util.Collection; +import java.util.Locale; +import java.util.Optional; +import java.util.function.Predicate; + +/** + * The HeaderMutationFilter class is responsible for filtering header mutations based on a given set + * of rules. + */ +public interface HeaderMutationFilter { + + /** + * A factory for creating {@link HeaderMutationFilter} instances. + */ + @FunctionalInterface + interface Factory { + /** + * Creates a new instance of {@code HeaderMutationFilter}. + * + * @param mutationRules The rules for header mutations. If an empty {@code Optional} is + * provided, all header mutations are allowed by default, except for certain system + * headers. If a {@link HeaderMutationRulesConfig} is provided, mutations will be + * filtered based on the specified rules. + */ + HeaderMutationFilter create(Optional mutationRules); + } + + /** + * The default factory for creating {@link HeaderMutationFilter} instances. + */ + Factory INSTANCE = HeaderMutationFilterImpl::new; + + /** + * Filters the given header mutations based on the configured rules and returns the allowed + * mutations. + * + * @param mutations The header mutations to filter + * @return The allowed header mutations. + * @throws HeaderMutationDisallowedException if a disallowed mutation is encountered and the rules + * specify that this should be an error. + */ + HeaderMutations filter(HeaderMutations mutations) throws HeaderMutationDisallowedException; + + /** Default implementation of {@link HeaderMutationFilter}. */ + final class HeaderMutationFilterImpl implements HeaderMutationFilter { + private final Optional mutationRules; + + /** + * Set of HTTP/2 pseudo-headers and the host header that are critical for routing and protocol + * correctness. These headers cannot be mutated by user configuration. + */ + private static final ImmutableSet IMMUTABLE_HEADERS = + ImmutableSet.of("host", ":authority", ":scheme", ":method"); + + private HeaderMutationFilterImpl(Optional mutationRules) { // NOPMD + this.mutationRules = mutationRules; + } + + @Override + public HeaderMutations filter(HeaderMutations mutations) + throws HeaderMutationDisallowedException { + ImmutableList allowedRequestHeaders = + filterCollection(mutations.requestMutations().headers(), + header -> isHeaderMutationAllowed(header.getHeader().getKey()) + && !appendsSystemHeader(header)); + ImmutableList allowedRequestHeadersToRemove = + filterCollection(mutations.requestMutations().headersToRemove(), + header -> isHeaderMutationAllowed(header) && isHeaderRemovalAllowed(header)); + ImmutableList allowedResponseHeaders = + filterCollection(mutations.responseMutations().headers(), + header -> isHeaderMutationAllowed(header.getHeader().getKey()) + && !appendsSystemHeader(header)); + return HeaderMutations.create( + RequestHeaderMutations.create(allowedRequestHeaders, allowedRequestHeadersToRemove), + ResponseHeaderMutations.create(allowedResponseHeaders)); + } + + /** + * A generic helper to filter a collection based on a predicate. + * + * @param items The collection of items to filter. + * @param isAllowedPredicate The predicate to apply to each item. + * @param The type of items in the collection. + * @return An immutable list of allowed items. + * @throws HeaderMutationDisallowedException if an item is disallowed and disallowIsError is + * true. + */ + private ImmutableList filterCollection(Collection items, + Predicate isAllowedPredicate) throws HeaderMutationDisallowedException { + ImmutableList.Builder allowed = ImmutableList.builder(); + for (T item : items) { + if (isAllowedPredicate.test(item)) { + allowed.add(item); + } else if (disallowIsError()) { + throw new HeaderMutationDisallowedException( + "Header mutation disallowed for header: " + item); + } + } + return allowed.build(); + } + + private boolean isHeaderRemovalAllowed(String headerKey) { + return !isSystemHeaderKey(headerKey); + } + + private boolean appendsSystemHeader(HeaderValueOption headerValueOption) { + String key = headerValueOption.getHeader().getKey(); + boolean isAppend = headerValueOption + .getAppendAction() == HeaderValueOption.HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD; + return isAppend && isSystemHeaderKey(key); + } + + private boolean isSystemHeaderKey(String key) { + return key.startsWith(":") || key.toLowerCase(Locale.ROOT).equals("host"); + } + + private boolean isHeaderMutationAllowed(String headerName) { + String lowerCaseHeaderName = headerName.toLowerCase(Locale.ROOT); + if (IMMUTABLE_HEADERS.contains(lowerCaseHeaderName)) { + return false; + } + return mutationRules.map(rules -> isHeaderMutationAllowed(lowerCaseHeaderName, rules)) + .orElse(true); + } + + private boolean isHeaderMutationAllowed(String lowerCaseHeaderName, + HeaderMutationRulesConfig rules) { + // TODO(sauravzg): The priority is slightly unclear in the spec. + // Both `disallowAll` and `disallow_expression` take precedence over `all other + // settings`. + // `allow_expression` takes precedence over everything except `disallow_expression`. + // This is a conflict between ordering for `allow_expression` and `disallowAll`. + // Choosing to proceed with current envoy implementation which favors `allow_expression` over + // `disallowAll`. + if (rules.disallowExpression().isPresent() + && rules.disallowExpression().get().matcher(lowerCaseHeaderName).matches()) { + return false; + } + if (rules.allowExpression().isPresent()) { + return rules.allowExpression().get().matcher(lowerCaseHeaderName).matches(); + } + return !rules.disallowAll(); + } + + private boolean disallowIsError() { + return mutationRules.map(HeaderMutationRulesConfig::disallowIsError).orElse(false); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java new file mode 100644 index 00000000000..e0cb3daede3 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutations.java @@ -0,0 +1,58 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; + +/** A collection of header mutations for both request and response headers. */ +@AutoValue +public abstract class HeaderMutations { + + public static HeaderMutations create(RequestHeaderMutations requestMutations, + ResponseHeaderMutations responseMutations) { + return new AutoValue_HeaderMutations(requestMutations, responseMutations); + } + + public abstract RequestHeaderMutations requestMutations(); + + public abstract ResponseHeaderMutations responseMutations(); + + /** Represents mutations for request headers. */ + @AutoValue + public abstract static class RequestHeaderMutations { + public static RequestHeaderMutations create(ImmutableList headers, + ImmutableList headersToRemove) { + return new AutoValue_HeaderMutations_RequestHeaderMutations(headers, headersToRemove); + } + + public abstract ImmutableList headers(); + + public abstract ImmutableList headersToRemove(); + } + + /** Represents mutations for response headers. */ + @AutoValue + public abstract static class ResponseHeaderMutations { + public static ResponseHeaderMutations create(ImmutableList headers) { + return new AutoValue_HeaderMutations_ResponseHeaderMutations(headers); + } + + public abstract ImmutableList headers(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java new file mode 100644 index 00000000000..de5b946bbc7 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/headermutations/HeaderMutator.java @@ -0,0 +1,143 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import com.google.common.io.BaseEncoding; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption.HeaderAppendAction; +import io.grpc.Metadata; +import io.grpc.xds.internal.headermutations.HeaderMutations.RequestHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import java.nio.charset.StandardCharsets; +import java.util.logging.Logger; + +/** + * The HeaderMutator class is an implementation of the HeaderMutator interface. It provides methods + * to apply header mutations to a given set of headers based on a given set of rules. + */ +public interface HeaderMutator { + /** + * Creates a new instance of {@code HeaderMutator}. + */ + static HeaderMutator create() { + return new HeaderMutatorImpl(); + } + + /** + * Applies the given header mutations to the provided metadata headers. + * + * @param mutations The header mutations to apply. + * @param headers The metadata headers to which the mutations will be applied. + */ + void applyRequestMutations(RequestHeaderMutations mutations, Metadata headers); + + + /** + * Applies the given header mutations to the provided metadata headers. + * + * @param mutations The header mutations to apply. + * @param headers The metadata headers to which the mutations will be applied. + */ + void applyResponseMutations(ResponseHeaderMutations mutations, Metadata headers); + + /** Default implementation of {@link HeaderMutator}. */ + final class HeaderMutatorImpl implements HeaderMutator { + + private static final Logger logger = Logger.getLogger(HeaderMutatorImpl.class.getName()); + + @Override + public void applyRequestMutations(final RequestHeaderMutations mutations, Metadata headers) { + // TODO(sauravzg): The specification is not clear on order of header removals and additions. + // in case of conflicts. Copying the order from Envoy here, which does removals at the end. + applyHeaderUpdates(mutations.headers(), headers); + for (String headerToRemove : mutations.headersToRemove()) { + headers.discardAll(Metadata.Key.of(headerToRemove, Metadata.ASCII_STRING_MARSHALLER)); + } + } + + @Override + public void applyResponseMutations(final ResponseHeaderMutations mutations, Metadata headers) { + applyHeaderUpdates(mutations.headers(), headers); + } + + private void applyHeaderUpdates(final Iterable headerOptions, + Metadata headers) { + for (HeaderValueOption headerOption : headerOptions) { + HeaderValue headerValue = headerOption.getHeader(); + updateHeader(headerValue, headerOption.getAppendAction(), headers); + } + } + + private void updateHeader(final HeaderValue header, final HeaderAppendAction action, + Metadata mutableHeaders) { + if (header.getKey().endsWith(Metadata.BINARY_HEADER_SUFFIX)) { + updateHeader(action, Metadata.Key.of(header.getKey(), Metadata.BINARY_BYTE_MARSHALLER), + getBinaryHeaderValue(header), mutableHeaders); + } else { + updateHeader(action, Metadata.Key.of(header.getKey(), Metadata.ASCII_STRING_MARSHALLER), + getAsciiValue(header), mutableHeaders); + } + } + + private void updateHeader(final HeaderAppendAction action, final Metadata.Key key, + final T value, Metadata mutableHeaders) { + switch (action) { + case APPEND_IF_EXISTS_OR_ADD: + mutableHeaders.put(key, value); + break; + case ADD_IF_ABSENT: + if (!mutableHeaders.containsKey(key)) { + mutableHeaders.put(key, value); + } + break; + case OVERWRITE_IF_EXISTS_OR_ADD: + mutableHeaders.discardAll(key); + mutableHeaders.put(key, value); + break; + case OVERWRITE_IF_EXISTS: + if (mutableHeaders.containsKey(key)) { + mutableHeaders.discardAll(key); + mutableHeaders.put(key, value); + } + break; + case UNRECOGNIZED: + // Ignore invalid value + logger.warning("Unrecognized HeaderAppendAction: " + action); + break; + default: + // Should be unreachable unless there's a proto schema mismatch. + logger.warning("Unknown HeaderAppendAction: " + action); + } + } + + private byte[] getBinaryHeaderValue(HeaderValue header) { + return BaseEncoding.base64().decode(getAsciiValue(header)); + } + + private String getAsciiValue(HeaderValue header) { + // TODO(sauravzg): GRPC only supports base64 encoded binary headers, so we decode bytes to + // String using `StandardCharsets.US_ASCII`. + // Envoy's spec `raw_value` specification can contain non UTF-8 bytes, so this may potentially + // cause an exception or corruption. + if (!header.getRawValue().isEmpty()) { + return header.getRawValue().toString(StandardCharsets.US_ASCII); + } + return header.getValue(); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java new file mode 100644 index 00000000000..e73460924c7 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationFilterTest.java @@ -0,0 +1,245 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption.HeaderAppendAction; +import io.grpc.xds.internal.headermutations.HeaderMutations.RequestHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import java.util.Optional; +import java.util.regex.Pattern; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutationFilterTest { + + private static HeaderValueOption header(String key, String value) { + return HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(key).setValue(value)).build(); + } + + private static HeaderValueOption header(String key, String value, HeaderAppendAction action) { + return HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(key).setValue(value)).setAppendAction(action) + .build(); + } + + @Test + public void filter_removesImmutableHeaders() throws HeaderMutationDisallowedException { + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.empty()); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create( + ImmutableList.of(header("add-key", "add-value"), header(":authority", "new-authority"), + header("host", "new-host"), header(":scheme", "https"), header(":method", "PUT")), + ImmutableList.of("remove-key", "host", ":authority", ":scheme", ":method")), + ResponseHeaderMutations.create(ImmutableList.of(header("resp-add-key", "resp-add-value"), + header(":scheme", "https")))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()) + .containsExactly(header("add-key", "add-value")); + assertThat(filtered.requestMutations().headersToRemove()).containsExactly("remove-key"); + assertThat(filtered.responseMutations().headers()) + .containsExactly(header("resp-add-key", "resp-add-value")); + } + + @Test + public void filter_cannotAppendToSystemHeaders() throws HeaderMutationDisallowedException { + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.empty()); + HeaderMutations mutations = + HeaderMutations.create( + RequestHeaderMutations.create( + ImmutableList.of( + header("add-key", "add-value", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(":authority", "new-authority", + HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header("host", "new-host", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(":path", "/new-path", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD)), + ImmutableList.of()), + ResponseHeaderMutations.create(ImmutableList + .of(header("host", "new-host", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD)))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()).containsExactly( + header("add-key", "add-value", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD)); + assertThat(filtered.responseMutations().headers()).isEmpty(); + } + + @Test + public void filter_cannotRemoveSystemHeaders() throws HeaderMutationDisallowedException { + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.empty()); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create(ImmutableList.of(), + ImmutableList.of("remove-key", "host", ":foo", ":bar")), + ResponseHeaderMutations.create(ImmutableList.of())); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headersToRemove()).containsExactly("remove-key"); + } + + @Test + public void filter_canOverrideSystemHeadersNotInImmutableHeaders() + throws HeaderMutationDisallowedException { + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.empty()); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create( + ImmutableList.of(header("user-agent", "new-agent"), + header(":path", "/new/path", HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header(":grpc-trace-bin", "binary-value", HeaderAppendAction.ADD_IF_ABSENT)), + ImmutableList.of()), + ResponseHeaderMutations.create(ImmutableList + .of(header(":alt-svc", "h3=:443", HeaderAppendAction.OVERWRITE_IF_EXISTS)))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()).containsExactly( + header("user-agent", "new-agent"), + header(":path", "/new/path", HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header(":grpc-trace-bin", "binary-value", HeaderAppendAction.ADD_IF_ABSENT)); + assertThat(filtered.responseMutations().headers()) + .containsExactly(header(":alt-svc", "h3=:443", HeaderAppendAction.OVERWRITE_IF_EXISTS)); + } + + @Test + public void filter_disallowAll_disablesAllModifications() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder().disallowAll(true).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create(ImmutableList.of(header("add-key", "add-value")), + ImmutableList.of("remove-key")), + ResponseHeaderMutations.create(ImmutableList.of(header("resp-add-key", "resp-add-value")))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()).isEmpty(); + assertThat(filtered.requestMutations().headersToRemove()).isEmpty(); + assertThat(filtered.responseMutations().headers()).isEmpty(); + } + + @Test + public void filter_disallowExpression_filtersRelevantExpressions() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder() + .disallowExpression(Pattern.compile("^x-private-.*")).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create( + ImmutableList.of(header("x-public", "value"), header("x-private-key", "value")), + ImmutableList.of("x-public-remove", "x-private-remove")), + ResponseHeaderMutations.create( + ImmutableList.of(header("x-public-resp", "value"), header("x-private-resp", "value")))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()).containsExactly(header("x-public", "value")); + assertThat(filtered.requestMutations().headersToRemove()).containsExactly("x-public-remove"); + assertThat(filtered.responseMutations().headers()) + .containsExactly(header("x-public-resp", "value")); + } + + @Test + public void filter_allowExpression_onlyAllowsRelevantExpressions() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder() + .allowExpression(Pattern.compile("^x-allowed-.*")).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = + HeaderMutations.create( + RequestHeaderMutations.create( + ImmutableList.of(header("x-allowed-key", "value"), + header("not-allowed-key", "value")), + ImmutableList.of("x-allowed-remove", "not-allowed-remove")), + ResponseHeaderMutations.create(ImmutableList.of(header("x-allowed-resp-key", "value"), + header("not-allowed-resp-key", "value")))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()) + .containsExactly(header("x-allowed-key", "value")); + assertThat(filtered.requestMutations().headersToRemove()).containsExactly("x-allowed-remove"); + assertThat(filtered.responseMutations().headers()) + .containsExactly(header("x-allowed-resp-key", "value")); + } + + @Test + public void filter_allowExpression_overridesDisallowAll() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = HeaderMutationRulesConfig.builder().disallowAll(true) + .allowExpression(Pattern.compile("^x-allowed-.*")).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create( + ImmutableList.of(header("x-allowed-key", "value"), header("not-allowed", "value")), + ImmutableList.of("x-allowed-remove", "not-allowed-remove")), + ResponseHeaderMutations.create(ImmutableList.of(header("x-allowed-resp-key", "value"), + header("not-allowed-resp-key", "value")))); + + HeaderMutations filtered = filter.filter(mutations); + + assertThat(filtered.requestMutations().headers()) + .containsExactly(header("x-allowed-key", "value")); + assertThat(filtered.requestMutations().headersToRemove()).containsExactly("x-allowed-remove"); + assertThat(filtered.responseMutations().headers()) + .containsExactly(header("x-allowed-resp-key", "value")); + } + + @Test(expected = HeaderMutationDisallowedException.class) + public void filter_disallowIsError_throwsExceptionOnDisallowed() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = + HeaderMutationRulesConfig.builder().disallowAll(true).disallowIsError(true).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create(RequestHeaderMutations + .create(ImmutableList.of(header("add-key", "add-value")), ImmutableList.of()), + ResponseHeaderMutations.create(ImmutableList.of())); + filter.filter(mutations); + } + + @Test(expected = HeaderMutationDisallowedException.class) + public void filter_disallowIsError_throwsExceptionOnDisallowedRemove() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = + HeaderMutationRulesConfig.builder().disallowAll(true).disallowIsError(true).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create(ImmutableList.of(), ImmutableList.of("remove-key")), + ResponseHeaderMutations.create(ImmutableList.of())); + filter.filter(mutations); + } + + @Test(expected = HeaderMutationDisallowedException.class) + public void filter_disallowIsError_throwsExceptionOnDisallowedResponseHeader() + throws HeaderMutationDisallowedException { + HeaderMutationRulesConfig rules = + HeaderMutationRulesConfig.builder().disallowAll(true).disallowIsError(true).build(); + HeaderMutationFilter filter = HeaderMutationFilter.INSTANCE.create(Optional.of(rules)); + HeaderMutations mutations = HeaderMutations.create( + RequestHeaderMutations.create(ImmutableList.of(), ImmutableList.of()), + ResponseHeaderMutations.create(ImmutableList.of(header("resp-add-key", "resp-add-value")))); + filter.filter(mutations); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java new file mode 100644 index 00000000000..f1dc0561692 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutationsTest.java @@ -0,0 +1,50 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.grpc.xds.internal.headermutations.HeaderMutations.RequestHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import org.junit.Test; + +public class HeaderMutationsTest { + @Test + public void testCreate() { + HeaderValueOption reqHeader = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey("req-key").setValue("req-value").build()) + .build(); + RequestHeaderMutations requestMutations = RequestHeaderMutations + .create(ImmutableList.of(reqHeader), ImmutableList.of("remove-req-key")); + assertThat(requestMutations.headers()).containsExactly(reqHeader); + assertThat(requestMutations.headersToRemove()).containsExactly("remove-req-key"); + + HeaderValueOption respHeader = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey("resp-key").setValue("resp-value").build()) + .build(); + ResponseHeaderMutations responseMutations = + ResponseHeaderMutations.create(ImmutableList.of(respHeader)); + assertThat(responseMutations.headers()).containsExactly(respHeader); + + HeaderMutations mutations = HeaderMutations.create(requestMutations, responseMutations); + assertThat(mutations.requestMutations()).isEqualTo(requestMutations); + assertThat(mutations.responseMutations()).isEqualTo(responseMutations); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java new file mode 100644 index 00000000000..df6ce383d8c --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/headermutations/HeaderMutatorTest.java @@ -0,0 +1,311 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.headermutations; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import com.google.common.io.BaseEncoding; +import com.google.common.testing.TestLogHandler; +import com.google.protobuf.ByteString; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption.HeaderAppendAction; +import io.grpc.Metadata; +import io.grpc.xds.internal.headermutations.HeaderMutations.RequestHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import java.nio.charset.StandardCharsets; +import java.util.List; +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class HeaderMutatorTest { + + private static final Metadata.Key ASCII_KEY = + Metadata.Key.of("some-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key BINARY_KEY = + Metadata.Key.of("some-key-bin", Metadata.BINARY_BYTE_MARSHALLER); + private static final Metadata.Key APPEND_KEY = + Metadata.Key.of("append-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key ADD_KEY = + Metadata.Key.of("add-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key OVERWRITE_KEY = + Metadata.Key.of("overwrite-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key REMOVE_KEY = + Metadata.Key.of("remove-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key NEW_ADD_KEY = + Metadata.Key.of("new-add-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key NEW_OVERWRITE_KEY = + Metadata.Key.of("new-overwrite-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key OVERWRITE_IF_EXISTS_KEY = + Metadata.Key.of("overwrite-if-exists-key", Metadata.ASCII_STRING_MARSHALLER); + private static final Metadata.Key OVERWRITE_IF_EXISTS_ABSENT_KEY = + Metadata.Key.of("overwrite-if-exists-absent-key", Metadata.ASCII_STRING_MARSHALLER); + + private final HeaderMutator headerMutator = HeaderMutator.create(); + + private static final TestLogHandler logHandler = new TestLogHandler(); + private static final Logger logger = + Logger.getLogger(HeaderMutator.HeaderMutatorImpl.class.getName()); + + @Before + public void setUp() { + logHandler.clear(); + logger.addHandler(logHandler); + logger.setLevel(Level.WARNING); + } + + @After + public void tearDown() { + logger.removeHandler(logHandler); + } + + private static HeaderValueOption header(String key, String value, HeaderAppendAction action) { + return HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(key).setValue(value)).setAppendAction(action) + .build(); + } + + @Test + public void applyRequestMutations_asciiHeaders() { + Metadata headers = new Metadata(); + headers.put(APPEND_KEY, "append-value-1"); + headers.put(ADD_KEY, "add-value-original"); + headers.put(OVERWRITE_KEY, "overwrite-value-original"); + headers.put(REMOVE_KEY, "remove-value-original"); + headers.put(OVERWRITE_IF_EXISTS_KEY, "original-value"); + + RequestHeaderMutations mutations = RequestHeaderMutations.create(ImmutableList.of( + // Append to existing header + header(APPEND_KEY.name(), "append-value-2", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + // Try to add to an existing header (should be no-op) + header(ADD_KEY.name(), "add-value-new", HeaderAppendAction.ADD_IF_ABSENT), + // Add a new header + header(NEW_ADD_KEY.name(), "new-add-value", HeaderAppendAction.ADD_IF_ABSENT), + // Overwrite an existing header + header(OVERWRITE_KEY.name(), "overwrite-value-new", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + // Overwrite a new header + header(NEW_OVERWRITE_KEY.name(), "new-overwrite-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + // Overwrite an existing header if it exists + header(OVERWRITE_IF_EXISTS_KEY.name(), "new-value", HeaderAppendAction.OVERWRITE_IF_EXISTS), + // Try to overwrite a header that does not exist + header(OVERWRITE_IF_EXISTS_ABSENT_KEY.name(), "new-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS)), + ImmutableList.of(REMOVE_KEY.name())); + + headerMutator.applyRequestMutations(mutations, headers); + + assertThat(headers.getAll(APPEND_KEY)).containsExactly("append-value-1", "append-value-2"); + assertThat(headers.get(ADD_KEY)).isEqualTo("add-value-original"); + assertThat(headers.get(NEW_ADD_KEY)).isEqualTo("new-add-value"); + assertThat(headers.get(OVERWRITE_KEY)).isEqualTo("overwrite-value-new"); + assertThat(headers.get(NEW_OVERWRITE_KEY)).isEqualTo("new-overwrite-value"); + assertThat(headers.containsKey(REMOVE_KEY)).isFalse(); + assertThat(headers.get(OVERWRITE_IF_EXISTS_KEY)).isEqualTo("new-value"); + assertThat(headers.containsKey(OVERWRITE_IF_EXISTS_ABSENT_KEY)).isFalse(); + } + + @Test + public void applyRequestMutations_InvalidAppendAction_isIgnored() { + Metadata headers = new Metadata(); + headers.put(ASCII_KEY, "value1"); + headerMutator + .applyRequestMutations( + RequestHeaderMutations + .create( + ImmutableList.of( + HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(ASCII_KEY.name()) + .setValue("value2")) + .setAppendActionValue(-1).build(), + HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(BINARY_KEY.name()) + .setValue("value2")) + .setAppendActionValue(-5).build()), + ImmutableList.of()), + headers); + assertThat(headers.getAll(ASCII_KEY)).containsExactly("value1"); + } + + @Test + public void applyRequestMutations_removalHasPriority() { + Metadata headers = new Metadata(); + headers.put(REMOVE_KEY, "value"); + RequestHeaderMutations mutations = RequestHeaderMutations.create( + ImmutableList.of( + header(REMOVE_KEY.name(), "new-value", HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD)), + ImmutableList.of(REMOVE_KEY.name())); + + headerMutator.applyRequestMutations(mutations, headers); + + assertThat(headers.containsKey(REMOVE_KEY)).isFalse(); + } + + @Test + public void applyRequestMutations_binary_withBase64RawValue() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + HeaderValueOption option = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(BINARY_KEY.name()).setRawValue( + ByteString.copyFrom(BaseEncoding.base64().encode(value), StandardCharsets.US_ASCII))) + .setAppendAction(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD).build(); + headerMutator.applyRequestMutations( + RequestHeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + assertThat(headers.get(BINARY_KEY)).isEqualTo(value); + } + + @Test + public void applyRequestMutations_binary_withBase64Value() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + String base64Value = BaseEncoding.base64().encode(value); + HeaderValueOption option = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(BINARY_KEY.name()).setValue(base64Value)) + .setAppendAction(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD).build(); + + headerMutator.applyRequestMutations( + RequestHeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + assertThat(headers.get(BINARY_KEY)).isEqualTo(value); + } + + @Test + public void applyRequestMutations_ascii_withRawValue() { + Metadata headers = new Metadata(); + byte[] value = "raw-value".getBytes(StandardCharsets.US_ASCII); + HeaderValueOption option = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(ASCII_KEY.name()) + .setRawValue(ByteString.copyFrom(value))) + .setAppendAction(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD).build(); + headerMutator.applyRequestMutations( + RequestHeaderMutations.create(ImmutableList.of(option), ImmutableList.of()), headers); + assertThat(headers.get(Metadata.Key.of(ASCII_KEY.name(), Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("raw-value"); + } + + @Test + public void applyResponseMutations_asciiHeaders() { + Metadata headers = new Metadata(); + headers.put(APPEND_KEY, "append-value-1"); + headers.put(ADD_KEY, "add-value-original"); + headers.put(OVERWRITE_KEY, "overwrite-value-original"); + + ResponseHeaderMutations mutations = ResponseHeaderMutations.create(ImmutableList.of( + header(APPEND_KEY.name(), "append-value-2", HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD), + header(ADD_KEY.name(), "add-value-new", HeaderAppendAction.ADD_IF_ABSENT), + header(NEW_ADD_KEY.name(), "new-add-value", HeaderAppendAction.ADD_IF_ABSENT), + header(OVERWRITE_KEY.name(), "overwrite-value-new", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD), + header(NEW_OVERWRITE_KEY.name(), "new-overwrite-value", + HeaderAppendAction.OVERWRITE_IF_EXISTS_OR_ADD))); + + headerMutator.applyResponseMutations(mutations, headers); + + assertThat(headers.getAll(APPEND_KEY)).containsExactly("append-value-1", "append-value-2"); + assertThat(headers.get(ADD_KEY)).isEqualTo("add-value-original"); + assertThat(headers.get(NEW_ADD_KEY)).isEqualTo("new-add-value"); + assertThat(headers.get(OVERWRITE_KEY)).isEqualTo("overwrite-value-new"); + assertThat(headers.get(NEW_OVERWRITE_KEY)).isEqualTo("new-overwrite-value"); + } + + + @Test + public void applyResponseMutations_InvalidAppendAction_isIgnored() { + Metadata headers = new Metadata(); + headers.put(ASCII_KEY, "value1"); + headerMutator + .applyResponseMutations( + ResponseHeaderMutations + .create( + ImmutableList.of( + HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(ASCII_KEY.name()) + .setValue("value2")) + .setAppendActionValue(-1).build(), + HeaderValueOption + .newBuilder().setHeader(HeaderValue.newBuilder() + .setKey(BINARY_KEY.name()).setValue("value2")) + .setAppendActionValue(-5).build())), + headers); + assertThat(headers.getAll(ASCII_KEY)).containsExactly("value1"); + } + + @Test + public void applyResponseMutations_binary_withBase64RawValue() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + HeaderValueOption option = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(BINARY_KEY.name()).setRawValue( + ByteString.copyFrom(BaseEncoding.base64().encode(value), StandardCharsets.US_ASCII))) + .setAppendAction(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD).build(); + headerMutator.applyResponseMutations(ResponseHeaderMutations.create(ImmutableList.of(option)), + headers); + assertThat(headers.get(BINARY_KEY)).isEqualTo(value); + } + + @Test + public void applyResponseMutations_binary_withBase64Value() { + Metadata headers = new Metadata(); + byte[] value = new byte[] {1, 2, 3}; + String base64Value = BaseEncoding.base64().encode(value); + HeaderValueOption option = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(BINARY_KEY.name()).setValue(base64Value)) + .setAppendAction(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD).build(); + + headerMutator.applyResponseMutations(ResponseHeaderMutations.create(ImmutableList.of(option)), + headers); + assertThat(headers.get(BINARY_KEY)).isEqualTo(value); + } + + @Test + public void applyResponseMutations_ascii_withRawValue() { + Metadata headers = new Metadata(); + byte[] value = "raw-value".getBytes(StandardCharsets.US_ASCII); + HeaderValueOption option = HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey(ASCII_KEY.name()) + .setRawValue(ByteString.copyFrom(value))) + .setAppendAction(HeaderAppendAction.APPEND_IF_EXISTS_OR_ADD).build(); + + headerMutator.applyResponseMutations(ResponseHeaderMutations.create(ImmutableList.of(option)), + headers); + assertThat(headers.get(Metadata.Key.of(ASCII_KEY.name(), Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("raw-value"); + } + + @Test + public void applyRequestMutations_unrecognizedAction_logsWarning() { + Metadata headers = new Metadata(); + RequestHeaderMutations mutations = + RequestHeaderMutations.create(ImmutableList.of(HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey("key").setValue("value")) + .setAppendActionValue(-1).build()), ImmutableList.of()); + headerMutator.applyRequestMutations(mutations, headers); + + List records = logHandler.getStoredLogRecords(); + assertThat(records).hasSize(1); + assertThat(records.get(0).getMessage()) + .contains("Unrecognized HeaderAppendAction: UNRECOGNIZED"); + } +} From ff89060115ecc738bfeec88b1fd3dcdf38297492 Mon Sep 17 00:00:00 2001 From: Saurav Date: Sun, 2 Nov 2025 19:36:53 +0000 Subject: [PATCH 4/7] feat(xds): Implement response handling for external authorization This commit introduces the `CheckResponseHandler` and `AuthzResponse` classes, which are responsible for processing responses from the external authorization service. The `CheckResponseHandler` parses the `CheckResponse` protobuf, determines whether the request should be allowed or denied, and applies any header mutations specified in the response. It handles both `OkHttpResponse` and `DeniedHttpResponse` messages. The `AuthzResponse` class is a value object that represents the outcome of the authorization check, encapsulating the decision (allow or deny), the status to be returned to the client (for deny decisions), and any header mutations. This commit also includes unit tests for the new components. --- .../xds/internal/extauthz/AuthzResponse.java | 91 +++++++++ .../extauthz/CheckResponseHandler.java | 148 ++++++++++++++ .../internal/extauthz/AuthzResponseTest.java | 66 ++++++ .../extauthz/CheckResponseHandlerTest.java | 191 ++++++++++++++++++ 4 files changed, 496 insertions(+) create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/AuthzResponse.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/CheckResponseHandler.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/AuthzResponseTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/CheckResponseHandlerTest.java diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/AuthzResponse.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/AuthzResponse.java new file mode 100644 index 00000000000..530badb631b --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/AuthzResponse.java @@ -0,0 +1,91 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.auto.value.AutoValue; +import com.google.common.collect.ImmutableList; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import java.util.Optional; + +/** + * Represents the outcome of an authorization check, detailing whether the request is allowed or + * denied and including any associated headers or status information. + */ +@AutoValue +public abstract class AuthzResponse { + + /** Defines the authorization decision. */ + public enum Decision { + /** The request is permitted. */ + ALLOW, + /** The request is rejected. */ + DENY, + } + + /** Creates a builder for an ALLOW response, initializing with the specified headers. */ + public static Builder allow(Metadata headers) { + return new AutoValue_AuthzResponse.Builder().setDecision(Decision.ALLOW) + .setResponseHeaderMutations(ResponseHeaderMutations.create(ImmutableList.of())) + .setHeaders(headers); + } + + /** Creates a builder for a DENY response, initializing with the specified status. */ + public static Builder deny(Status status) { + return new AutoValue_AuthzResponse.Builder().setDecision(Decision.DENY) + .setResponseHeaderMutations(ResponseHeaderMutations.create(ImmutableList.of())) + .setStatus(status); + } + + /** Returns the authorization decision. */ + public abstract Decision decision(); + + /** + * For DENY decisions, this provides the status to be returned to the calling client. It is empty + * for ALLOW decisions. + */ + public abstract Optional status(); + + /** + * For ALLOW decisions, this provides the headers to be appended to the request headers for + * upstream. It is empty for DENY decisions. + */ + public abstract Optional headers(); + + /** + * Returns mutations to be applied to the response headers. This is used for both ALLOW and DENY + * decisions. + */ + public abstract ResponseHeaderMutations responseHeaderMutations(); + + /** Builder for creating {@link AuthzResponse} instances. */ + @AutoValue.Builder + public abstract static class Builder { + + abstract Builder setDecision(Decision decision); + + abstract Builder setStatus(Status status); + + abstract Builder setHeaders(Metadata headers); + + public abstract Builder setResponseHeaderMutations( + ResponseHeaderMutations responseHeaderMutations); + + public abstract AuthzResponse build(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckResponseHandler.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckResponseHandler.java new file mode 100644 index 00000000000..6f03bcd1302 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/CheckResponseHandler.java @@ -0,0 +1,148 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.service.auth.v3.CheckResponse; +import io.envoyproxy.envoy.service.auth.v3.DeniedHttpResponse; +import io.envoyproxy.envoy.service.auth.v3.OkHttpResponse; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.internal.GrpcUtil; +import io.grpc.xds.internal.headermutations.HeaderMutationDisallowedException; +import io.grpc.xds.internal.headermutations.HeaderMutationFilter; +import io.grpc.xds.internal.headermutations.HeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutator; + +/** + * Handles the response from the external authorization service, processing it to determine the + * authorization decision and applying any necessary header mutations. + */ +public interface CheckResponseHandler { + + /** + * A factory for creating {@link CheckResponseHandler} instances. + */ + @FunctionalInterface + interface Factory { + /** + * Creates a new ResponseHandler. + * + * @param headerMutator Utility to apply header mutations. + * @param headerMutationFilter Filter to apply to header mutations. + * @param config The external authorization configuration. + */ + CheckResponseHandler create(HeaderMutator headerMutator, + HeaderMutationFilter headerMutationFilter, ExtAuthzConfig config); + } + + /** + * The default factory for creating {@link CheckResponseHandler} instances. + */ + Factory INSTANCE = ResponseHandlerImpl::new; + + /** + * Processes the CheckResponse from the external authorization service. + * + * @param response The response from the authorization service. + * @param headers The request headers, which may be mutated as part of handling the response. + * @return An {@link AuthzResponse} indicating the outcome of the authorization check. + */ + AuthzResponse handleResponse(final CheckResponse response, Metadata headers); + + /** Default implementation of {@link CheckResponseHandler}. */ + static final class ResponseHandlerImpl implements CheckResponseHandler { + private final HeaderMutator headerMutator; + private final HeaderMutationFilter headerMutationFilter; + private final ExtAuthzConfig config; + + ResponseHandlerImpl(HeaderMutator headerMutator, // NOPMD + HeaderMutationFilter headerMutationFilter, ExtAuthzConfig config) { + this.headerMutator = headerMutator; + this.headerMutationFilter = headerMutationFilter; + this.config = config; + } + + @Override + public AuthzResponse handleResponse(final CheckResponse response, Metadata headers) { + try { + if (response.getStatus().getCode() == Status.Code.OK.value()) { + return handleOkResponse(response, headers); + } else { + return handleNotOkResponse(response); + } + } catch (HeaderMutationDisallowedException e) { + return AuthzResponse.deny(e.getStatus()).build(); + } + } + + private AuthzResponse handleOkResponse(final CheckResponse response, Metadata headers) + throws HeaderMutationDisallowedException { + if (!response.hasOkResponse()) { + return AuthzResponse.allow(headers).build(); + } + OkHttpResponse okResponse = response.getOkResponse(); + HeaderMutations requestedMutations = buildHeaderMutationsFromOkResponse(okResponse); + HeaderMutations allowedMutations = headerMutationFilter.filter(requestedMutations); + + applyMutations(allowedMutations, headers); + return AuthzResponse.allow(headers) + .setResponseHeaderMutations(allowedMutations.responseMutations()).build(); + } + + private HeaderMutations buildHeaderMutationsFromOkResponse(OkHttpResponse okResponse) { + return HeaderMutations.create( + HeaderMutations.RequestHeaderMutations.create( + ImmutableList.copyOf(okResponse.getHeadersList()), + ImmutableList.copyOf(okResponse.getHeadersToRemoveList())), + HeaderMutations.ResponseHeaderMutations + .create(ImmutableList.copyOf(okResponse.getResponseHeadersToAddList()))); + } + + private AuthzResponse handleNotOkResponse(CheckResponse response) + throws HeaderMutationDisallowedException { + Status statusToReturn = config.statusOnError(); + if (!response.hasDeniedResponse()) { + return AuthzResponse.deny(statusToReturn).build(); + } + DeniedHttpResponse deniedResponse = response.getDeniedResponse(); + HeaderMutations requestedMutations = buildHeaderMutationsFromDeniedResponse(deniedResponse); + HeaderMutations allowedMutations = headerMutationFilter.filter(requestedMutations); + + Status status = statusToReturn; + if (deniedResponse.hasStatus()) { + status = GrpcUtil.httpStatusToGrpcStatus(deniedResponse.getStatus().getCodeValue()) + .withDescription(deniedResponse.getBody()); + } + return AuthzResponse.deny(status) + .setResponseHeaderMutations(allowedMutations.responseMutations()).build(); + } + + private HeaderMutations buildHeaderMutationsFromDeniedResponse( + DeniedHttpResponse deniedResponse) { + return HeaderMutations.create( + HeaderMutations.RequestHeaderMutations.create(ImmutableList.of(), ImmutableList.of()), + HeaderMutations.ResponseHeaderMutations + .create(ImmutableList.copyOf(deniedResponse.getHeadersList()))); + } + + + private void applyMutations(final HeaderMutations mutations, Metadata headers) { + headerMutator.applyRequestMutations(mutations.requestMutations(), headers); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/AuthzResponseTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/AuthzResponseTest.java new file mode 100644 index 00000000000..e81e356fe75 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/AuthzResponseTest.java @@ -0,0 +1,66 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; + +import com.google.common.collect.ImmutableList; +import io.envoyproxy.envoy.config.core.v3.HeaderValue; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.xds.internal.extauthz.AuthzResponse.Decision; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; + +@RunWith(JUnit4.class) +public class AuthzResponseTest { + @Test + public void testAllow() { + Metadata headers = new Metadata(); + headers.put(Metadata.Key.of("foo", Metadata.ASCII_STRING_MARSHALLER), "bar"); + AuthzResponse response = AuthzResponse.allow(headers).build(); + assertThat(response.decision()).isEqualTo(Decision.ALLOW); + assertThat(response.headers()).hasValue(headers); + assertThat(response.status()).isEmpty(); + assertThat(response.responseHeaderMutations().headers()).isEmpty(); + } + + @Test + public void testAllowWithHeaderMutations() { + Metadata headers = new Metadata(); + ResponseHeaderMutations mutations = + ResponseHeaderMutations.create(ImmutableList.of(HeaderValueOption.newBuilder() + .setHeader(HeaderValue.newBuilder().setKey("key").setValue("value")).build())); + AuthzResponse response = + AuthzResponse.allow(headers).setResponseHeaderMutations(mutations).build(); + assertThat(response.decision()).isEqualTo(Decision.ALLOW); + assertThat(response.responseHeaderMutations()).isEqualTo(mutations); + } + + @Test + public void testDeny() { + Status status = Status.PERMISSION_DENIED.withDescription("reason"); + AuthzResponse response = AuthzResponse.deny(status).build(); + assertThat(response.decision()).isEqualTo(Decision.DENY); + assertThat(response.status()).hasValue(status); + assertThat(response.headers()).isEmpty(); + assertThat(response.responseHeaderMutations().headers()).isEmpty(); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckResponseHandlerTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckResponseHandlerTest.java new file mode 100644 index 00000000000..31b14a312c4 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/CheckResponseHandlerTest.java @@ -0,0 +1,191 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Any; +import com.google.rpc.Code; +import io.envoyproxy.envoy.config.core.v3.HeaderValueOption; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.service.auth.v3.CheckResponse; +import io.envoyproxy.envoy.service.auth.v3.DeniedHttpResponse; +import io.envoyproxy.envoy.service.auth.v3.OkHttpResponse; +import io.envoyproxy.envoy.type.v3.HttpStatus; +import io.grpc.Metadata; +import io.grpc.Status; +import io.grpc.xds.internal.extauthz.AuthzResponse.Decision; +import io.grpc.xds.internal.headermutations.HeaderMutationDisallowedException; +import io.grpc.xds.internal.headermutations.HeaderMutationFilter; +import io.grpc.xds.internal.headermutations.HeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutator; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class CheckResponseHandlerTest { + @Rule + public final MockitoRule mockitoRule = MockitoJUnit.rule(); + + @Mock + private HeaderMutator headerMutator; + @Mock + private HeaderMutationFilter headerMutationFilter; + + private CheckResponseHandler responseHandler; + + @Before + public void setUp() throws Exception { + responseHandler = + CheckResponseHandler.INSTANCE.create(headerMutator, headerMutationFilter, + buildExtAuthzConfig()); + when(headerMutationFilter.filter(any(HeaderMutations.class))) + .thenAnswer(invocation -> invocation.getArgument(0)); + } + + @Test + public void handleResponse_ok() { + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Code.OK_VALUE).build()).build(); + Metadata headers = new Metadata(); + AuthzResponse authzResponse = responseHandler.handleResponse(checkResponse, headers); + assertThat(authzResponse.decision()).isEqualTo(Decision.ALLOW); + assertThat(authzResponse.headers()).hasValue(headers); + } + + @Test + public void handleResponse_okWithMutations() { + HeaderValueOption option = HeaderValueOption.newBuilder().build(); + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Code.OK_VALUE).build()) + .setOkResponse(OkHttpResponse.newBuilder().addHeaders(option) + .addHeadersToRemove("remove-key").addResponseHeadersToAdd(option).build()) + .build(); + Metadata headers = new Metadata(); + AuthzResponse authzResponse = responseHandler.handleResponse(checkResponse, headers); + assertThat(authzResponse.decision()).isEqualTo(Decision.ALLOW); + assertThat(authzResponse.headers()).hasValue(headers); + HeaderMutations expectedMutations = HeaderMutations.create( + HeaderMutations.RequestHeaderMutations.create(ImmutableList.of(option), + ImmutableList.of("remove-key")), + HeaderMutations.ResponseHeaderMutations.create(ImmutableList.of(option))); + verify(headerMutator).applyRequestMutations(expectedMutations.requestMutations(), headers); + assertThat(authzResponse.responseHeaderMutations()) + .isEqualTo(expectedMutations.responseMutations()); + } + + @Test + public void handleResponse_notOk() { + CheckResponse checkResponse = CheckResponse.newBuilder().setStatus(com.google.rpc.Status + .newBuilder().setCode(Code.PERMISSION_DENIED_VALUE).setMessage("denied").build()).build(); + Metadata headers = new Metadata(); + AuthzResponse authzResponse = responseHandler.handleResponse(checkResponse, headers); + assertThat(authzResponse.decision()).isEqualTo(Decision.DENY); + assertThat(authzResponse.status().isPresent()).isTrue(); + assertThat(authzResponse.status().get().getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(authzResponse.status().get().getDescription()).isEqualTo("HTTP status code 403"); + verify(headerMutator, never()).applyRequestMutations(any(), any()); + } + + @Test + public void handleResponse_deniedResponseWithoutStatusOverride() { + HeaderValueOption option = HeaderValueOption.newBuilder().build(); + DeniedHttpResponse deniedHttpResponse = + DeniedHttpResponse.newBuilder().addHeaders(option).build(); + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Code.ABORTED_VALUE).build()) + .setDeniedResponse(deniedHttpResponse).build(); + Metadata headers = new Metadata(); + AuthzResponse authzResponse = responseHandler.handleResponse(checkResponse, headers); + assertThat(authzResponse.decision()).isEqualTo(Decision.DENY); + assertThat(authzResponse.status().get().getCode()) + .isEqualTo(Status.PERMISSION_DENIED.getCode()); + assertThat(authzResponse.status().get().getDescription()).isEqualTo("HTTP status code 403"); + HeaderMutations.ResponseHeaderMutations expectedMutations = + HeaderMutations.ResponseHeaderMutations.create(ImmutableList.of(option)); + assertThat(authzResponse.responseHeaderMutations()).isEqualTo(expectedMutations); + verify(headerMutator, never()).applyRequestMutations(any(), any()); + } + + @Test + public void handleResponse_deniedResponseWithStatusOverride() { + DeniedHttpResponse deniedHttpResponse = + DeniedHttpResponse.newBuilder().setStatus(HttpStatus.newBuilder().setCodeValue(401).build()) + .setBody("custom body").build(); + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Code.ABORTED_VALUE).build()) + .setDeniedResponse(deniedHttpResponse).build(); + Metadata headers = new Metadata(); + AuthzResponse authzResponse = responseHandler.handleResponse(checkResponse, headers); + assertThat(authzResponse.decision()).isEqualTo(Decision.DENY); + assertThat(authzResponse.status().isPresent()).isTrue(); + Status status = authzResponse.status().get(); + assertThat(status.getCode()).isEqualTo(Status.Code.UNAUTHENTICATED); + assertThat(status.getDescription()).isEqualTo("custom body"); + HeaderMutations.ResponseHeaderMutations expectedMutations = + HeaderMutations.ResponseHeaderMutations.create(ImmutableList.of()); + assertThat(authzResponse.responseHeaderMutations()).isEqualTo(expectedMutations); + verify(headerMutator, never()).applyRequestMutations(any(), any()); + } + + @Test + public void handleResponse_okWithDisallowedMutation() throws HeaderMutationDisallowedException { + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Code.OK_VALUE).build()) + .setOkResponse(OkHttpResponse.newBuilder().build()).build(); + Metadata headers = new Metadata(); + HeaderMutationDisallowedException exception = + new HeaderMutationDisallowedException("disallowed"); + when(headerMutationFilter.filter(any(HeaderMutations.class))).thenThrow(exception); + + AuthzResponse authzResponse = responseHandler.handleResponse(checkResponse, headers); + + assertThat(authzResponse.decision()).isEqualTo(Decision.DENY); + assertThat(authzResponse.status().get().getCode()).isEqualTo(Status.INTERNAL.getCode()); + assertThat(authzResponse.status().get().getDescription()).isEqualTo("disallowed"); + } + + private ExtAuthzConfig buildExtAuthzConfig() throws ExtAuthzParseException { + Any googleDefaultChannelCreds = Any.pack(GoogleDefaultCredentials.newBuilder().build()); + Any fakeAccessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); + ExtAuthz extAuthz = ExtAuthz.newBuilder() + .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder() + .setGoogleGrpc(io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test-cluster").addChannelCredentialsPlugin(googleDefaultChannelCreds) + .addCallCredentialsPlugin(fakeAccessTokenCreds).build()) + .build()) + .setStatusOnError( + io.envoyproxy.envoy.type.v3.HttpStatus.newBuilder().setCodeValue(403).build()) + .build(); + return ExtAuthzConfig.fromProto(extAuthz); + } +} From e1787e2d9b707731d646ce00e38997368b88d458 Mon Sep 17 00:00:00 2001 From: Saurav Date: Thu, 23 Oct 2025 10:26:01 +0000 Subject: [PATCH 5/7] feat(xds): Add ExtAuthzClientInterceptor and related components This commit introduces the client-side implementation of the external authorization filter. The main component is the `ExtAuthzClientInterceptor`, which intercepts outgoing RPCs and performs external authorization checks. It uses a `BufferingAuthzClientCall` to buffer the outgoing RPC until the authorization decision is received from the authorization service. The following new classes are introduced: - `ExtAuthzClientInterceptor`: The main client interceptor for external authorization. - `BufferingAuthzClientCall`: A `ClientCall` implementation that buffers requests until an authorization decision is made. - `CallBuffer`: A helper class for `BufferingAuthzClientCall` to manage the buffered calls. - `FailingClientCall`: A utility `ClientCall` that immediately fails, used when the filter is disabled and configured to deny calls. This commit also includes comprehensive unit and integration tests for the new components. --- .../io/grpc/internal/FailingClientCall.java | 57 ++ .../grpc/internal/FailingClientCallTest.java | 76 +++ .../java/io/grpc/xds/ThreadSafeRandom.java | 20 +- .../grpc/xds/internal/ThreadSafeRandom.java | 54 ++ .../extauthz/BufferingAuthzClientCall.java | 224 +++++++ .../xds/internal/extauthz/CallBuffer.java | 87 +++ .../extauthz/ExtAuthzClientInterceptor.java | 80 +++ .../BufferingAuthzClientCallTest.java | 611 ++++++++++++++++++ .../xds/internal/extauthz/CallBufferTest.java | 190 ++++++ .../ExtAuthzClientInterceptorTest.java | 175 +++++ 10 files changed, 1563 insertions(+), 11 deletions(-) create mode 100644 core/src/main/java/io/grpc/internal/FailingClientCall.java create mode 100644 core/src/test/java/io/grpc/internal/FailingClientCallTest.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/ThreadSafeRandom.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/BufferingAuthzClientCall.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/CallBuffer.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzClientInterceptor.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/BufferingAuthzClientCallTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/CallBufferTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzClientInterceptorTest.java diff --git a/core/src/main/java/io/grpc/internal/FailingClientCall.java b/core/src/main/java/io/grpc/internal/FailingClientCall.java new file mode 100644 index 00000000000..33c7012f09f --- /dev/null +++ b/core/src/main/java/io/grpc/internal/FailingClientCall.java @@ -0,0 +1,57 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import io.grpc.ClientCall; +import io.grpc.Metadata; +import io.grpc.Status; +import javax.annotation.Nullable; + +/** + * A {@link ClientCall} that fails immediately upon starting. + */ +public final class FailingClientCall extends ClientCall { + + private final Status error; + + /** + * Creates a new call that will fail with the given error. + */ + public FailingClientCall(Status error) { + this.error = error; + } + + /** + * Immediately fails the call by calling {@link Listener#onClose}. + */ + @Override + public void start(Listener responseListener, Metadata headers) { + responseListener.onClose(error, new Metadata()); + } + + @Override + public void request(int numMessages) {} + + @Override + public void cancel(@Nullable String message, @Nullable Throwable cause) {} + + @Override + public void halfClose() {} + + @Override + public void sendMessage(ReqT message) {} +} diff --git a/core/src/test/java/io/grpc/internal/FailingClientCallTest.java b/core/src/test/java/io/grpc/internal/FailingClientCallTest.java new file mode 100644 index 00000000000..6fabfdd4b91 --- /dev/null +++ b/core/src/test/java/io/grpc/internal/FailingClientCallTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2016 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.internal; + +import static org.junit.Assert.assertEquals; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +import io.grpc.ClientCall; +import io.grpc.Metadata; +import io.grpc.Status; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Unit tests for {@link FailingClientCall}. */ +@RunWith(JUnit4.class) +public class FailingClientCallTest { + + @Rule public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private ClientCall.Listener mockListener; + + @Test + public void startCallsOnClose() { + Status error = Status.UNAVAILABLE.withDescription("test error"); + FailingClientCall call = new FailingClientCall<>(error); + Metadata metadata = new Metadata(); + call.start(mockListener, metadata); + + ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(mockListener).onClose(eq(error), metadataCaptor.capture()); + assertEquals(0, metadataCaptor.getValue().keys().size()); + verifyNoMoreInteractions(mockListener); + } + + @Test + public void otherMethodsAreNoOps() { + Status error = Status.UNAVAILABLE.withDescription("test error"); + FailingClientCall call = new FailingClientCall<>(error); + Metadata metadata = new Metadata(); + + call.start(mockListener, metadata); // Must call start first + + call.request(1); + call.cancel("message", new RuntimeException("cause")); + call.halfClose(); + call.sendMessage(new Object()); + + ArgumentCaptor metadataCaptor = ArgumentCaptor.forClass(Metadata.class); + verify(mockListener).onClose(eq(error), metadataCaptor.capture()); + assertEquals(0, metadataCaptor.getValue().keys().size()); + verifyNoMoreInteractions(mockListener); + } +} diff --git a/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java b/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java index 533ccee2375..87bd2ef1023 100644 --- a/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java +++ b/xds/src/main/java/io/grpc/xds/ThreadSafeRandom.java @@ -16,36 +16,34 @@ package io.grpc.xds; -import java.util.concurrent.ThreadLocalRandom; import javax.annotation.concurrent.ThreadSafe; -@ThreadSafe // Except for impls/mocks in tests -interface ThreadSafeRandom { - int nextInt(int bound); - - long nextLong(); - - long nextLong(long bound); +// TODO(sauravzg): Remove this class once all usages within xds are migrated to +// the internal version. +@ThreadSafe +interface ThreadSafeRandom extends io.grpc.xds.internal.ThreadSafeRandom { final class ThreadSafeRandomImpl implements ThreadSafeRandom { static final ThreadSafeRandom instance = new ThreadSafeRandomImpl(); + private final io.grpc.xds.internal.ThreadSafeRandom delegate = + io.grpc.xds.internal.ThreadSafeRandom.ThreadSafeRandomImpl.INSTANCE; private ThreadSafeRandomImpl() {} @Override public int nextInt(int bound) { - return ThreadLocalRandom.current().nextInt(bound); + return delegate.nextInt(bound); } @Override public long nextLong() { - return ThreadLocalRandom.current().nextLong(); + return delegate.nextLong(); } @Override public long nextLong(long bound) { - return ThreadLocalRandom.current().nextLong(bound); + return delegate.nextLong(bound); } } } diff --git a/xds/src/main/java/io/grpc/xds/internal/ThreadSafeRandom.java b/xds/src/main/java/io/grpc/xds/internal/ThreadSafeRandom.java new file mode 100644 index 00000000000..a51bfc8d6da --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/ThreadSafeRandom.java @@ -0,0 +1,54 @@ +/* + * Copyright 2023 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal; + +import java.util.concurrent.ThreadLocalRandom; +import javax.annotation.concurrent.ThreadSafe; + +/** + * A thread-safe random number generator. This is intended for internal use only. + */ +@ThreadSafe // Except for impls/mocks in tests +public interface ThreadSafeRandom { + int nextInt(int bound); + + long nextLong(); + + long nextLong(long bound); + + final class ThreadSafeRandomImpl implements ThreadSafeRandom { + + public static final ThreadSafeRandom INSTANCE = new ThreadSafeRandomImpl(); + + private ThreadSafeRandomImpl() {} + + @Override + public int nextInt(int bound) { + return ThreadLocalRandom.current().nextInt(bound); + } + + @Override + public long nextLong() { + return ThreadLocalRandom.current().nextLong(); + } + + @Override + public long nextLong(long bound) { + return ThreadLocalRandom.current().nextLong(bound); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/BufferingAuthzClientCall.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/BufferingAuthzClientCall.java new file mode 100644 index 00000000000..8cede96c897 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/BufferingAuthzClientCall.java @@ -0,0 +1,224 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.protobuf.util.Timestamps; +import io.envoyproxy.envoy.service.auth.v3.AuthorizationGrpc; +import io.envoyproxy.envoy.service.auth.v3.CheckRequest; +import io.envoyproxy.envoy.service.auth.v3.CheckResponse; +import io.grpc.Attributes; +import io.grpc.ClientCall; +import io.grpc.ForwardingClientCallListener; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Status; +import io.grpc.stub.StreamObserver; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutator; +import java.util.concurrent.atomic.AtomicBoolean; +import javax.annotation.Nullable; + +public final class BufferingAuthzClientCall extends ClientCall { + + private static final String X_ENVOY_AUTH_FAILURE_MODE_ALLOWED = + "x-envoy-auth-failure-mode-allowed"; + + /** A factory for creating {@link BufferingAuthzClientCall} instances. */ + @FunctionalInterface + public interface Factory { + ClientCall create(ClientCall delegate, + ExtAuthzConfig config, AuthorizationGrpc.AuthorizationStub authzStub, + CheckRequestBuilder checkRequestBuilder, CheckResponseHandler responseHandler, + HeaderMutator headerMutator, MethodDescriptor method, CallBuffer callBuffer); + } + + public static final Factory FACTORY_INSTANCE = BufferingAuthzClientCall::new; + + private final ClientCall delegate; + private final ExtAuthzConfig config; + private final MethodDescriptor method; + private final AuthorizationGrpc.AuthorizationStub authzStub; + private final CallBuffer callBuffer; + private final CheckRequestBuilder checkRequestBuilder; + private final CheckResponseHandler responseHandler; + private final HeaderMutator headerMutator; + private final AtomicBoolean callFailed = new AtomicBoolean(false); + + private BufferingAuthzClientCall(ClientCall delegate, ExtAuthzConfig config, + AuthorizationGrpc.AuthorizationStub authzStub, CheckRequestBuilder checkRequestBuilder, + CheckResponseHandler responseHandler, HeaderMutator headerMutator, + MethodDescriptor method, CallBuffer callBuffer) { + this.delegate = delegate; + this.config = config; + this.authzStub = authzStub; + this.checkRequestBuilder = checkRequestBuilder; + this.responseHandler = responseHandler; + this.headerMutator = headerMutator; + this.method = method; + this.callBuffer = callBuffer; + } + + private ClientCall delegate() { + return delegate; + } + + @Override + public boolean isReady() { + return callBuffer.isProcessed() && delegate.isReady(); + } + + + @Override + public void start(Listener responseListener, Metadata headers) { + // Headers is not thread-safe beyond `start`, so we need to create a copy to use in the async + // callback. + Metadata headersCopy = new Metadata(); + headersCopy.merge(headers); + StreamObserver observer = new StreamObserver() { + @Override + public void onNext(CheckResponse value) { + // This operation may mutate the headers + AuthzResponse authzResponse = responseHandler.handleResponse(value, headers); + if (authzResponse.decision() == AuthzResponse.Decision.ALLOW) { + // A allow response is guaranteed to have metadata set, so the `get` without + // check is safe. + delegate.start( + HeaderMutatingClientCallListener.create(responseListener, + authzResponse.responseHeaderMutations(), headerMutator), + authzResponse.headers().get()); + callBuffer.runAndFlush(); + } else { + // A deny response is guaranteed to have a status set, so the `get` without + // check is safe. + failUnstartedCall(authzResponse.status().get(), new Metadata(), responseListener); + } + } + + @Override + public void onError(Throwable t) { + // If failureModeAllow is true, bypass the authorization failure + if (config.failureModeAllow()) { + if (config.failureModeAllowHeaderAdd()) { + Metadata.Key failureModeKey = Metadata.Key.of(X_ENVOY_AUTH_FAILURE_MODE_ALLOWED, + Metadata.ASCII_STRING_MARSHALLER); + headersCopy.put(failureModeKey, "true"); + } + delegate.start(responseListener, headersCopy); + callBuffer.runAndFlush(); + } else { + // Authorization failed and failureModeAllow is false + Status statusToReturn = config.statusOnError().withCause(t); + failUnstartedCall(statusToReturn, new Metadata(), responseListener); + } + } + + @Override + public void onCompleted() { + // no-op, since this is a unary API. + } + }; + CheckRequest request = checkRequestBuilder.buildRequest(method, headers, + Timestamps.fromMillis(System.currentTimeMillis())); + authzStub.check(request, observer); + } + + @Override + public void request(int numMessages) { + runOrBuffer(() -> delegate().request(numMessages)); + } + + @Override + public void cancel(@Nullable String message, @Nullable Throwable cause) { + delegate().cancel(message, cause); + callBuffer.abandon(); + } + + @Override + public void halfClose() { + runOrBuffer(() -> delegate().halfClose()); + } + + @Override + public void sendMessage(ReqT message) { + runOrBuffer(() -> delegate().sendMessage(message)); + } + + @Override + public void setMessageCompression(boolean enabled) { + runOrBuffer(() -> delegate().setMessageCompression(enabled)); + } + + @Override + public Attributes getAttributes() { + // Since returning attributes can't be buffered and no other method except `cancel` can be + // called on the delegated object until it's started,we will have to unfortunately return empty + // until we are sure that `start` had been called. + if (!callBuffer.isProcessed() || callFailed.get()) { + return Attributes.EMPTY; + } else { + return delegate.getAttributes(); + } + } + + private void runOrBuffer(Runnable runnable) { + if (callFailed.get()) { + return; + } + if (callBuffer.isProcessed()) { + runnable.run(); + } else { + callBuffer.runOrBuffer(runnable); + } + } + + private void failUnstartedCall(Status status, Metadata trailers, + Listener responseListener) { + callFailed.set(true); + responseListener.onClose(status, trailers); + callBuffer.abandon(); + } + + /** + * A {@link ForwardingClientCallListener} that mutates the response headers before passing them to + * the delegate. + */ + private static final class HeaderMutatingClientCallListener + extends ForwardingClientCallListener.SimpleForwardingClientCallListener { + + private final ResponseHeaderMutations responseHeaderMutations; + private final HeaderMutator headerMutator; + + static ClientCall.Listener create(ClientCall.Listener delegate, + ResponseHeaderMutations responseHeaderMutations, HeaderMutator headerMutator) { + return new HeaderMutatingClientCallListener<>(delegate, responseHeaderMutations, + headerMutator); + } + + private HeaderMutatingClientCallListener(ClientCall.Listener delegate, + ResponseHeaderMutations responseHeaderMutations, HeaderMutator headerMutator) { + super(delegate); + this.responseHeaderMutations = responseHeaderMutations; + this.headerMutator = headerMutator; + } + + @Override + public void onHeaders(Metadata headers) { + headerMutator.applyResponseMutations(responseHeaderMutations, headers); + super.onHeaders(headers); + } + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/CallBuffer.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/CallBuffer.java new file mode 100644 index 00000000000..8732b4c522d --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/CallBuffer.java @@ -0,0 +1,87 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import javax.annotation.concurrent.ThreadSafe; + +/** + * A buffer for client calls that are pending an authorization decision. + */ +@ThreadSafe +final class CallBuffer { + + private final AtomicBoolean processed = new AtomicBoolean(false); + private final List bufferedCalls = new ArrayList<>(); + private final Object lock = new Object(); + + /** + * Buffers a runnable to be executed later. If the buffer has already been processed, the + * runnable is executed immediately. + * + * @param runnable the runnable to buffer. + */ + public void runOrBuffer(Runnable runnable) { + synchronized (lock) { + if (processed.get()) { + runnable.run(); + } else { + bufferedCalls.add(runnable); + } + } + } + + /** + * Executes all buffered runnables and marks the buffer as processed. + */ + public void runAndFlush() { + List toRun; + synchronized (lock) { + if (processed.getAndSet(true)) { + return; + } + toRun = new ArrayList<>(bufferedCalls); + bufferedCalls.clear(); + } + for (Runnable runnable : toRun) { + runnable.run(); + } + } + + /** + * Abandons all buffered runnables and marks the buffer as processed. + */ + public void abandon() { + synchronized (lock) { + if (processed.getAndSet(true)) { + return; + } + bufferedCalls.clear(); + } + } + + /** + * Returns whether the buffer has been processed. + * + * @return true if the buffer has been processed, false otherwise. + */ + public boolean isProcessed() { + return processed.get(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzClientInterceptor.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzClientInterceptor.java new file mode 100644 index 00000000000..5e982b5c29f --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzClientInterceptor.java @@ -0,0 +1,80 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.errorprone.annotations.ThreadSafe; +import io.envoyproxy.envoy.service.auth.v3.AuthorizationGrpc; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.MethodDescriptor; +import io.grpc.internal.FailingClientCall; +import io.grpc.xds.internal.Matchers.FractionMatcher; +import io.grpc.xds.internal.ThreadSafeRandom; +import io.grpc.xds.internal.headermutations.HeaderMutator; + +@ThreadSafe +public final class ExtAuthzClientInterceptor implements ClientInterceptor { + + /** A factory for creating {@link ExtAuthzClientInterceptor} instances. */ + @FunctionalInterface + public interface Factory { + ClientInterceptor create(ExtAuthzConfig config, AuthorizationGrpc.AuthorizationStub authzStub, + ThreadSafeRandom random, BufferingAuthzClientCall.Factory clientCallFactory, + CheckRequestBuilder checkRequestBuilder, CheckResponseHandler responseHandler, + HeaderMutator headerMutator); + } + + public static final Factory INSTANCE = ExtAuthzClientInterceptor::new; + + private final ExtAuthzConfig config; + private final AuthorizationGrpc.AuthorizationStub authzStub; + private final ThreadSafeRandom random; + private final BufferingAuthzClientCall.Factory clientCallFactory; + private final CheckRequestBuilder checkRequestBuilder; + private final CheckResponseHandler responseHandler; + private final HeaderMutator headerMutator; + + + private ExtAuthzClientInterceptor(ExtAuthzConfig config, + AuthorizationGrpc.AuthorizationStub authzStub, ThreadSafeRandom random, + BufferingAuthzClientCall.Factory clientCallFactory, CheckRequestBuilder checkRequestBuilder, + CheckResponseHandler responseHandler, HeaderMutator headerMutator) { + this.config = config; + this.random = random; + this.authzStub = authzStub; + this.clientCallFactory = clientCallFactory; + this.checkRequestBuilder = checkRequestBuilder; + this.responseHandler = responseHandler; + this.headerMutator = headerMutator; + } + + @Override + public ClientCall interceptCall(MethodDescriptor method, + CallOptions callOptions, Channel next) { + FractionMatcher filterEnabled = config.filterEnabled(); + if (random.nextInt(filterEnabled.denominator()) < filterEnabled.numerator()) { + if (config.denyAtDisable()) { + return new FailingClientCall<>(config.statusOnError()); + } + return next.newCall(method, callOptions); + } + return clientCallFactory.create(next.newCall(method, callOptions), config, authzStub, + checkRequestBuilder, responseHandler, headerMutator, method, new CallBuffer()); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/BufferingAuthzClientCallTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/BufferingAuthzClientCallTest.java new file mode 100644 index 00000000000..2704d18b4f3 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/BufferingAuthzClientCallTest.java @@ -0,0 +1,611 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Any; +import com.google.protobuf.Timestamp; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.service.auth.v3.AttributeContext; +import io.envoyproxy.envoy.service.auth.v3.AttributeContext.HttpRequest; +import io.envoyproxy.envoy.service.auth.v3.AttributeContext.Request; +import io.envoyproxy.envoy.service.auth.v3.AuthorizationGrpc; +import io.envoyproxy.envoy.service.auth.v3.CheckRequest; +import io.envoyproxy.envoy.service.auth.v3.CheckResponse; +import io.envoyproxy.envoy.service.auth.v3.OkHttpResponse; +import io.grpc.CallOptions; +import io.grpc.ClientCall; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.MethodDescriptor; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.ClientCalls; +import io.grpc.stub.MetadataUtils; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutator; +import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; +import org.mockito.stubbing.Answer; + +/** Unit tests for {@link BufferingAuthzClientCall}. */ +@RunWith(JUnit4.class) +public class BufferingAuthzClientCallTest { + + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + @Mock + private AuthorizationGrpc.AuthorizationImplBase authzService; + @Mock + private CheckRequestBuilder checkRequestBuilder; + @Mock + private CheckResponseHandler responseHandler; + @Mock + private HeaderMutator headerMutator; + + private ManagedChannel channel; + private Server server; + + private final AtomicReference serverHeadersCapture = new AtomicReference<>(); + private final AtomicReference clientHeadersCapture = new AtomicReference<>(); + private final AtomicReference clientTrailersCapture = new AtomicReference<>(); + + + @Before + public void setUp() throws IOException { + server = InProcessServerBuilder + .forName("in-process-server").addService(authzService).addService(ServerInterceptors + .intercept(new SimpleServiceImpl(), + new MetadataCapturingServerInterceptor(serverHeadersCapture))) + .directExecutor() + .build().start(); + channel = + InProcessChannelBuilder + .forName("in-process-server").intercept(MetadataUtils + .newCaptureMetadataInterceptor(clientHeadersCapture, clientTrailersCapture)) + .directExecutor() + .build(); + } + + @After + public void tearDown() { + server.shutdownNow(); + channel.shutdownNow(); + } + + @Test + public void onUnary_allowresponse() throws InterruptedException, ExtAuthzParseException { + ExtAuthzConfig config = buildExtAuthzConfig(false, false, 403); + + CheckRequest checkRequest = CheckRequest.newBuilder().setAttributes(AttributeContext + .newBuilder() + .setRequest(Request.newBuilder() + .setHttp(HttpRequest.newBuilder().setId("RequestId").build()).build()) + .build()) + .build(); + when(checkRequestBuilder.buildRequest(eq(SimpleServiceGrpc.getUnaryRpcMethod()), + any(Metadata.class), any(Timestamp.class))).thenReturn(checkRequest); + + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Status.Code.OK.value()).build()) + .setOkResponse(OkHttpResponse.getDefaultInstance()).build(); + doAnswer((Answer) invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(checkResponse); + observer.onCompleted(); + return null; + }).when(authzService).check(eq(checkRequest), + ArgumentMatchers.>any()); + + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("key1", Metadata.ASCII_STRING_MARSHALLER), "value1"); + ResponseHeaderMutations responseHeaderMutations = + ResponseHeaderMutations.create(ImmutableList.of()); + AuthzResponse allowResponse = + AuthzResponse.allow(metadata) + .setResponseHeaderMutations(responseHeaderMutations).build(); + when(responseHandler.handleResponse(eq(checkResponse), any(Metadata.class))) + .thenReturn(allowResponse); + + doAnswer((Answer) invocation -> { + Metadata headers = invocation.getArgument(1); + headers.put(Metadata.Key.of("key2", Metadata.ASCII_STRING_MARSHALLER), "value2"); + return null; + }).when(headerMutator).applyResponseMutations(eq(responseHeaderMutations), + any(Metadata.class)); + + ClientCall realCall = + channel.newCall(SimpleServiceGrpc.getUnaryRpcMethod(), CallOptions.DEFAULT); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + ClientCall call = BufferingAuthzClientCall.FACTORY_INSTANCE + .create(realCall, config, authzStub, checkRequestBuilder, responseHandler, headerMutator, + SimpleServiceGrpc.getUnaryRpcMethod(), new CallBuffer()); + SimpleServiceUnaryResponseObserver simpleServiceResponseObserver = + new SimpleServiceUnaryResponseObserver(); + SimpleRequest simpleRequest = SimpleRequest.newBuilder().setRequestMessage("World").build(); + ClientCalls.asyncUnaryCall(call, simpleRequest, simpleServiceResponseObserver); + simpleServiceResponseObserver.await(); + + assertThat(simpleServiceResponseObserver.getResponse().getResponseMessage()) + .isEqualTo("Hello World"); + assertThat( + serverHeadersCapture.get().get(Metadata.Key.of("key1", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("value1"); + assertThat( + clientHeadersCapture.get().get(Metadata.Key.of("key2", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("value2"); + assertThat(clientTrailersCapture.get()).isNotNull(); + verify(headerMutator).applyResponseMutations(eq(responseHeaderMutations), any(Metadata.class)); + } + + @Test + public void onUnary_denyResponse() throws InterruptedException, ExtAuthzParseException { + ExtAuthzConfig config = buildExtAuthzConfig(false, false, 403); + CheckRequest checkRequest = CheckRequest.newBuilder().build(); + when(checkRequestBuilder.buildRequest(eq(SimpleServiceGrpc.getUnaryRpcMethod()), + any(Metadata.class), any(Timestamp.class))).thenReturn(checkRequest); + + CheckResponse checkResponse = CheckResponse.newBuilder().setStatus( + com.google.rpc.Status.newBuilder().setCode(Status.Code.PERMISSION_DENIED.value()).build()) + .build(); + doAnswer((Answer) invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(checkResponse); + observer.onCompleted(); + return null; + }).when(authzService).check(eq(checkRequest), + ArgumentMatchers.>any()); + + Status expectedStatus = Status.PERMISSION_DENIED.withDescription("ext authz denied"); + AuthzResponse denyResponse = AuthzResponse.deny(expectedStatus).build(); + when(responseHandler.handleResponse(eq(checkResponse), any(Metadata.class))) + .thenReturn(denyResponse); + + ClientCall realCall = + channel.newCall(SimpleServiceGrpc.getUnaryRpcMethod(), CallOptions.DEFAULT); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + CallBuffer callBuffer = new CallBuffer(); + ClientCall call = BufferingAuthzClientCall.FACTORY_INSTANCE + .create(realCall, config, authzStub, checkRequestBuilder, responseHandler, headerMutator, + SimpleServiceGrpc.getUnaryRpcMethod(), callBuffer); + SimpleServiceUnaryResponseObserver simpleServiceResponseObserver = + new SimpleServiceUnaryResponseObserver(); + SimpleRequest simpleRequest = SimpleRequest.newBuilder().setRequestMessage("World").build(); + ClientCalls.asyncUnaryCall(call, + simpleRequest, simpleServiceResponseObserver); + simpleServiceResponseObserver.await(); + + assertThat(simpleServiceResponseObserver.getResponse()).isNull(); + assertThat(simpleServiceResponseObserver.getError()).isNotNull(); + assertThat(simpleServiceResponseObserver.getError().getCode()) + .isEqualTo(expectedStatus.getCode()); + assertThat(simpleServiceResponseObserver.getError().getDescription()) + .isEqualTo(expectedStatus.getDescription()); + assertThat(callBuffer.isProcessed()).isTrue(); + } + + @Test + public void onUnary_authzServerError_failTheCall() + throws InterruptedException, ExtAuthzParseException { + CheckRequest checkRequest = CheckRequest.newBuilder().build(); + when(checkRequestBuilder.buildRequest(eq(SimpleServiceGrpc.getUnaryRpcMethod()), + any(Metadata.class), any(Timestamp.class))).thenReturn(checkRequest); + + Status authzError = Status.UNAVAILABLE.withDescription("ext authz server unavailable"); + doAnswer((Answer) invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onError(authzError.asRuntimeException()); + return null; + }).when(authzService).check(eq(checkRequest), + ArgumentMatchers.>any()); + + ExtAuthzConfig config = buildExtAuthzConfig(false, false, 503); + + + ClientCall realCall = + channel.newCall(SimpleServiceGrpc.getUnaryRpcMethod(), CallOptions.DEFAULT); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + CallBuffer callBuffer = new CallBuffer(); + ClientCall call = BufferingAuthzClientCall.FACTORY_INSTANCE + .create(realCall, config, authzStub, checkRequestBuilder, responseHandler, headerMutator, + SimpleServiceGrpc.getUnaryRpcMethod(), callBuffer); + SimpleServiceUnaryResponseObserver simpleServiceResponseObserver = + new SimpleServiceUnaryResponseObserver(); + SimpleRequest simpleRequest = SimpleRequest.newBuilder().setRequestMessage("World").build(); + ClientCalls.asyncUnaryCall(call, simpleRequest, simpleServiceResponseObserver); + simpleServiceResponseObserver.await(); + + assertThat(simpleServiceResponseObserver.getResponse()).isNull(); + assertThat(simpleServiceResponseObserver.getError()).isNotNull(); + assertThat(simpleServiceResponseObserver.getError().getCode()) + .isEqualTo(Status.Code.UNAVAILABLE); + assertThat(callBuffer.isProcessed()).isTrue(); + verify(responseHandler, never()).handleResponse(any(CheckResponse.class), any(Metadata.class)); + } + + @Test + public void onUnary_authzServerError_failureModeAllow() + throws InterruptedException, ExtAuthzParseException { + ExtAuthzConfig config = buildExtAuthzConfig(true, false, 503); + CheckRequest checkRequest = CheckRequest.newBuilder().build(); + when(checkRequestBuilder.buildRequest(eq(SimpleServiceGrpc.getUnaryRpcMethod()), + any(Metadata.class), any(Timestamp.class))).thenReturn(checkRequest); + + Status authzError = Status.UNAVAILABLE.withDescription("authz server unavailable"); + doAnswer((Answer) invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onError(authzError.asRuntimeException()); + return null; + }).when(authzService).check(eq(checkRequest), + ArgumentMatchers.>any()); + + ClientCall realCall = + channel.newCall(SimpleServiceGrpc.getUnaryRpcMethod(), CallOptions.DEFAULT); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + ClientCall call = BufferingAuthzClientCall.FACTORY_INSTANCE + .create(realCall, config, authzStub, checkRequestBuilder, responseHandler, headerMutator, + SimpleServiceGrpc.getUnaryRpcMethod(), new CallBuffer()); + SimpleServiceUnaryResponseObserver simpleServiceResponseObserver = + new SimpleServiceUnaryResponseObserver(); + SimpleRequest simpleRequest = SimpleRequest.newBuilder().setRequestMessage("World").build(); + ClientCalls.asyncUnaryCall(call, simpleRequest, simpleServiceResponseObserver); + simpleServiceResponseObserver.await(); + + assertThat(simpleServiceResponseObserver.getResponse().getResponseMessage()) + .isEqualTo("Hello World"); + assertThat( + serverHeadersCapture.get().get(Metadata.Key.of("key1", Metadata.ASCII_STRING_MARSHALLER))) + .isNull(); + assertThat( + clientHeadersCapture.get().get(Metadata.Key.of("key2", Metadata.ASCII_STRING_MARSHALLER))) + .isNull(); + assertThat(clientTrailersCapture.get()).isNotNull(); + verify(headerMutator, never()).applyResponseMutations(any(ResponseHeaderMutations.class), + any(Metadata.class)); + verify(responseHandler, never()).handleResponse(any(CheckResponse.class), any(Metadata.class)); + } + + @Test + public void onUnary_authzServerError_failureModeAllowHeaderAdd() + throws InterruptedException, ExtAuthzParseException { + ExtAuthzConfig config = buildExtAuthzConfig(true, true, 503); + CheckRequest checkRequest = CheckRequest.newBuilder().build(); + when(checkRequestBuilder.buildRequest(eq(SimpleServiceGrpc.getUnaryRpcMethod()), + any(Metadata.class), any(Timestamp.class))).thenReturn(checkRequest); + + Status authzError = Status.UNAVAILABLE.withDescription("authz server unavailable"); + doAnswer((Answer) invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onError(authzError.asRuntimeException()); + return null; + }).when(authzService).check(eq(checkRequest), + ArgumentMatchers.>any()); + + ClientCall realCall = + channel.newCall(SimpleServiceGrpc.getUnaryRpcMethod(), CallOptions.DEFAULT); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + ClientCall call = BufferingAuthzClientCall.FACTORY_INSTANCE + .create(realCall, config, authzStub, checkRequestBuilder, responseHandler, headerMutator, + SimpleServiceGrpc.getUnaryRpcMethod(), new CallBuffer()); + SimpleServiceUnaryResponseObserver simpleServiceResponseObserver = + new SimpleServiceUnaryResponseObserver(); + SimpleRequest simpleRequest = SimpleRequest.newBuilder().setRequestMessage("World").build(); + ClientCalls.asyncUnaryCall(call, simpleRequest, simpleServiceResponseObserver); + simpleServiceResponseObserver.await(); + + assertThat(simpleServiceResponseObserver.getResponse().getResponseMessage()) + .isEqualTo("Hello World"); + assertThat(serverHeadersCapture.get().get( + Metadata.Key.of("x-envoy-auth-failure-mode-allowed", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("true"); + assertThat(clientTrailersCapture.get()).isNotNull(); + verify(headerMutator, never()).applyResponseMutations(any(ResponseHeaderMutations.class), + any(Metadata.class)); + verify(responseHandler, never()).handleResponse(any(CheckResponse.class), any(Metadata.class)); + } + + @Test + public void onStreaming_allowResponse() throws InterruptedException, ExtAuthzParseException { + ExtAuthzConfig config = buildExtAuthzConfig(false, false, 403); + MethodDescriptor streamingMethod = + SimpleServiceGrpc.getBidiStreamingRpcMethod(); + CheckRequest checkRequest = CheckRequest.newBuilder() + .setAttributes(AttributeContext.newBuilder() + .setRequest(Request.newBuilder() + .setHttp(HttpRequest.newBuilder().setId("RequestId").build()).build()) + .build()) + .build(); + when(checkRequestBuilder.buildRequest(eq(streamingMethod), any(Metadata.class), + any(Timestamp.class))).thenReturn(checkRequest); + + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Status.Code.OK.value()).build()) + .setOkResponse(OkHttpResponse.getDefaultInstance()).build(); + doAnswer((Answer) invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(checkResponse); + observer.onCompleted(); + return null; + }).when(authzService).check(eq(checkRequest), + ArgumentMatchers.>any()); + + Metadata metadata = new Metadata(); + metadata.put(Metadata.Key.of("key1", Metadata.ASCII_STRING_MARSHALLER), "value1"); + ResponseHeaderMutations responseHeaderMutations = + ResponseHeaderMutations.create(ImmutableList.of()); + AuthzResponse allowResponse = + AuthzResponse.allow(metadata).setResponseHeaderMutations(responseHeaderMutations).build(); + when(responseHandler.handleResponse(eq(checkResponse), any(Metadata.class))) + .thenReturn(allowResponse); + + doAnswer((Answer) invocation -> { + Metadata headers = invocation.getArgument(1); + headers.put(Metadata.Key.of("key2", Metadata.ASCII_STRING_MARSHALLER), "value2"); + return null; + }).when(headerMutator).applyResponseMutations(eq(responseHeaderMutations), any(Metadata.class)); + + ClientCall realCall = + channel.newCall(streamingMethod, CallOptions.DEFAULT); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + ClientCall call = + BufferingAuthzClientCall.FACTORY_INSTANCE.create(realCall, config, authzStub, + checkRequestBuilder, responseHandler, headerMutator, streamingMethod, new CallBuffer()); + SimpleServiceStreamingResponseObserver simpleServiceResponseObserver = + new SimpleServiceStreamingResponseObserver(); + StreamObserver requestObserver = + ClientCalls.asyncBidiStreamingCall(call, simpleServiceResponseObserver); + requestObserver.onNext(SimpleRequest.newBuilder().setRequestMessage("World").build()); + requestObserver.onNext(SimpleRequest.newBuilder().setRequestMessage("gRPC").build()); + requestObserver.onCompleted(); + simpleServiceResponseObserver.await(); + + assertThat(simpleServiceResponseObserver.getResponses()) + .containsExactly(SimpleResponse.newBuilder().setResponseMessage("Hello World").build(), + SimpleResponse.newBuilder().setResponseMessage("Hello gRPC").build()) + .inOrder(); + assertThat( + serverHeadersCapture.get().get(Metadata.Key.of("key1", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("value1"); + assertThat( + clientHeadersCapture.get().get(Metadata.Key.of("key2", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("value2"); + assertThat(clientTrailersCapture.get()).isNotNull(); + verify(headerMutator).applyResponseMutations(eq(responseHeaderMutations), any(Metadata.class)); + } + + @Test + public void onStreaming_denyResponse() throws InterruptedException, ExtAuthzParseException { + ExtAuthzConfig config = buildExtAuthzConfig(false, false, 403); + MethodDescriptor streamingMethod = + SimpleServiceGrpc.getBidiStreamingRpcMethod(); + CheckRequest checkRequest = CheckRequest.newBuilder().build(); + when(checkRequestBuilder.buildRequest(eq(streamingMethod), any(Metadata.class), + any(Timestamp.class))).thenReturn(checkRequest); + + CheckResponse checkResponse = CheckResponse.newBuilder().setStatus( + com.google.rpc.Status.newBuilder().setCode(Status.Code.PERMISSION_DENIED.value()).build()) + .build(); + doAnswer((Answer) invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(checkResponse); + observer.onCompleted(); + return null; + }).when(authzService).check(eq(checkRequest), + ArgumentMatchers.>any()); + + Status expectedStatus = Status.PERMISSION_DENIED.withDescription("ext authz denied"); + AuthzResponse denyResponse = AuthzResponse.deny(expectedStatus).build(); + when(responseHandler.handleResponse(eq(checkResponse), any(Metadata.class))) + .thenReturn(denyResponse); + + ClientCall realCall = + channel.newCall(streamingMethod, CallOptions.DEFAULT); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + CallBuffer callBuffer = new CallBuffer(); + ClientCall call = + BufferingAuthzClientCall.FACTORY_INSTANCE.create(realCall, config, authzStub, + checkRequestBuilder, responseHandler, headerMutator, streamingMethod, callBuffer); + SimpleServiceStreamingResponseObserver simpleServiceResponseObserver = + new SimpleServiceStreamingResponseObserver(); + StreamObserver requestObserver = + ClientCalls.asyncBidiStreamingCall(call, simpleServiceResponseObserver); + requestObserver.onNext(SimpleRequest.newBuilder().setRequestMessage("World").build()); + requestObserver.onNext(SimpleRequest.newBuilder().setRequestMessage("gRPC").build()); + requestObserver.onCompleted(); + simpleServiceResponseObserver.await(); + + assertThat(simpleServiceResponseObserver.getResponses()).isEmpty(); + assertThat(simpleServiceResponseObserver.getError()).isNotNull(); + assertThat(simpleServiceResponseObserver.getError().getCode()) + .isEqualTo(expectedStatus.getCode()); + assertThat(simpleServiceResponseObserver.getError().getDescription()) + .isEqualTo(expectedStatus.getDescription()); + assertThat(callBuffer.isProcessed()).isTrue(); + } + + private ExtAuthzConfig buildExtAuthzConfig(boolean failureModeAllow, + boolean failureModeAllowHeaderAdd, int httpStatusOnError) throws ExtAuthzParseException { + Any googleDefaultChannelCreds = Any.pack(GoogleDefaultCredentials.newBuilder().build()); + Any fakeAccessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); + ExtAuthz.Builder builder = ExtAuthz.newBuilder() + .setGrpcService(GrpcService.newBuilder() + .setGoogleGrpc(GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test-cluster").addChannelCredentialsPlugin(googleDefaultChannelCreds) + .addCallCredentialsPlugin(fakeAccessTokenCreds).build()) + .build()) + .setFailureModeAllow(failureModeAllow) + .setFailureModeAllowHeaderAdd(failureModeAllowHeaderAdd) + .setStatusOnError(io.envoyproxy.envoy.type.v3.HttpStatus.newBuilder() + .setCodeValue(httpStatusOnError).build()) + .setIncludePeerCertificate(true); + return ExtAuthzConfig.fromProto(builder.build()); + } + + private static class SimpleServiceUnaryResponseObserver + implements StreamObserver { + final AtomicReference responseCapture = new AtomicReference<>(); + final AtomicReference errorCapture = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(1); + + @Override + public void onNext(SimpleResponse value) { + responseCapture.set(value); + } + + @Override + public void onError(Throwable t) { + errorCapture.set(Status.fromThrowable(t)); + latch.countDown(); + } + + @Override + public void onCompleted() { + latch.countDown(); + } + + public void await() throws InterruptedException { + latch.await(5, TimeUnit.SECONDS); + } + + public SimpleResponse getResponse() { + return responseCapture.get(); + } + + public Status getError() { + return errorCapture.get(); + } + } + + private static class SimpleServiceStreamingResponseObserver + implements StreamObserver { + final ImmutableList.Builder responsesCapture = new ImmutableList.Builder<>(); + final AtomicReference errorCapture = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(1); + + @Override + public void onNext(SimpleResponse value) { + responsesCapture.add(value); + } + + @Override + public void onError(Throwable t) { + errorCapture.set(Status.fromThrowable(t)); + latch.countDown(); + } + + @Override + public void onCompleted() { + latch.countDown(); + } + + public void await() throws InterruptedException { + latch.await(5, TimeUnit.SECONDS); + } + + public ImmutableList getResponses() { + return responsesCapture.build(); + } + + public Status getError() { + return errorCapture.get(); + } + } + + private static final class MetadataCapturingServerInterceptor implements ServerInterceptor { + + final AtomicReference headersCapture; + + MetadataCapturingServerInterceptor(AtomicReference headersCapture) { + this.headersCapture = headersCapture; + } + + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + Metadata metadataCopy = new Metadata(); + metadataCopy.merge(headers); + headersCapture.set(metadataCopy); + return next.startCall(call, headers); + } + } + + private static class SimpleServiceImpl extends SimpleServiceGrpc.SimpleServiceImplBase { + @Override + public void unaryRpc(SimpleRequest request, StreamObserver streamObserver) { + streamObserver.onNext(SimpleResponse.newBuilder() + .setResponseMessage("Hello " + request.getRequestMessage()).build()); + streamObserver.onCompleted(); + } + + @Override + public StreamObserver bidiStreamingRpc( + final StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(SimpleRequest request) { + responseObserver.onNext(SimpleResponse.newBuilder() + .setResponseMessage("Hello " + request.getRequestMessage()).build()); + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/CallBufferTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/CallBufferTest.java new file mode 100644 index 00000000000..15baf213be8 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/CallBufferTest.java @@ -0,0 +1,190 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.InOrder; + +@RunWith(JUnit4.class) +public class CallBufferTest { + + private CallBuffer callBuffer; + + @Before + public void setUp() { + callBuffer = new CallBuffer(); + } + + @Test + public void runOrBuffer_beforeProcessed_buffersCall() { + Runnable runnable = mock(Runnable.class); + callBuffer.runOrBuffer(runnable); + verify(runnable, never()).run(); + assertThat(callBuffer.isProcessed()).isFalse(); + } + + @Test + public void runAndFlush_executesBufferedCallsInOrder() { + Runnable runnable1 = mock(Runnable.class); + Runnable runnable2 = mock(Runnable.class); + callBuffer.runOrBuffer(runnable1); + callBuffer.runOrBuffer(runnable2); + + InOrder inOrder = inOrder(runnable1, runnable2); + verify(runnable1, never()).run(); + verify(runnable2, never()).run(); + + callBuffer.runAndFlush(); + + inOrder.verify(runnable1).run(); + inOrder.verify(runnable2).run(); + assertThat(callBuffer.isProcessed()).isTrue(); + } + + @Test + public void runOrBuffer_afterRunAndFlush_runsImmediately() { + callBuffer.runAndFlush(); + assertThat(callBuffer.isProcessed()).isTrue(); + + Runnable runnable = mock(Runnable.class); + callBuffer.runOrBuffer(runnable); + verify(runnable).run(); + } + + @Test + public void abandon_discardsBufferedCalls() { + Runnable runnable = mock(Runnable.class); + callBuffer.runOrBuffer(runnable); + verify(runnable, never()).run(); + + callBuffer.abandon(); + assertThat(callBuffer.isProcessed()).isTrue(); + + // Another flush should not run the abandoned runnable + callBuffer.runAndFlush(); + verify(runnable, never()).run(); + } + + @Test + public void runOrBuffer_afterAbandon_runsImmediately() { + callBuffer.abandon(); + assertThat(callBuffer.isProcessed()).isTrue(); + + Runnable runnable = mock(Runnable.class); + callBuffer.runOrBuffer(runnable); + verify(runnable).run(); + } + + @Test + public void runAndFlush_isIdempotent() { + Runnable runnable = mock(Runnable.class); + callBuffer.runOrBuffer(runnable); + + callBuffer.runAndFlush(); + callBuffer.runAndFlush(); + + verify(runnable, times(1)).run(); + } + + @Test + public void abandon_isIdempotent() { + Runnable runnable = mock(Runnable.class); + callBuffer.runOrBuffer(runnable); + + callBuffer.abandon(); + callBuffer.abandon(); + + verify(runnable, never()).run(); + } + + @Test + public void abandon_afterRunAndFlush_isNoOp() { + Runnable runnable = mock(Runnable.class); + callBuffer.runOrBuffer(runnable); + + callBuffer.runAndFlush(); + callBuffer.abandon(); + + verify(runnable, times(1)).run(); + } + + @Test + public void runAndFlush_afterAbandon_isNoOp() { + Runnable runnable = mock(Runnable.class); + callBuffer.runOrBuffer(runnable); + + callBuffer.abandon(); + callBuffer.runAndFlush(); + + verify(runnable, never()).run(); + } + + // TODO(sauravzg): How to remove dependency on time using explicit synchronization? + @Test + public void concurrentRunOrBuffer_thenRunAndFlush() throws Exception { + int numThreads = 10; + ExecutorService executor = Executors.newFixedThreadPool(numThreads); + List runnables = new ArrayList<>(); + List> futures = new ArrayList<>(); + for (int i = 0; i < numThreads; i++) { + runnables.add(mock(Runnable.class)); + } + CountDownLatch latch = new CountDownLatch(numThreads); + + for (Runnable runnable : runnables) { + futures.add(executor.submit(() -> { + callBuffer.runOrBuffer(runnable); + latch.countDown(); + })); + } + + latch.await(5, TimeUnit.SECONDS); + for (Runnable runnable : runnables) { + verify(runnable, never()).run(); + } + + callBuffer.runAndFlush(); + + for (Runnable runnable : runnables) { + verify(runnable).run(); + } + + for (Future future : futures) { + future.get(); + } + + executor.shutdown(); + executor.awaitTermination(5, TimeUnit.SECONDS); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzClientInterceptorTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzClientInterceptorTest.java new file mode 100644 index 00000000000..cee7681ddb6 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzClientInterceptorTest.java @@ -0,0 +1,175 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.config.core.v3.RuntimeFeatureFlag; +import io.envoyproxy.envoy.config.core.v3.RuntimeFractionalPercent; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.service.auth.v3.AuthorizationGrpc; +import io.envoyproxy.envoy.type.v3.FractionalPercent; +import io.grpc.CallOptions; +import io.grpc.Channel; +import io.grpc.ClientCall; +import io.grpc.ClientInterceptor; +import io.grpc.MethodDescriptor; +import io.grpc.internal.FailingClientCall; +import io.grpc.testing.TestMethodDescriptors; +import io.grpc.xds.internal.ThreadSafeRandom; +import io.grpc.xds.internal.headermutations.HeaderMutator; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class ExtAuthzClientInterceptorTest { + + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + @Mock + private CheckResponseHandler responseHandler; + @Mock + private HeaderMutator headerMutator; + + @Mock + private CheckRequestBuilder checkRequestBuilder; + + @Mock + ClientCall expectedCall; + + @Mock + ClientCall nextCall; + + private AuthorizationGrpc.AuthorizationStub authzStub; + + @Mock + private ThreadSafeRandom random; + + @Mock + private BufferingAuthzClientCall.Factory clientCallFactory; + + @Mock + private Channel next; + + private MethodDescriptor method = TestMethodDescriptors.voidMethod(); + private CallOptions callOptions = CallOptions.DEFAULT; + + @Before + public void setUp() { + authzStub = AuthorizationGrpc.newStub(mock(Channel.class)); + } + + @Test + public void interceptCall_denyAtDisable() throws ExtAuthzParseException { + when(random.nextInt(100)).thenReturn(50); + ExtAuthz extAuthzProto = ExtAuthz.newBuilder().setGrpcService(GrpcService.newBuilder() + .setGoogleGrpc(GrpcService.GoogleGrpc.newBuilder().setTargetUri("test-cluster") + .addChannelCredentialsPlugin(Any.pack(GoogleDefaultCredentials.newBuilder().build())) + .addCallCredentialsPlugin( + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build())) + .build()) + .build()) + .setFilterEnabled(RuntimeFractionalPercent.newBuilder() + .setDefaultValue(FractionalPercent.newBuilder().setNumerator(100) + .setDenominator(FractionalPercent.DenominatorType.HUNDRED).build()) + .build()) + .setDenyAtDisable( + RuntimeFeatureFlag.newBuilder().setDefaultValue(BoolValue.of(true)).build()) + .setStatusOnError( + io.envoyproxy.envoy.type.v3.HttpStatus.newBuilder().setCodeValue(403).build()) + .build(); + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthzProto); + ClientInterceptor interceptor = ExtAuthzClientInterceptor.INSTANCE.create(config, authzStub, + random, clientCallFactory, checkRequestBuilder, responseHandler, headerMutator); + + ClientCall call = interceptor.interceptCall(method, callOptions, next); + + assertThat(call).isInstanceOf(FailingClientCall.class); + } + + @Test + public void interceptCall_delegateToRealCall() throws ExtAuthzParseException { + when(random.nextInt(100)).thenReturn(50); + ExtAuthz extAuthzProto = ExtAuthz.newBuilder() + .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder() + .setGoogleGrpc(io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test-cluster") + .addChannelCredentialsPlugin( + Any.pack(GoogleDefaultCredentials.newBuilder().build())) + .addCallCredentialsPlugin( + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build())) + .build())) + .setFilterEnabled(RuntimeFractionalPercent.newBuilder() + .setDefaultValue(FractionalPercent.newBuilder().setNumerator(100) + .setDenominator(FractionalPercent.DenominatorType.HUNDRED).build()) + .build()) + .build(); + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthzProto); + ClientInterceptor interceptor = ExtAuthzClientInterceptor.INSTANCE.create(config, authzStub, + random, clientCallFactory, checkRequestBuilder, responseHandler, headerMutator); + when(next.newCall(method, callOptions)).thenReturn(expectedCall); + + ClientCall call = interceptor.interceptCall(method, callOptions, next); + + assertThat(call).isSameInstanceAs(expectedCall); + } + + @Test + public void interceptCall_factoryCreatesCall() throws ExtAuthzParseException { + ExtAuthz extAuthzProto = ExtAuthz.newBuilder() + .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder() + .setGoogleGrpc(io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("test-cluster") + .addChannelCredentialsPlugin( + Any.pack(GoogleDefaultCredentials.newBuilder().build())) + .addCallCredentialsPlugin( + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build())) + .build()) + .build()) + .setFilterEnabled(RuntimeFractionalPercent.newBuilder() + .setDefaultValue(FractionalPercent.newBuilder().setNumerator(0) + .setDenominator(FractionalPercent.DenominatorType.HUNDRED).build()) + .build()) + .build(); + ExtAuthzConfig config = ExtAuthzConfig.fromProto(extAuthzProto); + when(random.nextInt(100)).thenReturn(50); + ClientInterceptor interceptor = ExtAuthzClientInterceptor.INSTANCE.create(config, authzStub, + random, clientCallFactory, checkRequestBuilder, responseHandler, headerMutator); + when(next.newCall(method, callOptions)).thenReturn(nextCall); + when(clientCallFactory.create(eq(nextCall), eq(config), eq(authzStub), eq(checkRequestBuilder), + eq(responseHandler), eq(headerMutator), eq(method), any(CallBuffer.class))) + .thenReturn(expectedCall); + ClientCall call = interceptor.interceptCall(method, callOptions, next); + assertThat(call).isSameInstanceAs(expectedCall); + } +} From 475cc91f6f7cbbda0d5bea3580cb89ed58db2163 Mon Sep 17 00:00:00 2001 From: Saurav Date: Thu, 23 Oct 2025 10:26:01 +0000 Subject: [PATCH 6/7] feat(xds): Add ExtAuthzServerInterceptor and tests This commit introduces the `ExtAuthzServerInterceptor`, a server interceptor that performs external authorization for incoming RPCs. The interceptor checks if the external authorization filter is enabled. If it is, it calls the external authorization service and handles the response. It supports both unary and streaming RPCs. The interceptor handles the following scenarios: - Allow responses: The RPC is allowed to proceed. - Deny responses: The RPC is denied with a `PERMISSION_DENIED` status. - Authorization service errors: The RPC is either denied or allowed to proceed based on the `failure_mode_allow` configuration. This commit also includes comprehensive integration tests for the `ExtAuthzServerInterceptor`, covering various scenarios and configurations. --- .../extauthz/ExtAuthzServerInterceptor.java | 236 +++++++ .../ExtAuthzServerInterceptorTest.java | 583 ++++++++++++++++++ 2 files changed, 819 insertions(+) create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzServerInterceptor.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzServerInterceptorTest.java diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzServerInterceptor.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzServerInterceptor.java new file mode 100644 index 00000000000..4f5a97d367f --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/ExtAuthzServerInterceptor.java @@ -0,0 +1,236 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.protobuf.util.Timestamps; +import io.envoyproxy.envoy.service.auth.v3.AuthorizationGrpc; +import io.envoyproxy.envoy.service.auth.v3.CheckRequest; +import io.envoyproxy.envoy.service.auth.v3.CheckResponse; +import io.grpc.ForwardingServerCall.SimpleForwardingServerCall; +import io.grpc.ForwardingServerCallListener; +import io.grpc.Metadata; +import io.grpc.ServerCall; +import io.grpc.ServerCall.Listener; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.stub.StreamObserver; +import io.grpc.xds.internal.Matchers.FractionMatcher; +import io.grpc.xds.internal.ThreadSafeRandom; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutator; +import java.util.concurrent.atomic.AtomicReference; + +/** + * A server interceptor that performs external authorization for incoming RPCs. + */ +public final class ExtAuthzServerInterceptor implements ServerInterceptor { + + /** + * A factory for creating {@link ExtAuthzServerInterceptor} instances. + */ + @FunctionalInterface + public interface Factory { + /** + * Creates a new {@link ExtAuthzServerInterceptor}. + * + * @param config the external authorization configuration. + * @param authzStub the gRPC stub for the authorization service. + * @param random the random number generator for filter matching. + * @param checkRequestBuilder the builder for creating authorization check requests. + * @param responseHandler the handler for processing authorization responses. + * @param headerMutator the mutator for applying header mutations. + * @return a new {@link ServerInterceptor}. + */ + ServerInterceptor create(ExtAuthzConfig config, AuthorizationGrpc.AuthorizationStub authzStub, + ThreadSafeRandom random, CheckRequestBuilder checkRequestBuilder, + CheckResponseHandler responseHandler, HeaderMutator headerMutator); + } + + /** + * A factory for creating {@link ExtAuthzServerInterceptor} instances. This is the only supported + * way to create a new ExtAuthzServerInterceptor. + */ + public static final Factory INSTANCE = ExtAuthzServerInterceptor::new; + + private final ExtAuthzConfig config; + private final AuthorizationGrpc.AuthorizationStub authzStub; + private final ThreadSafeRandom random; + private final CheckRequestBuilder checkRequestBuilder; + private final CheckResponseHandler responseHandler; + private final HeaderMutator headerMutator; + + private ExtAuthzServerInterceptor(ExtAuthzConfig config, + AuthorizationGrpc.AuthorizationStub authzStub, ThreadSafeRandom random, + CheckRequestBuilder checkRequestBuilder, + CheckResponseHandler responseHandler, HeaderMutator headerMutator) { + this.config = config; + this.random = random; + this.authzStub = authzStub; + this.checkRequestBuilder = checkRequestBuilder; + this.responseHandler = responseHandler; + this.headerMutator = headerMutator; + } + + /** + * Intercepts an incoming call to perform external authorization. + * + * @param call the server call to intercept. + * @param headers the headers of the incoming call. + * @param next the next handler in the chain. + * @return a listener for the server call. + */ + @Override + public ServerCall.Listener interceptCall(ServerCall call, + final Metadata headers, ServerCallHandler next) { + FractionMatcher filterEnabled = config.filterEnabled(); + if (random.nextInt(filterEnabled.denominator()) < filterEnabled.numerator()) { + if (config.denyAtDisable()) { + call.close(config.statusOnError(), new Metadata()); + return new ServerCall.Listener() {}; + } + return next.startCall(call, headers); + } + ExtAuthzForwardingListener listener = new ExtAuthzForwardingListener<>(config, + authzStub, headers, call, next, checkRequestBuilder, responseHandler, headerMutator); + listener.startAuthzCall(); + return listener; + } + + /** + * A forwarding server call listener that handles the external authorization process. + */ + private static final class ExtAuthzForwardingListener + extends ForwardingServerCallListener { + private static final String X_ENVOY_AUTH_FAILURE_MODE_ALLOWED = + "x-envoy-auth-failure-mode-allowed"; + + private final ExtAuthzConfig config; + private final AuthorizationGrpc.AuthorizationStub authzStub; + private final Metadata headers; + private final ServerCall realServerCall; + private final ServerCallHandler serverCallHandler; + private final CheckRequestBuilder checkRequestBuilder; + private final CheckResponseHandler responseHandler; + private final HeaderMutator headerMutator; + private final AtomicReference> delegateListener; + + /** + * Constructs a new {@link ExtAuthzForwardingListener}. + */ + ExtAuthzForwardingListener(ExtAuthzConfig config, + AuthorizationGrpc.AuthorizationStub authzStub, Metadata headers, + ServerCall serverCall, ServerCallHandler serverCallHandler, + CheckRequestBuilder checkRequestBuilder, CheckResponseHandler responseHandler, + HeaderMutator headerMutator) { + this.config = config; + this.authzStub = authzStub; + this.headers = headers; + this.realServerCall = serverCall; + this.serverCallHandler = serverCallHandler; + this.checkRequestBuilder = checkRequestBuilder; + this.responseHandler = responseHandler; + this.headerMutator = headerMutator; + this.delegateListener = + new AtomicReference>(new ServerCall.Listener() {}); + } + + /** + * Starts the external authorization call. + */ + void startAuthzCall() { + CheckRequest checkRequest = checkRequestBuilder.buildRequest(realServerCall, headers, + Timestamps.fromMillis(System.currentTimeMillis())); + StreamObserver observer = new StreamObserver() { + @Override + public void onNext(CheckResponse value) { + // The handleResponse method may add or modify headers based on the authorization + // response. + AuthzResponse authzResponse = responseHandler.handleResponse(value, headers); + if (authzResponse.decision() == AuthzResponse.Decision.ALLOW) { + AuthzServerCall authzServerCall = new AuthzServerCall<>(realServerCall, + authzResponse.responseHeaderMutations(), headerMutator); + delegateListener + .set(serverCallHandler.startCall(authzServerCall, authzResponse.headers().get())); + } else { + // A deny response is guaranteed to have a status set, so the `get` without + // check is safe. + realServerCall.close(authzResponse.status().get(), new Metadata()); + } + } + + @Override + public void onError(Throwable t) { + if (config.failureModeAllow()) { + if (config.failureModeAllowHeaderAdd()) { + Metadata.Key key = Metadata.Key.of(X_ENVOY_AUTH_FAILURE_MODE_ALLOWED, + Metadata.ASCII_STRING_MARSHALLER); + headers.put(key, "true"); + } + delegateListener.set(serverCallHandler.startCall(realServerCall, headers)); + } else { + realServerCall.close(config.statusOnError().withCause(t), new Metadata()); + } + } + + @Override + public void onCompleted() { + // No-op. The authorization service uses a unary RPC, so we only expect one response. + } + }; + authzStub.check(checkRequest, observer); + } + + @Override + protected Listener delegate() { + return delegateListener.get(); + } + } + + /** + * A server call that applies response header mutations from the authorization service. + */ + private static class AuthzServerCall + extends SimpleForwardingServerCall { + private final ResponseHeaderMutations responseHeaderMutations; + private final HeaderMutator headerMutator; + + /** + * Constructs a new {@link AuthzServerCall}. + * + * @param delegate the original server call. + * @param responseHeaderMutations the response header mutations to apply. + * @param headerMutator the mutator for applying header mutations. + */ + private AuthzServerCall(ServerCall delegate, + ResponseHeaderMutations responseHeaderMutations, HeaderMutator headerMutator) { + super(delegate); + this.responseHeaderMutations = responseHeaderMutations; + this.headerMutator = headerMutator; + } + + /** + * Sends the headers after applying any mutations from the authorization service. + * + * @param headers the headers to send. + */ + @Override + public void sendHeaders(Metadata headers) { + headerMutator.applyResponseMutations(responseHeaderMutations, headers); + super.sendHeaders(headers); + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzServerInterceptorTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzServerInterceptorTest.java new file mode 100644 index 00000000000..a86bcbdc52b --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/ExtAuthzServerInterceptorTest.java @@ -0,0 +1,583 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.common.collect.ImmutableList; +import com.google.protobuf.Any; +import com.google.protobuf.BoolValue; +import com.google.protobuf.Timestamp; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.config.core.v3.RuntimeFeatureFlag; +import io.envoyproxy.envoy.config.core.v3.RuntimeFractionalPercent; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.service.auth.v3.AuthorizationGrpc; +import io.envoyproxy.envoy.service.auth.v3.CheckRequest; +import io.envoyproxy.envoy.service.auth.v3.CheckResponse; +import io.envoyproxy.envoy.type.v3.FractionalPercent; +import io.envoyproxy.envoy.type.v3.FractionalPercent.DenominatorType; +import io.envoyproxy.envoy.type.v3.HttpStatus; +import io.grpc.ManagedChannel; +import io.grpc.Metadata; +import io.grpc.Server; +import io.grpc.ServerCall; +import io.grpc.ServerCallHandler; +import io.grpc.ServerInterceptor; +import io.grpc.ServerInterceptors; +import io.grpc.Status; +import io.grpc.inprocess.InProcessChannelBuilder; +import io.grpc.inprocess.InProcessServerBuilder; +import io.grpc.stub.ClientCalls; +import io.grpc.stub.MetadataUtils; +import io.grpc.stub.StreamObserver; +import io.grpc.testing.protobuf.SimpleRequest; +import io.grpc.testing.protobuf.SimpleResponse; +import io.grpc.testing.protobuf.SimpleServiceGrpc; +import io.grpc.xds.internal.ThreadSafeRandom; +import io.grpc.xds.internal.headermutations.HeaderMutations.ResponseHeaderMutations; +import io.grpc.xds.internal.headermutations.HeaderMutator; +import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; +import org.junit.After; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** Integration tests for {@link ExtAuthzServerInterceptor}. */ +@RunWith(JUnit4.class) +public class ExtAuthzServerInterceptorTest { + + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + @Mock + private ThreadSafeRandom mockRandom; + @Mock + private CheckRequestBuilder mockCheckRequestBuilder; + @Mock + private CheckResponseHandler mockResponseHandler; + @Mock + private HeaderMutator mockHeaderMutator; + @Mock + private AuthorizationGrpc.AuthorizationImplBase authzService; + + private Server server; + private ManagedChannel channel; + private final AtomicReference serverHeadersCapture = new AtomicReference<>(); + private final AtomicReference clientResponseHeadersCapture = new AtomicReference<>(); + private final AtomicReference clientResponseTrailersCapture = new AtomicReference<>(); + + private final SimpleServiceGrpc.SimpleServiceImplBase simpleServiceImpl = + new SimpleServiceGrpc.SimpleServiceImplBase() { + @Override + public void unaryRpc(SimpleRequest request, + StreamObserver responseObserver) { + responseObserver.onNext(SimpleResponse.newBuilder() + .setResponseMessage("Hello " + request.getRequestMessage()).build()); + responseObserver.onCompleted(); + } + + @Override + public StreamObserver bidiStreamingRpc( + final StreamObserver responseObserver) { + return new StreamObserver() { + @Override + public void onNext(SimpleRequest request) { + responseObserver.onNext(SimpleResponse.newBuilder() + .setResponseMessage("Hello " + request.getRequestMessage()).build()); + } + + @Override + public void onError(Throwable t) { + responseObserver.onError(t); + } + + @Override + public void onCompleted() { + responseObserver.onCompleted(); + } + }; + } + }; + + @Before + public void setUp() throws IOException {} + + @After + public void tearDown() { + if (server != null) { + server.shutdownNow(); + } + if (channel != null) { + channel.shutdownNow(); + } + } + + @Test + public void interceptCall_allow() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(false, false, 403, false, 0); + String serverName = InProcessServerBuilder.generateName(); + channel = buildChannel(serverName); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + server = buildAndStartServer(config, authzStub, serverName); + CheckRequest checkRequest = CheckRequest.getDefaultInstance(); + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Status.Code.OK.value())).build(); + ResponseHeaderMutations responseHeaderMutations = + ResponseHeaderMutations.create(ImmutableList.of()); + setUpAllowCheck(checkRequest, checkResponse, responseHeaderMutations); + + SimpleServiceUnaryResponseObserver responseObserver = new SimpleServiceUnaryResponseObserver(); + ClientCalls.asyncUnaryCall( + channel.newCall(SimpleServiceGrpc.getUnaryRpcMethod(), io.grpc.CallOptions.DEFAULT), + SimpleRequest.newBuilder().setRequestMessage("world").build(), responseObserver); + responseObserver.await(); + assertThat(responseObserver.getResponse().getResponseMessage()).isEqualTo("Hello world"); + assertThat(responseObserver.getError()).isNull(); + assertThat(serverHeadersCapture.get() + .get(Metadata.Key.of("auth-key", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("auth-value"); + assertThat(clientResponseHeadersCapture.get() + .get(Metadata.Key.of("client-resp-key", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("client-resp-value"); + assertThat(clientResponseTrailersCapture.get()).isNotNull(); + } + + @Test + public void interceptCall_deny() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(false, false, 403, false, 0); + String serverName = InProcessServerBuilder.generateName(); + channel = buildChannel(serverName); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + server = buildAndStartServer(config, authzStub, serverName); + when(mockRandom.nextInt(100)).thenReturn(50); + CheckRequest checkRequest = CheckRequest.getDefaultInstance(); + when(mockCheckRequestBuilder.buildRequest(any(ServerCall.class), any(Metadata.class), + any(Timestamp.class))).thenReturn(checkRequest); + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus( + com.google.rpc.Status.newBuilder().setCode(Status.Code.PERMISSION_DENIED.value())) + .build(); + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(checkResponse); + observer.onCompleted(); + return null; + }).when(authzService).check(eq(checkRequest), ArgumentMatchers.any()); + Status expectedStatus = Status.PERMISSION_DENIED.withDescription("ext authz denied"); + AuthzResponse denyResponse = AuthzResponse.deny(expectedStatus).build(); + when(mockResponseHandler.handleResponse(eq(checkResponse), any())).thenReturn(denyResponse); + + SimpleServiceUnaryResponseObserver responseObserver = new SimpleServiceUnaryResponseObserver(); + ClientCalls.asyncUnaryCall( + channel.newCall(SimpleServiceGrpc.getUnaryRpcMethod(), io.grpc.CallOptions.DEFAULT), + SimpleRequest.newBuilder().setRequestMessage("world").build(), responseObserver); + responseObserver.await(); + + assertThat(responseObserver.getResponse()).isNull(); + assertThat(responseObserver.getError()).isNotNull(); + assertThat(responseObserver.getError().getCode()).isEqualTo(expectedStatus.getCode()); + assertThat(responseObserver.getError().getDescription()) + .isEqualTo(expectedStatus.getDescription()); + } + + @Test + public void interceptCall_authzServerError_failCall() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(false, false, 503, false, 0); + String serverName = InProcessServerBuilder.generateName(); + channel = buildChannel(serverName); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + server = buildAndStartServer(config, authzStub, serverName); + when(mockRandom.nextInt(100)).thenReturn(50); + CheckRequest checkRequest = CheckRequest.getDefaultInstance(); + when(mockCheckRequestBuilder.buildRequest(any(ServerCall.class), any(Metadata.class), + any(Timestamp.class))).thenReturn(checkRequest); + Status authzError = Status.UNAVAILABLE.withDescription("authz server unavailable"); + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onError(authzError.asRuntimeException()); + return null; + }).when(authzService).check(eq(checkRequest), ArgumentMatchers.any()); + + SimpleServiceUnaryResponseObserver responseObserver = new SimpleServiceUnaryResponseObserver(); + ClientCalls.asyncUnaryCall( + channel.newCall(SimpleServiceGrpc.getUnaryRpcMethod(), io.grpc.CallOptions.DEFAULT), + SimpleRequest.newBuilder().setRequestMessage("world").build(), responseObserver); + responseObserver.await(); + + assertThat(responseObserver.getResponse()).isNull(); + assertThat(responseObserver.getError()).isNotNull(); + assertThat(responseObserver.getError().getCode()).isEqualTo(config.statusOnError().getCode()); + } + + @Test + public void interceptCall_authzServerError_allow() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(true, false, 503, false, 0); + String serverName = InProcessServerBuilder.generateName(); + channel = buildChannel(serverName); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + server = buildAndStartServer(config, authzStub, serverName); + when(mockRandom.nextInt(100)).thenReturn(50); + CheckRequest checkRequest = CheckRequest.getDefaultInstance(); + when(mockCheckRequestBuilder.buildRequest(any(ServerCall.class), any(Metadata.class), + any(Timestamp.class))).thenReturn(checkRequest); + Status authzError = Status.UNAVAILABLE.withDescription("authz server unavailable"); + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onError(authzError.asRuntimeException()); + return null; + }).when(authzService).check(eq(checkRequest), ArgumentMatchers.any()); + + SimpleServiceUnaryResponseObserver responseObserver = new SimpleServiceUnaryResponseObserver(); + ClientCalls.asyncUnaryCall( + channel.newCall(SimpleServiceGrpc.getUnaryRpcMethod(), io.grpc.CallOptions.DEFAULT), + SimpleRequest.newBuilder().setRequestMessage("world").build(), responseObserver); + responseObserver.await(); + + assertThat(responseObserver.getResponse().getResponseMessage()).isEqualTo("Hello world"); + assertThat(responseObserver.getError()).isNull(); + assertThat(serverHeadersCapture.get().get( + Metadata.Key.of("x-envoy-auth-failure-mode-allowed", Metadata.ASCII_STRING_MARSHALLER))) + .isNull(); + } + + @Test + public void interceptCall_authzServerError_allowWithHeaderAdd() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(true, true, 503, false, 0); + String serverName = InProcessServerBuilder.generateName(); + channel = buildChannel(serverName); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + server = buildAndStartServer(config, authzStub, serverName); + when(mockRandom.nextInt(100)).thenReturn(50); + CheckRequest checkRequest = CheckRequest.getDefaultInstance(); + when(mockCheckRequestBuilder.buildRequest(any(ServerCall.class), any(Metadata.class), + any(Timestamp.class))).thenReturn(checkRequest); + Status authzError = Status.UNAVAILABLE.withDescription("authz server unavailable"); + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onError(authzError.asRuntimeException()); + return null; + }).when(authzService).check(eq(checkRequest), ArgumentMatchers.any()); + + SimpleServiceUnaryResponseObserver responseObserver = new SimpleServiceUnaryResponseObserver(); + ClientCalls.asyncUnaryCall( + channel.newCall(SimpleServiceGrpc.getUnaryRpcMethod(), io.grpc.CallOptions.DEFAULT), + SimpleRequest.newBuilder().setRequestMessage("world").build(), responseObserver); + responseObserver.await(); + + assertThat(responseObserver.getResponse().getResponseMessage()).isEqualTo("Hello world"); + assertThat(responseObserver.getError()).isNull(); + assertThat(serverHeadersCapture.get().get( + Metadata.Key.of("x-envoy-auth-failure-mode-allowed", Metadata.ASCII_STRING_MARSHALLER))) + .isEqualTo("true"); + } + + @Test + public void interceptCall_filterDisabled_denyAtDisable() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(false, false, 403, true, 100); + String serverName = InProcessServerBuilder.generateName(); + channel = buildChannel(serverName); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + server = buildAndStartServer(config, authzStub, serverName); + when(mockRandom.nextInt(100)).thenReturn(50); + + SimpleServiceUnaryResponseObserver responseObserver = new SimpleServiceUnaryResponseObserver(); + ClientCalls.asyncUnaryCall( + channel.newCall(SimpleServiceGrpc.getUnaryRpcMethod(), io.grpc.CallOptions.DEFAULT), + SimpleRequest.newBuilder().setRequestMessage("world").build(), responseObserver); + responseObserver.await(); + + assertThat(responseObserver.getResponse()).isNull(); + assertThat(responseObserver.getError()).isNotNull(); + assertThat(responseObserver.getError().getCode()).isEqualTo(config.statusOnError().getCode()); + verify(authzService, never()).check(any(), any()); + } + + @Test + public void interceptCall_filterDisabled_allow() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(false, false, 403, false, 100); + String serverName = InProcessServerBuilder.generateName(); + channel = buildChannel(serverName); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + server = buildAndStartServer(config, authzStub, serverName); + when(mockRandom.nextInt(100)).thenReturn(50); + + SimpleServiceUnaryResponseObserver responseObserver = new SimpleServiceUnaryResponseObserver(); + ClientCalls.asyncUnaryCall( + channel.newCall(SimpleServiceGrpc.getUnaryRpcMethod(), io.grpc.CallOptions.DEFAULT), + SimpleRequest.newBuilder().setRequestMessage("world").build(), responseObserver); + responseObserver.await(); + + assertThat(responseObserver.getResponse().getResponseMessage()).isEqualTo("Hello world"); + assertThat(responseObserver.getError()).isNull(); + verify(authzService, never()).check(any(), any()); + } + + @Test + public void interceptCall_streaming_allow() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(false, false, 403, false, 0); + String serverName = InProcessServerBuilder.generateName(); + channel = buildChannel(serverName); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + server = buildAndStartServer(config, authzStub, serverName); + when(mockRandom.nextInt(100)).thenReturn(50); + CheckRequest checkRequest = CheckRequest.getDefaultInstance(); + when(mockCheckRequestBuilder.buildRequest(any(ServerCall.class), any(Metadata.class), + any(Timestamp.class))).thenReturn(checkRequest); + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus(com.google.rpc.Status.newBuilder().setCode(Status.Code.OK.value())).build(); + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(checkResponse); + observer.onCompleted(); + return null; + }).when(authzService).check(eq(checkRequest), ArgumentMatchers.any()); + AuthzResponse allowResponse = AuthzResponse.allow(new Metadata()).build(); + when(mockResponseHandler.handleResponse(eq(checkResponse), any())).thenReturn(allowResponse); + + SimpleServiceStreamingResponseObserver responseObserver = + new SimpleServiceStreamingResponseObserver(); + StreamObserver requestObserver = ClientCalls.asyncBidiStreamingCall( + channel.newCall(SimpleServiceGrpc.getBidiStreamingRpcMethod(), io.grpc.CallOptions.DEFAULT), + responseObserver); + requestObserver.onNext(SimpleRequest.newBuilder().setRequestMessage("world").build()); + requestObserver.onCompleted(); + responseObserver.await(); + + assertThat(responseObserver.getResponses()).hasSize(1); + assertThat(responseObserver.getResponses().get(0).getResponseMessage()) + .isEqualTo("Hello world"); + assertThat(responseObserver.getError()).isNull(); + } + + @Test + public void interceptCall_streaming_deny() throws Exception { + ExtAuthzConfig config = buildExtAuthzConfig(false, false, 403, false, 0); + String serverName = InProcessServerBuilder.generateName(); + channel = buildChannel(serverName); + AuthorizationGrpc.AuthorizationStub authzStub = AuthorizationGrpc.newStub(channel); + server = buildAndStartServer(config, authzStub, serverName); + when(mockRandom.nextInt(100)).thenReturn(50); + CheckRequest checkRequest = CheckRequest.getDefaultInstance(); + when(mockCheckRequestBuilder.buildRequest(any(ServerCall.class), any(Metadata.class), + any(Timestamp.class))).thenReturn(checkRequest); + CheckResponse checkResponse = CheckResponse.newBuilder() + .setStatus( + com.google.rpc.Status.newBuilder().setCode(Status.Code.PERMISSION_DENIED.value())) + .build(); + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(checkResponse); + observer.onCompleted(); + return null; + }).when(authzService).check(eq(checkRequest), ArgumentMatchers.any()); + Status expectedStatus = Status.PERMISSION_DENIED.withDescription("ext authz denied"); + AuthzResponse denyResponse = AuthzResponse.deny(expectedStatus).build(); + when(mockResponseHandler.handleResponse(eq(checkResponse), any())).thenReturn(denyResponse); + + SimpleServiceStreamingResponseObserver responseObserver = + new SimpleServiceStreamingResponseObserver(); + StreamObserver requestObserver = ClientCalls.asyncBidiStreamingCall( + channel.newCall(SimpleServiceGrpc.getBidiStreamingRpcMethod(), io.grpc.CallOptions.DEFAULT), + responseObserver); + requestObserver.onNext(SimpleRequest.newBuilder().setRequestMessage("world").build()); + requestObserver.onCompleted(); + responseObserver.await(); + + assertThat(responseObserver.getResponses()).isEmpty(); + assertThat(responseObserver.getError()).isNotNull(); + assertThat(responseObserver.getError().getCode()).isEqualTo(expectedStatus.getCode()); + assertThat(responseObserver.getError().getDescription()) + .isEqualTo(expectedStatus.getDescription()); + } + + private ManagedChannel buildChannel(String serverName) { + return InProcessChannelBuilder.forName(serverName).intercept(MetadataUtils + .newCaptureMetadataInterceptor(clientResponseHeadersCapture, clientResponseTrailersCapture)) + .directExecutor().build(); + + } + + private Server buildAndStartServer(ExtAuthzConfig config, + AuthorizationGrpc.AuthorizationStub authzStub, String serverName) throws IOException { + ServerInterceptor interceptor = ExtAuthzServerInterceptor.INSTANCE.create(config, authzStub, + mockRandom, mockCheckRequestBuilder, mockResponseHandler, mockHeaderMutator); + + return InProcessServerBuilder.forName(serverName).addService(authzService) + .addService(ServerInterceptors.intercept(simpleServiceImpl, + new MetadataCapturingServerInterceptor(serverHeadersCapture), interceptor)) + .directExecutor().build().start(); + + } + + private ExtAuthzConfig buildExtAuthzConfig(boolean failureModeAllow, + boolean failureModeAllowHeaderAdd, int httpStatusOnError, boolean denyAtDisable, int percent) + throws ExtAuthzParseException { + Any googleDefaultChannelCreds = Any.pack(GoogleDefaultCredentials.newBuilder().build()); + Any fakeAccessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); + ExtAuthz extAuthz = ExtAuthz.newBuilder() + .setGrpcService(GrpcService.newBuilder() + .setGoogleGrpc(GrpcService.GoogleGrpc.newBuilder().setTargetUri("test-cluster") + .addChannelCredentialsPlugin(googleDefaultChannelCreds) + .addCallCredentialsPlugin(fakeAccessTokenCreds).build()) + .build()) + .setFailureModeAllow(failureModeAllow) + .setFailureModeAllowHeaderAdd(failureModeAllowHeaderAdd) + .setDenyAtDisable( + RuntimeFeatureFlag.newBuilder().setDefaultValue(BoolValue.of(denyAtDisable)).build()) + .setStatusOnError(HttpStatus.newBuilder().setCodeValue(httpStatusOnError).build()) + .setFilterEnabled(RuntimeFractionalPercent.newBuilder() + .setDefaultValue(FractionalPercent.newBuilder().setNumerator(percent) + .setDenominator(DenominatorType.HUNDRED).build()) + .build()) + .setIncludePeerCertificate(denyAtDisable).build(); + return ExtAuthzConfig.fromProto(extAuthz); + } + + private static class SimpleServiceUnaryResponseObserver + implements StreamObserver { + + final AtomicReference responseCapture = new AtomicReference<>(); + final AtomicReference errorCapture = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(1); + + @Override + public void onNext(SimpleResponse value) { + responseCapture.set(value); + } + + @Override + public void onError(Throwable t) { + errorCapture.set(Status.fromThrowable(t)); + latch.countDown(); + } + + @Override + public void onCompleted() { + latch.countDown(); + } + + public void await() throws InterruptedException { + latch.await(5, TimeUnit.SECONDS); + } + + public SimpleResponse getResponse() { + return responseCapture.get(); + } + + public Status getError() { + return errorCapture.get(); + } + } + + private static class SimpleServiceStreamingResponseObserver + implements StreamObserver { + + final ImmutableList.Builder responsesCapture = new ImmutableList.Builder<>(); + final AtomicReference errorCapture = new AtomicReference<>(); + final CountDownLatch latch = new CountDownLatch(1); + + @Override + public void onNext(SimpleResponse value) { + responsesCapture.add(value); + } + + @Override + public void onError(Throwable t) { + errorCapture.set(Status.fromThrowable(t)); + latch.countDown(); + } + + @Override + public void onCompleted() { + latch.countDown(); + } + + public void await() throws InterruptedException { + latch.await(5, TimeUnit.SECONDS); + } + + public ImmutableList getResponses() { + return responsesCapture.build(); + } + + public Status getError() { + return errorCapture.get(); + } + } + + private static final class MetadataCapturingServerInterceptor implements ServerInterceptor { + private final AtomicReference headersCapture; + + MetadataCapturingServerInterceptor(AtomicReference headersCapture) { + this.headersCapture = headersCapture; + } + + @Override + public ServerCall.Listener interceptCall(ServerCall call, + Metadata headers, ServerCallHandler next) { + Metadata metadataCopy = new Metadata(); + metadataCopy.merge(headers); + headersCapture.set(metadataCopy); + return next.startCall(call, headers); + } + } + + private void setUpAllowCheck(CheckRequest checkRequest, CheckResponse checkResponse, + ResponseHeaderMutations responseHeaderMutations) { + when(mockRandom.nextInt(100)).thenReturn(50); + when(mockCheckRequestBuilder.buildRequest(any(ServerCall.class), any(Metadata.class), + any(Timestamp.class))).thenReturn(checkRequest); + doAnswer(invocation -> { + StreamObserver observer = invocation.getArgument(1); + observer.onNext(checkResponse); + observer.onCompleted(); + return null; + }).when(authzService).check(eq(checkRequest), ArgumentMatchers.any()); + Metadata headersFromServer = new Metadata(); + headersFromServer.put(Metadata.Key.of("auth-key", Metadata.ASCII_STRING_MARSHALLER), + "auth-value"); + AuthzResponse allowResponse = AuthzResponse.allow(headersFromServer) + .setResponseHeaderMutations(responseHeaderMutations).build(); + when(mockResponseHandler.handleResponse(eq(checkResponse), any())).thenReturn(allowResponse); + doAnswer(invocation -> { + Metadata headers = invocation.getArgument(1); + headers.put(Metadata.Key.of("client-resp-key", Metadata.ASCII_STRING_MARSHALLER), + "client-resp-value"); + return null; + }).when(mockHeaderMutator).applyResponseMutations(eq(responseHeaderMutations), + any(Metadata.class)); + } +} From a1d4d2368e671572c4b58bb1dd5dae441a542dbc Mon Sep 17 00:00:00 2001 From: Saurav Date: Thu, 6 Nov 2025 10:14:30 +0000 Subject: [PATCH 7/7] feat(xds): Add ExternalAuthorizationFilter This commit introduces the `ExternalAuthorizationFilter`, an implementation of the `Filter` interface that provides external authorization capabilities. The `ExternalAuthorizationFilter` is responsible for: - Parsing `ExtAuthz` and `ExtAuthzPerRoute` configurations. - Creating `ExtAuthzClientInterceptor` and `ExtAuthzServerInterceptor` to handle client and server-side authorization. - Managing the lifecycle of the authorization stub using a `StubManager`. The `StubManager` is a new class that manages the lifecycle of the `AuthorizationStub`, including creating and caching the gRPC channel and stub based on the provided configuration. This ensures that a single channel and stub are reused for the same configuration, improving performance and resource utilization. --- .../main/java/io/grpc/xds/ExtAuthzFilter.java | 217 +++++++++++++ .../xds/internal/extauthz/StubManager.java | 114 +++++++ .../java/io/grpc/xds/ExtAuthzFilterTest.java | 298 ++++++++++++++++++ .../internal/extauthz/StubManagerTest.java | 152 +++++++++ 4 files changed, 781 insertions(+) create mode 100644 xds/src/main/java/io/grpc/xds/ExtAuthzFilter.java create mode 100644 xds/src/main/java/io/grpc/xds/internal/extauthz/StubManager.java create mode 100644 xds/src/test/java/io/grpc/xds/ExtAuthzFilterTest.java create mode 100644 xds/src/test/java/io/grpc/xds/internal/extauthz/StubManagerTest.java diff --git a/xds/src/main/java/io/grpc/xds/ExtAuthzFilter.java b/xds/src/main/java/io/grpc/xds/ExtAuthzFilter.java new file mode 100644 index 00000000000..c74b8817d90 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/ExtAuthzFilter.java @@ -0,0 +1,217 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import com.google.protobuf.Any; +import com.google.protobuf.InvalidProtocolBufferException; +import com.google.protobuf.Message; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.service.auth.v3.AuthorizationGrpc; +import io.grpc.ClientInterceptor; +import io.grpc.ServerInterceptor; +import io.grpc.xds.internal.ThreadSafeRandom; +import io.grpc.xds.internal.ThreadSafeRandom.ThreadSafeRandomImpl; +import io.grpc.xds.internal.extauthz.BufferingAuthzClientCall; +import io.grpc.xds.internal.extauthz.CheckRequestBuilder; +import io.grpc.xds.internal.extauthz.CheckResponseHandler; +import io.grpc.xds.internal.extauthz.ExtAuthzCertificateProvider; +import io.grpc.xds.internal.extauthz.ExtAuthzClientInterceptor; +import io.grpc.xds.internal.extauthz.ExtAuthzConfig; +import io.grpc.xds.internal.extauthz.ExtAuthzParseException; +import io.grpc.xds.internal.extauthz.ExtAuthzServerInterceptor; +import io.grpc.xds.internal.extauthz.StubManager; +import io.grpc.xds.internal.grpcservice.InsecureGrpcChannelFactory; +import io.grpc.xds.internal.headermutations.HeaderMutationFilter; +import io.grpc.xds.internal.headermutations.HeaderMutator; +import java.util.concurrent.ScheduledExecutorService; +import javax.annotation.Nullable; + +final class ExtAuthzFilter implements Filter { + + private static final String TYPE_URL = + "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthz"; + + private static final String TYPE_URL_OVERRIDE_CONFIG = + "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute"; + + + static final class ExtAuthzFilterConfig implements Filter.FilterConfig { + + private final ExtAuthzConfig extAuthzConfig; + + ExtAuthzFilterConfig(ExtAuthzConfig extAuthzConfig) { + this.extAuthzConfig = extAuthzConfig; + } + + public ExtAuthzConfig extAuthzConfig() { + return extAuthzConfig; + } + + @Override + public String typeUrl() { + return ExtAuthzFilter.TYPE_URL; + } + + public static ExtAuthzFilterConfig fromProto(ExtAuthz extAuthzProto) + throws ExtAuthzParseException { + return new ExtAuthzFilterConfig(ExtAuthzConfig.fromProto(extAuthzProto)); + } + } + + // Placeholder for the external authorization filter's override config. + static final class ExtAuthzFilterConfigOverride implements Filter.FilterConfig { + @Override + public final String typeUrl() { + return ExtAuthzFilter.TYPE_URL_OVERRIDE_CONFIG; + } + } + + static final class Provider implements Filter.Provider { + + @Override + public String[] typeUrls() { + return new String[] {TYPE_URL, TYPE_URL_OVERRIDE_CONFIG}; + } + + @Override + public boolean isClientFilter() { + return true; + } + + @Override + public boolean isServerFilter() { + return true; + } + + @Override + public ExtAuthzFilter newInstance(String name) { + // Create a dedicated scheduler for this filter instance's StubManager + StubManager stubManager = StubManager.create(InsecureGrpcChannelFactory.getInstance()); + return new ExtAuthzFilter(stubManager, ThreadSafeRandomImpl.INSTANCE, + BufferingAuthzClientCall.FACTORY_INSTANCE, ExtAuthzCertificateProvider.create(), + CheckRequestBuilder.INSTANCE, CheckResponseHandler.INSTANCE, + ExtAuthzClientInterceptor.INSTANCE, ExtAuthzServerInterceptor.INSTANCE, + HeaderMutationFilter.INSTANCE, HeaderMutator.create()); + } + + @Override + public ConfigOrError parseFilterConfig(Message rawProtoMessage) { + ExtAuthz extAuthzProto; + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + } + Any anyMessage = (Any) rawProtoMessage; + try { + extAuthzProto = anyMessage.unpack(ExtAuthz.class); + return ConfigOrError.fromConfig(ExtAuthzFilterConfig.fromProto(extAuthzProto)); + } catch (InvalidProtocolBufferException | ExtAuthzParseException e) { + return ConfigOrError.fromError("Invalid proto: " + e); + } + } + + @Override + public ConfigOrError parseFilterConfigOverride( + Message rawProtoMessage) { + if (!(rawProtoMessage instanceof Any)) { + return ConfigOrError.fromError("Invalid config type: " + rawProtoMessage.getClass()); + } + return ConfigOrError.fromConfig(new ExtAuthzFilterConfigOverride()); + } + } + + private final StubManager stubManager; + private final ThreadSafeRandom random; + private final BufferingAuthzClientCall.Factory bufferingAuthzClientCallFactory; + private final ExtAuthzCertificateProvider certificateProvider; + private final CheckRequestBuilder.Factory checkRequestBuilderFactory; + private final CheckResponseHandler.Factory checkResponseHandlerFactory; + private final ExtAuthzClientInterceptor.Factory extAuthzClientInterceptorFactory; + private final ExtAuthzServerInterceptor.Factory extAuthzServerInterceptorFactory; + private final HeaderMutationFilter.Factory headerMutationFilterFactory; + private final HeaderMutator headerMutator; + + + ExtAuthzFilter(StubManager stubManager, ThreadSafeRandom random, + BufferingAuthzClientCall.Factory bufferingAuthzClientCallFactory, + ExtAuthzCertificateProvider certificateProvider, + CheckRequestBuilder.Factory checkRequestBuilderFactory, + CheckResponseHandler.Factory checkResponseHandlerFactory, + ExtAuthzClientInterceptor.Factory extAuthzClientInterceptorFactory, + ExtAuthzServerInterceptor.Factory extAuthzServerInterceptorFactory, + HeaderMutationFilter.Factory headerMutationFilterFactory, HeaderMutator headerMutator) { + this.stubManager = stubManager; + this.random = random; + this.bufferingAuthzClientCallFactory = bufferingAuthzClientCallFactory; + this.certificateProvider = certificateProvider; + this.checkRequestBuilderFactory = checkRequestBuilderFactory; + this.checkResponseHandlerFactory = checkResponseHandlerFactory; + this.extAuthzClientInterceptorFactory = extAuthzClientInterceptorFactory; + this.extAuthzServerInterceptorFactory = extAuthzServerInterceptorFactory; + this.headerMutationFilterFactory = headerMutationFilterFactory; + this.headerMutator = headerMutator; + } + + @Nullable + @Override + public ClientInterceptor buildClientInterceptor(FilterConfig config, + @Nullable FilterConfig overrideConfig, ScheduledExecutorService scheduler) { + if (overrideConfig != null) { + return null; + } + if (!(config instanceof ExtAuthzFilterConfig)) { + return null; + } + ExtAuthzFilterConfig extAuthzFilterConfig = (ExtAuthzFilterConfig) config; + AuthorizationGrpc.AuthorizationStub stub = + stubManager.getStub(extAuthzFilterConfig.extAuthzConfig()); + ExtAuthzConfig extAuthzConfig = extAuthzFilterConfig.extAuthzConfig(); + return extAuthzClientInterceptorFactory.create(extAuthzConfig, stub, + random, bufferingAuthzClientCallFactory, + checkRequestBuilderFactory.create(extAuthzConfig, certificateProvider), + checkResponseHandlerFactory.create(headerMutator, + headerMutationFilterFactory.create(extAuthzConfig.decoderHeaderMutationRules()), + extAuthzConfig), + headerMutator); + } + + @Nullable + @Override + public ServerInterceptor buildServerInterceptor(FilterConfig config, + @Nullable FilterConfig overrideConfig) { + if (overrideConfig != null) { + return null; + } + if (!(config instanceof ExtAuthzFilterConfig)) { + return null; + } + ExtAuthzFilterConfig extAuthzFilterConfig = (ExtAuthzFilterConfig) config; + AuthorizationGrpc.AuthorizationStub stub = + stubManager.getStub(extAuthzFilterConfig.extAuthzConfig()); + ExtAuthzConfig extAuthzConfig = extAuthzFilterConfig.extAuthzConfig(); + return extAuthzServerInterceptorFactory.create(extAuthzConfig, stub, random, + checkRequestBuilderFactory.create(extAuthzConfig, certificateProvider), + checkResponseHandlerFactory.create(headerMutator, + headerMutationFilterFactory.create(extAuthzConfig.decoderHeaderMutationRules()), + extAuthzConfig), + headerMutator); + } + + @Override + public void close() { + stubManager.close(); + } +} diff --git a/xds/src/main/java/io/grpc/xds/internal/extauthz/StubManager.java b/xds/src/main/java/io/grpc/xds/internal/extauthz/StubManager.java new file mode 100644 index 00000000000..d5a0698a393 --- /dev/null +++ b/xds/src/main/java/io/grpc/xds/internal/extauthz/StubManager.java @@ -0,0 +1,114 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import com.google.auto.value.AutoValue; +import io.envoyproxy.envoy.service.auth.v3.AuthorizationGrpc; +import io.grpc.ManagedChannel; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfig.GoogleGrpcConfig; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfigChannelFactory; +import java.util.Optional; +import javax.annotation.concurrent.GuardedBy; + +/** + * Manages the lifecycle of the authorization stub. + */ +public interface StubManager { + /** Creates a new instance of {@code StubManager}. */ + static StubManager create(GrpcServiceConfigChannelFactory channelFactory) { + return new StubManagerImpl(channelFactory); + } + + /** + * Returns a stub for the given configuration. + */ + AuthorizationGrpc.AuthorizationStub getStub(ExtAuthzConfig config); + + /** + * Frees underlying resources on shutdown. + */ + public void close(); + + /** + * Default implementation of {@link StubManager}. + */ + final class StubManagerImpl implements StubManager { + + private final GrpcServiceConfigChannelFactory channelFactory; + private final Object lock = new Object(); + + @GuardedBy("lock") + private Optional stubHolder = Optional.empty(); + + private StubManagerImpl(GrpcServiceConfigChannelFactory channelFactory) { // NOPMD + this.channelFactory = channelFactory; + } + + @Override + public AuthorizationGrpc.AuthorizationStub getStub(ExtAuthzConfig config) { + GoogleGrpcConfig googleGrpc = config.grpcService().googleGrpc(); + ChannelKey newChannelKey = + ChannelKey.of(googleGrpc.target(), googleGrpc.hashedChannelCredentials().hash()); + + synchronized (lock) { + if (stubHolder.isPresent() && stubHolder.get().channelKey().equals(newChannelKey)) { + return stubHolder.get().stub(); + } + Optional oldChannel = stubHolder.map(StubHolder::channel); + ManagedChannel newChannel = channelFactory.createChannel(config.grpcService()); + stubHolder = Optional.of( + StubHolder.create(newChannelKey, newChannel, AuthorizationGrpc.newStub(newChannel))); + oldChannel.ifPresent(ManagedChannel::shutdown); + return stubHolder.get().stub(); + } + } + + @AutoValue + abstract static class ChannelKey { + static ChannelKey of(String target, int hash) { + return new AutoValue_StubManager_StubManagerImpl_ChannelKey(target, hash); + } + + abstract String target(); + + abstract int hash(); + } + + @AutoValue + abstract static class StubHolder { + static StubHolder create(ChannelKey channelKey, ManagedChannel channel, + AuthorizationGrpc.AuthorizationStub stub) { + return new AutoValue_StubManager_StubManagerImpl_StubHolder(channelKey, channel, stub); + } + + abstract ChannelKey channelKey(); + + abstract ManagedChannel channel(); + + abstract AuthorizationGrpc.AuthorizationStub stub(); + } + + @Override + public void close() { + synchronized (lock) { + stubHolder.ifPresent(holder -> { + holder.channel().shutdown(); + }); + } + } + } +} diff --git a/xds/src/test/java/io/grpc/xds/ExtAuthzFilterTest.java b/xds/src/test/java/io/grpc/xds/ExtAuthzFilterTest.java new file mode 100644 index 00000000000..f8073418a8e --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/ExtAuthzFilterTest.java @@ -0,0 +1,298 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.protobuf.Any; +import com.google.protobuf.Empty; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.google_default.v3.GoogleDefaultCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.insecure.v3.InsecureCredentials; +import io.envoyproxy.envoy.service.auth.v3.AuthorizationGrpc; +import io.grpc.ClientInterceptor; +import io.grpc.ManagedChannel; +import io.grpc.ServerInterceptor; +import io.grpc.xds.ExtAuthzFilter.ExtAuthzFilterConfig; +import io.grpc.xds.ExtAuthzFilter.ExtAuthzFilterConfigOverride; +import io.grpc.xds.internal.extauthz.BufferingAuthzClientCall; +import io.grpc.xds.internal.extauthz.CheckRequestBuilder; +import io.grpc.xds.internal.extauthz.CheckResponseHandler; +import io.grpc.xds.internal.extauthz.ExtAuthzCertificateProvider; +import io.grpc.xds.internal.extauthz.ExtAuthzClientInterceptor; +import io.grpc.xds.internal.extauthz.ExtAuthzConfig; +import io.grpc.xds.internal.extauthz.ExtAuthzParseException; +import io.grpc.xds.internal.extauthz.ExtAuthzServerInterceptor; +import io.grpc.xds.internal.extauthz.StubManager; +import io.grpc.xds.internal.headermutations.HeaderMutationFilter; +import io.grpc.xds.internal.headermutations.HeaderMutator; +import java.util.concurrent.ScheduledExecutorService; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +/** + * Unit tests for {@link ExtAuthzFilter}. + */ +@RunWith(JUnit4.class) +public class ExtAuthzFilterTest { + + @Rule + public final MockitoRule mocks = MockitoJUnit.rule(); + + @Mock + private StubManager mockStubManager; + @Mock + private ThreadSafeRandom mockRandom; + @Mock + private BufferingAuthzClientCall.Factory mockBufferingAuthzClientCallFactory; + @Mock + private ExtAuthzCertificateProvider mockCertificateProvider; + @Mock + private CheckRequestBuilder.Factory mockCheckRequestBuilderFactory; + @Mock + private CheckRequestBuilder mockCheckRequestBuilder; + @Mock + private CheckResponseHandler.Factory mockCheckResponseHandlerFactory; + @Mock + private CheckResponseHandler mockCheckResponseHandler; + @Mock + private ExtAuthzClientInterceptor.Factory mockClientInterceptorFactory; + @Mock + private ExtAuthzServerInterceptor.Factory mockServerInterceptorFactory; + @Mock + private HeaderMutationFilter.Factory mockHeaderMutationFilterFactory; + @Mock + private HeaderMutationFilter mockHeaderMutationFilter; + @Mock + private HeaderMutator mockHeaderMutator; + @Mock + private ManagedChannel mockChannel; + @Mock + private ScheduledExecutorService mockScheduler; + + private ExtAuthzFilter filter; + private final ExtAuthzFilter.Provider provider = new ExtAuthzFilter.Provider(); + private ExtAuthzConfig extAuthzConfig; + private AuthorizationGrpc.AuthorizationStub authzStub; + + @Before + public void setUp() { + authzStub = AuthorizationGrpc.newStub(mockChannel); + filter = new ExtAuthzFilter(mockStubManager, mockRandom, + mockBufferingAuthzClientCallFactory, mockCertificateProvider, + mockCheckRequestBuilderFactory, mockCheckResponseHandlerFactory, + mockClientInterceptorFactory, mockServerInterceptorFactory, mockHeaderMutationFilterFactory, + mockHeaderMutator); + } + + private ExtAuthzConfig buildExtAuthzConfig() throws ExtAuthzParseException { + ExtAuthz extAuthz = ExtAuthz.newBuilder() + .setGrpcService(GrpcService.newBuilder() + .setGoogleGrpc(GrpcService.GoogleGrpc.newBuilder().setTargetUri("authz.service.com") + .addChannelCredentialsPlugin(Any.pack(InsecureCredentials.newBuilder().build())) + .addCallCredentialsPlugin( + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build())) + .build()) + .build()) + .build(); + return ExtAuthzConfig.fromProto(extAuthz); + } + + @Test + public void buildClientInterceptor_success() throws ExtAuthzParseException { + extAuthzConfig = buildExtAuthzConfig(); + ExtAuthzFilterConfig filterConfig = new ExtAuthzFilterConfig(extAuthzConfig); + when(mockStubManager.getStub(extAuthzConfig)).thenReturn(authzStub); + when(mockCheckRequestBuilderFactory.create(extAuthzConfig, mockCertificateProvider)) + .thenReturn(mockCheckRequestBuilder); + when(mockHeaderMutationFilterFactory.create(any())).thenReturn(mockHeaderMutationFilter); + when(mockCheckResponseHandlerFactory.create(mockHeaderMutator, mockHeaderMutationFilter, + extAuthzConfig)).thenReturn(mockCheckResponseHandler); + ExtAuthzClientInterceptor interceptor = + (ExtAuthzClientInterceptor) ExtAuthzClientInterceptor.INSTANCE.create(null, null, null, + null, null, null, null); + when(mockClientInterceptorFactory.create(extAuthzConfig, authzStub, mockRandom, + mockBufferingAuthzClientCallFactory, mockCheckRequestBuilder, mockCheckResponseHandler, + mockHeaderMutator)).thenReturn(interceptor); + + ClientInterceptor created = filter.buildClientInterceptor(filterConfig, null, mockScheduler); + assertThat(created).isSameInstanceAs(interceptor); + } + + @Test + public void buildClientInterceptor_withOverride_returnsNull() throws ExtAuthzParseException { + extAuthzConfig = buildExtAuthzConfig(); + ClientInterceptor interceptor = + filter.buildClientInterceptor(new ExtAuthzFilterConfig(extAuthzConfig), + new ExtAuthzFilterConfigOverride(), mockScheduler); + assertThat(interceptor).isNull(); + } + + @Test + public void buildClientInterceptor_wrongConfigType_returnsNull() { + ClientInterceptor interceptor = + filter.buildClientInterceptor(mock(Filter.FilterConfig.class), null, mockScheduler); + assertThat(interceptor).isNull(); + } + + @Test + public void buildServerInterceptor_success() throws ExtAuthzParseException { + extAuthzConfig = buildExtAuthzConfig(); + ExtAuthzFilterConfig filterConfig = new ExtAuthzFilterConfig(extAuthzConfig); + when(mockStubManager.getStub(extAuthzConfig)).thenReturn(authzStub); + when(mockCheckRequestBuilderFactory.create(extAuthzConfig, mockCertificateProvider)) + .thenReturn(mockCheckRequestBuilder); + when(mockHeaderMutationFilterFactory.create(any())).thenReturn(mockHeaderMutationFilter); + when(mockCheckResponseHandlerFactory.create(mockHeaderMutator, mockHeaderMutationFilter, + extAuthzConfig)).thenReturn(mockCheckResponseHandler); + ExtAuthzServerInterceptor interceptor = + (ExtAuthzServerInterceptor) ExtAuthzServerInterceptor.INSTANCE.create(null, null, null, + null, null, null); + when(mockServerInterceptorFactory.create(extAuthzConfig, authzStub, mockRandom, + mockCheckRequestBuilder, mockCheckResponseHandler, mockHeaderMutator)) + .thenReturn(interceptor); + + ServerInterceptor created = filter.buildServerInterceptor(filterConfig, null); + assertThat(created).isSameInstanceAs(interceptor); + } + + @Test + public void buildServerInterceptor_withOverride_returnsNull() throws ExtAuthzParseException { + extAuthzConfig = buildExtAuthzConfig(); + ServerInterceptor interceptor = filter.buildServerInterceptor( + new ExtAuthzFilterConfig(extAuthzConfig), new ExtAuthzFilterConfigOverride()); + assertThat(interceptor).isNull(); + } + + @Test + public void buildServerInterceptor_wrongConfigType_returnsNull() { + ServerInterceptor interceptor = + filter.buildServerInterceptor(mock(Filter.FilterConfig.class), null); + assertThat(interceptor).isNull(); + } + + @Test + public void close_shouldCloseStubManager() { + filter.close(); + verify(mockStubManager).close(); + } + + @Test + public void provider_typeUrls() { + assertThat(provider.typeUrls()).asList().containsExactly( + "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthz", + "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute"); + } + + @Test + public void provider_isClientAndServerFilter() { + assertThat(provider.isClientFilter()).isTrue(); + assertThat(provider.isServerFilter()).isTrue(); + } + + @Test + public void provider_newInstance() { + ExtAuthzFilter instance = provider.newInstance("test-filter"); + assertThat(instance).isNotNull(); + } + + @Test + public void provider_parseFilterConfig_success() { + + Any googleDefaultChannelCreds = Any.pack(GoogleDefaultCredentials.newBuilder().build()); + Any fakeAccessTokenCreds = + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build()); + ExtAuthz extAuthz = ExtAuthz.newBuilder() + .setGrpcService(io.envoyproxy.envoy.config.core.v3.GrpcService.newBuilder() + .setGoogleGrpc(io.envoyproxy.envoy.config.core.v3.GrpcService.GoogleGrpc.newBuilder() + .setTargetUri("authz.service.com") + .addChannelCredentialsPlugin(googleDefaultChannelCreds) + .addCallCredentialsPlugin(fakeAccessTokenCreds).build()) + .build()) + .setStatusOnError( + io.envoyproxy.envoy.type.v3.HttpStatus.newBuilder().setCodeValue(403).build()) + .build(); + Any anyProto = Any.pack(extAuthz); + ConfigOrError result = provider.parseFilterConfig(anyProto); + assertThat(result.config).isNotNull(); + assertThat(result.errorDetail).isNull(); + assertThat(result.config.extAuthzConfig().grpcService().googleGrpc().target()) + .isEqualTo("authz.service.com"); + } + + @Test + public void provider_parseFilterConfig_invalidProto() { + Any anyProto = Any.pack(Empty.getDefaultInstance()); + + ConfigOrError result = provider.parseFilterConfig(anyProto); + + assertThat(result.config).isNull(); + assertThat(result.errorDetail).contains("Invalid proto"); + } + + @Test + public void provider_parseFilterConfig_notAny() { + ConfigOrError result = + provider.parseFilterConfig(Empty.getDefaultInstance()); + assertThat(result.config).isNull(); + assertThat(result.errorDetail).contains("Invalid config type"); + } + + @Test + public void provider_parseFilterConfigOverride_success() { + Any anyProto = Any.pack(ExtAuthz.getDefaultInstance()); + ConfigOrError result = + provider.parseFilterConfigOverride(anyProto); + assertThat(result.config).isNotNull(); + assertThat(result.errorDetail).isNull(); + } + + @Test + public void provider_parseFilterConfigOverride_notAny() { + ConfigOrError result = + provider.parseFilterConfigOverride(Empty.getDefaultInstance()); + assertThat(result.config).isNull(); + assertThat(result.errorDetail).contains("Invalid config type"); + } + + @Test + public void extAuthzFilterConfig_typeUrl() throws ExtAuthzParseException { + extAuthzConfig = buildExtAuthzConfig(); + ExtAuthzFilterConfig config = new ExtAuthzFilterConfig(extAuthzConfig); + assertThat(config.typeUrl()) + .isEqualTo("type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthz"); + } + + @Test + public void extAuthzFilterConfigOverride_typeUrl() { + ExtAuthzFilterConfigOverride override = new ExtAuthzFilterConfigOverride(); + assertThat(override.typeUrl()).isEqualTo( + "type.googleapis.com/envoy.extensions.filters.http.ext_authz.v3.ExtAuthzPerRoute"); + } +} diff --git a/xds/src/test/java/io/grpc/xds/internal/extauthz/StubManagerTest.java b/xds/src/test/java/io/grpc/xds/internal/extauthz/StubManagerTest.java new file mode 100644 index 00000000000..6a60f0bf142 --- /dev/null +++ b/xds/src/test/java/io/grpc/xds/internal/extauthz/StubManagerTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2025 The gRPC Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.grpc.xds.internal.extauthz; + +import static com.google.common.truth.Truth.assertThat; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import com.google.protobuf.Any; +import io.envoyproxy.envoy.config.core.v3.GrpcService; +import io.envoyproxy.envoy.extensions.filters.http.ext_authz.v3.ExtAuthz; +import io.envoyproxy.envoy.extensions.grpc_service.call_credentials.access_token.v3.AccessTokenCredentials; +import io.envoyproxy.envoy.extensions.grpc_service.channel_credentials.insecure.v3.InsecureCredentials; +import io.envoyproxy.envoy.service.auth.v3.AuthorizationGrpc; +import io.grpc.ManagedChannel; +import io.grpc.xds.internal.grpcservice.GrpcServiceConfigChannelFactory; +import org.junit.Before; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.JUnit4; +import org.mockito.Mock; +import org.mockito.junit.MockitoJUnit; +import org.mockito.junit.MockitoRule; + +@RunWith(JUnit4.class) +public class StubManagerTest { + + @Rule + public final MockitoRule mockito = MockitoJUnit.rule(); + + @Mock + private GrpcServiceConfigChannelFactory channelFactory; + @Mock + private ManagedChannel channel1; + @Mock + private ManagedChannel channel2; + + private StubManager stubManager; + private ExtAuthzConfig config1; + private ExtAuthzConfig config2; + + @Before + public void setUp() throws ExtAuthzParseException { + stubManager = StubManager.create(channelFactory); + config1 = buildExtAuthzConfig("target1"); + config2 = buildExtAuthzConfig("target2"); + + when(channelFactory.createChannel(config1.grpcService())).thenReturn(channel1); + when(channelFactory.createChannel(config2.grpcService())).thenReturn(channel2); + } + + private ExtAuthzConfig buildExtAuthzConfig(String targetUri) throws ExtAuthzParseException { + ExtAuthz extAuthz = ExtAuthz.newBuilder() + .setGrpcService(GrpcService.newBuilder() + .setGoogleGrpc(GrpcService.GoogleGrpc.newBuilder().setTargetUri(targetUri) + .addChannelCredentialsPlugin(Any.pack(InsecureCredentials.newBuilder().build())) + .addCallCredentialsPlugin( + Any.pack(AccessTokenCredentials.newBuilder().setToken("fake-token").build())) + .build()) + .build()) + .build(); + return ExtAuthzConfig.fromProto(extAuthz); + } + + @Test + public void getStub_createsNewStubAndChannel_firstTime() { + AuthorizationGrpc.AuthorizationStub stub = stubManager.getStub(config1); + assertThat(stub).isNotNull(); + verify(channelFactory).createChannel(config1.grpcService()); + } + + @Test + public void getStub_returnsExistingStub_sameConfig() throws ExtAuthzParseException { + AuthorizationGrpc.AuthorizationStub stub1 = stubManager.getStub(config1); + ExtAuthzConfig sameAsConfig1 = buildExtAuthzConfig("target1"); + AuthorizationGrpc.AuthorizationStub stub2 = stubManager.getStub(sameAsConfig1); + + assertThat(stub1).isSameInstanceAs(stub2); + verify(channelFactory, times(1)).createChannel(config1.grpcService()); + } + + @Test + public void getStub_createsNewStubAndShutsDownOld_differentConfig() { + AuthorizationGrpc.AuthorizationStub stub1 = stubManager.getStub(config1); + AuthorizationGrpc.AuthorizationStub stub2 = stubManager.getStub(config2); + + assertThat(stub1).isNotSameInstanceAs(stub2); + verify(channelFactory).createChannel(config1.grpcService()); + verify(channelFactory).createChannel(config2.grpcService()); + verify(channel1).shutdown(); + verify(channel2, never()).shutdown(); + } + + @Test + public void getStub_createsNewStubAndShutsDownOld_differentTarget() + throws ExtAuthzParseException { + config2 = buildExtAuthzConfig("target1-different"); + when(channelFactory.createChannel(config2.grpcService())).thenReturn(channel2); + + AuthorizationGrpc.AuthorizationStub stub1 = stubManager.getStub(config1); + AuthorizationGrpc.AuthorizationStub stub2 = stubManager.getStub(config2); + + assertThat(stub1).isNotSameInstanceAs(stub2); + verify(channelFactory).createChannel(config1.grpcService()); + verify(channelFactory).createChannel(config2.grpcService()); + verify(channel1).shutdown(); + } + + @Test + public void getStub_createsNewStubAndShutsDownOld_differentCredentialsHash() + throws ExtAuthzParseException { + when(channelFactory.createChannel(config2.grpcService())).thenReturn(channel2); + + AuthorizationGrpc.AuthorizationStub stub1 = stubManager.getStub(config1); + AuthorizationGrpc.AuthorizationStub stub2 = stubManager.getStub(config2); + + assertThat(stub1).isNotSameInstanceAs(stub2); + verify(channelFactory).createChannel(config1.grpcService()); + verify(channelFactory).createChannel(config2.grpcService()); + verify(channel1).shutdown(); + } + + @Test + public void close_shutsDownChannel() { + stubManager.getStub(config1); + stubManager.close(); + verify(channel1).shutdown(); + } + + @Test + public void close_noStubCreated_doesNothing() { + stubManager.close(); + verify(channel1, never()).shutdown(); + } +}