Skip to content

Commit

Permalink
Merge pull request #37 from gnn-tracking/dev
Browse files Browse the repository at this point in the history
dev
  • Loading branch information
klieret authored Mar 12, 2024
2 parents bc73845 + ad5c47e commit dcb04cd
Show file tree
Hide file tree
Showing 28 changed files with 547 additions and 29 deletions.
4 changes: 2 additions & 2 deletions scripts/full_detector/configs/exponential_lr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,5 @@ init_args:
scheduler:
class_path: torch.optim.lr_scheduler.ExponentialLR
init_args:
# 1e-3 at epoch 100
gamma: 0.933
# 1e-4 at epoch 100
gamma: 0.912
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ model:
preproc:
class_path: gnn_tracking.models.graph_construction.MLGraphConstructionFromChkpt
init_args:
ml_chkpt_path: /home/kl5675/Documents/23/git_sync/hyperparameter_optimization2/scripts/full_detector/lightning_logs/garrulous-peach-manatee/checkpoints/epoch=111-step=50400.compat.ckpt
ml_chkpt_path: /home/kl5675/Documents/23/git_sync/hyperparameter_optimization2/scripts/full_detector/lightning_logs/amber-gibbon-of-joy/checkpoints/epoch=78-step=71100.compat_newcompatible.ckpt
ec_chkpt_path: ""
ml_class_name: gnn_tracking.training.ml.MLModule
ec_class_name: gnn_tracking.training.ec.ECModule
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
# pytorch_lightning==2.1.1
data:
identifier: point_clouds_10
train:
dirs:
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_1
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_2
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_3
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_4
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_5
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_6
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_7
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_8
sample_size: 450
val:
dirs:
- /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_9
stop: 5
test: null
cpus: 6
model:
class_path: gnn_tracking.training.tc.TCModule
init_args:
loss_fct:
class_path: gnn_tracking.metrics.losses.metric_learning.OldGraphConstructionHingeEmbeddingLoss
init_args:
lw_repulsive: 0.006
attr_pt_thld: 0.9
max_num_neighbors: 256
p_attr: 2
p_rep: 2
r_emb: 1
cluster_scanner:
class_path: gnn_tracking.postprocessing.dbscanscanner.DBSCANHyperParamScanner
init_args:
eps_range:
- 0
- 1
min_samples_range:
- 1
- 4
n_trials: 60
keep_best: 30
n_jobs: 6
guide: double_majority_pt0.9
pt_thlds:
- 0.0
- 0.5
- 0.9
- 1.5
max_eta: 4.0
model:
class_path: gnn_tracking.models.track_condensation_networks.GraphTCNForMLGCPipeline
init_args:
node_indim: 38
edge_indim: 76
h_dim: 192
e_dim: 192
h_outdim: 24
hidden_dim: 192
L_hc: 4
alpha_hc: 0.5
ec: null
feed_edge_weights: false
ec_threshold: 0.5
mask_orphan_nodes: false
use_ec_embeddings_for_hc: false
alpha_latent: 0.5
n_embedding_coords: 24
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 7.0e-05
betas:
- 0.9
- 0.999
eps: 1.0e-08
weight_decay: 0.0
amsgrad: false
foreach: null
maximize: false
capturable: false
differentiable: false
fused: null
scheduler:
class_path: torch.optim.lr_scheduler.ExponentialLR
init_args:
gamma: 0.933
last_epoch: -1
verbose: false
preproc:
class_path: gnn_tracking.models.graph_construction.MLGraphConstructionFromChkpt
init_args:
ml_chkpt_path: /home/kl5675/Documents/23/git_sync/hyperparameter_optimization2/scripts/full_detector/lightning_logs/amber-gibbon-of-joy/checkpoints/epoch=78-step=71100.compat_newcompatible.ckpt

