Skip to content

Commit

Permalink
pep8 fix
Browse files Browse the repository at this point in the history
  • Loading branch information
aPovidlo committed Jul 31, 2024
1 parent afed3f9 commit 3a2c7c0
Showing 1 changed file with 20 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,21 @@ def fit(self, input_data: InputData):
if self.params.get('use_eval_set'):
train_input, eval_input = train_test_data_setup(input_data)

X_train, y_train = self.convert_to_dataframe(train_input, identify_cats=self.params.get('enable_categorical'))
X_eval, y_eval = self.convert_to_dataframe(eval_input, identify_cats=self.params.get('enable_categorical'))
X_train, y_train = self.convert_to_dataframe(
train_input, identify_cats=self.params.get('enable_categorical')
)

X_eval, y_eval = self.convert_to_dataframe(
eval_input, identify_cats=self.params.get('enable_categorical')
)

self.model.eval_metric = self.set_eval_metric(self.classes_)

self.model.fit(X=X_train, y=y_train, eval_set=[(X_eval, y_eval)], verbose=self.model_params['verbosity'])
else:
X_train, y_train = self.convert_to_dataframe(input_data, identify_cats=self.params.get('enable_categorical'))
X_train, y_train = self.convert_to_dataframe(
input_data, identify_cats=self.params.get('enable_categorical')
)
self.features_names = input_data.features_names

self.model.fit(X=X_train, y=y_train, verbose=self.model_params['verbosity'])
Expand Down Expand Up @@ -155,8 +162,13 @@ def fit(self, input_data: InputData):
if self.params.get('use_eval_set'):
train_input, eval_input = train_test_data_setup(input_data)

X_train, y_train = self.convert_to_dataframe(train_input, identify_cats=self.params.get('enable_categorical'))
X_eval, y_eval = self.convert_to_dataframe(eval_input, identify_cats=self.params.get('enable_categorical'))
X_train, y_train = self.convert_to_dataframe(
train_input, identify_cats=self.params.get('enable_categorical')
)

X_eval, y_eval = self.convert_to_dataframe(
eval_input, identify_cats=self.params.get('enable_categorical')
)

eval_metric = self.set_eval_metric(self.classes_)
callbacks = self.update_callbacks()
Expand All @@ -168,7 +180,9 @@ def fit(self, input_data: InputData):
)

else:
X_train, y_train = self.convert_to_dataframe(input_data, identify_cats=self.params.get('enable_categorical'))
X_train, y_train = self.convert_to_dataframe(
input_data, identify_cats=self.params.get('enable_categorical')
)

self.model.fit(
X=X_train, y=y_train,
Expand Down

0 comments on commit 3a2c7c0

Please sign in to comment.