diff --git a/dwave/optimization/include/dwave-optimization/cp/core/cpvar.hpp b/dwave/optimization/include/dwave-optimization/cp/core/cpvar.hpp index bbd6c0d2c..6eca0cc13 100644 --- a/dwave/optimization/include/dwave-optimization/cp/core/cpvar.hpp +++ b/dwave/optimization/include/dwave-optimization/cp/core/cpvar.hpp @@ -154,11 +154,11 @@ class CPVar { std::vector on_bounds; std::vector on_array_size_change; + const dwave::optimization::ArrayNode* node_; + protected: const CPModel& model_; - // but should I have this? - const dwave::optimization::ArrayNode* node_; const ssize_t cp_var_index_; class Listener : public DomainListener { diff --git a/dwave/optimization/include/dwave-optimization/cp/propagators/indexing_propagators.hpp b/dwave/optimization/include/dwave-optimization/cp/propagators/indexing_propagators.hpp new file mode 100644 index 000000000..53c9b5f27 --- /dev/null +++ b/dwave/optimization/include/dwave-optimization/cp/propagators/indexing_propagators.hpp @@ -0,0 +1,44 @@ +// Copyright 2026 D-Wave Systems Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "dwave-optimization/cp/core/cpvar.hpp" +#include "dwave-optimization/cp/core/propagator.hpp" +#include "dwave-optimization/nodes/indexing.hpp" + +namespace dwave::optimization::cp { + +struct BasicIndexingForwardTransform : IndexTransform { + BasicIndexingForwardTransform(const ArrayNode* array_ptr, const BasicIndexingNode* bi_ptr); + + void affected(ssize_t i, std::vector& out) override; + + const ArrayNode* array_ptr_; + const BasicIndexingNode* bi_ptr_; + std::vector slices; +}; + +class BasicIndexingPropagator : public Propagator { + public: + BasicIndexingPropagator(ssize_t index, CPVar* array, CPVar* basic_indexing); + + void initialize_state(CPState& state) const override; + CPStatus propagate(CPPropagatorsState& p_state, CPVarsState& v_state) const override; + + private: + CPVar* array_; + CPVar* basic_indexing_; +}; + +} // namespace dwave::optimization::cp diff --git a/dwave/optimization/src/cp/core/cpvar.cpp b/dwave/optimization/src/cp/core/cpvar.cpp index b1b421706..982f5a221 100644 --- a/dwave/optimization/src/cp/core/cpvar.cpp +++ b/dwave/optimization/src/cp/core/cpvar.cpp @@ -20,7 +20,7 @@ namespace dwave::optimization::cp { // ------ CPVar ------- CPVar::CPVar(const CPModel& model, const dwave::optimization::ArrayNode* node_ptr, int index) - : model_(model), node_(node_ptr), cp_var_index_(index) {} + : node_(node_ptr), model_(model), cp_var_index_(index) {} double CPVar::min(const CPVarsState& state, int index) const { const CPVarData* data = data_ptr(state); diff --git a/dwave/optimization/src/cp/propagators/indexing_propagators.cpp b/dwave/optimization/src/cp/propagators/indexing_propagators.cpp new file mode 100644 index 000000000..ed21b0445 --- /dev/null +++ b/dwave/optimization/src/cp/propagators/indexing_propagators.cpp @@ -0,0 +1,153 @@ +// Copyright 2026 D-Wave Systems Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "dwave-optimization/cp/propagators/indexing_propagators.hpp" + +#include "dwave-optimization/nodes/indexing.hpp" + +namespace dwave::optimization::cp { + +BasicIndexingForwardTransform::BasicIndexingForwardTransform(const ArrayNode* array_ptr, + const BasicIndexingNode* bi_ptr) + : array_ptr_(array_ptr), bi_ptr_(bi_ptr) { + slices = bi_ptr->infer_indices(); + for (ssize_t axis = 0; axis < array_ptr->ndim(); ++axis) { + if (std::holds_alternative(slices[axis])) { + if (std::get(slices[axis]).step != 1) { + throw std::invalid_argument("step != 1 not supported"); + } + + slices[axis] = std::get(slices[axis]).fit(array_ptr->shape()[axis]); + } + } +} + +void BasicIndexingForwardTransform::affected(ssize_t i, std::vector& out) { + std::vector in_multi_index = unravel_index(i, array_ptr_->shape()); + std::vector out_multi_index; + bool belongs = true; + // Iterate through the axes to see if any index is outside the slice + for (ssize_t axis = 0; axis < array_ptr_->ndim(); ++axis) { + if (std::holds_alternative(slices[axis])) { + if (in_multi_index[axis] == std::get(slices[axis])) continue; + } else { + const auto& slice = std::get(slices[axis]); + if (in_multi_index[axis] >= slice.start and in_multi_index[axis] < slice.stop) { + out_multi_index.push_back(in_multi_index[axis] - slice.start); + continue; + } + } + + belongs = false; + break; + } + + if (belongs) { + out.push_back(ravel_multi_index(out_multi_index, bi_ptr_->shape())); + } +} + +BasicIndexingPropagator::BasicIndexingPropagator(ssize_t index, CPVar* array, CPVar* basic_indexing) + : Propagator(index) { + // TODO: not supporting dynamic variables for now + if (array->min_size() != array->max_size()) { + throw std::invalid_argument("dynamic arrays not supported"); + } + + array_ = array; + basic_indexing_ = basic_indexing; +} + +void BasicIndexingPropagator::initialize_state(CPState& state) const { + CPPropagatorsState& p_state = state.get_propagators_state(); + assert(propagator_index_ >= 0); + assert(propagator_index_ < static_cast(p_state.size())); + p_state[propagator_index_] = std::make_unique(state.get_state_manager(), + basic_indexing_->max_size()); +} + +CPStatus BasicIndexingPropagator::propagate(CPPropagatorsState& p_state, + CPVarsState& v_state) const { + auto data = data_ptr(p_state); + + const BasicIndexingNode* bi = dynamic_cast(basic_indexing_->node_); + assert(bi); + + // Not caching this for now as we may need to fit these at propagate time for + // dynamic arrays + std::vector slices = bi->infer_indices(); + for (ssize_t axis = 0; axis < array_->node_->ndim(); ++axis) { + if (std::holds_alternative(slices[axis])) { + assert(std::get(slices[axis]).step == 1); + slices[axis] = std::get(slices[axis]).fit(array_->node_->shape()[axis]); + } + } + + std::deque& indices_to_process = data->indices_to_process(); + + assert(indices_to_process.size() > 0); + while (indices_to_process.size() > 0) { + ssize_t bi_index = indices_to_process.front(); + indices_to_process.pop_front(); + + // Derive the original array index based on the index of the basic indexing variable. + // We unravel the basic indexing variable index, transform the multi-index into + // one on the original array, and then ravel it to get the final linear index on + // the array. + std::vector bi_multi_index = + unravel_index(bi_index, basic_indexing_->node_->shape()); + std::vector arr_multi_index; + ssize_t bi_axis = 0; + for (ssize_t axis = 0; axis < array_->node_->ndim(); ++axis) { + if (std::holds_alternative(slices[axis])) { + arr_multi_index.push_back(std::get(slices[axis])); + continue; + } + assert(std::holds_alternative(slices[axis])); + const auto& slice = std::get(slices[axis]); + assert(slice.step == 1); + arr_multi_index.push_back(bi_multi_index[bi_axis] + slice.start); + bi_axis++; + } + ssize_t array_index = ravel_multi_index(arr_multi_index, array_->node_->shape()); + + // Now we make the bounds of the array element and the basic indexing element equal + + // Make the upper bounds consistent + if (CPStatus status = basic_indexing_->remove_above( + v_state, array_->max(v_state, array_index), bi_index); + not status) + return status; + if (CPStatus status = array_->remove_above(v_state, basic_indexing_->max(v_state, bi_index), + array_index); + not status) + return status; + + // Make the lower bounds consistent + if (CPStatus status = basic_indexing_->remove_below( + v_state, array_->min(v_state, array_index), bi_index); + not status) + return status; + if (CPStatus status = array_->remove_below(v_state, basic_indexing_->min(v_state, bi_index), + array_index); + not status) + return status; + + data->set_scheduled(false, bi_index); + } + + return CPStatus::OK; +} + +} // namespace dwave::optimization::cp diff --git a/meson.build b/meson.build index 78403b998..c8461a64c 100644 --- a/meson.build +++ b/meson.build @@ -62,6 +62,7 @@ dwave_optimization_src = [ 'dwave/optimization/src/cp/propagators/binaryop.cpp', 'dwave/optimization/src/cp/propagators/identity_propagator.cpp', + 'dwave/optimization/src/cp/propagators/indexing_propagators.cpp', 'dwave/optimization/src/cp/propagators/reduce.cpp', 'dwave/optimization/src/cp/state/copier.cpp', diff --git a/tests/cpp/cp/test_propagator.cpp b/tests/cpp/cp/test_propagator.cpp index ce0f7c03f..a8fe9e1f5 100644 --- a/tests/cpp/cp/test_propagator.cpp +++ b/tests/cpp/cp/test_propagator.cpp @@ -23,6 +23,7 @@ #include "dwave-optimization/cp/core/index_transform.hpp" #include "dwave-optimization/cp/core/interval_array.hpp" #include "dwave-optimization/cp/propagators/identity_propagator.hpp" +#include "dwave-optimization/cp/propagators/indexing_propagators.hpp" #include "dwave-optimization/cp/state/copier.hpp" #include "dwave-optimization/nodes.hpp" #include "dwave-optimization/state.hpp" @@ -109,4 +110,159 @@ TEST_CASE("ElementWiseIdentityPropagator") { } } } + +TEST_CASE("BasicIndexingPropagator") { + using namespace dwave::optimization; + + GIVEN("A dwopt graph with basic indexing") { + Graph graph; + auto i = graph.emplace_node(5, -3, 4); + auto b = graph.emplace_node(i, Slice(1, 4)); + + // Lock the graph + graph.topological_sort(); + + // Construct the CP corresponding model + AND_GIVEN("The CP Model") { + CPModel model; + + // Add the variabbles to the model + CPVar* i_var = model.emplace_variable(model, i, i->topological_index()); + CPVar* b_var = model.emplace_variable(model, b, b->topological_index()); + + Propagator* p = model.emplace_propagator( + model.num_propagators(), i_var, b_var); + + // build the advisor for the propagator p aimed to the variable for i + Advisor advisor_i(p, 0, std::make_unique(i, b)); + i_var->propagate_on_domain_change(std::move(advisor_i)); + + // build the advisor for the propagator p aimed to the variable for b + Advisor advisor_b(p, 1, std::make_unique()); + b_var->propagate_on_domain_change(std::move(advisor_b)); + + REQUIRE(i_var->on_domain.size() == 1); + REQUIRE(b_var->on_domain.size() == 1); + + WHEN("We initialize a state") { + CPState state = model.initialize_state(); + CPVarsState& s_state = state.get_variables_state(); + CPPropagatorsState& p_state = state.get_propagators_state(); + + REQUIRE(s_state.size() == 2); + REQUIRE(p_state.size() == 1); + + i_var->initialize_state(state); + b_var->initialize_state(state); + p->initialize_state(state); + + AND_WHEN("We alter the domain of the integer variable inside the slice") { + CPStatus status = i_var->assign(s_state, -2, 0); + REQUIRE(status == CPStatus::OK); + THEN("We see that the propagator is not triggered") { + CHECK(not p_state[0]->scheduled()); + CHECK(p_state[0]->indices_to_process().size() == 0); + } + } + + AND_WHEN("We alter the domain of the integer variable inside the slice") { + CPStatus status = i_var->assign(s_state, -2, 3); + REQUIRE(status == CPStatus::OK); + THEN("We see that the propagator is triggered to run on the same index") { + REQUIRE(p_state[0]->scheduled()); + REQUIRE(p_state[0]->indices_to_process().size() == 1); + + CHECK(p_state[0]->scheduled(2)); + } + + AND_WHEN("We call the fix point engine") { + CPEngine engine; + engine.fix_point(state); + + THEN("The sum output variable 2 is correctly fixed") { + CHECK(b_var->min(s_state, 2) == -2); + CHECK(b_var->max(s_state, 2) == -2); + } + } + } + } + } + } + + GIVEN("A dwopt graph with basic indexing on a 2d array") { + Graph graph; + auto i = graph.emplace_node(std::initializer_list{4, 7}, -3, 4); + auto b = graph.emplace_node(i, 2, Slice(1, 4)); + + // Lock the graph + graph.topological_sort(); + + // Construct the CP corresponding model + AND_GIVEN("The CP Model") { + CPModel model; + + // Add the variabbles to the model + CPVar* i_var = model.emplace_variable(model, i, i->topological_index()); + CPVar* b_var = model.emplace_variable(model, b, b->topological_index()); + + Propagator* p = model.emplace_propagator( + model.num_propagators(), i_var, b_var); + + // build the advisor for the propagator p aimed to the variable for i + Advisor advisor_i(p, 0, std::make_unique(i, b)); + i_var->propagate_on_domain_change(std::move(advisor_i)); + + // build the advisor for the propagator p aimed to the variable for b + Advisor advisor_b(p, 1, std::make_unique()); + b_var->propagate_on_domain_change(std::move(advisor_b)); + + REQUIRE(i_var->on_domain.size() == 1); + REQUIRE(b_var->on_domain.size() == 1); + + WHEN("We initialize a state") { + CPState state = model.initialize_state(); + CPVarsState& s_state = state.get_variables_state(); + CPPropagatorsState& p_state = state.get_propagators_state(); + + REQUIRE(s_state.size() == 2); + REQUIRE(p_state.size() == 1); + + i_var->initialize_state(state); + b_var->initialize_state(state); + p->initialize_state(state); + + AND_WHEN("We alter the domain of the integer variable inside the slice") { + CPStatus status = i_var->assign(s_state, -2, 0); + REQUIRE(status == CPStatus::OK); + THEN("We see that the propagator is not triggered") { + CHECK(not p_state[0]->scheduled()); + CHECK(p_state[0]->indices_to_process().size() == 0); + } + } + + AND_WHEN("We alter the domain of the integer variable inside the slice") { + CPStatus status = i_var->assign(s_state, -2, 16); + REQUIRE(status == CPStatus::OK); + THEN("We see that the propagator is triggered to run on the same index") { + REQUIRE(p_state[0]->scheduled()); + REQUIRE(p_state[0]->indices_to_process().size() == 1); + + CHECK(p_state[0]->scheduled(1)); + } + + AND_WHEN("We call the fix point engine") { + CPEngine engine; + engine.fix_point(state); + + THEN("The sum output variable 2 is correctly fixed") { + CHECK(b_var->min(s_state, 1) == -2); + CHECK(b_var->max(s_state, 1) == -2); + } + } + } + } + } + } +} + } // namespace dwave::optimization::cp