Skip to content

Commit

Permalink
Overrides dtype to match input
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonb5 committed Nov 8, 2023
1 parent df2d23a commit 3d5d0d9
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions xcdat/regridder/regrid2.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,8 @@ def _regrid(

output_data = []

target_dtype = input_data.dtype

# need to optimize
for y in range(y_length):
y_seg = np.take(input_data, lat_mapping[y], axis=y_index)
Expand All @@ -138,12 +140,12 @@ def _regrid(
)

cell_value = np.nansum(
np.multiply(x_seg, cell_weight), axis=(y_index, x_index)
) / np.sum(cell_weight)
np.multiply(x_seg, cell_weight, dtype=target_dtype), axis=(y_index, x_index), dtype=target_dtype
) / np.sum(cell_weight, dtype=target_dtype)

output_data.append(cell_value)

output_data = np.asarray(output_data, dtype=np.float32)
output_data = np.asarray(output_data, dtype=target_dtype)
output_data = output_data.reshape(tuple(data_shape.values()))

return output_data
Expand Down

0 comments on commit 3d5d0d9

Please sign in to comment.