From c94ab340a1bbf4f6228644af6afc4a12fa3589f0 Mon Sep 17 00:00:00 2001 From: Michal Szadkowski Date: Fri, 13 Sep 2024 12:50:48 +0200 Subject: [PATCH] Split RunPolicy validators to Update and Create --- pkg/common/util/webhooks.go | 21 +++++++++---------- .../paddlepaddle/paddlepaddle_webhook.go | 15 +++++-------- pkg/webhooks/pytorch/pytorchjob_webhook.go | 15 +++++-------- pkg/webhooks/tensorflow/tfjob_webhook.go | 15 +++++-------- pkg/webhooks/xgboost/xgboostjob_webhook.go | 15 +++++-------- 5 files changed, 30 insertions(+), 51 deletions(-) diff --git a/pkg/common/util/webhooks.go b/pkg/common/util/webhooks.go index a99f08eb95..46693239bb 100644 --- a/pkg/common/util/webhooks.go +++ b/pkg/common/util/webhooks.go @@ -12,18 +12,10 @@ var supportedJobControllers = sets.New( v1.MultiKueueController, v1.KubeflowJobsController) -func ValidateManagedBy(oldRunPolicy *v1.RunPolicy, newRunPolicy *v1.RunPolicy) field.ErrorList { +func ValidateRunPolicyCreate(runPolicy *v1.RunPolicy) field.ErrorList { errs := field.ErrorList{} - // Validate immutability - if oldRunPolicy != nil && newRunPolicy != nil { - oldManager := oldRunPolicy.ManagedBy - newManager := newRunPolicy.ManagedBy - fieldPath := field.NewPath("spec", "runPolicy", "managedBy") - errs = apivalidation.ValidateImmutableField(newManager, oldManager, fieldPath) - } - // Validate the value - if newRunPolicy != nil && newRunPolicy.ManagedBy != nil { - manager := *newRunPolicy.ManagedBy + if runPolicy.ManagedBy != nil { + manager := *runPolicy.ManagedBy if !supportedJobControllers.Has(manager) { fieldPath := field.NewPath("spec", "runPolicy", "managedBy") errs = append(errs, field.NotSupported(fieldPath, manager, supportedJobControllers.UnsortedList())) @@ -31,3 +23,10 @@ func ValidateManagedBy(oldRunPolicy *v1.RunPolicy, newRunPolicy *v1.RunPolicy) f } return errs } + +func ValidateRunPolicyUpdate(oldRunPolicy, newRunPolicy *v1.RunPolicy) field.ErrorList { + oldManager := oldRunPolicy.ManagedBy + newManager := newRunPolicy.ManagedBy + fieldPath := field.NewPath("spec", "runPolicy", "managedBy") + return apivalidation.ValidateImmutableField(newManager, oldManager, fieldPath) +} diff --git a/pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go b/pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go index 6817cd6da2..8f584003e9 100644 --- a/pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go +++ b/pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go @@ -77,18 +77,13 @@ func validatePaddleJob(oldJob, newJob *trainingoperator.PaddleJob) field.ErrorLi allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), newJob.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ",")))) } - allErrs = append(allErrs, validateRunPolicy(oldJob, newJob)...) - allErrs = append(allErrs, validateSpec(newJob.Spec.PaddleReplicaSpecs)...) - return allErrs -} - -func validateRunPolicy(oldJob, newJob *trainingoperator.PaddleJob) field.ErrorList { - var oldRunPolicy, newRunPolicy *trainingoperator.RunPolicy = nil, &newJob.Spec.RunPolicy if oldJob != nil { - oldRunPolicy = &oldJob.Spec.RunPolicy + allErrs = append(allErrs, util.ValidateRunPolicyUpdate(&oldJob.Spec.RunPolicy, &newJob.Spec.RunPolicy)...) + } else { + allErrs = append(allErrs, util.ValidateRunPolicyCreate(&newJob.Spec.RunPolicy)...) } - - return util.ValidateManagedBy(oldRunPolicy, newRunPolicy) + allErrs = append(allErrs, validateSpec(newJob.Spec.PaddleReplicaSpecs)...) + return allErrs } func validateSpec(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList { diff --git a/pkg/webhooks/pytorch/pytorchjob_webhook.go b/pkg/webhooks/pytorch/pytorchjob_webhook.go index 362df4a4ac..e8511ff18b 100644 --- a/pkg/webhooks/pytorch/pytorchjob_webhook.go +++ b/pkg/webhooks/pytorch/pytorchjob_webhook.go @@ -80,22 +80,17 @@ func validatePyTorchJob(oldJob, newJob *trainingoperator.PyTorchJob) (admission. if errors := apimachineryvalidation.NameIsDNS1035Label(newJob.ObjectMeta.Name, false); len(errors) != 0 { allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), newJob.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ",")))) } - allErrs = append(allErrs, validateRunPolicy(oldJob, newJob)...) + if oldJob != nil { + allErrs = append(allErrs, util.ValidateRunPolicyUpdate(&oldJob.Spec.RunPolicy, &newJob.Spec.RunPolicy)...) + } else { + allErrs = append(allErrs, util.ValidateRunPolicyCreate(&newJob.Spec.RunPolicy)...) + } ws, err := validateSpec(newJob.Spec) warnings = append(warnings, ws...) allErrs = append(allErrs, err...) return warnings, allErrs } -func validateRunPolicy(oldJob, newJob *trainingoperator.PyTorchJob) field.ErrorList { - var oldRunPolicy, newRunPolicy *trainingoperator.RunPolicy = nil, &newJob.Spec.RunPolicy - if oldJob != nil { - oldRunPolicy = &oldJob.Spec.RunPolicy - } - - return util.ValidateManagedBy(oldRunPolicy, newRunPolicy) -} - func validateSpec(spec trainingoperator.PyTorchJobSpec) (admission.Warnings, field.ErrorList) { var allErrs field.ErrorList var warnings admission.Warnings diff --git a/pkg/webhooks/tensorflow/tfjob_webhook.go b/pkg/webhooks/tensorflow/tfjob_webhook.go index d666e8f98e..499ff986cd 100644 --- a/pkg/webhooks/tensorflow/tfjob_webhook.go +++ b/pkg/webhooks/tensorflow/tfjob_webhook.go @@ -75,18 +75,13 @@ func validateTFJob(oldJob, newJob *trainingoperator.TFJob) field.ErrorList { if errors := apimachineryvalidation.NameIsDNS1035Label(newJob.Name, false); len(errors) != 0 { allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), newJob.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ",")))) } - allErrs = append(allErrs, validateRunPolicy(oldJob, newJob)...) - allErrs = append(allErrs, validateSpec(newJob.Spec)...) - return allErrs -} - -func validateRunPolicy(oldJob, newJob *trainingoperator.TFJob) field.ErrorList { - var oldRunPolicy, newRunPolicy *trainingoperator.RunPolicy = nil, &newJob.Spec.RunPolicy if oldJob != nil { - oldRunPolicy = &oldJob.Spec.RunPolicy + allErrs = append(allErrs, util.ValidateRunPolicyUpdate(&oldJob.Spec.RunPolicy, &newJob.Spec.RunPolicy)...) + } else { + allErrs = append(allErrs, util.ValidateRunPolicyCreate(&newJob.Spec.RunPolicy)...) } - - return util.ValidateManagedBy(oldRunPolicy, newRunPolicy) + allErrs = append(allErrs, validateSpec(newJob.Spec)...) + return allErrs } func validateSpec(spec trainingoperator.TFJobSpec) field.ErrorList { diff --git a/pkg/webhooks/xgboost/xgboostjob_webhook.go b/pkg/webhooks/xgboost/xgboostjob_webhook.go index eb95949149..a411a4c87c 100644 --- a/pkg/webhooks/xgboost/xgboostjob_webhook.go +++ b/pkg/webhooks/xgboost/xgboostjob_webhook.go @@ -76,18 +76,13 @@ func validateXGBoostJob(oldJob, newJob *trainingoperator.XGBoostJob) field.Error if errors := apimachineryvalidation.NameIsDNS1035Label(newJob.Name, false); len(errors) != 0 { allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), newJob.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ",")))) } - allErrs = append(allErrs, validateRunPolicy(oldJob, newJob)...) - allErrs = append(allErrs, validateSpec(newJob.Spec)...) - return allErrs -} - -func validateRunPolicy(oldJob, newJob *trainingoperator.XGBoostJob) field.ErrorList { - var oldRunPolicy, newRunPolicy *trainingoperator.RunPolicy = nil, &newJob.Spec.RunPolicy if oldJob != nil { - oldRunPolicy = &oldJob.Spec.RunPolicy + allErrs = append(allErrs, util.ValidateRunPolicyUpdate(&oldJob.Spec.RunPolicy, &newJob.Spec.RunPolicy)...) + } else { + allErrs = append(allErrs, util.ValidateRunPolicyCreate(&newJob.Spec.RunPolicy)...) } - - return util.ValidateManagedBy(oldRunPolicy, newRunPolicy) + allErrs = append(allErrs, validateSpec(newJob.Spec)...) + return allErrs } func validateSpec(spec trainingoperator.XGBoostJobSpec) field.ErrorList {