Skip to content

Commit

Permalink
feat: add functionality for batch ambient removal
Browse files Browse the repository at this point in the history
  • Loading branch information
CaibinSh committed May 26, 2024
1 parent 3a0058d commit 5bc0d84
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 17 deletions.
25 changes: 19 additions & 6 deletions scar/main/_scar.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,8 @@ def __init__(
ambient_profile[batch_id, :] = subset.X.sum(axis=0) / subset.X.sum()

# add a mapper to locate the batch id
self.batch_id = batch_id_per_cell
self.batch_id = torch.from_numpy(batch_id_per_cell).int().to(self.device)
self.n_batch = batch_id.unique().size()[0]

# get ambient profile from AnnData.uns
elif (ambient_profile is None) and ("ambient_profile_all" in raw_count.uns):
Expand Down Expand Up @@ -354,7 +355,8 @@ def __init__(
.reshape(1, -1)
)
# add a mapper to locate the artificial batch id
self.batch_id = np.zeros(raw_count.shape[0], dtype=int)
self.batch_id = torch.zeros(raw_count.shape[0]).int().to(self.device)
self.n_batch = 1

self.ambient_profile = torch.from_numpy(ambient_profile).float().to(self.device)
"""ambient_profile : np.ndarray, the probability of occurrence of each ambient transcript.
Expand Down Expand Up @@ -461,6 +463,7 @@ def train(
feature_type=self.feature_type,
count_model=self.count_model,
sparsity=self.sparsity,
n_batch=self.n_batch,
verbose=verbose,
).to(self.device)
# Define optimizer
Expand Down Expand Up @@ -491,9 +494,9 @@ def train(
train_recon_loss = 0

vae_nets.train()
for x_batch, ambient_freq in training_generator:
for x_batch, ambient_freq, batch_id_onehot in training_generator:
optim.zero_grad()
dec_nr, dec_prob, means, var, dec_dp = vae_nets(x_batch)
dec_nr, dec_prob, means, var, dec_dp = vae_nets(x_batch, batch_id_onehot)
recon_loss_minibatch, kld_loss_minibatch, loss_minibatch = loss_fn(
x_batch,
dec_nr,
Expand Down Expand Up @@ -600,7 +603,7 @@ def inference(
total_set, batch_size=batch_size, shuffle=False
)

for x_batch_tot, ambient_freq_tot in generator_full_data:
for x_batch_tot, ambient_freq_tot, x_batch_id_onehot_tot in generator_full_data:
minibatch_size = x_batch_tot.shape[
0
] # if not the last batch, equals to batch size
Expand All @@ -612,6 +615,7 @@ def inference(
noise_ratio_batch,
) = self.trained_model.inference(
x_batch_tot,
x_batch_id_onehot_tot,
ambient_freq_tot[0, :],
count_model_inf=count_model_inf,
adjust=adjust,
Expand Down Expand Up @@ -708,6 +712,7 @@ def __init__(self, raw_count, ambient_profile, batch_id, list_ids=None):
self.raw_count = raw_count
self.ambient_profile = ambient_profile
self.batch_id = batch_id
self.batch_onehot = self._onehot(batch_id)

if list_ids:
self.list_ids = list_ids
Expand All @@ -724,4 +729,12 @@ def __getitem__(self, index):
sc_id = self.list_ids[index]
sc_count = self.raw_count[sc_id, :]
sc_ambient = self.ambient_profile[self.batch_id[sc_id], :]
return sc_count, sc_ambient
sc_batch_id_onehot = self.batch_onehot[self.batch_id[sc_id], :]
return sc_count, sc_ambient, sc_batch_id_onehot

def _onehot(self, batch_id):
"""One-hot encoding"""
n_batch = batch_id.unique().size()[0]
x_onehot = torch.zeros(n_batch, n_batch).to(batch_id.device)
x_onehot.scatter_(1, batch_id.unique().unsqueeze(1), 1)
return x_onehot
30 changes: 19 additions & 11 deletions scar/main/_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def __init__(
dropout_prob=0,
feature_type="mRNA",
count_model="binomial",
n_batch=1,
sparsity=0.9,
verbose=True,
):
Expand Down Expand Up @@ -81,10 +82,11 @@ def __init__(
sparsity = 1

self.encoder = Encoder(
n_features, nn_layer1, nn_layer2, latent_dim, dropout_prob
n_features, n_batch, nn_layer1, nn_layer2, latent_dim, dropout_prob
)
self.decoder = Decoder(
n_features,
n_batch,
nn_layer1,
nn_layer2,
latent_dim,
Expand All @@ -105,16 +107,17 @@ def __init__(
vae_logger.info(f"...dropout_prob: {dropout_prob:.2f}")
vae_logger.info(f"...expected data sparsity: {sparsity:.2f}")

def forward(self, input_matrix):
def forward(self, input_matrix, batch_id_onehot=None):
"""forward function"""
sampling, means, var = self.encoder(input_matrix)
dec_nr, dec_prob, dec_dp = self.decoder(sampling)
sampling, means, var = self.encoder(input_matrix, batch_id_onehot)
dec_nr, dec_prob, dec_dp = self.decoder(sampling, batch_id_onehot)
return dec_nr, dec_prob, means, var, dec_dp

@torch.no_grad()
def inference(
self,
input_matrix,
batch_id_onehot,
amb_prob,
count_model_inf="poisson",
adjust="micro",
Expand All @@ -128,7 +131,7 @@ def inference(
assert adjust in [False, "global", "micro"]

# Estimate native signals
dec_nr, dec_prob, _, _, _ = self.forward(input_matrix)
dec_nr, dec_prob, _, _, _ = self.forward(input_matrix, batch_id_onehot)

# Copy tensor to CPU
input_matrix_np = input_matrix.cpu().numpy()
Expand Down Expand Up @@ -230,11 +233,13 @@ class Encoder(nn.Module):
Consists of 2 FC layers.
"""

def __init__(self, n_features, nn_layer1, nn_layer2, latent_dim, dropout_prob):
def __init__(self, n_features, n_batch, nn_layer1, nn_layer2, latent_dim, dropout_prob):
"""initialization"""
super().__init__()
self.activation = nn.SELU()
self.fc1 = nn.Linear(n_features, nn_layer1)
# if n_batch > 1:
# n_features += n_batch
self.fc1 = nn.Linear(n_features + n_batch, nn_layer1)
self.bn1 = nn.BatchNorm1d(nn_layer1, momentum=0.01, eps=0.001)
self.dp1 = nn.Dropout(p=dropout_prob)
self.fc2 = nn.Linear(nn_layer1, nn_layer2)
Expand All @@ -250,9 +255,10 @@ def reparametrize(self, means, log_vars):
var = log_vars.exp() + 1e-4
return torch.distributions.Normal(means, var.sqrt()).rsample(), var

def forward(self, input_matrix):
def forward(self, input_matrix, batch_id_onehot):
"""forward function"""
input_matrix = (input_matrix + 1).log2() # log transformation of count data
input_matrix = torch.cat([input_matrix, batch_id_onehot], 1)
enc = self.fc1(input_matrix)
enc = self.bn1(enc)
enc = self.activation(enc)
Expand Down Expand Up @@ -284,6 +290,7 @@ class Decoder(nn.Module):
def __init__(
self,
n_features,
n_batch,
nn_layer1,
nn_layer2,
latent_dim,
Expand All @@ -297,7 +304,7 @@ def __init__(
self.normalization_native_freq = hnormalization()
self.noise_activation = mytanh()
self.activation_native_freq = mysoftplus(sparsity)
self.fc4 = nn.Linear(latent_dim, nn_layer2)
self.fc4 = nn.Linear(latent_dim + n_batch, nn_layer2)
self.bn4 = nn.BatchNorm1d(nn_layer2, momentum=0.01, eps=0.001)
self.dp4 = nn.Dropout(p=dropout_prob)
self.fc5 = nn.Linear(nn_layer2, nn_layer1)
Expand All @@ -311,10 +318,11 @@ def __init__(
self.dropoutprob = nn.Linear(nn_layer1, 1)
self.dropout_activation = mytanh()

def forward(self, sampling):
def forward(self, sampling, batch_id_onehot):
"""forward function"""
# decoder
dec = self.fc4(sampling)
cond_sampling = torch.cat([sampling, batch_id_onehot], 1)
dec = self.fc4(cond_sampling)
dec = self.bn4(dec)
dec = self.activation(dec)
dec = self.fc5(dec)
Expand Down

0 comments on commit 5bc0d84

Please sign in to comment.