From b12ed13b4e17b7e594680397449d08f7cb723936 Mon Sep 17 00:00:00 2001 From: Quentin Khan Date: Wed, 6 May 2026 05:16:15 -0700 Subject: [PATCH] Fix memory outdated planning optimization invalidated by reshapes. When XNNPack's memory planner finds a node that can be computed in place because the input and output shapes are the same, it aliases the output data pointer to the input buffer. After a reshape, this optimization may not be valid anymore. For instance, an input could require a broadcast and be smaller than its output: this happens in our own test cases and was hidden by the face that the inputs and outputs were external values (which disables the above optimization). PiperOrigin-RevId: 911271536 --- src/runtime.c | 20 +++++++++ test/subgraph/BUILD | 15 +++++++ test/subgraph/planning.cc | 88 +++++++++++++++++++++++++++++++++++++++ 3 files changed, 123 insertions(+) create mode 100644 test/subgraph/planning.cc diff --git a/src/runtime.c b/src/runtime.c index 14eb39f8f49..ed390e8c997 100644 --- a/src/runtime.c +++ b/src/runtime.c @@ -884,6 +884,26 @@ enum xnn_status xnn_reshape_runtime(xnn_runtime_t runtime) { xnn_operator_type_to_string_v2(opdata->operator_objects[0])); return status; } + + for (size_t i = 0; i < opdata->num_inputs && !reallocation_required; i++) { + const uint32_t input_id = opdata->inputs[i]; + if (input_id == XNN_INVALID_VALUE_ID) { + continue; + } + for (size_t j = 0; j < opdata->num_outputs && !reallocation_required; j++) { + const uint32_t output_id = opdata->outputs[j]; + if (output_id == XNN_INVALID_VALUE_ID) { + continue; + } + const struct xnn_runtime_value* input = &runtime->values[input_id]; + const struct xnn_runtime_value* output = &runtime->values[output_id]; + if (input->data != NULL && input->data == output->data) { + if (xnn_runtime_tensor_get_size(input) != xnn_runtime_tensor_get_size(output)) { + reallocation_required = true; + } + } + } + } } if (reallocation_required || !runtime->memory_planned) { runtime->memory_planned = true; diff --git a/test/subgraph/BUILD b/test/subgraph/BUILD index 63c89880b65..40b63e20f77 100644 --- a/test/subgraph/BUILD +++ b/test/subgraph/BUILD @@ -364,3 +364,18 @@ xnnpack_unit_test( "//test:replicable_random_device", ], ) + +xnnpack_unit_test( + name = "planning_test", + srcs = ["planning.cc"], + tags = ["no_ynnpack"], + deps = [ + "//:xnnpack_h", + "//litert/tensor", + "//litert/tensor:arithmetic", + "//litert/tensor:datatypes", + "//litert/tensor/backends/xnnpack:arithmetic", + "//litert/tensor/backends/xnnpack:conversion", + "//litert/tensor/utils:matchers_no_g3", + ], +) diff --git a/test/subgraph/planning.cc b/test/subgraph/planning.cc new file mode 100644 index 00000000000..718d7b0d9f2 --- /dev/null +++ b/test/subgraph/planning.cc @@ -0,0 +1,88 @@ +// Copyright 2026 Google LLC +// +// This source code is licensed under the BSD-style license found in the +// LICENSE file in the root directory of this source tree. + +#include +#include +#include +#include +#include + +#include +#include +#include "include/xnnpack.h" +#include "litert/tensor/arithmetic.h" +#include "litert/tensor/backends/xnnpack/arithmetic.h" +#include "litert/tensor/backends/xnnpack/conversion.h" +#include "litert/tensor/datatypes.h" +#include "litert/tensor/tensor.h" +#include "litert/tensor/utils/matchers.h" + +namespace xnnpack { +namespace { + +namespace lrt = ::litert::tensor; +using XTensor = lrt::Tensor; + +using testing::ElementsAreArray; + +TEST(PlanningTest, ReshapingToBroadcastWorks) { + std::unique_ptr graph; + uint32_t a_id = XNN_INVALID_VALUE_ID; + uint32_t b_id = XNN_INVALID_VALUE_ID; + uint32_t c_id = XNN_INVALID_VALUE_ID; + + { + XTensor a({.type = lrt::Type::kI16, .shape = {3, 3}}); + XTensor b({.type = lrt::Type::kI16, .shape = {3, 3}}); + + XTensor c = Add(Cast(a, lrt::Type::kFP32), Cast(b, lrt::Type::kFP32)); + c.SetShape({3, 3}); + c = Cast(c, lrt::Type::kI16); + + LRT_TENSOR_ASSERT_OK_AND_ASSIGN(graph, lrt::BuildXnnpackGraph({c})); + LRT_TENSOR_ASSERT_OK_AND_ASSIGN(const size_t a_idx, graph->Lookup(a)); + LRT_TENSOR_ASSERT_OK_AND_ASSIGN(const size_t b_idx, graph->Lookup(b)); + LRT_TENSOR_ASSERT_OK_AND_ASSIGN(const size_t c_idx, graph->Lookup(c)); + a_id = graph->values()[a_idx].id; + b_id = graph->values()[b_idx].id; + c_id = graph->values()[c_idx].id; + } + + xnn_runtime_t runtime; + xnn_create_runtime_v4(graph->subgraph(), /*weights_cache=*/nullptr, + /*workspace=*/nullptr, /*threadpool=*/nullptr, + /*flags=*/0, &runtime); + + std::array a_data{1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::array b_data{1, 2, 3, 4, 5, 6, 7, 8, 9}; + std::array c_data{}; + + std::array values{ + xnn_external_value{.id = a_id, .data = a_data.data()}, + xnn_external_value{.id = b_id, .data = b_data.data()}, + xnn_external_value{.id = c_id, .data = c_data.data()}, + }; + + xnn_reshape_runtime(runtime); + xnn_setup_runtime_v2(runtime, values.size(), values.data()); + xnn_invoke_runtime(runtime); + + EXPECT_THAT(c_data, ElementsAreArray({2, 4, 6, 8, 10, 12, 14, 16, 18})); + + // Change `a` so that the operation now needs a broadcast. The internal buffer + // after the cast can't be reused by the add op to write its output to. + std::array new_a_dims{3, 1}; + xnn_reshape_external_value(runtime, a_id, new_a_dims.size(), + new_a_dims.data()); + + xnn_reshape_runtime(runtime); + xnn_setup_runtime_v2(runtime, values.size(), values.data()); + xnn_invoke_runtime(runtime); + + EXPECT_THAT(c_data, ElementsAreArray({2, 3, 4, 6, 7, 8, 10, 11, 12})); +} + +} // namespace +} // namespace xnnpack