Skip to content

Commit

Permalink
Use Kueue as default (#470)
Browse files Browse the repository at this point in the history
* Made Kueue the default queueing strategy

Updated oauth test to have mcad=True

Changed .codeflare/appwrappers to .codeflare/resources

Addressed comments & added SUSPENDED status

Review changes & list_cluster functions

Updated tests and load_components

Update tests, Rebase

* Update src/codeflare_sdk/cluster/cluster.py

Co-authored-by: Antonin Stefanutti <[email protected]>

---------

Co-authored-by: Antonin Stefanutti <[email protected]>
  • Loading branch information
Bobbins228 and astefanutti authored Apr 5, 2024
1 parent bd49ef7 commit 017fa12
Show file tree
Hide file tree
Showing 13 changed files with 515 additions and 23 deletions.
2 changes: 2 additions & 0 deletions src/codeflare_sdk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
RayCluster,
AppWrapper,
get_cluster,
list_all_queued,
list_all_clusters,
)

from .job import JobDefinition, Job, DDPJobDefinition, DDPJob, RayJobClient
Expand Down
8 changes: 7 additions & 1 deletion src/codeflare_sdk/cluster/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,12 @@
AppWrapper,
)

from .cluster import Cluster, ClusterConfiguration, get_cluster
from .cluster import (
Cluster,
ClusterConfiguration,
get_cluster,
list_all_queued,
list_all_clusters,
)

from .awload import AWManager
43 changes: 32 additions & 11 deletions src/codeflare_sdk/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ def create_app_wrapper(self):
dispatch_priority = self.config.dispatch_priority
write_to_file = self.config.write_to_file
verify_tls = self.config.verify_tls
local_queue = self.config.local_queue
return generate_appwrapper(
name=name,
namespace=namespace,
Expand All @@ -213,6 +214,7 @@ def create_app_wrapper(self):
priority_val=priority_val,
write_to_file=write_to_file,
verify_tls=verify_tls,
local_queue=local_queue,
)

# creates a new cluster with the provided or default spec
Expand Down Expand Up @@ -319,6 +321,9 @@ def status(
# check the ray cluster status
cluster = _ray_cluster_status(self.config.name, self.config.namespace)
if cluster:
if cluster.status == RayClusterStatus.SUSPENDED:
ready = False
status = CodeFlareClusterStatus.SUSPENDED
if cluster.status == RayClusterStatus.UNKNOWN:
ready = False
status = CodeFlareClusterStatus.STARTING
Expand Down Expand Up @@ -588,17 +593,24 @@ def list_all_clusters(namespace: str, print_to_console: bool = True):
return clusters


def list_all_queued(namespace: str, print_to_console: bool = True):
def list_all_queued(namespace: str, print_to_console: bool = True, mcad: bool = False):
"""
Returns (and prints by default) a list of all currently queued-up AppWrappers
Returns (and prints by default) a list of all currently queued-up Ray Clusters
in a given namespace.
"""
app_wrappers = _get_app_wrappers(
namespace, filter=[AppWrapperStatus.RUNNING, AppWrapperStatus.PENDING]
)
if print_to_console:
pretty_print.print_app_wrappers_status(app_wrappers)
return app_wrappers
if mcad:
resources = _get_app_wrappers(
namespace, filter=[AppWrapperStatus.RUNNING, AppWrapperStatus.PENDING]
)
if print_to_console:
pretty_print.print_app_wrappers_status(resources)
else:
resources = _get_ray_clusters(
namespace, filter=[RayClusterStatus.READY, RayClusterStatus.SUSPENDED]
)
if print_to_console:
pretty_print.print_ray_clusters_status(resources)
return resources


def get_current_namespace(): # pragma: no cover
Expand Down Expand Up @@ -798,7 +810,9 @@ def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]:
return None


def _get_ray_clusters(namespace="default") -> List[RayCluster]:
def _get_ray_clusters(
namespace="default", filter: Optional[List[RayClusterStatus]] = None
) -> List[RayCluster]:
list_of_clusters = []
try:
config_check()
Expand All @@ -812,8 +826,15 @@ def _get_ray_clusters(namespace="default") -> List[RayCluster]:
except Exception as e: # pragma: no cover
return _kube_api_error_handling(e)

