Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add HeterogeneousEqual and HeterogeneousNotEqual to binary operation #220

Merged
merged 8 commits into from
May 26, 2024
217 changes: 148 additions & 69 deletions biscuit-auth/src/datalog/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ pub enum Binary {
BitwiseOr,
BitwiseXor,
NotEqual,
HeterogeneousEqual,
HeterogeneousNotEqual,
}

impl Binary {
Expand All @@ -95,8 +97,14 @@ impl Binary {
(Binary::GreaterThan, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i > j)),
(Binary::LessOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i <= j)),
(Binary::GreaterOrEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i >= j)),
(Binary::Equal, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Integer(i), Term::Integer(j)) => Ok(Term::Bool(i != j)),
(Binary::Equal | Binary::HeterogeneousEqual, Term::Integer(i), Term::Integer(j)) => {
Ok(Term::Bool(i == j))
}
(
Binary::NotEqual | Binary::HeterogeneousNotEqual,
Term::Integer(i),
Term::Integer(j),
) => Ok(Term::Bool(i != j)),
(Binary::Add, Term::Integer(i), Term::Integer(j)) => i
.checked_add(j)
.map(Term::Integer)
Expand Down Expand Up @@ -159,26 +167,42 @@ impl Binary {
_ => Err(error::Expression::UnknownSymbol(s1)),
}
}
(Binary::Equal, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Str(i), Term::Str(j)) => Ok(Term::Bool(i != j)),
(Binary::Equal | Binary::HeterogeneousEqual, Term::Str(i), Term::Str(j)) => {
Ok(Term::Bool(i == j))
}
(Binary::NotEqual | Binary::HeterogeneousNotEqual, Term::Str(i), Term::Str(j)) => {
Ok(Term::Bool(i != j))
}

// date
(Binary::LessThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i < j)),
(Binary::GreaterThan, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i > j)),
(Binary::LessOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i <= j)),
(Binary::GreaterOrEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i >= j)),
(Binary::Equal, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Date(i), Term::Date(j)) => Ok(Term::Bool(i != j)),
(Binary::Equal | Binary::HeterogeneousEqual, Term::Date(i), Term::Date(j)) => {
Ok(Term::Bool(i == j))
}
(Binary::NotEqual | Binary::HeterogeneousNotEqual, Term::Date(i), Term::Date(j)) => {
Ok(Term::Bool(i != j))
}

// symbol

// byte array
(Binary::Equal, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Bytes(i), Term::Bytes(j)) => Ok(Term::Bool(i != j)),
(Binary::Equal | Binary::HeterogeneousEqual, Term::Bytes(i), Term::Bytes(j)) => {
Ok(Term::Bool(i == j))
}
(Binary::NotEqual | Binary::HeterogeneousNotEqual, Term::Bytes(i), Term::Bytes(j)) => {
Ok(Term::Bool(i != j))
}

// set
(Binary::Equal, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set == s)),
(Binary::NotEqual, Term::Set(set), Term::Set(s)) => Ok(Term::Bool(set != s)),
(Binary::Equal | Binary::HeterogeneousEqual, Term::Set(set), Term::Set(s)) => {
Ok(Term::Bool(set == s))
}
(Binary::NotEqual | Binary::HeterogeneousNotEqual, Term::Set(set), Term::Set(s)) => {
Ok(Term::Bool(set != s))
}
(Binary::Intersection, Term::Set(set), Term::Set(s)) => {
Ok(Term::Set(set.intersection(&s).cloned().collect()))
}
Expand All @@ -205,16 +229,31 @@ impl Binary {
// boolean
(Binary::And, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i & j)),
(Binary::Or, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i | j)),
(Binary::Equal, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i == j)),
(Binary::NotEqual, Term::Bool(i), Term::Bool(j)) => Ok(Term::Bool(i != j)),
(Binary::Equal | Binary::HeterogeneousEqual, Term::Bool(i), Term::Bool(j)) => {
Ok(Term::Bool(i == j))
}
(Binary::NotEqual | Binary::HeterogeneousNotEqual, Term::Bool(i), Term::Bool(j)) => {
Ok(Term::Bool(i != j))
}

