Skip to content

Commit

Permalink
Split RunPolicy validators to Update and Create
Browse files Browse the repository at this point in the history
  • Loading branch information
mszadkow committed Sep 13, 2024
1 parent 80576f6 commit c94ab34
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 51 deletions.
21 changes: 10 additions & 11 deletions pkg/common/util/webhooks.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,21 @@ var supportedJobControllers = sets.New(
v1.MultiKueueController,
v1.KubeflowJobsController)

func ValidateManagedBy(oldRunPolicy *v1.RunPolicy, newRunPolicy *v1.RunPolicy) field.ErrorList {
func ValidateRunPolicyCreate(runPolicy *v1.RunPolicy) field.ErrorList {
errs := field.ErrorList{}
// Validate immutability
if oldRunPolicy != nil && newRunPolicy != nil {
oldManager := oldRunPolicy.ManagedBy
newManager := newRunPolicy.ManagedBy
fieldPath := field.NewPath("spec", "runPolicy", "managedBy")
errs = apivalidation.ValidateImmutableField(newManager, oldManager, fieldPath)
}
// Validate the value
if newRunPolicy != nil && newRunPolicy.ManagedBy != nil {
manager := *newRunPolicy.ManagedBy
if runPolicy.ManagedBy != nil {
manager := *runPolicy.ManagedBy
if !supportedJobControllers.Has(manager) {
fieldPath := field.NewPath("spec", "runPolicy", "managedBy")
errs = append(errs, field.NotSupported(fieldPath, manager, supportedJobControllers.UnsortedList()))
}
}
return errs
}

func ValidateRunPolicyUpdate(oldRunPolicy, newRunPolicy *v1.RunPolicy) field.ErrorList {
oldManager := oldRunPolicy.ManagedBy
newManager := newRunPolicy.ManagedBy
fieldPath := field.NewPath("spec", "runPolicy", "managedBy")
return apivalidation.ValidateImmutableField(newManager, oldManager, fieldPath)
}
15 changes: 5 additions & 10 deletions pkg/webhooks/paddlepaddle/paddlepaddle_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -77,18 +77,13 @@ func validatePaddleJob(oldJob, newJob *trainingoperator.PaddleJob) field.ErrorLi
allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), newJob.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ","))))
}

allErrs = append(allErrs, validateRunPolicy(oldJob, newJob)...)
allErrs = append(allErrs, validateSpec(newJob.Spec.PaddleReplicaSpecs)...)
return allErrs
}

func validateRunPolicy(oldJob, newJob *trainingoperator.PaddleJob) field.ErrorList {
var oldRunPolicy, newRunPolicy *trainingoperator.RunPolicy = nil, &newJob.Spec.RunPolicy
if oldJob != nil {
oldRunPolicy = &oldJob.Spec.RunPolicy
allErrs = append(allErrs, util.ValidateRunPolicyUpdate(&oldJob.Spec.RunPolicy, &newJob.Spec.RunPolicy)...)
} else {
allErrs = append(allErrs, util.ValidateRunPolicyCreate(&newJob.Spec.RunPolicy)...)
}

return util.ValidateManagedBy(oldRunPolicy, newRunPolicy)
allErrs = append(allErrs, validateSpec(newJob.Spec.PaddleReplicaSpecs)...)
return allErrs
}

func validateSpec(rSpecs map[trainingoperator.ReplicaType]*trainingoperator.ReplicaSpec) field.ErrorList {
Expand Down
15 changes: 5 additions & 10 deletions pkg/webhooks/pytorch/pytorchjob_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,22 +80,17 @@ func validatePyTorchJob(oldJob, newJob *trainingoperator.PyTorchJob) (admission.
if errors := apimachineryvalidation.NameIsDNS1035Label(newJob.ObjectMeta.Name, false); len(errors) != 0 {
allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), newJob.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ","))))
}
allErrs = append(allErrs, validateRunPolicy(oldJob, newJob)...)
if oldJob != nil {
allErrs = append(allErrs, util.ValidateRunPolicyUpdate(&oldJob.Spec.RunPolicy, &newJob.Spec.RunPolicy)...)
} else {
allErrs = append(allErrs, util.ValidateRunPolicyCreate(&newJob.Spec.RunPolicy)...)
}
ws, err := validateSpec(newJob.Spec)
warnings = append(warnings, ws...)
allErrs = append(allErrs, err...)
return warnings, allErrs
}

