Skip to content

Commit

Permalink
fix: fix a bug and refactor dataloading
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed May 26, 2024
1 parent 353059f commit 5066899
Showing 1 changed file with 11 additions and 11 deletions.
22 changes: 11 additions & 11 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def __init__(
raw_count = raw_count.fillna(0) # missing vals -> zeros

# Loading numpy to tensor on GPU
self.raw_count = torch.from_numpy(raw_count.values).int().to(self.device)
self.raw_count = raw_count.values
"""raw_count : np.ndarray, raw count matrix.
"""
self.n_features = raw_count.shape[1]
Expand Down Expand Up @@ -358,7 +358,7 @@ def __init__(
self.batch_id = torch.zeros(raw_count.shape[0]).int().to(self.device)
self.n_batch = 1

self.ambient_profile = torch.from_numpy(ambient_profile).float().to(self.device)
self.ambient_profile = ambient_profile
"""ambient_profile : np.ndarray, the probability of occurrence of each ambient transcript.
"""

Expand Down Expand Up @@ -442,11 +442,11 @@ def train(
train_ids, test_ids = train_test_split(list_ids, train_size=train_size)

# Generators
training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, train_ids)
training_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, train_ids, device=self.device)
training_generator = torch.utils.data.DataLoader(
training_set, batch_size=batch_size, shuffle=shuffle
)
val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, test_ids)
val_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, test_ids, device=self.device)
val_generator = torch.utils.data.DataLoader(
val_set, batch_size=batch_size, shuffle=shuffle
)
Expand Down Expand Up @@ -588,7 +588,7 @@ def inference(
native_frequencies, and noise_ratio. \
A feature_assignment will be added in 'sgRNA' or 'tag' or 'CMO' feature type.
"""
total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id)
total_set = UMIDataset(self.raw_count, self.ambient_profile, self.batch_id, device=self.device)
n_features = self.n_features
sample_size = self.raw_count.shape[0]
self.native_counts = np.empty([sample_size, n_features])
Expand Down Expand Up @@ -707,12 +707,12 @@ def assignment(self, cutoff=3, moi=None):
class UMIDataset(torch.utils.data.Dataset):
"""Characterizes dataset for PyTorch"""

def __init__(self, raw_count, ambient_profile, batch_id, list_ids=None):
def __init__(self, raw_count, ambient_profile, batch_id, device, list_ids=None):
"""Initialization"""
self.raw_count = raw_count
self.ambient_profile = ambient_profile
self.batch_id = batch_id
self.batch_onehot = self._onehot(batch_id)
self.raw_count = torch.from_numpy(raw_count).int().to(device)
self.ambient_profile = torch.from_numpy(ambient_profile).float().to(device)
self.batch_id = batch_id.to(torch.int64).to(device)
self.batch_onehot = self._onehot(batch_id.to(torch.int64)).to(device)

if list_ids:
self.list_ids = list_ids
Expand All @@ -735,6 +735,6 @@ def __getitem__(self, index):
def _onehot(self, batch_id):
"""One-hot encoding"""
n_batch = batch_id.unique().size()[0]
x_onehot = torch.zeros(n_batch, n_batch).to(batch_id.device)
x_onehot = torch.zeros(n_batch, n_batch)
x_onehot.scatter_(1, batch_id.unique().unsqueeze(1), 1)
return x_onehot

0 comments on commit 5066899

Please sign in to comment.