Skip to content

Commit

Permalink
Merge pull request diffblue#8235 from tautschnig/cleanup/no-follow-so…
Browse files Browse the repository at this point in the history
…lvers

Solvers: Replace uses of namespacet::follow
  • Loading branch information
tautschnig authored May 9, 2024
2 parents 6fca7bb + 8a0782b commit f18b509
Show file tree
Hide file tree
Showing 7 changed files with 94 additions and 39 deletions.
5 changes: 4 additions & 1 deletion src/solvers/flattening/boolbv_struct.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,10 @@ Author: Daniel Kroening, [email protected]

bvt boolbvt::convert_struct(const struct_exprt &expr)
{
const struct_typet &struct_type=to_struct_type(ns.follow(expr.type()));
const struct_typet &struct_type =
expr.type().id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(expr.type()))
: to_struct_type(expr.type());

std::size_t width=boolbv_width(struct_type);

Expand Down
14 changes: 10 additions & 4 deletions src/solvers/flattening/boolbv_typecast.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -538,17 +538,23 @@ bool boolbvt::type_conversion(
return false;
}
}
else if(ns.follow(dest_type).id() == ID_struct)
else if(dest_type.id() == ID_struct || dest_type.id() == ID_struct_tag)
{
const struct_typet &dest_struct = to_struct_type(ns.follow(dest_type));
const struct_typet &dest_struct =
dest_type.id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(dest_type))
: to_struct_type(dest_type);

if(ns.follow(src_type).id() == ID_struct)
if(src_type.id() == ID_struct || src_type.id() == ID_struct_tag)
{
// we do subsets

dest.resize(dest_width, const_literal(false));

const struct_typet &op_struct = to_struct_type(ns.follow(src_type));
const struct_typet &op_struct =
src_type.id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(src_type))
: to_struct_type(src_type);

const struct_typet::componentst &dest_comp = dest_struct.components();

