Skip to content

Commit

Permalink
Merge pull request OSGeo#8520 from rouault/parquet_fid_filtering
Browse files Browse the repository at this point in the history
Parquet: fix SetAttributeFilter() with the special FID column
  • Loading branch information
rouault authored Oct 6, 2023
2 parents a38fb95 + 70f9fae commit 2cf3921
Show file tree
Hide file tree
Showing 6 changed files with 290 additions and 23 deletions.
86 changes: 86 additions & 0 deletions autotest/ogr/ogr_parquet.py
Original file line number Diff line number Diff line change
Expand Up @@ -2123,6 +2123,92 @@ def test_ogr_parquet_arrow_stream_numpy_fast_attribute_filter(filter):
assert fc != 0


###############################################################################


def test_ogr_parquet_arrow_stream_numpy_attribute_filter_on_fid_without_fid_column():
pytest.importorskip("osgeo.gdal_array")
pytest.importorskip("numpy")

ds = ogr.Open("data/parquet/test.parquet")
lyr = ds.GetLayer(0)
ignored_fields = ["decimal128", "decimal256", "time64_ns"]
lyr_defn = lyr.GetLayerDefn()
for i in range(lyr_defn.GetFieldCount()):
fld_defn = lyr_defn.GetFieldDefn(i)
if fld_defn.GetName().startswith("map_"):
ignored_fields.append(fld_defn.GetNameRef())
lyr.SetIgnoredFields(ignored_fields)
lyr.SetAttributeFilter("fid in (1, 3)")
assert lyr.TestCapability(ogr.OLCFastGetArrowStream) == 1

f = lyr.GetNextFeature()
assert f["uint8"] == 2
f = lyr.GetNextFeature()
assert f["uint8"] == 4
f = lyr.GetNextFeature()
assert f is None

lyr.ResetReading()
stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
vals = []
for batch in stream:
for v in batch["uint8"]:
vals.append(v)
assert vals == [2, 4]

# Check that it works if the effect of the attribute filter is to
# skip entire row groups.
lyr.SetAttributeFilter("fid = 4")

lyr.ResetReading()
stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"])
vals = []
for batch in stream:
for v in batch["uint8"]:
vals.append(v)
assert vals == [5]


###############################################################################


def test_ogr_parquet_arrow_stream_numpy_attribute_filter_on_fid_with_fid_column():
pytest.importorskip("osgeo.gdal_array")
pytest.importorskip("numpy")

filename = "/vsimem/test_ogr_parquet_arrow_stream_numpy_attribute_filter_on_fid_with_fid_column.parquet"
gdal.VectorTranslate(
filename, "data/poly.shp", options="-unsetFieldWidth -lco FID=my_fid"
)
ds = ogr.Open(filename)
lyr = ds.GetLayer(0)
lyr.SetAttributeFilter("fid in (1, 3)")
assert lyr.TestCapability(ogr.OLCFastGetArrowStream) == 1

f = lyr.GetNextFeature()
assert f["EAS_ID"] == 179
f = lyr.GetNextFeature()
assert f["EAS_ID"] == 173
f = lyr.GetNextFeature()
assert f is None

lyr.ResetReading()
stream = lyr.GetArrowStreamAsNumPy()
vals_fid = []
vals_EAS_ID = []
for batch in stream:
for v in batch["my_fid"]:
vals_fid.append(v)
for v in batch["EAS_ID"]:
vals_EAS_ID.append(v)
assert vals_fid == [1, 3]
assert vals_EAS_ID == [179, 173]

ds = None
gdal.Unlink(filename)


