Skip to content

Commit

Permalink
...
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Aug 20, 2024
1 parent fef4cdb commit 2110ceb
Show file tree
Hide file tree
Showing 2 changed files with 1 addition and 10 deletions.
9 changes: 0 additions & 9 deletions R/LearnerTorch.R
Original file line number Diff line number Diff line change
Expand Up @@ -452,15 +452,6 @@ LearnerTorch = R6Class("LearnerTorch",
# Ideally we could rely on state$train_task, but there is this complication
# https://github.com/mlr-org/mlr3/issues/947
param_vals$device = auto_device(param_vals$device)
if (!test_equal_col_info(ci_train, ci_predict)) { # nolint
stopf(paste0(
"Predict task's column info does not match the train task's column info.\n",
"This migth be handled more gracefully in the future.\n",
"Training column info:\n'%s'\n",
"Prediction column info:\n'%s'"),
paste0(capture.output(ci_train), collapse = "\n"),
paste0(capture.output(ci_predict), collapse = "\n"))
}
private$.verify_predict_task(task, param_vals)

with_torch_settings(seed = self$model$seed, num_threads = param_vals$num_threads, {
Expand Down
2 changes: 1 addition & 1 deletion tests/testthat/test_learner_torch_methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -178,5 +178,5 @@ test_that("wrong column info stops learner from prediction.", {

learner = lrn("regr.torch_featureless", epochs = 1, batch_size = 50)
learner$train(t1)
expect_error(learner$predict(t2), "more gracefully")
expect_error(learner$predict(t2))
})

0 comments on commit 2110ceb

Please sign in to comment.