func validateRunPolicy(oldJob, newJob *trainingoperator.PyTorchJob) field.ErrorList {
var oldRunPolicy, newRunPolicy *trainingoperator.RunPolicy = nil, &newJob.Spec.RunPolicy
if oldJob != nil {
oldRunPolicy = &oldJob.Spec.RunPolicy
}

return util.ValidateManagedBy(oldRunPolicy, newRunPolicy)
}

func validateSpec(spec trainingoperator.PyTorchJobSpec) (admission.Warnings, field.ErrorList) {
var allErrs field.ErrorList
var warnings admission.Warnings
Expand Down
15 changes: 5 additions & 10 deletions pkg/webhooks/tensorflow/tfjob_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,18 +75,13 @@ func validateTFJob(oldJob, newJob *trainingoperator.TFJob) field.ErrorList {
if errors := apimachineryvalidation.NameIsDNS1035Label(newJob.Name, false); len(errors) != 0 {
allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), newJob.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ","))))
}
allErrs = append(allErrs, validateRunPolicy(oldJob, newJob)...)
allErrs = append(allErrs, validateSpec(newJob.Spec)...)
return allErrs
}

func validateRunPolicy(oldJob, newJob *trainingoperator.TFJob) field.ErrorList {
var oldRunPolicy, newRunPolicy *trainingoperator.RunPolicy = nil, &newJob.Spec.RunPolicy
if oldJob != nil {
oldRunPolicy = &oldJob.Spec.RunPolicy
allErrs = append(allErrs, util.ValidateRunPolicyUpdate(&oldJob.Spec.RunPolicy, &newJob.Spec.RunPolicy)...)
} else {
allErrs = append(allErrs, util.ValidateRunPolicyCreate(&newJob.Spec.RunPolicy)...)
}

return util.ValidateManagedBy(oldRunPolicy, newRunPolicy)
allErrs = append(allErrs, validateSpec(newJob.Spec)...)
return allErrs
}

func validateSpec(spec trainingoperator.TFJobSpec) field.ErrorList {
Expand Down
15 changes: 5 additions & 10 deletions pkg/webhooks/xgboost/xgboostjob_webhook.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,18 +76,13 @@ func validateXGBoostJob(oldJob, newJob *trainingoperator.XGBoostJob) field.Error
if errors := apimachineryvalidation.NameIsDNS1035Label(newJob.Name, false); len(errors) != 0 {
allErrs = append(allErrs, field.Invalid(field.NewPath("metadata").Child("name"), newJob.Name, fmt.Sprintf("should match: %v", strings.Join(errors, ","))))
}
allErrs = append(allErrs, validateRunPolicy(oldJob, newJob)...)
allErrs = append(allErrs, validateSpec(newJob.Spec)...)
return allErrs
}

func validateRunPolicy(oldJob, newJob *trainingoperator.XGBoostJob) field.ErrorList {
var oldRunPolicy, newRunPolicy *trainingoperator.RunPolicy = nil, &newJob.Spec.RunPolicy
if oldJob != nil {
oldRunPolicy = &oldJob.Spec.RunPolicy
allErrs = append(allErrs, util.ValidateRunPolicyUpdate(&oldJob.Spec.RunPolicy, &newJob.Spec.RunPolicy)...)
} else {
allErrs = append(allErrs, util.ValidateRunPolicyCreate(&newJob.Spec.RunPolicy)...)
}

return util.ValidateManagedBy(oldRunPolicy, newRunPolicy)
allErrs = append(allErrs, validateSpec(newJob.Spec)...)
return allErrs
}

func validateSpec(spec trainingoperator.XGBoostJobSpec) field.ErrorList {
Expand Down

0 comments on commit c94ab34

Please sign in to comment.