From 2901042c51d613be0430bef5ae113904f31d9622 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Thu, 29 Jun 2023 13:51:25 -0400 Subject: [PATCH] Revert the kmeans1d change. --- lib/ccv_numeric.c | 58 ++++++----------------------------------------- 1 file changed, 7 insertions(+), 51 deletions(-) diff --git a/lib/ccv_numeric.c b/lib/ccv_numeric.c index ad1124987..73f4ceca8 100644 --- a/lib/ccv_numeric.c +++ b/lib/ccv_numeric.c @@ -1266,65 +1266,21 @@ void ccv_distance_transform(ccv_dense_matrix_t* a, ccv_dense_matrix_t** b, int t #undef for_block } -__attribute__((__always_inline__)) inline static double _kmeans1d_cost(double* cumsum, double* cumsum2, int i, int j) { if (j < i) return 0; - double j_minus_i_plus_1_d = (double)(j - i + 1); - double temp = (cumsum[j + 1] - cumsum[i]); - double mu = temp / j_minus_i_plus_1_d; + double mu = (cumsum[j + 1] - cumsum[i]) / (j - i + 1); double result = cumsum2[j + 1] - cumsum2[i]; - result += j_minus_i_plus_1_d * (mu * mu); - result -= (2 * mu) * temp; + result += (j - i + 1) * (mu * mu); + result -= (2 * mu) * (cumsum[j + 1] - cumsum[i]); return result; } -__attribute__((__always_inline__)) inline static double _kmeans1d_lookup(double* D, double* cumsum, double* cumsum2, int i, int j) { - const int i_minus_j_plus_1 = i - j + 1; - const int col = i_minus_j_plus_1 < 0 ? i : j - 1; - double result = (col >= 0 ? D[col] : 0); - if (i_minus_j_plus_1 < 1) - return result; - double i_minus_j_plus_1_d = (double)i_minus_j_plus_1; - double temp = (cumsum[i + 1] - cumsum[j]); - double mu = temp / i_minus_j_plus_1_d; - double result_alt = result + cumsum2[i + 1] - cumsum2[j]; - result_alt += i_minus_j_plus_1_d * (mu * mu); - result_alt -= (2 * mu) * temp; - return result_alt; -} - -__attribute__((__always_inline__)) -inline static int _kmeans1d_lookup_compare(double* D, double* cumsum, double* cumsum2, int i, int j1, int j2) -{ - // Uses either 2-wide SIMD or instruction-level parallelism. - const int i_minus_j1_plus_1 = i - j1 + 1; - const int i_minus_j2_plus_1 = i - j2 + 1; - const int col1 = i_minus_j1_plus_1 < 0 ? i : j1 - 1; - const int col2 = i_minus_j2_plus_1 < 0 ? i : j2 - 1; - double result1 = (col1 >= 0 ? D[col1] : 0); - double result2 = (col2 >= 0 ? D[col2] : 0); - - double i_minus_j1_plus_1_d = (double)i_minus_j1_plus_1; - double i_minus_j2_plus_1_d = (double)i_minus_j2_plus_1; - double cumsum_i_1 = cumsum[i + 1]; - double cumsum2_i_1 = cumsum2[i + 1]; - double temp1 = (cumsum_i_1 - cumsum[j1]); - double temp2 = (cumsum_i_1 - cumsum[j2]); - double mu1 = temp1 / i_minus_j1_plus_1_d; - double mu2 = temp2 / i_minus_j2_plus_1_d; - double result1_alt = result1 + cumsum2_i_1 - cumsum2[j1]; - double result2_alt = result2 + cumsum2_i_1 - cumsum2[j2]; - result1_alt += i_minus_j1_plus_1_d * (mu1 * mu1); - result2_alt += i_minus_j2_plus_1_d * (mu2 * mu2); - result1_alt -= (2 * mu1) * temp1; - result2_alt -= (2 * mu2) * temp2; - result1 = (i_minus_j1_plus_1 < 1) ? result1 : result1_alt; - result2 = (i_minus_j2_plus_1 < 1) ? result2 : result2_alt; - return result1 >= result2; + const int col = i < j - 1 ? i : j - 1; + return (col >= 0 ? D[col] : 0) + _kmeans1d_cost(cumsum, cumsum2, j, i); } static void _smawk2(int row_start, int row_stride, int row_size, int* cols, int col_size, int* reserved, double* D, double* cumsum, double* cumsum2, int* result) @@ -1342,7 +1298,7 @@ static void _smawk2(int row_start, int row_stride, int row_size, int* cols, int if (_col_size == 0) break; const int row = row_start + row_stride * (_col_size - 1); - if (_kmeans1d_lookup_compare(D, cumsum, cumsum2, row, col, _cols[_col_size - 1])) + if (_kmeans1d_lookup(D, cumsum, cumsum2, row, col) >= _kmeans1d_lookup(D, cumsum, cumsum2, row, _cols[_col_size - 1])) break; --_col_size; } @@ -1396,7 +1352,7 @@ static void _smawk1(int row_start, int row_stride, int row_size, int* cols, int if (_col_size == 0) break; const int row = row_start + row_stride * (_col_size - 1); - if (_kmeans1d_lookup_compare(D, cumsum, cumsum2, row, col, _cols[_col_size - 1])) + if (_kmeans1d_lookup(D, cumsum, cumsum2, row, col) >= _kmeans1d_lookup(D, cumsum, cumsum2, row, _cols[_col_size - 1])) break; --_col_size; }