Skip to content

Commit

Permalink
Fix multiplication and division of complex numbers
Browse files Browse the repository at this point in the history
Multiplication and division of complex numbers are not just pointwise
applications of those operations.

Fixes: diffblue#8375
  • Loading branch information
tautschnig committed Jul 11, 2024
1 parent 629dbcd commit 6a6c480
Show file tree
Hide file tree
Showing 4 changed files with 131 additions and 40 deletions.
15 changes: 15 additions & 0 deletions regression/cbmc/complex2/main.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#include <complex.h>

int main()
{
char choice;
float re = choice ? 1.3f : 2.1f; // a non-constant well-behaved float
float complex z1 = I + re;
float complex z2 = z1 * z1;
float complex expected = 2 * I * re + re * re - 1; // (a+i)^2 = 2ai + a^2 - 1
float complex actual =
re * re + I; // (a1 + b1*i)*(a2 + b2*i) = (a1*a2 + b1*b2*i)
__CPROVER_assert(z2 == expected, "right");
__CPROVER_assert(z2 == actual, "wrong");
return 0;
}
10 changes: 10 additions & 0 deletions regression/cbmc/complex2/test.desc
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
CORE no-new-smt
main.c

^\[main.assertion.1\] line 12 right: SUCCESS$
^\[main.assertion.2\] line 13 wrong: FAILURE$
^VERIFICATION FAILED$
^EXIT=10$
^SIGNAL=0$
--
^warning: ignoring
56 changes: 50 additions & 6 deletions src/goto-programs/remove_complex.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,10 @@ static void remove_complex(exprt &expr)

