diff --git a/autotest/ogr/ogr_parquet.py b/autotest/ogr/ogr_parquet.py index 7c9c1fd3d241..e11034488c9a 100755 --- a/autotest/ogr/ogr_parquet.py +++ b/autotest/ogr/ogr_parquet.py @@ -2123,6 +2123,92 @@ def test_ogr_parquet_arrow_stream_numpy_fast_attribute_filter(filter): assert fc != 0 +############################################################################### + + +def test_ogr_parquet_arrow_stream_numpy_attribute_filter_on_fid_without_fid_column(): + pytest.importorskip("osgeo.gdal_array") + pytest.importorskip("numpy") + + ds = ogr.Open("data/parquet/test.parquet") + lyr = ds.GetLayer(0) + ignored_fields = ["decimal128", "decimal256", "time64_ns"] + lyr_defn = lyr.GetLayerDefn() + for i in range(lyr_defn.GetFieldCount()): + fld_defn = lyr_defn.GetFieldDefn(i) + if fld_defn.GetName().startswith("map_"): + ignored_fields.append(fld_defn.GetNameRef()) + lyr.SetIgnoredFields(ignored_fields) + lyr.SetAttributeFilter("fid in (1, 3)") + assert lyr.TestCapability(ogr.OLCFastGetArrowStream) == 1 + + f = lyr.GetNextFeature() + assert f["uint8"] == 2 + f = lyr.GetNextFeature() + assert f["uint8"] == 4 + f = lyr.GetNextFeature() + assert f is None + + lyr.ResetReading() + stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"]) + vals = [] + for batch in stream: + for v in batch["uint8"]: + vals.append(v) + assert vals == [2, 4] + + # Check that it works if the effect of the attribute filter is to + # skip entire row groups. + lyr.SetAttributeFilter("fid = 4") + + lyr.ResetReading() + stream = lyr.GetArrowStreamAsNumPy(options=["USE_MASKED_ARRAYS=NO"]) + vals = [] + for batch in stream: + for v in batch["uint8"]: + vals.append(v) + assert vals == [5] + + +############################################################################### + + +def test_ogr_parquet_arrow_stream_numpy_attribute_filter_on_fid_with_fid_column(): + pytest.importorskip("osgeo.gdal_array") + pytest.importorskip("numpy") + + filename = "/vsimem/test_ogr_parquet_arrow_stream_numpy_attribute_filter_on_fid_with_fid_column.parquet" + gdal.VectorTranslate( + filename, "data/poly.shp", options="-unsetFieldWidth -lco FID=my_fid" + ) + ds = ogr.Open(filename) + lyr = ds.GetLayer(0) + lyr.SetAttributeFilter("fid in (1, 3)") + assert lyr.TestCapability(ogr.OLCFastGetArrowStream) == 1 + + f = lyr.GetNextFeature() + assert f["EAS_ID"] == 179 + f = lyr.GetNextFeature() + assert f["EAS_ID"] == 173 + f = lyr.GetNextFeature() + assert f is None + + lyr.ResetReading() + stream = lyr.GetArrowStreamAsNumPy() + vals_fid = [] + vals_EAS_ID = [] + for batch in stream: + for v in batch["my_fid"]: + vals_fid.append(v) + for v in batch["EAS_ID"]: + vals_EAS_ID.append(v) + assert vals_fid == [1, 3] + assert vals_EAS_ID == [179, 173] + + ds = None + gdal.Unlink(filename) + + ############################################################################### # Test attribute filter through ArrowStream API # We use the pyarrow API, to be able to test we correctly deal with decimal diff --git a/ogr/ogrsf_frmts/arrow_common/ograrrowlayer.hpp b/ogr/ogrsf_frmts/arrow_common/ograrrowlayer.hpp index 812b8d95ee1f..87f33bb93b27 100644 --- a/ogr/ogrsf_frmts/arrow_common/ograrrowlayer.hpp +++ b/ogr/ogrsf_frmts/arrow_common/ograrrowlayer.hpp @@ -2424,6 +2424,8 @@ inline void OGRArrowLayer::ComputeConstraintsArrayIdx() if (constraint.iField == m_poFeatureDefn->GetFieldCount() + SPF_FID) { constraint.iArrayIdx = m_nRequestedFIDColumn; + if (constraint.iArrayIdx < 0 && m_osFIDColumn.empty()) + return; } else { @@ -2437,8 +2439,7 @@ inline void OGRArrowLayer::ComputeConstraintsArrayIdx() "it being ignored", constraint.iField == m_poFeatureDefn->GetFieldCount() + SPF_FID - ? (m_osFIDColumn.empty() ? "FID" - : m_osFIDColumn.c_str()) + ? m_osFIDColumn.c_str() : m_poFeatureDefn->GetFieldDefn(constraint.iField) ->GetNameRef()); } @@ -2448,12 +2449,11 @@ inline void OGRArrowLayer::ComputeConstraintsArrayIdx() if (constraint.iField == m_poFeatureDefn->GetFieldCount() + SPF_FID) { constraint.iArrayIdx = m_iFIDArrowColumn; - if (constraint.iArrayIdx < 0) + if (constraint.iArrayIdx < 0 && !m_osFIDColumn.empty()) { CPLDebug(GetDriverUCName().c_str(), "Constraint on field %s cannot be applied", - m_osFIDColumn.empty() ? "FID" - : m_osFIDColumn.c_str()); + m_osFIDColumn.c_str()); } } else @@ -2755,10 +2755,24 @@ inline bool OGRArrowLayer::SkipToNextFeatureDueToAttributeFilter() const { if (constraint.iArrayIdx < 0) { - // can happen if ignoring a field that is needed by the - // attribute filter. ComputeConstraintsArrayIdx() will have - // warned about that - continue; + if (constraint.iField == + m_poFeatureDefn->GetFieldCount() + SPF_FID && + m_osFIDColumn.empty()) + { + if (!ConstraintEvaluator(constraint, + static_cast(m_nFeatureIdx))) + { + return true; + } + continue; + } + else + { + // can happen if ignoring a field that is needed by the + // attribute filter. ComputeConstraintsArrayIdx() will have + // warned about that + continue; + } } const arrow::Array *array = @@ -4014,9 +4028,19 @@ inline int OGRArrowLayer::GetNextArrowArray(struct ArrowArrayStream *stream, OverrideArrowRelease(m_poArrowDS, out_array); + const auto nFeatureIdxCur = m_nFeatureIdx; + m_nFeatureIdx += m_nIdxInBatch; + if (m_poAttrQuery || m_poFilterGeom) { - PostFilterArrowArray(&m_sCachedSchema, out_array); + CPLStringList aosOptions; + if (m_iFIDArrowColumn < 0) + aosOptions.SetNameValue( + "BASE_SEQUENTIAL_FID", + CPLSPrintf(CPL_FRMT_GIB, + static_cast(nFeatureIdxCur))); + PostFilterArrowArray(&m_sCachedSchema, out_array, + aosOptions.List()); if (out_array->length == 0) { // If there are no records after filtering, start again @@ -4024,6 +4048,7 @@ inline int OGRArrowLayer::GetNextArrowArray(struct ArrowArrayStream *stream, continue; } } + break; } diff --git a/ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp b/ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp index 36c76d7ad782..fc6167379f41 100644 --- a/ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp +++ b/ogr/ogrsf_frmts/generic/ogrlayerarrow.cpp @@ -30,6 +30,8 @@ #include "ogr_api.h" #include "ogr_recordbatch.h" #include "ograrrowarrayhelper.h" +#include "ogr_p.h" +#include "ogr_swq.h" #include "ogr_wkb.h" #include "cpl_float.h" @@ -2826,7 +2828,7 @@ BuildMapFieldNameToArrowPath(const struct ArrowSchema *schema, static size_t FillValidityArrayFromAttrQuery( const OGRLayer *poLayer, OGRFeatureQuery *poAttrQuery, const struct ArrowSchema *schema, struct ArrowArray *array, - std::vector &abyValidityFromFilters) + std::vector &abyValidityFromFilters, CSLConstList papszOptions) { size_t nCountIntersecting = 0; auto poFeatureDefn = const_cast(poLayer)->GetLayerDefn(); @@ -2844,6 +2846,7 @@ static size_t FillValidityArrayFromAttrQuery( }; std::vector aoUsedFieldsInfo; + bool bNeedsFID = false; const CPLStringList aosUsedFields(poAttrQuery->GetUsedFields()); for (int i = 0; i < aosUsedFields.size(); ++i) { @@ -2865,13 +2868,139 @@ static size_t FillValidityArrayFromAttrQuery( aosUsedFields[i]); } } + else if (EQUAL(aosUsedFields[i], "FID")) + { + bNeedsFID = true; + } + else + { + CPLDebug("OGR", "Cannot find used field %s", aosUsedFields[i]); + } } const size_t nLength = abyValidityFromFilters.size(); + + GIntBig nBaseSeqFID = -1; + std::vector anArrowPathToFIDColumn; + if (bNeedsFID) + { + // BASE_SEQUENTIAL_FID is set when there is no Arrow column for the FID + // and we assume sequential FID numbering + const char *pszBaseSeqFID = + CSLFetchNameValue(papszOptions, "BASE_SEQUENTIAL_FID"); + if (pszBaseSeqFID) + { + nBaseSeqFID = CPLAtoGIntBig(pszBaseSeqFID); + + // Optimizimation for "FID = constant" + swq_expr_node *poNode = + static_cast(poAttrQuery->GetSWQExpr()); + if (poNode->eNodeType == SNT_OPERATION && + poNode->nOperation == SWQ_EQ && poNode->nSubExprCount == 2 && + poNode->papoSubExpr[0]->eNodeType == SNT_COLUMN && + poNode->papoSubExpr[1]->eNodeType == SNT_CONSTANT && + poNode->papoSubExpr[0]->field_index == + poFeatureDefn->GetFieldCount() + SPF_FID && + poNode->papoSubExpr[1]->field_type == SWQ_INTEGER64) + { + if (nBaseSeqFID + static_cast(nLength) < + poNode->papoSubExpr[1]->int_value || + nBaseSeqFID > poNode->papoSubExpr[1]->int_value) + { + return 0; + } + } + } + else + { + const char *pszFIDColumn = + const_cast(poLayer)->GetFIDColumn(); + if (pszFIDColumn && pszFIDColumn[0]) + { + const auto oIter = oMapFieldNameToArrowPath.find(pszFIDColumn); + if (oIter != oMapFieldNameToArrowPath.end()) + { + anArrowPathToFIDColumn = oIter->second; + } + } + if (anArrowPathToFIDColumn.empty()) + { + CPLError(CE_Failure, CPLE_AppDefined, + "Filtering on FID requested but cannot associate a " + "FID with Arrow records"); + } + } + } + for (size_t iRow = 0; iRow < nLength; ++iRow) { if (!abyValidityFromFilters[iRow]) continue; + + if (bNeedsFID) + { + if (nBaseSeqFID >= 0) + { + oFeature.SetFID(nBaseSeqFID + iRow); + } + else if (!anArrowPathToFIDColumn.empty()) + { + oFeature.SetFID(OGRNullFID); + + const struct ArrowSchema *psSchemaField = schema; + const struct ArrowArray *psArray = array; + bool bSkip = false; + for (size_t i = 0; i < anArrowPathToFIDColumn.size(); ++i) + { + const int iChild = anArrowPathToFIDColumn[i]; + if (i > 0) + { + const uint8_t *pabyValidity = + psArray->null_count == 0 + ? nullptr + : static_cast( + const_cast(psArray->buffers[0])); + const size_t nOffsettedIndex = + static_cast(iRow + psArray->offset); + if (pabyValidity && + !TestBit(pabyValidity, nOffsettedIndex)) + { + bSkip = true; + break; + } + } + + psSchemaField = psSchemaField->children[iChild]; + psArray = psArray->children[iChild]; + } + if (bSkip) + continue; + + const char *format = psSchemaField->format; + const uint8_t *pabyValidity = + psArray->null_count == 0 + ? nullptr + : static_cast( + const_cast(psArray->buffers[0])); + const size_t nOffsettedIndex = + static_cast(iRow + psArray->offset); + if (pabyValidity && !TestBit(pabyValidity, nOffsettedIndex)) + { + // do nothing + } + else if (format[0] == 'i') + { + oFeature.SetFID(static_cast( + psArray->buffers[1])[nOffsettedIndex]); + } + else if (format[0] == 'l') + { + oFeature.SetFID(static_cast( + psArray->buffers[1])[nOffsettedIndex]); + } + } + } + for (const auto &sInfo : aoUsedFieldsInfo) { const int iOGRFieldIndex = sInfo.iOGRFieldIndex; @@ -3227,7 +3356,8 @@ static size_t FillValidityArrayFromAttrQuery( * Assumes that CanPostFilterArrowArray() has been called and returned true. */ void OGRLayer::PostFilterArrowArray(const struct ArrowSchema *schema, - struct ArrowArray *array) const + struct ArrowArray *array, + CSLConstList papszOptions) const { if (!m_poFilterGeom && !m_poAttrQuery) return; @@ -3276,7 +3406,8 @@ void OGRLayer::PostFilterArrowArray(const struct ArrowSchema *schema, const size_t nCountIntersecting = m_poAttrQuery && nCountIntersectingGeom > 0 ? FillValidityArrayFromAttrQuery(this, m_poAttrQuery, schema, array, - abyValidityFromFilters) + abyValidityFromFilters, + papszOptions) : m_poFilterGeom ? nCountIntersectingGeom : nLength; // Nothing to do ? @@ -3286,7 +3417,8 @@ void OGRLayer::PostFilterArrowArray(const struct ArrowSchema *schema, return; } - if (!CompactStructArray(schema, array, 0, abyValidityFromFilters)) + if (nCountIntersecting > 0 && + !CompactStructArray(schema, array, 0, abyValidityFromFilters)) { array->release(array); memset(array, 0, sizeof(*array)); diff --git a/ogr/ogrsf_frmts/ogrsf_frmts.h b/ogr/ogrsf_frmts/ogrsf_frmts.h index f6c5107a66f7..50b29045f80b 100644 --- a/ogr/ogrsf_frmts/ogrsf_frmts.h +++ b/ogr/ogrsf_frmts/ogrsf_frmts.h @@ -161,7 +161,8 @@ class CPL_DLL OGRLayer : public GDALMajorObject virtual bool CanPostFilterArrowArray(const struct ArrowSchema *schema) const; void PostFilterArrowArray(const struct ArrowSchema *schema, - struct ArrowArray *array) const; + struct ArrowArray *array, + CSLConstList papszOptions) const; public: OGRLayer();