for rc in rcs["items"]:
list_of_clusters.append(_map_to_ray_cluster(rc))
# Get a list of RCs with the filter if it is passed to the function
if filter is not None:
for rc in rcs["items"]:
ray_cluster = _map_to_ray_cluster(rc)
if filter and ray_cluster.status in filter:
list_of_clusters.append(ray_cluster)
else:
for rc in rcs["items"]:
list_of_clusters.append(_map_to_ray_cluster(rc))
return list_of_clusters


Expand Down
4 changes: 3 additions & 1 deletion src/codeflare_sdk/cluster/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ class ClusterConfiguration:
num_gpus: int = 0
template: str = f"{dir}/templates/base-template.yaml"
instascale: bool = False
mcad: bool = True
mcad: bool = False
envs: dict = field(default_factory=dict)
image: str = ""
local_interactive: bool = False
Expand All @@ -60,3 +60,5 @@ def __post_init__(self):
print(
"Warning: TLS verification has been disabled - Endpoint checks will be bypassed"
)

local_queue: str = None
2 changes: 2 additions & 0 deletions src/codeflare_sdk/cluster/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class RayClusterStatus(Enum):
UNHEALTHY = "unhealthy"
FAILED = "failed"
UNKNOWN = "unknown"
SUSPENDED = "suspended"


class AppWrapperStatus(Enum):
Expand Down Expand Up @@ -59,6 +60,7 @@ class CodeFlareClusterStatus(Enum):
QUEUEING = 4
FAILED = 5
UNKNOWN = 6
SUSPENDED = 7


