Skip to content

Commit

Permalink
modify tutorial notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
poyentung committed Sep 20, 2022
1 parent 0157f2a commit 0674cd4
Show file tree
Hide file tree
Showing 6 changed files with 448 additions and 166 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from setuptools import setup, find_packages
with open("README.md", "r") as fh:
long_description = fh.read()
__version__ = "0.1.29"
__version__ = "0.1.30"

setup(
name='emsigma',
Expand Down
40 changes: 27 additions & 13 deletions sigma/gui/gui.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import random
import numpy as np
import pandas as pd
from typing import List
from typing import List, Dict
import hyperspy.api as hs
from matplotlib import pyplot as plt
import matplotlib as mpl
Expand Down Expand Up @@ -367,28 +367,42 @@ def dropdown_b_eventhandler(change):
display(plots_output)


def view_pixel_distributions(sem, norm_list=[], peak="Fe_Ka", cmap="viridis"):
out = widgets.Output()
with out:
def view_pixel_distributions(sem, norm_list=[], cmap="viridis"):
peak_options = sem.feature_list
dropdown_peaks = widgets.Dropdown(options=peak_options, description="Element:")

plots_output = widgets.Output()

with plots_output:
fig = visual.plot_pixel_distributions(
sem=sem, norm_list=norm_list, peak=peak, cmap=cmap
)
sem=sem, norm_list=norm_list, peak=dropdown_peaks.value, cmap=cmap
)
plt.show()

out_box = widgets.Box([out])
save_fig(fig)

def dropdown_option_eventhandler(change):
plots_output.clear_output()
with plots_output:
fig = visual.plot_pixel_distributions(
sem=sem, norm_list=norm_list, peak=dropdown_peaks.value, cmap=cmap
)
plt.show()
save_fig(fig)

dropdown_peaks.observe(dropdown_option_eventhandler, names="value")
out_box = widgets.VBox([dropdown_peaks, plots_output])
display(out_box)
save_fig(fig)


def view_intensity_maps(edx, element_list):
pick_color(visual.plot_intensity_maps, edx=edx, element_list=element_list)


def view_bic(
latent,
n_components=20,
model="BayesianGaussianMixture",
model_args={"random_state": 6},
latent: np.ndarray,
n_components: int = 20,
model: str = "BayesianGaussianMixture",
model_args: Dict = {"random_state": 6},
):
bic_list = PixelSegmenter.bic(latent, n_components, model, model_args)
fig = go.Figure(
Expand Down
29 changes: 15 additions & 14 deletions sigma/src/dim_reduction.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import torch
import datetime
from typing import Dict
import numpy as np
import torch.nn as nn
from torch.optim import Adam
Expand Down Expand Up @@ -73,22 +74,22 @@ def set_up_results_dirs(self):
def run_model(
self,
num_epochs: int,
patience: int,
batch_size: int,
learning_rate=1e-4,
weight_decay=0.0,
task="train_eval",
patience: int = 50,
learning_rate: float = 1e-4,
weight_decay: float = 0.0,
task: str = "train_eval",
noise_added=None,
criterion="MSE",
KLD_lambda=1e-4,
print_latent=False,
lr_scheduler_args={
"factor": 0.5,
"verbose": True,
"patience": 5,
"threshold": 1e-2,
"min_lr": 1e-7,
},
criterion: str = "MSE",
KLD_lambda:float = 1e-4,
print_latent: bool = False,
lr_scheduler_args: Dict={
"factor": 0.5,
"verbose": True,
"patience": 5,
"threshold": 1e-2,
"min_lr": 1e-7,
},
):
# Loss functions
if criterion == "MSE":
Expand Down
9 changes: 5 additions & 4 deletions sigma/src/segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from sigma.utils.load import SEMDataset
from sigma.utils.visualisation import make_colormap

from typing import Dict
import hyperspy.api as hs
import numpy as np
import pandas as pd
Expand All @@ -23,11 +24,11 @@
class PixelSegmenter(object):
def __init__(
self,
latent: np,
dataset_norm: np,
latent: np.ndarray,
dataset_norm: np.ndarray,
sem: SEMDataset,
method="BayesianGaussianMixture",
method_args={"n_components": 8, "random_state": 4},
method: str = "BayesianGaussianMixture",
method_args: Dict = {"n_components": 8, "random_state": 4},
):

self.latent = latent
Expand Down
Loading

0 comments on commit 0674cd4

Please sign in to comment.