Skip to content

Commit

Permalink
Produce rows as slices of pyarrow.Table
Browse files Browse the repository at this point in the history
  • Loading branch information
ptallada committed Aug 3, 2024
1 parent ac09074 commit 5efd997
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 42 deletions.
25 changes: 7 additions & 18 deletions TCLIService/ttypes.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions pyhive/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@ def _reset_state(self):

# Internal helper state
self._state = self._STATE_NONE
self._data = collections.deque()
self._data = None
self._columns = None

def _fetch_while(self, fn):
def _fetch_while(self, fn, schema):
while fn():
self._fetch_more()
self._fetch_more(schema)
if fn():
time.sleep(self._poll_interval)

Expand Down
158 changes: 137 additions & 21 deletions pyhive/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,11 @@

import base64
import datetime
import io
import itertools
import numpy as np
import pyarrow as pa
import pyarrow.json
import re
from decimal import Decimal
from ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED, create_default_context
Expand Down Expand Up @@ -40,7 +45,8 @@

_logger = logging.getLogger(__name__)

_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,6})?)')
_TIMESTAMP_PATTERN = re.compile(r'(\d+-\d+-\d+ \d+:\d+:\d+(\.\d{,9})?)')
_INTERVAL_DAY_TIME_PATTERN = re.compile(r'(\d+) (\d+):(\d+):(\d+(?:.\d+)?)')

