From b82cd22d7cd02776e3f00ddb55a327b7992068b7 Mon Sep 17 00:00:00 2001 From: Laurent Perron Date: Thu, 20 Jul 2023 15:13:56 -0700 Subject: [PATCH] [ModelBuilder] Add set_coefficient API, fix add_term to not create duplicate variable entries --- ortools/linear_solver/python/model_builder.py | 7 +++++- .../python/model_builder_helper.cc | 3 +++ .../python/model_builder_test.py | 22 +++++++++++++++++++ .../wrappers/model_builder_helper.cc | 15 +++++++++++++ .../wrappers/model_builder_helper.h | 2 ++ 5 files changed, 48 insertions(+), 1 deletion(-) diff --git a/ortools/linear_solver/python/model_builder.py b/ortools/linear_solver/python/model_builder.py index 27dff6f926f..c099242659b 100644 --- a/ortools/linear_solver/python/model_builder.py +++ b/ortools/linear_solver/python/model_builder.py @@ -722,8 +722,13 @@ def __str__(self): def __repr__(self): return self.__str__() + def set_coefficient(self, var: Variable, coeff: NumberT) -> None: + """Sets the coefficient of the variable in the constraint.""" + self.__helper.set_constraint_coefficient(self.__index, var.index, coeff) + def add_term(self, var: Variable, coeff: NumberT) -> None: - self.__helper.add_term_to_constraint(self.__index, var.index, coeff) + """Adds var * coeff to the constraint.""" + self.__helper.safe_add_term_to_constraint(self.__index, var.index, coeff) class ModelBuilder: diff --git a/ortools/linear_solver/python/model_builder_helper.cc b/ortools/linear_solver/python/model_builder_helper.cc index 49ef475d1c2..b270a64a777 100644 --- a/ortools/linear_solver/python/model_builder_helper.cc +++ b/ortools/linear_solver/python/model_builder_helper.cc @@ -305,6 +305,9 @@ PYBIND11_MODULE(model_builder_helper, m) { helper->AddConstraintTerm(ct_index, i, c); } }) + .def("safe_add_term_to_constraint", + &ModelBuilderHelper::SafeAddConstraintTerm, arg("ct_index"), + arg("var_index"), arg("coeff")) .def("set_constraint_name", &ModelBuilderHelper::SetConstraintName, arg("ct_index"), arg("name")) .def("set_constraint_coefficient", diff --git a/ortools/linear_solver/python/model_builder_test.py b/ortools/linear_solver/python/model_builder_test.py index 28f1b0962f3..6f6b510579c 100644 --- a/ortools/linear_solver/python/model_builder_test.py +++ b/ortools/linear_solver/python/model_builder_test.py @@ -306,6 +306,28 @@ def test_duplicate_variables(self): solver = mb.ModelSolver("sat") self.assertEqual(mb.SolveStatus.OPTIMAL, solver.solve(model)) + def test_add_term(self): + model = mb.ModelBuilder() + x = model.new_int_var(0.0, 4.0, "x") + y = model.new_int_var(0.0, 4.0, "y") + z = model.new_int_var(0.0, 4.0, "z") + t = model.new_int_var(0.0, 4.0, "t") + ct = model.add(x + 2 * y == 3) + self.assertEqual(ct.helper.constraint_var_indices(ct.index), [0, 1]) + self.assertEqual(ct.helper.constraint_coefficients(ct.index), [1, 2]) + ct.add_term(x, 2) + self.assertEqual(ct.helper.constraint_var_indices(ct.index), [0, 1]) + self.assertEqual(ct.helper.constraint_coefficients(ct.index), [3, 2]) + ct.set_coefficient(x, 5) + self.assertEqual(ct.helper.constraint_var_indices(ct.index), [0, 1]) + self.assertEqual(ct.helper.constraint_coefficients(ct.index), [5, 2]) + ct.add_term(z, 4) + self.assertEqual(ct.helper.constraint_var_indices(ct.index), [0, 1, 2]) + self.assertEqual(ct.helper.constraint_coefficients(ct.index), [5, 2, 4]) + ct.set_coefficient(t, -1) + self.assertEqual(ct.helper.constraint_var_indices(ct.index), [0, 1, 2, 3]) + self.assertEqual(ct.helper.constraint_coefficients(ct.index), [5, 2, 4, -1]) + def test_issue_3614(self): total_number_of_choices = 5 + 1 total_unique_products = 3 diff --git a/ortools/linear_solver/wrappers/model_builder_helper.cc b/ortools/linear_solver/wrappers/model_builder_helper.cc index 2f401eeaa1e..e3731363ac8 100644 --- a/ortools/linear_solver/wrappers/model_builder_helper.cc +++ b/ortools/linear_solver/wrappers/model_builder_helper.cc @@ -147,6 +147,21 @@ void ModelBuilderHelper::AddConstraintTerm(int ct_index, int var_index, ct_proto->add_coefficient(coeff); } +void ModelBuilderHelper::SafeAddConstraintTerm(int ct_index, int var_index, + double coeff) { + MPConstraintProto* ct_proto = model_.mutable_constraint(ct_index); + for (int i = 0; i < ct_proto->var_index_size(); ++i) { + if (ct_proto->var_index(i) == var_index) { + ct_proto->set_coefficient(i, coeff + ct_proto->coefficient(i)); + return; + } + } + // If we reach this point, the variable does not exist in the constraint yet, + // so we add it to the constraint as a new term. + ct_proto->add_var_index(var_index); + ct_proto->add_coefficient(coeff); +} + void ModelBuilderHelper::SetConstraintName(int ct_index, const std::string& name) { model_.mutable_constraint(ct_index)->set_name(name); diff --git a/ortools/linear_solver/wrappers/model_builder_helper.h b/ortools/linear_solver/wrappers/model_builder_helper.h index fd19158bef3..4abc54e30de 100644 --- a/ortools/linear_solver/wrappers/model_builder_helper.h +++ b/ortools/linear_solver/wrappers/model_builder_helper.h @@ -73,6 +73,8 @@ class ModelBuilderHelper { void SetConstraintLowerBound(int ct_index, double lb); void SetConstraintUpperBound(int ct_index, double ub); void AddConstraintTerm(int ct_index, int var_index, double coeff); + // Safe version that checks is does not create duplicate entries. + void SafeAddConstraintTerm(int ct_index, int var_index, double coeff); void SetConstraintName(int ct_index, const std::string& name); void SetConstraintCoefficient(int ct_index, int var_index, double coeff);