diff --git a/R/add_coverage.R b/R/add_coverage.R deleted file mode 100644 index 97fdedefc..000000000 --- a/R/add_coverage.R +++ /dev/null @@ -1,100 +0,0 @@ -#' @title Get Quantile And Interval Coverage Values For Quantile-Based Forecasts -#' -#' @description For a validated forecast object in a quantile-based format -#' (see [as_forecast()] for more information), this function computes -#' - interval coverage of central prediction intervals -#' - quantile coverage for predictive quantiles -#' - the deviation between desired and actual coverage (both for interval and -#' quantile coverage) -#' -#' Coverage values are computed for a specific level of grouping, as specified -#' in the `by` argument. By default, coverage values are computed per model. -#' -#' **Interval coverage** -#' -#' Interval coverage for a given interval range is defined as the proportion of -#' observations that fall within the corresponding central prediction intervals. -#' Central prediction intervals are symmetric around the median and formed -#' by two quantiles that denote the lower and upper bound. For example, the 50% -#' central prediction interval is the interval between the 0.25 and 0.75 -#' quantiles of the predictive distribution. -#' -#' **Quantile coverage** -#' -#' Quantile coverage for a given quantile is defined as the proportion of -#' observed values that are smaller than the corresponding predictive quantile. -#' For example, the 0.5 quantile coverage is the proportion of observed values -#' that are smaller than the 0.5 quantile of the predictive distribution. -#' Just as above, for a single observation and the quantile of a single -#' predictive distribution, the value will either be `TRUE` or `FALSE`. -#' -#' **Coverage deviation** -#' -#' The coverage deviation is the difference between the desired coverage -#' (can be either interval or quantile coverage) and the -#' actual coverage. For example, if the desired coverage is 90% and the actual -#' coverage is 80%, the coverage deviation is -0.1. -#' @return A data.table with columns as specified in `by` and additional -#' columns for the coverage values described above -#' @inheritParams score -#' @param by character vector that denotes the level of grouping for which the -#' coverage values should be computed. By default (`"model"`), one coverage -#' value per model will be returned. -#' @return a data.table with columns "interval_coverage", -#' "interval_coverage_deviation", "quantile_coverage", -#' "quantile_coverage_deviation" and the columns specified in `by`. -#' @importFrom data.table setcolorder -#' @importFrom checkmate assert_subset -#' @examples -#' library(magrittr) # pipe operator -#' example_quantile %>% -#' as_forecast() %>% -#' get_coverage(by = "model") -#' @export -#' @keywords scoring -#' @export -get_coverage <- function(data, by = "model") { - # input checks --------------------------------------------------------------- - data <- copy(data) - data <- na.omit(data) - suppressWarnings(suppressMessages(validate_forecast(data))) - assert_subset(get_forecast_type(data), "quantile") - - # remove "quantile_level" and "interval_range" from `by` if present, as these - # are included anyway - by <- setdiff(by, c("quantile_level", "interval_range")) - assert_subset(by, names(data)) - - # convert to wide interval format and compute interval coverage -------------- - interval_data <- quantile_to_interval(data, format = "wide") - interval_data[, - interval_coverage := (observed <= upper) & (observed >= lower) - ][, c("lower", "upper", "observed") := NULL] - interval_data[, interval_coverage_deviation := - interval_coverage - interval_range / 100] - - # merge interval range data with original data ------------------------------- - # preparations - data[, interval_range := get_range_from_quantile(quantile_level)] - data_cols <- colnames(data) # store so we can reset column order later - forecast_unit <- get_forecast_unit(data) - - data <- merge(data, interval_data, - by = unique(c(forecast_unit, "interval_range"))) - - # compute quantile coverage and deviation ------------------------------------ - data[, quantile_coverage := observed <= predicted] - data[, quantile_coverage_deviation := quantile_coverage - quantile_level] - - # summarise coverage values according to `by` and cleanup -------------------- - # reset column order - new_metrics <- c("interval_coverage", "interval_coverage_deviation", - "quantile_coverage", "quantile_coverage_deviation") - setcolorder(data, unique(c(data_cols, "interval_range", new_metrics))) - # remove forecast class and convert to regular data.table - data <- as.data.table(data) - by <- unique(c(by, "quantile_level", "interval_range")) - # summarise - data <- data[, lapply(.SD, mean), by = by, .SDcols = new_metrics] - return(data[]) -} diff --git a/R/available_forecasts.R b/R/available_forecasts.R deleted file mode 100644 index b03f4f66d..000000000 --- a/R/available_forecasts.R +++ /dev/null @@ -1,78 +0,0 @@ -#' @title Count Number of Available Forecasts -#' -#' @description -#' Given a data set with forecasts, this function counts the number of available forecasts. -#' The level of grouping can be specified using the `by` argument (e.g. to -#' count the number of forecasts per model, or the number of forecasts per -#' model and location). -#' This is useful to determine whether there are any missing forecasts. -#' -#' @param by character vector or `NULL` (the default) that denotes the -#' categories over which the number of forecasts should be counted. -#' By default (`by = NULL`) this will be the unit of a single forecast (i.e. -#' all available columns (apart from a few "protected" columns such as -#' 'predicted' and 'observed') plus "quantile_level" or "sample_id" where -#' present). -#' -#' @param collapse character vector (default: `c("quantile_level", "sample_id"`) -#' with names of categories for which the number of rows should be collapsed to -#' one when counting. For example, a single forecast is usually represented by a -#' set of several quantiles or samples and collapsing these to one makes sure -#' that a single forecast only gets counted once. Setting `collapse = c()` -#' would mean that all quantiles / samples would be counted as individual -#' forecasts. -#' -#' @return A data.table with columns as specified in `by` and an additional -#' column "count" with the number of forecasts. -#' -#' @inheritParams score -#' @importFrom data.table .I .N nafill -#' @export -#' @keywords check-forecasts -#' @examples -#' \dontshow{ -#' data.table::setDTthreads(2) # restricts number of cores used on CRAN -#' } -#' -#' get_forecast_counts( -#' as_forecast(example_quantile), -#' by = c("model", "target_type") -#' ) -get_forecast_counts <- function(data, - by = NULL, - collapse = c("quantile_level", "sample_id")) { - data <- copy(data) - suppressWarnings(suppressMessages(validate_forecast(data))) - forecast_unit <- get_forecast_unit(data) - data <- na.omit(data) - - if (is.null(by)) { - by <- forecast_unit - } - - # collapse several rows to 1, e.g. treat a set of 10 quantiles as one, - # because they all belong to one single forecast that should be counted once - collapse_by <- setdiff( - c(forecast_unit, "quantile_level", "sample_id"), - collapse - ) - # filter out "quantile_level" or "sample" if present in collapse_by, but not data - collapse_by <- intersect(collapse_by, names(data)) - - data <- data[data[, .I[1], by = collapse_by]$V1] - - # count number of rows = number of forecasts - out <- as.data.table(data)[, .(count = .N), by = by] - - # make sure that all combinations in "by" are included in the output (with - # count = 0). To achieve that, take the unique values in data and expand grid - col_vecs <- unclass(out) - col_vecs$count <- NULL - col_vecs <- lapply(col_vecs, unique) - out_empty <- expand.grid(col_vecs, stringsAsFactors = FALSE) - - out <- merge(out, out_empty, by = by, all.y = TRUE) - out[, count := nafill(count, fill = 0)] - - return(out[]) -} diff --git a/R/validate.R b/R/forecast.R similarity index 91% rename from R/validate.R rename to R/forecast.R index 8108d849d..e20ce9086 100644 --- a/R/validate.R +++ b/R/forecast.R @@ -391,44 +391,3 @@ is_forecast.forecast_point <- function(x, ...) { is_forecast.forecast_quantile <- function(x, ...) { inherits(x, "forecast_quantile") } - - -#' @title Validate metrics -#' -#' @description This function validates whether the list of metrics is a list -#' of valid functions. -#' -#' The function is used in [score()] to make sure that all metrics are valid -#' functions -#' -#' @param metrics A named list with metrics. Every element should be a scoring -#' function to be applied to the data. -#' @importFrom cli cli_warn -#' -#' @return A named list of metrics, with those filtered out that are not -#' valid functions -#' @importFrom checkmate assert_list test_list check_function -#' @keywords internal_input_check -validate_metrics <- function(metrics) { - - assert_list(metrics, min.len = 1, names = "named") - - for (i in seq_along(metrics)) { - check_fun <- check_function(metrics[[i]]) - if (!is.logical(check_fun)) { - #nolint start: keyword_quote_linter - cli_warn( - c( - "!" = "`Metrics` element number {i} is not a valid function." - ) - ) - #nolint end - names(metrics)[i] <- "scoringutils_delete" - } - } - metrics[names(metrics) == "scoringutils_delete"] <- NULL - - assert_list(metrics, min.len = 1, .var.name = "valid metrics") - - return(metrics) -} diff --git a/R/get_-functions.R b/R/get_-functions.R index 2a9396924..ae05293db 100644 --- a/R/get_-functions.R +++ b/R/get_-functions.R @@ -282,3 +282,185 @@ get_duplicate_forecasts <- function( out[, scoringutils_InternalDuplicateCheck := NULL] return(out[]) } + + +#' @title Get Quantile And Interval Coverage Values For Quantile-Based Forecasts +#' +#' @description For a validated forecast object in a quantile-based format +#' (see [as_forecast()] for more information), this function computes +#' - interval coverage of central prediction intervals +#' - quantile coverage for predictive quantiles +#' - the deviation between desired and actual coverage (both for interval and +#' quantile coverage) +#' +#' Coverage values are computed for a specific level of grouping, as specified +#' in the `by` argument. By default, coverage values are computed per model. +#' +#' **Interval coverage** +#' +#' Interval coverage for a given interval range is defined as the proportion of +#' observations that fall within the corresponding central prediction intervals. +#' Central prediction intervals are symmetric around the median and formed +#' by two quantiles that denote the lower and upper bound. For example, the 50% +#' central prediction interval is the interval between the 0.25 and 0.75 +#' quantiles of the predictive distribution. +#' +#' **Quantile coverage** +#' +#' Quantile coverage for a given quantile is defined as the proportion of +#' observed values that are smaller than the corresponding predictive quantile. +#' For example, the 0.5 quantile coverage is the proportion of observed values +#' that are smaller than the 0.5 quantile of the predictive distribution. +#' Just as above, for a single observation and the quantile of a single +#' predictive distribution, the value will either be `TRUE` or `FALSE`. +#' +#' **Coverage deviation** +#' +#' The coverage deviation is the difference between the desired coverage +#' (can be either interval or quantile coverage) and the +#' actual coverage. For example, if the desired coverage is 90% and the actual +#' coverage is 80%, the coverage deviation is -0.1. +#' @return A data.table with columns as specified in `by` and additional +#' columns for the coverage values described above +#' @inheritParams score +#' @param by character vector that denotes the level of grouping for which the +#' coverage values should be computed. By default (`"model"`), one coverage +#' value per model will be returned. +#' @return a data.table with columns "interval_coverage", +#' "interval_coverage_deviation", "quantile_coverage", +#' "quantile_coverage_deviation" and the columns specified in `by`. +#' @importFrom data.table setcolorder +#' @importFrom checkmate assert_subset +#' @examples +#' library(magrittr) # pipe operator +#' example_quantile %>% +#' as_forecast() %>% +#' get_coverage(by = "model") +#' @export +#' @keywords scoring +#' @export +get_coverage <- function(data, by = "model") { + # input checks --------------------------------------------------------------- + data <- copy(data) + data <- na.omit(data) + suppressWarnings(suppressMessages(validate_forecast(data))) + assert_subset(get_forecast_type(data), "quantile") + + # remove "quantile_level" and "interval_range" from `by` if present, as these + # are included anyway + by <- setdiff(by, c("quantile_level", "interval_range")) + assert_subset(by, names(data)) + + # convert to wide interval format and compute interval coverage -------------- + interval_data <- quantile_to_interval(data, format = "wide") + interval_data[, + interval_coverage := (observed <= upper) & (observed >= lower) + ][, c("lower", "upper", "observed") := NULL] + interval_data[, interval_coverage_deviation := + interval_coverage - interval_range / 100] + + # merge interval range data with original data ------------------------------- + # preparations + data[, interval_range := get_range_from_quantile(quantile_level)] + data_cols <- colnames(data) # store so we can reset column order later + forecast_unit <- get_forecast_unit(data) + + data <- merge(data, interval_data, + by = unique(c(forecast_unit, "interval_range"))) + + # compute quantile coverage and deviation ------------------------------------ + data[, quantile_coverage := observed <= predicted] + data[, quantile_coverage_deviation := quantile_coverage - quantile_level] + + # summarise coverage values according to `by` and cleanup -------------------- + # reset column order + new_metrics <- c("interval_coverage", "interval_coverage_deviation", + "quantile_coverage", "quantile_coverage_deviation") + setcolorder(data, unique(c(data_cols, "interval_range", new_metrics))) + # remove forecast class and convert to regular data.table + data <- as.data.table(data) + by <- unique(c(by, "quantile_level", "interval_range")) + # summarise + data <- data[, lapply(.SD, mean), by = by, .SDcols = new_metrics] + return(data[]) +} + + +#' @title Count Number of Available Forecasts +#' +#' @description +#' Given a data set with forecasts, this function counts the number of available forecasts. +#' The level of grouping can be specified using the `by` argument (e.g. to +#' count the number of forecasts per model, or the number of forecasts per +#' model and location). +#' This is useful to determine whether there are any missing forecasts. +#' +#' @param by character vector or `NULL` (the default) that denotes the +#' categories over which the number of forecasts should be counted. +#' By default (`by = NULL`) this will be the unit of a single forecast (i.e. +#' all available columns (apart from a few "protected" columns such as +#' 'predicted' and 'observed') plus "quantile_level" or "sample_id" where +#' present). +#' +#' @param collapse character vector (default: `c("quantile_level", "sample_id"`) +#' with names of categories for which the number of rows should be collapsed to +#' one when counting. For example, a single forecast is usually represented by a +#' set of several quantiles or samples and collapsing these to one makes sure +#' that a single forecast only gets counted once. Setting `collapse = c()` +#' would mean that all quantiles / samples would be counted as individual +#' forecasts. +#' +#' @return A data.table with columns as specified in `by` and an additional +#' column "count" with the number of forecasts. +#' +#' @inheritParams score +#' @importFrom data.table .I .N nafill +#' @export +#' @keywords check-forecasts +#' @examples +#' \dontshow{ +#' data.table::setDTthreads(2) # restricts number of cores used on CRAN +#' } +#' +#' get_forecast_counts( +#' as_forecast(example_quantile), +#' by = c("model", "target_type") +#' ) +get_forecast_counts <- function(data, + by = NULL, + collapse = c("quantile_level", "sample_id")) { + data <- copy(data) + suppressWarnings(suppressMessages(validate_forecast(data))) + forecast_unit <- get_forecast_unit(data) + data <- na.omit(data) + + if (is.null(by)) { + by <- forecast_unit + } + + # collapse several rows to 1, e.g. treat a set of 10 quantiles as one, + # because they all belong to one single forecast that should be counted once + collapse_by <- setdiff( + c(forecast_unit, "quantile_level", "sample_id"), + collapse + ) + # filter out "quantile_level" or "sample" if present in collapse_by, but not data + collapse_by <- intersect(collapse_by, names(data)) + + data <- data[data[, .I[1], by = collapse_by]$V1] + + # count number of rows = number of forecasts + out <- as.data.table(data)[, .(count = .N), by = by] + + # make sure that all combinations in "by" are included in the output (with + # count = 0). To achieve that, take the unique values in data and expand grid + col_vecs <- unclass(out) + col_vecs$count <- NULL + col_vecs <- lapply(col_vecs, unique) + out_empty <- expand.grid(col_vecs, stringsAsFactors = FALSE) + + out <- merge(out, out_empty, by = by, all.y = TRUE) + out[, count := nafill(count, fill = 0)] + + return(out[]) +} diff --git a/R/metrics-validate.R b/R/metrics-validate.R new file mode 100644 index 000000000..3c979be33 --- /dev/null +++ b/R/metrics-validate.R @@ -0,0 +1,39 @@ +#' @title Validate metrics +#' +#' @description This function validates whether the list of metrics is a list +#' of valid functions. +#' +#' The function is used in [score()] to make sure that all metrics are valid +#' functions +#' +#' @param metrics A named list with metrics. Every element should be a scoring +#' function to be applied to the data. +#' @importFrom cli cli_warn +#' +#' @return A named list of metrics, with those filtered out that are not +#' valid functions +#' @importFrom checkmate assert_list test_list check_function +#' @keywords internal_input_check +validate_metrics <- function(metrics) { + + assert_list(metrics, min.len = 1, names = "named") + + for (i in seq_along(metrics)) { + check_fun <- check_function(metrics[[i]]) + if (!is.logical(check_fun)) { + #nolint start: keyword_quote_linter + cli_warn( + c( + "!" = "`Metrics` element number {i} is not a valid function." + ) + ) + #nolint end + names(metrics)[i] <- "scoringutils_delete" + } + } + metrics[names(metrics) == "scoringutils_delete"] <- NULL + + assert_list(metrics, min.len = 1, .var.name = "valid metrics") + + return(metrics) +} diff --git a/R/pairwise-comparisons.R b/R/pairwise-comparisons.R index 37006865c..88bfeccf1 100644 --- a/R/pairwise-comparisons.R +++ b/R/pairwise-comparisons.R @@ -506,3 +506,58 @@ permutation_test <- function(scores1, # plus ones to make sure p-val is never 0? return(pVal) } + + +#' @title Add pairwise comparisons +#' @description Adds a columns with relative skills computed by running +#' pairwise comparisons on the scores. +#' For more information on +#' the computation of relative skill, see [pairwise_comparison()]. +#' Relative skill will be calculated for the aggregation level specified in +#' `by`. +#' @inheritParams pairwise_comparison +#' @export +#' @keywords keyword scoring +add_pairwise_comparison <- function( + scores, + by = "model", + metric = intersect(c("wis", "crps", "brier_score"), names(scores)), + baseline = NULL +) { + + # input checks are done in `pairwise_comparison()` + # do pairwise comparisons ---------------------------------------------------- + pairwise <- pairwise_comparison( + scores = scores, + metric = metric, + baseline = baseline, + by = by + ) + + # store original metrics + metrics <- get_metrics(scores) + + if (!is.null(pairwise)) { + # delete unnecessary columns + pairwise[, c( + "compare_against", "mean_scores_ratio", + "pval", "adj_pval" + ) := NULL] + pairwise <- unique(pairwise) + + # merge back + scores <- merge( + scores, pairwise, all.x = TRUE, by = get_forecast_unit(pairwise) + ) + } + + # Update score names + new_metrics <- paste( + metric, c("relative_skill", "scaled_relative_skill"), + sep = "_" + ) + new_metrics <- new_metrics[new_metrics %in% names(scores)] + scores <- new_scores(scores, metrics = c(metrics, new_metrics)) + + return(scores) +} diff --git a/R/print.R b/R/print.R new file mode 100644 index 000000000..853c83e22 --- /dev/null +++ b/R/print.R @@ -0,0 +1,84 @@ +#' @title Print Information About A Forecast Object +#' @description This function prints information about a forecast object, +#' including "Forecast type", "Score columns", +#' "Forecast unit". +#' +#' @param x An object of class 'forecast_*' object as produced by +#' `as_forecast()` +#' @param ... additional arguments for [print()] +#' @return returns x invisibly +#' @importFrom cli cli_inform cli_warn col_blue cli_text +#' @export +#' @keywords check-forecasts +#' @examples +#' dat <- as_forecast(example_quantile) +#' print(dat) +print.forecast_binary <- function(x, ...) { + + # check whether object passes validation + validation <- try(do.call(validate_forecast, list(data = x)), silent = TRUE) + if (inherits(validation, "try-error")) { + cli_warn( + c( + "!" = "Error in validating forecast object: {validation}." + ) + ) + } + + # get forecast type, forecast unit and score columns + forecast_type <- try( + do.call(get_forecast_type, list(data = x)), + silent = TRUE + ) + forecast_unit <- get_forecast_unit(x) + + # Print forecast object information + if (inherits(forecast_type, "try-error")) { + cli_inform( + "Could not determine forecast type due to error in validation." + ) + } else { + cli_text( + col_blue( + "Forecast type:" + ) + ) + cli_text( + "{forecast_type}" + ) + } + + if (length(forecast_unit) == 0) { + cli_inform( + c( + "!" = "Could not determine forecast unit." + ) + ) + } else { + cli_text( + col_blue( + "Forecast unit:" + ) + ) + cli_text( + "{forecast_unit}" + ) + } + + cat("\n") + NextMethod(x, ...) + + return(invisible(x)) +} + +#' @rdname print.forecast_binary +#' @export +print.forecast_quantile <- print.forecast_binary + +#' @rdname print.forecast_binary +#' @export +print.forecast_point <- print.forecast_binary + +#' @rdname print.forecast_binary +#' @export +print.forecast_sample <- print.forecast_binary diff --git a/R/summarise_scores.R b/R/summarise_scores.R index ea40a1e59..14d9b6a4e 100644 --- a/R/summarise_scores.R +++ b/R/summarise_scores.R @@ -103,59 +103,3 @@ summarise_scores <- function(scores, #' @keywords scoring #' @export summarize_scores <- summarise_scores - - - -#' @title Add pairwise comparisons -#' @description Adds a columns with relative skills computed by running -#' pairwise comparisons on the scores. -#' For more information on -#' the computation of relative skill, see [pairwise_comparison()]. -#' Relative skill will be calculated for the aggregation level specified in -#' `by`. -#' @inheritParams pairwise_comparison -#' @export -#' @keywords keyword scoring -add_pairwise_comparison <- function( - scores, - by = "model", - metric = intersect(c("wis", "crps", "brier_score"), names(scores)), - baseline = NULL -) { - - # input checks are done in `pairwise_comparison()` - # do pairwise comparisons ---------------------------------------------------- - pairwise <- pairwise_comparison( - scores = scores, - metric = metric, - baseline = baseline, - by = by - ) - - # store original metrics - metrics <- get_metrics(scores) - - if (!is.null(pairwise)) { - # delete unnecessary columns - pairwise[, c( - "compare_against", "mean_scores_ratio", - "pval", "adj_pval" - ) := NULL] - pairwise <- unique(pairwise) - - # merge back - scores <- merge( - scores, pairwise, all.x = TRUE, by = get_forecast_unit(pairwise) - ) - } - - # Update score names - new_metrics <- paste( - metric, c("relative_skill", "scaled_relative_skill"), - sep = "_" - ) - new_metrics <- new_metrics[new_metrics %in% names(scores)] - scores <- new_scores(scores, metrics = c(metrics, new_metrics)) - - return(scores) -} diff --git a/R/utils.R b/R/utils.R index 4ea9509f3..d7a7cbca0 100644 --- a/R/utils.R +++ b/R/utils.R @@ -71,88 +71,3 @@ ensure_data.table <- function(data) { } return(data) } - -#' @title Print Information About A Forecast Object -#' @description This function prints information about a forecast object, -#' including "Forecast type", "Score columns", -#' "Forecast unit". -#' -#' @param x An object of class 'forecast_*' object as produced by -#' `as_forecast()` -#' @param ... additional arguments for [print()] -#' @return returns x invisibly -#' @importFrom cli cli_inform cli_warn col_blue cli_text -#' @export -#' @keywords check-forecasts -#' @examples -#' dat <- as_forecast(example_quantile) -#' print(dat) -print.forecast_binary <- function(x, ...) { - - # check whether object passes validation - validation <- try(do.call(validate_forecast, list(data = x)), silent = TRUE) - if (inherits(validation, "try-error")) { - cli_warn( - c( - "!" = "Error in validating forecast object: {validation}." - ) - ) - } - - # get forecast type, forecast unit and score columns - forecast_type <- try( - do.call(get_forecast_type, list(data = x)), - silent = TRUE - ) - forecast_unit <- get_forecast_unit(x) - - # Print forecast object information - if (inherits(forecast_type, "try-error")) { - cli_inform( - "Could not determine forecast type due to error in validation." - ) - } else { - cli_text( - col_blue( - "Forecast type:" - ) - ) - cli_text( - "{forecast_type}" - ) - } - - if (length(forecast_unit) == 0) { - cli_inform( - c( - "!" = "Could not determine forecast unit." - ) - ) - } else { - cli_text( - col_blue( - "Forecast unit:" - ) - ) - cli_text( - "{forecast_unit}" - ) - } - - cat("\n") - NextMethod(x, ...) - - return(invisible(x)) -} - -#' @rdname print.forecast_binary -#' @export -print.forecast_quantile <- print.forecast_binary - -#' @rdname print.forecast_binary -#' @export -print.forecast_point <- print.forecast_binary - -#' @rdname print.forecast_binary -#' @export -print.forecast_sample <- print.forecast_binary diff --git a/man/add_pairwise_comparison.Rd b/man/add_pairwise_comparison.Rd index 9d3a8930b..57f93133d 100644 --- a/man/add_pairwise_comparison.Rd +++ b/man/add_pairwise_comparison.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/summarise_scores.R +% Please edit documentation in R/pairwise-comparisons.R \name{add_pairwise_comparison} \alias{add_pairwise_comparison} \title{Add pairwise comparisons} diff --git a/man/as_forecast.Rd b/man/as_forecast.Rd index efe1f4fce..0ce306e62 100644 --- a/man/as_forecast.Rd +++ b/man/as_forecast.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/validate.R +% Please edit documentation in R/as_forecast.R \name{as_forecast} \alias{as_forecast} \alias{as_forecast.default} diff --git a/man/get_coverage.Rd b/man/get_coverage.Rd index 9cc53124f..83a0615a0 100644 --- a/man/get_coverage.Rd +++ b/man/get_coverage.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/add_coverage.R +% Please edit documentation in R/get_-functions.R \name{get_coverage} \alias{get_coverage} \title{Get Quantile And Interval Coverage Values For Quantile-Based Forecasts} diff --git a/man/get_forecast_counts.Rd b/man/get_forecast_counts.Rd index eb30d3e15..17444511b 100644 --- a/man/get_forecast_counts.Rd +++ b/man/get_forecast_counts.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/available_forecasts.R +% Please edit documentation in R/get_-functions.R \name{get_forecast_counts} \alias{get_forecast_counts} \title{Count Number of Available Forecasts} diff --git a/man/is_forecast.Rd b/man/is_forecast.Rd index 4be25b453..a86b9f905 100644 --- a/man/is_forecast.Rd +++ b/man/is_forecast.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/validate.R +% Please edit documentation in R/as_forecast.R \name{is_forecast} \alias{is_forecast} \alias{is_forecast.default} diff --git a/man/new_forecast.Rd b/man/new_forecast.Rd index 0a86f349e..b19cd8768 100644 --- a/man/new_forecast.Rd +++ b/man/new_forecast.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/validate.R +% Please edit documentation in R/as_forecast.R \name{new_forecast} \alias{new_forecast} \title{Class constructor for scoringutils objects} diff --git a/man/print.forecast_binary.Rd b/man/print.forecast_binary.Rd index 60737a0e5..72f4b8f35 100644 --- a/man/print.forecast_binary.Rd +++ b/man/print.forecast_binary.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/utils.R +% Please edit documentation in R/print.R \name{print.forecast_binary} \alias{print.forecast_binary} \alias{print.forecast_quantile} diff --git a/man/validate_forecast.Rd b/man/validate_forecast.Rd index d475fb5e8..76065f0d9 100644 --- a/man/validate_forecast.Rd +++ b/man/validate_forecast.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/validate.R +% Please edit documentation in R/as_forecast.R \name{validate_forecast} \alias{validate_forecast} \alias{validate_forecast.forecast_quantile} diff --git a/man/validate_general.Rd b/man/validate_general.Rd index 8394145cf..1c89543fb 100644 --- a/man/validate_general.Rd +++ b/man/validate_general.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/validate.R +% Please edit documentation in R/as_forecast.R \name{validate_general} \alias{validate_general} \title{Validation Common To All Forecast Types} diff --git a/man/validate_metrics.Rd b/man/validate_metrics.Rd index d373eaa58..e443754ee 100644 --- a/man/validate_metrics.Rd +++ b/man/validate_metrics.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/validate.R +% Please edit documentation in R/as_forecast.R \name{validate_metrics} \alias{validate_metrics} \title{Validate metrics} diff --git a/tests/testthat/_snaps/utils.md b/tests/testthat/_snaps/print.md similarity index 100% rename from tests/testthat/_snaps/utils.md rename to tests/testthat/_snaps/print.md diff --git a/tests/testthat/test-add_coverage.R b/tests/testthat/test-add_coverage.R deleted file mode 100644 index e89c6af3d..000000000 --- a/tests/testthat/test-add_coverage.R +++ /dev/null @@ -1,53 +0,0 @@ -ex_coverage <- example_quantile[model == "EuroCOVIDhub-ensemble"] - -test_that("get_coverage() works as expected", { - cov <- example_quantile %>% - na.omit() %>% - as_forecast() %>% - get_coverage(by = get_forecast_unit(example_quantile)) - - expect_equal( - sort(colnames(cov)), - sort(c(get_forecast_unit(example_quantile), c( - "interval_range", "quantile_level", "interval_coverage", "interval_coverage_deviation", - "quantile_coverage", "quantile_coverage_deviation" - ))) - ) - - expect_equal(nrow(cov), nrow(na.omit(example_quantile))) -}) - -test_that("get_coverage() outputs an object of class c('data.table', 'data.frame'", { - ex <- as_forecast(na.omit(example_quantile)) - cov <- get_coverage(ex) - expect_s3_class(cov, c("data.table", "data.frame"), exact = TRUE) -}) - -test_that("get_coverage() can deal with non-symmetric prediction intervals", { - # the expected result is that `get_coverage()` just works. However, - # all interval coverages with missing values should just be `NA` - test <- data.table::copy(example_quantile) %>% - na.omit() %>% - as_forecast() - test <- test[!quantile_level %in% c(0.2, 0.3, 0.5)] - - expect_no_condition(cov <- get_coverage(test)) - - prediction_intervals <- get_range_from_quantile(c(0.2, 0.3, 0.5)) - - missing <- cov[interval_range %in% prediction_intervals] - not_missing <- cov[!interval_range %in% prediction_intervals] - - expect_true(all(is.na(missing$interval_coverage))) - expect_false(any(is.na(not_missing))) - - # test for a version where values are not missing, but just `NA` - # since `get_coverage()` calls `na.omit`, the result should be the same. - test <- data.table::copy(example_quantile) %>% - na.omit() %>% - as_forecast() %>% - suppressMessages() - test <- test[quantile_level %in% c(0.2, 0.3, 0.5), predicted := NA] - cov2 <- get_coverage(test) - expect_equal(cov, cov2) -}) diff --git a/tests/testthat/test-as_forecast.R b/tests/testthat/test-as_forecast.R deleted file mode 100644 index 76254633d..000000000 --- a/tests/testthat/test-as_forecast.R +++ /dev/null @@ -1,110 +0,0 @@ -test_that("Running `as_forecast()` twice returns the same object", { - ex <- na.omit(example_continuous) - - expect_identical( - as_forecast(as_forecast(ex)), - as_forecast(ex) - ) -}) - -test_that("as_forecast() works as expected", { - test <- na.omit(data.table::copy(example_quantile)) - expect_s3_class(as_forecast(test), "forecast_quantile") - - # expect error when arguments are not correct - expect_error(as_forecast(test, observed = 3), "Must be of type 'character'") - expect_error(as_forecast(test, quantile_level = c("1", "2")), "Must have length 1") - expect_error(as_forecast(test, observed = "missing"), "Must be a subset of") - - # expect no condition with columns already present - expect_no_condition( - as_forecast(test, - observed = "observed", predicted = "predicted", - forecast_unit = c( - "location", "model", "target_type", - "target_end_date", "horizon" - ), - quantile_level = "quantile_level" - ) - ) - - # additional test with renaming the model column - test <- na.omit(data.table::copy(example_continuous)) - data.table::setnames(test, - old = c("observed", "predicted", "sample_id", "model"), - new = c("obs", "pred", "sample", "mod") - ) - expect_no_condition( - as_forecast(test, - observed = "obs", predicted = "pred", model = "mod", - forecast_unit = c( - "location", "model", "target_type", - "target_end_date", "horizon" - ), - sample_id = "sample" - ) - ) - - # test if desired forecast type does not correspond to inferred one - test <- na.omit(data.table::copy(example_continuous)) - expect_error( - as_forecast(test, forecast_type = "quantile"), - "Forecast type determined by scoringutils based on input" - ) -}) - - -test_that("is_forecast() works as expected", { - ex_binary <- suppressMessages(as_forecast(example_binary)) - ex_point <- suppressMessages(as_forecast(example_point)) - ex_quantile <- suppressMessages(as_forecast(example_quantile)) - ex_continuous <- suppressMessages(as_forecast(example_continuous)) - - expect_true(is_forecast(ex_binary)) - expect_true(is_forecast(ex_point)) - expect_true(is_forecast(ex_quantile)) - expect_true(is_forecast(ex_continuous)) - - expect_false(is_forecast(1:10)) - expect_false(is_forecast(data.table::as.data.table(example_point))) - expect_false(is_forecast.forecast_sample(ex_quantile)) - expect_false(is_forecast.forecast_quantile(ex_binary)) -}) - - -test_that("validate_forecast() works as expected", { - # test that by default, `as_forecast()` errors - expect_error(validate_forecast(data.frame(x = 1:10)), - "The input needs to be a forecast object.") -}) - -test_that("validate_forecast.forecast_binary works as expected", { - test <- na.omit(data.table::copy(example_binary)) - test[, "sample_id" := 1:nrow(test)] - - # error if there is a superficial sample_id column - expect_error( - as_forecast(test), - "Input looks like a binary forecast, but an additional column called `sample_id` or `quantile` was found." - ) - - # expect error if probabilties are not in [0, 1] - test <- na.omit(data.table::copy(example_binary)) - test[, "predicted" := predicted + 1] - expect_error( - as_forecast(test), - "Input looks like a binary forecast, but found the following issue" - ) -}) - -test_that("validate_forecast.forecast_point() works as expected", { - test <- na.omit(data.table::copy(example_point)) - test <- as_forecast(test) - - # expect an error if column is changed to character after initial validation. - test <- test[, "predicted" := as.character(predicted)] - expect_error( - validate_forecast(test), - "Input looks like a point forecast, but found the following issue" - ) -}) diff --git a/tests/testthat/test-available_forecasts.R b/tests/testthat/test-available_forecasts.R deleted file mode 100644 index 43db41e91..000000000 --- a/tests/testthat/test-available_forecasts.R +++ /dev/null @@ -1,44 +0,0 @@ -test_that("get_forecast_counts() works as expected", { - af <- suppressMessages(as_forecast(example_quantile)) - af <- get_forecast_counts( - af, - by = c("model", "target_type", "target_end_date") - ) - - expect_type(af, "list") - expect_type(af$target_type, "character") - expect_type(af$`count`, "integer") - expect_equal(nrow(af[is.na(`count`)]), 0) - af <- na.omit(example_quantile) %>% - as_forecast() %>% - get_forecast_counts(by = "model") - expect_equal(nrow(af), 4) - expect_equal(af$`count`, c(256, 256, 128, 247)) - - # Ensure the returning object class is exactly same as a data.table. - expect_s3_class(af, c("data.table", "data.frame"), exact = TRUE) - - # Setting `collapse = c()` means that all quantiles and samples are counted - af <- na.omit(example_quantile) %>% - as_forecast() %>% - get_forecast_counts(by = "model", collapse = c()) - expect_equal(nrow(af), 4) - expect_equal(af$`count`, c(5888, 5888, 2944, 5681)) - - # setting by = NULL, the default, results in by equal to forecast unit - af <- na.omit(example_quantile) %>% - as_forecast() %>% - get_forecast_counts() - expect_equal(nrow(af), 50688) - - # check whether collapsing also works for model-based forecasts - af <- na.omit(example_integer) %>% - as_forecast() %>% - get_forecast_counts(by = "model") - expect_equal(nrow(af), 4) - - af <- na.omit(example_integer) %>% - as_forecast() %>% - get_forecast_counts(by = "model", collapse = c()) - expect_equal(af$count, c(10240, 10240, 5120, 9880)) -}) diff --git a/tests/testthat/test-check_forecasts.R b/tests/testthat/test-check_forecasts.R deleted file mode 100644 index 16537f8ef..000000000 --- a/tests/testthat/test-check_forecasts.R +++ /dev/null @@ -1,119 +0,0 @@ -test_that("as_forecast() function works", { - check <- suppressMessages(as_forecast(example_quantile)) - expect_s3_class(check, "forecast_quantile") -}) - -test_that("as_forecast() function has an error for empty data.frame", { - d <- data.frame(observed = numeric(), predicted = numeric(), model = character()) - - expect_error( - as_forecast(d), - "Assertion on 'data' failed: Must have at least 1 rows, but has 0 rows." - ) -}) - -test_that("as_forecast() errors if there is both a sample_id and a quantile_level column", { - example <- data.table::copy(example_quantile)[, sample_id := 1] - expect_error( - as_forecast(example), - "Found columns `quantile_level` and `sample_id`. Only one of these is allowed" - ) -}) - -test_that("check_columns_present() works", { - expect_equal( - check_columns_present(example_quantile, c("observed", "predicted", "nop")), - "Column 'nop' not found in data" - ) - expect_true( - check_columns_present(example_quantile, c("observed", "predicted")) - ) -}) - -test_that("check_duplicates() works", { - bad <- rbind( - example_quantile[1000:1010], - example_quantile[1000:1010] - ) - - expect_equal(scoringutils:::check_duplicates(bad), - "There are instances with more than one forecast for the same target. This can't be right and needs to be resolved. Maybe you need to check the unit of a single forecast and add missing columns? Use the function get_duplicate_forecasts() to identify duplicate rows" - ) -}) - -# test_that("as_forecast() function returns a message with NA in the data", { -# expect_message( -# { check <- as_forecast(example_quantile) }, -# "\\d+ values for `predicted` are NA" -# ) -# expect_match( -# unlist(check$messages), -# "\\d+ values for `predicted` are NA" -# ) -# }) - -# test_that("as_forecast() function returns messages with NA in the data", { -# example <- data.table::copy(example_quantile) -# example[horizon == 2, observed := NA] -# check <- suppressMessages(as_forecast(example)) -# -# expect_equal(length(check$messages), 2) -# }) - -test_that("as_forecast() function throws an error with duplicate forecasts", { - example <- rbind(example_quantile, - example_quantile[1000:1010]) - - expect_error( - suppressMessages(suppressWarnings(as_forecast(example))), - "Assertion on 'data' failed: There are instances with more than one forecast for the same target. This can't be right and needs to be resolved. Maybe you need to check the unit of a single forecast and add missing columns? Use the function get_duplicate_forecasts() to identify duplicate rows.", #nolint - fixed = TRUE - ) -}) - -test_that("as_forecast() function warns when no model column is present", { - no_model <- data.table::copy(example_quantile[model == "EuroCOVIDhub-ensemble"])[, model := NULL][] - expect_warning( - as_forecast(no_model), - "There is no column called `model` in the data.") -}) - -test_that("as_forecast() function throws an error when no predictions or observed values are present", { - expect_error(suppressMessages(suppressWarnings(as_forecast( - data.table::copy(example_quantile)[, predicted := NULL] - ))), - "Assertion on 'data' failed: Column 'predicted' not found in data.") - - expect_error(suppressMessages(suppressWarnings(as_forecast( - data.table::copy(example_quantile)[, observed := NULL] - ))), - "Assertion on 'data' failed: Column 'observed' not found in data.") - - expect_error(suppressMessages(suppressWarnings(as_forecast( - data.table::copy(example_quantile)[, c("observed", "predicted") := NULL] - ))), - "Assertion on 'data' failed: Columns 'observed', 'predicted' not found in data.") -}) - -# test_that("as_forecast() function throws an error when no predictions or observed values are present", { -# expect_error(suppressMessages(suppressWarnings(as_forecast( -# data.table::copy(example_quantile)[, predicted := NA] -# )))) -# expect_error(suppressMessages(suppressWarnings(check_forecasts( -# data.table::copy(example_quantile)[, observed := NA] -# )))) -# }) - -# test_that("as_forecast() function throws an sample/quantile not present", { -# expect_error(suppressMessages(suppressWarnings(as_forecast( -# data.table::copy(example_quantile)[, quantile := NULL] -# )))) -# }) - -test_that("output of as_forecasts() is accepted as input to score()", { - check <- suppressMessages(as_forecast(example_binary)) - expect_no_error( - score_check <- score(na.omit(check)) - ) - expect_equal(score_check, suppressMessages(score(as_forecast(example_binary)))) -}) diff --git a/tests/testthat/test-forecast.R b/tests/testthat/test-forecast.R new file mode 100644 index 000000000..f8e16646b --- /dev/null +++ b/tests/testthat/test-forecast.R @@ -0,0 +1,237 @@ +test_that("Running `as_forecast()` twice returns the same object", { + ex <- na.omit(example_continuous) + + expect_identical( + as_forecast(as_forecast(ex)), + as_forecast(ex) + ) +}) + +test_that("as_forecast() works as expected", { + test <- na.omit(data.table::copy(example_quantile)) + expect_s3_class(as_forecast(test), "forecast_quantile") + + # expect error when arguments are not correct + expect_error(as_forecast(test, observed = 3), "Must be of type 'character'") + expect_error(as_forecast(test, quantile_level = c("1", "2")), "Must have length 1") + expect_error(as_forecast(test, observed = "missing"), "Must be a subset of") + + # expect no condition with columns already present + expect_no_condition( + as_forecast(test, + observed = "observed", predicted = "predicted", + forecast_unit = c( + "location", "model", "target_type", + "target_end_date", "horizon" + ), + quantile_level = "quantile_level" + ) + ) + + # additional test with renaming the model column + test <- na.omit(data.table::copy(example_continuous)) + data.table::setnames(test, + old = c("observed", "predicted", "sample_id", "model"), + new = c("obs", "pred", "sample", "mod") + ) + expect_no_condition( + as_forecast(test, + observed = "obs", predicted = "pred", model = "mod", + forecast_unit = c( + "location", "model", "target_type", + "target_end_date", "horizon" + ), + sample_id = "sample" + ) + ) + + # test if desired forecast type does not correspond to inferred one + test <- na.omit(data.table::copy(example_continuous)) + expect_error( + as_forecast(test, forecast_type = "quantile"), + "Forecast type determined by scoringutils based on input" + ) +}) + +test_that("as_forecast() function works", { + check <- suppressMessages(as_forecast(example_quantile)) + expect_s3_class(check, "forecast_quantile") +}) + +test_that("as_forecast() function has an error for empty data.frame", { + d <- data.frame(observed = numeric(), predicted = numeric(), model = character()) + + expect_error( + as_forecast(d), + "Assertion on 'data' failed: Must have at least 1 rows, but has 0 rows." + ) +}) + +test_that("as_forecast() errors if there is both a sample_id and a quantile_level column", { + example <- data.table::copy(example_quantile)[, sample_id := 1] + expect_error( + as_forecast(example), + "Found columns `quantile_level` and `sample_id`. Only one of these is allowed" + ) +}) + +test_that("check_columns_present() works", { + expect_equal( + check_columns_present(example_quantile, c("observed", "predicted", "nop")), + "Column 'nop' not found in data" + ) + expect_true( + check_columns_present(example_quantile, c("observed", "predicted")) + ) +}) + +test_that("check_duplicates() works", { + bad <- rbind( + example_quantile[1000:1010], + example_quantile[1000:1010] + ) + + expect_equal(scoringutils:::check_duplicates(bad), + "There are instances with more than one forecast for the same target. This can't be right and needs to be resolved. Maybe you need to check the unit of a single forecast and add missing columns? Use the function get_duplicate_forecasts() to identify duplicate rows" + ) +}) + +# test_that("as_forecast() function returns a message with NA in the data", { +# expect_message( +# { check <- as_forecast(example_quantile) }, +# "\\d+ values for `predicted` are NA" +# ) +# expect_match( +# unlist(check$messages), +# "\\d+ values for `predicted` are NA" +# ) +# }) + +# test_that("as_forecast() function returns messages with NA in the data", { +# example <- data.table::copy(example_quantile) +# example[horizon == 2, observed := NA] +# check <- suppressMessages(as_forecast(example)) +# +# expect_equal(length(check$messages), 2) +# }) + +test_that("as_forecast() function throws an error with duplicate forecasts", { + example <- rbind(example_quantile, + example_quantile[1000:1010]) + + expect_error( + suppressMessages(suppressWarnings(as_forecast(example))), + "Assertion on 'data' failed: There are instances with more than one forecast for the same target. This can't be right and needs to be resolved. Maybe you need to check the unit of a single forecast and add missing columns? Use the function get_duplicate_forecasts() to identify duplicate rows.", #nolint + fixed = TRUE + ) +}) + +test_that("as_forecast() function warns when no model column is present", { + no_model <- data.table::copy(example_quantile[model == "EuroCOVIDhub-ensemble"])[, model := NULL][] + expect_warning( + as_forecast(no_model), + "There is no column called `model` in the data.") +}) + +test_that("as_forecast() function throws an error when no predictions or observed values are present", { + expect_error(suppressMessages(suppressWarnings(as_forecast( + data.table::copy(example_quantile)[, predicted := NULL] + ))), + "Assertion on 'data' failed: Column 'predicted' not found in data.") + + expect_error(suppressMessages(suppressWarnings(as_forecast( + data.table::copy(example_quantile)[, observed := NULL] + ))), + "Assertion on 'data' failed: Column 'observed' not found in data.") + + expect_error(suppressMessages(suppressWarnings(as_forecast( + data.table::copy(example_quantile)[, c("observed", "predicted") := NULL] + ))), + "Assertion on 'data' failed: Columns 'observed', 'predicted' not found in data.") +}) + +# test_that("as_forecast() function throws an error when no predictions or observed values are present", { +# expect_error(suppressMessages(suppressWarnings(as_forecast( +# data.table::copy(example_quantile)[, predicted := NA] +# )))) +# expect_error(suppressMessages(suppressWarnings(check_forecasts( +# data.table::copy(example_quantile)[, observed := NA] +# )))) +# }) + +# test_that("as_forecast() function throws an sample/quantile not present", { +# expect_error(suppressMessages(suppressWarnings(as_forecast( +# data.table::copy(example_quantile)[, quantile := NULL] +# )))) +# }) + +test_that("output of as_forecasts() is accepted as input to score()", { + check <- suppressMessages(as_forecast(example_binary)) + expect_no_error( + score_check <- score(na.omit(check)) + ) + expect_equal(score_check, suppressMessages(score(as_forecast(example_binary)))) +}) + +# ============================================================================== +# is_forecast() +# ============================================================================== + +test_that("is_forecast() works as expected", { + ex_binary <- suppressMessages(as_forecast(example_binary)) + ex_point <- suppressMessages(as_forecast(example_point)) + ex_quantile <- suppressMessages(as_forecast(example_quantile)) + ex_continuous <- suppressMessages(as_forecast(example_continuous)) + + expect_true(is_forecast(ex_binary)) + expect_true(is_forecast(ex_point)) + expect_true(is_forecast(ex_quantile)) + expect_true(is_forecast(ex_continuous)) + + expect_false(is_forecast(1:10)) + expect_false(is_forecast(data.table::as.data.table(example_point))) + expect_false(is_forecast.forecast_sample(ex_quantile)) + expect_false(is_forecast.forecast_quantile(ex_binary)) +}) + + +# ============================================================================== +# validate_forecast() +# ============================================================================== + +test_that("validate_forecast() works as expected", { + # test that by default, `as_forecast()` errors + expect_error(validate_forecast(data.frame(x = 1:10)), + "The input needs to be a forecast object.") +}) + +test_that("validate_forecast.forecast_binary works as expected", { + test <- na.omit(data.table::copy(example_binary)) + test[, "sample_id" := 1:nrow(test)] + + # error if there is a superficial sample_id column + expect_error( + as_forecast(test), + "Input looks like a binary forecast, but an additional column called `sample_id` or `quantile` was found." + ) + + # expect error if probabilties are not in [0, 1] + test <- na.omit(data.table::copy(example_binary)) + test[, "predicted" := predicted + 1] + expect_error( + as_forecast(test), + "Input looks like a binary forecast, but found the following issue" + ) +}) + +test_that("validate_forecast.forecast_point() works as expected", { + test <- na.omit(data.table::copy(example_point)) + test <- as_forecast(test) + + # expect an error if column is changed to character after initial validation. + test <- test[, "predicted" := as.character(predicted)] + expect_error( + validate_forecast(test), + "Input looks like a point forecast, but found the following issue" + ) +}) diff --git a/tests/testthat/test-get_-functions.R b/tests/testthat/test-get_-functions.R index aa2bb7430..a5a23bd1e 100644 --- a/tests/testthat/test-get_-functions.R +++ b/tests/testthat/test-get_-functions.R @@ -182,3 +182,110 @@ test_that("get_forecast_type() works as expected", { fixed = TRUE ) }) + + +# ============================================================================== +# `get_coverage()` +# ============================================================================== +ex_coverage <- example_quantile[model == "EuroCOVIDhub-ensemble"] + +test_that("get_coverage() works as expected", { + cov <- example_quantile %>% + na.omit() %>% + as_forecast() %>% + get_coverage(by = get_forecast_unit(example_quantile)) + + expect_equal( + sort(colnames(cov)), + sort(c(get_forecast_unit(example_quantile), c( + "interval_range", "quantile_level", "interval_coverage", "interval_coverage_deviation", + "quantile_coverage", "quantile_coverage_deviation" + ))) + ) + + expect_equal(nrow(cov), nrow(na.omit(example_quantile))) +}) + +test_that("get_coverage() outputs an object of class c('data.table', 'data.frame'", { + ex <- as_forecast(na.omit(example_quantile)) + cov <- get_coverage(ex) + expect_s3_class(cov, c("data.table", "data.frame"), exact = TRUE) +}) + +test_that("get_coverage() can deal with non-symmetric prediction intervals", { + # the expected result is that `get_coverage()` just works. However, + # all interval coverages with missing values should just be `NA` + test <- data.table::copy(example_quantile) %>% + na.omit() %>% + as_forecast() + test <- test[!quantile_level %in% c(0.2, 0.3, 0.5)] + + expect_no_condition(cov <- get_coverage(test)) + + prediction_intervals <- get_range_from_quantile(c(0.2, 0.3, 0.5)) + + missing <- cov[interval_range %in% prediction_intervals] + not_missing <- cov[!interval_range %in% prediction_intervals] + + expect_true(all(is.na(missing$interval_coverage))) + expect_false(any(is.na(not_missing))) + + # test for a version where values are not missing, but just `NA` + # since `get_coverage()` calls `na.omit`, the result should be the same. + test <- data.table::copy(example_quantile) %>% + na.omit() %>% + as_forecast() %>% + suppressMessages() + test <- test[quantile_level %in% c(0.2, 0.3, 0.5), predicted := NA] + cov2 <- get_coverage(test) + expect_equal(cov, cov2) +}) + + +# ============================================================================== +# `get_forecast_counts()` +# ============================================================================== +test_that("get_forecast_counts() works as expected", { + af <- suppressMessages(as_forecast(example_quantile)) + af <- get_forecast_counts( + af, + by = c("model", "target_type", "target_end_date") + ) + + expect_type(af, "list") + expect_type(af$target_type, "character") + expect_type(af$`count`, "integer") + expect_equal(nrow(af[is.na(`count`)]), 0) + af <- na.omit(example_quantile) %>% + as_forecast() %>% + get_forecast_counts(by = "model") + expect_equal(nrow(af), 4) + expect_equal(af$`count`, c(256, 256, 128, 247)) + + # Ensure the returning object class is exactly same as a data.table. + expect_s3_class(af, c("data.table", "data.frame"), exact = TRUE) + + # Setting `collapse = c()` means that all quantiles and samples are counted + af <- na.omit(example_quantile) %>% + as_forecast() %>% + get_forecast_counts(by = "model", collapse = c()) + expect_equal(nrow(af), 4) + expect_equal(af$`count`, c(5888, 5888, 2944, 5681)) + + # setting by = NULL, the default, results in by equal to forecast unit + af <- na.omit(example_quantile) %>% + as_forecast() %>% + get_forecast_counts() + expect_equal(nrow(af), 50688) + + # check whether collapsing also works for model-based forecasts + af <- na.omit(example_integer) %>% + as_forecast() %>% + get_forecast_counts(by = "model") + expect_equal(nrow(af), 4) + + af <- na.omit(example_integer) %>% + as_forecast() %>% + get_forecast_counts(by = "model", collapse = c()) + expect_equal(af$count, c(10240, 10240, 5120, 9880)) +}) diff --git a/tests/testthat/test-metrics-validate.R b/tests/testthat/test-metrics-validate.R new file mode 100644 index 000000000..b75b40a90 --- /dev/null +++ b/tests/testthat/test-metrics-validate.R @@ -0,0 +1,18 @@ +test_that("validate_metrics() works as expected", { + test_fun <- function(x, y, ...) { + if (hasArg("test")) { + message("test argument found") + } + return(y) + } + ## Additional tests for validate_metrics() + # passing in something that's not a function or a known metric + expect_warning( + expect_warning( + score(as_forecast(na.omit(example_binary)), metrics = list( + "test1" = test_fun, "test" = test_fun, "hi" = "hi", "2" = 3) + ), + "`Metrics` element number 3 is not a valid function" + ), + "`Metrics` element number 4 is not a valid function") +}) diff --git a/tests/testthat/test-pairwise_comparison.R b/tests/testthat/test-pairwise_comparison.R index 3bc3b6e77..1892a6fbe 100644 --- a/tests/testthat/test-pairwise_comparison.R +++ b/tests/testthat/test-pairwise_comparison.R @@ -448,3 +448,45 @@ test_that("compare_two_models() throws error with wrong inputs", { ) }) +test_that("add_pairwise_comparison() works with point forecasts", { + expect_no_condition( + pw_point <- add_pairwise_comparison( + scores_point, + metric = "se_point" + ) + ) + pw_point <- summarise_scores(pw_point, by = "model") + + pw_manual <- pairwise_comparison( + scores_point, by = "model", metric = "se_point" + ) + + expect_equal( + pw_point$relative_skill, + unique(pw_manual$relative_skill) + ) +}) + +test_that("add_pairwise_comparison() can compute relative measures", { + scores_with <- add_pairwise_comparison( + scores_quantile, + ) + scores_with <- summarise_scores(scores_with, by = "model") + + expect_equal( + scores_with[, wis_relative_skill], + c(1.6, 0.81, 0.75, 1.03), tolerance = 0.01 + ) + + scores_with <- add_pairwise_comparison( + scores_quantile, by = "model", + metric = "ae_median" + ) + scores_with <- summarise_scores(scores_with, by = "model") + + expect_equal( + scores_with[, ae_median_relative_skill], + c(1.6, 0.78, 0.77, 1.04), tolerance = 0.01 + ) +}) + diff --git a/tests/testthat/test-print.R b/tests/testthat/test-print.R new file mode 100644 index 000000000..11a622f12 --- /dev/null +++ b/tests/testthat/test-print.R @@ -0,0 +1,56 @@ +test_that("print() works on forecast_* objects", { + # Check print works on each forecast object + test_dat <- list(na.omit(example_binary), na.omit(example_quantile), + na.omit(example_point), na.omit(example_continuous), na.omit(example_integer)) + for (dat in test_dat){ + dat <- as_forecast(dat) + forecast_type <- get_forecast_type(dat) + forecast_unit <- get_forecast_unit(dat) + + # Check Forecast type + expect_snapshot(print(dat)) + expect_snapshot(print(dat)) + # Check Forecast unit + expect_snapshot(print(dat)) + expect_snapshot(print(dat)) + + # Check print.data.table works. + output_original <- capture.output(print(dat)) + output_test <- capture.output(print(data.table(dat))) + expect_contains(output_original, output_test) + } +}) + +test_that("print methods fail gracefully", { + test <- as_forecast(na.omit(example_quantile)) + test$observed <- NULL + + # message if forecast type can't be computed + expect_warning( + expect_message( + expect_output( + print(test), + pattern = "Forecast unit:" + ), + "Could not determine forecast type due to error in validation." + ), + "Error in validating forecast object:" + ) + + # message if forecast unit can't be computed + test <- 1:10 + class(test) <- "forecast_point" + expect_warning( + expect_message( + expect_message( + expect_output( + print(test), + pattern = "Forecast unit:" + ), + "Could not determine forecast unit." + ), + "Could not determine forecast type" + ), + "Error in validating forecast object:" + ) +}) diff --git a/tests/testthat/test-score.R b/tests/testthat/test-score.R index f569ad6d3..9503906b1 100644 --- a/tests/testthat/test-score.R +++ b/tests/testthat/test-score.R @@ -150,18 +150,6 @@ test_that( "something" ) - - ## Additional tests for validate_metrics() - # passing in something that's not a function or a known metric - expect_warning( - expect_warning( - score(df, metrics = list( - "test1" = test_fun, "test" = test_fun, "hi" = "hi", "2" = 3) - ), - "`Metrics` element number 3 is not a valid function" - ), - "`Metrics` element number 4 is not a valid function") - # passing a single named argument for metrics by position expect_contains( names(score(df, list("hi" = test_fun))), diff --git a/tests/testthat/test-summarise_scores.R b/tests/testthat/test-summarise_scores.R index b95d56fa0..8302f5814 100644 --- a/tests/testthat/test-summarise_scores.R +++ b/tests/testthat/test-summarise_scores.R @@ -31,25 +31,6 @@ test_that("summarise_scores() handles wrong by argument well", { ) }) -test_that("summarise_scores() works with point forecasts", { - expect_no_condition( - pw_point <- add_pairwise_comparison( - scores_point, - metric = "se_point" - ) - ) - pw_point <- summarise_scores(pw_point, by = "model") - - pw_manual <- pairwise_comparison( - scores_point, by = "model", metric = "se_point" - ) - - expect_equal( - pw_point$relative_skill, - unique(pw_manual$relative_skill) - ) -}) - test_that("summarise_scores() handles the `metrics` attribute correctly", { test <- data.table::copy(scores_quantile) attr(test, "metrics") <- NULL @@ -68,29 +49,6 @@ test_that("summarise_scores() handles the `metrics` attribute correctly", { ) }) -test_that("summarise_scores() can compute relative measures", { - scores_with <- add_pairwise_comparison( - scores_quantile, - ) - scores_with <- summarise_scores(scores_with, by = "model") - - expect_equal( - scores_with[, wis_relative_skill], - c(1.6, 0.81, 0.75, 1.03), tolerance = 0.01 - ) - - scores_with <- add_pairwise_comparison( - scores_quantile, by = "model", - metric = "ae_median" - ) - scores_with <- summarise_scores(scores_with, by = "model") - - expect_equal( - scores_with[, ae_median_relative_skill], - c(1.6, 0.78, 0.77, 1.04), tolerance = 0.01 - ) -}) - test_that("summarise_scores() across argument works as expected", { ex <- data.table::copy(example_quantile) ex <- suppressMessages(as_forecast(ex)) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index f7c6444bb..907d303ec 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -78,65 +78,3 @@ test_that("get_metrics() works as expected", { "scores have been previously computed, but are no longer column names" ) }) - - -# ============================================================================== -# print -# ============================================================================== - -test_that("print() works on forecast_* objects", { - # Check print works on each forecast object - test_dat <- list(na.omit(example_binary), na.omit(example_quantile), - na.omit(example_point), na.omit(example_continuous), na.omit(example_integer)) - for (dat in test_dat){ - dat <- as_forecast(dat) - forecast_type <- get_forecast_type(dat) - forecast_unit <- get_forecast_unit(dat) - - # Check Forecast type - expect_snapshot(print(dat)) - expect_snapshot(print(dat)) - # Check Forecast unit - expect_snapshot(print(dat)) - expect_snapshot(print(dat)) - - # Check print.data.table works. - output_original <- capture.output(print(dat)) - output_test <- capture.output(print(data.table(dat))) - expect_contains(output_original, output_test) - } -}) - -test_that("print methods fail gracefully", { - test <- as_forecast(na.omit(example_quantile)) - test$observed <- NULL - - # message if forecast type can't be computed - expect_warning( - expect_message( - expect_output( - print(test), - pattern = "Forecast unit:" - ), - "Could not determine forecast type due to error in validation." - ), - "Error in validating forecast object:" - ) - - # message if forecast unit can't be computed - test <- 1:10 - class(test) <- "forecast_point" - expect_warning( - expect_message( - expect_message( - expect_output( - print(test), - pattern = "Forecast unit:" - ), - "Could not determine forecast unit." - ), - "Could not determine forecast type" - ), - "Error in validating forecast object:" - ) -})