Skip to content

Commit

Permalink
Update diversity strategy to handle greedy update
Browse files Browse the repository at this point in the history
  • Loading branch information
patel-zeel committed Oct 28, 2023
1 parent 921690f commit 7d0d3e8
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 33 deletions.
50 changes: 26 additions & 24 deletions astra/torch/al/strategies/diversity.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader

from astra.torch.al import Strategy

from typing import Sequence, Dict
from typing import Sequence, Dict, Union, List


class DiversityStrategy(Strategy):
def query(
self,
net: nn.Module,
pool_indices: Sequence[int],
context_indices: Sequence[int] = None,
pool_indices: Union[List[int], np.ndarray, torch.Tensor],
context_indices: Union[List[int], np.ndarray, torch.Tensor] = None,
n_query_samples: int = 1,
n_mc_samples: int = None,
batch_size: int = None,
Expand All @@ -33,33 +34,34 @@ def query(
if batch_size is None:
batch_size = len(pool_indices)

data_loader = DataLoader(self.dataset[pool_indices])
context_data_loader = DataLoader(self.dataset[context_indices])
data_loader = DataLoader(self.dataset)

with torch.no_grad():
# Get the features for the pool
pool_features_list = []
# Get all features
features_list = []
for x, _ in data_loader:
pool_features = net(x)
pool_features_list.append(pool_features)
pool_features = torch.cat(pool_features_list, dim=0) # (pool_dim, feature_dim)

# Get the features for the context
context_features_list = []
for x, _ in context_data_loader:
context_features = net(x.to(self.device))
context_features_list.append(context_features)
context_features = torch.cat(context_features_list, dim=0) # (context_dim, feature_dim)
features = net(x)
features_list.append(features)
features = torch.cat(features_list, dim=0) # (data_dim, feature_dim)

best_indices = {}
# TODO: Fix this for loop to do the following:
# - Get the max score. Get corresponding index.
# - Add that index to selected_indices and also to pool indices.
# - Remove that index from pool indices.
# - Repeat until len(selected_indices) == n_query_samples.

if not isinstance(pool_indices, list):
# tolist() works for both numpy and torch tensors. It also work for tensors on GPU.
pool_indices = pool_indices.tolist()
if not isinstance(context_indices, list):
context_indices = context_indices.tolist()

for acq_name, acquisition in self.acquisitions.items():
scores = acquisition.acquire_scores(pool_features, context_features)
selected_indices = torch.topk(scores, n_query_samples).indices
selected_indices = []
# TODO: We can make this loop faster by computing scores only for updated indices. There can be a method in acquisition to update the scores.
for _ in range(n_query_samples):
scores = acquisition.acquire_scores(features, pool_indices, context_indices)
best_index = torch.argmax(scores)
selected_indices.append(best_index)
pool_indices.pop(best_index)
context_indices.append(best_index)
selected_indices = torch.tensor(selected_indices, device=self.device)
best_indices[acq_name] = selected_indices

return best_indices
6 changes: 3 additions & 3 deletions astra/torch/al/strategies/ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@

from astra.torch.al import Strategy

from typing import Sequence, Dict
from typing import Sequence, Dict, List, Union


class EnsembleStrategy(Strategy):
def query(
self,
net: Sequence[nn.Module],
pool_indices: Sequence[int],
net: Union[List[int], np.ndarray, torch.Tensor],
pool_indices: Union[List[int], np.ndarray, torch.Tensor],
context_indices: Sequence[int] = None,
n_query_samples: int = 1,
n_mc_samples: int = None,
Expand Down
6 changes: 3 additions & 3 deletions astra/torch/al/strategies/mc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@

from astra.torch.al import Strategy

from typing import Sequence, Dict
from typing import Sequence, Dict, Union, List


class MCStrategy(Strategy):
def query(
self,
net: nn.Module,
pool_indices: Sequence[int],
context_indices: Sequence[int] = None,
pool_indices: Union[List[int], np.ndarray, torch.Tensor],
context_indices: Union[List[int], np.ndarray, torch.Tensor] = None,
n_query_samples: int = 1,
n_mc_samples: int = 10,
batch_size: int = None,
Expand Down
7 changes: 4 additions & 3 deletions astra/torch/al/strategies/random.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
import numpy as np
import torch
import torch.nn as nn
from astra.torch.al import Strategy

from typing import Sequence, Dict
from typing import Dict, Union, List


class RandomStrategy(Strategy):
def query(
self,
net: nn.Module,
pool_indices: Sequence[int],
context_indices: Sequence[int] = None,
pool_indices: Union[List[int], np.ndarray, torch.Tensor],
context_indices: Union[List[int], np.ndarray, torch.Tensor] = None,
n_query_samples: int = 1,
n_mc_samples: int = 10,
batch_size: int = None,
Expand Down

0 comments on commit 7d0d3e8

Please sign in to comment.