Expand Down
12 changes: 8 additions & 4 deletions src/solvers/flattening/boolbv_update.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,11 @@ void boolbvt::convert_update_rec(
{
const irep_idt &component_name=designator.get(ID_component_name);

if(ns.follow(type).id() == ID_struct)
if(type.id() == ID_struct || type.id() == ID_struct_tag)
{
const struct_typet &struct_type = to_struct_type(ns.follow(type));
const struct_typet &struct_type =
type.id() == ID_struct_tag ? ns.follow_tag(to_struct_tag_type(type))
: to_struct_type(type);

std::size_t struct_offset=0;

Expand Down Expand Up @@ -144,9 +146,11 @@ void boolbvt::convert_update_rec(
convert_update_rec(
designators, d+1, new_type, new_offset, new_value, bv);
}
else if(ns.follow(type).id() == ID_union)
else if(type.id() == ID_union || type.id() == ID_union_tag)
{
const union_typet &union_type = to_union_type(ns.follow(type));
const union_typet &union_type = type.id() == ID_union_tag
? ns.follow_tag(to_union_tag_type(type))
: to_union_type(type);

const union_typet::componentt &component=
union_type.get_component(component_name);
Expand Down
31 changes: 21 additions & 10 deletions src/solvers/flattening/bv_pointers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,18 +312,23 @@ std::optional<bvt> bv_pointerst::convert_address_of_rec(const exprt &expr)
{
const member_exprt &member_expr=to_member_expr(expr);
const exprt &struct_op = member_expr.compound();
const typet &struct_op_type=ns.follow(struct_op.type());

// recursive call
auto bv_opt = convert_address_of_rec(struct_op);
if(!bv_opt.has_value())
return {};

bvt bv = std::move(*bv_opt);
if(struct_op_type.id()==ID_struct)
if(
struct_op.type().id() == ID_struct ||
struct_op.type().id() == ID_struct_tag)
{
auto offset = member_offset(
to_struct_type(struct_op_type), member_expr.get_component_name(), ns);
const struct_typet &struct_op_type =
struct_op.type().id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(struct_op.type()))
: to_struct_type(struct_op.type());
auto offset =
member_offset(struct_op_type, member_expr.get_component_name(), ns);
CHECK_RETURN(offset.has_value());

// add offset
Expand All @@ -333,7 +338,8 @@ std::optional<bvt> bv_pointerst::convert_address_of_rec(const exprt &expr)
else
{
INVARIANT(
struct_op_type.id() == ID_union,
struct_op.type().id() == ID_union ||
struct_op.type().id() == ID_union_tag,
"member expression should operate on struct or union");
// nothing to do, all members have offset 0
}
Expand Down Expand Up @@ -551,21 +557,26 @@ bvt bv_pointerst::convert_pointer_type(const exprt &expr)
else if(expr.id() == ID_field_address)
{
const auto &field_address_expr = to_field_address_expr(expr);
const typet &compound_type = ns.follow(field_address_expr.compound_type());
const typet &compound_type = field_address_expr.compound_type();

// recursive call
auto bv = convert_bitvector(field_address_expr.base());

if(compound_type.id() == ID_struct)
if(compound_type.id() == ID_struct || compound_type.id() == ID_struct_tag)
{
auto offset = member_offset(
to_struct_type(compound_type), field_address_expr.component_name(), ns);
const struct_typet &struct_type =
compound_type.id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(compound_type))
: to_struct_type(compound_type);
auto offset =
member_offset(struct_type, field_address_expr.component_name(), ns);
CHECK_RETURN(offset.has_value());

// add offset
bv = offset_arithmetic(field_address_expr.type(), bv, *offset);
}
else if(compound_type.id() == ID_union)
else if(
compound_type.id() == ID_union || compound_type.id() == ID_union_tag)
{
// nothing to do, all fields have offset 0
}
Expand Down
40 changes: 28 additions & 12 deletions src/solvers/smt2/smt2_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -846,7 +846,10 @@ void smt2_convt::convert_address_of_rec(
struct_op_type.id() == ID_struct || struct_op_type.id() == ID_struct_tag,
"member expression operand shall have struct type");

const struct_typet &struct_type = to_struct_type(ns.follow(struct_op_type));
const struct_typet &struct_type =
struct_op_type.id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(struct_op_type))
: to_struct_type(struct_op_type);

const irep_idt &component_name = member_expr.get_component_name();

Expand Down Expand Up @@ -3159,7 +3162,10 @@ void smt2_convt::convert_floatbv_typecast(const floatbv_typecast_exprt &expr)

void smt2_convt::convert_struct(const struct_exprt &expr)
{
const struct_typet &struct_type = to_struct_type(ns.follow(expr.type()));
const struct_typet &struct_type =
expr.type().id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(expr.type()))
: to_struct_type(expr.type());

const struct_typet::componentst &components=
struct_type.components();
Expand Down Expand Up @@ -3262,10 +3268,9 @@ void smt2_convt::flatten_array(const exprt &expr)

void smt2_convt::convert_union(const union_exprt &expr)
{
const union_typet &union_type = to_union_type(ns.follow(expr.type()));
const exprt &op=expr.op();

std::size_t total_width=boolbv_width(union_type);
std::size_t total_width = boolbv_width(expr.type());

std::size_t member_width=boolbv_width(op.type());

Expand Down Expand Up @@ -4182,7 +4187,10 @@ void smt2_convt::convert_with(const with_exprt &expr)
}
else if(expr_type.id() == ID_struct || expr_type.id() == ID_struct_tag)
{
const struct_typet &struct_type = to_struct_type(ns.follow(expr_type));
const struct_typet &struct_type =
expr_type.id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(expr_type))
: to_struct_type(expr_type);

const exprt &index = expr.where();
const exprt &value = expr.new_value();
Expand Down Expand Up @@ -4253,11 +4261,9 @@ void smt2_convt::convert_with(const with_exprt &expr)
}
else if(expr_type.id() == ID_union || expr_type.id() == ID_union_tag)
{
const union_typet &union_type = to_union_type(ns.follow(expr_type));

const exprt &value = expr.new_value();

std::size_t total_width=boolbv_width(union_type);
std::size_t total_width = boolbv_width(expr_type);

std::size_t member_width=boolbv_width(value.type());

Expand Down Expand Up @@ -4399,7 +4405,10 @@ void smt2_convt::convert_member(const member_exprt &expr)

if(struct_op_type.id() == ID_struct || struct_op_type.id() == ID_struct_tag)
{
const struct_typet &struct_type = to_struct_type(ns.follow(struct_op_type));
const struct_typet &struct_type =
struct_op_type.id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(struct_op_type))
: to_struct_type(struct_op_type);

INVARIANT(
struct_type.has_component(name), "struct should have accessed component");
Expand Down Expand Up @@ -4496,7 +4505,9 @@ void smt2_convt::flatten2bv(const exprt &expr)
if(use_datatypes)
{
// concatenate elements
const struct_typet &struct_type = to_struct_type(ns.follow(type));
const struct_typet &struct_type =
type.id() == ID_struct_tag ? ns.follow_tag(to_struct_tag_type(type))
: to_struct_type(type);

const struct_typet::componentst &components=
struct_type.components();
Expand Down Expand Up @@ -4622,7 +4633,9 @@ void smt2_convt::unflatten(

out << "(mk-" << smt_typename;

const struct_typet &struct_type = to_struct_type(ns.follow(type));
const struct_typet &struct_type =
type.id() == ID_struct_tag ? ns.follow_tag(to_struct_tag_type(type))
: to_struct_type(type);

const struct_typet::componentst &components=
struct_type.components();
Expand Down Expand Up @@ -5501,8 +5514,11 @@ void smt2_convt::convert_type(const typet &type)
else if(type.id() == ID_union || type.id() == ID_union_tag)
{
std::size_t width=boolbv_width(type);
const union_typet &union_type = type.id() == ID_union_tag
? ns.follow_tag(to_union_tag_type(type))
: to_union_type(type);
CHECK_RETURN_WITH_DIAGNOSTICS(
to_union_type(ns.follow(type)).components().empty() || width != 0,
union_type.components().empty() || width != 0,
"failed to get width of union");

out << "(_ BitVec " << width << ")";
Expand Down
17 changes: 12 additions & 5 deletions src/solvers/smt2_incremental/encoding/struct_encoding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ extricate_updates(const with_exprt &struct_expr)
static exprt encode(const with_exprt &with, const namespacet &ns)
{
const auto tag_type = type_checked_cast<struct_tag_typet>(with.type());
const auto struct_type =
type_checked_cast<struct_typet>(ns.follow(with.type()));
const auto struct_type = ns.follow_tag(tag_type);
const auto updates = extricate_updates(with);
const auto components =
make_range(struct_type.components())
Expand Down Expand Up @@ -194,11 +193,19 @@ static std::size_t count_trailing_bit_width(
/// the combined width of the fields which follow the field being selected.
exprt struct_encodingt::encode_member(const member_exprt &member_expr) const
{
const auto &type = ns.get().follow(member_expr.compound().type());
const auto &compound_type = member_expr.compound().type();
const auto offset_bits = [&]() -> std::size_t {
if(can_cast_type<union_typet>(type))
if(
can_cast_type<union_typet>(compound_type) ||
can_cast_type<union_tag_typet>(compound_type))
{
return 0;
const auto &struct_type = type_checked_cast<struct_typet>(type);
}
const auto &struct_type =
compound_type.id() == ID_struct_tag
? ns.get().follow_tag(
type_checked_cast<struct_tag_typet>(compound_type))
: type_checked_cast<struct_typet>(compound_type);
return count_trailing_bit_width(
struct_type, member_expr.get_component_name(), *boolbv_width);
}();
Expand Down
14 changes: 11 additions & 3 deletions src/solvers/strings/string_refinement.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,10 @@ static void add_equations_for_symbol_resolution(
{
if(rhs.type().id() == ID_struct || rhs.type().id() == ID_struct_tag)
{
const struct_typet &struct_type = to_struct_type(ns.follow(rhs.type()));
const struct_typet &struct_type =
rhs.type().id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(rhs.type()))
: to_struct_type(rhs.type());
for(const auto &comp : struct_type.components())
{
if(is_char_pointer_type(comp.type()))
Expand Down Expand Up @@ -377,7 +380,10 @@ extract_strings_from_lhs(const exprt &lhs, const namespacet &ns)
result.push_back(lhs);
else if(lhs.type().id() == ID_struct || lhs.type().id() == ID_struct_tag)
{
const struct_typet &struct_type = to_struct_type(ns.follow(lhs.type()));
const struct_typet &struct_type =
lhs.type().id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(lhs.type()))
: to_struct_type(lhs.type());
for(const auto &comp : struct_type.components())
{
const std::vector<exprt> strings_in_comp = extract_strings_from_lhs(
Expand Down Expand Up @@ -439,7 +445,9 @@ static void add_string_equation_to_symbol_resolution(
eq.rhs().type().id() == ID_struct_tag)
{
const struct_typet &struct_type =
to_struct_type(ns.follow(eq.rhs().type()));
eq.rhs().type().id() == ID_struct_tag
? ns.follow_tag(to_struct_tag_type(eq.rhs().type()))
: to_struct_type(eq.rhs().type());
for(const auto &comp : struct_type.components())
{
const member_exprt lhs_data(eq.lhs(), comp.get_name(), comp.type());
Expand Down

0 comments on commit f18b509

Please sign in to comment.