@@ -11,7 +11,7 @@ flexsurv_post <- function(pred, object) {
11
11
tidyr :: nest(.by = .row ) %> %
12
12
dplyr :: select(- .row )
13
13
}
14
- pred
14
+ pred
15
15
}
16
16
17
17
flexsurv_rename_time <- function (pred ){
@@ -27,3 +27,46 @@ flexsurv_rename_time <- function(pred){
27
27
dplyr :: rename(.eval_time = .time )
28
28
}
29
29
}
30
+
31
+ # ------------------------------------------------------------------------------
32
+ # Conversion of quantile predictions to the vctrs format
33
+
34
+ # For single quantile levels, flexsurv returns a data frame with column
35
+ # ".pred_quantile" and perhaps also ".pred_lower" and ".pred_upper"
36
+
37
+ # With mutiple quantile levels, flexsurv returns a data frame with a ".pred"
38
+ # column with co.lumns ".quantile" and ".pred_quantile" and perhaps
39
+ # ".pred_lower" and ".pred_upper"
40
+
41
+ flexsurv_to_quantile_pred <- function (x , object ) {
42
+ # if one level, convert to nested format
43
+ if (! identical(names(x ), " .pred" )) {
44
+ # convert to the same format as predictions with mulitplel levels
45
+ x <- re_nest(x )
46
+ }
47
+
48
+ # Get column names to convert to vctrs encoding
49
+ nms <- names(x $ .pred [[1 ]])
50
+ possible_cols <- c(" .pred_quantile" , " .pred_lower" , " .pred_upper" )
51
+ existing_cols <- intersect(possible_cols , nms )
52
+
53
+ # loop over prediction column names
54
+ res <- list ()
55
+ for (col in existing_cols ) {
56
+ res [[col ]] <- purrr :: map_vec(x $ .pred , nested_df_iter , col = col )
57
+ }
58
+ tibble :: new_tibble(res )
59
+
60
+ }
61
+
62
+ re_nest <- function (df ) {
63
+ .row <- 1 : nrow(df )
64
+ df <- vctrs :: vec_split(df , by = .row )
65
+ df $ key <- NULL
66
+ names(df ) <- " .pred"
67
+ df
68
+ }
69
+
70
+ nested_df_iter <- function (df , col ) {
71
+ hardhat :: quantile_pred(matrix (df [[col ]], nrow = 1 ), quantile_levels = df $ .quantile )
72
+ }
0 commit comments