-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #37 from gnn-tracking/dev
dev
- Loading branch information
Showing
28 changed files
with
547 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
111 changes: 111 additions & 0 deletions
111
scripts/full_detector/configs/fixed-all-in-one/strict-mutant-agouti-legacy-metric.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
# pytorch_lightning==2.1.1 | ||
data: | ||
identifier: point_clouds_10 | ||
train: | ||
dirs: | ||
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_1 | ||
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_2 | ||
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_3 | ||
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_4 | ||
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_5 | ||
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_6 | ||
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_7 | ||
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_8 | ||
sample_size: 450 | ||
val: | ||
dirs: | ||
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_9 | ||
stop: 5 | ||
test: null | ||
cpus: 6 | ||
model: | ||
class_path: gnn_tracking.training.tc.TCModule | ||
init_args: | ||
loss_fct: | ||
class_path: gnn_tracking.metrics.losses.metric_learning.OldGraphConstructionHingeEmbeddingLoss | ||
init_args: | ||
lw_repulsive: 0.006 | ||
attr_pt_thld: 0.9 | ||
max_num_neighbors: 256 | ||
p_attr: 2 | ||
p_rep: 2 | ||
r_emb: 1 | ||
cluster_scanner: | ||
class_path: gnn_tracking.postprocessing.dbscanscanner.DBSCANHyperParamScanner | ||
init_args: | ||
eps_range: | ||
- 0 | ||
- 1 | ||
min_samples_range: | ||
- 1 | ||
- 4 | ||
n_trials: 60 | ||
keep_best: 30 | ||
n_jobs: 6 | ||
guide: double_majority_pt0.9 | ||
pt_thlds: | ||
- 0.0 | ||
- 0.5 | ||
- 0.9 | ||
- 1.5 | ||
max_eta: 4.0 | ||
model: | ||
class_path: gnn_tracking.models.track_condensation_networks.GraphTCNForMLGCPipeline | ||
init_args: | ||
node_indim: 38 | ||
edge_indim: 76 | ||
h_dim: 192 | ||
e_dim: 192 | ||
h_outdim: 24 | ||
hidden_dim: 192 | ||
L_hc: 4 | ||
alpha_hc: 0.5 | ||
ec: null | ||
feed_edge_weights: false | ||
ec_threshold: 0.5 | ||
mask_orphan_nodes: false | ||
use_ec_embeddings_for_hc: false | ||
alpha_latent: 0.5 | ||
n_embedding_coords: 24 | ||
optimizer: | ||
class_path: torch.optim.Adam | ||
init_args: | ||
lr: 7.0e-05 | ||
betas: | ||
- 0.9 | ||
- 0.999 | ||
eps: 1.0e-08 | ||
weight_decay: 0.0 | ||
amsgrad: false | ||
foreach: null | ||
maximize: false | ||
capturable: false | ||
differentiable: false | ||
fused: null | ||
scheduler: | ||
class_path: torch.optim.lr_scheduler.ExponentialLR | ||
init_args: | ||
gamma: 0.933 | ||
last_epoch: -1 | ||
verbose: false | ||
preproc: | ||
class_path: gnn_tracking.models.graph_construction.MLGraphConstructionFromChkpt | ||
init_args: | ||
ml_chkpt_path: /home/kl5675/Documents/23/git_sync/hyperparameter_optimization2/scripts/full_detector/lightning_logs/amber-gibbon-of-joy/checkpoints/epoch=78-step=71100.compat_newcompatible.ckpt | ||
|
||
ec_chkpt_path: "" | ||
ml_class_name: gnn_tracking.training.ml.MLModule | ||
ec_class_name: gnn_tracking.training.ec.ECModule | ||
ml_model_only: true | ||
ec_model_only: true | ||
max_radius: 1.0 | ||
max_num_neighbors: 25 | ||
use_embedding_features: true | ||
ratio_of_false: null | ||
build_edge_features: true | ||
ec_threshold: null | ||
ml_freeze: true | ||
ec_freeze: true | ||
embedding_slice: | ||
- null | ||
- null |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
class_path: gnn_tracking.training.ml.MLModule | ||
init_args: | ||
model: | ||
class_path: gnn_tracking.models.graph_construction.GraphConstructionFCNN | ||
init_args: | ||
alpha: 0.5 | ||
depth: 5 | ||
hidden_dim: 512 | ||
in_dim: 14 | ||
out_dim: 24 | ||
loss_fct: | ||
class_path: gnn_tracking.metrics.losses.metric_learning.OldGraphConstructionHingeEmbeddingLoss | ||
init_args: | ||
lw_repulsive: 0.006 | ||
attr_pt_thld: 0.9 | ||
max_num_neighbors: 256 | ||
p_attr: 2 | ||
p_rep: 2 | ||
r_emb: 1 | ||
gc_scanner: | ||
class_path: gnn_tracking.graph_construction.k_scanner.GraphConstructionKNNScanner | ||
init_args: | ||
ks: [7, 8, 9, 10, 11, 12, 13, 14, 15] | ||
optimizer: | ||
class_path: torch.optim.Adam | ||
init_args: | ||
lr: 0.0007 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
import torch | ||
from gnn_tracking.training.callbacks import PrintValidationMetrics | ||
from gnn_tracking.utils.loading import TrackingDataModule | ||
from gnn_tracking.utils.nomenclature import random_trial_name | ||
from pytorch_lightning.callbacks import RichProgressBar | ||
from pytorch_lightning.loggers import TensorBoardLogger | ||
|
||
from hpo2.lightning_utils import ContinueTrainingCLI | ||
|
||
name = random_trial_name() | ||
|
||
|
||
tb_logger = TensorBoardLogger(".", version=name) | ||
|
||
|
||
def cli_main(): | ||
torch.set_float32_matmul_precision("medium") | ||
|
||
# noinspection PyUnusedLocal | ||
cli = ContinueTrainingCLI( # noqa: F841 | ||
datamodule_class=TrackingDataModule, | ||
trainer_defaults={ | ||
"callbacks": [ | ||
RichProgressBar(leave=True), | ||
PrintValidationMetrics(), | ||
], | ||
"logger": [tb_logger], | ||
}, | ||
compile_model=True, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
cli_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
import torch | ||
import wandb | ||
from gnn_tracking.training.callbacks import ExpandWandbConfig, PrintValidationMetrics | ||
from gnn_tracking.utils.loading import TrackingDataModule | ||
from gnn_tracking.utils.nomenclature import random_trial_name | ||
from lightning_fabric.plugins.environments.slurm import SLURMEnvironment | ||
from pytorch_lightning.callbacks import ( | ||
EarlyStopping, | ||
LearningRateMonitor, | ||
ModelCheckpoint, | ||
RichProgressBar, | ||
) | ||
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger | ||
from wandb_osh.lightning_hooks import TriggerWandbSyncLightningCallback | ||
|
||
from hpo2.lightning_utils import ContinueTrainingCLI | ||
|
||
name = random_trial_name() | ||
|
||
|
||
logger = WandbLogger( | ||
project="gnn_tracking_fd", | ||
group="gc-loss-legacy-norm", | ||
offline=True, | ||
version=name, | ||
tags=["no-ec", "gc-loss", "continued"], | ||
) | ||
|
||
# Make sure that wandb init is called | ||
_ = logger.experiment | ||
wandb.define_metric( | ||
"max_trk.double_majority_pt0.9", | ||
step_metric="trk.double_majority_pt0.9", | ||
summary="max", | ||
) | ||
|
||
tb_logger = TensorBoardLogger(".", version=name) | ||
|
||
|
||
def cli_main(): | ||
torch.set_float32_matmul_precision("medium") | ||
|
||
# noinspection PyUnusedLocal | ||
cli = ContinueTrainingCLI( # noqa: F841 | ||
datamodule_class=TrackingDataModule, | ||
trainer_defaults={ | ||
"callbacks": [ | ||
RichProgressBar(leave=True), | ||
TriggerWandbSyncLightningCallback(), | ||
PrintValidationMetrics(), | ||
ExpandWandbConfig(), | ||
EarlyStopping( | ||
monitor="trk.perfect_pt0.9", | ||
mode="max", | ||
patience=20, | ||
verbose=True, | ||
), | ||
ModelCheckpoint( | ||
save_top_k=2, | ||
monitor="trk.perfect_pt0.9", | ||
mode="max", | ||
verbose=True, | ||
), | ||
LearningRateMonitor(logging_interval="step", log_momentum=True), | ||
], | ||
"logger": [tb_logger, logger], | ||
"plugins": [SLURMEnvironment(auto_requeue=False)], | ||
"max_time": "1:23:30:00", | ||
}, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
cli_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import torch | ||
from gnn_tracking.training.callbacks import PrintValidationMetrics | ||
from gnn_tracking.utils.loading import TrackingDataModule | ||
from gnn_tracking.utils.nomenclature import random_trial_name | ||
from pytorch_lightning.callbacks import RichProgressBar | ||
from pytorch_lightning.loggers import TensorBoardLogger | ||
|
||
from hpo2.lightning_utils import ContinueTrainingCLI | ||
|
||
name = random_trial_name() | ||
|
||
|
||
tb_logger = TensorBoardLogger(".", version=name) | ||
|
||
|
||
def cli_main(): | ||
torch.set_float32_matmul_precision("medium") | ||
|
||
# noinspection PyUnusedLocal | ||
cli = ContinueTrainingCLI( # noqa: F841 | ||
datamodule_class=TrackingDataModule, | ||
trainer_defaults={ | ||
"callbacks": [ | ||
RichProgressBar(leave=True), | ||
PrintValidationMetrics(), | ||
], | ||
"logger": [tb_logger], | ||
}, | ||
) | ||
|
||
|
||
if __name__ == "__main__": | ||
cli_main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.