-
-
Notifications
You must be signed in to change notification settings - Fork 25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PipeOpLearnerCVPlus #838
base: master
Are you sure you want to change the base?
PipeOpLearnerCVPlus #838
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the reasons why the tests fail is that lrn("regr.ranger")
is not in the lrn()
dictionary when mlr3learners
is not loaded. So you either have to load that, or see if you can use lrn("regr.rpart")
instead, which is present when using mlr3 without mlr3learners.
R/PipeOpLearnerCVPlus.R
Outdated
} | ||
|
||
task_type = mlr_reflections$task_types[type, mult = "first"]$task | ||
out_type = mlr_reflections$task_types[type, mult = "first"]$task |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think you need out_type
here? You can just use the task_type
variable.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
out_type
was necessary here, but I changed it to mlr_reflections$task_types[type, mult = "first"]$prediction
like you suggested.
R/PipeOpLearnerCVPlus.R
Outdated
prds = rbindlist(map(rr$predictions(predict_sets = "test"), as.data.table), idcol = "fold") | ||
|
||
# Add trained models and residuals to PipeOp state | ||
self$state = list(cv_models = rr$learners, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
self$state = list(cv_models = rr$learners, | |
self$state = list(cv_model_states = map(rr$learners, "state"), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is best to only store the essential information in the PipeOp's $state
, which are the model states. Try not to store the entire learner objects.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You'd also need to update the documentation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Okay, I changed the state and documentation to only include the state.
R/PipeOpLearnerCVPlus.R
Outdated
if (is.null(self$state) || is_noop(self$state)) { | ||
private$.learner | ||
} else { | ||
multiplicity_recurse(self$state, clone_with_state, learner = private$.learner) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this does not work, since the PipeOp's $state
is not the same as the Learner's $state
; this is different in the PipeOpLearnerCV. But tbh I don't know what kind of result one would want here, maybe best to discuss this.
One possibility of what could work would be
multiplicity_recurse(self$state, learner = function(state) {
map(state$cv_model_states, clone_with_state, learner = private$.learner)
})
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This assumes you accept the 'cv_model_states' edit below. Also I haven't checked it, there may be a typo in there.
R/PipeOpLearnerCVPlus.R
Outdated
mu_hat = map(self$state$cv_models, function(learner) { | ||
as.data.table(learner$predict(task)) | ||
}) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
mu_hat = map(self$state$cv_models, function(learner) { | |
as.data.table(learner$predict(task)) | |
}) | |
mu_hat = map(self$state$cv_model_states, function(state) { | |
on.exit({private$.learner$state = NULL}) | |
private$.learner$state = state | |
as.data.table(private$.learner$predict(task)) | |
}) |
Also don't forget to add a NEWS entry. And feel free to add yourself as |
No description provided.