Skip to content

Commit

Permalink
share freeze_model among run_lmp and run_relax; rename head -> model_…
Browse files Browse the repository at this point in the history
…frozen_head in run_lmp and run_relax

Signed-off-by: zjgemi <[email protected]>
  • Loading branch information
zjgemi committed Sep 3, 2024
1 parent 1666e91 commit 6acdbef
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 51 deletions.
2 changes: 1 addition & 1 deletion dpgen2/entrypoint/submit.py
Original file line number Diff line number Diff line change
Expand Up @@ -606,7 +606,7 @@ def workflow_concurrent_learning(
init_data = upload_artifact_and_print_uri(init_data, "multi_init_data")
train_config["multitask"] = True
train_config["head"] = head
explore_config["head"] = head
explore_config["model_frozen_head"] = head
else:
if config["inputs"]["init_data_uri"] is not None:
init_data = get_artifact_from_uri(config["inputs"]["init_data_uri"])
Expand Down
50 changes: 27 additions & 23 deletions dpgen2/op/run_lmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,28 +150,7 @@ def execute(
elif ext == ".pt":
# freeze model
mname = pytorch_model_name_pattern % (idx)
freeze_args = "-o %s" % mname
if config.get("head") is not None:
freeze_args += " --head %s" % config["head"]
freeze_cmd = "dp --pt freeze -c %s %s" % (mm, freeze_args)
ret, out, err = run_command(freeze_cmd, shell=True)
if ret != 0:
logging.error(
"".join(
(
"freeze failed\n",
"command was",
freeze_cmd,
"out msg",
out,
"\n",
"err msg",
err,
"\n",
)
)
)
raise TransientError("freeze failed")
freeze_model(mm, mname, config.get("model_frozen_head"))
else:
raise RuntimeError(
"Model file with extension '%s' is not supported" % ext
Expand Down Expand Up @@ -240,7 +219,7 @@ def lmp_args():
default=False,
doc=doc_shuffle_models,
),
Argument("head", str, optional=True, default=None, doc=doc_head),
Argument("model_frozen_head", str, optional=True, default=None, doc=doc_head),
]

@staticmethod
Expand Down Expand Up @@ -310,3 +289,28 @@ def find_only_one_key(lmp_lines, key, raise_not_found=True):
else:
return None
return found[0]


def freeze_model(input_model, frozen_model, head=None):
freeze_args = "-o %s" % frozen_model
if head is not None:
freeze_args += " --head %s" % head
freeze_cmd = "dp --pt freeze -c %s %s" % (input_model, freeze_args)
ret, out, err = run_command(freeze_cmd, shell=True)
if ret != 0:
logging.error(
"".join(
(
"freeze failed\n",
"command was",
freeze_cmd,
"out msg",
out,
"\n",
"err msg",
err,
"\n",
)
)
)
raise TransientError("freeze failed")
32 changes: 5 additions & 27 deletions dpgen2/op/run_relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
from dpgen2.exploration.task import (
DiffCSPTaskGroup,
)
from dpgen2.utils.run_command import (
run_command,
)

from .run_caly_model_devi import (
atoms2lmpdump,
)
from .run_lmp import (
freeze_model,
)


class RunRelax(OP):
Expand Down Expand Up @@ -81,33 +81,11 @@ def execute(
task_group = ip["diffcsp_task_grp"]
task = next(iter(task_group)) # Only support single task
models = ip["models"]
if ip["expl_config"].get("head") is not None:
if ip["expl_config"].get("model_frozen_head") is not None:
frozen_models = []
for idx in range(len(models)):
mname = pytorch_model_name_pattern % (idx)
freeze_cmd = "dp --pt freeze -c %s --head %s -o %s" % (
models[idx],
ip["expl_config"]["head"],
mname,
)
ret, out, err = run_command(freeze_cmd, shell=True)
if ret != 0:
logging.error(
"".join(
(
"freeze failed\n",
"command was",
freeze_cmd,
"out msg",
out,
"\n",
"err msg",
err,
"\n",
)
)
)
raise RuntimeError("freeze failed")
freeze_model(models[idx], mname, ip["expl_config"]["model_frozen_head"])
frozen_models.append(Path(mname))
models = frozen_models
relaxer = Relaxer(models[0])
Expand Down

0 comments on commit 6acdbef

Please sign in to comment.