diff --git a/espaloma/data/dataset.py b/espaloma/data/dataset.py index 4e413c86..ba842611 100644 --- a/espaloma/data/dataset.py +++ b/espaloma/data/dataset.py @@ -186,10 +186,14 @@ def split(self, partition): """ n_data = len(self) - partition = [int(n_data * x / sum(partition)) for x in partition] + p_sizes = [] + for i, _partition in enumerate(partition): + p_size = int((n_data - sum(p_sizes)) * _partition / sum(partition[i:])) + p_sizes.append(p_size) + assert sum(p_sizes) == n_data, f"{p_sizes}, {sum(p_sizes)}" ds = [] idx = 0 - for p_size in partition: + for p_size in p_sizes: ds.append(self[idx : idx + p_size]) idx += p_size