Skip to content

Commit 297320e

Browse files
topepo‘topepo’dajmcdonsimonpcouchhfrick
authored
Add quantile regression mode (#1209)
* add a quantile regression mode to test with * update type checkers * avoid confusion with global all_models object * add quantile_level argument to set_mode() * initial data for quantreg * some initial tests * fix some issues * enable quantile prediction * tests for quantreg * Quantile predictions output constructor (#1191) * small change to predict checks * add vctrs for quantiles and test, refactor *_rq_preds * revise tests * Apply some of the suggestions from code review Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com> * rename tests on suggestion from code review * export missing funs from vctrs for formatting * convert errors to snapshot tests * pass call through input check * update snapshots for caller_env * rename to parsnip_quantiles, add format snapshot tests * Apply suggestions from @topepo Co-authored-by: Max Kuhn <mxkuhn@gmail.com> * rename parsnip_quantiles to quantile_pred * rename parsnip_quantiles to quantile_pred and add vector probability check * fix: two bugs introduced earlier * add formatting tests for single quantile * replace walk with a loop to avoid "Error in map()" * remove row/col names * adjust quantile_pred format * as_tibble method * updated NEWS file * add PR number * small new update * helper methods * update docs * re-enable quantiles prediction for #1203 * update some tests * no longer needed * use tibble::new_tibble * braces * test as_tibble * remove print methods --------- Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com> Co-authored-by: Max Kuhn <mxkuhn@gmail.com> Co-authored-by: ‘topepo’ <‘mxkuhn@gmail.com’> * quantile regression updates for new hardhat model (#1207) * bump hardhat version * remove parts now in hardhat * update for new hardhat version * quantile_levels (plural now) * news update * typo * rename helper function * run CI on PRs from branches * forgotten remote * actions for edited PRs * plural * expand branch list * export function for censored to use * updated snapshot * remake snapshot * Revert "remake snapshot" This reverts commit 954e326. * updated snapshot * Update R/arguments.R Co-authored-by: Hannah Frick <hfrick@users.noreply.github.com> * typo * changes from reviewer feedback --------- Co-authored-by: ‘topepo’ <‘mxkuhn@gmail.com’> Co-authored-by: Hannah Frick <hfrick@users.noreply.github.com> * Change to `quantile` argument to `quantile levels` (#1208) * quantile -> quantile_levels for #1203 * defer test until censored updates in new PR * update docs for quantile_levels * update test * disable quantile predictions for surv_reg --------- Co-authored-by: ‘topepo’ <‘mxkuhn@gmail.com’> * post conflict merge updates * update news * version bump and fix typo * revert GHA branches * small bug fix * Apply suggestions from code review Co-authored-by: Hannah Frick <hfrick@users.noreply.github.com> Co-authored-by: Emil Hvitfeldt <emil.hvitfeldt@posit.co> * don't export median * add call arg * added documentation on model * add mode * convert error to warning * remove rankdeficient * added skip * add deprecated `quantile` arg back in * remove numeric prediction --------- Co-authored-by: ‘topepo’ <‘mxkuhn@gmail.com’> Co-authored-by: Daniel McDonald <dajmcdon@gmail.com> Co-authored-by: Simon P. Couch <simonpatrickcouch@gmail.com> Co-authored-by: Hannah Frick <hfrick@users.noreply.github.com> Co-authored-by: Emil Hvitfeldt <emil.hvitfeldt@posit.co>
1 parent 5ce414e commit 297320e

35 files changed

+906
-108
lines changed

DESCRIPTION

+8-7
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
Package: parsnip
22
Title: A Common API to Modeling and Analysis Functions
3-
Version: 1.2.1.9002
3+
Version: 1.2.1.9003
44
Authors@R: c(
55
person("Max", "Kuhn", , "max@posit.co", role = c("aut", "cre")),
66
person("Davis", "Vaughan", , "davis@posit.co", role = "aut"),
@@ -25,7 +25,7 @@ Imports:
2525
ggplot2,
2626
globals,
2727
glue,
28-
hardhat (>= 1.4.0),
28+
hardhat (>= 1.4.0.9002),
2929
lifecycle,
3030
magrittr,
3131
pillar,
@@ -40,8 +40,8 @@ Imports:
4040
vctrs (>= 0.6.0),
4141
withr
4242
Suggests:
43-
C50,
4443
bench,
44+
C50,
4545
covr,
4646
dials (>= 1.1.0),
4747
earth,
@@ -69,16 +69,17 @@ Suggests:
6969
xgboost (>= 1.5.0.1)
7070
VignetteBuilder:
7171
knitr
72+
Remotes:
73+
r-lib/sparsevctrs,
74+
tidymodels/hardhat
7275
ByteCompile: true
7376
Config/Needs/website: C50, dbarts, earth, glmnet, keras, kernlab, kknn,
74-
LiblineaR, mgcv, nnet, parsnip, randomForest, ranger, rpart, rstanarm,
75-
tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate,
77+
LiblineaR, mgcv, nnet, parsnip, quantreg, randomForest, ranger, rpart,
78+
rstanarm, tidymodels/tidymodels, tidyverse/tidytemplate, rstudio/reticulate,
7679
xgboost
7780
Config/rcmdcheck/ignore-inconsequential-notes: true
7881
Config/testthat/edition: 3
7982
Encoding: UTF-8
8083
LazyData: true
8184
Roxygen: list(markdown = TRUE)
82-
Remotes:
83-
r-lib/sparsevctrs
8485
RoxygenNote: 7.3.2

NAMESPACE

+1
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,7 @@ export(make_classes)
264264
export(make_engine_list)
265265
export(make_seealso_list)
266266
export(mars)
267+
export(matrix_to_quantile_pred)
267268
export(max_mtry_formula)
268269
export(maybe_data_frame)
269270
export(maybe_matrix)

NEWS.md

+20-4
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,41 @@
11
# parsnip (development version)
22

3+
## New Features
4+
5+
* A new model mode (`"quantile regression"`) was added. Including:
6+
* A `linear_reg()` engine for `"quantreg"`.
7+
* Predictions are encoded via a custom vector type. See [hardhat::quantile_pred()].
8+
* Predicted quantile levels are designated when the new mode is specified. See `?set_mode`.
9+
310
* `fit_xy()` can now take dgCMatrix input for `x` argument (#1121).
411

512
* `fit_xy()` can now take sparse tibbles as data values (#1165).
613

714
* `predict()` can now take dgCMatrix and sparse tibble input for `new_data` argument, and error informatively when model doesn't support it (#1167).
815

9-
* Transitioned package errors and warnings to use cli (#1147 and #1148 by
10-
@shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160,
11-
#1161, #1081).
16+
* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).
17+
18+
## Other Changes
19+
20+
* Transitioned package errors and warnings to use cli (#1147 and #1148 by @shum461, #1153 by @RobLBaker and @wright13, #1154 by @JamesHWade, #1160, #1161, #1081).
1221

1322
* `fit_xy()` currently raises an error for `gen_additive_mod()` model specifications as the default engine (`"mgcv"`) specifies smoothing terms in model formulas. However, some engines specify smooths via additional arguments, in which case the restriction on `fit_xy()` is excessive. parsnip will now only raise an error when fitting a `gen_additive_mod()` with `fit_xy()` when using the `"mgcv"` engine (#775).
1423

1524
* Aligned `null_model()` with other model types; the model type now has an engine argument that defaults to `"parsnip"` and is checked with the same machinery that checks other model types in the package (#1083).
1625

17-
* New `extract_fit_time()` method has been added that returns the time it took to train the model (#853).
26+
## Bug Fixes
1827

1928
* Ensure that `knit_engine_docs()` has the required packages installed (#1156).
2029

2130
* Fixed bug where some models fit using `fit_xy()` couldn't predict (#1166).
2231

32+
## Breaking Change
33+
34+
* For quantile prediction, the `quantile` argument to `predict()` has been deprecate in facor of `quantile_levels`. This does not affect models with mode `"quantile regression"`.
35+
36+
* The quantile regression prediction type was disabled for the deprecated `surv_reg()` model.
37+
38+
2339
# parsnip 1.2.1
2440

2541
* Added a missing `tidy()` method for survival analysis glmnet models (#1086).

R/aaa-import-standalone-types-check.R

+16-5
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,3 @@
1-
# Standalone file: do not edit by hand
2-
# Source: <https://github.com/r-lib/rlang/blob/main/R/standalone-types-check.R>
3-
# ----------------------------------------------------------------------
4-
#
51
# ---
62
# repo: r-lib/rlang
73
# file: standalone-types-check.R
@@ -13,6 +9,9 @@
139
#
1410
# ## Changelog
1511
#
12+
# 2024-08-15:
13+
# - `check_character()` gains an `allow_na` argument (@martaalcalde, #1724)
14+
#
1615
# 2023-03-13:
1716
# - Improved error messages of number checkers (@teunbrand)
1817
# - Added `allow_infinite` argument to `check_number_whole()` (@mgirlich).
@@ -461,15 +460,28 @@ check_formula <- function(x,
461460

462461
# Vectors -----------------------------------------------------------------
463462

463+
# TODO: Figure out what to do with logical `NA` and `allow_na = TRUE`
464+
464465
check_character <- function(x,
465466
...,
467+
allow_na = TRUE,
466468
allow_null = FALSE,
467469
arg = caller_arg(x),
468470
call = caller_env()) {
471+
469472
if (!missing(x)) {
470473
if (is_character(x)) {
474+
if (!allow_na && any(is.na(x))) {
475+
abort(
476+
sprintf("`%s` can't contain NA values.", arg),
477+
arg = arg,
478+
call = call
479+
)
480+
}
481+
471482
return(invisible(NULL))
472483
}
484+
473485
if (allow_null && is_null(x)) {
474486
return(invisible(NULL))
475487
}
@@ -479,7 +491,6 @@ check_character <- function(x,
479491
x,
480492
"a character vector",
481493
...,
482-
allow_na = FALSE,
483494
allow_null = allow_null,
484495
arg = arg,
485496
call = call

R/aaa_models.R

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Initialize model environments
22

3-
all_modes <- c("classification", "regression", "censored regression")
3+
all_modes <- c("classification", "regression", "censored regression", "quantile regression")
44

55
# ------------------------------------------------------------------------------
66

@@ -195,8 +195,8 @@ stop_missing_engine <- function(cls, call) {
195195
}
196196

197197
check_mode_for_new_engine <- function(cls, eng, mode, call = caller_env()) {
198-
all_modes <- get_from_env(paste0(cls, "_modes"))
199-
if (!(mode %in% all_modes)) {
198+
model_modes <- get_from_env(paste0(cls, "_modes"))
199+
if (!(mode %in% model_modes)) {
200200
cli::cli_abort(
201201
"{.val {mode}} is not a known mode for model {.fn {cls}}.",
202202
call = call

R/aaa_quantiles.R

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
#' Reformat quantile predictions
2+
#'
3+
#' @param x A matrix of predictions with rows as samples and columns as quantile
4+
#' levels.
5+
#' @param object A parsnip `model_fit` object from a quantile regression model.
6+
#' @keywords internal
7+
#' @export
8+
matrix_to_quantile_pred <- function(x, object) {
9+
if (!is.matrix(x)) {
10+
x <- as.matrix(x)
11+
}
12+
rownames(x) <- NULL
13+
n_pred_quantiles <- ncol(x)
14+
quantile_levels <- object$spec$quantile_levels
15+
16+
tibble::new_tibble(x = list(.pred_quantile = hardhat::quantile_pred(x, quantile_levels)))
17+
}

R/arguments.R

+21-3
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ check_eng_args <- function(args, obj, core_args) {
4949
#' set_args(mtry = 3, importance = TRUE) %>%
5050
#' set_mode("regression")
5151
#'
52+
#' linear_reg() %>%
53+
#' set_mode("quantile regression", quantile_levels = c(0.2, 0.5, 0.8))
5254
#' @export
5355
set_args <- function(object, ...) {
5456
UseMethod("set_args")
@@ -89,12 +91,18 @@ set_args.default <- function(object,...) {
8991

9092
#' @rdname set_args
9193
#' @export
92-
set_mode <- function(object, mode) {
94+
set_mode <- function(object, mode, ...) {
9395
UseMethod("set_mode")
9496
}
9597

98+
#' @rdname set_args
99+
#' @param quantile_levels A vector of values between zero and one (only for the
100+
#' `"quantile regression"` mode); otherwise, it is `NULL`. The model uses these
101+
#' values to appropriately train quantile regression models to make predictions
102+
#' for these values (e.g., `quantile_levels = 0.5` is the median).
96103
#' @export
97-
set_mode.model_spec <- function(object, mode) {
104+
set_mode.model_spec <- function(object, mode, quantile_levels = NULL, ...) {
105+
check_dots_empty()
98106
cls <- class(object)[1]
99107
if (rlang::is_missing(mode)) {
100108
spec_modes <- rlang::env_get(get_model_env(), paste0(cls, "_modes"))
@@ -111,11 +119,21 @@ set_mode.model_spec <- function(object, mode) {
111119

112120
object$mode <- mode
113121
object$user_specified_mode <- TRUE
122+
if (mode == "quantile regression") {
123+
hardhat::check_quantile_levels(quantile_levels)
124+
} else {
125+
if (!is.null(quantile_levels)) {
126+
cli::cli_warn("{.arg quantile_levels} is only used when the mode is
127+
{.val quantile regression}.")
128+
}
129+
}
130+
131+
object$quantile_levels <- quantile_levels
114132
object
115133
}
116134

117135
#' @export
118-
set_mode.default <- function(object, mode) {
136+
set_mode.default <- function(object, mode, ...) {
119137
error_set_object(object, func = "set_mode")
120138

121139
invisible(FALSE)

R/fit.R

+8-1
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,10 @@ fit.model_spec <-
176176
eval_env$formula <- formula
177177
eval_env$weights <- wts
178178

179+
if (!is.null(object$quantile_levels)) {
180+
eval_env$quantile_levels <- object$quantile_levels
181+
}
182+
179183
data <- materialize_sparse_tibble(data, object, "data")
180184

181185
fit_interface <-
@@ -187,7 +191,6 @@ fit.model_spec <-
187191
with a spark data object."
188192
)
189193

190-
191194
# populate `method` with the details for this model type
192195
object <- add_methods(object, engine = object$engine)
193196

@@ -295,6 +298,10 @@ fit_xy.model_spec <-
295298
eval_env$y_var <- y_var
296299
eval_env$weights <- weights_to_numeric(case_weights, object)
297300

301+
if (!is.null(object$quantile_levels)) {
302+
eval_env$quantile_levels <- object$quantile_levels
303+
}
304+
298305
# TODO case weights: pass in eval_env not individual elements
299306
fit_interface <- check_xy_interface(eval_env$x, eval_env$y, cl, object)
300307

R/install_packages.R

+2-2
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ install_engine_packages <- function(extension = TRUE, extras = TRUE,
2626
}
2727

2828
if (extras) {
29-
rmd_pkgs <- c("tidymodels", "broom.mixed", "glmnet", "Cubist", "xrf", "ape",
30-
"rmarkdown")
29+
rmd_pkgs <- c("ape", "broom.mixed", "Cubist", "glmnet", "quantreg",
30+
"rmarkdown", "tidymodels", "xrf")
3131
engine_packages <- unique(c(engine_packages, rmd_pkgs))
3232
}
3333

R/linear_reg_data.R

+46
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
set_new_model("linear_reg")
22

33
set_model_mode("linear_reg", "regression")
4+
set_model_mode("linear_reg", "quantile regression")
45

56
# ------------------------------------------------------------------------------
67

@@ -582,3 +583,48 @@ set_pred(
582583
)
583584
)
584585

586+
# ------------------------------------------------------------------------------
587+
588+
set_model_engine(model = "linear_reg", mode = "quantile regression", eng = "quantreg")
589+
set_dependency(model = "linear_reg", eng = "quantreg", pkg = "quantreg", mode = "quantile regression")
590+
591+
set_fit(
592+
model = "linear_reg",
593+
eng = "quantreg",
594+
mode = "quantile regression",
595+
value = list(
596+
interface = "formula",
597+
protect = c("formula", "data", "weights"),
598+
func = c(pkg = "quantreg", fun = "rq"),
599+
defaults = list(tau = expr(quantile_levels))
600+
)
601+
)
602+
603+
set_encoding(
604+
model = "linear_reg",
605+
eng = "quantreg",
606+
mode = "quantile regression",
607+
options = list(
608+
predictor_indicators = "traditional",
609+
compute_intercept = TRUE,
610+
remove_intercept = TRUE,
611+
allow_sparse_x = FALSE
612+
)
613+
)
614+
615+
set_pred(
616+
model = "linear_reg",
617+
eng = "quantreg",
618+
mode = "quantile regression",
619+
type = "quantile",
620+
value = list(
621+
pre = NULL,
622+
post = matrix_to_quantile_pred,
623+
func = c(fun = "predict"),
624+
args =
625+
list(
626+
object = expr(object$fit),
627+
newdata = expr(new_data)
628+
)
629+
)
630+
)

R/linear_reg_quantreg.R

+11
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
#' Linear quantile regression via the quantreg package
2+
#'
3+
#' [quantreg::rq()] optimizes quantile loss to fit models with numeric outcomes.
4+
#'
5+
#' @includeRmd man/rmd/linear_reg_quantreg.md details
6+
#'
7+
#' @name details_linear_reg_quantreg
8+
#' @keywords internal
9+
NULL
10+
11+
# See inst/README-DOCS.md for a description of how these files are processed

R/predict.R

+4-2
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,14 @@ check_pred_type <- function(object, type, ..., call = rlang::caller_env()) {
201201
regression = "numeric",
202202
classification = "class",
203203
"censored regression" = "time",
204+
"quantile regression" = "quantile",
204205
cli::cli_abort(
205-
"{.arg type} should be 'regression', 'censored regression', or 'classification'.",
206+
"{.arg type} should be one of {.or {.val {all_modes}}}.",
206207
call = call
207208
)
208209
)
209210
}
211+
210212
if (!(type %in% pred_types))
211213
cli::cli_abort(
212214
"{.arg type} should be one of {.or {.arg {pred_types}}}.",
@@ -373,7 +375,7 @@ check_pred_type_dots <- function(object, type, ..., call = rlang::caller_env())
373375

374376
# ----------------------------------------------------------------------------
375377

376-
other_args <- c("interval", "level", "std_error", "quantile",
378+
other_args <- c("interval", "level", "std_error", "quantile_levels",
377379
"time", "eval_time", "increasing")
378380

379381
eval_time_types <- c("survival", "hazard")

0 commit comments

Comments
 (0)