Skip to content

Commit

Permalink
restrict $contains operator to scalar in array
Browse files Browse the repository at this point in the history
  • Loading branch information
olirice committed Feb 27, 2024
1 parent 10ccfde commit de0aa8e
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 31 deletions.
6 changes: 3 additions & 3 deletions docs/concepts_metadata.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ Comparison operators compare a provided value with a value stored in metadata fi
| $lt | Matches values that are less than a specified value |
| $lte | Matches values that are less than or equal to a specified value |
| $in | Matches values that are contained by scalar list of specified values |
| $contains | Matches values where a specified value is contained within the metadata field value |
| $contains | Matches values where a scalar is contained within an array metadata field |


### Logical Operators
Expand Down Expand Up @@ -99,10 +99,10 @@ Those variants are most consistently able to make use of indexes.
}
```

`tags` contain "important"
`tags`, an array, contains the string "important"

```json
{
"tags": {"$contains": "important"}
}
```
```
52 changes: 29 additions & 23 deletions src/tests/test_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,8 +583,10 @@ def test_filters_contains(client: vecs.Client) -> None:
records = [
("0", [0, 0, 0, 0], {"a": 1, "b": 2}),
("1", [1, 0, 0, 0], {"a": [1, 2, 3]}),
("2", [1, 1, 0, 0], {"a": {"1": "2"}}),
("3", [0, 0, 0, 0], {"a": "1"}),
("2", [1, 1, 0, 0], {"a": {"1": "2", "x": "y"}}),
("3", [0, 0, 0, 0], {"a": ["1"]}),
("4", [1, 0, 0, 0], {"a": [4, 3, 2, 1]}),
("5", [1, 0, 0, 0], {"a": [2]}),
]

bar.upsert(records)
Expand All @@ -595,35 +597,39 @@ def test_filters_contains(client: vecs.Client) -> None:
data=[0, 0, 0, 0],
limit=3,
filters={"a": {"$contains": 1}},
) == ["0", "1"]

# Test $contains operator for list value
assert bar.query(
data=[1, 0, 0, 0],
limit=3,
filters={"a": {"$contains": [1, 2, 3]}},
) == ["1"]

# Test $contains operator for dictionary value
assert bar.query(
data=[1, 1, 0, 0],
limit=3,
filters={"a": {"$contains": {"1": "2"}}},
) == ["2"]
) == ["1", "4"]

# Test $contains operator for string value
# Test $contains operator for string value. Strings treated differently than ints
assert bar.query(
data=[0, 0, 0, 0],
limit=3,
filters={"a": {"$contains": "1"}},
) == ["3"]

# Test $contains operator for non-existent value
assert bar.query(
data=[0, 0, 0, 0],
limit=3,
filters={"a": {"$contains": 5}},
) == []
assert (
bar.query(
data=[0, 0, 0, 0],
limit=3,
filters={"a": {"$contains": 5}},
)
== []
)

# Test $contains requires a scalar value
with pytest.raises(vecs.exc.FilterError):
bar.query(
data=[1, 0, 0, 0],
limit=3,
filters={"a": {"$contains": [1, 2, 3]}},
)

with pytest.raises(vecs.exc.FilterError):
bar.query(
data=[1, 0, 0, 0],
limit=3,
filters={"a": {"$contains": {"a": 1}}},
)


def test_access_index(client: vecs.Client) -> None:
Expand Down
39 changes: 34 additions & 5 deletions src/vecs/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -860,7 +860,16 @@ def build_filters(json_col: Column, filters: Dict):
if len(value) > 1:
raise FilterError("only one operator permitted")
for operator, clause in value.items():
if operator not in ("$eq", "$ne", "$lt", "$lte", "$gt", "$gte", "$in", '$contains'):
if operator not in (
"$eq",
"$ne",
"$lt",
"$lte",
"$gt",
"$gte",
"$in",
"$contains",
):
raise FilterError("unknown operator")

# equality of singular values can take advantage of the metadata index
Expand All @@ -877,19 +886,39 @@ def build_filters(json_col: Column, filters: Dict):
for elem in clause:
if not isinstance(elem, (int, str, float)):
raise FilterError(
"argument to $in filter must be a list or scalars"
"argument to $in filter must be a list of scalars"
)

# cast the array of scalars to a postgres array of jsonb so we can
# directly compare json types in the query
contains_value = [cast(elem, postgresql.JSONB) for elem in clause]
return json_col.op("->")(key).in_(contains_value)

matches_value = cast(clause, postgresql.JSONB)

# @> in Postgres is heavily overloaded.
# By default, it will return True for
#
# scalar in array
# '[1, 2, 3]'::jsonb @> '1'::jsonb -- true#
# equality:
# '1'::jsonb @> '1'::jsonb -- true
# key value pair in object
# '{"a": 1, "b": 2}'::jsonb @> '{"a": 1}'::jsonb -- true
#
# At this time we only want to allow "scalar in array" so
# we assert that the clause is a scalar and the target metadata
# is an array
if operator == "$contains":
contains_value = cast(clause, postgresql.JSONB)
return json_col.op("->")(key).contains(contains_value)
if not isinstance(clause, (int, str, float)):
raise FilterError(
"argument to $contains filter must be a scalar"
)

matches_value = cast(clause, postgresql.JSONB)
return and_(
json_col.op("->")(key).contains(matches_value),
func.jsonb_typeof(json_col.op("->")(key)) == "array",
)

# handles non-singular values
if operator == "$eq":
Expand Down

0 comments on commit de0aa8e

Please sign in to comment.