Skip to content

Commit 2dbac85

Browse files
author
‘topepo’
committed
data manipulation functions for tidymodels/parsnip#1203
1 parent a297661 commit 2dbac85

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

DESCRIPTION

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ Imports:
2727
dplyr (>= 0.8.0.1),
2828
generics,
2929
glue,
30-
hardhat (>= 1.1.0),
30+
hardhat (>= 1.4.0.9002),
3131
lifecycle,
3232
mboost,
3333
prodlim (>= 2023.03.31),
@@ -48,6 +48,8 @@ Suggests:
4848
rmarkdown,
4949
rpart,
5050
testthat (>= 3.0.0)
51+
Remotes:
52+
tidymodels/hardhat
5153
Config/Needs/website:
5254
tidymodels,
5355
tidyverse/tidytemplate

R/survival_reg-flexsurv.R

Lines changed: 44 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ flexsurv_post <- function(pred, object) {
1111
tidyr::nest(.by = .row) %>%
1212
dplyr::select(-.row)
1313
}
14-
pred
14+
pred
1515
}
1616

1717
flexsurv_rename_time <- function(pred){
@@ -27,3 +27,46 @@ flexsurv_rename_time <- function(pred){
2727
dplyr::rename(.eval_time = .time)
2828
}
2929
}
30+
31+
# ------------------------------------------------------------------------------
32+
# Conversion of quantile predictions to the vctrs format
33+
34+
# For single quantile levels, flexsurv returns a data frame with column
35+
# ".pred_quantile" and perhaps also ".pred_lower" and ".pred_upper"
36+
37+
# With mutiple quantile levels, flexsurv returns a data frame with a ".pred"
38+
# column with co.lumns ".quantile" and ".pred_quantile" and perhaps
39+
# ".pred_lower" and ".pred_upper"
40+
41+
flexsurv_to_quantile_pred <- function(x, object) {
42+
# if one level, convert to nested format
43+
if(!identical(names(x), ".pred")) {
44+
# convert to the same format as predictions with mulitplel levels
45+
x <- re_nest(x)
46+
}
47+
48+
# Get column names to convert to vctrs encoding
49+
nms <- names(x$.pred[[1]])
50+
possible_cols <- c(".pred_quantile", ".pred_lower", ".pred_upper")
51+
existing_cols <- intersect(possible_cols, nms)
52+
53+
# loop over prediction column names
54+
res <- list()
55+
for (col in existing_cols) {
56+
res[[col]] <- purrr::map_vec(x$.pred, nested_df_iter, col = col)
57+
}
58+
tibble::new_tibble(res)
59+
60+
}
61+
62+
re_nest <- function(df) {
63+
.row <- 1:nrow(df)
64+
df <- vctrs::vec_split(df, by = .row)
65+
df$key <- NULL
66+
names(df) <- ".pred"
67+
df
68+
}
69+
70+
nested_df_iter <- function(df, col) {
71+
hardhat::quantile_pred(matrix(df[[col]], nrow = 1), quantile_levels = df$.quantile)
72+
}

0 commit comments

Comments
 (0)