diff --git a/src/main/cpp/CMakeLists.txt b/src/main/cpp/CMakeLists.txt index 1ad65687e2..d83253747b 100644 --- a/src/main/cpp/CMakeLists.txt +++ b/src/main/cpp/CMakeLists.txt @@ -1,5 +1,5 @@ # ============================================================================= -# Copyright (c) 2022-2023, NVIDIA CORPORATION. +# Copyright (c) 2022-2024, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except # in compliance with the License. You may obtain a copy of the License at @@ -165,6 +165,7 @@ add_library( src/cast_float_to_string.cu src/cast_string.cu src/cast_string_to_float.cu + src/datetime_parser.cu src/datetime_rebase.cu src/decimal_utils.cu src/histogram.cu diff --git a/src/main/cpp/src/CastStringJni.cpp b/src/main/cpp/src/CastStringJni.cpp index b7d898a0c8..1d39e7152e 100644 --- a/src/main/cpp/src/CastStringJni.cpp +++ b/src/main/cpp/src/CastStringJni.cpp @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -30,6 +30,7 @@ #include #include "cudf_jni_apis.hpp" +#include "datetime_parser.hpp" #include "dtype_utils.hpp" #include "jni_utils.hpp" @@ -255,4 +256,57 @@ JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_fromInteger } CATCH_CAST_EXCEPTION(env, 0); } + +JNIEXPORT jlong JNICALL +Java_com_nvidia_spark_rapids_jni_CastStrings_toTimestamp(JNIEnv* env, + jclass, + jlong input_column, + jlong transitions_handle, + jlong tz_indices_col, + jint tz_default_index, + jboolean ansi_enabled) +{ + JNI_NULL_CHECK(env, input_column, "input column is null", 0); + try { + cudf::jni::auto_set_device(env); + + auto const& input_view = + cudf::strings_column_view(*reinterpret_cast(input_column)); + auto const transitions = + reinterpret_cast(transitions_handle)->column(0); + const cudf::column_view* tz_indices_view = + reinterpret_cast(tz_indices_col); + auto const tz_index = static_cast(tz_default_index); + auto ret_cv = spark_rapids_jni::string_to_timestamp_with_tz( + input_view, transitions, *tz_indices_view, tz_index, ansi_enabled); + if (ret_cv) { return cudf::jni::release_as_jlong(ret_cv); } + } + CATCH_STD(env, 0); + + // sucess is false, throw exception. + // Note: do not need to release ret_cv, because it's nullptr when success is + // false. + JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Parse failed on Ansi mode", 0); +} + +JNIEXPORT jlong JNICALL Java_com_nvidia_spark_rapids_jni_CastStrings_toTimestampWithoutTimeZone( + JNIEnv* env, jclass, jlong input_column, jboolean allow_time_zone, jboolean ansi_enabled) +{ + JNI_NULL_CHECK(env, input_column, "input column is null", 0); + try { + cudf::jni::auto_set_device(env); + auto const& input_view = + cudf::strings_column_view(*reinterpret_cast(input_column)); + + auto ret_cv = + spark_rapids_jni::string_to_timestamp_without_tz(input_view, allow_time_zone, ansi_enabled); + if (ret_cv) { return cudf::jni::release_as_jlong(ret_cv); } + } + CATCH_STD(env, 0); + + // sucess is false, throw exception. + // Note: do not need to release ret_cv, because it's nullptr when success is + // false. + JNI_THROW_NEW(env, "java/lang/IllegalArgumentException", "Parse failed on Ansi mode", 0); +} } diff --git a/src/main/cpp/src/datetime_parser.cu b/src/main/cpp/src/datetime_parser.cu new file mode 100644 index 0000000000..505f9821ad --- /dev/null +++ b/src/main/cpp/src/datetime_parser.cu @@ -0,0 +1,672 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "datetime_parser.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include + +namespace spark_rapids_jni { +namespace { + +/** + * Represents local date time in a time zone. + */ +struct timestamp_components { + /** + * year: Max 6 digits. + * Spark stores timestamp into Long in microseconds. + * A Long is able to represent a timestamp within [+-]200 thousand years. + * Calculated from: Long.MaxValue/MinValue / microseconds_per_year + */ + int32_t year; + int8_t month; + int8_t day; + int8_t hour; + int8_t minute; + int8_t second; + int32_t microseconds; +}; + +/** + * Is white space + */ +__device__ __host__ inline bool is_whitespace(const char chr) +{ + switch (chr) { + case ' ': + case '\r': + case '\t': + case '\n': return true; + default: return false; + } +} + +/** + * Ported from Spark + */ +__device__ __host__ bool is_valid_digits(int segment, int digits) +{ + // A Long is able to represent a timestamp within [+-]200 thousand years + constexpr int maxDigitsYear = 6; + // For the nanosecond part, more than 6 digits is allowed, but will be + // truncated. + return segment == 6 || (segment == 0 && digits >= 4 && digits <= maxDigitsYear) || + // For the zoneId segment(7), it could be zero digits when it's a + // region-based zone ID + (segment == 7 && digits <= 2) || + (segment != 0 && segment != 6 && segment != 7 && digits > 0 && digits <= 2); +} + +/** + * function to get a string from string view + */ +struct get_string_fn { + cudf::column_device_view const& string_view; + + __device__ cudf::string_view operator()(size_t idx) + { + return string_view.element(idx); + } +}; + +/** + * We have to distinguish INVALID value with UNSUPPORTED value. + * INVALID means the value is invalid in Spark SQL. + * UNSUPPORTED means the value is valid in Spark SQL but not supported by rapids + * yet. As for INVALID values, we treat them in the same as Spark SQL. As for + * UNSUPPORTED values, we just throw cuDF exception. + */ +enum ParseResult { OK = 0, INVALID = 1, UNSUPPORTED = 2 }; + +template +struct parse_timestamp_string_fn { + // below three are required: + cudf::column_device_view const d_strings; + cudf::size_type const default_tz_index; + bool const allow_tz_in_date_str; + + // below two are optinal: + // The list column of transitions to figure out the correct offset + // to adjust the timestamp. The type of the values in this column is + // LIST>. + thrust::optional transitions = thrust::nullopt; + thrust::optional tz_indices = thrust::nullopt; + + __device__ thrust::tuple operator()(const cudf::size_type& idx) const + { + // inherit the null mask of the input column + if (!d_strings.is_valid(idx)) { + return thrust::make_tuple(cudf::timestamp_us{cudf::duration_us{0}}, ParseResult::INVALID); + } + + auto const d_str = d_strings.element(idx); + auto parse_ret_tuple = parse_string_to_timestamp_us(d_str); + auto ts_comp = thrust::get<0>(parse_ret_tuple); + auto tz_lit_ptr = thrust::get<1>(parse_ret_tuple); + auto tz_lit_len = thrust::get<2>(parse_ret_tuple); + auto result = thrust::get<3>(parse_ret_tuple); + + switch (result) { + case ParseResult::INVALID: + return thrust::make_tuple(cudf::timestamp_us{cudf::duration_us{0}}, ParseResult::INVALID); + case ParseResult::UNSUPPORTED: + return thrust::make_tuple(cudf::timestamp_us{cudf::duration_us{0}}, + ParseResult::UNSUPPORTED); + case ParseResult::OK: break; + } + + if constexpr (!with_timezone) { + // path without timezone, in which unix_timestamp is straightforwardly + // computed + auto const ts_unaligned = compute_epoch_us(ts_comp); + return thrust::make_tuple(cudf::timestamp_us{cudf::duration_us{ts_unaligned}}, + ParseResult::OK); + } + + // path with timezone, in which timezone offset has to be determined before + // computing unix_timestamp + int64_t utc_offset; + if (tz_lit_ptr == nullptr) { + // no tz in the string tailing, use default tz + utc_offset = compute_utc_offset(compute_loose_epoch_s(ts_comp), default_tz_index); + } else { + auto const tz_view = cudf::string_view(tz_lit_ptr, tz_lit_len); + // Firstly, try parsing as utc-like timezone rep + auto [fix_offset, ret_code] = parse_utc_like_tz(tz_view); + if (ret_code == ParseUtcLikeTzResult::UTC_LIKE_TZ) { + utc_offset = fix_offset; + } else if (ret_code == ParseUtcLikeTzResult::NOT_UTC_LIKE_TZ) { + // Then, try parsing as region-based timezone ID + auto const tz_index = query_index_from_tz_db(tz_view); + if (tz_index < 0) { + // TODO: distinguish unsupported and invalid tz + return thrust::make_tuple(cudf::timestamp_us{cudf::duration_us{0}}, ParseResult::INVALID); + } else { + // supported tz + utc_offset = compute_utc_offset(compute_loose_epoch_s(ts_comp), tz_index); + } + } else { + // (ret_code == ParseUtcLikeTzResult::INVALID) quick path to mark value invalid + return thrust::make_tuple(cudf::timestamp_us{cudf::duration_us{0}}, ParseResult::INVALID); + } + } + + // Compute the epoch as UTC timezone, then apply the timezone offset. + auto const ts_unaligned = compute_epoch_us(ts_comp); + + return thrust::make_tuple( + cudf::timestamp_us{cudf::duration_us{ts_unaligned - utc_offset * 1000000L}}, ParseResult::OK); + } + + enum ParseUtcLikeTzResult { + UTC_LIKE_TZ = 0, // successfully parsed the timezone offset + NOT_UTC_LIKE_TZ = 1, // not a valid UTC-like timezone representation, maybe valid region-based + INVALID = 2 // not a valid timezone representation + }; + + /** + * + * Parse UTC-like timezone representation such as: UTC+11:22:33, GMT-8:08:01. + * This function is purposed to be fully aligned to Apache Spark's behavior. The + * function returns the status along with the ParseUtcLikeTzResult result. + * + * with colon + * hh:mm : ^(GMT|UTC)?[+-](\d|0[0-9]|1[0-8]):(\d|[0-5][0-9]) + * hh:mm:ss : ^(GMT|UTC)?[+-](\d|0[0-9]|1[0-8]):[0-5][0-9]:[0-5][0-9] + * without colon + * hh only : ^(GMT|UTC)?[+-](\d|0[0-9]|1[0-8]) + * hh:mm:(ss) : ^(GMT|UTC)?[+-](0[0-9]|1[0-8])([0-5][0-9])?([0-5][0-9])? + * + * additional restriction: 18:00:00 is the upper bound (which means 18:00:01 + * is invalid) + */ + __device__ inline thrust::pair parse_utc_like_tz( + cudf::string_view const& tz_lit) const + { + cudf::size_type const len = tz_lit.size_bytes(); + + char const* ptr = tz_lit.data(); + + size_t char_offset = 0; + // skip UTC|GMT if existing + if (len > 2 && ((*ptr == 'G' && *(ptr + 1) == 'M' && *(ptr + 2) == 'T') || + (*ptr == 'U' && *(ptr + 1) == 'T' && *(ptr + 2) == 'C'))) { + char_offset = 3; + } + + // return for the pattern UTC|GMT (without exact offset) + if (len == char_offset) return {0, ParseUtcLikeTzResult::UTC_LIKE_TZ}; + + // parse sign +|- + char const sign_char = *(ptr + char_offset++); + int64_t sign; + if (sign_char == '+') { + sign = 1L; + } else if (sign_char == '-') { + sign = -1L; + } else { + // if the rep starts with UTC|GMT, it can NOT be region-based rep + return { + 0, char_offset < 3 ? ParseUtcLikeTzResult::NOT_UTC_LIKE_TZ : ParseUtcLikeTzResult::INVALID}; + } + + // parse hh:mm:ss + int64_t hms[3] = {0L, 0L, 0L}; + bool has_colon = false; + for (cudf::size_type i = 0; i < 3; i++) { + // deal with the first digit + hms[i] = *(ptr + char_offset++) - '0'; + if (hms[i] < 0 || hms[i] > 9) return {0, ParseUtcLikeTzResult::INVALID}; + + // deal with trailing single digit instant: + // hh(GMT+8) - valid + // mm(GMT+11:2) - must be separated from (h)h by `:` + // ss(GMT-11:22:3) - invalid + if (len == char_offset) { + if (i == 2 || (i == 1 && !has_colon)) return {0, ParseUtcLikeTzResult::INVALID}; + break; + } + + // deal with `:` + if (*(ptr + char_offset) == ':') { + // 1. (i == 1) one_digit mm with ss is invalid (+11:2:3) + // 2. (i == 2) one_digit ss is invalid (+11:22:3) + // 3. trailing `:` is invalid (GMT+8:) + if (i > 0 || len == ++char_offset) return {0, ParseUtcLikeTzResult::INVALID}; + has_colon = true; + continue; + } + + // deal with the second digit + auto const digit = *(ptr + char_offset++) - '0'; + if (digit < 0 || digit > 9) return {0, ParseUtcLikeTzResult::INVALID}; + hms[i] = hms[i] * 10 + digit; + + if (len == char_offset) break; + // deal with `:` + if (*(ptr + char_offset) == ':') { + // trailing `:` is invalid (UTC+11:) + if (len == ++char_offset) return {0, ParseUtcLikeTzResult::INVALID}; + has_colon = true; + } + } + + // the upper bound is 18:00:00 (regardless of sign) + if (hms[0] > 18 || hms[1] > 59 || hms[2] > 59) return {0, ParseUtcLikeTzResult::INVALID}; + if (hms[0] == 18 && hms[1] + hms[2] > 0) return {0, ParseUtcLikeTzResult::INVALID}; + + return {sign * (hms[0] * 3600L + hms[1] * 60L + hms[2]), ParseUtcLikeTzResult::UTC_LIKE_TZ}; + } + + /** + * use binary search to find tz index. + */ + __device__ inline int query_index_from_tz_db(cudf::string_view const& tz_lit) const + { + auto const tz_col = tz_indices->child(0); + auto const index_in_transition_col = tz_indices->child(1); + + auto const string_iter_begin = + thrust::make_transform_iterator(thrust::make_counting_iterator(0), get_string_fn{tz_col}); + auto const string_iter_end = string_iter_begin + tz_col.size(); + auto const it = thrust::lower_bound( + thrust::seq, string_iter_begin, string_iter_end, tz_lit, thrust::less()); + if (it != string_iter_end && *it == tz_lit) { + // found tz + auto const tz_name_index = + static_cast(thrust::distance(string_iter_begin, it)); + return static_cast(index_in_transition_col.element(tz_name_index)); + } else { + // not found tz + return -1; + } + } + + /** + * Perform binary search to search out the offset from UTC based on local epoch + * instants. Basically, this is the same approach as + * `convert_timestamp_tz_functor`. + */ + __device__ inline int64_t compute_utc_offset(int64_t const loose_epoch_second, + cudf::size_type const tz_index) const + { + auto const& utc_offsets = transitions->child().child(2); + auto const& loose_instants = transitions->child().child(3); + + auto const local_transitions = cudf::list_device_view{*transitions, tz_index}; + auto const list_size = local_transitions.size(); + + auto const transition_times = cudf::device_span( + loose_instants.data() + local_transitions.element_offset(0), + static_cast(list_size)); + + auto const it = thrust::upper_bound( + thrust::seq, transition_times.begin(), transition_times.end(), loose_epoch_second); + auto const idx = static_cast(thrust::distance(transition_times.begin(), it)); + auto const list_offset = local_transitions.element_offset(idx - 1); + return static_cast(utc_offsets.element(list_offset)); + } + + /** + * The formula to compute loose epoch from local time. The loose epoch is used + * to search for the corresponding timezone offset of specific zone ID from + * TimezoneDB. The target of loose epoch is to transfer local time to a number + * which is proportional to the real timestamp as easily as possible. Loose + * epoch, as a computation approach, helps us to align probe(kernel side) to + * the TimezoneDB(Java side). Then, we can apply binary search based on loose + * epoch instants of TimezoneDB to find out the correct timezone offset. + * + * Loose epoch column is used for binary search. + * Here we use 400 days a year, it's safe, because mapping from local time to + * loose epoch is monotonic. + */ + __device__ inline int64_t compute_loose_epoch_s(timestamp_components const& ts) const + { + return (ts.year * 400 + (ts.month - 1) * 31 + ts.day - 1) * 86400L + ts.hour * 3600L + + ts.minute * 60L + ts.second; + } + + /** + * Leverage STL to convert local time to UTC timestamp(in microseconds) + */ + __device__ inline int64_t compute_epoch_us(timestamp_components const& ts) const + { + auto const ymd = // chrono class handles the leap year calculations for us + cuda::std::chrono::year_month_day(cuda::std::chrono::year{ts.year}, + cuda::std::chrono::month{static_cast(ts.month)}, + cuda::std::chrono::day{static_cast(ts.day)}); + auto const days = cuda::std::chrono::sys_days(ymd).time_since_epoch().count(); + + int64_t const timestamp_s = + (days * 24L * 3600L) + (ts.hour * 3600L) + (ts.minute * 60L) + ts.second; + return timestamp_s * 1000000L + ts.microseconds; + } + + /** + * Ported from Spark: + * https://github.com/apache/spark/blob/v3.5.0/sql/api/src/main/scala/ + * org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala#L394 + * + * Parse a string with time zone to a timestamp. + * The bool in the returned tuple is false if the parse failed. + */ + __device__ inline thrust::tuple + parse_string_to_timestamp_us(cudf::string_view const& timestamp_str) const + { + timestamp_components ts_comp{}; + char const* parsed_tz_ptr = nullptr; + cudf::size_type parsed_tz_length = -1; + auto invalid_ret = + thrust::make_tuple(ts_comp, parsed_tz_ptr, parsed_tz_length, ParseResult::INVALID); + + const char* curr_ptr = timestamp_str.data(); + const char* end_ptr = curr_ptr + timestamp_str.size_bytes(); + + // trim left + while (curr_ptr < end_ptr && is_whitespace(*curr_ptr)) { + ++curr_ptr; + } + // trim right + while (curr_ptr < end_ptr - 1 && is_whitespace(*(end_ptr - 1))) { + --end_ptr; + } + + if (curr_ptr == end_ptr) { return invalid_ret; } + + const char* const bytes = curr_ptr; + const cudf::size_type bytes_length = end_ptr - curr_ptr; + + // segments stores: [year, month, day, hour, minute, seconds, microseconds, no_use_item, + // no_use_item] the two tail items are no use, but here keeps them as Spark does + int segments[] = {1, 1, 1, 0, 0, 0, 0, 0, 0}; + int segments_len = 9; + int i = 0; + int current_segment_value = 0; + int current_segment_digits = 0; + size_t j = 0; + int digits_milli = 0; + thrust::optional year_sign; + if ('-' == bytes[j] || '+' == bytes[j]) { + if ('-' == bytes[j]) { + year_sign = -1; + } else { + year_sign = 1; + } + j += 1; + } + + while (j < bytes_length) { + char const b = bytes[j]; + int const parsed_value = static_cast(b - '0'); + if (parsed_value < 0 || parsed_value > 9) { + if (0 == j && 'T' == b) { + i += 3; + } else if (i < 2) { + if (b == '-') { + if (!is_valid_digits(i, current_segment_digits)) { return invalid_ret; } + segments[i] = current_segment_value; + current_segment_value = 0; + current_segment_digits = 0; + i += 1; + } else if (0 == i && ':' == b && !year_sign.has_value()) { + if (!is_valid_digits(3, current_segment_digits)) { return invalid_ret; } + segments[3] = current_segment_value; + current_segment_value = 0; + current_segment_digits = 0; + i = 4; + } else { + return invalid_ret; + } + } else if (2 == i) { + if (' ' == b || 'T' == b) { + if (!is_valid_digits(i, current_segment_digits)) { return invalid_ret; } + segments[i] = current_segment_value; + current_segment_value = 0; + current_segment_digits = 0; + i += 1; + } else { + return invalid_ret; + } + } else if (3 == i || 4 == i) { + if (':' == b) { + if (!is_valid_digits(i, current_segment_digits)) { return invalid_ret; } + segments[i] = current_segment_value; + current_segment_value = 0; + current_segment_digits = 0; + i += 1; + } else { + return invalid_ret; + } + } else if (5 == i || 6 == i) { + if ('.' == b && 5 == i) { + if (!is_valid_digits(i, current_segment_digits)) { return invalid_ret; } + segments[i] = current_segment_value; + current_segment_value = 0; + current_segment_digits = 0; + i += 1; + } else { + if (!is_valid_digits(i, current_segment_digits) || !allow_tz_in_date_str) { + return invalid_ret; + } + segments[i] = current_segment_value; + current_segment_value = 0; + current_segment_digits = 0; + i += 1; + parsed_tz_ptr = bytes + j; + // strip the whitespace between timestamp and timezone + while (parsed_tz_ptr < end_ptr && is_whitespace(*parsed_tz_ptr)) + ++parsed_tz_ptr; + parsed_tz_length = end_ptr - parsed_tz_ptr; + break; + } + if (i == 6 && '.' != b) { i += 1; } + } else { + if (i < segments_len && (':' == b || ' ' == b)) { + if (!is_valid_digits(i, current_segment_digits)) { return invalid_ret; } + segments[i] = current_segment_value; + current_segment_value = 0; + current_segment_digits = 0; + i += 1; + } else { + return invalid_ret; + } + } + } else { + if (6 == i) { digits_milli += 1; } + // We will truncate the nanosecond part if there are more than 6 digits, + // which results in loss of precision + if (6 != i || current_segment_digits < 6) { + current_segment_value = current_segment_value * 10 + parsed_value; + } + current_segment_digits += 1; + } + j += 1; + } + + if (!is_valid_digits(i, current_segment_digits)) { return invalid_ret; } + segments[i] = current_segment_value; + + while (digits_milli < 6) { + segments[6] *= 10; + digits_milli += 1; + } + + segments[0] *= year_sign.value_or(1); + // above is ported from Spark. + + // copy segments to equivalent kernel timestamp_components + // Note: In order to keep above code is equivalent to Spark implementation, + // did not use `timestamp_components` directly to save values. + ts_comp.year = segments[0]; + ts_comp.month = static_cast(segments[1]); + ts_comp.day = static_cast(segments[2]); + ts_comp.hour = static_cast(segments[3]); + ts_comp.minute = static_cast(segments[4]); + ts_comp.second = static_cast(segments[5]); + ts_comp.microseconds = segments[6]; + + return thrust::make_tuple(ts_comp, parsed_tz_ptr, parsed_tz_length, ParseResult::OK); + } +}; + +/** + * The common entrance of string_to_timestamp, two paths call this function: + * - `string_to_timestamp_with_tz` : with time zone + * - `string_to_timestamp_without_tz` : without time zone + * The parameters transitions, tz_indices and default_tz_index are only for handling + * inputs with timezone. + * It's called from `string_to_timestamp_without_tz` if transitions and tz_indices + * are nullptr, otherwise called from `string_to_timestamp_with_tz`. + * + */ +std::unique_ptr to_timestamp( + cudf::strings_column_view const& input, + bool const ansi_mode, + bool const allow_tz_in_date_str, + cudf::size_type const default_tz_index = -1, + cudf::column_view const* transitions = nullptr, + cudf::column_view const* tz_indices = nullptr, + rmm::cuda_stream_view stream = cudf::get_default_stream(), + rmm::mr::device_memory_resource* mr = rmm::mr::get_current_device_resource()) +{ + auto const d_strings = cudf::column_device_view::create(input.parent(), stream); + // column to store the result timestamp + auto result_col = + cudf::make_timestamp_column(cudf::data_type{cudf::type_id::TIMESTAMP_MICROSECONDS}, + input.size(), + cudf::mask_state::UNALLOCATED, + stream, + mr); + // column to store the status `ParseResult` + auto result_valid_col = cudf::make_fixed_width_column( + cudf::data_type{cudf::type_id::UINT8}, input.size(), cudf::mask_state::UNALLOCATED, stream, mr); + + if (transitions == nullptr || tz_indices == nullptr) { + thrust::transform( + rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input.size()), + thrust::make_zip_iterator( + thrust::make_tuple(result_col->mutable_view().begin(), + result_valid_col->mutable_view().begin())), + parse_timestamp_string_fn{*d_strings, default_tz_index, allow_tz_in_date_str}); + } else { + auto const ft_cdv_ptr = cudf::column_device_view::create(*transitions, stream); + auto const d_transitions = cudf::detail::lists_column_device_view{*ft_cdv_ptr}; + auto const d_tz_indices = cudf::column_device_view::create(*tz_indices, stream); + + thrust::transform(rmm::exec_policy(stream), + thrust::make_counting_iterator(0), + thrust::make_counting_iterator(input.size()), + thrust::make_zip_iterator( + thrust::make_tuple(result_col->mutable_view().begin(), + result_valid_col->mutable_view().begin())), + parse_timestamp_string_fn{ + *d_strings, default_tz_index, true, d_transitions, *d_tz_indices}); + } + + auto valid_view = result_valid_col->mutable_view(); + + // throw cuDF exception if there exists any unsupported formats + auto exception_exists = + thrust::any_of(rmm::exec_policy(stream), + valid_view.begin(), + valid_view.end(), + [] __device__(uint8_t e) { return e == ParseResult::UNSUPPORTED; }); + if (exception_exists) { CUDF_FAIL("There exists unsupported timestamp schema!"); } + + // build the updated null mask and compute the null count + auto [valid_bitmask, valid_null_count] = cudf::detail::valid_if( + valid_view.begin(), + valid_view.end(), + [] __device__(uint8_t e) { return e == 0; }, + stream, + mr); + + // `output null count > input null count` indicates that there are new null + // values generated during the `to_timestamp` transaction to replace invalid + // inputs. + if (ansi_mode && input.null_count() < valid_null_count) { return nullptr; } + + result_col->set_null_mask(valid_bitmask, valid_null_count, stream); + return result_col; +} + +} // anonymous namespace + +/** + * Parse string column with time zone to timestamp column. + * If a string does not have time zone in it, use the default time zone. + * + * Returns nullptr if ANSI mode is true and strings have invalid data, + * otherwise, returns non-null timestamp column(the invalid date will be empty in this column) + * + */ +std::unique_ptr string_to_timestamp_with_tz(cudf::strings_column_view const& input, + cudf::column_view const& transitions, + cudf::column_view const& tz_indices, + cudf::size_type const default_tz_index, + bool const ansi_mode) +{ + if (input.size() == 0) { return nullptr; } + return to_timestamp(input, ansi_mode, true, default_tz_index, &transitions, &tz_indices); +} + +/** + * Parse string column without time zone to timestamp column. + * Returns nullptr if ANSI mode is true and strings have any invalid value, returns non-null + * timestamp column otherwise. + * + */ +std::unique_ptr string_to_timestamp_without_tz(cudf::strings_column_view const& input, + bool const allow_time_zone, + bool const ansi_mode) +{ + if (input.size() == 0) { return nullptr; } + return to_timestamp(input, ansi_mode, allow_time_zone); +} + +} // namespace spark_rapids_jni diff --git a/src/main/cpp/src/datetime_parser.hpp b/src/main/cpp/src/datetime_parser.hpp new file mode 100644 index 0000000000..ba83f43064 --- /dev/null +++ b/src/main/cpp/src/datetime_parser.hpp @@ -0,0 +1,136 @@ +/* + * Copyright (c) 2024, NVIDIA CORPORATION. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include + +namespace spark_rapids_jni { + +/** + * + * Trims and parses a timestamp string column with time zone suffix to a + * timestamp column. e.g.: 1991-04-14T02:00:00Asia/Shanghai => 1991-04-13 + * 18:00:00 + * + * Refer to: https://github.com/apache/spark/blob/v3.5.0/sql/api/src/main/scala/ + * org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala#L394 + * + * Spark supports the following formats: + * `[+-]yyyy*` + * `[+-]yyyy*-[m]m` + * `[+-]yyyy*-[m]m-[d]d` + * `[+-]yyyy*-[m]m-[d]d ` + * `[+-]yyyy*-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * `[+-]yyyy*-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * + * Unlike Spark, Spark-Rapids only supports the following formats: + * `[+-]yyyy*` + * `[+-]yyyy*-[m]m` + * `[+-]yyyy*-[m]m-[d]d` + * `[+-]yyyy*-[m]m-[d]d ` + * `[+-]yyyy*-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * `[+-]yyyy*-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * + * Spark supports the following zone id forms: + * - Z - Zulu time zone UTC+0 + * - +|-[h]h:[m]m + * - A short id, see + * https://docs.oracle.com/javase/8/docs/api/java/time/ZoneId.html#SHORT_IDS + * - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-, + * and a suffix in the formats: + * - +|-h[h] + * - +|-hh[:]mm + * - +|-hh:mm:ss + * - +|-hhmmss + * - Region-based zone IDs in the form `area/city`, such as `Europe/Paris` + * + * Unlike Spark, Spark-Rapids currently does not support DST time zones. + * + * @param input input string column view. + * @param transitions refer to TimezoneDB, the table of transitions contains all + * information for timezones. + * @param tz_indices refer to TimezoneDB, map from time zone to TimezoneDB transition index. + * @param default_tz_index the index of default timezone in TimezoneDB, if input + * date-like string does not contain a time zone (like: YYYY-MM-DD:hhmmss), use + * this time zone. + * @param ansi_mode whether enforce ANSI mode or not. If true, exception will be + * thrown encountering any invalid inputs. + * @returns the pointer of the timestamp result column, which points to nullptr + * if there exists invalid inputs and ANSI mode is on. + */ +std::unique_ptr string_to_timestamp_with_tz(cudf::strings_column_view const& input, + cudf::column_view const& transitions, + cudf::column_view const& tz_indices, + cudf::size_type const default_tz_index, + bool const ansi_mode); + +/** + * + * Trims and parses a timestamp string column with time zone suffix to a + * timestamp column. e.g.: 1991-04-14T02:00:00Asia/Shanghai => 1991-04-13 + * 18:00:00 + * + * Refer to: https://github.com/apache/spark/blob/v3.5.0/sql/api/src/main/scala/ + * org/apache/spark/sql/catalyst/util/SparkDateTimeUtils.scala#L394 + * + * Spark supports the following formats: + * `[+-]yyyy*` + * `[+-]yyyy*-[m]m` + * `[+-]yyyy*-[m]m-[d]d` + * `[+-]yyyy*-[m]m-[d]d ` + * `[+-]yyyy*-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * `[+-]yyyy*-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * `[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * `T[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * + * Unlike Spark, Spark-Rapids only supports the following formats: + * `[+-]yyyy*` + * `[+-]yyyy*-[m]m` + * `[+-]yyyy*-[m]m-[d]d` + * `[+-]yyyy*-[m]m-[d]d ` + * `[+-]yyyy*-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * `[+-]yyyy*-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * + * Spark supports the following zone id forms: + * - Z - Zulu time zone UTC+0 + * - +|-[h]h:[m]m + * - A short id, see + * https://docs.oracle.com/javase/8/docs/api/java/time/ZoneId.html#SHORT_IDS + * - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-, + * and a suffix in the formats: + * - +|-h[h] + * - +|-hh[:]mm + * - +|-hh:mm:ss + * - +|-hhmmss + * - Region-based zone IDs in the form `area/city`, such as `Europe/Paris` + * + * Unlike Spark, Spark-Rapids currently does not support DST time zones. + * + * + * @param input input string column view. + * @param allow_time_zone whether allow time zone in the timestamp string. e.g.: + * 1991-04-14T02:00:00Asia/Shanghai is invalid when do not allow time zone. + * @param ansi_mode whether enforce ANSI mode or not. If true, exception will be + * thrown encountering any invalid inputs. + * @returns the pointer of the timestamp result column, which points to nullptr + * if there exists invalid inputs and ANSI mode is on. + */ +std::unique_ptr string_to_timestamp_without_tz(cudf::strings_column_view const& input, + bool const allow_time_zone, + bool const ansi_mode); + +} // namespace spark_rapids_jni diff --git a/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java b/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java index 2b2267f034..3c4c4a3cc6 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/CastStrings.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,6 +18,8 @@ import ai.rapids.cudf.*; +import java.time.ZoneId; + /** Utility class for casting between string columns and native type columns */ public class CastStrings { static { @@ -152,6 +154,130 @@ public static ColumnVector fromIntegersWithBase(ColumnView cv, int base) { return new ColumnVector(fromIntegersWithBase(cv.getNativeView(), base)); } + /** + * Trims and parses a timestamp string column with time zone suffix to a + * timestamp column. + * Use the default time zone if string does not contain time zone. + * + * Supports the following formats: + * `[+-]yyyy*` + * `[+-]yyyy*-[m]m` + * `[+-]yyyy*-[m]m-[d]d` + * `[+-]yyyy*-[m]m-[d]d ` + * `[+-]yyyy*-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * `[+-]yyyy*-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * + * Spark supports the following zone id forms: + * - Z - Zulu time zone UTC+0 + * - +|-[h]h:[m]m + * - A short id, see + * https://docs.oracle.com/javase/8/docs/api/java/time/ZoneId.html#SHORT_IDS + * - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-, + * and a suffix in the formats: + * - +|-h[h] + * - +|-hh[:]mm + * - +|-hh:mm:ss + * - +|-hhmmss + * - Region-based zone IDs in the form `area/city`, such as `Europe/Paris` + * + * Unlike Spark, Spark-Rapids currently does not support DST time zones. + * + * Note: + * - Do not support cast special strings(epoch now today yesterday tomorrow) to timestamp. + * Spark31x supports cast special strings while Spark320+ do not supports + * - Do not support DST time zones, return null in non-ANSI mode. + * TODO: DST support. + * + * Example: + * input = [" 2023", "2023-01-01T08:00:00Asia/Shanghai "] + * ts = toTimestamp(input, "UTC", allowSpecialExpressions = true, ansiEnabled = + * false) + * ts is: ['2023-01-01 00:00:00', '2023-01-01T00:00:00'] + * + * Example: + * input = ["2023-01-01T08:00:00 non-exist-time-zone"] + * In ANSI mode: throws IllegalArgumentException + * In non-ANSI mode: return null value + * + * @param cv The input string column to be converted. + * @param defaultTimeZone Use the default time zone if string does not + * contain time zone. + * @param ansiEnabled is Ansi mode + * @return a timestamp column + * @throws IllegalArgumentException if any string in cv has invalid format or the time zone is + * non-existed/wrong when ansiEnabled is true + */ + public static ColumnVector toTimestamp(ColumnView cv, ZoneId defaultTimeZone, boolean ansiEnabled) { + if (!GpuTimeZoneDB.isSupportedTimeZone(defaultTimeZone)) { + throw new IllegalArgumentException(String.format("Unsupported timezone: %s", + defaultTimeZone.getId())); + } + + Integer tzIndex = GpuTimeZoneDB.getZoneIDMap().get(defaultTimeZone.getId()); + try (Table transitions = GpuTimeZoneDB.getTransitions(); + ColumnVector tzIndices = GpuTimeZoneDB.getZoneIDVector()) { + return new ColumnVector(toTimestamp(cv.getNativeView(), transitions.getNativeView(), + tzIndices.getNativeView(), tzIndex, ansiEnabled)); + } + } + + /** + * Trims and parses a timestamp string column with time zone suffix to a + * timestamp column. + * Do not use the time zones in timestamp strings. + * + * Supports the following formats: + * `[+-]yyyy*` + * `[+-]yyyy*-[m]m` + * `[+-]yyyy*-[m]m-[d]d` + * `[+-]yyyy*-[m]m-[d]d ` + * `[+-]yyyy*-[m]m-[d]d [h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * `[+-]yyyy*-[m]m-[d]dT[h]h:[m]m:[s]s.[ms][ms][ms][us][us][us][zone_id]` + * + * Spark supports the following zone id forms: + * - Z - Zulu time zone UTC+0 + * - +|-[h]h:[m]m + * - A short id, see + * https://docs.oracle.com/javase/8/docs/api/java/time/ZoneId.html#SHORT_IDS + * - An id with one of the prefixes UTC+, UTC-, GMT+, GMT-, UT+ or UT-, + * and a suffix in the formats: + * - +|-h[h] + * - +|-hh[:]mm + * - +|-hh:mm:ss + * - +|-hhmmss + * - Region-based zone IDs in the form `area/city`, such as `Europe/Paris` + * + * Unlike Spark, Spark-Rapids currently does not support DST time zones. + * + * Note: Do not support cast special strings(epoch now today yesterday tomorrow) to timestamp. + * Spark31x supports cast special strings while Spark320+ do not supports + * + * Example: + * input = [" 2023", "2023-01-01T08:00:00Asia/Shanghai "] + * ts = toTimestampWithoutTimeZone(input, allowTimeZone = true, + * allowSpecialExpressions = true, ansiEnabled = false) + * ts is: ['2023-01-01 00:00:00', '2023-01-01T08:00:00'] + * + * Note: this function will never use the time zones in the strings. + * allowTimeZone means whether allow time zone in the timestamp string. + * If allowTimeZone is true, the time zones are ignored if has. + * if allowTimeZone is false, then this function will throw exception if has any time zone in the strings and it's ANSI mode. + * + * @param cv The input string column to be converted. + * @param allowTimeZone whether allow time zone in the timestamp + * string. e.g.: + * 1991-04-14T02:00:00Asia/Shanghai is invalid + * when do not allow time zone. + * @param ansiEnabled is Ansi mode + * @return a timestamp column + * @throws IllegalArgumentException if any string in cv has invalid format or contains time zone + * while `allowTimeZone` is false when ANSI is true. + * + */ + public static ColumnVector toTimestampWithoutTimeZone(ColumnView cv, boolean allowTimeZone, boolean ansiEnabled) { + return new ColumnVector(toTimestampWithoutTimeZone(cv.getNativeView(), allowTimeZone, ansiEnabled)); + } + private static native long toInteger(long nativeColumnView, boolean ansi_enabled, boolean strip, int dtype); private static native long toDecimal(long nativeColumnView, boolean ansi_enabled, boolean strip, @@ -163,4 +289,8 @@ private static native long toDecimal(long nativeColumnView, boolean ansi_enabled private static native long toIntegersWithBase(long nativeColumnView, int base, boolean ansiEnabled, int dtype); private static native long fromIntegersWithBase(long nativeColumnView, int base); -} \ No newline at end of file + private static native long toTimestamp(long input, + long transitions, long tzIndices, int tzIndex, boolean ansiEnabled); + private static native long toTimestampWithoutTimeZone(long input, boolean allowTimeZone, + boolean ansiEnabled); +} diff --git a/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java b/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java index 643db278df..efcd592604 100644 --- a/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java +++ b/src/main/java/com/nvidia/spark/rapids/jni/GpuTimeZoneDB.java @@ -24,16 +24,29 @@ import org.slf4j.LoggerFactory; import java.time.Instant; +import java.time.LocalDateTime; import java.time.ZoneId; import java.time.zone.ZoneOffsetTransition; import java.time.zone.ZoneRules; import java.time.zone.ZoneRulesException; import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.TimeZone; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.function.Function; /** * Gpu time zone utility. @@ -49,14 +62,21 @@ public class GpuTimeZoneDB { private static final Logger log = LoggerFactory.getLogger(GpuTimeZoneDB.class); - // For the timezone database, we store the transitions in a ColumnVector that is a list of + // For the timezone database, we store the transitions in a ColumnVector that is a list of // structs. The type of this column vector is: - // LIST> - private Map zoneIdToTable; - + // LIST> // use this reference to indicate if time zone cache is initialized. + // `fixedTransitions` saves transitions for deduplicated time zones, diferent time zones + // may map to one normalized time zone. private HostColumnVector fixedTransitions; + // time zone to index in `fixedTransitions` + // The key of `zoneIdToTable` is the time zone names before dedup. + private Map zoneIdToTable; + + // host column vector for `zoneIdToTable`, sorted by time zone strings + private HostColumnVector zoneIdToTableVec; + // Guarantee singleton instance private GpuTimeZoneDB() { } @@ -190,7 +210,7 @@ private void shutdownImpl() { closeResources(); } - private void closeResources() { + private void closeResources() { if (zoneIdToTable != null) { zoneIdToTable.clear(); zoneIdToTable = null; @@ -199,6 +219,10 @@ private void closeResources() { fixedTransitions.close(); fixedTransitions = null; } + if (zoneIdToTableVec != null) { + zoneIdToTableVec.close(); + zoneIdToTableVec = null; + } } public static ColumnVector fromTimestampToUtcTimestamp(ColumnVector input, ZoneId currentTimeZone) { @@ -260,86 +284,197 @@ public static ZoneId getZoneId(String timeZoneId) { @SuppressWarnings("unchecked") private void loadData() { try { - List> masterTransitions = new ArrayList<>(); - zoneIdToTable = new HashMap<>(); - for (String tzId : TimeZone.getAvailableIDs()) { - ZoneId zoneId; - try { - zoneId = ZoneId.of(tzId).normalized(); // we use the normalized form to dedupe - } catch (ZoneRulesException e) { - // Sometimes the list of getAvailableIDs() is one of the 3-letter abbreviations, however, - // this use is deprecated due to ambiguity reasons (same abbrevation can be used for - // multiple time zones). These are not supported by ZoneId.of(...) directly here. - continue; + // Note: ZoneId.normalized will transform fixed offset time zone to standard fixed offset + // e.g.: ZoneId.of("Etc/GMT").normalized.getId = Z; ZoneId.of("Etc/GMT+0").normalized.getId = Z + // Both Etc/GMT and Etc/GMT+0 have normalized Z. + // We use the normalized form to dedupe, + // but should record map from TimeZone.getAvailableIDs() Set to normalized Set. + // `fixedTransitions` saves transitions for normalized time zones. + // Spark uses time zones from TimeZone.getAvailableIDs() + // So we have a Map from TimeZone.getAvailableIDs() to index of `fixedTransitions`. + + // get and sort time zones + String[] timeZones = TimeZone.getAvailableIDs(); + List sortedTimeZones = new ArrayList<>(Arrays.asList(timeZones)); + // Note: Z is a special normalized time zone from UTC: ZoneId.of("UTC").normalized = Z + // TimeZone.getAvailableIDs does not contains Z and ZoneId.SHORT_IDS also does not contain Z + // Should add Z to `zoneIdToTable` + sortedTimeZones.add("Z"); + Collections.sort(sortedTimeZones); + + // Note: Spark uses ZoneId.SHORT_IDS + // `TimeZone.getAvailableIDs` contains all keys in `ZoneId.SHORT_IDS` + // So do not need extra work for ZoneId.SHORT_IDS, here just check this assumption + for (String tz : ZoneId.SHORT_IDS.keySet()) { + if (!sortedTimeZones.contains(tz)) { + throw new IllegalStateException( + String.format("Can not find short Id %s in time zones %s", tz, sortedTimeZones)); } - ZoneRules zoneRules = zoneId.getRules(); + } + + // A simple approach to transform LocalDateTime to a value which is proportional to + // the exact EpochSecond. After caching these values along with EpochSeconds, we + // can easily search out which time zone transition rule we should apply according + // to LocalDateTime structs. The searching procedure is same as the binary search with + // exact EpochSeconds(convert_timestamp_tz_functor), except using "loose instant" + // as search index instead of exact EpochSeconds. + Function localToLooseEpochSecond = lt -> + 86400L * (lt.getYear() * 400L + (lt.getMonthValue() - 1) * 31L + + lt.getDayOfMonth() - 1) + + 3600L * lt.getHour() + 60L * lt.getMinute() + lt.getSecond(); + + List> masterTransitions = new ArrayList<>(); + + // map: normalizedTimeZone -> index in fixedTransitions + Map mapForNormalizedTimeZone = new HashMap<>(); + // go though all time zones and save by normalized time zone + List sortedSupportedTimeZones = new ArrayList<>(); + for (String timeZone : sortedTimeZones) { + ZoneId normalizedZoneId = ZoneId.of(timeZone, ZoneId.SHORT_IDS).normalized(); + String normalizedTimeZone = normalizedZoneId.getId(); + ZoneRules zoneRules = normalizedZoneId.getRules(); // Filter by non-repeating rules if (!zoneRules.isFixedOffset() && !zoneRules.getTransitionRules().isEmpty()) { continue; } - if (!zoneIdToTable.containsKey(zoneId.getId())) { - List transitions = zoneRules.getTransitions(); + sortedSupportedTimeZones.add(timeZone); + if (!mapForNormalizedTimeZone.containsKey(normalizedTimeZone)) { // dedup + List data = getTransitionData(localToLooseEpochSecond, zoneRules); + // add transition data for time zone int idx = masterTransitions.size(); - List data = new ArrayList<>(); - if (zoneRules.isFixedOffset()) { - data.add( - new HostColumnVector.StructData(Long.MIN_VALUE, Long.MIN_VALUE, - zoneRules.getOffset(Instant.now()).getTotalSeconds()) - ); - } else { - // Capture the first official offset (before any transition) using Long min - ZoneOffsetTransition first = transitions.get(0); - data.add( - new HostColumnVector.StructData(Long.MIN_VALUE, Long.MIN_VALUE, - first.getOffsetBefore().getTotalSeconds()) - ); - transitions.forEach(t -> { - // Whether transition is an overlap vs gap. - // In Spark: - // if it's a gap, then we use the offset after *on* the instant - // If it's an overlap, then there are 2 sets of valid timestamps in that are overlapping - // So, for the transition to UTC, you need to compare to instant + {offset before} - // The time math still uses {offset after} - if (t.isGap()) { - data.add( - new HostColumnVector.StructData( - t.getInstant().getEpochSecond(), - t.getInstant().getEpochSecond() + t.getOffsetAfter().getTotalSeconds(), - t.getOffsetAfter().getTotalSeconds()) - ); - } else { - data.add( - new HostColumnVector.StructData( - t.getInstant().getEpochSecond(), - t.getInstant().getEpochSecond() + t.getOffsetBefore().getTotalSeconds(), - t.getOffsetAfter().getTotalSeconds()) - ); - } - }); - } + mapForNormalizedTimeZone.put(normalizedTimeZone, idx); masterTransitions.add(data); - zoneIdToTable.put(zoneId.getId(), idx); } } + HostColumnVector.DataType childType = new HostColumnVector.StructType(false, new HostColumnVector.BasicType(false, DType.INT64), new HostColumnVector.BasicType(false, DType.INT64), - new HostColumnVector.BasicType(false, DType.INT32)); + new HostColumnVector.BasicType(false, DType.INT32), + new HostColumnVector.BasicType(false, DType.INT64)); HostColumnVector.DataType resultType = new HostColumnVector.ListType(false, childType); - fixedTransitions = HostColumnVector.fromLists(resultType, - masterTransitions.toArray(new List[0])); + + // generate all transitions for all time zones + fixedTransitions = HostColumnVector.fromLists(resultType, masterTransitions.toArray(new List[0])); + + // generate `zoneIdToTable`, key should be time zone not normalized time zone + zoneIdToTable = new HashMap<>(); + for (String timeZone : sortedSupportedTimeZones) { + // map from time zone to normalized + String normalized = ZoneId.of(timeZone, ZoneId.SHORT_IDS).normalized().getId(); + Integer index = mapForNormalizedTimeZone.get(normalized); + if (index != null) { + zoneIdToTable.put(timeZone, index); + } else { + throw new IllegalStateException("Could not find index for normalized time zone " + normalized); + } + } + // generate host vector + zoneIdToTableVec = generateZoneIdToTableVec(sortedSupportedTimeZones, zoneIdToTable); + } catch (IllegalStateException e) { + throw e; } catch (Exception e) { throw new IllegalStateException("load time zone DB cache failed!", e); } } - private Map getZoneIDMap() { - return zoneIdToTable; + // generate transition data for a time zone + private List getTransitionData(Function localToLooseEpochSecond, + ZoneRules zoneRules) { + List transitions = zoneRules.getTransitions(); + List data = new ArrayList<>(); + if (zoneRules.isFixedOffset()) { + data.add( + new HostColumnVector.StructData(Long.MIN_VALUE, Long.MIN_VALUE, + zoneRules.getOffset(Instant.now()).getTotalSeconds(), Long.MIN_VALUE) + ); + } else { + // Capture the first official offset (before any transition) using Long min + ZoneOffsetTransition first = transitions.get(0); + data.add( + new HostColumnVector.StructData(Long.MIN_VALUE, Long.MIN_VALUE, + first.getOffsetBefore().getTotalSeconds(), Long.MIN_VALUE) + ); + transitions.forEach(t -> { + // Whether transition is an overlap vs gap. + // In Spark: + // if it's a gap, then we use the offset after *on* the instant + // If it's an overlap, then there are 2 sets of valid timestamps in that are overlapping + // So, for the transition to UTC, you need to compare to instant + {offset before} + // The time math still uses {offset after} + if (t.isGap()) { + data.add( + new HostColumnVector.StructData( + t.getInstant().getEpochSecond(), + t.getInstant().getEpochSecond() + t.getOffsetAfter().getTotalSeconds(), + t.getOffsetAfter().getTotalSeconds(), + localToLooseEpochSecond.apply(t.getDateTimeAfter()) // this column is for rebase local date time + ) + ); + } else { + data.add( + new HostColumnVector.StructData( + t.getInstant().getEpochSecond(), + t.getInstant().getEpochSecond() + t.getOffsetBefore().getTotalSeconds(), + t.getOffsetAfter().getTotalSeconds(), + localToLooseEpochSecond.apply(t.getDateTimeBefore()) // this column is for rebase local date time + ) + ); + } + }); + } + return data; + } + + /** + * Generate map from time zone to index in transition table. + * regular time zone map to normalized time zone, then get from + * @param sortedSupportedTimeZones is sorted and supported time zones + * @param zoneIdToTableMap is a map from non-normalized time zone to index in transition table + */ + private static HostColumnVector generateZoneIdToTableVec(List sortedSupportedTimeZones, Map zoneIdToTableMap) { + HostColumnVector.DataType type = new HostColumnVector.StructType(false, + new HostColumnVector.BasicType(false, DType.STRING), + new HostColumnVector.BasicType(false, DType.INT32)); + ArrayList data = new ArrayList<>(); + + for (String timeZone : sortedSupportedTimeZones) { + Integer mapTo = zoneIdToTableMap.get(timeZone); + if (mapTo != null) { + data.add(new HostColumnVector.StructData(timeZone, mapTo)); + } else { + throw new IllegalStateException("Could not find index for time zone " + timeZone); + } + } + return HostColumnVector.fromStructs(type, data); + } + + /** + * get map from time zone to time zone index in transition table. + * @return map from time zone to time zone index in transition table. + */ + public static Map getZoneIDMap() { + cacheDatabase(); + return instance.zoneIdToTable; } - private Table getTransitions() { - try (ColumnVector fixedTransitions = getFixedTransitions()) { + /** + * Get vector from time zone to index in transition table + * @return + */ + public static ColumnVector getZoneIDVector() { + cacheDatabase(); + return instance.zoneIdToTableVec.copyToDevice(); + } + + /** + * Transition table + * @return + */ + public static Table getTransitions() { + cacheDatabase(); + try (ColumnVector fixedTransitions = instance.getFixedTransitions()) { return new Table(fixedTransitions); } } @@ -355,11 +490,10 @@ private ColumnVector getFixedTransitions() { * fixed transitions for a particular zoneId. * * It has default visibility so the test can access it. - * @param zoneId + * @param zoneId the time zones from TimeZone.getAvailableIDs without `ZoneId.normalized` * @return list of fixed transitions */ List getHostFixedTransitions(String zoneId) { - zoneId = ZoneId.of(zoneId).normalized().toString(); // we use the normalized form to dedupe Integer idx = getZoneIDMap().get(zoneId); if (idx == null) { return null; @@ -367,7 +501,6 @@ List getHostFixedTransitions(String zoneId) { return fixedTransitions.getList(idx); } - private static native long convertTimestampColumnToUTC(long input, long transitions, int tzIndex); private static native long convertUTCTimestampColumnToTimeZone(long input, long transitions, int tzIndex); diff --git a/src/test/java/com/nvidia/spark/rapids/jni/CastStringsTest.java b/src/test/java/com/nvidia/spark/rapids/jni/CastStringsTest.java index c39766454a..cafe69a6b4 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/CastStringsTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/CastStringsTest.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2022-2023, NVIDIA CORPORATION. + * Copyright (c) 2022-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,16 +17,21 @@ package com.nvidia.spark.rapids.jni; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.fail; +import java.time.*; import java.util.ArrayList; import java.util.List; +import java.util.AbstractMap; +import java.util.Map; import org.junit.jupiter.api.Test; import ai.rapids.cudf.AssertUtils; import ai.rapids.cudf.ColumnVector; import ai.rapids.cudf.DType; +import ai.rapids.cudf.HostColumnVector; import ai.rapids.cudf.Table; public class CastStringsTest { @@ -324,4 +329,178 @@ void baseHex2DecTest() { convTestInternal(input, expected, 16); } } + + @Test + void toTimestampTestAnsiWithoutTz() { + assertThrows(IllegalArgumentException.class, () -> { + try (ColumnVector input = ColumnVector.fromStrings(" invalid_value ")) { + // ansiEnabled is true + CastStrings.toTimestampWithoutTimeZone(input, false, true); + } + }); + + Instant instant = LocalDateTime.parse("2023-11-05T03:04:55").toInstant(ZoneOffset.UTC); + long expectedResults = instant.getEpochSecond() * 1000000L; + + try ( + ColumnVector input = ColumnVector.fromStrings("2023-11-05 3:04:55"); + ColumnVector expected = ColumnVector.timestampMicroSecondsFromBoxedLongs(expectedResults); + ColumnVector actual = CastStrings.toTimestampWithoutTimeZone(input, false, true)) { + AssertUtils.assertColumnsAreEqual(expected, actual); + } + } + + @Test + void toTimestampTestWithTz() { + List> entries = new ArrayList<>(); + // Without timezone + entries.add(new AbstractMap.SimpleEntry<>(" 2000-01-29 ", 949104000000000L)); + // Timezone IDs + entries.add(new AbstractMap.SimpleEntry<>("2023-11-05 3:4:55 America/Sao_Paulo", 1699164295000000L)); + entries.add(new AbstractMap.SimpleEntry<>("2023-11-5T03:04:55.1 Asia/Shanghai", 1699124695100000L)); + entries.add(new AbstractMap.SimpleEntry<>("2000-1-29 13:59:8 Iran", 949141748000000L)); + entries.add(new AbstractMap.SimpleEntry<>("1968-03-25T23:59:1.123Asia/Tokyo", -55846858877000L)); + entries.add(new AbstractMap.SimpleEntry<>("1968-03-25T23:59:1.123456Asia/Tokyo", -55846858876544L)); + + // UTC-like timezones + // no adjustment + entries.add(new AbstractMap.SimpleEntry<>("1970-9-9 2:33:44 Z", 21695624000000L)); + entries.add(new AbstractMap.SimpleEntry<>(" 1969-12-1 2:3:4.999Z", -2671015001000L)); + entries.add(new AbstractMap.SimpleEntry<>("1954-10-20 00:11:22 GMT ", -479692118000000L)); + entries.add(new AbstractMap.SimpleEntry<>("1984-1-3 00:11:22UTC", 441936682000000L)); + // hh + entries.add(new AbstractMap.SimpleEntry<>("1998-11-05T20:00:1.12 UTC+18 ", 910231201120000L)); + entries.add(new AbstractMap.SimpleEntry<>("1998-11-05T20:00:1.12UTC+0", 910296001120000L)); + entries.add(new AbstractMap.SimpleEntry<>("1998-11-05T20:00:1.12UTC-00", 910296001120000L)); + entries.add(new AbstractMap.SimpleEntry<>(" 1998-11-05T20:00:1.12 GMT+09 ", 910263601120000L)); + entries.add(new AbstractMap.SimpleEntry<>("1998-11-05T20:00:1.12 GMT-1", 910299601120000L)); + entries.add(new AbstractMap.SimpleEntry<>("1998-11-05T20:00:1.12 UTC-6", 910317601120000L)); + entries.add(new AbstractMap.SimpleEntry<>("1998-11-05T20:00:1.12 UTC-18", 910360801120000L)); + entries.add(new AbstractMap.SimpleEntry<>("1998-11-05T20:00:1.12UTC-00", 910296001120000L)); + entries.add(new AbstractMap.SimpleEntry<>(" 1998-11-05T20:00:1.12 +09 ", 910263601120000L)); + entries.add(new AbstractMap.SimpleEntry<>("1998-11-05T20:00:1.12 -1", 910299601120000L)); + entries.add(new AbstractMap.SimpleEntry<>("1998-11-05T20:00:1.12 +18 ", 910231201120000L)); + entries.add(new AbstractMap.SimpleEntry<>("1998-11-05T20:00:1.12-00", 910296001120000L)); + // hh:mm + entries.add(new AbstractMap.SimpleEntry<>("1969-12-1 2:3:4.999 UTC+1428", -2723095001000L)); + entries.add(new AbstractMap.SimpleEntry<>("1969-12-1 2:3:4.999 GMT-1501", -2616955001000L)); + entries.add(new AbstractMap.SimpleEntry<>("1969-12-1 2:3:4.999 GMT+1:22", -2675935001000L)); + entries.add(new AbstractMap.SimpleEntry<>("1969-12-1 2:3:4.8888 GMT+8:2", -2699935111200L)); + entries.add(new AbstractMap.SimpleEntry<>("1969-12-1 2:3:4.999 UTC+17:9", -2732755001000L)); + entries.add(new AbstractMap.SimpleEntry<>("1969-12-1 2:3:4.999 UTC-09:11", -2637955001000L)); + entries.add(new AbstractMap.SimpleEntry<>("1969-12-1 2:3:4.999 +1428 ", -2723095001000L)); + entries.add(new AbstractMap.SimpleEntry<>("1969-12-1 2:3:4.999-1501 ", -2616955001000L)); + entries.add(new AbstractMap.SimpleEntry<>("1969-12-1 2:3:4.999 +1:22 ", -2675935001000L)); + entries.add(new AbstractMap.SimpleEntry<>("1969-12-1 2:3:4.8888 +8:2 ", -2699935111200L)); + entries.add(new AbstractMap.SimpleEntry<>("1969-12-1 2:3:4.999+17:9", -2732755001000L)); + entries.add(new AbstractMap.SimpleEntry<>("1969-12-1 2:3:4.999 -09:11", -2637955001000L)); + // hh:mm::ss + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 22:33:44.1 GMT+112233", 1571569871100000L)); + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 22:33:44.1 UTC-100102", 1571646886100000L)); + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 22:33:44.1 UTC+11:22:33", 1571569871100000L)); + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 22:33:44.1 GMT-10:10:10", 1571647434100000L)); + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 22:33:44.1 GMT-8:08:01", 1571640105100000L)); + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 22:33:44.1 UTC+4:59:59", 1571592825100000L)); + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 00:1:20.3 +102030", 1571492450300000L)); + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 00:1:20.3 -020103", 1571536943300000L)); + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 22:33:44.1 -8:08:01 ", 1571640105100000L)); + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 22:33:44.1+4:59:59", 1571592825100000L)); + // short TZ ID: BST->Asia/Dhaka, CTT->Asia/Shanghai + entries.add(new AbstractMap.SimpleEntry<>("2023-11-5T03:04:55.1 CTT", 1699124695100000L)); + entries.add(new AbstractMap.SimpleEntry<>("2023-11-5T03:04:55.1 BST", 1699124695100000L + 7200L * 1000000L)); // BST is 2 hours later than CTT + // short TZ ID: EST: -05:00; HST: -10:00; MST: -07:00 + entries.add(new AbstractMap.SimpleEntry<>("2023-11-5T03:04:55.1 EST", 1699124695100000L + 13L * 3600L * 1000000L)); // EST is 8 + 5 hours later than Asia/Shanghai + entries.add(new AbstractMap.SimpleEntry<>("2023-11-5T03:04:55.1 HST", 1699124695100000L + 18L * 3600L * 1000000L)); // HST is 8 + 10 hours later than Asia/Shanghai + entries.add(new AbstractMap.SimpleEntry<>("2023-11-5T03:04:55.1 MST", 1699124695100000L + 15L * 3600L * 1000000L)); // MST is 8 + 7 hours later than Asia/Shanghai + // test time zones not in notmalized names, e.g,: ZoneId.of("Etc/GMT").normalized.getId = Z; ZoneId.of("Etc/GMT+0").normalized.getId = Z; Etc/GMT+10 -> -10:00 + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 22:33:44.1 Etc/GMT", 1571610824100000L)); + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 22:33:44.1 Etc/GMT+0", 1571610824100000L)); + entries.add(new AbstractMap.SimpleEntry<>("2019-10-20 22:33:44.1 Etc/GMT+10", 1571646824100000L)); + + int validDataSize = entries.size(); + + // Invalid instances + // Timezone without hh:mm:ss + entries.add(new AbstractMap.SimpleEntry<>("2000-01-29 Iran", null)); + // Invalid Timezone ID + entries.add(new AbstractMap.SimpleEntry<>("2000-01-29 10:20:30 Asia/London", null)); + // Invalid UTC-like timezone + // overflow + entries.add(new AbstractMap.SimpleEntry<>("2000-01-29 10:20:30 +10:60", null)); + entries.add(new AbstractMap.SimpleEntry<>("2000-01-29 10:20:30 UTC-7:59:60", null)); + entries.add(new AbstractMap.SimpleEntry<>("2000-01-29 10:20:30 +19", null)); + entries.add(new AbstractMap.SimpleEntry<>("2000-01-29 10:20:30 UTC-23", null)); + entries.add(new AbstractMap.SimpleEntry<>("2000-01-29 10:20:30 GMT+1801", null)); + entries.add(new AbstractMap.SimpleEntry<>("2000-01-29 10:20:30 -180001", null)); + entries.add(new AbstractMap.SimpleEntry<>("2000-01-29 10:20:30 UTC+18:00:10", null)); + entries.add(new AbstractMap.SimpleEntry<>("2000-01-29 10:20:30 GMT-23:5", null)); + List inputs = new ArrayList<>(); + List expects = new ArrayList<>(); + for (Map.Entry entry : entries) { + inputs.add(entry.getKey()); + expects.add(entry.getValue()); + } + + // Throw unsupported exception for symbols because Europe/London contains DST rules + assertThrows(IllegalArgumentException.class, () -> { + try (ColumnVector input = ColumnVector.fromStrings("2000-01-29 1:2:3 Europe/London")) { + CastStrings.toTimestamp(input, ZoneId.of("UTC"), true); + } + }); + + // Throw IllegalArgumentException for symbols of special dates + // Note: Spark 31x supports "epoch", "now", "today", "yesterday", "tomorrow". + // But Spark 32x to Spark 35x do not supports. + // Currently JNI do not supports + for (String date : new String[]{"epoch", "now", "today", "yesterday", "tomorrow"}) + assertThrows(IllegalArgumentException.class, () -> { + try (ColumnVector input = ColumnVector.fromStrings(date)) { + CastStrings.toTimestamp(input, ZoneId.of("UTC"), true); + } + }); + + // non-ANSI mode + try ( + ColumnVector input = ColumnVector.fromStrings(inputs.toArray(new String[0])); + ColumnVector expected = ColumnVector.timestampMicroSecondsFromBoxedLongs(expects.toArray(new Long[0])); + ColumnVector actual = CastStrings.toTimestamp(input, ZoneId.of("UTC"), false); + ColumnVector actual2 = CastStrings.toTimestamp(input, ZoneId.of("Z"), false)) { + AssertUtils.assertColumnsAreEqual(expected, actual); + AssertUtils.assertColumnsAreEqual(expected, actual2); + } + + // Should NOT throw exception because all inputs are valid + String[] validInputs = inputs.stream().limit(validDataSize).toArray(String[]::new); + Long[] validExpects = expects.stream().limit(validDataSize).toArray(Long[]::new); + try ( + ColumnVector input = ColumnVector.fromStrings(validInputs); + ColumnVector expected = ColumnVector.timestampMicroSecondsFromBoxedLongs(validExpects); + ColumnVector actual = CastStrings.toTimestamp(input, ZoneId.of("UTC"), true)) { + AssertUtils.assertColumnsAreEqual(expected, actual); + } + + // Throw IllegalArgumentException for invalid timestamps under ANSI mode + assertThrows(IllegalArgumentException.class, () -> { + try (ColumnVector input = ColumnVector.fromStrings(inputs.toArray(new String[0]))) { + CastStrings.toTimestamp(input, ZoneId.of("UTC"), true); + } + }); + + // Throw IllegalArgumentException for non-exist-tz in ANSI mode + assertThrows(IllegalArgumentException.class, () -> { + try (ColumnVector input = ColumnVector.fromStrings("2000-01-29 1:2:3 non-exist-tz")) { + CastStrings.toTimestamp(input, ZoneId.of("UTC"), true); + } + }); + + // Return null for non-exist-tz in non-Ansi mode + Long[] nullExpected = {null}; + try ( + ColumnVector input = ColumnVector.fromStrings("2000-01-29 1:2:3 non-exist-tz"); + ColumnVector expected = ColumnVector.timestampMicroSecondsFromBoxedLongs(nullExpected); + ColumnVector actual = CastStrings.toTimestamp(input, ZoneId.of("UTC"), false)) { + AssertUtils.assertColumnsAreEqual(expected, actual); + } + + } } diff --git a/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java b/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java index 7aaec496de..f50fe64c51 100644 --- a/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java +++ b/src/test/java/com/nvidia/spark/rapids/jni/TimeZoneTest.java @@ -22,6 +22,7 @@ import static ai.rapids.cudf.AssertUtils.assertColumnsAreEqual; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import ai.rapids.cudf.ColumnVector; @@ -45,12 +46,17 @@ static void cleanup() { void databaseLoadedTest() { // Check for a few timezones GpuTimeZoneDB instance = GpuTimeZoneDB.getInstance(); + + // UTC+8 is not in `TimeZone.getAvailableIDs`, so return null + // UTC+8 can be handle by kernel directly List transitions = instance.getHostFixedTransitions("UTC+8"); - assertNotNull(transitions); - assertEquals(1, transitions.size()); + assertNull(transitions); + transitions = instance.getHostFixedTransitions("Asia/Shanghai"); assertNotNull(transitions); + ZoneId shanghai = ZoneId.of("Asia/Shanghai").normalized(); + // inserted a min transition place holder, so it's n + 1 assertEquals(shanghai.getRules().getTransitions().size() + 1, transitions.size()); }