if(expr.type().id()==ID_complex)
{
if(expr.id()==ID_plus || expr.id()==ID_minus ||
expr.id()==ID_mult || expr.id()==ID_div)
if(expr.id() == ID_plus || expr.id() == ID_minus)
{
// FIXME plus and mult are defined as n-ary operations
// rather than binary. This code assumes that they
// can only have exactly 2 operands, and it is not clear
// that it is safe to do so in this context
// plus and mult are n-ary expressions, but front-ends currently ensure
// that we only see them as binary ones
PRECONDITION(expr.operands().size() == 2);
// do component-wise:
// x+y -> complex(x.r+y.r,x.i+y.i)
Expand All @@ -153,6 +150,53 @@ static void remove_complex(exprt &expr)

expr=struct_expr;
}
else if(expr.id() == ID_mult)
{
// plus and mult are n-ary expressions, but front-ends currently ensure
// that we only see them as binary ones
PRECONDITION(expr.operands().size() == 2);
exprt lhs_real = complex_member(to_binary_expr(expr).op0(), ID_real);
exprt lhs_imag = complex_member(to_binary_expr(expr).op0(), ID_imag);
exprt rhs_real = complex_member(to_binary_expr(expr).op1(), ID_real);
exprt rhs_imag = complex_member(to_binary_expr(expr).op1(), ID_imag);

struct_exprt struct_expr{
{minus_exprt{
mult_exprt{lhs_real, rhs_real}, mult_exprt{lhs_imag, rhs_imag}},
plus_exprt{
mult_exprt{lhs_imag, rhs_real}, mult_exprt{lhs_real, rhs_imag}}},
expr.type()};

struct_expr.op0().add_source_location() = expr.source_location();
struct_expr.op1().add_source_location() = expr.source_location();

expr = struct_expr;
}
else if(expr.id() == ID_div)
{
exprt lhs_real = complex_member(to_binary_expr(expr).op0(), ID_real);
exprt lhs_imag = complex_member(to_binary_expr(expr).op0(), ID_imag);
exprt rhs_real = complex_member(to_binary_expr(expr).op1(), ID_real);
exprt rhs_imag = complex_member(to_binary_expr(expr).op1(), ID_imag);

plus_exprt numerator_real{
mult_exprt{lhs_real, rhs_real}, mult_exprt{lhs_imag, rhs_imag}};
minus_exprt numerator_imag{
mult_exprt{lhs_imag, rhs_real}, mult_exprt{lhs_real, rhs_imag}};

plus_exprt denominator{
mult_exprt{rhs_real, rhs_real}, mult_exprt{rhs_imag, rhs_imag}};

struct_exprt struct_expr{
{div_exprt{numerator_real, denominator},
div_exprt{numerator_imag, denominator}},
expr.type()};

struct_expr.op0().add_source_location() = expr.source_location();
struct_expr.op1().add_source_location() = expr.source_location();

expr = struct_expr;
}
else if(expr.id()==ID_unary_minus)
{
auto const &unary_minus_expr = to_unary_minus_expr(expr);
Expand Down
90 changes: 56 additions & 34 deletions src/solvers/flattening/boolbv_floatbv_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,44 +131,66 @@ bvt boolbvt::convert_floatbv_op(const ieee_float_op_exprt &expr)
sub_width > 0 && width % sub_width == 0,
"width of a complex subtype must be positive and evenly divide the "
"width of the complex expression");
DATA_INVARIANT(
sub_width * 2 == width, "a complex type consists of exactly two parts");

bvt lhs_real{lhs_as_bv.begin(), lhs_as_bv.begin() + sub_width};
bvt rhs_real{rhs_as_bv.begin(), rhs_as_bv.begin() + sub_width};

std::size_t size=width/sub_width;
bvt result_bv;
result_bv.resize(width);
bvt lhs_imag{lhs_as_bv.begin() + sub_width, lhs_as_bv.end()};
bvt rhs_imag{rhs_as_bv.begin() + sub_width, rhs_as_bv.end()};

for(std::size_t i=0; i<size; i++)
bvt result_real, result_imag;

if(expr.id() == ID_floatbv_plus || expr.id() == ID_floatbv_minus)
{
result_real = float_utils.add_sub(
lhs_real, rhs_real, expr.id() == ID_floatbv_minus);
result_imag = float_utils.add_sub(
lhs_imag, rhs_imag, expr.id() == ID_floatbv_minus);
}
else if(expr.id() == ID_floatbv_mult)
{
// Could be optimised to just three multiplications with more additions
// instead, but then we'd have to worry about the impact of possible
// overflows. So we use the naive approach for now:
result_real = float_utils.add_sub(
float_utils.mul(lhs_real, rhs_real),
float_utils.mul(lhs_imag, rhs_imag),
true);
result_imag = float_utils.add_sub(
float_utils.mul(lhs_real, rhs_imag),
float_utils.mul(lhs_imag, rhs_real),
false);
}
else if(expr.id() == ID_floatbv_div)
{
bvt lhs_sub_bv, rhs_sub_bv, sub_result_bv;

lhs_sub_bv.assign(
lhs_as_bv.begin() + i * sub_width,
lhs_as_bv.begin() + (i + 1) * sub_width);
rhs_sub_bv.assign(
rhs_as_bv.begin() + i * sub_width,
rhs_as_bv.begin() + (i + 1) * sub_width);

if(expr.id()==ID_floatbv_plus)
sub_result_bv = float_utils.add_sub(lhs_sub_bv, rhs_sub_bv, false);
else if(expr.id()==ID_floatbv_minus)
sub_result_bv = float_utils.add_sub(lhs_sub_bv, rhs_sub_bv, true);
else if(expr.id()==ID_floatbv_mult)
sub_result_bv = float_utils.mul(lhs_sub_bv, rhs_sub_bv);
else if(expr.id()==ID_floatbv_div)
sub_result_bv = float_utils.div(lhs_sub_bv, rhs_sub_bv);
else
UNREACHABLE;

INVARIANT(
sub_result_bv.size() == sub_width,
"we constructed a new complex of the right size");
INVARIANT(
i * sub_width + sub_width - 1 < result_bv.size(),
"the sub-bitvector fits into the result bitvector");
std::copy(
sub_result_bv.begin(),
sub_result_bv.end(),
result_bv.begin() + i * sub_width);
bvt numerator_real = float_utils.add_sub(
float_utils.mul(lhs_real, rhs_real),
float_utils.mul(lhs_imag, rhs_imag),
false);
bvt numerator_imag = float_utils.add_sub(
float_utils.mul(lhs_imag, rhs_real),
float_utils.mul(lhs_real, rhs_imag),
true);

bvt denominator = float_utils.add_sub(
float_utils.mul(rhs_real, rhs_real),
float_utils.mul(rhs_imag, rhs_imag),
false);

result_real = float_utils.div(numerator_real, denominator);
result_imag = float_utils.div(numerator_imag, denominator);
}
else
UNREACHABLE;

bvt result_bv = std::move(result_real);
result_bv.reserve(width);
result_bv.insert(
result_bv.end(),
std::make_move_iterator(result_imag.begin()),
std::make_move_iterator(result_imag.end()));

return result_bv;
}
Expand Down

0 comments on commit 6a6c480

Please sign in to comment.