Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
mszadkow committed Jul 19, 2024
1 parent 116865c commit 03e2736
Show file tree
Hide file tree
Showing 10 changed files with 183 additions and 70 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -253,8 +253,8 @@ func TestReconciler(t *testing.T) {
reconcilerOptions: []jobframework.Option{
jobframework.WithManageJobsWithoutQueueName(true),
},
job: testingxgboostjob.MakeXGBoostJob("xgboostjob", "ns").Parallelism(2).Obj(),
wantJob: testingxgboostjob.MakeXGBoostJob("xgboostjob", "ns").Parallelism(2).Obj(),
job: testingxgboostjob.MakeXGBoostJob("xgboostjob", "ns").XGBReplicaSpecsDefault().Parallelism(2).Obj(),
wantJob: testingxgboostjob.MakeXGBoostJob("xgboostjob", "ns").XGBReplicaSpecsDefault().Parallelism(2).Obj(),
wantWorkloads: []kueue.Workload{
*utiltesting.MakeWorkload("xgboostjob", "ns").
PodSets(
Expand All @@ -268,12 +268,13 @@ func TestReconciler(t *testing.T) {
reconcilerOptions: []jobframework.Option{
jobframework.WithManageJobsWithoutQueueName(false),
},
job: testingxgboostjob.MakeXGBoostJob("xgboostjob", "ns").Parallelism(2).Obj(),
wantJob: testingxgboostjob.MakeXGBoostJob("xgboostjob", "ns").Parallelism(2).Obj(),
job: testingxgboostjob.MakeXGBoostJob("xgboostjob", "ns").XGBReplicaSpecsDefault().Parallelism(2).Obj(),
wantJob: testingxgboostjob.MakeXGBoostJob("xgboostjob", "ns").XGBReplicaSpecsDefault().Parallelism(2).Obj(),
wantWorkloads: []kueue.Workload{},
},
"when workload is evicted, suspended is reset, restore node affinity": {
job: testingxgboostjob.MakeXGBoostJob("xgboostjob", "ns").
XGBReplicaSpecsDefault().
Image("").
Args(nil).
Queue("foo").
Expand Down Expand Up @@ -317,6 +318,7 @@ func TestReconciler(t *testing.T) {
Obj(),
},
wantJob: testingxgboostjob.MakeXGBoostJob("xgboostjob", "ns").
XGBReplicaSpecsDefault().
Image("").
Args(nil).
Queue("foo").
Expand Down
7 changes: 4 additions & 3 deletions pkg/util/testingjobs/mxjob/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ func (j *MXJobWrapper) MXReplicaSpecs(replicaSpecs ...MXReplicaSpecRequirement)
for _, rs := range replicaSpecs {
j.Spec.MXReplicaSpecs[rs.ReplicaType].Replicas = ptr.To[int32](rs.ReplicaCount)
j.Spec.MXReplicaSpecs[rs.ReplicaType].Template.Spec.RestartPolicy = corev1.RestartPolicy(rs.RestartPolicy)
j.Spec.MXReplicaSpecs[rs.ReplicaType].Template.Spec.Containers[0].Name = "mxnet"

if rs.Annotations != nil {
j.Spec.MXReplicaSpecs[rs.ReplicaType].Template.ObjectMeta.Annotations = rs.Annotations
Expand All @@ -77,7 +78,7 @@ func (j *MXJobWrapper) MXReplicaSpecsDefault() *MXJobWrapper {
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "mxnet",
Name: "c",
Image: "pause",
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
Expand All @@ -95,7 +96,7 @@ func (j *MXJobWrapper) MXReplicaSpecsDefault() *MXJobWrapper {
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "mxnet",
Name: "c",
Image: "pause",
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
Expand All @@ -113,7 +114,7 @@ func (j *MXJobWrapper) MXReplicaSpecsDefault() *MXJobWrapper {
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "mxnet",
Name: "c",
Image: "pause",
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
Expand Down
5 changes: 3 additions & 2 deletions pkg/util/testingjobs/paddlejob/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func (j *PaddleJobWrapper) PaddleReplicaSpecs(replicaSpecs ...PaddleReplicaSpecR
for _, rs := range replicaSpecs {
j.Spec.PaddleReplicaSpecs[rs.ReplicaType].Replicas = ptr.To[int32](rs.ReplicaCount)
j.Spec.PaddleReplicaSpecs[rs.ReplicaType].Template.Spec.RestartPolicy = corev1.RestartPolicy(rs.RestartPolicy)
j.Spec.PaddleReplicaSpecs[rs.ReplicaType].Template.Spec.Containers[0].Name = "paddle"

if rs.Annotations != nil {
j.Spec.PaddleReplicaSpecs[rs.ReplicaType].Template.ObjectMeta.Annotations = rs.Annotations
Expand All @@ -76,7 +77,7 @@ func (j *PaddleJobWrapper) PaddleReplicaSpecsDefault() *PaddleJobWrapper {
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "paddle",
Name: "c",
Image: "pause",
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
Expand All @@ -94,7 +95,7 @@ func (j *PaddleJobWrapper) PaddleReplicaSpecsDefault() *PaddleJobWrapper {
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "paddle",
Name: "c",
Image: "pause",
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
Expand Down
62 changes: 45 additions & 17 deletions pkg/util/testingjobs/pytorchjob/wrappers_pytorchjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ func MakePyTorchJob(name, ns string) *PyTorchJobWrapper {
RunPolicy: kftraining.RunPolicy{
Suspend: ptr.To(true),
},
PyTorchReplicaSpecs: make(map[kftraining.ReplicaType]*kftraining.ReplicaSpec),
},
}}
}
Expand All @@ -54,28 +55,55 @@ type PyTorchReplicaSpecRequirement struct {
}

func (j *PyTorchJobWrapper) PyTorchReplicaSpecs(replicaSpecs ...PyTorchReplicaSpecRequirement) *PyTorchJobWrapper {
j.Spec.PyTorchReplicaSpecs = make(map[kftraining.ReplicaType]*kftraining.ReplicaSpec)

j = j.PyTorchReplicaSpecsDefault()
for _, rs := range replicaSpecs {
j.Spec.PyTorchReplicaSpecs[rs.ReplicaType] = &kftraining.ReplicaSpec{
Replicas: ptr.To[int32](rs.ReplicaCount),
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Annotations: rs.Annotations,
j.Spec.PyTorchReplicaSpecs[rs.ReplicaType].Replicas = ptr.To[int32](rs.ReplicaCount)
j.Spec.PyTorchReplicaSpecs[rs.ReplicaType].Template.Spec.RestartPolicy = corev1.RestartPolicy(rs.RestartPolicy)
j.Spec.PyTorchReplicaSpecs[rs.ReplicaType].Template.Spec.Containers[0].Name = "pytorch"

if rs.Annotations != nil {
j.Spec.PyTorchReplicaSpecs[rs.ReplicaType].Template.ObjectMeta.Annotations = rs.Annotations
}
}

return j
}

func (j *PyTorchJobWrapper) PyTorchReplicaSpecsDefault() *PyTorchJobWrapper {
j.Spec.PyTorchReplicaSpecs[kftraining.PyTorchJobReplicaTypeMaster] = &kftraining.ReplicaSpec{
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "c",
Image: "pause",
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
},
},
Spec: corev1.PodSpec{
RestartPolicy: corev1.RestartPolicy(rs.RestartPolicy),
Containers: []corev1.Container{
{
Name: "pytorch", // each pytorchjob container must have the name "pytorch"
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
},
NodeSelector: map[string]string{},
},
},
}

j.Spec.PyTorchReplicaSpecs[kftraining.PyTorchJobReplicaTypeWorker] = &kftraining.ReplicaSpec{
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "c",
Image: "pause",
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
},
NodeSelector: map[string]string{},
},
NodeSelector: map[string]string{},
},
}
},
}

return j
Expand Down
79 changes: 62 additions & 17 deletions pkg/util/testingjobs/tfjob/wrappers_tfjob.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,73 @@ type TFReplicaSpecRequirement struct {
}

func (j *TFJobWrapper) TFReplicaSpecs(replicaSpecs ...TFReplicaSpecRequirement) *TFJobWrapper {
j.Spec.TFReplicaSpecs = make(map[kftraining.ReplicaType]*kftraining.ReplicaSpec)

j = j.TFReplicaSpecsDefault()
for _, rs := range replicaSpecs {
j.Spec.TFReplicaSpecs[rs.ReplicaType] = &kftraining.ReplicaSpec{
Replicas: ptr.To[int32](rs.ReplicaCount),
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Annotations: rs.Annotations,
j.Spec.TFReplicaSpecs[rs.ReplicaType].Replicas = ptr.To[int32](rs.ReplicaCount)
j.Spec.TFReplicaSpecs[rs.ReplicaType].Template.Spec.RestartPolicy = corev1.RestartPolicy(rs.RestartPolicy)
j.Spec.TFReplicaSpecs[rs.ReplicaType].Template.Spec.Containers[0].Name = "tensorflow"

if rs.Annotations != nil {
j.Spec.TFReplicaSpecs[rs.ReplicaType].Template.ObjectMeta.Annotations = rs.Annotations
}
}

return j
}

func (j *TFJobWrapper) TFReplicaSpecsDefault() *TFJobWrapper {
j.Spec.TFReplicaSpecs[kftraining.TFJobReplicaTypeChief] = &kftraining.ReplicaSpec{
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "c",
Image: "pause",
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
},
},
Spec: corev1.PodSpec{
RestartPolicy: corev1.RestartPolicy(rs.RestartPolicy),
Containers: []corev1.Container{
{
Name: "tensorflow", // each tfjob container must have the name "tensorflow"
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
},
NodeSelector: map[string]string{},
},
},
}

j.Spec.TFReplicaSpecs[kftraining.TFJobReplicaTypePS] = &kftraining.ReplicaSpec{
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "c",
Image: "pause",
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
},
NodeSelector: map[string]string{},
},
NodeSelector: map[string]string{},
},
}
},
}

j.Spec.TFReplicaSpecs[kftraining.TFJobReplicaTypeWorker] = &kftraining.ReplicaSpec{
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "c",
Image: "pause",
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
},
},
NodeSelector: map[string]string{},
},
},
}

return j
Expand Down
61 changes: 44 additions & 17 deletions pkg/util/testingjobs/xgboostjob/wrappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,28 +55,55 @@ type XGBReplicaSpecRequirement struct {
}

func (j *XGBoostJobWrapper) XGBReplicaSpecs(replicaSpecs ...XGBReplicaSpecRequirement) *XGBoostJobWrapper {
j.Spec.XGBReplicaSpecs = make(map[kftraining.ReplicaType]*kftraining.ReplicaSpec)

j = j.XGBReplicaSpecsDefault()
for _, rs := range replicaSpecs {
j.Spec.XGBReplicaSpecs[rs.ReplicaType] = &kftraining.ReplicaSpec{
Replicas: ptr.To[int32](rs.ReplicaCount),
Template: corev1.PodTemplateSpec{
ObjectMeta: metav1.ObjectMeta{
Annotations: rs.Annotations,
j.Spec.XGBReplicaSpecs[rs.ReplicaType].Replicas = ptr.To[int32](rs.ReplicaCount)
j.Spec.XGBReplicaSpecs[rs.ReplicaType].Template.Spec.RestartPolicy = corev1.RestartPolicy(rs.RestartPolicy)
j.Spec.XGBReplicaSpecs[rs.ReplicaType].Template.Spec.Containers[0].Name = "xgboost"

if rs.Annotations != nil {
j.Spec.XGBReplicaSpecs[rs.ReplicaType].Template.ObjectMeta.Annotations = rs.Annotations
}
}

return j
}

func (j *XGBoostJobWrapper) XGBReplicaSpecsDefault() *XGBoostJobWrapper {
j.Spec.XGBReplicaSpecs[kftraining.XGBoostJobReplicaTypeMaster] = &kftraining.ReplicaSpec{
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "c",
Image: "pause",
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
},
},
Spec: corev1.PodSpec{
RestartPolicy: corev1.RestartPolicy(rs.RestartPolicy),
Containers: []corev1.Container{
{
Name: "xgboost", // each XgBoostJob container must have the name "xgboost"
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
},
NodeSelector: map[string]string{},
},
},
}

j.Spec.XGBReplicaSpecs[kftraining.XGBoostJobReplicaTypeWorker] = &kftraining.ReplicaSpec{
Replicas: ptr.To[int32](1),
Template: corev1.PodTemplateSpec{
Spec: corev1.PodSpec{
RestartPolicy: "Never",
Containers: []corev1.Container{
{
Name: "c",
Image: "pause",
Command: []string{},
Resources: corev1.ResourceRequirements{Requests: corev1.ResourceList{}},
},
NodeSelector: map[string]string{},
},
NodeSelector: map[string]string{},
},
}
},
}

return j
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ var _ = ginkgo.Describe("Job controller", ginkgo.Ordered, ginkgo.ContinueOnFailu
})

ginkgo.It("Should reconcile MXJobs", func() {
kfJob := kubeflowjob.KubeflowJob{KFJobControl: (*workloadmxjob.JobControl)(testingmxjob.MakeMXJob(jobName, ns.Name).Obj())}
kfJob := kubeflowjob.KubeflowJob{KFJobControl: (*workloadmxjob.JobControl)(testingmxjob.MakeMXJob(jobName, ns.Name).MXReplicaSpecsDefault().Obj())}
createdJob := kubeflowjob.KubeflowJob{KFJobControl: (*workloadmxjob.JobControl)(&kftraining.MXJob{})}

kftesting.ShouldReconcileJob(ctx, k8sClient, kfJob, createdJob, []kftesting.PodSetsResource{
Expand Down
Loading

0 comments on commit 03e2736

Please sign in to comment.