diff --git a/scar/main/_scar.py b/scar/main/_scar.py index ecf1ef7..7bc390f 100644 --- a/scar/main/_scar.py +++ b/scar/main/_scar.py @@ -325,7 +325,7 @@ def __init__( raw_count = raw_count.fillna(0) # missing vals -> zeros # Loading numpy to tensor on GPU - self.raw_count = torch.from_numpy(raw_count.values).int().to(self.device) + self.raw_count = raw_count.values """raw_count : np.ndarray, raw count matrix. """ self.n_features = raw_count.shape[1] @@ -358,7 +358,7 @@ def __init__( self.batch_id = torch.zeros(raw_count.shape[0]).int().to(self.device) self.n_batch = 1 - self.ambient_profile = torch.from_numpy(ambient_profile).float().to(self.device) + self.ambient_profile = ambient_profile """ambient_profile : np.ndarray, the probability of occurrence of each ambient transcript. """ @@ -442,11 +442,11 @@ def train( train_ids, test_ids = train_test_split(list_ids, train_size=train_size) # Generators - training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, train_ids) + training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, train_ids, device=self.device) training_generator = torch.utils.data.DataLoader( training_set, batch_size=batch_size, shuffle=shuffle ) - val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, test_ids) + val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, test_ids, device=self.device) val_generator = torch.utils.data.DataLoader( val_set, batch_size=batch_size, shuffle=shuffle ) @@ -588,7 +588,7 @@ def inference( native_frequencies, and noise_ratio. \ A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type. """ - total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id) + total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device) n_features = self.n_features sample_size = self.raw_count.shape[0] self.native_counts = np.empty([sample_size, n_features]) @@ -707,12 +707,12 @@ def assignment(self, cutoff=3, moi=None): class UMIDataset(torch.utils.data.Dataset): """Characterizes dataset for PyTorch""" - def __init__(self, raw_count, ambient_profile, batch_id, list_ids=None): + def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None): """Initialization""" - self.raw_count = raw_count - self.ambient_profile = ambient_profile - self.batch_id = batch_id - self.batch_onehot = self._onehot(batch_id) + self.raw_count = torch.from_numpy(raw_count).int().to(device) + self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device) + self.batch_id = batch_id.to(torch.int64).to(device) + self.batch_onehot = self._onehot(batch_id.to(torch.int64)).to(device) if list_ids: self.list_ids = list_ids @@ -735,6 +735,6 @@ def __getitem__(self, index): def _onehot(self, batch_id): """One-hot encoding""" n_batch = batch_id.unique().size()[0] - x_onehot = torch.zeros(n_batch, n_batch).to(batch_id.device) + x_onehot = torch.zeros(n_batch, n_batch) x_onehot.scatter_(1, batch_id.unique().unsqueeze(1), 1) return x_onehot \ No newline at end of file