diff --git a/modopt/opt/algorithms/admm.py b/modopt/opt/algorithms/admm.py index b881b770..23d382b4 100644 --- a/modopt/opt/algorithms/admm.py +++ b/modopt/opt/algorithms/admm.py @@ -38,7 +38,10 @@ def __init__(self, cost_funcs, A, B, b, tau, **kwargs): self.B = B self.b = b self.tau = tau - + if self.A is self.B: + self.AuplusBv = lambda u, v: self.A.op(u + v) + else: + self.AuplusBv = lambda u, v: self.A.op(u) + self.B.op(v) def _calc_cost(self, u, v, **kwargs): """Calculate the cost. @@ -60,7 +63,7 @@ def _calc_cost(self, u, v, **kwargs): xp = get_array_module(u) cost = self.cost_funcs[0](u) cost += self.cost_funcs[1](v) - cost += self.tau * xp.linalg.norm(self.A.op(u) + self.B.op(v) - self.b) + cost += self.tau * xp.linalg.norm(self.AuplusBv(u,v) - self.b) return cost diff --git a/modopt/opt/algorithms/forward_backward.py b/modopt/opt/algorithms/forward_backward.py index 702799c6..d81bcde8 100644 --- a/modopt/opt/algorithms/forward_backward.py +++ b/modopt/opt/algorithms/forward_backward.py @@ -955,7 +955,11 @@ def _update(self): t_shifted_ratio = (self._t_old - 1) / self._t_new sigma_t_ratio = self._sigma * self._t_old / self._t_new beta_xi_t_shifted_ratio = t_shifted_ratio * self._beta / self._xi - self._z = - beta_xi_t_shifted_ratio * (self._x_old - self._z) + + #self._z = - beta_xi_t_shifted_ratio * (self._x_old - self._z) + self._z -= self._x_old + self._z *= beta_xi_t_shifted_ratio + self._z += self._u_new self._z += t_shifted_ratio * (self._u_new - self._u_old) self._z += sigma_t_ratio * (self._u_new - self._x_old)