ssl_cert_parameter_map = {
"none": CERT_NONE,
Expand Down Expand Up @@ -106,9 +112,36 @@ def _parse_timestamp(value):
value = None
return value

def _parse_date(value):
if value:
format = '%Y-%m-%d'
value = datetime.datetime.strptime(value, format).date()
else:
value = None
return value

TYPES_CONVERTER = {"DECIMAL_TYPE": Decimal,
"TIMESTAMP_TYPE": _parse_timestamp}
def _parse_interval_day_time(value):
if value:
match = _INTERVAL_DAY_TIME_PATTERN.match(value)
if match:
days = int(match.group(1))
hours = int(match.group(2))
minutes = int(match.group(3))
seconds = float(match.group(4))
value = datetime.timedelta(days=days, hours=hours, minutes=minutes, seconds=seconds)
else:
raise Exception(
'Cannot convert "{}" into an interval_day_time'.format(value))
else:
value = None
return value

TYPES_CONVERTER = {
"DECIMAL_TYPE": Decimal,
"TIMESTAMP_TYPE": _parse_timestamp,
"DATE_TYPE": _parse_date,
"INTERVAL_DAY_TIME_TYPE": _parse_interval_day_time,
}


class HiveParamEscaper(common.ParamEscaper):
Expand Down Expand Up @@ -488,7 +521,50 @@ def cancel(self):
response = self._connection.client.CancelOperation(req)
_check_status(response)

def _fetch_more(self):
def fetchone(self, schema=[]):
return self.fetchmany(1, schema)

def fetchall(self, schema=[]):
return self.fetchmany(-1, schema)

def fetchmany(self, size=None, schema=[]):
if size is None:
size = self.arraysize

if self._state == self._STATE_NONE:
raise exc.ProgrammingError("No query yet")

if size == -1:
# Fetch everything
self._fetch_while(lambda: self._state != self._STATE_FINISHED, schema)
else:
self._fetch_while(lambda:
(self._state != self._STATE_FINISHED) and
(self._data is None or self._data.num_rows < size),
schema
)

if not self._data:
return None

if size == -1:
# Fetch everything
size = self._data.num_rows
else:
size = min(size, self._data.num_rows)

self._rownumber += size
rows = self._data[:size]

if size == self._data.num_rows:
# Fetch everything
self._data = None
else:
self._data = self._data[size:]

return rows

def _fetch_more(self, ext_schema):
"""Send another TFetchResultsReq and update state"""
assert(self._state == self._STATE_RUNNING), "Should be running when in _fetch_more"
assert(self._operationHandle is not None), "Should have an op handle in _fetch_more"
Expand All @@ -503,15 +579,21 @@ def _fetch_more(self):
_check_status(response)
schema = self.description
assert not response.results.rows, 'expected data in columnar format'
columns = [_unwrap_column(col, col_schema[1]) for col, col_schema in
zip(response.results.columns, schema)]
new_data = list(zip(*columns))
self._data += new_data
columns = [_unwrap_column(col, col_schema[1], e_schema) for col, col_schema, e_schema in
itertools.zip_longest(response.results.columns, schema, ext_schema)]
names = [col[0] for col in schema]
new_data = pa.Table.from_batches([pa.RecordBatch.from_arrays(columns, names=names)])
# response.hasMoreRows seems to always be False, so we instead check the number of rows
# https://github.com/apache/hive/blob/release-1.2.1/service/src/java/org/apache/hive/service/cli/thrift/ThriftCLIService.java#L678
# if not response.hasMoreRows:
if not new_data:
if new_data.num_rows == 0:
self._state = self._STATE_FINISHED
return

if self._data is None:
self._data = new_data
else:
self._data = pa.concat_tables([self._data, new_data])

def poll(self, get_progress_update=True):
"""Poll for and return the raw status data provided by the Hive Thrift REST API.
Expand Down Expand Up @@ -585,21 +667,55 @@ def fetch_logs(self):
#


def _unwrap_column(col, type_=None):
def _unwrap_column(col, type_=None, schema=None):
"""Return a list of raw values from a TColumn instance."""
for attr, wrapper in iteritems(col.__dict__):
if wrapper is not None:
result = wrapper.values
nulls = wrapper.nulls # bit set describing what's null
assert isinstance(nulls, bytes)
for i, char in enumerate(nulls):
byte = ord(char) if sys.version_info[0] == 2 else char
for b in range(8):
if byte & (1 << b):
result[i * 8 + b] = None
converter = TYPES_CONVERTER.get(type_, None)
if converter and type_:
result = [converter(row) if row else row for row in result]
if attr in ['boolVal', 'byteVal', 'i16Val', 'i32Val', 'i64Val', 'doubleVal']:
values = wrapper.values
# unpack nulls as a byte array
nulls = np.unpackbits(np.frombuffer(wrapper.nulls, dtype='uint8')).view(bool)
# override a full mask as trailing False values are not sent
mask = np.zeros(values.shape, dtype='?')
end = min(len(mask), len(nulls))
mask[:end] = nulls[:end]

# float values are transferred as double
if type_ == 'FLOAT_TYPE':
values = values.astype('>f4')

result = pa.array(values.byteswap().view(values.dtype.newbyteorder()), mask=mask)

else:
result = wrapper.values
nulls = wrapper.nulls # bit set describing what's null
if len(result) == 0:
return pa.array([])
assert isinstance(nulls, bytes)
for i, char in enumerate(nulls):
byte = ord(char) if sys.version_info[0] == 2 else char
for b in range(8):
if byte & (1 << b):
result[i * 8 + b] = None
converter = TYPES_CONVERTER.get(type_, None)
if converter and type_:
result = [converter(row) if row else row for row in result]

if type_ in ['ARRAY_TYPE', 'MAP_TYPE', 'STRUCT_TYPE']:
fd = io.BytesIO()
for row in result:
if row is None:
row = 'null'
fd.write(f'{{"c":{row}}}\n'.encode('utf8'))
fd.seek(0)

if schema == None:
# NOTE: JSON map conversion (from the original struct) is not supported
result = pa.json.read_json(fd, parse_options=None)[0].combine_chunks()
else:
sch = pa.schema([('c', schema)])
opts = pa.json.ParseOptions(explicit_schema=sch)
result = pa.json.read_json(fd, parse_options=opts)[0].combine_chunks()
return result
raise DataError("Got empty column value {}".format(col)) # pragma: no cover

Expand Down
97 changes: 97 additions & 0 deletions pyhive/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
"""
This module attempts to reconstruct an Arrow schema from the info dumped at the beginning of a Hive query log.
SUPPORTS:
* All primitive types _except_ INTERVAL.
* STRUCT and ARRAY types.
* Composition of any combination of previous types.
LIMITATIONS:
* PyHive does not support INTERVAL types yet. A converter needs to be implemented.
* Hive sends complex types always as strings as something _similar_ to JSON.
* Arrow can parse most of this pseudo-JSON excluding:
* MAP and INTERVAL types
* A custom parser would be needed to implement support for all types and their composition.
"""

import pyparsing as pp
import pyarrow as pa

def a_type(s, loc, toks):
m_basic = {
'tinyint' : pa.int8(),
'smallint' : pa.int16(),
'int' : pa.int32(),
'bigint' : pa.int64(),
'float' : pa.float32(),
'double' : pa.float64(),
'boolean' : pa.bool_(),
'string' : pa.string(),
'char' : pa.string(),
'varchar' : pa.string(),
'binary' : pa.binary(),
'timestamp' : pa.timestamp('ns'),
'date' : pa.date32(),
#'interval_year_month' : pa.month_day_nano_interval(),
#'interval_day_time' : pa.month_day_nano_interval(),
}

typ, args = toks[0], toks[1:]

if typ in m_basic:
return m_basic[typ]
if typ == 'decimal':
return pa.decimal128(*map(int, args))
if typ == 'array':
return pa.list_(args[0])
#if typ == 'map':
# return pa.map_(args[0], args[1])
if typ == 'struct':
return pa.struct(args)
raise NotImplementedError(f"Type {typ} is not supported")

def a_field(s, loc, toks):
return pa.field(toks[0], toks[1])

LB, RB, LP, RP, LT, RT, COMMA, COLON = map(pp.Suppress, "[]()<>,:")

def t_args(n):
return LP + pp.delimitedList(pp.Word(pp.nums), ",", min=n, max=n) + RP

t_basic = pp.one_of(
"tinyint smallint int bigint float double boolean string binary timestamp date decimal",
caseless=True, as_keyword=True
)
t_interval = pp.one_of(
"interval_year_month interval_day_time",
caseless=True, as_keyword=True
)
t_char = pp.one_of("char varchar", caseless=True, as_keyword=True) + t_args(1)
t_decimal = pp.CaselessKeyword("decimal") + t_args(2)
t_primitive = (t_basic ^ t_char ^ t_decimal).set_parse_action(a_type)

t_type = pp.Forward()

t_label = pp.Word(pp.alphas + "_", pp.alphanums + "_")
t_array = pp.CaselessKeyword('array') + LT + t_type + RT
t_map = pp.CaselessKeyword('map') + LT + t_primitive + COMMA + t_type + RT
t_struct = pp.CaselessKeyword('struct') + LT + pp.delimitedList((t_label + COLON + t_type).set_parse_action(a_field), ",") + RT
t_complex = (t_array ^ t_map ^ t_struct).set_parse_action(a_type)

t_type <<= t_primitive ^ t_complex
t_top_type = t_type ^ t_interval

l_schema, l_fieldschemas, l_fieldschema, l_name, l_type, l_comment, l_properties, l_null = map(
lambda x: pp.Keyword(x).suppress(), "Schema fieldSchemas FieldSchema name type comment properties null".split(' ')
)
t_fieldschema = l_fieldschema + LP + l_name + COLON + t_label.suppress() + COMMA + l_type + COLON + t_top_type + COMMA + l_comment + COLON + l_null + RP
t_schema = l_schema + LP + l_fieldschemas + COLON + LB + pp.delimitedList(t_fieldschema, ',') + RB + COMMA + l_properties + COLON + l_null + RP

def parse_schema(logs):
prefix = 'INFO : Returning Hive schema: '

for l in logs:
if l.startswith(prefix):
str_schema = l[len(prefix):]

return t_schema.parse_string(str_schema).as_list()
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[build-system]
requires = ["setuptools"]
build-backend = "setuptools.build_meta"

0 comments on commit 5efd997

Please sign in to comment.