@dataclass
Expand Down
62 changes: 57 additions & 5 deletions src/codeflare_sdk/utils/generate_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
(in the cluster sub-module) for AppWrapper generation.
"""

from typing import Optional
import typing
import yaml
import sys
Expand Down Expand Up @@ -460,29 +461,79 @@ def _create_oauth_sidecar_object(
)


def write_components(user_yaml: dict, output_file_name: str):
def get_default_kueue_name(namespace: str):
# If the local queue is set, use it. Otherwise, try to use the default queue.
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
local_queues = api_instance.list_namespaced_custom_object(
group="kueue.x-k8s.io",
version="v1beta1",
namespace=namespace,
plural="localqueues",
)
except Exception as e: # pragma: no cover
return _kube_api_error_handling(e)
for lq in local_queues["items"]:
if (
"annotations" in lq["metadata"]
and "kueue.x-k8s.io/default-queue" in lq["metadata"]["annotations"]
and lq["metadata"]["annotations"]["kueue.x-k8s.io/default-queue"].lower()
== "true"
):
return lq["metadata"]["name"]
raise ValueError(
"Default Local Queue with kueue.x-k8s.io/default-queue: true annotation not found please create a default Local Queue or provide the local_queue name in Cluster Configuration"
)


def write_components(
user_yaml: dict, output_file_name: str, namespace: str, local_queue: Optional[str]
):
# Create the directory if it doesn't exist
directory_path = os.path.dirname(output_file_name)
if not os.path.exists(directory_path):
os.makedirs(directory_path)

components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
open(output_file_name, "w").close()
lq_name = local_queue or get_default_kueue_name(namespace)
with open(output_file_name, "a") as outfile:
for component in components:
if "generictemplate" in component:
if (
"workload.codeflare.dev/appwrapper"
in component["generictemplate"]["metadata"]["labels"]
):
del component["generictemplate"]["metadata"]["labels"][
"workload.codeflare.dev/appwrapper"
]
labels = component["generictemplate"]["metadata"]["labels"]
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
outfile.write("---\n")
yaml.dump(
component["generictemplate"], outfile, default_flow_style=False
)
print(f"Written to: {output_file_name}")


def load_components(user_yaml: dict, name: str):
def load_components(
user_yaml: dict, name: str, namespace: str, local_queue: Optional[str]
):
component_list = []
components = user_yaml.get("spec", "resources")["resources"].get("GenericItems")
lq_name = local_queue or get_default_kueue_name(namespace)
for component in components:
if "generictemplate" in component:
if (
"workload.codeflare.dev/appwrapper"
in component["generictemplate"]["metadata"]["labels"]
):
del component["generictemplate"]["metadata"]["labels"][
"workload.codeflare.dev/appwrapper"
]
labels = component["generictemplate"]["metadata"]["labels"]
labels.update({"kueue.x-k8s.io/queue-name": lq_name})
component_list.append(component["generictemplate"])

resources = "---\n" + "---\n".join(
Expand Down Expand Up @@ -523,6 +574,7 @@ def generate_appwrapper(
priority_val: int,
write_to_file: bool,
verify_tls: bool,
local_queue: Optional[str],
):
user_yaml = read_template(template)
appwrapper_name, cluster_name = gen_names(name)
Expand Down Expand Up @@ -575,18 +627,18 @@ def generate_appwrapper(
if is_openshift_cluster():
enable_openshift_oauth(user_yaml, cluster_name, namespace)

directory_path = os.path.expanduser("~/.codeflare/appwrapper/")
directory_path = os.path.expanduser("~/.codeflare/resources/")
outfile = os.path.join(directory_path, appwrapper_name + ".yaml")

if write_to_file:
if mcad:
write_user_appwrapper(user_yaml, outfile)
else:
write_components(user_yaml, outfile)
write_components(user_yaml, outfile, namespace, local_queue)
return outfile
else:
if mcad:
user_yaml = load_appwrapper(user_yaml, name)
else:
user_yaml = load_components(user_yaml, name)
user_yaml = load_components(user_yaml, name, namespace, local_queue)
return user_yaml
24 changes: 24 additions & 0 deletions src/codeflare_sdk/utils/pretty_print.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,30 @@ def print_app_wrappers_status(app_wrappers: List[AppWrapper], starting: bool = F
console.print(Panel.fit(table))


def print_ray_clusters_status(app_wrappers: List[AppWrapper], starting: bool = False):
if not app_wrappers:
print_no_resources_found()
return # shortcircuit

console = Console()
table = Table(
box=box.ASCII_DOUBLE_HEAD,
title="[bold] :rocket: Cluster Queue Status :rocket:",
)
table.add_column("Name", style="cyan", no_wrap=True)
table.add_column("Status", style="magenta")

for app_wrapper in app_wrappers:
name = app_wrapper.name
status = app_wrapper.status.value
if starting:
status += " (starting)"
table.add_row(name, status)
table.add_row("") # empty row for spacing

console.print(Panel.fit(table))


def print_cluster_status(cluster: RayCluster):
"Pretty prints the status of a passed-in cluster"
if not cluster:
Expand Down
1 change: 1 addition & 0 deletions tests/e2e/mnist_raycluster_sdk_oauth_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def run_mnist_raycluster_sdk_oauth(self):
instascale=False,
image=ray_image,
write_to_file=True,
mcad=True,
)
)

Expand Down
1 change: 1 addition & 0 deletions tests/e2e/mnist_raycluster_sdk_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def run_mnist_raycluster_sdk(self):
instascale=False,
image=ray_image,
write_to_file=True,
mcad=True,
)
)

Expand Down
1 change: 1 addition & 0 deletions tests/e2e/start_ray_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
num_gpus=0,
instascale=False,
image=ray_image,
mcad=True,
)
)

Expand Down
2 changes: 1 addition & 1 deletion tests/test-case-no-mcad.yamls
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ metadata:
sdk.codeflare.dev/local_interactive: 'False'
labels:
controller-tools.k8s.io: '1.0'
workload.codeflare.dev/appwrapper: unit-test-cluster-ray
kueue.x-k8s.io/queue-name: local-queue-default
name: unit-test-cluster-ray
namespace: ns
spec:
Expand Down
Loading

0 comments on commit 017fa12

Please sign in to comment.