Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,11 @@ class CPVar {
std::vector<Advisor> on_bounds;
std::vector<Advisor> on_array_size_change;

const dwave::optimization::ArrayNode* node_;

protected:
const CPModel& model_;

// but should I have this?
const dwave::optimization::ArrayNode* node_;
const ssize_t cp_var_index_;

class Listener : public DomainListener {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
// Copyright 2026 D-Wave Systems Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once

#include "dwave-optimization/cp/core/cpvar.hpp"
#include "dwave-optimization/cp/core/propagator.hpp"
#include "dwave-optimization/nodes/indexing.hpp"

namespace dwave::optimization::cp {

struct BasicIndexingForwardTransform : IndexTransform {
BasicIndexingForwardTransform(const ArrayNode* array_ptr, const BasicIndexingNode* bi_ptr);

void affected(ssize_t i, std::vector<ssize_t>& out) override;

const ArrayNode* array_ptr_;
const BasicIndexingNode* bi_ptr_;
std::vector<BasicIndexingNode::slice_or_int> slices;
};

class BasicIndexingPropagator : public Propagator {
public:
BasicIndexingPropagator(ssize_t index, CPVar* array, CPVar* basic_indexing);

void initialize_state(CPState& state) const override;
CPStatus propagate(CPPropagatorsState& p_state, CPVarsState& v_state) const override;

private:
CPVar* array_;
CPVar* basic_indexing_;
};

} // namespace dwave::optimization::cp
2 changes: 1 addition & 1 deletion dwave/optimization/src/cp/core/cpvar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ namespace dwave::optimization::cp {
// ------ CPVar -------

CPVar::CPVar(const CPModel& model, const dwave::optimization::ArrayNode* node_ptr, int index)
: model_(model), node_(node_ptr), cp_var_index_(index) {}
: node_(node_ptr), model_(model), cp_var_index_(index) {}

double CPVar::min(const CPVarsState& state, int index) const {
const CPVarData* data = data_ptr<CPVarData>(state);
Expand Down
153 changes: 153 additions & 0 deletions dwave/optimization/src/cp/propagators/indexing_propagators.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
// Copyright 2026 D-Wave Systems Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "dwave-optimization/cp/propagators/indexing_propagators.hpp"

#include "dwave-optimization/nodes/indexing.hpp"

namespace dwave::optimization::cp {

BasicIndexingForwardTransform::BasicIndexingForwardTransform(const ArrayNode* array_ptr,
const BasicIndexingNode* bi_ptr)
: array_ptr_(array_ptr), bi_ptr_(bi_ptr) {
slices = bi_ptr->infer_indices();
for (ssize_t axis = 0; axis < array_ptr->ndim(); ++axis) {
if (std::holds_alternative<Slice>(slices[axis])) {
if (std::get<Slice>(slices[axis]).step != 1) {
throw std::invalid_argument("step != 1 not supported");
}

slices[axis] = std::get<Slice>(slices[axis]).fit(array_ptr->shape()[axis]);
}
}
}

void BasicIndexingForwardTransform::affected(ssize_t i, std::vector<ssize_t>& out) {
std::vector<ssize_t> in_multi_index = unravel_index(i, array_ptr_->shape());
std::vector<ssize_t> out_multi_index;
bool belongs = true;
// Iterate through the axes to see if any index is outside the slice
for (ssize_t axis = 0; axis < array_ptr_->ndim(); ++axis) {
if (std::holds_alternative<ssize_t>(slices[axis])) {
if (in_multi_index[axis] == std::get<ssize_t>(slices[axis])) continue;
} else {
const auto& slice = std::get<Slice>(slices[axis]);
if (in_multi_index[axis] >= slice.start and in_multi_index[axis] < slice.stop) {
out_multi_index.push_back(in_multi_index[axis] - slice.start);
continue;
}
}

belongs = false;
break;
}

if (belongs) {
out.push_back(ravel_multi_index(out_multi_index, bi_ptr_->shape()));
}
}

BasicIndexingPropagator::BasicIndexingPropagator(ssize_t index, CPVar* array, CPVar* basic_indexing)
: Propagator(index) {
// TODO: not supporting dynamic variables for now
if (array->min_size() != array->max_size()) {
throw std::invalid_argument("dynamic arrays not supported");
}

array_ = array;
basic_indexing_ = basic_indexing;
}

void BasicIndexingPropagator::initialize_state(CPState& state) const {
CPPropagatorsState& p_state = state.get_propagators_state();
assert(propagator_index_ >= 0);
assert(propagator_index_ < static_cast<ssize_t>(p_state.size()));
p_state[propagator_index_] = std::make_unique<PropagatorData>(state.get_state_manager(),
basic_indexing_->max_size());
}

CPStatus BasicIndexingPropagator::propagate(CPPropagatorsState& p_state,
CPVarsState& v_state) const {
auto data = data_ptr<PropagatorData>(p_state);

const BasicIndexingNode* bi = dynamic_cast<const BasicIndexingNode*>(basic_indexing_->node_);
assert(bi);

// Not caching this for now as we may need to fit these at propagate time for
// dynamic arrays
std::vector<BasicIndexingNode::slice_or_int> slices = bi->infer_indices();
for (ssize_t axis = 0; axis < array_->node_->ndim(); ++axis) {
if (std::holds_alternative<Slice>(slices[axis])) {
assert(std::get<Slice>(slices[axis]).step == 1);
slices[axis] = std::get<Slice>(slices[axis]).fit(array_->node_->shape()[axis]);
}
}

std::deque<ssize_t>& indices_to_process = data->indices_to_process();

assert(indices_to_process.size() > 0);
while (indices_to_process.size() > 0) {
ssize_t bi_index = indices_to_process.front();
indices_to_process.pop_front();

// Derive the original array index based on the index of the basic indexing variable.
// We unravel the basic indexing variable index, transform the multi-index into
// one on the original array, and then ravel it to get the final linear index on
// the array.
std::vector<ssize_t> bi_multi_index =
unravel_index(bi_index, basic_indexing_->node_->shape());
std::vector<ssize_t> arr_multi_index;
ssize_t bi_axis = 0;
for (ssize_t axis = 0; axis < array_->node_->ndim(); ++axis) {
if (std::holds_alternative<ssize_t>(slices[axis])) {
arr_multi_index.push_back(std::get<ssize_t>(slices[axis]));
continue;
}
assert(std::holds_alternative<Slice>(slices[axis]));
const auto& slice = std::get<Slice>(slices[axis]);
assert(slice.step == 1);
arr_multi_index.push_back(bi_multi_index[bi_axis] + slice.start);
bi_axis++;
}
ssize_t array_index = ravel_multi_index(arr_multi_index, array_->node_->shape());

// Now we make the bounds of the array element and the basic indexing element equal

// Make the upper bounds consistent
if (CPStatus status = basic_indexing_->remove_above(
v_state, array_->max(v_state, array_index), bi_index);
not status)
return status;
if (CPStatus status = array_->remove_above(v_state, basic_indexing_->max(v_state, bi_index),
array_index);
not status)
return status;

// Make the lower bounds consistent
if (CPStatus status = basic_indexing_->remove_below(
v_state, array_->min(v_state, array_index), bi_index);
not status)
return status;
if (CPStatus status = array_->remove_below(v_state, basic_indexing_->min(v_state, bi_index),
array_index);
not status)
return status;

data->set_scheduled(false, bi_index);
}

return CPStatus::OK;
}

} // namespace dwave::optimization::cp
1 change: 1 addition & 0 deletions meson.build
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ dwave_optimization_src = [

'dwave/optimization/src/cp/propagators/binaryop.cpp',
'dwave/optimization/src/cp/propagators/identity_propagator.cpp',
'dwave/optimization/src/cp/propagators/indexing_propagators.cpp',
'dwave/optimization/src/cp/propagators/reduce.cpp',

'dwave/optimization/src/cp/state/copier.cpp',
Expand Down
156 changes: 156 additions & 0 deletions tests/cpp/cp/test_propagator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include "dwave-optimization/cp/core/index_transform.hpp"
#include "dwave-optimization/cp/core/interval_array.hpp"
#include "dwave-optimization/cp/propagators/identity_propagator.hpp"
#include "dwave-optimization/cp/propagators/indexing_propagators.hpp"
#include "dwave-optimization/cp/state/copier.hpp"
#include "dwave-optimization/nodes.hpp"
#include "dwave-optimization/state.hpp"
Expand Down Expand Up @@ -109,4 +110,159 @@ TEST_CASE("ElementWiseIdentityPropagator") {
}
}
}

TEST_CASE("BasicIndexingPropagator") {
using namespace dwave::optimization;

GIVEN("A dwopt graph with basic indexing") {
Graph graph;
auto i = graph.emplace_node<IntegerNode>(5, -3, 4);
auto b = graph.emplace_node<BasicIndexingNode>(i, Slice(1, 4));

// Lock the graph
graph.topological_sort();

// Construct the CP corresponding model
AND_GIVEN("The CP Model") {
CPModel model;

// Add the variabbles to the model
CPVar* i_var = model.emplace_variable<CPVar>(model, i, i->topological_index());
CPVar* b_var = model.emplace_variable<CPVar>(model, b, b->topological_index());

Propagator* p = model.emplace_propagator<BasicIndexingPropagator>(
model.num_propagators(), i_var, b_var);

// build the advisor for the propagator p aimed to the variable for i
Advisor advisor_i(p, 0, std::make_unique<BasicIndexingForwardTransform>(i, b));
i_var->propagate_on_domain_change(std::move(advisor_i));

// build the advisor for the propagator p aimed to the variable for b
Advisor advisor_b(p, 1, std::make_unique<ElementWiseTransform>());
b_var->propagate_on_domain_change(std::move(advisor_b));

REQUIRE(i_var->on_domain.size() == 1);
REQUIRE(b_var->on_domain.size() == 1);

WHEN("We initialize a state") {
CPState state = model.initialize_state<Copier>();
CPVarsState& s_state = state.get_variables_state();
CPPropagatorsState& p_state = state.get_propagators_state();

REQUIRE(s_state.size() == 2);
REQUIRE(p_state.size() == 1);

i_var->initialize_state(state);
b_var->initialize_state(state);
p->initialize_state(state);

AND_WHEN("We alter the domain of the integer variable inside the slice") {
CPStatus status = i_var->assign(s_state, -2, 0);
REQUIRE(status == CPStatus::OK);
THEN("We see that the propagator is not triggered") {
CHECK(not p_state[0]->scheduled());
CHECK(p_state[0]->indices_to_process().size() == 0);
}
}

AND_WHEN("We alter the domain of the integer variable inside the slice") {
CPStatus status = i_var->assign(s_state, -2, 3);
REQUIRE(status == CPStatus::OK);
THEN("We see that the propagator is triggered to run on the same index") {
REQUIRE(p_state[0]->scheduled());
REQUIRE(p_state[0]->indices_to_process().size() == 1);

CHECK(p_state[0]->scheduled(2));
}

AND_WHEN("We call the fix point engine") {
CPEngine engine;
engine.fix_point(state);

THEN("The sum output variable 2 is correctly fixed") {
CHECK(b_var->min(s_state, 2) == -2);
CHECK(b_var->max(s_state, 2) == -2);
}
}
}
}
}
}

GIVEN("A dwopt graph with basic indexing on a 2d array") {
Graph graph;
auto i = graph.emplace_node<IntegerNode>(std::initializer_list<ssize_t>{4, 7}, -3, 4);
auto b = graph.emplace_node<BasicIndexingNode>(i, 2, Slice(1, 4));

// Lock the graph
graph.topological_sort();

// Construct the CP corresponding model
AND_GIVEN("The CP Model") {
CPModel model;

// Add the variabbles to the model
CPVar* i_var = model.emplace_variable<CPVar>(model, i, i->topological_index());
CPVar* b_var = model.emplace_variable<CPVar>(model, b, b->topological_index());

Propagator* p = model.emplace_propagator<BasicIndexingPropagator>(
model.num_propagators(), i_var, b_var);

// build the advisor for the propagator p aimed to the variable for i
Advisor advisor_i(p, 0, std::make_unique<BasicIndexingForwardTransform>(i, b));
i_var->propagate_on_domain_change(std::move(advisor_i));

// build the advisor for the propagator p aimed to the variable for b
Advisor advisor_b(p, 1, std::make_unique<ElementWiseTransform>());
b_var->propagate_on_domain_change(std::move(advisor_b));

REQUIRE(i_var->on_domain.size() == 1);
REQUIRE(b_var->on_domain.size() == 1);

WHEN("We initialize a state") {
CPState state = model.initialize_state<Copier>();
CPVarsState& s_state = state.get_variables_state();
CPPropagatorsState& p_state = state.get_propagators_state();

REQUIRE(s_state.size() == 2);
REQUIRE(p_state.size() == 1);

i_var->initialize_state(state);
b_var->initialize_state(state);
p->initialize_state(state);

AND_WHEN("We alter the domain of the integer variable inside the slice") {
CPStatus status = i_var->assign(s_state, -2, 0);
REQUIRE(status == CPStatus::OK);
THEN("We see that the propagator is not triggered") {
CHECK(not p_state[0]->scheduled());
CHECK(p_state[0]->indices_to_process().size() == 0);
}
}

AND_WHEN("We alter the domain of the integer variable inside the slice") {
CPStatus status = i_var->assign(s_state, -2, 16);
REQUIRE(status == CPStatus::OK);
THEN("We see that the propagator is triggered to run on the same index") {
REQUIRE(p_state[0]->scheduled());
REQUIRE(p_state[0]->indices_to_process().size() == 1);

CHECK(p_state[0]->scheduled(1));
}

AND_WHEN("We call the fix point engine") {
CPEngine engine;
engine.fix_point(state);

THEN("The sum output variable 2 is correctly fixed") {
CHECK(b_var->min(s_state, 1) == -2);
CHECK(b_var->max(s_state, 1) == -2);
}
}
}
}
}
}
}

} // namespace dwave::optimization::cp