diff --git a/R/LearnerTorch.R b/R/LearnerTorch.R index 10466256..fcb9c942 100644 --- a/R/LearnerTorch.R +++ b/R/LearnerTorch.R @@ -452,15 +452,6 @@ LearnerTorch = R6Class("LearnerTorch", # Ideally we could rely on state$train_task, but there is this complication # https://github.com/mlr-org/mlr3/issues/947 param_vals$device = auto_device(param_vals$device) - if (!test_equal_col_info(ci_train, ci_predict)) { # nolint - stopf(paste0( - "Predict task's column info does not match the train task's column info.\n", - "This migth be handled more gracefully in the future.\n", - "Training column info:\n'%s'\n", - "Prediction column info:\n'%s'"), - paste0(capture.output(ci_train), collapse = "\n"), - paste0(capture.output(ci_predict), collapse = "\n")) - } private$.verify_predict_task(task, param_vals) with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads, { diff --git a/tests/testthat/test_learner_torch_methods.R b/tests/testthat/test_learner_torch_methods.R index 673b1e4e..6f1088df 100644 --- a/tests/testthat/test_learner_torch_methods.R +++ b/tests/testthat/test_learner_torch_methods.R @@ -178,5 +178,5 @@ test_that("wrong column info stops learner from prediction.", { learner = lrn("regr.torch_featureless", epochs = 1, batch_size = 50) learner$train(t1) - expect_error(learner$predict(t2), "more gracefully") + expect_error(learner$predict(t2)) })