Skip to content

Commit

Permalink
[BUG] Use os.path.join in read_hudi only for local fs (#2336)
Browse files Browse the repository at this point in the history
Closes #2295

This PR modifies the path joining logic in read_hudi to only use
`os.path.join` if the file system is a local filesystem, and otherwise
manually join paths using `/`. This fixes the issue of reading Hudi from
S3 on a Windows machine.

Simulated a hudi read from Windows with:
```
from unittest.mock import patch
import daft

with patch("os.path.join", side_effect=lambda *args: "\\".join(args)):
    df = daft.read_hudi("s3://daft-public-data/hudi/v6_simplekeygen_nonhivestyle/")
    df.show()
```
  • Loading branch information
colin-ho authored Jun 4, 2024
1 parent 9a9e52e commit b7295da
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 12 deletions.
17 changes: 17 additions & 0 deletions daft/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import dataclasses
import logging
import os
import pathlib
import sys
import urllib.parse
Expand Down Expand Up @@ -320,3 +321,19 @@ def glob_path_with_stats(
num_rows.append(infos.get("rows"))

return FileInfos.from_infos(file_paths=file_paths, file_sizes=file_sizes, num_rows=num_rows)


###
# Path joining
###


def join_path(fs: FileSystem, base_path: str, *sub_paths: str) -> str:
"""
Join a base path with sub-paths using the appropriate path separator
for the given filesystem.
"""
if isinstance(fs, LocalFileSystem):
return os.path.join(base_path, *sub_paths)
else:
return f"{base_path.rstrip('/')}/{'/'.join(sub_paths)}"
9 changes: 4 additions & 5 deletions daft/hudi/hudi_scan.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

import logging
import os
from collections.abc import Iterator

import daft
Expand All @@ -12,7 +11,7 @@
ScanTask,
StorageConfig,
)
from daft.filesystem import _resolve_paths_and_filesystem
from daft.filesystem import _resolve_paths_and_filesystem, join_path
from daft.hudi.pyhudi.table import HUDI_METAFIELD_PARTITION_PATH, HudiTable, HudiTableMetadata
from daft.io.scan import PartitionField, ScanOperator
from daft.logical.schema import Schema
Expand All @@ -23,8 +22,8 @@
class HudiScanOperator(ScanOperator):
def __init__(self, table_uri: str, storage_config: StorageConfig) -> None:
super().__init__()
resolved_path, resolved_fs = _resolve_paths_and_filesystem(table_uri, storage_config.config.io_config)
self._table = HudiTable(table_uri, resolved_fs, resolved_path[0])
resolved_path, self._resolved_fs = _resolve_paths_and_filesystem(table_uri, storage_config.config.io_config)
self._table = HudiTable(table_uri, self._resolved_fs, resolved_path[0])
self._storage_config = storage_config
self._schema = Schema.from_pyarrow_schema(self._table.schema)
partition_fields = set(self._table.props.partition_fields)
Expand Down Expand Up @@ -69,7 +68,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
if limit_files and rows_left <= 0:
break

path = os.path.join(self._table.table_uri, files_metadata["path"][task_idx].as_py())
path = join_path(self._resolved_fs, self._table.table_uri, files_metadata["path"][task_idx].as_py())
record_count = files_metadata["num_records"][task_idx].as_py()
try:
size_bytes = files_metadata["size_bytes"][task_idx].as_py()
Expand Down
4 changes: 2 additions & 2 deletions daft/hudi/pyhudi/table.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
from __future__ import annotations

import os
from collections import defaultdict
from dataclasses import dataclass

import pyarrow as pa
import pyarrow.fs as pafs

from daft.filesystem import join_path
from daft.hudi.pyhudi.filegroup import BaseFile, FileGroup, FileSlice
from daft.hudi.pyhudi.timeline import Timeline
from daft.hudi.pyhudi.utils import (
Expand Down Expand Up @@ -92,7 +92,7 @@ def get_latest_file_slices(self) -> list[FileSlice]:
class HudiTableProps:
def __init__(self, fs: pafs.FileSystem, base_path: str):
self._props = {}
hoodie_properties_file = os.path.join(base_path, ".hoodie", "hoodie.properties")
hoodie_properties_file = join_path(fs, base_path, ".hoodie", "hoodie.properties")
with fs.open_input_file(hoodie_properties_file) as f:
lines = f.readall().decode("utf-8").splitlines()
for line in lines:
Expand Down
10 changes: 7 additions & 3 deletions daft/hudi/pyhudi/timeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
import pyarrow.fs as pafs
import pyarrow.parquet as pq

from daft.filesystem import join_path


class State(Enum):
REQUESTED = 0
Expand Down Expand Up @@ -47,7 +49,7 @@ def has_completed_commit(self) -> bool:
return len(self.completed_commit_instants) > 0

def _load_completed_commit_instants(self):
timeline_path = os.path.join(self.base_path, ".hoodie")
timeline_path = join_path(self.fs, self.base_path, ".hoodie")
write_action_exts = {".commit"}
commit_instants = []
for file_info in self.fs.get_file_info(pafs.FileSelector(timeline_path)):
Expand All @@ -61,7 +63,9 @@ def _load_completed_commit_instants(self):
def get_latest_commit_metadata(self) -> dict:
if not self.has_completed_commit:
return {}
latest_instant_file_path = os.path.join(self.base_path, ".hoodie", self.completed_commit_instants[-1].file_name)
latest_instant_file_path = join_path(
self.fs, self.base_path, ".hoodie", self.completed_commit_instants[-1].file_name
)
with self.fs.open_input_file(latest_instant_file_path) as f:
return json.load(f)

Expand All @@ -71,6 +75,6 @@ def get_latest_commit_schema(self) -> pa.Schema:
return pa.schema([])

_, write_stats = next(iter(latest_commit_metadata["partitionToWriteStats"].items()))
base_file_path = os.path.join(self.base_path, write_stats[0]["path"])
base_file_path = join_path(self.fs, self.base_path, write_stats[0]["path"])
with self.fs.open_input_file(base_file_path) as f:
return pq.read_schema(f)
6 changes: 4 additions & 2 deletions daft/hudi/pyhudi/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,16 @@
import pyarrow.fs as pafs
import pyarrow.parquet as pq

from daft.filesystem import join_path


@dataclass(init=False)
class FsFileMetadata:
def __init__(self, fs: pafs.FileSystem, base_path: str, path: str, base_name: str):
self.base_path = base_path
self.path = path
self.base_name = base_name
with fs.open_input_file(os.path.join(base_path, path)) as f:
with fs.open_input_file(join_path(fs, base_path, path)) as f:
metadata = pq.read_metadata(f)
self.size = metadata.serialized_size
self.num_records = metadata.num_rows
Expand Down Expand Up @@ -70,7 +72,7 @@ def _extract_min_max(metadata: pq.FileMetaData):
def list_relative_file_paths(
base_path: str, sub_path: str, fs: pafs.FileSystem, includes: list[str] | None
) -> list[FsFileMetadata]:
listed_paths: list[pafs.FileInfo] = fs.get_file_info(pafs.FileSelector(os.path.join(base_path, sub_path)))
listed_paths: list[pafs.FileInfo] = fs.get_file_info(pafs.FileSelector(join_path(fs, base_path, sub_path)))
file_paths = []
common_prefix_len = len(base_path) + 1
for listed_path in listed_paths:
Expand Down

0 comments on commit b7295da

Please sign in to comment.