ec_chkpt_path: ""
ml_class_name: gnn_tracking.training.ml.MLModule
ec_class_name: gnn_tracking.training.ec.ECModule
ml_model_only: true
ec_model_only: true
max_radius: 1.0
max_num_neighbors: 25
use_embedding_features: true
ratio_of_false: null
build_edge_features: true
ec_threshold: null
ml_freeze: true
ec_freeze: true
embedding_slice:
- null
- null
27 changes: 27 additions & 0 deletions scripts/full_detector/configs/gc/model-legacy-loss.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
class_path: gnn_tracking.training.ml.MLModule
init_args:
model:
class_path: gnn_tracking.models.graph_construction.GraphConstructionFCNN
init_args:
alpha: 0.5
depth: 5
hidden_dim: 512
in_dim: 14
out_dim: 24
loss_fct:
class_path: gnn_tracking.metrics.losses.metric_learning.OldGraphConstructionHingeEmbeddingLoss
init_args:
lw_repulsive: 0.006
attr_pt_thld: 0.9
max_num_neighbors: 256
p_attr: 2
p_rep: 2
r_emb: 1
gc_scanner:
class_path: gnn_tracking.graph_construction.k_scanner.GraphConstructionKNNScanner
init_args:
ks: [7, 8, 9, 10, 11, 12, 13, 14, 15]
optimizer:
class_path: torch.optim.Adam
init_args:
lr: 0.0007
1 change: 1 addition & 0 deletions scripts/full_detector/configs/gc/model.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ init_args:
p_attr: 2
p_rep: 2
r_emb: 1
rep_normalization: n_rep_edges
gc_scanner:
class_path: gnn_tracking.graph_construction.k_scanner.GraphConstructionKNNScanner
init_args:
Expand Down
9 changes: 5 additions & 4 deletions scripts/full_detector/configs/oc/model_gc_loss.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ init_args:
init_args:
node_indim: 38
edge_indim: 76
h_dim: 192
e_dim: 192
hidden_dim: 192
h_dim: 256
e_dim: 256
hidden_dim: 256
h_outdim: 24
L_hc: 3
L_hc: 4
alpha_latent: 0.5
n_embedding_coords: 24
heterogeneous_node_encoder: true
Expand All @@ -30,6 +30,7 @@ init_args:
p_attr: 2
p_rep: 2
r_emb: 1
rep_normalization: n_rep_edges
cluster_scanner:
class_path: gnn_tracking.postprocessing.dbscanscanner.DBSCANHyperParamScanner
init_args:
Expand Down
14 changes: 10 additions & 4 deletions scripts/full_detector/continue_gc.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,12 @@
from gnn_tracking.utils.loading import TrackingDataModule
from gnn_tracking.utils.nomenclature import random_trial_name
from lightning_fabric.plugins.environments.slurm import SLURMEnvironment
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, RichProgressBar
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
RichProgressBar,
)
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from wandb_osh.lightning_hooks import TriggerWandbSyncLightningCallback

Expand All @@ -13,11 +18,11 @@


wandb_logger = WandbLogger(
project="gnn_tracking_gc",
group="full-detector",
project="gnn_tracking_gc_fd",
group="",
offline=True,
version=name,
tags=["full-detector", "continued"],
tags=["continued"],
)


Expand All @@ -38,6 +43,7 @@ def cli_main():
ExpandWandbConfig(),
EarlyStopping(monitor="total", mode="min", patience=10),
ModelCheckpoint(save_top_k=2, monitor="total", mode="min"),
LearningRateMonitor(logging_interval="step", log_momentum=True),
],
"logger": [tb_logger, wandb_logger],
"plugins": [SLURMEnvironment(auto_requeue=False)],
Expand Down
34 changes: 34 additions & 0 deletions scripts/full_detector/continue_gc_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
from gnn_tracking.training.callbacks import PrintValidationMetrics
from gnn_tracking.utils.loading import TrackingDataModule
from gnn_tracking.utils.nomenclature import random_trial_name
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import TensorBoardLogger

from hpo2.lightning_utils import ContinueTrainingCLI

name = random_trial_name()


tb_logger = TensorBoardLogger(".", version=name)


def cli_main():
torch.set_float32_matmul_precision("medium")

