diff --git a/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_controller.go b/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_controller.go index e53957d1d4..eb091ba7a6 100644 --- a/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_controller.go +++ b/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_controller.go @@ -45,7 +45,7 @@ func init() { JobType: &kftraining.PaddleJob{}, AddToScheme: kftraining.AddToScheme, IsManagingObjectsOwner: isPaddleJob, - MultiKueueAdapter: &multikueueAdapter{}, + MultiKueueAdapter: kubeflowjob.NewMKAdapter(copyJobSpec, copyJobStatus, getEmptyList, gvk), })) } diff --git a/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_multikueue_adapter.go b/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_multikueue_adapter.go index 850ee3beb2..6617b478e3 100644 --- a/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_multikueue_adapter.go +++ b/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_multikueue_adapter.go @@ -17,101 +17,27 @@ limitations under the License. package paddlejob import ( - "context" - "errors" - "fmt" - kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/runtime/schema" - "k8s.io/apimachinery/pkg/types" - "k8s.io/klog/v2" "sigs.k8s.io/controller-runtime/pkg/client" - kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" - "sigs.k8s.io/kueue/pkg/controller/constants" "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/controller/jobs/kubeflow/kubeflowjob" "sigs.k8s.io/kueue/pkg/util/api" - clientutil "sigs.k8s.io/kueue/pkg/util/client" ) -type multikueueAdapter struct{} - -var _ jobframework.MultiKueueAdapter = (*multikueueAdapter)(nil) - -func (b *multikueueAdapter) SyncJob(ctx context.Context, localClient client.Client, remoteClient client.Client, key types.NamespacedName, workloadName, origin string) error { - localJob := kftraining.PaddleJob{} - err := localClient.Get(ctx, key, &localJob) - if err != nil { - return err - } +var _ jobframework.MultiKueueAdapter = kubeflowjob.NewMKAdapter(copyJobSpec, copyJobStatus, getEmptyList, gvk) - remoteJob := &kftraining.PaddleJob{} - err = remoteClient.Get(ctx, key, remoteJob) - if client.IgnoreNotFound(err) != nil { - return err - } - - // if the remote exists, just copy the status - if err == nil { - return clientutil.PatchStatus(ctx, localClient, &localJob, func() (bool, error) { - localJob.Status = remoteJob.Status - return true, nil - }) - } - - remoteJob = &kftraining.PaddleJob{ - ObjectMeta: api.CloneObjectMetaForCreation(&localJob.ObjectMeta), - Spec: *localJob.Spec.DeepCopy(), - } - - // add the prebuilt workload - if remoteJob.Labels == nil { - remoteJob.Labels = make(map[string]string, 2) - } - remoteJob.Labels[constants.PrebuiltWorkloadLabel] = workloadName - remoteJob.Labels[kueuealpha.MultiKueueOriginLabel] = origin - - return remoteClient.Create(ctx, remoteJob) +func copyJobStatus(dst, src *kftraining.PaddleJob) { + dst.Status = src.Status } -func (b *multikueueAdapter) DeleteRemoteObject(ctx context.Context, remoteClient client.Client, key types.NamespacedName) error { - job := kftraining.PaddleJob{} - err := remoteClient.Get(ctx, key, &job) - if err != nil { - return client.IgnoreNotFound(err) +func copyJobSpec(dst, src *kftraining.PaddleJob) { + *dst = kftraining.PaddleJob{ + ObjectMeta: api.CloneObjectMetaForCreation(&src.ObjectMeta), + Spec: *src.Spec.DeepCopy(), } - return client.IgnoreNotFound(remoteClient.Delete(ctx, &job)) } -func (b *multikueueAdapter) KeepAdmissionCheckPending() bool { - return false -} - -func (b *multikueueAdapter) IsJobManagedByKueue(context.Context, client.Client, types.NamespacedName) (bool, string, error) { - return true, "", nil -} - -func (b *multikueueAdapter) GVK() schema.GroupVersionKind { - return gvk -} - -var _ jobframework.MultiKueueWatcher = (*multikueueAdapter)(nil) - -func (*multikueueAdapter) GetEmptyList() client.ObjectList { +func getEmptyList() client.ObjectList { return &kftraining.PaddleJobList{} } - -func (*multikueueAdapter) WorkloadKeyFor(o runtime.Object) (types.NamespacedName, error) { - paddleJob, isPaddleJob := o.(*kftraining.PaddleJob) - if !isPaddleJob { - return types.NamespacedName{}, errors.New("not a PaddleJob") - } - - prebuiltWl, hasPrebuiltWorkload := paddleJob.Labels[constants.PrebuiltWorkloadLabel] - if !hasPrebuiltWorkload { - return types.NamespacedName{}, fmt.Errorf("no prebuilt workload found for PaddleJob: %s", klog.KObj(paddleJob)) - } - - return types.NamespacedName{Name: prebuiltWl, Namespace: paddleJob.Namespace}, nil -} diff --git a/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_multikueue_adapter_test.go b/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_multikueue_adapter_test.go index 2c71041ecf..d896d989b4 100644 --- a/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_multikueue_adapter_test.go +++ b/pkg/controller/jobs/kubeflow/jobs/paddlejob/paddlejob_multikueue_adapter_test.go @@ -31,6 +31,8 @@ import ( kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" "sigs.k8s.io/kueue/pkg/controller/constants" + "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/controller/jobs/kubeflow/kubeflowjob" "sigs.k8s.io/kueue/pkg/util/slices" utiltesting "sigs.k8s.io/kueue/pkg/util/testing" kfutiltesting "sigs.k8s.io/kueue/pkg/util/testingjobs/paddlejob" @@ -52,7 +54,7 @@ func TestMultikueueAdapter(t *testing.T) { managersPaddleJobs []kftraining.PaddleJob workerPaddleJobs []kftraining.PaddleJob - operation func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error + operation func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error wantError error wantManagersPaddleJobs []kftraining.PaddleJob @@ -62,7 +64,8 @@ func TestMultikueueAdapter(t *testing.T) { managersPaddleJobs: []kftraining.PaddleJob{ *paddleJobBuilder.Clone().Obj(), }, - operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "paddlejob1", Namespace: TestNamespace}, "wl1", "origin1") }, @@ -87,7 +90,7 @@ func TestMultikueueAdapter(t *testing.T) { StatusConditions(kftraining.JobCondition{Type: kftraining.JobSucceeded, Status: corev1.ConditionTrue}). Obj(), }, - operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "paddlejob1", Namespace: TestNamespace}, "wl1", "origin1") }, @@ -111,7 +114,7 @@ func TestMultikueueAdapter(t *testing.T) { Label(kueuealpha.MultiKueueOriginLabel, "origin1"). Obj(), }, - operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "paddlejob1", Namespace: TestNamespace}) }, }, @@ -129,7 +132,7 @@ func TestMultikueueAdapter(t *testing.T) { ctx, _ := utiltesting.ContextWithLog(t) - adapter := &multikueueAdapter{} + adapter := kubeflowjob.NewMKAdapter(copyJobSpec, copyJobStatus, getEmptyList, gvk) gotErr := tc.operation(ctx, adapter, managerClient, workerClient) diff --git a/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorch_multikueue_adapter.go b/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorch_multikueue_adapter.go index 2976d79add..2e11d957e7 100644 --- a/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorch_multikueue_adapter.go +++ b/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorch_multikueue_adapter.go @@ -17,101 +17,27 @@ limitations under the License. package pytorchjob import ( - "context" - "errors" - "fmt" - kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/runtime/schema" - "k8s.io/apimachinery/pkg/types" - "k8s.io/klog/v2" "sigs.k8s.io/controller-runtime/pkg/client" - kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" - "sigs.k8s.io/kueue/pkg/controller/constants" "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/controller/jobs/kubeflow/kubeflowjob" "sigs.k8s.io/kueue/pkg/util/api" - clientutil "sigs.k8s.io/kueue/pkg/util/client" ) -type multikueueAdapter struct{} - -var _ jobframework.MultiKueueAdapter = (*multikueueAdapter)(nil) - -func (b *multikueueAdapter) SyncJob(ctx context.Context, localClient client.Client, remoteClient client.Client, key types.NamespacedName, workloadName, origin string) error { - localJob := kftraining.PyTorchJob{} - err := localClient.Get(ctx, key, &localJob) - if err != nil { - return err - } +var _ jobframework.MultiKueueAdapter = kubeflowjob.NewMKAdapter(copyJobSpec, copyJobStatus, getEmptyList, gvk) - remoteJob := &kftraining.PyTorchJob{} - err = remoteClient.Get(ctx, key, remoteJob) - if client.IgnoreNotFound(err) != nil { - return err - } - - // if the remote exists, just copy the status - if err == nil { - return clientutil.PatchStatus(ctx, localClient, &localJob, func() (bool, error) { - localJob.Status = remoteJob.Status - return true, nil - }) - } - - remoteJob = &kftraining.PyTorchJob{ - ObjectMeta: api.CloneObjectMetaForCreation(&localJob.ObjectMeta), - Spec: *localJob.Spec.DeepCopy(), - } - - // add the prebuilt workload - if remoteJob.Labels == nil { - remoteJob.Labels = make(map[string]string, 2) - } - remoteJob.Labels[constants.PrebuiltWorkloadLabel] = workloadName - remoteJob.Labels[kueuealpha.MultiKueueOriginLabel] = origin - - return remoteClient.Create(ctx, remoteJob) +func copyJobStatus(dst, src *kftraining.PyTorchJob) { + dst.Status = src.Status } -func (b *multikueueAdapter) DeleteRemoteObject(ctx context.Context, remoteClient client.Client, key types.NamespacedName) error { - job := kftraining.PyTorchJob{} - err := remoteClient.Get(ctx, key, &job) - if err != nil { - return client.IgnoreNotFound(err) +func copyJobSpec(dst, src *kftraining.PyTorchJob) { + *dst = kftraining.PyTorchJob{ + ObjectMeta: api.CloneObjectMetaForCreation(&src.ObjectMeta), + Spec: *src.Spec.DeepCopy(), } - return client.IgnoreNotFound(remoteClient.Delete(ctx, &job)) } -func (b *multikueueAdapter) KeepAdmissionCheckPending() bool { - return false -} - -func (b *multikueueAdapter) IsJobManagedByKueue(context.Context, client.Client, types.NamespacedName) (bool, string, error) { - return true, "", nil -} - -func (b *multikueueAdapter) GVK() schema.GroupVersionKind { - return gvk -} - -var _ jobframework.MultiKueueWatcher = (*multikueueAdapter)(nil) - -func (*multikueueAdapter) GetEmptyList() client.ObjectList { +func getEmptyList() client.ObjectList { return &kftraining.PyTorchJobList{} } - -func (*multikueueAdapter) WorkloadKeyFor(o runtime.Object) (types.NamespacedName, error) { - pyTorchJob, isPyTorchJob := o.(*kftraining.PyTorchJob) - if !isPyTorchJob { - return types.NamespacedName{}, errors.New("not a PyTorchJob") - } - - prebuiltWl, hasPrebuiltWorkload := pyTorchJob.Labels[constants.PrebuiltWorkloadLabel] - if !hasPrebuiltWorkload { - return types.NamespacedName{}, fmt.Errorf("no prebuilt workload found for PyTorchJob: %s", klog.KObj(pyTorchJob)) - } - - return types.NamespacedName{Name: prebuiltWl, Namespace: pyTorchJob.Namespace}, nil -} diff --git a/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorch_multikueue_adapter_test.go b/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorch_multikueue_adapter_test.go index 8ae5db1a41..baed632397 100644 --- a/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorch_multikueue_adapter_test.go +++ b/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorch_multikueue_adapter_test.go @@ -32,6 +32,8 @@ import ( kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" "sigs.k8s.io/kueue/pkg/controller/constants" + "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/controller/jobs/kubeflow/kubeflowjob" "sigs.k8s.io/kueue/pkg/util/slices" utiltesting "sigs.k8s.io/kueue/pkg/util/testing" kfutiltesting "sigs.k8s.io/kueue/pkg/util/testingjobs/pytorchjob" @@ -53,7 +55,7 @@ func TestMultikueueAdapter(t *testing.T) { managersPyTorchJobs []kftraining.PyTorchJob workerPyTorchJobs []kftraining.PyTorchJob - operation func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error + operation func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error wantError error wantManagersPyTorchJobs []kftraining.PyTorchJob @@ -63,7 +65,7 @@ func TestMultikueueAdapter(t *testing.T) { managersPyTorchJobs: []kftraining.PyTorchJob{ *pyTorchJobBuilder.Clone().Obj(), }, - operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "pytorchjob1", Namespace: TestNamespace}, "wl1", "origin1") }, @@ -88,7 +90,7 @@ func TestMultikueueAdapter(t *testing.T) { StatusConditions(kftraining.JobCondition{Type: kftraining.JobSucceeded, Status: corev1.ConditionTrue}). Obj(), }, - operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "pytorchjob1", Namespace: TestNamespace}, "wl1", "origin1") }, @@ -112,7 +114,7 @@ func TestMultikueueAdapter(t *testing.T) { Label(kueuealpha.MultiKueueOriginLabel, "origin1"). Obj(), }, - operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "pytorchjob1", Namespace: TestNamespace}) }, }, @@ -130,7 +132,7 @@ func TestMultikueueAdapter(t *testing.T) { ctx, _ := utiltesting.ContextWithLog(t) - adapter := &multikueueAdapter{} + adapter := kubeflowjob.NewMKAdapter(copyJobSpec, copyJobStatus, getEmptyList, gvk) gotErr := tc.operation(ctx, adapter, managerClient, workerClient) diff --git a/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorchjob_controller.go b/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorchjob_controller.go index 1bfc88cb95..b63f17cbc0 100644 --- a/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorchjob_controller.go +++ b/pkg/controller/jobs/kubeflow/jobs/pytorchjob/pytorchjob_controller.go @@ -45,7 +45,7 @@ func init() { JobType: &kftraining.PyTorchJob{}, AddToScheme: kftraining.AddToScheme, IsManagingObjectsOwner: isPyTorchJob, - MultiKueueAdapter: &multikueueAdapter{}, + MultiKueueAdapter: kubeflowjob.NewMKAdapter(copyJobSpec, copyJobStatus, getEmptyList, gvk), })) } diff --git a/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_controller.go b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_controller.go index e3932ab393..02f9896a19 100644 --- a/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_controller.go +++ b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_controller.go @@ -45,7 +45,7 @@ func init() { JobType: &kftraining.TFJob{}, AddToScheme: kftraining.AddToScheme, IsManagingObjectsOwner: isTFJob, - MultiKueueAdapter: &multikueueAdapter{}, + MultiKueueAdapter: kubeflowjob.NewMKAdapter(copyJobSpec, copyJobStatus, getEmptyList, gvk), })) } diff --git a/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter.go b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter.go index 4c81578ef8..8a053c67a1 100644 --- a/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter.go +++ b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter.go @@ -17,101 +17,27 @@ limitations under the License. package tfjob import ( - "context" - "errors" - "fmt" - kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/runtime/schema" - "k8s.io/apimachinery/pkg/types" - "k8s.io/klog/v2" "sigs.k8s.io/controller-runtime/pkg/client" - kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" - "sigs.k8s.io/kueue/pkg/controller/constants" "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/controller/jobs/kubeflow/kubeflowjob" "sigs.k8s.io/kueue/pkg/util/api" - clientutil "sigs.k8s.io/kueue/pkg/util/client" ) -type multikueueAdapter struct{} - -var _ jobframework.MultiKueueAdapter = (*multikueueAdapter)(nil) - -func (b *multikueueAdapter) SyncJob(ctx context.Context, localClient client.Client, remoteClient client.Client, key types.NamespacedName, workloadName, origin string) error { - localJob := kftraining.TFJob{} - err := localClient.Get(ctx, key, &localJob) - if err != nil { - return err - } +var _ jobframework.MultiKueueAdapter = kubeflowjob.NewMKAdapter(copyJobSpec, copyJobStatus, getEmptyList, gvk) - remoteJob := &kftraining.TFJob{} - err = remoteClient.Get(ctx, key, remoteJob) - if client.IgnoreNotFound(err) != nil { - return err - } - - // if the remote exists, just copy the status - if err == nil { - return clientutil.PatchStatus(ctx, localClient, &localJob, func() (bool, error) { - localJob.Status = remoteJob.Status - return true, nil - }) - } - - remoteJob = &kftraining.TFJob{ - ObjectMeta: api.CloneObjectMetaForCreation(&localJob.ObjectMeta), - Spec: *localJob.Spec.DeepCopy(), - } - - // add the prebuilt workload - if remoteJob.Labels == nil { - remoteJob.Labels = make(map[string]string, 2) - } - remoteJob.Labels[constants.PrebuiltWorkloadLabel] = workloadName - remoteJob.Labels[kueuealpha.MultiKueueOriginLabel] = origin - - return remoteClient.Create(ctx, remoteJob) +func copyJobStatus(dst, src *kftraining.TFJob) { + dst.Status = src.Status } -func (b *multikueueAdapter) DeleteRemoteObject(ctx context.Context, remoteClient client.Client, key types.NamespacedName) error { - job := kftraining.TFJob{} - err := remoteClient.Get(ctx, key, &job) - if err != nil { - return client.IgnoreNotFound(err) +func copyJobSpec(dst, src *kftraining.TFJob) { + *dst = kftraining.TFJob{ + ObjectMeta: api.CloneObjectMetaForCreation(&src.ObjectMeta), + Spec: *src.Spec.DeepCopy(), } - return client.IgnoreNotFound(remoteClient.Delete(ctx, &job)) } -func (b *multikueueAdapter) KeepAdmissionCheckPending() bool { - return false -} - -func (b *multikueueAdapter) IsJobManagedByKueue(ctx context.Context, c client.Client, key types.NamespacedName) (bool, string, error) { - return true, "", nil -} - -func (b *multikueueAdapter) GVK() schema.GroupVersionKind { - return gvk -} - -var _ jobframework.MultiKueueWatcher = (*multikueueAdapter)(nil) - -func (*multikueueAdapter) GetEmptyList() client.ObjectList { +func getEmptyList() client.ObjectList { return &kftraining.TFJobList{} } - -func (*multikueueAdapter) WorkloadKeyFor(o runtime.Object) (types.NamespacedName, error) { - tfJob, isTfJob := o.(*kftraining.TFJob) - if !isTfJob { - return types.NamespacedName{}, errors.New("not a TFJob") - } - - prebuiltWl, hasPrebuiltWorkload := tfJob.Labels[constants.PrebuiltWorkloadLabel] - if !hasPrebuiltWorkload { - return types.NamespacedName{}, fmt.Errorf("no prebuilt workload found for TFJob: %s", klog.KObj(tfJob)) - } - - return types.NamespacedName{Name: prebuiltWl, Namespace: tfJob.Namespace}, nil -} diff --git a/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter_test.go b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter_test.go index 69fd4d5915..d50de8e761 100644 --- a/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter_test.go +++ b/pkg/controller/jobs/kubeflow/jobs/tfjob/tfjob_multikueue_adapter_test.go @@ -31,6 +31,8 @@ import ( kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" "sigs.k8s.io/kueue/pkg/controller/constants" + "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/controller/jobs/kubeflow/kubeflowjob" "sigs.k8s.io/kueue/pkg/util/slices" utiltesting "sigs.k8s.io/kueue/pkg/util/testing" kfutiltesting "sigs.k8s.io/kueue/pkg/util/testingjobs/tfjob" @@ -52,7 +54,7 @@ func TestMultikueueAdapter(t *testing.T) { managersTFJobs []kftraining.TFJob workerTFJobs []kftraining.TFJob - operation func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error + operation func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error wantError error wantManagersTFJobs []kftraining.TFJob @@ -62,7 +64,7 @@ func TestMultikueueAdapter(t *testing.T) { managersTFJobs: []kftraining.TFJob{ *tfJobBuilder.Clone().Obj(), }, - operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}, "wl1", "origin1") }, @@ -87,7 +89,7 @@ func TestMultikueueAdapter(t *testing.T) { StatusConditions(kftraining.JobCondition{Type: kftraining.JobSucceeded, Status: corev1.ConditionTrue}). Obj(), }, - operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}, "wl1", "origin1") }, @@ -111,7 +113,7 @@ func TestMultikueueAdapter(t *testing.T) { Label(kueuealpha.MultiKueueOriginLabel, "origin1"). Obj(), }, - operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "tfjob1", Namespace: TestNamespace}) }, }, @@ -129,7 +131,7 @@ func TestMultikueueAdapter(t *testing.T) { ctx, _ := utiltesting.ContextWithLog(t) - adapter := &multikueueAdapter{} + adapter := kubeflowjob.NewMKAdapter(copyJobSpec, copyJobStatus, getEmptyList, gvk) gotErr := tc.operation(ctx, adapter, managerClient, workerClient) diff --git a/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_controller.go b/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_controller.go index 72f624c876..8ddd0dae14 100644 --- a/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_controller.go +++ b/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_controller.go @@ -45,7 +45,7 @@ func init() { JobType: &kftraining.XGBoostJob{}, AddToScheme: kftraining.AddToScheme, IsManagingObjectsOwner: isXGBoostJob, - MultiKueueAdapter: &multikueueAdapter{}, + MultiKueueAdapter: kubeflowjob.NewMKAdapter(copyJobSpec, copyJobStatus, getEmptyList, gvk), })) } diff --git a/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_multikueue_adapter.go b/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_multikueue_adapter.go index b5b841f0bc..cf1ca5d680 100644 --- a/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_multikueue_adapter.go +++ b/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_multikueue_adapter.go @@ -17,101 +17,27 @@ limitations under the License. package xgboostjob import ( - "context" - "errors" - "fmt" - kftraining "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" - "k8s.io/apimachinery/pkg/runtime" - "k8s.io/apimachinery/pkg/runtime/schema" - "k8s.io/apimachinery/pkg/types" - "k8s.io/klog/v2" "sigs.k8s.io/controller-runtime/pkg/client" - kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" - "sigs.k8s.io/kueue/pkg/controller/constants" "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/controller/jobs/kubeflow/kubeflowjob" "sigs.k8s.io/kueue/pkg/util/api" - clientutil "sigs.k8s.io/kueue/pkg/util/client" ) -type multikueueAdapter struct{} - -var _ jobframework.MultiKueueAdapter = (*multikueueAdapter)(nil) - -func (b *multikueueAdapter) SyncJob(ctx context.Context, localClient client.Client, remoteClient client.Client, key types.NamespacedName, workloadName, origin string) error { - localJob := kftraining.XGBoostJob{} - err := localClient.Get(ctx, key, &localJob) - if err != nil { - return err - } +var _ jobframework.MultiKueueAdapter = kubeflowjob.NewMKAdapter(copyJobSpec, copyJobStatus, getEmptyList, gvk) - remoteJob := &kftraining.XGBoostJob{} - err = remoteClient.Get(ctx, key, remoteJob) - if client.IgnoreNotFound(err) != nil { - return err - } - - // if the remote exists, just copy the status - if err == nil { - return clientutil.PatchStatus(ctx, localClient, &localJob, func() (bool, error) { - localJob.Status = remoteJob.Status - return true, nil - }) - } - - remoteJob = &kftraining.XGBoostJob{ - ObjectMeta: api.CloneObjectMetaForCreation(&localJob.ObjectMeta), - Spec: *localJob.Spec.DeepCopy(), - } - - // add the prebuilt workload - if remoteJob.Labels == nil { - remoteJob.Labels = make(map[string]string, 2) - } - remoteJob.Labels[constants.PrebuiltWorkloadLabel] = workloadName - remoteJob.Labels[kueuealpha.MultiKueueOriginLabel] = origin - - return remoteClient.Create(ctx, remoteJob) +func copyJobStatus(dst, src *kftraining.XGBoostJob) { + dst.Status = src.Status } -func (b *multikueueAdapter) DeleteRemoteObject(ctx context.Context, remoteClient client.Client, key types.NamespacedName) error { - job := kftraining.XGBoostJob{} - err := remoteClient.Get(ctx, key, &job) - if err != nil { - return client.IgnoreNotFound(err) +func copyJobSpec(dst, src *kftraining.XGBoostJob) { + *dst = kftraining.XGBoostJob{ + ObjectMeta: api.CloneObjectMetaForCreation(&src.ObjectMeta), + Spec: *src.Spec.DeepCopy(), } - return client.IgnoreNotFound(remoteClient.Delete(ctx, &job)) } -func (b *multikueueAdapter) KeepAdmissionCheckPending() bool { - return false -} - -func (b *multikueueAdapter) IsJobManagedByKueue(context.Context, client.Client, types.NamespacedName) (bool, string, error) { - return true, "", nil -} - -func (b *multikueueAdapter) GVK() schema.GroupVersionKind { - return gvk -} - -var _ jobframework.MultiKueueWatcher = (*multikueueAdapter)(nil) - -func (*multikueueAdapter) GetEmptyList() client.ObjectList { +func getEmptyList() client.ObjectList { return &kftraining.XGBoostJobList{} } - -func (*multikueueAdapter) WorkloadKeyFor(o runtime.Object) (types.NamespacedName, error) { - xgBoostJob, isXgBoostJob := o.(*kftraining.XGBoostJob) - if !isXgBoostJob { - return types.NamespacedName{}, errors.New("not a XgBoostJob") - } - - prebuiltWl, hasPrebuiltWorkload := xgBoostJob.Labels[constants.PrebuiltWorkloadLabel] - if !hasPrebuiltWorkload { - return types.NamespacedName{}, fmt.Errorf("no prebuilt workload found for XgBoostJob: %s", klog.KObj(xgBoostJob)) - } - - return types.NamespacedName{Name: prebuiltWl, Namespace: xgBoostJob.Namespace}, nil -} diff --git a/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_multikueue_adapter_test.go b/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_multikueue_adapter_test.go index 85dbbcb6c9..a0084482a3 100644 --- a/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_multikueue_adapter_test.go +++ b/pkg/controller/jobs/kubeflow/jobs/xgboostjob/xgboostjob_multikueue_adapter_test.go @@ -31,6 +31,8 @@ import ( kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" "sigs.k8s.io/kueue/pkg/controller/constants" + "sigs.k8s.io/kueue/pkg/controller/jobframework" + "sigs.k8s.io/kueue/pkg/controller/jobs/kubeflow/kubeflowjob" "sigs.k8s.io/kueue/pkg/util/slices" utiltesting "sigs.k8s.io/kueue/pkg/util/testing" kfutiltesting "sigs.k8s.io/kueue/pkg/util/testingjobs/xgboostjob" @@ -52,7 +54,7 @@ func TestMultikueueAdapter(t *testing.T) { managersXGBoostJobs []kftraining.XGBoostJob workerXGBoostJobs []kftraining.XGBoostJob - operation func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error + operation func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error wantError error wantManagersXGBoostJobs []kftraining.XGBoostJob @@ -62,7 +64,7 @@ func TestMultikueueAdapter(t *testing.T) { managersXGBoostJobs: []kftraining.XGBoostJob{ *xgboostJobBuilder.Clone().Obj(), }, - operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "xgboostjob1", Namespace: TestNamespace}, "wl1", "origin1") }, @@ -87,7 +89,7 @@ func TestMultikueueAdapter(t *testing.T) { StatusConditions(kftraining.JobCondition{Type: kftraining.JobSucceeded, Status: corev1.ConditionTrue}). Obj(), }, - operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { return adapter.SyncJob(ctx, managerClient, workerClient, types.NamespacedName{Name: "xgboostjob1", Namespace: TestNamespace}, "wl1", "origin1") }, @@ -111,7 +113,7 @@ func TestMultikueueAdapter(t *testing.T) { Label(kueuealpha.MultiKueueOriginLabel, "origin1"). Obj(), }, - operation: func(ctx context.Context, adapter *multikueueAdapter, managerClient, workerClient client.Client) error { + operation: func(ctx context.Context, adapter jobframework.MultiKueueAdapter, managerClient, workerClient client.Client) error { return adapter.DeleteRemoteObject(ctx, workerClient, types.NamespacedName{Name: "xgboostjob1", Namespace: TestNamespace}) }, }, @@ -129,7 +131,7 @@ func TestMultikueueAdapter(t *testing.T) { ctx, _ := utiltesting.ContextWithLog(t) - adapter := &multikueueAdapter{} + adapter := kubeflowjob.NewMKAdapter(copyJobSpec, copyJobStatus, getEmptyList, gvk) gotErr := tc.operation(ctx, adapter, managerClient, workerClient) diff --git a/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_multikueue_adapter.go b/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_multikueue_adapter.go new file mode 100644 index 0000000000..f422524848 --- /dev/null +++ b/pkg/controller/jobs/kubeflow/kubeflowjob/kubeflowjob_multikueue_adapter.go @@ -0,0 +1,145 @@ +/* +Copyright 2024 The Kubernetes Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package kubeflowjob + +import ( + "context" + "fmt" + + metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" + "k8s.io/apimachinery/pkg/runtime" + "k8s.io/apimachinery/pkg/runtime/schema" + "k8s.io/apimachinery/pkg/types" + "k8s.io/klog/v2" + + "sigs.k8s.io/controller-runtime/pkg/client" + + kueuealpha "sigs.k8s.io/kueue/apis/kueue/v1alpha1" + "sigs.k8s.io/kueue/pkg/controller/constants" + "sigs.k8s.io/kueue/pkg/controller/jobframework" + clientutil "sigs.k8s.io/kueue/pkg/util/client" +) + +type objAsPtr[T any] interface { + metav1.Object + client.Object + *T +} + +type adapter[PtrT objAsPtr[T], T any] struct { + copySpec func(dst, src PtrT) + copyStatus func(dst, src PtrT) + emptyList func() client.ObjectList + gvk schema.GroupVersionKind +} + +type fullInterface interface { + jobframework.MultiKueueAdapter + jobframework.MultiKueueWatcher +} + +func NewMKAdapter[PtrT objAsPtr[T], T any]( + copySpec func(dst, src PtrT), + copyStatus func(dst, src PtrT), + emptyList func() client.ObjectList, + gvk schema.GroupVersionKind, +) fullInterface { + return &adapter[PtrT, T]{ + copySpec: copySpec, + copyStatus: copyStatus, + emptyList: emptyList, + gvk: gvk, + } +} + +func (a adapter[PtrT, T]) GVK() schema.GroupVersionKind { + return a.gvk +} + +func (a adapter[PtrT, T]) KeepAdmissionCheckPending() bool { + return false +} + +func (a adapter[PtrT, T]) IsJobManagedByKueue(context.Context, client.Client, types.NamespacedName) (bool, string, error) { + return true, "", nil +} + +func (a adapter[PtrT, T]) SyncJob( + ctx context.Context, + localClient client.Client, + remoteClient client.Client, + key types.NamespacedName, + workloadName, origin string) error { + localJob := PtrT(new(T)) + err := localClient.Get(ctx, key, localJob) + if err != nil { + return err + } + + remoteJob := PtrT(new(T)) + err = remoteClient.Get(ctx, key, remoteJob) + if client.IgnoreNotFound(err) != nil { + return err + } + + if err == nil { + return clientutil.PatchStatus(ctx, localClient, localJob, func() (bool, error) { + // if the remote exists, just copy the status + a.copyStatus(localJob, remoteJob) + return true, nil + }) + } + + remoteJob = PtrT(new(T)) + a.copySpec(remoteJob, localJob) + + // add the prebuilt workload + labels := remoteJob.GetLabels() + if remoteJob.GetLabels() == nil { + labels = make(map[string]string, 2) + } + labels[constants.PrebuiltWorkloadLabel] = workloadName + labels[kueuealpha.MultiKueueOriginLabel] = origin + remoteJob.SetLabels(labels) + + return remoteClient.Create(ctx, remoteJob) +} + +func (a adapter[PtrT, T]) DeleteRemoteObject(ctx context.Context, remoteClient client.Client, key types.NamespacedName) error { + job := PtrT(new(T)) + job.SetName(key.Name) + job.SetNamespace(key.Namespace) + return client.IgnoreNotFound(remoteClient.Delete(ctx, job)) +} + +func (a adapter[PtrT, T]) GetEmptyList() client.ObjectList { + return a.emptyList() +} + +func (a adapter[PtrT, T]) WorkloadKeyFor(o runtime.Object) (types.NamespacedName, error) { + job, isTheJob := o.(PtrT) + if !isTheJob { + return types.NamespacedName{}, fmt.Errorf("not a %s", a.gvk.Kind) + } + + prebuiltWl, hasPrebuiltWorkload := job.GetLabels()[constants.PrebuiltWorkloadLabel] + if !hasPrebuiltWorkload { + return types.NamespacedName{}, fmt.Errorf("no prebuilt workload found for %s: %s", a.gvk.Kind, klog.KObj(job)) + } + + return types.NamespacedName{Name: prebuiltWl, Namespace: job.GetNamespace()}, nil +}