###############################################################################
# Test attribute filter through ArrowStream API
# We use the pyarrow API, to be able to test we correctly deal with decimal
Expand Down
45 changes: 35 additions & 10 deletions ogr/ogrsf_frmts/arrow_common/ograrrowlayer.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2424,6 +2424,8 @@ inline void OGRArrowLayer::ComputeConstraintsArrayIdx()
if (constraint.iField == m_poFeatureDefn->GetFieldCount() + SPF_FID)
{
constraint.iArrayIdx = m_nRequestedFIDColumn;
if (constraint.iArrayIdx < 0 && m_osFIDColumn.empty())
return;
}
else
{
Expand All @@ -2437,8 +2439,7 @@ inline void OGRArrowLayer::ComputeConstraintsArrayIdx()
"it being ignored",
constraint.iField ==
m_poFeatureDefn->GetFieldCount() + SPF_FID
? (m_osFIDColumn.empty() ? "FID"
: m_osFIDColumn.c_str())
? m_osFIDColumn.c_str()
: m_poFeatureDefn->GetFieldDefn(constraint.iField)
->GetNameRef());
}
Expand All @@ -2448,12 +2449,11 @@ inline void OGRArrowLayer::ComputeConstraintsArrayIdx()
if (constraint.iField == m_poFeatureDefn->GetFieldCount() + SPF_FID)
{
constraint.iArrayIdx = m_iFIDArrowColumn;
if (constraint.iArrayIdx < 0)
if (constraint.iArrayIdx < 0 && !m_osFIDColumn.empty())
{
CPLDebug(GetDriverUCName().c_str(),
"Constraint on field %s cannot be applied",
m_osFIDColumn.empty() ? "FID"
: m_osFIDColumn.c_str());
m_osFIDColumn.c_str());
}
}
else
Expand Down Expand Up @@ -2755,10 +2755,24 @@ inline bool OGRArrowLayer::SkipToNextFeatureDueToAttributeFilter() const
{
if (constraint.iArrayIdx < 0)
{
// can happen if ignoring a field that is needed by the
// attribute filter. ComputeConstraintsArrayIdx() will have
// warned about that
continue;
if (constraint.iField ==
m_poFeatureDefn->GetFieldCount() + SPF_FID &&
m_osFIDColumn.empty())
{
if (!ConstraintEvaluator(constraint,
static_cast<GIntBig>(m_nFeatureIdx)))
{
return true;
}
continue;
}
else
{
// can happen if ignoring a field that is needed by the
// attribute filter. ComputeConstraintsArrayIdx() will have
// warned about that
continue;
}
}

const arrow::Array *array =
Expand Down Expand Up @@ -4014,16 +4028,27 @@ inline int OGRArrowLayer::GetNextArrowArray(struct ArrowArrayStream *stream,

OverrideArrowRelease(m_poArrowDS, out_array);

const auto nFeatureIdxCur = m_nFeatureIdx;
m_nFeatureIdx += m_nIdxInBatch;

if (m_poAttrQuery || m_poFilterGeom)
{
PostFilterArrowArray(&m_sCachedSchema, out_array);
CPLStringList aosOptions;
if (m_iFIDArrowColumn < 0)
aosOptions.SetNameValue(
"BASE_SEQUENTIAL_FID",
CPLSPrintf(CPL_FRMT_GIB,
static_cast<GIntBig>(nFeatureIdxCur)));
PostFilterArrowArray(&m_sCachedSchema, out_array,
aosOptions.List());
if (out_array->length == 0)
{
// If there are no records after filtering, start again
// with a new batch
continue;
}
}

break;
}

Expand Down
140 changes: 136 additions & 4 deletions ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@
#include "ogr_api.h"
#include "ogr_recordbatch.h"
#include "ograrrowarrayhelper.h"
#include "ogr_p.h"
#include "ogr_swq.h"
#include "ogr_wkb.h"

#include "cpl_float.h"
Expand Down Expand Up @@ -2826,7 +2828,7 @@ BuildMapFieldNameToArrowPath(const struct ArrowSchema *schema,
static size_t FillValidityArrayFromAttrQuery(
const OGRLayer *poLayer, OGRFeatureQuery *poAttrQuery,
const struct ArrowSchema *schema, struct ArrowArray *array,
std::vector<bool> &abyValidityFromFilters)
std::vector<bool> &abyValidityFromFilters, CSLConstList papszOptions)
{
size_t nCountIntersecting = 0;
auto poFeatureDefn = const_cast<OGRLayer *>(poLayer)->GetLayerDefn();
Expand All @@ -2844,6 +2846,7 @@ static size_t FillValidityArrayFromAttrQuery(
};
std::vector<UsedFieldsInfo> aoUsedFieldsInfo;

bool bNeedsFID = false;
const CPLStringList aosUsedFields(poAttrQuery->GetUsedFields());
for (int i = 0; i < aosUsedFields.size(); ++i)
{
Expand All @@ -2865,13 +2868,139 @@ static size_t FillValidityArrayFromAttrQuery(
aosUsedFields[i]);
}
}
else if (EQUAL(aosUsedFields[i], "FID"))
{
bNeedsFID = true;
}
else
{
CPLDebug("OGR", "Cannot find used field %s", aosUsedFields[i]);
}
}

