Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

balance dataset #14

Open
bw4sz opened this issue Nov 9, 2021 · 4 comments
Open

balance dataset #14

bw4sz opened this issue Nov 9, 2021 · 4 comments

Comments

@bw4sz
Copy link
Collaborator

bw4sz commented Nov 9, 2021

try balancing without any floor or ceiling resampling value

@ethanwhite
Copy link
Member

Balancing approaches are implemented in #22. Both @bw4sz and I have experimented with versions of this both with and without floors/ceilings and not seen any improvement (see also #21). This feels weird, but may be related to Focal Loss having already addressed the class imbalance to the degree possible.

@bw4sz
Copy link
Collaborator Author

bw4sz commented Dec 20, 2021

I want to keep coming back to this. It just feels too vital to let go. I've had no success outside of completely balanced data, but i think the sampling process is leading to a ton of inter-run variability and overall just feels like we are denying the model atleast some reasonable prevalence information. Especially when we use site-level metadata, it feels like the overfitting argument here is passed, we are already providing site-specific info.

@bw4sz
Copy link
Collaborator Author

bw4sz commented Dec 21, 2021

One of the challenges of this research program is that each decision seems to cascade and effect others. I certainly tested the sampler when we added into species classification. Now, on first try, I cannot find any difference between balanced and unbalanced ('raw') sampling.

https://www.comet.ml/bw4sz/deeptreeattention/176395505730431ca567e5ddef84267d/e3b659a81e02473596572d5a815f17c5/compare?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&xAxis=step

This either means that the sampler is not working as intended, or that some other innovation in the mean time has rendered it irrelevant. I will continue to follow this.

@bw4sz
Copy link
Collaborator Author

bw4sz commented Dec 21, 2021

Only by the smallest amount does balancing with a ceiling win now. I find the code confusing though, needs more thought.

    def train_dataloader(self):
        """Load a training file. The default location is saved during self.setup(), to override this location, set self.train_file before training"""       
        
        #get class weights
        train = pd.read_csv(self.train_file)
        class_weights = train.label.value_counts().to_dict()     
            
        data_weights = []
        #balance classes
        for idx in range(len(self.train_ds)):
            path, image, targets = self.train_ds[idx]
            label = int(targets.numpy())
            class_freq = class_weights[label]
            if class_freq > 100:
                class_freq = 100
            data_weights.append(1/class_freq)
            
        sampler = torch.utils.data.sampler.WeightedRandomSampler(weights = data_weights, num_samples=len(self.train_ds))
        data_loader = torch.utils.data.DataLoader(
            self.train_ds,
            batch_size=self.config["batch_size"],
            num_workers=self.config["workers"],
            sampler=sampler
        )
        
        return data_loader
(DeepTreeAttention) [b.weinstein@login3 DeepTreeAttention]$ python
Python 3.8.11 (default, Aug  3 2021, 15:09:35)
[GCC 7.5.0] :: Anaconda, Inc. on linux
Type "help", "copyright", "credits" or "license" for more information.
>>> from src.data import *
>>> import pandas as pd
>>> config = read_config("config.yml")
>>> data_module = TreeData(csv_file="data/raw/neon_vst_data_2021.csv", regenerate=False, client=None, metadata=True, comet_logger=None)
>>> data_module.setup()
>>> dl = data_module.train_dataloader()
>>> labels = []
>>> for batch in dl:
...     paths, inputs, batch_labels = batch
...     labels.append(batch_labels.numpy())
...
...
>>>
>>> labels = np.concatenate(labels)
>>> g = pd.Series(labels).value_counts().reset_index(name="taxonID")
>>>
>>> g
    index  taxonID
0      13     1979
1      19      687
2       1      676
3       8      290
4       7      120
5      14      105
6      15      102
7      23       95
8      20       93
9      21       90
10      2       89
11      9       87
12     17       83
13     11       82
14     10       81
15      0       81
16      4       79
17     22       79
18     16       78
19     18       76
20      3       74
21      5       72
22      6       65
23     12       64

Basically by undersampling the top class, we oversample the bottom. Which is strange because replacement = True is default.
https://www.comet.ml/bw4sz/deeptreeattention/eaf46472a6ab4e15b16e5d9286556901/4e121fa6f04e426bad08183ad17f5b94/compare?experiment-tab=chart&showOutliers=true&smoothing=0&transformY=smoothing&xAxis=epoch

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants