Skip to content

Commit 30a7936

Browse files
author
‘topepo’
committed
more changes for tidymodels/parsnip#1203
1 parent 2dbac85 commit 30a7936

6 files changed

+61
-100
lines changed

DESCRIPTION

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@ Imports:
3535
rlang (>= 1.0.0),
3636
stats,
3737
tibble (>= 3.1.3),
38-
tidyr (>= 1.0.0)
38+
tidyr (>= 1.0.0),
39+
vctrs
3940
Suggests:
4041
aorsf (>= 0.1.2),
4142
coin,

R/censored-package.R

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,12 @@ utils::globalVariables(
6262
".id", ".tmp", "engine", "predictor_indicators", ".strata", "group",
6363
".pred_quantile", ".quantile", "interval", "level", ".pred_linear_pred",
6464
".pred_link", ".pred_time", ".pred_survival", "next_event_time",
65-
"sum_component", "time_interval"
65+
"sum_component", "time_interval", "quantile_levels"
6666
)
6767
)
6868

6969
# quiet R-CMD-check NOTEs that prodlim is unused
70-
# (parsnip uses it for all censored regression models
70+
# (parsnip uses it for all censored regression models
7171
# but only has it in Suggests)
7272
#' @importFrom prodlim prodlim
7373
NULL

R/survival_reg-data.R

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -86,14 +86,14 @@ make_survival_reg_survival <- function() {
8686
type = "quantile",
8787
value = list(
8888
pre = NULL,
89-
post = survreg_quant,
89+
post = parsnip::matrix_to_quantile_pred,
9090
func = c(fun = "predict"),
9191
args =
9292
list(
9393
object = expr(object$fit),
9494
newdata = expr(new_data),
9595
type = "quantile",
96-
p = expr(quantile)
96+
p = expr(quantile_levels)
9797
)
9898
)
9999
)
@@ -236,14 +236,14 @@ make_survival_reg_flexsurv <- function() {
236236
type = "quantile",
237237
value = list(
238238
pre = NULL,
239-
post = NULL,
239+
post = flexsurv_to_quantile_pred,
240240
func = c(fun = "predict"),
241241
args =
242242
list(
243243
object = rlang::expr(object$fit),
244244
newdata = rlang::expr(new_data),
245245
type = "quantile",
246-
p = rlang::expr(quantile),
246+
p = rlang::expr(quantile_levels),
247247
conf.int = rlang::expr(interval == "confidence"),
248248
conf.level = rlang::expr(level)
249249
)
@@ -393,14 +393,14 @@ make_survival_reg_flexsurvspline <- function() {
393393
type = "quantile",
394394
value = list(
395395
pre = NULL,
396-
post = NULL,
396+
post = flexsurv_to_quantile_pred,
397397
func = c(fun = "predict"),
398398
args =
399399
list(
400400
object = rlang::expr(object$fit),
401401
newdata = rlang::expr(new_data),
402402
type = "quantile",
403-
p = rlang::expr(quantile),
403+
p = rlang::expr(quantile_levels),
404404
conf.int = rlang::expr(interval == "confidence"),
405405
conf.level = rlang::expr(level)
406406
)

tests/testthat/test-survival_reg-flexsurv.R

Lines changed: 21 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ library(testthat)
22

33
test_that("model object", {
44
skip_if_not_installed("flexsurv")
5-
5+
66
set.seed(1234)
77
exp_f_fit <- flexsurv::flexsurvreg(
88
Surv(time, status) ~ age + ph.ecog,
@@ -149,7 +149,7 @@ test_that("survival probabilities for single eval time point", {
149149

150150
test_that("can predict for out-of-domain timepoints", {
151151
skip_if_not_installed("flexsurv")
152-
152+
153153
eval_time_obs_max_and_ood <- c(1022, 2000)
154154
obs_without_NA <- lung[2,]
155155

@@ -236,41 +236,28 @@ test_that("quantile predictions", {
236236
)
237237

238238
expect_s3_class(pred, "tbl_df")
239-
expect_equal(names(pred), ".pred")
239+
expect_equal(names(pred), ".pred_quantile")
240240
expect_equal(nrow(pred), 3)
241-
expect_true(
242-
all(purrr::map_lgl(
243-
pred$.pred,
244-
~ all(dim(.x) == c(9, 2))
245-
))
246-
)
247-
expect_true(
248-
all(purrr::map_lgl(
249-
pred$.pred,
250-
~ all(names(.x) == c(".quantile", ".pred_quantile"))
251-
))
252-
)
253-
expect_equal(
254-
tidyr::unnest(pred, cols = .pred)$.pred_quantile,
255-
do.call(rbind, exp_pred)$est
256-
)
241+
expect_s3_class(pred$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))
242+
243+
for (.row in 1:nrow(pred)) {
244+
expect_equal(
245+
unclass(pred$.pred_quantile[.row])[[1]],
246+
exp_pred[[.row]]$est
247+
)
248+
}
257249

258250
# add confidence interval
259-
pred <- predict(fit_s,
251+
pred_ci <- predict(fit_s,
260252
new_data = bladder[1:3, ], type = "quantile",
261253
interval = "confidence", level = 0.7
262254
)
263-
expect_true(
264-
all(purrr::map_lgl(
265-
pred$.pred,
266-
~ all(names(.x) == c(
267-
".quantile",
268-
".pred_quantile",
269-
".pred_lower",
270-
".pred_upper"
271-
))
272-
))
273-
)
255+
expect_s3_class(pred_ci, "tbl_df")
256+
expect_equal(names(pred_ci), c(".pred_quantile", ".pred_lower", ".pred_upper"))
257+
expect_equal(nrow(pred_ci), 3)
258+
expect_s3_class(pred_ci$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))
259+
expect_s3_class(pred_ci$.pred_lower, c("quantile_pred", "vctrs_vctr", "list"))
260+
expect_s3_class(pred_ci$.pred_upper, c("quantile_pred", "vctrs_vctr", "list"))
274261

275262
# single observation
276263
f_pred_1 <- predict(fit_s, bladder[2,], type = "quantile")
@@ -354,7 +341,7 @@ test_that("hazard for single eval time point", {
354341

355342
test_that("`fix_xy()` works", {
356343
skip_if_not_installed("flexsurv")
357-
344+
358345
lung_x <- as.matrix(lung[, c("age", "ph.ecog")])
359346
lung_y <- Surv(lung$time, lung$status)
360347
lung_pred <- lung[1:5, ]
@@ -401,13 +388,13 @@ test_that("`fix_xy()` works", {
401388
f_fit,
402389
new_data = lung_pred,
403390
type = "quantile",
404-
quantile = c(0.2, 0.8)
391+
quantile_levels = c(0.2, 0.8)
405392
)
406393
xy_pred_quantile <- predict(
407394
xy_fit,
408395
new_data = lung_pred,
409396
type = "quantile",
410-
quantile = c(0.2, 0.8)
397+
quantile_levels = c(0.2, 0.8)
411398
)
412399
expect_equal(f_pred_quantile, xy_pred_quantile)
413400

tests/testthat/test-survival_reg-flexsurvspline.R

Lines changed: 15 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ test_that("survival probability prediction", {
6161
head(lung),
6262
type = "survival",
6363
times = c(0, 500, 1000)
64-
)
64+
)
6565
if (packageVersion("flexsurv") < "2.3") {
6666
exp_pred <- exp_pred %>%
6767
dplyr::rowwise() %>%
@@ -211,59 +211,26 @@ test_that("quantile predictions", {
211211
set_mode("censored regression") %>%
212212
fit(Surv(stop, event) ~ rx + size + enum, data = bladder)
213213
pred <- predict(fit_s, new_data = bladder[1:3, ], type = "quantile")
214-
215-
set.seed(1)
216-
exp_fit <- flexsurv::flexsurvspline(
217-
Surv(stop, event) ~ rx + size + enum,
218-
data = bladder,
219-
k = 1
220-
)
221-
exp_pred <- summary(
222-
exp_fit,
223-
newdata = bladder[1:3, ],
224-
type = "quantile",
225-
quantiles = (1:9) / 10
226-
)
227-
228214
expect_s3_class(pred, "tbl_df")
229-
expect_equal(names(pred), ".pred")
215+
expect_equal(names(pred), ".pred_quantile")
230216
expect_equal(nrow(pred), 3)
231-
expect_true(
232-
all(purrr::map_lgl(
233-
pred$.pred,
234-
~ all(dim(.x) == c(9, 2))
235-
))
236-
)
237-
expect_true(
238-
all(purrr::map_lgl(
239-
pred$.pred,
240-
~ all(names(.x) == c(".quantile", ".pred_quantile"))
241-
))
242-
)
243-
expect_equal(
244-
tidyr::unnest(pred, cols = .pred)$.pred_quantile,
245-
do.call(rbind, exp_pred)$est
246-
)
217+
expect_s3_class(pred$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))
218+
247219

248220
# add confidence interval
249-
pred <- predict(
221+
pred_ci <- predict(
250222
fit_s,
251223
new_data = bladder[1:3, ],
252224
type = "quantile",
253225
interval = "confidence",
254226
level = 0.7
255227
)
256-
expect_true(
257-
all(purrr::map_lgl(
258-
pred$.pred,
259-
~ all(names(.x) == c(
260-
".quantile",
261-
".pred_quantile",
262-
".pred_lower",
263-
".pred_upper"
264-
))
265-
))
266-
)
228+
expect_s3_class(pred_ci, "tbl_df")
229+
expect_equal(names(pred_ci), c(".pred_quantile", ".pred_lower", ".pred_upper"))
230+
expect_equal(nrow(pred_ci), 3)
231+
expect_s3_class(pred_ci$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))
232+
expect_s3_class(pred_ci$.pred_lower, c("quantile_pred", "vctrs_vctr", "list"))
233+
expect_s3_class(pred_ci$.pred_upper, c("quantile_pred", "vctrs_vctr", "list"))
267234

268235
# single observation
269236
f_pred_1 <- predict(fit_s, bladder[2,], type = "quantile")
@@ -284,7 +251,7 @@ test_that("hazard prediction", {
284251
head(lung),
285252
type = "hazard",
286253
times = c(0, 500, 1000)
287-
)
254+
)
288255
if (packageVersion("flexsurv") < "2.3") {
289256
exp_pred <- exp_pred %>%
290257
dplyr::rowwise() %>%
@@ -409,13 +376,13 @@ test_that("`fix_xy()` works", {
409376
f_fit,
410377
new_data = lung_pred,
411378
type = "quantile",
412-
quantile = c(0.2, 0.8)
379+
quantile_levels = c(0.2, 0.8)
413380
)
414381
xy_pred_quantile <- predict(
415382
xy_fit,
416383
new_data = lung_pred,
417384
type = "quantile",
418-
quantile = c(0.2, 0.8)
385+
quantile_levels = c(0.2, 0.8)
419386
)
420387
expect_equal(f_pred_quantile, xy_pred_quantile)
421388

@@ -438,7 +405,7 @@ test_that("`fix_xy()` works", {
438405

439406
test_that("can handle case weights", {
440407
skip_if_not_installed("flexsurv")
441-
408+
442409
# flexsurv engine can only take weights > 0
443410
set.seed(1)
444411
wts <- runif(nrow(lung))

tests/testthat/test-survival_reg-survival.R

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -122,16 +122,22 @@ test_that("prediction of survival time quantile", {
122122
fit(Surv(time, status) ~ age + sex, data = lung)
123123

124124
exp_quant <- predict(res$fit, head(lung), p = (2:4) / 5, type = "quantile")
125-
exp_quant <- apply(exp_quant, 1, function(x) {
126-
tibble::tibble(.quantile = (2:4) / 5, .pred_quantile = x)
127-
})
128-
exp_quant <- tibble::tibble(.pred = exp_quant)
129-
obs_quant <- predict(res, head(lung), type = "quantile", quantile = (2:4) / 5)
125+
obs_quant <- predict(res, head(lung), type = "quantile", quantile_levels = (2:4) / 5)
130126

131-
expect_equal(as.data.frame(exp_quant), as.data.frame(obs_quant))
127+
expect_s3_class(obs_quant, "tbl_df")
128+
expect_equal(names(obs_quant), ".pred_quantile")
129+
expect_equal(nrow(obs_quant), 6)
130+
expect_s3_class(obs_quant$.pred_quantile, c("quantile_pred", "vctrs_vctr", "list"))
131+
132+
for (.row in 1:nrow(obs_quant)) {
133+
expect_equal(
134+
unclass(obs_quant$.pred_quantile[.row])[[1]],
135+
exp_quant[.row,]
136+
)
137+
}
132138

133139
# single observation
134-
f_pred_1 <- predict(res, lung[1, ], type = "quantile")
140+
f_pred_1 <- predict(res, lung[1, ], type = "quantile", quantile_levels = .5)
135141
expect_identical(nrow(f_pred_1), 1L)
136142
})
137143

@@ -213,13 +219,13 @@ test_that("`fix_xy()` works", {
213219
f_fit,
214220
new_data = lung_pred,
215221
type = "quantile",
216-
quantile = c(0.2, 0.8)
222+
quantile_levels = c(0.2, 0.8)
217223
)
218224
xy_pred_quantile <- predict(
219225
xy_fit,
220226
new_data = lung_pred,
221227
type = "quantile",
222-
quantile = c(0.2, 0.8)
228+
quantile_levels = c(0.2, 0.8)
223229
)
224230
expect_equal(f_pred_quantile, xy_pred_quantile)
225231

0 commit comments

Comments
 (0)