const size_t nLength = abyValidityFromFilters.size();

GIntBig nBaseSeqFID = -1;
std::vector<int> anArrowPathToFIDColumn;
if (bNeedsFID)
{
// BASE_SEQUENTIAL_FID is set when there is no Arrow column for the FID
// and we assume sequential FID numbering
const char *pszBaseSeqFID =
CSLFetchNameValue(papszOptions, "BASE_SEQUENTIAL_FID");
if (pszBaseSeqFID)
{
nBaseSeqFID = CPLAtoGIntBig(pszBaseSeqFID);

// Optimizimation for "FID = constant"
swq_expr_node *poNode =
static_cast<swq_expr_node *>(poAttrQuery->GetSWQExpr());
if (poNode->eNodeType == SNT_OPERATION &&
poNode->nOperation == SWQ_EQ && poNode->nSubExprCount == 2 &&
poNode->papoSubExpr[0]->eNodeType == SNT_COLUMN &&
poNode->papoSubExpr[1]->eNodeType == SNT_CONSTANT &&
poNode->papoSubExpr[0]->field_index ==
poFeatureDefn->GetFieldCount() + SPF_FID &&
poNode->papoSubExpr[1]->field_type == SWQ_INTEGER64)
{
if (nBaseSeqFID + static_cast<int64_t>(nLength) <
poNode->papoSubExpr[1]->int_value ||
nBaseSeqFID > poNode->papoSubExpr[1]->int_value)
{
return 0;
}
}
}
else
{
const char *pszFIDColumn =
const_cast<OGRLayer *>(poLayer)->GetFIDColumn();
if (pszFIDColumn && pszFIDColumn[0])
{
const auto oIter = oMapFieldNameToArrowPath.find(pszFIDColumn);
if (oIter != oMapFieldNameToArrowPath.end())
{
anArrowPathToFIDColumn = oIter->second;
}
}
if (anArrowPathToFIDColumn.empty())
{
CPLError(CE_Failure, CPLE_AppDefined,
"Filtering on FID requested but cannot associate a "
"FID with Arrow records");
}
}
}

