diff --git a/sdk/python/kubeflow/training/utils/utils.py b/sdk/python/kubeflow/training/utils/utils.py index 250252677f..7b76b24f3d 100644 --- a/sdk/python/kubeflow/training/utils/utils.py +++ b/sdk/python/kubeflow/training/utils/utils.py @@ -193,6 +193,8 @@ def get_container_spec( args: Optional[List[str]] = None, resources: Union[dict, models.V1ResourceRequirements, None] = None, volume_mounts: Optional[List[models.V1VolumeMount]] = None, + env: Optional[List[models.V1EnvVar]] = None, + env_from: Optional[List[models.V1EnvFromSource]] = None, ) -> models.V1Container: """ Get container spec for the given parameters. @@ -230,6 +232,10 @@ def get_container_spec( # Add resources to the container spec. container_spec.resources = resources + # Add environment variables to the container spec. + container_spec.env = env if env else None + container_spec.env_from = env_from if env_from else None + return container_spec @@ -237,6 +243,7 @@ def get_pod_template_spec( containers: List[models.V1Container], init_containers: Optional[List[models.V1Container]] = None, volumes: Optional[List[models.V1Volume]] = None, + restart_policy: Optional[str] = None, ) -> models.V1PodTemplateSpec: """ Get Pod template spec for the given parameters. @@ -251,6 +258,7 @@ def get_pod_template_spec( init_containers=init_containers, containers=containers, volumes=volumes, + restart_policy=restart_policy, ), )