diff --git a/python/python/lance/torch/kmeans.py b/python/python/lance/torch/kmeans.py index ac22042dd8..dc0703b5f1 100644 --- a/python/python/lance/torch/kmeans.py +++ b/python/python/lance/torch/kmeans.py @@ -174,7 +174,7 @@ def _updated_centroids( self, centroids: torch.Tensor, counts: torch.Tensor ) -> torch.Tensor: centroids = centroids / counts[:, None] - zero_counts = centroids == 0 + zero_counts = counts == 0 for idx in zero_counts.nonzero(as_tuple=False): # split the largest cluster and remove empty cluster max_idx = torch.argmax(counts).item()