diff --git a/DESCRIPTION b/DESCRIPTION index 37cb21d..2c97d4f 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -13,8 +13,8 @@ Depends: Imports: rpart, rpart.plot (>= 3.0.6), - cluster, lavaan, + cluster, ggplot2, tidyr, methods, @@ -25,7 +25,6 @@ Imports: clisymbols, future.apply, data.table, - ctsemOMX, expm, gridBase Suggests: @@ -35,7 +34,8 @@ Suggests: MASS, psychTools, testthat, - future + future, + ctsemOMX Description: SEM Trees and SEM Forests -- an extension of model-based decision trees and forests to Structural Equation Models (SEM). SEM trees hierarchically split empirical data into homogeneous groups each sharing similar data patterns diff --git a/NAMESPACE b/NAMESPACE index 6d07d54..cf5199b 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -59,6 +59,7 @@ export(prune) export(se) export(semforest) export(semforest.control) +export(semforest_control) export(semforest_score_control) export(semtree) export(semtree.constraints) diff --git a/R/growTree.R b/R/growTree.R index fb26334..e47c4a3 100644 --- a/R/growTree.R +++ b/R/growTree.R @@ -315,10 +315,6 @@ growTree <- function(model = NULL, mydata = NULL, return(NULL) } ) - } - # 3. Traditional cross validation for covariate split selection - else if (control$method == "cv") { - stop("This split selection procedure is not supported anymore. Please see the new score-based tests for split selection.") } else { ui_fail("Error. Unknown split method selected") stop() diff --git a/R/semforest.R b/R/semforest.R index b5d86fe..2bb864c 100644 --- a/R/semforest.R +++ b/R/semforest.R @@ -114,6 +114,13 @@ semforest <- function(model, } + # set mtry heuristically if not set manually + if (is.null(semforest.control$mtry)) { + num_covariates <- length(covariate.ids) + mtry <- ceil(sqrt(num_covariates)) + ui_message("Setting mtry = ",mtry," based on ",num_covariates," predictors.\n") + } + # pass mtry from forest to tree control if (!is.na(semforest.control$semtree.control$mtry)) { ui_stop( diff --git a/R/semforest.control.R b/R/semforest.control.R index dc6268b..0238afe 100644 --- a/R/semforest.control.R +++ b/R/semforest.control.R @@ -4,7 +4,7 @@ #' algorithm. #' #' -#' @aliases semforest.control print.semforest.control semforest_score_control +#' @aliases semforest.control semforest_control print.semforest.control semforest_score_control #' @param num.trees Number of trees. #' @param sampling Sampling procedure. Can be subsample or bootstrap. #' @param control A SEM Tree control object. Will be generated by default. @@ -55,3 +55,9 @@ semforest_score_control <- function(...) return(ctrl) } + +#' @export +semforest_control <- function(...) +{ + semforest.control(...) +} diff --git a/R/semtree.control.R b/R/semtree.control.R index 2d1262e..038db87 100644 --- a/R/semtree.control.R +++ b/R/semtree.control.R @@ -99,7 +99,7 @@ #' #' @export semtree.control <- - function(method = "naive", + function(method = c("naive","score","fair","fair3"), min.N = 20, max.depth = NA, alpha = .05, @@ -146,7 +146,7 @@ semtree.control <- # minimum number of cases in leaf options$min.bucket <- min.bucket # method - options$method <- method + options$method <- match.arg(method) # maximal depth of the tree , set to NA for unrestricted trees options$max.depth <- max.depth # test invariance of strong restrictions diff --git a/man/semforest.control.Rd b/man/semforest.control.Rd index 0a0d309..a523105 100644 --- a/man/semforest.control.Rd +++ b/man/semforest.control.Rd @@ -2,6 +2,7 @@ % Please edit documentation in R/semforest.control.R \name{semforest.control} \alias{semforest.control} +\alias{semforest_control} \alias{print.semforest.control} \alias{semforest_score_control} \title{SEM Forest Control Object} diff --git a/man/semtree.control.Rd b/man/semtree.control.Rd index ff7444f..b908002 100644 --- a/man/semtree.control.Rd +++ b/man/semtree.control.Rd @@ -6,7 +6,7 @@ \title{SEM Tree Control Object} \usage{ semtree.control( - method = "naive", + method = c("naive", "score", "fair", "fair3"), min.N = 20, max.depth = NA, alpha = 0.05, diff --git a/vignettes/forests.Rmd b/vignettes/forests.Rmd index 316361a..6a73c48 100644 --- a/vignettes/forests.Rmd +++ b/vignettes/forests.Rmd @@ -69,7 +69,7 @@ summary(result) Create a forest control object that stores all tuning parameters of the forest. Note that we use only 5 trees for illustration. Please increase the number in real applications to several hundreds. To speed up computation time, consider score-based test for variable selection in the trees. ```{r} -control <- semforest.control(num.trees = 5) +control <- semforest_control(num.trees = 5) print(control) ```