for (size_t iRow = 0; iRow < nLength; ++iRow)
{
if (!abyValidityFromFilters[iRow])
continue;

if (bNeedsFID)
{
if (nBaseSeqFID >= 0)
{
oFeature.SetFID(nBaseSeqFID + iRow);
}
else if (!anArrowPathToFIDColumn.empty())
{
oFeature.SetFID(OGRNullFID);

const struct ArrowSchema *psSchemaField = schema;
const struct ArrowArray *psArray = array;
bool bSkip = false;
for (size_t i = 0; i < anArrowPathToFIDColumn.size(); ++i)
{
const int iChild = anArrowPathToFIDColumn[i];
if (i > 0)
{
const uint8_t *pabyValidity =
psArray->null_count == 0
? nullptr
: static_cast<uint8_t *>(
const_cast<void *>(psArray->buffers[0]));
const size_t nOffsettedIndex =
static_cast<size_t>(iRow + psArray->offset);
if (pabyValidity &&
!TestBit(pabyValidity, nOffsettedIndex))
{
bSkip = true;
break;
}
}

psSchemaField = psSchemaField->children[iChild];
psArray = psArray->children[iChild];
}
if (bSkip)
continue;

const char *format = psSchemaField->format;
const uint8_t *pabyValidity =
psArray->null_count == 0
? nullptr
: static_cast<uint8_t *>(
const_cast<void *>(psArray->buffers[0]));
const size_t nOffsettedIndex =
static_cast<size_t>(iRow + psArray->offset);
if (pabyValidity && !TestBit(pabyValidity, nOffsettedIndex))
{
// do nothing
}
else if (format[0] == 'i')
{
oFeature.SetFID(static_cast<const int32_t *>(
psArray->buffers[1])[nOffsettedIndex]);
}
else if (format[0] == 'l')
{
oFeature.SetFID(static_cast<const int64_t *>(
psArray->buffers[1])[nOffsettedIndex]);
}
}
}

for (const auto &sInfo : aoUsedFieldsInfo)
{
const int iOGRFieldIndex = sInfo.iOGRFieldIndex;
Expand Down Expand Up @@ -3227,7 +3356,8 @@ static size_t FillValidityArrayFromAttrQuery(
* Assumes that CanPostFilterArrowArray() has been called and returned true.
*/
void OGRLayer::PostFilterArrowArray(const struct ArrowSchema *schema,
struct ArrowArray *array) const
struct ArrowArray *array,
CSLConstList papszOptions) const
{
if (!m_poFilterGeom && !m_poAttrQuery)
return;
Expand Down Expand Up @@ -3276,7 +3406,8 @@ void OGRLayer::PostFilterArrowArray(const struct ArrowSchema *schema,
const size_t nCountIntersecting =
m_poAttrQuery && nCountIntersectingGeom > 0
? FillValidityArrayFromAttrQuery(this, m_poAttrQuery, schema, array,
abyValidityFromFilters)
abyValidityFromFilters,
papszOptions)
: m_poFilterGeom ? nCountIntersectingGeom
: nLength;
// Nothing to do ?
Expand All @@ -3286,7 +3417,8 @@ void OGRLayer::PostFilterArrowArray(const struct ArrowSchema *schema,
return;
}

if (!CompactStructArray(schema, array, 0, abyValidityFromFilters))
if (nCountIntersecting > 0 &&
!CompactStructArray(schema, array, 0, abyValidityFromFilters))
{
array->release(array);
memset(array, 0, sizeof(*array));
Expand Down
3 changes: 2 additions & 1 deletion ogr/ogrsf_frmts/ogrsf_frmts.h
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ class CPL_DLL OGRLayer : public GDALMajorObject
virtual bool
CanPostFilterArrowArray(const struct ArrowSchema *schema) const;
void PostFilterArrowArray(const struct ArrowSchema *schema,
struct ArrowArray *array) const;
struct ArrowArray *array,
CSLConstList papszOptions) const;

public:
OGRLayer();
Expand Down
1 change: 1 addition & 0 deletions ogr/ogrsf_frmts/parquet/ogr_parquet.h
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@ class OGRParquetLayer final : public OGRParquetLayerBase
std::vector<int> m_anMapGeomFieldIndexToParquetColumn{};
bool m_bHasMissingMappingToParquet = false;

std::vector<int64_t> m_anSelectedGroupsStartFeatureIdx{};
std::vector<int> m_anRequestedParquetColumns{}; // only valid when
// m_bIgnoredFields is set
#ifdef DEBUG
Expand Down
Loading

0 comments on commit 2cf3921

Please sign in to comment.