diff --git a/java-spanner/google-cloud-spanner-executor/src/main/java/com/google/cloud/executor/spanner/CloudClientExecutor.java b/java-spanner/google-cloud-spanner-executor/src/main/java/com/google/cloud/executor/spanner/CloudClientExecutor.java index a9323fbdc25f..96cc0d06e994 100644 --- a/java-spanner/google-cloud-spanner-executor/src/main/java/com/google/cloud/executor/spanner/CloudClientExecutor.java +++ b/java-spanner/google-cloud-spanner-executor/src/main/java/com/google/cloud/executor/spanner/CloudClientExecutor.java @@ -166,17 +166,22 @@ import java.io.ByteArrayInputStream; import java.io.File; import java.io.IOException; +import java.io.InvalidClassException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; import java.io.Serializable; import java.math.BigDecimal; import java.text.ParseException; import java.time.Duration; import java.time.LocalDate; import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Set; import java.util.UUID; import java.util.concurrent.ExecutionException; import java.util.concurrent.Executor; @@ -3795,11 +3800,38 @@ private static com.google.spanner.v1.Type cloudTypeToTypeProto(@Nonnull Type clo } } + /** Define the Allowlist: ONLY allow specific, safe internal classes */ + private static final Set ALLOWED_CLASSES = new HashSet<>(Arrays.asList( + "com.google.cloud.spanner.PartitionOptions", + "com.google.cloud.spanner.Partition", + "com.google.cloud.spanner.BatchTransactionId", + "java.util.ArrayList", + "java.lang.Number", + "java.lang.Long", + "java.lang.Integer" + )); + /** Unmarshall ByteString to serializable object. */ - private T unmarshall(ByteString input) - throws IOException, ClassNotFoundException { - ObjectInputStream objectInputStream = new ObjectInputStream(input.newInput()); - return (T) objectInputStream.readObject(); + private T unmarshall(ByteString input) throws IOException, ClassNotFoundException { + try (InputStream is = input.newInput(); + ObjectInputStream ois = new SecureObjectInputStream(is)) { + return (T) ois.readObject(); + } + } + + /** The "Look-ahead" Filter */ + private static class SecureObjectInputStream extends ObjectInputStream { + public SecureObjectInputStream(InputStream in) throws IOException { + super(in); + } + + @Override + protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { + if (!ALLOWED_CLASSES.contains(desc.getName())) { + throw new InvalidClassException("Unauthorized deserialization attempt: ", desc.getName()); + } + return super.resolveClass(desc); + } } /** Marshall a serializable object into ByteString. */