// null
(Binary::Equal, Term::Null, Term::Null) => Ok(Term::Bool(true)),
(Binary::Equal, Term::Null, _) => Ok(Term::Bool(false)),
(Binary::Equal, _, Term::Null) => Ok(Term::Bool(false)),
(Binary::NotEqual, Term::Null, Term::Null) => Ok(Term::Bool(false)),
(Binary::NotEqual, Term::Null, _) => Ok(Term::Bool(true)),
(Binary::NotEqual, _, Term::Null) => Ok(Term::Bool(true)),
(Binary::Equal | Binary::HeterogeneousEqual, Term::Null, Term::Null) => {
Ok(Term::Bool(true))
}
(Binary::Equal | Binary::HeterogeneousEqual, Term::Null, _) => Ok(Term::Bool(false)),
Geal marked this conversation as resolved.
Show resolved Hide resolved
(Binary::Equal | Binary::HeterogeneousEqual, _, Term::Null) => Ok(Term::Bool(false)),
Geal marked this conversation as resolved.
Show resolved Hide resolved
(Binary::NotEqual | Binary::HeterogeneousNotEqual, Term::Null, Term::Null) => {
Ok(Term::Bool(false))
}
(Binary::NotEqual | Binary::HeterogeneousNotEqual, Term::Null, _) => {
Ok(Term::Bool(true))
}
(Binary::NotEqual | Binary::HeterogeneousNotEqual, _, Term::Null) => {
Ok(Term::Bool(true))
}
Geal marked this conversation as resolved.
Show resolved Hide resolved

(Binary::HeterogeneousEqual, _, _) => Ok(Term::Bool(false)),
(Binary::HeterogeneousNotEqual, _, _) => Ok(Term::Bool(true)),

