From 539e5a7430028a9e3225aada281df61d069252a6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Wed, 25 Sep 2024 17:52:32 -0400 Subject: [PATCH 1/5] quantile -> quantile_levels for #1203 --- NEWS.md | 3 +++ R/predict.R | 2 +- R/predict_quantile.R | 21 ++++++++++++++++----- man/other_predict.Rd | 5 ++--- man/set_args.Rd | 2 +- tests/testthat/test-linear_reg_quantreg.R | 5 +++++ 6 files changed, 28 insertions(+), 10 deletions(-) diff --git a/NEWS.md b/NEWS.md index e2a63b619..004da1ef5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,9 @@ * New `extract_fit_time()` method has been added that returns the time it took to train the model (#853). +## Breaking Change + +* For quantile prediction, the `predict()` argument has been changed from `quantile` to `quantile_levels` for consistency. This does not affect models with mode `"quantile regression"`. # parsnip 1.2.1 diff --git a/R/predict.R b/R/predict.R index 3a2681048..397b92112 100644 --- a/R/predict.R +++ b/R/predict.R @@ -344,7 +344,7 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env()) # ---------------------------------------------------------------------------- - other_args <- c("interval", "level", "std_error", "quantile", + other_args <- c("interval", "level", "std_error", "quantile_levels", "time", "eval_time", "increasing") is_pred_arg <- names(the_dots) %in% other_args if (any(!is_pred_arg)) { diff --git a/R/predict_quantile.R b/R/predict_quantile.R index fc2d91b15..a3950a75c 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -1,14 +1,13 @@ #' @keywords internal #' @rdname other_predict -#' @param quantile A vector of numbers between 0 and 1 for the quantile being -#' predicted. +#' @param quantile_levels A vector of values between zero and one. #' @inheritParams predict.model_fit #' @method predict_quantile model_fit #' @export predict_quantile.model_fit #' @export predict_quantile.model_fit <- function(object, new_data, - quantile = (1:9)/10, + quantile_levels = NULL, interval = "none", level = 0.95, ...) { @@ -20,6 +19,20 @@ predict_quantile.model_fit <- function(object, return(NULL) } + if (object$spec$mode != "quantile regression") { + if (is.null(quantile_levels)) { + quantile_levels <- (1:9)/10 + } + hardhat::check_quantile_levels(quantile_levels) + # Pass some extra arguments to be used in post-processor + object$quantile_levels <- quantile_levels + } else { + if (!is.null(quantile_levels)) { + cli::cli_abort("{.arg quantile_levels} are specified by {.fn set_mode} + when the mode is {.val quantile regression}.") + } + } + new_data <- prepare_data(object, new_data) # preprocess data @@ -27,8 +40,6 @@ predict_quantile.model_fit <- function(object, new_data <- object$spec$method$pred$quantile$pre(new_data, object) } - # Pass some extra arguments to be used in post-processor - object$spec$method$pred$quantile$args$p <- quantile pred_call <- make_pred_call(object$spec$method$pred$quantile) res <- eval_tidy(pred_call) diff --git a/man/other_predict.Rd b/man/other_predict.Rd index 6c997e28d..d1342d87f 100644 --- a/man/other_predict.Rd +++ b/man/other_predict.Rd @@ -49,7 +49,7 @@ predict_numeric(object, ...) \method{predict_quantile}{model_fit}( object, new_data, - quantile = (1:9)/10, + quantile_levels = NULL, interval = "none", level = 0.95, ... @@ -103,8 +103,7 @@ interval estimates.} \item{std_error}{A single logical for whether the standard error should be returned (assuming that the model can compute it).} -\item{quantile}{A vector of numbers between 0 and 1 for the quantile being -predicted.} +\item{quantile_levels}{A vector of values between zero and one.} } \description{ These are internal functions not meant to be directly called by the user. diff --git a/man/set_args.Rd b/man/set_args.Rd index 6d3b60f3d..b31e4ad4c 100644 --- a/man/set_args.Rd +++ b/man/set_args.Rd @@ -21,7 +21,7 @@ set_mode(object, mode, ...) "regression")} \item{quantile_levels}{A vector of values between zero and one (only for the -\verb{quantile regression } mode); otherwise, it is \code{NULL}. The model uses these +\code{"quantile regression"} mode); otherwise, it is \code{NULL}. The model uses these values to appropriately train quantile regression models to make predictions for these values (e.g., \code{quantile_levels = 0.5} is the median).} } diff --git a/tests/testthat/test-linear_reg_quantreg.R b/tests/testthat/test-linear_reg_quantreg.R index 7edc7c3a5..2cce1f6ae 100644 --- a/tests/testthat/test-linear_reg_quantreg.R +++ b/tests/testthat/test-linear_reg_quantreg.R @@ -83,6 +83,11 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_named(ten_quant_df, c(".pred_quantile", ".quantile_levels", ".row")) expect_true(nrow(ten_quant_df) == nrow(sac_test) * 10) + expect_snapshot( + ten_quant_pred <- predict(ten_quant, new_data = sac_test), + error = TRUE + ) + ### ten_quant_one_row <- predict(ten_quant, new_data = sac_test[1,]) From 623a9c768c0b8a5c9ad3c414f6a39bc442b0a50d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=98topepo=E2=80=99?= <‘mxkuhn@gmail.com’> Date: Wed, 25 Sep 2024 17:52:48 -0400 Subject: [PATCH 2/5] defer test until censored updates in new PR --- tests/testthat/test-surv_reg_survreg.R | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/testthat/test-surv_reg_survreg.R b/tests/testthat/test-surv_reg_survreg.R index dbb279998..c6a051884 100644 --- a/tests/testthat/test-surv_reg_survreg.R +++ b/tests/testthat/test-surv_reg_survreg.R @@ -10,8 +10,6 @@ complete_form <- survival::Surv(time) ~ group # ------------------------------------------------------------------------------ test_that('survival execution', { - skip_on_travis() - rlang::local_options(lifecycle_verbosity = "quiet") surv_basic <- surv_reg() %>% set_engine("survival") surv_lnorm <- surv_reg(dist = "lognormal") %>% set_engine("survival") @@ -46,7 +44,7 @@ test_that('survival execution', { }) test_that('survival prediction', { - skip_on_travis() + skip_if_not_installed("censored", minimum_version = "0.3.2.9001") rlang::local_options(lifecycle_verbosity = "quiet") surv_basic <- surv_reg() %>% set_engine("survival") @@ -67,7 +65,7 @@ test_that('survival prediction', { apply(exp_quant, 1, function(x) tibble(.pred = x, .quantile = (2:4) / 5)) exp_quant <- tibble(.pred = exp_quant) - obs_quant <- predict(res, head(lung), type = "quantile", quantile = (2:4)/5) + obs_quant <- predict(res, head(lung), type = "quantile", quantile_level = (2:4)/5) expect_equal(as.data.frame(exp_quant), as.data.frame(obs_quant)) From ad7e3e0d52b62fef82ba6af86c101b0ce790ba5e Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 26 Sep 2024 13:34:34 -0400 Subject: [PATCH 3/5] update docs for quantile_levels --- R/predict_quantile.R | 16 +++++++++------- man/other_predict.Rd | 4 +++- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/R/predict_quantile.R b/R/predict_quantile.R index a3950a75c..56ec31bde 100644 --- a/R/predict_quantile.R +++ b/R/predict_quantile.R @@ -1,6 +1,8 @@ #' @keywords internal #' @rdname other_predict -#' @param quantile_levels A vector of values between zero and one. +#' @param quantile_levels A vector of values between zero and one for the +#' quantile to be predicted. If the model has a `"censored regression"` mode, +#' this value should be `NULL`. For other modes, the default is `(1:9)/10`. #' @inheritParams predict.model_fit #' @method predict_quantile model_fit #' @export predict_quantile.model_fit @@ -19,18 +21,18 @@ predict_quantile.model_fit <- function(object, return(NULL) } - if (object$spec$mode != "quantile regression") { + if (object$spec$mode == "quantile regression") { + if (!is.null(quantile_levels)) { + cli::cli_abort("When the mode is {.val quantile regression}, + {.arg quantile_levels} are specified by {.fn set_mode}.") + } + } else { if (is.null(quantile_levels)) { quantile_levels <- (1:9)/10 } hardhat::check_quantile_levels(quantile_levels) # Pass some extra arguments to be used in post-processor object$quantile_levels <- quantile_levels - } else { - if (!is.null(quantile_levels)) { - cli::cli_abort("{.arg quantile_levels} are specified by {.fn set_mode} - when the mode is {.val quantile regression}.") - } } new_data <- prepare_data(object, new_data) diff --git a/man/other_predict.Rd b/man/other_predict.Rd index d1342d87f..313ff4d72 100644 --- a/man/other_predict.Rd +++ b/man/other_predict.Rd @@ -103,7 +103,9 @@ interval estimates.} \item{std_error}{A single logical for whether the standard error should be returned (assuming that the model can compute it).} -\item{quantile_levels}{A vector of values between zero and one.} +\item{quantile_levels}{A vector of values between zero and one for the +quantile to be predicted. If the model has a \code{"censored regression"} mode, +this value should be \code{NULL}. For other modes, the default is \code{(1:9)/10}.} } \description{ These are internal functions not meant to be directly called by the user. From 53c091656c8c0e12e6398e0a2d292998c1ee6d10 Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 26 Sep 2024 13:34:52 -0400 Subject: [PATCH 4/5] update test --- tests/testthat/_snaps/linear_reg_quantreg.md | 9 +++++++++ tests/testthat/test-linear_reg_quantreg.R | 2 +- 2 files changed, 10 insertions(+), 1 deletion(-) create mode 100644 tests/testthat/_snaps/linear_reg_quantreg.md diff --git a/tests/testthat/_snaps/linear_reg_quantreg.md b/tests/testthat/_snaps/linear_reg_quantreg.md new file mode 100644 index 000000000..cba265991 --- /dev/null +++ b/tests/testthat/_snaps/linear_reg_quantreg.md @@ -0,0 +1,9 @@ +# linear quantile regression via quantreg - multiple quantiles + + Code + ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0: + 9) / 9) + Condition + Error in `predict_quantile()`: + ! When the mode is "quantile regression", `quantile_levels` are specified by `set_mode()`. + diff --git a/tests/testthat/test-linear_reg_quantreg.R b/tests/testthat/test-linear_reg_quantreg.R index 2cce1f6ae..0785fe7b5 100644 --- a/tests/testthat/test-linear_reg_quantreg.R +++ b/tests/testthat/test-linear_reg_quantreg.R @@ -84,7 +84,7 @@ test_that('linear quantile regression via quantreg - multiple quantiles', { expect_true(nrow(ten_quant_df) == nrow(sac_test) * 10) expect_snapshot( - ten_quant_pred <- predict(ten_quant, new_data = sac_test), + ten_quant_pred <- predict(ten_quant, new_data = sac_test, quantile_levels = (0:9)/9), error = TRUE ) From 3aec2d0785d2f325e4df934ba60455aed0f50237 Mon Sep 17 00:00:00 2001 From: topepo Date: Thu, 26 Sep 2024 13:35:21 -0400 Subject: [PATCH 5/5] disable quantile predictions for surv_reg --- NEWS.md | 1 + R/surv_reg_data.R | 38 -------------------------- tests/testthat/test-surv_reg_survreg.R | 12 +------- 3 files changed, 2 insertions(+), 49 deletions(-) diff --git a/NEWS.md b/NEWS.md index 004da1ef5..dfa96528f 100644 --- a/NEWS.md +++ b/NEWS.md @@ -14,6 +14,7 @@ ## Breaking Change * For quantile prediction, the `predict()` argument has been changed from `quantile` to `quantile_levels` for consistency. This does not affect models with mode `"quantile regression"`. +* The quantile regression prediction type was disabled for the deprecated `surv_reg()` model. # parsnip 1.2.1 diff --git a/R/surv_reg_data.R b/R/surv_reg_data.R index 9313ede22..a37dc50bd 100644 --- a/R/surv_reg_data.R +++ b/R/surv_reg_data.R @@ -59,25 +59,6 @@ set_pred( ) ) -set_pred( - model = "surv_reg", - eng = "flexsurv", - mode = "regression", - type = "quantile", - value = list( - pre = NULL, - post = flexsurv_quant, - func = c(fun = "summary"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "quantile", - quantiles = expr(quantile) - ) - ) -) - # ------------------------------------------------------------------------------ set_model_engine("surv_reg", mode = "regression", eng = "survival") @@ -133,22 +114,3 @@ set_pred( ) ) ) - -set_pred( - model = "surv_reg", - eng = "survival", - mode = "regression", - type = "quantile", - value = list( - pre = NULL, - post = survreg_quant, - func = c(fun = "predict"), - args = - list( - object = expr(object$fit), - newdata = expr(new_data), - type = "quantile", - p = expr(quantile) - ) - ) -) diff --git a/tests/testthat/test-surv_reg_survreg.R b/tests/testthat/test-surv_reg_survreg.R index c6a051884..cda216c51 100644 --- a/tests/testthat/test-surv_reg_survreg.R +++ b/tests/testthat/test-surv_reg_survreg.R @@ -44,7 +44,7 @@ test_that('survival execution', { }) test_that('survival prediction', { - skip_if_not_installed("censored", minimum_version = "0.3.2.9001") + skip_if_not_installed("survival") rlang::local_options(lifecycle_verbosity = "quiet") surv_basic <- surv_reg() %>% set_engine("survival") @@ -59,16 +59,6 @@ test_that('survival prediction', { exp_pred <- predict(extract_fit_engine(res), head(lung)) exp_pred <- tibble(.pred = unname(exp_pred)) expect_equal(exp_pred, predict(res, head(lung))) - - exp_quant <- predict(extract_fit_engine(res), head(lung), p = (2:4)/5, type = "quantile") - exp_quant <- - apply(exp_quant, 1, function(x) - tibble(.pred = x, .quantile = (2:4) / 5)) - exp_quant <- tibble(.pred = exp_quant) - obs_quant <- predict(res, head(lung), type = "quantile", quantile_level = (2:4)/5) - - expect_equal(as.data.frame(exp_quant), as.data.frame(obs_quant)) - })