Skip to content

Commit

Permalink
Merge pull request #710 from epiforecasts/rules-to-metrics
Browse files Browse the repository at this point in the history
Issue #709: Rename `rules` to `metrics`
  • Loading branch information
nikosbosse authored Mar 9, 2024
2 parents 5db9070 + cab02d6 commit 3e434e4
Show file tree
Hide file tree
Showing 17 changed files with 150 additions and 150 deletions.
10 changes: 5 additions & 5 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ export(logs_binary)
export(logs_sample)
export(mad_sample)
export(merge_pred_and_obs)
export(metrics_binary)
export(metrics_point)
export(metrics_quantile)
export(metrics_sample)
export(new_forecast)
export(overprediction)
export(pairwise_comparison)
Expand All @@ -64,15 +68,11 @@ export(plot_score_table)
export(plot_wis)
export(quantile_score)
export(quantile_to_interval)
export(rules_binary)
export(rules_point)
export(rules_quantile)
export(rules_sample)
export(run_safely)
export(sample_to_quantile)
export(score)
export(se_mean_sample)
export(select_rules)
export(select_metrics)
export(set_forecast_unit)
export(summarise_scores)
export(summarize_scores)
Expand Down
12 changes: 6 additions & 6 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,21 @@ The update introduces breaking changes. If you want to keep using the older vers
- In `score()`, required columns "true_value" and "prediction" were renamed and replaced by required columns "observed" and "predicted" and "model" (which now is a required column). Scoring functions now also use the function arguments "observed" and "predicted" everywhere consistently.
- The overall scoring workflow was updated. Most functions now operate on forecast objects, which can be created using the function `as_forecast()`. This function replaces the previous `check_forecast()` function and validates the inputs. `as_forecast()` also allows users to rename required columns and specify the forecast unit in a single step, taking over the functionality of `set_forecast_unit()` in most cases. `score()` is now a generic function that dispatches the correct method based on the forecast type. Forecast types currently supported are "binary", "point", "sample" and "quantile" with corresponding classes "forecast_binary", "forecast_point", "forecast_sample" and "forecast_quantile".
- `set_forecast_unit()` now errors if any of the values in `forecast_unit` are not columns of the data.
- Scoring rules (functions used for scoring) received a consistent interface and input checks:
- Scoring rules for binary forecasts:
- All scoring functions exported by the package received a consistent interface and input checks:
- Metrics and scoring rules for binary forecasts:
- `observed`: factor with exactly 2 levels
- `predicted`: numeric, vector with probabilities
- Scoring rules for point forecasts:
- Metrics and scoring rules for point forecasts:
- `observed`: numeric vector
- `predicted`: numeric vector
- Scoring rules for sample-based forecasts:
- Metrics and scoring rules for sample-based forecasts:
- `observed`: numeric, either a scalar or a vector
- `predicted`: numeric, a vector (if `observed` is a scalar) or a matrix (if `observed` is a vector)
- Scoring rules for quantile-based forecasts:
- Metrics and scoring rules for quantile-based forecasts:
- `observed`: numeric, either a scalar or a vector
- `predicted`: numeric, a vector (if `observed` is a scalar) or a matrix (if `observed` is a vector)
- `quantile_level`: numeric, a vector with quantile-levels. Can alternatively be a matrix of the same shape as `predicted`.
- Users can now supply their own scoring rules to `score()` as a list of functions. Default scoring rules can be accessed using the functions `rules_point()`, `rules_sample()`, `rules_quantile()` and `rules_binary()`, which return a named list of scoring rules suitable for the respective forecast type. Column names of scores in the output of `score()` correspond to the names of the scoring rules (i.e. the names of the functions in the list of scoring rules).
- Users can now supply their own metrics and scoring rules to `score()` as a list of functions. Default scoring rules can be accessed using the functions `metrics_point()`, `metrics_sample()`, `metrics_quantile()` and `metrics_binary()`, which return a named list of scoring rules suitable for the respective forecast type. Column names of scores in the output of `score()` correspond to the names of the scoring rules (i.e. the names of the functions in the list of metrics).
- `score()` now returns objects of class `scores` with a stored attribute `metrics` that holds the names of the scoring rules that were used. Users can call `get_metrics()` to access the names of those scoring rules.
- `check_forecasts()` was replaced by a different workflow. There now is a function, `as_forecast()`, that determines forecast type of the data, constructs a forecasting object and validates it using the function `validate_forecast()` (a generic that dispatches the correct method based on the forecast type). Objects of class `forecast_binary`, `forecast_point`, `forecast_sample` and `forecast_quantile` have print methods that fulfill the functionality of `check_forecasts()`.
- Users can test whether an object is of class `forecast_*()` using the function `is_forecast()`. Users can also test for a specific `forecast_*` class using the appropriate `is_forecast.forecast_*` method. For example, to check whether an object is of class `forecast_quantile`, you would use you would use `scoringutils:::is_forecast.forecast_quantile()`.
Expand Down
80 changes: 40 additions & 40 deletions R/default-scoring-rules.R
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
#' @title Select Scoring Rules From A List of Possible Scoring Rules
#' @title Select Metrics From A List of Functions
#' @description Helper function to return only the scoring rules selected by
#' the user from a list of possible scoring rules.
#' @param rules A list of scoring rules.
#' the user from a list of possible functions.
#' @param metrics A list of scoring functions.
#' @param select A character vector of scoring rules to select from the list.
#' If `select` is `NULL` (the default), all possible scoring rules are returned.
#' @param exclude A character vector of scoring rules to exclude from the list.
Expand All @@ -11,84 +11,84 @@
#' @importFrom checkmate assert_subset assert_list
#' @export
#' @examples
#' select_rules(
#' rules = rules_binary(),
#' select_metrics(
#' metrics = metrics_binary(),
#' select = "brier_score"
#' )
#' select_rules(
#' rules = rules_binary(),
#' select_metrics(
#' metrics = metrics_binary(),
#' exclude = "log_score"
#' )
select_rules <- function(rules, select = NULL, exclude = NULL) {
select_metrics <- function(metrics, select = NULL, exclude = NULL) {
assert_character(x = c(select, exclude), null.ok = TRUE)
assert_list(rules, names = "named")
allowed <- names(rules)
assert_list(metrics, names = "named")
allowed <- names(metrics)

if (is.null(select) && is.null(exclude)) {
return(rules)
return(metrics)
} else if (is.null(select)) {
assert_subset(exclude, allowed)
select <- allowed[!allowed %in% exclude]
return(rules[select])
return(metrics[select])
} else {
assert_subset(select, allowed)
return(rules[select])
return(metrics[select])
}
}


#' @title Scoring Rules for Binary Forecasts
#' @title Default Metrics And Scoring Rules for Binary Forecasts
#' @description Helper function that returns a named list of default
#' scoring rules suitable for binary forecasts.
#'
#' The default scoring rules are:
#' - "brier_score" = [brier_score()]
#' - "log_score" = [logs_binary()]
#' @inherit select_rules params return
#' @inherit select_metrics params return
#' @export
#' @keywords metric
#' @examples
#' rules_binary()
#' rules_binary(select = "brier_score")
#' rules_binary(exclude = "log_score")
rules_binary <- function(select = NULL, exclude = NULL) {
#' metrics_binary()
#' metrics_binary(select = "brier_score")
#' metrics_binary(exclude = "log_score")
metrics_binary <- function(select = NULL, exclude = NULL) {
all <- list(
brier_score = brier_score,
log_score = logs_binary
)
selected <- select_rules(all, select, exclude)
selected <- select_metrics(all, select, exclude)
return(selected)
}


#' @title Scoring Rules for Point Forecasts
#' @title Default Metrics And Scoring Rules for Point Forecasts
#' @description Helper function that returns a named list of default
#' scoring rules suitable for point forecasts.
#'
#' The default scoring rules are:
#' - "ae_point" = [ae()][Metrics::ae()]
#' - "se_point" = [se()][Metrics::se()]
#' - "ape" = [ape()][Metrics::ape()]
#' @inherit select_rules params return
#' @inherit select_metrics params return
#' @export
#' @keywords metric
#' @examples
#' rules_point()
#' rules_point(select = "ape")
rules_point <- function(select = NULL, exclude = NULL) {
#' metrics_point()
#' metrics_point(select = "ape")
metrics_point <- function(select = NULL, exclude = NULL) {
all <- list(
ae_point = Metrics::ae,
se_point = Metrics::se,
ape = Metrics::ape
)
selected <- select_rules(all, select, exclude)
selected <- select_metrics(all, select, exclude)
return(selected)
}


#' @title Scoring Rules for Sample-Based Forecasts
#' @title Default Metrics And Scoring Rules for Sample-Based Forecasts
#' @description Helper function that returns a named list of default
#' scoring rules suitable for forecasts in a sample-based format
#' scoring rules suitable for forecasts in a sample-based format.
#'
#' The default scoring rules are:
#' - "mad" = [mad_sample()]
Expand All @@ -99,13 +99,13 @@ rules_point <- function(select = NULL, exclude = NULL) {
#' - "mad" = [mad_sample()]
#' - "ae_median" = [ae_median_sample()]
#' - "se_mean" = [se_mean_sample()]
#' @inherit select_rules params return
#' @inherit select_metrics params return
#' @export
#' @keywords metric
#' @examples
#' rules_sample()
#' rules_sample(select = "mad")
rules_sample <- function(select = NULL, exclude = NULL) {
#' metrics_sample()
#' metrics_sample(select = "mad")
metrics_sample <- function(select = NULL, exclude = NULL) {
all <- list(
bias = bias_sample,
dss = dss_sample,
Expand All @@ -115,14 +115,14 @@ rules_sample <- function(select = NULL, exclude = NULL) {
ae_median = ae_median_sample,
se_mean = se_mean_sample
)
selected <- select_rules(all, select, exclude)
selected <- select_metrics(all, select, exclude)
return(selected)
}


#' @title Scoring Rules for Quantile-Based Forecasts
#' @title Default Metrics And Scoring Rules for Quantile-Based Forecasts
#' @description Helper function that returns a named list of default
#' scoring rules suitable for forecasts in a quantile-based format
#' scoring rules suitable for forecasts in a quantile-based format.
#'
#' The default scoring rules are:
#' - "wis" = [wis]
Expand All @@ -144,13 +144,13 @@ rules_sample <- function(select = NULL, exclude = NULL) {
#' accept get passed on to it. `interval_range = 90` is set in the function definition,
#' as passing an argument `interval_range = 90` to [score()] would mean it would also
#' get passed to `interval_coverage_50`.
#' @inherit select_rules params return
#' @inherit select_metrics params return
#' @export
#' @keywords metric
#' @examples
#' rules_quantile()
#' rules_quantile(select = "wis")
rules_quantile <- function(select = NULL, exclude = NULL) {
#' metrics_quantile()
#' metrics_quantile(select = "wis")
metrics_quantile <- function(select = NULL, exclude = NULL) {
all <- list(
wis = wis,
overprediction = overprediction,
Expand All @@ -164,6 +164,6 @@ rules_quantile <- function(select = NULL, exclude = NULL) {
interval_coverage_deviation = interval_coverage_deviation,
ae_median = ae_median_quantile
)
selected <- select_rules(all, select, exclude)
selected <- select_metrics(all, select, exclude)
return(selected)
}
2 changes: 1 addition & 1 deletion R/get_-functions.R
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ get_type <- function(x) {
}


#' @title Get Names Of The Scoring Rules That Were Used For Scoring
#' @title Get Names Of The Metrics That Were Used For Scoring
#' @description
#' When applying a scoring rule via [score()], the names of the scoring rules
#' become column names of the
Expand Down
24 changes: 12 additions & 12 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
#' @param data A forecast object (a validated data.table with predicted and
#' observed values, see [as_forecast()])
#' @param metrics A named list of scoring functions. Names will be used as
#' column names in the output. See [rules_point()], [rules_binary()],
#' [rules_quantile()], and [rules_sample()] for more information on the
#' column names in the output. See [metrics_point()], [metrics_binary()],
#' [metrics_quantile()], and [metrics_sample()] for more information on the
#' default metrics used.
#' @param ... additional arguments
#' @return An object of class `scores`. This object is a data.table with
Expand Down Expand Up @@ -81,13 +81,13 @@ score.default <- function(data, metrics, ...) {
#' @importFrom data.table setattr copy
#' @rdname score
#' @export
score.forecast_binary <- function(data, metrics = rules_binary(), ...) {
score.forecast_binary <- function(data, metrics = metrics_binary(), ...) {
data <- copy(data)
suppressWarnings(suppressMessages(validate_forecast(data)))
data <- na.omit(data)
metrics <- validate_metrics(metrics)

scores <- apply_rules(
scores <- apply_metrics(
data, metrics,
data$observed, data$predicted, ...
)
Expand All @@ -102,13 +102,13 @@ score.forecast_binary <- function(data, metrics = rules_binary(), ...) {
#' @importFrom data.table setattr copy
#' @rdname score
#' @export
score.forecast_point <- function(data, metrics = rules_point(), ...) {
score.forecast_point <- function(data, metrics = metrics_point(), ...) {
data <- copy(data)
suppressWarnings(suppressMessages(validate_forecast(data)))
data <- na.omit(data)
metrics <- validate_metrics(metrics)

scores <- apply_rules(
scores <- apply_metrics(
data, metrics,
data$observed, data$predicted, ...
)
Expand All @@ -121,7 +121,7 @@ score.forecast_point <- function(data, metrics = rules_point(), ...) {
#' @importFrom data.table setattr copy
#' @rdname score
#' @export
score.forecast_sample <- function(data, metrics = rules_sample(), ...) {
score.forecast_sample <- function(data, metrics = metrics_sample(), ...) {
data <- copy(data)
suppressWarnings(suppressMessages(validate_forecast(data)))
data <- na.omit(data)
Expand All @@ -144,7 +144,7 @@ score.forecast_sample <- function(data, metrics = rules_sample(), ...) {
predicted <- do.call(rbind, data$predicted)
data[, c("observed", "predicted", "scoringutils_N") := NULL]

data <- apply_rules(
data <- apply_metrics(
data, metrics,
observed, predicted, ...
)
Expand All @@ -160,7 +160,7 @@ score.forecast_sample <- function(data, metrics = rules_sample(), ...) {
#' @importFrom data.table `:=` as.data.table rbindlist %like% setattr copy
#' @rdname score
#' @export
score.forecast_quantile <- function(data, metrics = rules_quantile(), ...) {
score.forecast_quantile <- function(data, metrics = metrics_quantile(), ...) {
data <- copy(data)
suppressWarnings(suppressMessages(validate_forecast(data)))
data <- na.omit(data)
Expand Down Expand Up @@ -190,7 +190,7 @@ score.forecast_quantile <- function(data, metrics = rules_quantile(), ...) {
"observed", "predicted", "quantile_level", "scoringutils_quantile_level"
) := NULL]

data <- apply_rules(
data <- apply_metrics(
data, metrics,
observed, predicted, quantile_level, ...
)
Expand All @@ -206,15 +206,15 @@ score.forecast_quantile <- function(data, metrics = rules_quantile(), ...) {

#' @title Apply A List Of Functions To A Data Table Of Forecasts
#' @description This helper function applies scoring rules (stored as a list of
#' functions) to a data table of forecasts. `apply_rules` is used within
#' functions) to a data table of forecasts. `apply_metrics` is used within
#' `score()` to apply all scoring rules to the data.
#' Scoring rules are wrapped in [run_safely()] to catch errors and to make
#' sure that only arguments are passed to the scoring rule that are actually
#' accepted by it.
#' @inheritParams score
#' @return A data table with the forecasts and the calculated metrics
#' @keywords internal
apply_rules <- function(data, metrics, ...) {
apply_metrics <- function(data, metrics, ...) {
expr <- expression(
data[, (metric_name) := do.call(run_safely, list(..., fun = fun))]
)
Expand Down
8 changes: 4 additions & 4 deletions R/z_globalVariables.R
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ globalVariables(c(
"metric",
"metrics_select",
"metrics",
"rules_binary",
"rules_point",
"rules_quantile",
"rules_sample",
"metrics_binary",
"metrics_point",
"metrics_quantile",
"metrics_sample",
"model",
"n_obs",
"n_obs wis_component_name",
Expand Down
12 changes: 6 additions & 6 deletions man/apply_rules.Rd → man/apply_metrics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 3e434e4

Please sign in to comment.