Skip to content

Commit

Permalink
fix: fix a bug
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed May 26, 2024
1 parent 5066899 commit e5ebaf8
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,11 +442,11 @@ def train(
train_ids, test_ids = train_test_split(list_ids, train_size=train_size)

# Generators
training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, train_ids, device=self.device)
training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, list_ids=train_ids)
training_generator = torch.utils.data.DataLoader(
training_set, batch_size=batch_size, shuffle=shuffle
)
val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, test_ids, device=self.device)
val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device, list_ids=test_ids)
val_generator = torch.utils.data.DataLoader(
val_set, batch_size=batch_size, shuffle=shuffle
)
Expand Down

0 comments on commit e5ebaf8

Please sign in to comment.