# noinspection PyUnusedLocal
cli = ContinueTrainingCLI( # noqa: F841
datamodule_class=TrackingDataModule,
trainer_defaults={
"callbacks": [
RichProgressBar(leave=True),
PrintValidationMetrics(),
],
"logger": [tb_logger],
},
compile_model=True,
)


if __name__ == "__main__":
cli_main()
74 changes: 74 additions & 0 deletions scripts/full_detector/continue_oc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import torch
import wandb
from gnn_tracking.training.callbacks import ExpandWandbConfig, PrintValidationMetrics
from gnn_tracking.utils.loading import TrackingDataModule
from gnn_tracking.utils.nomenclature import random_trial_name
from lightning_fabric.plugins.environments.slurm import SLURMEnvironment
from pytorch_lightning.callbacks import (
EarlyStopping,
LearningRateMonitor,
ModelCheckpoint,
RichProgressBar,
)
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from wandb_osh.lightning_hooks import TriggerWandbSyncLightningCallback

from hpo2.lightning_utils import ContinueTrainingCLI

name = random_trial_name()


logger = WandbLogger(
project="gnn_tracking_fd",
group="gc-loss-legacy-norm",
offline=True,
version=name,
tags=["no-ec", "gc-loss", "continued"],
)

# Make sure that wandb init is called
_ = logger.experiment
wandb.define_metric(
"max_trk.double_majority_pt0.9",
step_metric="trk.double_majority_pt0.9",
summary="max",
)

tb_logger = TensorBoardLogger(".", version=name)


def cli_main():
torch.set_float32_matmul_precision("medium")

# noinspection PyUnusedLocal
cli = ContinueTrainingCLI( # noqa: F841
datamodule_class=TrackingDataModule,
trainer_defaults={
"callbacks": [
RichProgressBar(leave=True),
TriggerWandbSyncLightningCallback(),
PrintValidationMetrics(),
ExpandWandbConfig(),
EarlyStopping(
monitor="trk.perfect_pt0.9",
mode="max",
patience=20,
verbose=True,
),
ModelCheckpoint(
save_top_k=2,
monitor="trk.perfect_pt0.9",
mode="max",
verbose=True,
),
LearningRateMonitor(logging_interval="step", log_momentum=True),
],
"logger": [tb_logger, logger],
"plugins": [SLURMEnvironment(auto_requeue=False)],
"max_time": "1:23:30:00",
},
)


if __name__ == "__main__":
cli_main()
33 changes: 33 additions & 0 deletions scripts/full_detector/continue_oc_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import torch
from gnn_tracking.training.callbacks import PrintValidationMetrics
from gnn_tracking.utils.loading import TrackingDataModule
from gnn_tracking.utils.nomenclature import random_trial_name
from pytorch_lightning.callbacks import RichProgressBar
from pytorch_lightning.loggers import TensorBoardLogger

from hpo2.lightning_utils import ContinueTrainingCLI

name = random_trial_name()


tb_logger = TensorBoardLogger(".", version=name)


def cli_main():
torch.set_float32_matmul_precision("medium")

# noinspection PyUnusedLocal
cli = ContinueTrainingCLI( # noqa: F841
datamodule_class=TrackingDataModule,
trainer_defaults={
"callbacks": [
RichProgressBar(leave=True),
PrintValidationMetrics(),
],
"logger": [tb_logger],
},
)


if __name__ == "__main__":
cli_main()
3 changes: 1 addition & 2 deletions scripts/full_detector/run_gc_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import torch
from gnn_tracking.training.callbacks import PrintValidationMetrics
from gnn_tracking.utils.loading import TrackingDataModule
from pytorch_lightning.callbacks import RichProgressBar
Expand All @@ -7,7 +6,7 @@


def cli_main():
torch.set_float32_matmul_precision("medium")
# torch.set_float32_matmul_precision("medium")

# noinspection PyUnusedLocal
cli = TorchCompileCLI( # noqa: F841
Expand Down
Loading

0 comments on commit dcb04cd

Please sign in to comment.