_ => {
//println!("unexpected value type on the stack");
Expand All @@ -229,8 +268,8 @@ impl Binary {
Binary::GreaterThan => format!("{} > {}", left, right),
Binary::LessOrEqual => format!("{} <= {}", left, right),
Binary::GreaterOrEqual => format!("{} >= {}", left, right),
Binary::Equal => format!("{} == {}", left, right),
Binary::NotEqual => format!("{} != {}", left, right),
Binary::Equal | Binary::HeterogeneousEqual => format!("{} == {}", left, right),
Binary::NotEqual | Binary::HeterogeneousNotEqual => format!("{} != {}", left, right),
Binary::Contains => format!("{}.contains({})", left, right),
Binary::Prefix => format!("{}.starts_with({})", left, right),
Binary::Suffix => format!("{}.ends_with({})", left, right),
Expand Down Expand Up @@ -324,6 +363,8 @@ impl Expression {

#[cfg(test)]
mod tests {
use std::collections::BTreeSet;

use super::*;
use crate::datalog::{SymbolTable, TemporarySymbolTable};

Expand Down Expand Up @@ -482,81 +523,119 @@ mod tests {
fn null_equal() {
let symbols = SymbolTable::new();
let mut tmp_symbols = TemporarySymbolTable::new(&symbols);

let ops = vec![
Op::Value(Term::Null),
Op::Value(Term::Null),
let values: HashMap<u32, Term> = HashMap::new();
let operands = vec![Op::Value(Term::Null), Op::Value(Term::Null)];
let operators = vec![
Op::Binary(Binary::Equal),
Op::Binary(Binary::HeterogeneousEqual),
];

let values: HashMap<u32, Term> = HashMap::new();

println!("ops: {:?}", ops);
for op in operators {
let mut ops = operands.clone();
ops.push(op);
println!("ops: {:?}", ops);

let e = Expression { ops };
println!("print: {}", e.print(&symbols).unwrap());
let e = Expression { ops };
println!("print: {}", e.print(&symbols).unwrap());

let res = e.evaluate(&values, &mut tmp_symbols);
assert_eq!(res, Ok(Term::Bool(true)));
let res = e.evaluate(&values, &mut tmp_symbols);
assert_eq!(res, Ok(Term::Bool(true)));
}
}

#[test]
fn null_not_equal() {
let symbols = SymbolTable::new();
let mut tmp_symbols = TemporarySymbolTable::new(&symbols);

let ops = vec![
Op::Value(Term::Null),
Op::Value(Term::Null),
let values: HashMap<u32, Term> = HashMap::new();
let operands = vec![Op::Value(Term::Null), Op::Value(Term::Null)];
let operators = vec![
Op::Binary(Binary::NotEqual),
Op::Binary(Binary::HeterogeneousNotEqual),
];

let values: HashMap<u32, Term> = HashMap::new();

println!("ops: {:?}", ops);
for op in operators {
let mut ops = operands.clone();
ops.push(op);
println!("ops: {:?}", ops);

let e = Expression { ops };
println!("print: {}", e.print(&symbols).unwrap());
let e = Expression { ops };
println!("print: {}", e.print(&symbols).unwrap());

let res = e.evaluate(&values, &mut tmp_symbols);
assert_eq!(res, Ok(Term::Bool(false)));
let res = e.evaluate(&values, &mut tmp_symbols);
assert_eq!(res, Ok(Term::Bool(false)));
}
}

#[test]
fn null_heterogeneous() {
let symbols = SymbolTable::new();
let mut tmp_symbols = TemporarySymbolTable::new(&symbols);

let ops = vec![
Op::Value(Term::Null),
Op::Value(Term::Integer(1)),
Op::Binary(Binary::Equal),
];

let values: HashMap<u32, Term> = HashMap::new();
let operands = vec![Op::Value(Term::Null), Op::Value(Term::Integer(1))];
let operators = HashMap::from([
(Op::Binary(Binary::NotEqual), true),
(Op::Binary(Binary::HeterogeneousNotEqual), true),
(Op::Binary(Binary::Equal), false),
(Op::Binary(Binary::HeterogeneousEqual), false),
Geal marked this conversation as resolved.
Show resolved Hide resolved
]);

for (op, result) in operators {
let mut ops = operands.clone();
ops.push(op);
println!("ops: {:?}", ops);

println!("ops: {:?}", ops);

let e = Expression { ops };
println!("print: {}", e.print(&symbols).unwrap());

let res = e.evaluate(&values, &mut tmp_symbols);
assert_eq!(res, Ok(Term::Bool(false)));
let e = Expression { ops };
println!("print: {}", e.print(&symbols).unwrap());

let ops = vec![
Op::Value(Term::Null),
Op::Value(Term::Integer(1)),
Op::Binary(Binary::NotEqual),
];
let res = e.evaluate(&values, &mut tmp_symbols);
assert_eq!(res, Ok(Term::Bool(result)));
}
}

#[test]
fn equal_heterogeneous() {
let symbols = SymbolTable::new();
let mut tmp_symbols = TemporarySymbolTable::new(&symbols);
let values: HashMap<u32, Term> = HashMap::new();

println!("ops: {:?}", ops);

let e = Expression { ops };
println!("print: {}", e.print(&symbols).unwrap());

let res = e.evaluate(&values, &mut tmp_symbols);
assert_eq!(res, Ok(Term::Bool(true)));
let operands_samples = [
vec![Op::Value(Term::Bool(true)), Op::Value(Term::Integer(1))],
vec![Op::Value(Term::Bool(true)), Op::Value(Term::Str(1))],
vec![Op::Value(Term::Integer(1)), Op::Value(Term::Str(1))],
vec![
Op::Value(Term::Set(BTreeSet::from([Term::Integer(1)]))),
Op::Value(Term::Set(BTreeSet::from([Term::Str(1)]))),
],
vec![
Op::Value(Term::Bytes(Vec::new())),
Op::Value(Term::Integer(1)),
],
vec![
Op::Value(Term::Bytes(Vec::new())),
Op::Value(Term::Str(1025)),
],
vec![Op::Value(Term::Date(12)), Op::Value(Term::Integer(1))],
];
let operators = HashMap::from([
(Op::Binary(Binary::HeterogeneousNotEqual), true),
(Op::Binary(Binary::HeterogeneousEqual), false),
]);

for operands in operands_samples {
let operands_reversed: Vec<_> = operands.iter().cloned().rev().collect();
for operand in [operands, operands_reversed] {
for (op, result) in &operators {
let mut ops = operand.clone();
ops.push(op.clone());
println!("ops: {:?}", ops);

let e = Expression { ops };
println!("print: {}", e.print(&symbols).unwrap());

let res = e.evaluate(&values, &mut tmp_symbols);
assert_eq!(res, Ok(Term::Bool(*result)));
}
}
}
}
}
13 changes: 7 additions & 6 deletions biscuit-auth/src/datalog/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -965,12 +965,13 @@ pub fn contains_v4_op(expressions: &[Expression]) -> bool {

fn contains_v5_op(expressions: &[Expression]) -> bool {
expressions.iter().any(|expression| {
expression.ops.iter().any(|op| {
if let Op::Value(term) = op {
contains_v5_term(term)
} else {
false
}
expression.ops.iter().any(|op| match op {
Op::Value(term) => contains_v5_term(term),
Op::Binary(binary) => match binary {
Binary::HeterogeneousEqual | Binary::HeterogeneousNotEqual => true,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

_ => false,
},
_ => false,
})
})
}
Expand Down
8 changes: 8 additions & 0 deletions biscuit-auth/src/format/convert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -633,6 +633,8 @@ pub mod v2 {
Binary::BitwiseOr => Kind::BitwiseOr,
Binary::BitwiseXor => Kind::BitwiseXor,
Binary::NotEqual => Kind::NotEqual,
Binary::HeterogeneousEqual => Kind::HeterogeneousEqual,
Binary::HeterogeneousNotEqual => Kind::HeterogeneousNotEqual,
} as i32,
})
}
Expand Down Expand Up @@ -687,6 +689,12 @@ pub mod v2 {
Some(op_binary::Kind::BitwiseOr) => Op::Binary(Binary::BitwiseOr),
Some(op_binary::Kind::BitwiseXor) => Op::Binary(Binary::BitwiseXor),
Some(op_binary::Kind::NotEqual) => Op::Binary(Binary::NotEqual),
Some(op_binary::Kind::HeterogeneousEqual) => {
Op::Binary(Binary::HeterogeneousEqual)
}
Some(op_binary::Kind::HeterogeneousNotEqual) => {
Op::Binary(Binary::HeterogeneousNotEqual)
}
None => {
return Err(error::Format::DeserializationError(
"deserialization error: binary operation is empty".to_string(),
Expand Down
2 changes: 2 additions & 0 deletions biscuit-auth/src/format/schema.proto
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,8 @@ message OpBinary {
BitwiseOr = 18;
BitwiseXor = 19;
NotEqual = 20;
HeterogeneousEqual = 21;
HeterogeneousNotEqual = 22;
}

required Kind kind = 1;
Expand Down
2 changes: 2 additions & 0 deletions biscuit-auth/src/format/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,8 @@ pub mod op_binary {
BitwiseOr = 18,
BitwiseXor = 19,
NotEqual = 20,
HeterogeneousEqual = 21,
HeterogeneousNotEqual = 22,
}
}
#[derive(Clone, PartialEq, ::prost::Message)]
Expand Down
4 changes: 2 additions & 2 deletions biscuit-auth/src/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ mod tests {
ops: vec![
Op::Value(int(1)),
Op::Value(int(2)),
Op::Binary(Binary::Equal),
Op::Binary(Binary::HeterogeneousEqual),
],
}],
)],
Expand Down Expand Up @@ -629,7 +629,7 @@ mod tests {
ops: vec![
Op::Value(int(1)),
Op::Value(int(2)),
Op::Binary(Binary::Equal),
Op::Binary(Binary::HeterogeneousEqual),
],
}],
)],
Expand Down
2 changes: 2 additions & 0 deletions biscuit-auth/src/token/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,8 @@ impl From<biscuit_parser::builder::Binary> for Binary {
biscuit_parser::builder::Binary::BitwiseOr => Binary::BitwiseOr,
biscuit_parser::builder::Binary::BitwiseXor => Binary::BitwiseXor,
biscuit_parser::builder::Binary::NotEqual => Binary::NotEqual,
biscuit_parser::builder::Binary::HeterogeneousEqual => Binary::HeterogeneousEqual,
biscuit_parser::builder::Binary::HeterogeneousNotEqual => Binary::HeterogeneousNotEqual,
}
}
}
Expand Down
Loading