Skip to content

Commit

Permalink
[ModelBuilder] Add set_coefficient API, fix add_term to not create du…
Browse files Browse the repository at this point in the history
…plicate variable entries
  • Loading branch information
lperron committed Jul 20, 2023
1 parent 36d137c commit b82cd22
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 1 deletion.
7 changes: 6 additions & 1 deletion ortools/linear_solver/python/model_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,13 @@ def __str__(self):
def __repr__(self):
return self.__str__()

def set_coefficient(self, var: Variable, coeff: NumberT) -> None:
"""Sets the coefficient of the variable in the constraint."""
self.__helper.set_constraint_coefficient(self.__index, var.index, coeff)

def add_term(self, var: Variable, coeff: NumberT) -> None:
self.__helper.add_term_to_constraint(self.__index, var.index, coeff)
"""Adds var * coeff to the constraint."""
self.__helper.safe_add_term_to_constraint(self.__index, var.index, coeff)


class ModelBuilder:
Expand Down
3 changes: 3 additions & 0 deletions ortools/linear_solver/python/model_builder_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,9 @@ PYBIND11_MODULE(model_builder_helper, m) {
helper->AddConstraintTerm(ct_index, i, c);
}
})
.def("safe_add_term_to_constraint",
&ModelBuilderHelper::SafeAddConstraintTerm, arg("ct_index"),
arg("var_index"), arg("coeff"))
.def("set_constraint_name", &ModelBuilderHelper::SetConstraintName,
arg("ct_index"), arg("name"))
.def("set_constraint_coefficient",
Expand Down
22 changes: 22 additions & 0 deletions ortools/linear_solver/python/model_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,6 +306,28 @@ def test_duplicate_variables(self):
solver = mb.ModelSolver("sat")
self.assertEqual(mb.SolveStatus.OPTIMAL, solver.solve(model))

def test_add_term(self):
model = mb.ModelBuilder()
x = model.new_int_var(0.0, 4.0, "x")
y = model.new_int_var(0.0, 4.0, "y")
z = model.new_int_var(0.0, 4.0, "z")
t = model.new_int_var(0.0, 4.0, "t")
ct = model.add(x + 2 * y == 3)
self.assertEqual(ct.helper.constraint_var_indices(ct.index), [0, 1])
self.assertEqual(ct.helper.constraint_coefficients(ct.index), [1, 2])
ct.add_term(x, 2)
self.assertEqual(ct.helper.constraint_var_indices(ct.index), [0, 1])
self.assertEqual(ct.helper.constraint_coefficients(ct.index), [3, 2])
ct.set_coefficient(x, 5)
self.assertEqual(ct.helper.constraint_var_indices(ct.index), [0, 1])
self.assertEqual(ct.helper.constraint_coefficients(ct.index), [5, 2])
ct.add_term(z, 4)
self.assertEqual(ct.helper.constraint_var_indices(ct.index), [0, 1, 2])
self.assertEqual(ct.helper.constraint_coefficients(ct.index), [5, 2, 4])
ct.set_coefficient(t, -1)
self.assertEqual(ct.helper.constraint_var_indices(ct.index), [0, 1, 2, 3])
self.assertEqual(ct.helper.constraint_coefficients(ct.index), [5, 2, 4, -1])

def test_issue_3614(self):
total_number_of_choices = 5 + 1
total_unique_products = 3
Expand Down
15 changes: 15 additions & 0 deletions ortools/linear_solver/wrappers/model_builder_helper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,21 @@ void ModelBuilderHelper::AddConstraintTerm(int ct_index, int var_index,
ct_proto->add_coefficient(coeff);
}

void ModelBuilderHelper::SafeAddConstraintTerm(int ct_index, int var_index,
double coeff) {
MPConstraintProto* ct_proto = model_.mutable_constraint(ct_index);
for (int i = 0; i < ct_proto->var_index_size(); ++i) {
if (ct_proto->var_index(i) == var_index) {
ct_proto->set_coefficient(i, coeff + ct_proto->coefficient(i));
return;
}
}
// If we reach this point, the variable does not exist in the constraint yet,
// so we add it to the constraint as a new term.
ct_proto->add_var_index(var_index);
ct_proto->add_coefficient(coeff);
}

void ModelBuilderHelper::SetConstraintName(int ct_index,
const std::string& name) {
model_.mutable_constraint(ct_index)->set_name(name);
Expand Down
2 changes: 2 additions & 0 deletions ortools/linear_solver/wrappers/model_builder_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ class ModelBuilderHelper {
void SetConstraintLowerBound(int ct_index, double lb);
void SetConstraintUpperBound(int ct_index, double ub);
void AddConstraintTerm(int ct_index, int var_index, double coeff);
// Safe version that checks is does not create duplicate entries.
void SafeAddConstraintTerm(int ct_index, int var_index, double coeff);
void SetConstraintName(int ct_index, const std::string& name);
void SetConstraintCoefficient(int ct_index, int var_index, double coeff);

Expand Down

0 comments on commit b82cd22

Please sign in to comment.