June 27, 2020

The modeling of time-to-event data, also known as survival analysis, requires specialized methods that can deal with censoring and truncation, time-varying features and effects, and that extend to settings with multiple competing events. However, many machine learning methods for survival analysis only consider the standard setting with right-censored data and proportional hazards assumption. The methods that do provide extensions usually address at most a subset of these challenges and often require specialized software that can not be integrated into standard machine learning workflows directly. In this work, we present a very general machine learning framework for time-to-event analysis that uses a data augmentation strategy to reduce complex survival tasks to standard Poisson regression tasks. This reformulation is based on well developed statistical theory. With the proposed approach, any algorithm that can optimize a Poisson (log-)likelihood, such as gradient boosted trees, deep neural networks, model-based boosting and many more can be used in the context of time-to-event analysis. The proposed technique does not require any assumptions with respect to the distribution of event times or the functional shapes of feature and interaction effects. Based on the proposed framework we develop new methods that are competitive with specialized state of the art approaches in terms of accuracy, and versatility, but with comparatively small investments of programming effort or requirements for specialized methodological know-how.

Survival analysis is a branch of statistics that provides a framework for the analysis of time-to-event data, i.e., the outcome is defined by the time it takes until an event occurs. Analysis of such data requires specialized techniques because, in contrast to standard regression or classification tasks,

the outcome can often not be observed fully (censoring, truncation),

the features can change their value during the observation period (time-varying features (TVF),

the association of the feature(s) with the outcome changes over time (time-varying effects (TVE)),

one or more other events occur that make it impossible to observe the event of interest (competing risks (CR)),

more generally, in a multi-state setting, observation units can move from and to different states (multi-state models (MSM)).

Failure to take these issues into consideration usually results in biased estimates, incorrect interpretation of feature effects on the outcome, loss of predictive accuracy or a combination thereof. In this work, we use a reformulation of the survival task to a standard regression task that provides a holistic approach to survival analysis. Within this framework, censoring, truncation and time-varying features (TVF) can be incorporated by specific data transformations and extensions to time-varying effects (TVE) as well as competing risks and multi-state models can be re-expressed in terms of interaction effects. This abstraction of the survival task away from specialized algorithms is illustrated schematically in Figure 1. Task-appropriate pre-processing (leftmost subgraph) yields a standardized data format that allows the estimation of feature-conditional hazard rates using any learning algorithm that can minimize the negative Poisson log-likelihood, such as, GBT, deep neural networks (DNN), regularization based methods, and others (middle subgraph).

We define a general machine learning framework for survival analysis based on piece-wise exponential models (cf. Section 2). Within this framework, different concepts specific to time-to-event data analysis can be understood in terms of data augmentation and inclusion of interaction terms. By re-expressing the survival task as a Poisson regression task, a large variety of algorithms become available for survival analysis. Based on the proposed approach, we implement a gradient boosted trees algorithm with comparatively low development effort and show that it achieves state-of-the-art performance (cf. Section 3).

The machine learning community has developed many highly efficient methods for high-dimensional settings in different domains, including survival analysis. The individual methods and implementations, however, often only support a subset of the cases relevant for time-to-event analysis mentioned above. For example, the random survival forest (RSF) proposed in [1] was later extended to the competing risks setting [2], but does not support left-truncation, TVF and TVE, or multistate models. Another popular implementation of random forests [3] only supports right censored data and proportional hazards models. An extension of RSF, the oblique RSF (ORSF, [4]) was shown to outperform other RSF based algorithms, but has the same limitations. With respect to TVF and TVE, a review of tree- and forest-based methods for survival analysis stated that “the modeling of time–varying features and time–varying effects deserves much more attention” [5]. Similarly, a more recent review of machine learning methods for survival analysis [6] only lists the time-dependent Cox model [7], and L1- and L2-regularized extensions thereof, as a possibility for the inclusion of TVF.

Deep learning based methods for time-to-event data have also received much attention lately. An early use of neural networks for Cox type models was proposed by [8]. More recently, [9] presented a framework for deep single event survival analysis based on a joint latent process for features and survival times using deep exponential families. For competing risks data, a deep learning framework based on Gaussian processes was described in [10]. Another recent framework is DeepHit, which can handle competing risks using a custom loss function [11] and was extended to handle TVF [12], but did not discuss left-truncation, multistate models and TVE.

Boosting has also been a popular technique for high-dimensional survival analysis. For example, [13] propose a Cox-type boosting approach for the estimation of proportional sub-distribution hazards. A flexible multi-state model based on the stratified Cox partial likelihood in the context of model-based boosting [14] is presented in [15]. Furthermore, an implementation of gradient boosted trees (GBT) for the Cox PH model is also available for the popular XGBoost implementation [16], which was also shown to perform well compared with the ORSF [4]. Recently, [17] derived a custom algorithm for gradient boosted trees that support TVF and demonstrate that their inclusion improves predictive performance compared to boosting algorithms that don’t take TVF into account.

Compared to methods based on Cox regression, few publications have developed methods based on the piece-wise exponential model, on which the framework proposed here is based. Among them is an early application of neural networks to survival analysis suggested in [18] and extended by [19]. The latter offers a general framework based on the representation of generalized linear models via feed forward neural networks, but does not discuss MSM. Piece-wise exponential trees with TVF and splits based on the piece-wise exponential survival function were suggested by [20]. A spline based estimation of the hazard function was discussed in [21], which could also be represented via neural networks (cf. [22]). A flexible estimation of piece-wise-exponential model based multi-state models with shared effects using structured fusion Lasso was developed in [23]. All of these methods can be viewed as special cases within the proposed framework. For example, [19] could be extended to different neural network architectures and MSMs, [20] could be extended to forests.

In the context of survival analysis, an observation usually consists of a tuple \((t_i, \delta_i, \mathbf{x}_i)\), where \(t_i\) is the observed event time for observation unit \(i=1,\ldots,n\), \(\delta_i \in \{0,1\}\) is the event- or status-indicator (i.e. 1 if event occurred, 0 if the observation of censored) and \(\mathbf{x}_i\) is the \(p\)-dimensional feature vector. The presence of censoring requires special estimation techniques, as the time-to-event can not be observed when censoring occurs before the event of interest. Thus \(t_i = \min(T_i, C_i)\), where \(T_i\) and \(C_i\) random variables of the event time and censoring time, respectively. A classic example is the time until death when censoring occurs as patients drop out of the study (unrelated to the event of interest, \(T_i\perp C_i\)). Left-truncation occurs when the event of interest already occurred before the subject could be included into the sample and thus presents a form of sampling bias. In some settings, another event could preclude observation of the event of interest or change the probability of its occurrence. In this case we speak of competing risks (CR), thus the observation consists of \((t_i, \delta_i, k, \mathbf{x}_i)\), where \(k = 1,\ldots, K\) indicates the type of event that occurred at, \(t_i\) if \(\delta_i = 1\). More generally, there might be multiple states that the observation units can transition from and to. We then speak of multi-state models (MSM) and \(k\) is an indicator for different transitions (cf. Eq. \(\ref{eq:cshazard}\)).

In general, the goal of survival analysis is to estimate the conditional distribution of event times defined by the survival probability \(S(t|\mathbf{x}) = P(T > t|\mathbf{x})\). While some methods focus on the estimation of \(S(t|\mathbf{x})\) directly, it is often more convenient to estimate the (log-)hazard

\[\begin{equation}\lambda(t|\mathbf{x}) := \lim\limits_{\Delta t \to 0} \frac{P(t\leq T <t+\Delta t | T\geq t, \mathbf{x})}{\Delta t}\,\label{eq:hazard} \end{equation}\]

from which \(S(t|\mathbf{x})\) follows as

\[\begin{equation}S(t|\mathbf{x}) = \exp\left(-\int_0^t \lambda(s|\mathbf{x})\mathrm{d}s\right).\end{equation}\]

Here we represent \(\ref{eq:hazard}\) via \[\begin{equation}\lambda(t|\mathbf{x}(t)) = \exp(g(\mathbf{x}(t), t)),\label{eq:phmodel}
\end{equation}\] where \(g\) is a general function of potentially TVF \(\mathbf{x}(t)\), that can include high-order feature interactions, non-linearity and time-dependence of feature effects (TVE) via an interaction with \(t\).

In this work, we approximate \(\ref{eq:phmodel}\) using the piece-wise exponential model [24]. Let \(t_i\) the observed event or censoring time and \(\delta_i \in \{0,1\}\) the respective censoring or event indicator for observation units \(i=1,\ldots,n\). The distribution of censoring times can depend on features but is assumed to be independent of the event time process \(T\). By partitioning the follow-up, i.e., the time span under investigation, into \(j=1,\ldots,J\) intervals with cut-points \(\kappa_0 = 0 < \cdots < \kappa_J\) and partitions \((\kappa_0, \kappa_1],\ldots, (\kappa_{j-1},\kappa_{j}],\ldots (\kappa_{J-1},\kappa_{J}]\), we can rewrite \(\ref{eq:phmodel}\) using piece-wise constant hazard rates \[\begin{gather} \label{eq:pcmodel} \lambda(t| \mathbf{x}_i(t)) & \equiv \exp(g(\mathbf{x}_{ij}, t_j)):=\lambda_{ij},\;\;\forall t \in (\kappa_{j-1}, \kappa_j],\end{gather}\]

with \(t_j\) a representation of time in interval \(j\), e.g., \(t_j:=\kappa_j\) and \(\mathbf{x}_{ij}\) the value of the TVF in interval \(j\). Depending on the desired resolution, additional cut-points can be introduced at each time point at which feature values are updated, otherwise multiple feature values have to be aggregated in one interval. This model assumes that only the current value of \(\mathbf{x}_{ij}\) affects the hazard in interval \(j\), but more sophisticated approaches have been suggested within this framework that take into account the entire history of TVF [25]. Piece-wise constant hazards imply piece-wise exponential log-likelihood contributions \[\begin{equation}\ell_i = \log(\lambda(t_i;\mathbf{x}_i)^{\delta_i}S(t_i;\mathbf{x}_i))\\ = \sum_{j=1}^{J_i}\left(\delta_{ij}\log\lambda_{ij} - \lambda_{ij}t_{ij}\right),\label{eq:ll} \end{equation}\]

where \(J_i\) is the last interval in which observation unit \(i\) was observed, such that \(t_{i} \in (\kappa_{J_i-1},\kappa_{J_i}]\) and

\[\begin{equation}\delta_{ij} = \begin{cases}1 & t_i \in (\kappa_{j-1}, \kappa_j] \wedge \delta_i = 1\\0 & \text{else}\end{cases},\ t_{ij} = \begin{cases}t_{i}-\kappa_{j-1} & \delta_{ij}=1\\ \kappa_{j}-\kappa_{j-1}& \text{else}\end{cases}.\label{eq:ped-95status} \end{equation}\]

Concrete examples for the type of data transformations required to obtain \(\ref{eq:ped-95status}\) for right-censored data (including TVF) are provided in [26] (cf. Tables 1 and 2.

Using the working assumption \(\delta_{ij}\stackrel{iid}{\sim}Poisson(\mu_{ij}=\lambda_{ij}t_{ij})\) and with \(f(\delta_{ij})\) the Poisson density function, [24] showed that the Poisson log-likelihood

\[\begin{equation}\ell_i = \log\left(\prod_{j=1}^{J_i} f(\delta_{ij})\right) = \sum_{j=1}^{J_i} ( \delta_{ij}\log \lambda_{ij} + \delta_{ij}\log t_{ij} - \lambda_{ij}t_{ij})\label{eq:poisson-ll} \end{equation}\]

is proportional to \(\ref{eq:ll}\) and therefore the former can be minimized using Poisson regression. Note that \(\ref{eq:poisson-ll}\) can be directly extended to the setting with left-truncated event times [27] by replacing \(j=1\) with \(j_i\), the first interval in which observation unit \(i\) is in the risk set. The expectation is defined by \(\mu_{ij}=\lambda_{ij}t_{ij} = \exp(g(\mathbf{x}_{ij},t_j) + \log(t_{ij}))\). For estimation, \(\log(t_{ij})\) is included as an offset, thus the hazard rate \(\frac{\mu_{ij}}{t_{ij}} = \lambda_{ij} = g(\mathbf{x}_{ij},t_j)\) is defined as the conditional expectation of having an event in interval \(j\) divided by the time under risk. Note that the Poisson assumption is simply a computational vehicle for the estimation of the hazard \(\ref{eq:pcmodel}\) rather than an assumption about the distribution of the event times. Despite the partition of the follow-up into intervals, this is a method for continuous event times as the information about the time under risk in each interval is contained in the offset and thus used during estimation. The number and placement of cut points controls the approximation of the hazard and could thus be viewed as a potential tuning parameter. In our experience, however, setting cut-points at the unique event times \(\{t_i:\delta_i = 1,i=1,\ldots,n\}\) in the training data always leads to a good approximation (at least with enough regularization) as the number of cut-points will increase in areas with many events. For larger data sets, however, we recommend to set these cut-points on a smaller representative sub-sample of the data set (cf. Section 4).

For the extension of \(\ref{eq:phmodel}\) to MSMs, we define \[\begin{equation}\lambda(t|\mathbf{x}, k) = \exp\left(f(\mathbf{x}(t), t,k)\right),\;k = 1,\ldots,K,\label{eq:cshazard} \end{equation}\]

as the transition specific hazard for the transition indexed by \(k\), i.e., \(k\) is an index of transitions \(m_k\rightarrow m_k'\) where \(m_k\) is the initial state and \(m_k'\) a transient or absorbing state. The set of possible transitions is given by \(\{m_k\rightarrow m'_k: k = 1,\ldots, K\} \subseteq \{m\rightarrow m': m,m'\in \{0,\ldots,M\}, m\neq m'\}\), where \(M+1\) denotes the total number of possible states. \(f(\mathbf{x}(t), t, k)\) is a function of potentially time-varying features \(\mathbf{x}(t)\), including multivariate and/or non-linear effects. The dependency of \(f(\mathbf{x}(t), t, k)\) on time \(t\) (TVE) and transition \(k\) (MSM) is expressed in terms of interactions by defining \(\tilde{\mathbf{x}} := (\mathbf{x}(t), t, k)\) and \(f(\mathbf{x}(t), t, k) = f(\tilde{\mathbf{x}}(t))\). Let \(t_{i,k}\) be the event or censoring time w.r.t. transition \(m_k\rightarrow m'_k\) and \(\delta_{i,k} \in \{0,1\}\) the respective transition indicator. As extension of \(\ref{eq:ped-95status}\) we define \[\begin{gather} \delta_{ij,k} & = \begin{cases} 1 & t_{i,k} \in (\kappa_{j-1}, \kappa_j] \wedge \delta_{i,k} = 1\\\nonumber 0 & \text{else,} \end{cases}, t_{ij,k} & = \begin{cases} t_{i,k}-\kappa_{j-1} & \delta_{ij,k}=1\\ \kappa_{j}-\kappa_{j-1}& \text{else}\end{cases}.\nonumber\end{gather}\]

Table 1 shows how the data must be transformed in order to estimate \(\ref{eq:cshazard}\) via PEMs for the competing risks setting, i.e., \(k=1,\ldots,K\) is an index of transitions \(m_k=0\rightarrow m_k', m_k'\in \{1,\ldots,M\}\); a concrete example is given in Table 2. For each \(i=1,\ldots,n\), there is one row for each interval the observation unit was under risk for a specific transition. Thus, one data set is created for each transition such that transitions to state \(m_k'\) are encoded as \(1\) and everything else, i.e., censoring and transition to other states is encoded as \(0\). These transition-specific data sets, each containing a feature vector with the transition index \(k\), are then concatenated. Note that we used the same interval split points \(\kappa_j\) for all transitions in Table 1. However, it would also be possible to choose transition specific cut-points \(\kappa_{j,k}\), or, more generally, even use multiple time-scales [28]. In the general multi-state setting, the number of observation units under risk might depend on the transition and the intervals visited by \(i\) are defined by \((t_{i,m},\kappa_{j_{i,k}}], \ldots, (\kappa_{J_{i,k}-1},\kappa_{J_{i,K}}]\), where \(t_{i,m}\) is the time-point at which \(i\) enters state \(m\).

\(i\) | \(j\) | \(\delta_{ij,k}\) | \(t_j\) | \(t_{ij}\) | \(k\) | \(x_{ij,1}\) | \(\ldots\) | \(x_{ij,P}\) |
---|---|---|---|---|---|---|---|---|

\(1\) | \(1\) | \(\delta_{11,1}\) | \(t_1\) | \(t_{11}\) | \(1\) | \(x_{11,1}\) | \(\ldots\) | \(x_{11,P}\) |

\(1\) | \(2\) | \(\delta_{12,1}\) | \(t_2\) | \(t_{12}\) | \(1\) | \(x_{12,1}\) | \(\ldots\) | \(x_{12,P}\) |

\(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) |

\(1\) | \(J_1\) | \(\delta_{1J_1,1}\) | \(t_{J_1}\) | \(t_{1J_1}\) | \(1\) | \(x_{1J_1,1}\) | \(\ldots\) | \(x_{1J_1,P}\) |

\(2\) | \(1\) | \(\delta_{21,1}\) | \(t_1\) | \(t_{21}\) | \(1\) | \(x_{21,1}\) | \(\ldots\) | \(x_{21,P}\) |

\(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) |

\(n\) | \(1\) | \(\delta_{n1,1}\) | \(t_1\) | \(t_{n1}\) | \(1\) | \(x_{n1,1}\) | \(\ldots\) | \(x_{n1,P}\) |

\(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) |

\(n\) | \(J_n\) | \(\delta_{nJ_n,1}\) | \(t_{J_n}\) | \(t_{nJ_n}\) | \(1\) | \(x_{nJ_n,1}\) | \(\ldots\) | \(x_{nJ_n,P}\) |

\(1\) | \(1\) | \(\delta_{11,2}\) | \(t_1\) | \(t_{11}\) | \(2\) | \(x_{11,1}\) | \(\ldots\) | \(x_{11,P}\) |

\(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) |

\(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) |

\(1\) | \(1\) | \(\delta_{11,K}\) | \(t_1\) | \(t_{11}\) | \(K\) | \(x_{11,1}\) | \(\ldots\) | \(x_{11,P}\) |

\(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) | \(\vdots\) |

\(i\) | \(j\) | \(\delta_{ij}\) | \(t_j\) | \(t_{ij}\) | \(k\) |
---|---|---|---|---|---|

1 | 1 | 0 | \(1\) | 1 | 1 |

1 | 2 | 0 | \(1.5\) | 0.3 | 1 |

2 | 1 | 0 | \(1\) | 0.5 | 1 |

3 | 1 | 0 | \(1\) | 1 | 1 |

3 | 2 | 0 | \(1.5\) | 0.5 | 1 |

3 | 3 | 1 | \(3\) | 1.2 | 1 |

\(i\) | \(j\) | \(\delta_{ij}\) | \(t_j\) | \(t_{ij}\) | \(k\) |
---|---|---|---|---|---|

1 | 1 | 0 | \(1\) | 1 | 2 |

1 | 2 | 1 | \(1.5\) | 0.3 | 2 |

2 | 1 | 0 | \(1\) | 0.5 | 2 |

3 | 1 | 0 | \(1\) | 1 | 2 |

3 | 2 | 0 | \(1.5\) | 0.5 | 2 |

3 | 3 | 0 | \(3\) | 1.2 | 2 |

In Section 3 we evaluate the suggested approach using an implementation based on GBT that we refer to as GBT (PEM). As a concrete computing engine we used the extreme gradient boosting (XGBoost) library [16] without any alterations to the algorithm. Therefore all features of the library can be used directly when estimating the hazard on the transformed data set. Note, however, that depending on the algorithm used, one must be able to specify an offset during estimation and potentially some other, algorithm or implementation specific settings. For example, when using XGBoost to estimate the GBT (PEM), the objective function needs to be set to the Poisson objective and the base score must be set to 1, because the default of 0.5 would imply a wrong offset, while \(\log(1)=0\). The offset (\(\log(t_{ij})\)) must be attached to the data via the base margin argument during estimation. In contrast, for the prediction of the conditional hazard \(\lambda(t|\mathbf{x}_i)=\lambda_{ij}\) based on new data points, the offset should be omitted, otherwise the algorithm will predict \(\hat{\mu}_{ij}=\hat{\lambda}_{ij}\cdot t_{ij}\) (the expectation) instead of \(\hat{\lambda}_{j}(\mathbf{x})\) (the hazard). When predicting the cumulative hazard or survival probability, however, the time under risk in each interval must be taken into account, such that \(\hat{S}(t|\mathbf{x}_i) = \exp\left(-\int \hat{\lambda}(t|\mathbf{x}_i)\mathrm{d}t\right) = \exp\left(-\sum_{j=1}^{j(t)}\hat{\lambda}_{ij}\tilde{t}_{j}\right)\), where \(j(t)\) indicates the interval for which \(t \in (\kappa_{j-1}, \kappa_j]\) and \(\tilde{t}_j=\min(\kappa_j-\kappa_{j-1}, t-\kappa_{j-1})\) is the time spent in interval \(j\). A prototype implementation of the GBT (PEM) algorithm that takes these issues into account and also provides the necessary helper functions for data transformation, estimation, tuning, prediction and evaluation is provided at https://github.com/adibender/pem.xgb.

We perform a set of benchmark experiments with real world and synthetic data sets, exclusively using openly and directly available data, including a subset of data sets from recent publications on oblique random survival forests (ORSF, [4]) and DeepHit [11]. DeepHit and ORSF both have been shown to outperform other approaches such as RSF [1], conditional forests [29], regularized Cox regression [30] and DeepSurv [9]. We compare our approach in benchmarks against the two algorithms, which are evaluated separately based on evaluation measures used in the respective publications to ensure comparability. All code to perform respective analyses as well as additional supplementary files are provided in a GitHub repository: https://github.com/adibender/machine-learning-for-survival-ecml2020.

The data sets used for single event comparisons are listed in Table 3. The “synthetic (TVE)” data set is created based on an additive predictor \(g(\mathbf{x},t) = f_0(x_0, t)\cdot 6 -0.1\cdot x_1 + f_2(x_2, t) + f_3(x_3, t)\), where \(f_0\), \(f_2\) and \(f_3\) are bivariate, non-linear functions of the inputs (see code repository for details) and \(x_0, \ldots, x_3\) feature columns comprised in \(\mathbf{x}\). Additionally, 20 noise variables are drawn from the uniform distribution \(U(0, 1)\).

For the comparison with ORSF, we use the Integrated Brier Score \(\text{IBS}(\tau) = \frac{1}{\tau} \int_{0}^\tau \widehat{\text{BS}}(u, \hat{S})\mathrm{d}u\), where \(\widehat{\text{BS}}(t, \hat{S})\) is the estimated Brier Score at time \(t\) weighted by the inverse probability of censoring weights [31] and \(\hat{S}\) the estimated survival probability function of the respective algorithm. In addition, [4] report the time-dependent C-Index [32]. We only consider the IBS here as it measures calibration as well as discrimination, while the C-index only measures the latter. Note that the IBS depends on the specific evaluation time \(\tau\) and different methods might perform better at different evaluation times. Therefore, we calculate the IBS for three different time-points, the 25%, 50% and 75% quantiles of the event times in the test data, in the following referred to as Q25, Q50 and Q75, respectively.

name | n | p | censoring | |
---|---|---|---|---|

1 | PBC | 412 | 14 | 61.90 |

2 | Breast | 614 | 1690 | 78.20 |

3 | GBSG 2 | 686 | 8 | 56.40 |

4 | Tumor | 776 | 7 | 51.70 |

5 | synthetic (TVE) | 1000 | 24 | \(\sim 33\%\) |

For comparison with DeepHit, we use the metabric data set (cf. [11]) for single event comparison as well as two CR data sets. The “MGUS 2” data is described in [33]. The “synthetic (TVE CR)” data set is simulated using an additive predictor identical to the one used for the “synthetic (TVE)” data simulation for the first cause. The predictor for the second cause, has a simpler structure \(f_0(x_0, t) + 2\cdot x_4 -.1 \cdot x_5\), however, with non-proportional baseline hazard with respect to \(x_0\in\{-1,1\}\). The number of noise variables is limited to 10 for this setting. Here we report the weighted C-index alongside the weighted Brier Score as it was the main measure reported in [11]. The proposed GBT (PEM) approach for CR (cf. Section 2) is a cause specific hazards model, however, the parameters of both causes are estimated jointly and the hazards of both causes can have shared effects (see Figure 2). The simulation setting, therefore, constitutes a difficult setup because there are no shared effects and optimization w.r.t. the first cause will favor parameters that allow flexible models while the optimization w.r.t. the second cause favors sparse models and thus parameters that would restrict flexibility.

name | n | p | censoring (%) | |
---|---|---|---|---|

METABRIC | 1981 | 79 | 55.20 | |

MGUS 2 | 1384 | 6 | 29.6 | |

Synthetic (TVE CR) | 500 | 14 | \(\sim 23\%\) |

We compare four algorithms, the non-parametric Kaplan Meier estimate (Reference) as a minimal baseline, the Cox proportional hazards model [34] (baseline for linear, time-constant effects), the Oblique Random Survival Forest (ORSF) [4] and DeepHit [11]. For each experimental replication for a specific data set, 70% of the data is randomly assigned as training data and the remaining 30% is used to calculate the evaluation measures at three time points Q25, Q50 and Q75. Algorithms are tuned on the training data using random search with a fixed budget and 4-fold cross-validation. Each algorithm is then retrained on the entire training data set using the best set of parameters before making final predictions on the test set. The random search consists of 20 iteration for each algorithm. For the GBT (PEM), we define the search space as follows (possible range in brackets): maximum tree depth \(\{1,\ldots,20\}\), minimum loss reduction [0, 5], minimum child weight \(\{5, \ldots, 50\}\), subsample percentage (rows) [0.5, 1], subsample percentage of features in each tree [.5, 1], L2-regularization [1, 3]. The learning rate is set to 0.05 and number of rounds to 5000, with early stopping after 50 rounds without improvement. For the ORSF we tune the elastic net mixing parameter (0, 1), the parameter that penalizes complexity of the linear predictor in each node (0.25, 0.75), minimum number of events to split node \(\{5,\ldots, 20\}\) and minimum observations to split node \(\{10, 40\}\). For DeepHit, we use 50 random search iterations, where we search through {1,2,3,5} shared layers with {50, 100, 200, 300} dimensions, {1,2,3,5} cause-specific network layers with {50,100,200,300} dimensions, ReLU, eLU or Tanh as activation function in these layers, a batch size in {32,64,128}, a maximum of 50000 iterations, a dropout rate of 0.6 (taken from the original paper) and a learning rate of 0.0001. The network specific parameters \(\alpha\) and \(\gamma\) are also chosen in accordance with the original paper and set to \(1\) and \(0\), respectively, while the network specific parameter \(\beta\) is varied in the random search with possible values in {0.1, 0.5, 1, 3, 5}.

The results for the experiments based on single-event scenarios comparison with ORSF are summarized in Table 5. The proposed method performs well in many settings in comparison to ORSF. Notably, both algorithms are often not much better than the Cox PH models indicating that the PH assumption is not violated strongly in those data sets and the sample size might be too small to detect small deviations w.r.t. to non-linearity of feature effects, interaction effects and time-varying effects. The “synthetic (TVE)” setting illustrates that in the presence of strong, non-linear and non-linearly TVE our approach clearly outperforms the other methods. For the PBC data we additionally ran an analysis including TVF with GBT (PEM). In this case, the inclusion of TVF resulted in a worse performance (IBS of 4.3 (Q25), 6.4 (Q50) and 9.2 (Q75)), which indicates that the inclusion of TVF lead to overfitting or that simple inclusion of the last observed value and carrying the last value forward is not appropriate in this setting.

data | Kaplan-Meier | Cox-PH | ORSF | GBT (PEM) | |
---|---|---|---|---|---|

Q25 | \(\phantom{0}\textbf{1.9}\) | - | \(\phantom{0}2.0\) | \(\phantom{0}2.0\) | |

Breast | Q50 | \(\phantom{0}4.1\) | \(\phantom{0}-\) | \(\phantom{0}\textbf{4.0}\) | \(\phantom{0}\textbf{4.0}\) |

Q75 | \(\phantom{0}7.2\) | \(\phantom{0}-\) | \(\phantom{0}\textbf{6.7}\) | \(\phantom{0}\textbf{6.7}\) | |

Q25 | \(\phantom{0}3.1\) | \(\phantom{0}3.1\) | \(\phantom{0}\textbf{2.9}\) | \(\phantom{0}3.0\) | |

GBSG 2 | Q50 | \(\phantom{0}6.8\) | \(\phantom{0}6.5\) | \(\phantom{0}\textbf{6.2}\) | \(\phantom{0}6.4\) |

Q75 | \(12.5\) | \(11.4\) | \(\textbf{11.1}\) | \(11.3\) | |

Q25 | \(\phantom{0}5.4\) | \(\phantom{0}\textbf{3.7}\) | \(\phantom{0}4.0\) | \(\phantom{0}3.8\) | |

PBC | Q50 | \(\phantom{0}9.1\) | \(\phantom{0}\textbf{5.3}\) | \(\phantom{0}6.1\) | \(\phantom{0}5.5\) |

Q75 | \(14.0\) | \(\phantom{0}8.1\) | \(\phantom{0}8.6\) | \(\phantom{0}\textbf{7.8}\) | |

Q25 | \(\phantom{0}9.8\) | \(\phantom{0}7.3\) | \(\phantom{0}7.0\) | \(\phantom{0}\textbf{4.6}\) | |

synthetic (TVE) | Q50 | \(19.2\) | \(10.3\) | \(\phantom{0}9.9\) | \(\phantom{0}\textbf{6.7}\) |

Q75 | \(23.7\) | \(11.1\) | \(11.7\) | \(\phantom{0}\textbf{8.6}\) | |

Q25 | \(\phantom{0}6.7\) | \(\phantom{0}6.0\) | \(\phantom{0}\textbf{5.5}\) | \(\phantom{0}5.8\) | |

Tumor | Q50 | \(12.3\) | \(11.2\) | \(\textbf{10.8}\) | \(10.9\) |

Q75 | \(17.6\) | \(16.3\) | \(\textbf{16.2}\) | \(\textbf{16.2}\) |

Table 6 summarizes the results of comparisons with DeepHit. The GBT (PEM) again shows good overall performance. For the synthetic data set our method clearly outperforms the other approaches because it is capable of estimating non-linearity as well as time-variation. On the MGUS 2 data set, DeepHit shows the best performance for cause 1, while GBT (PEM) outperforms the other approaches for cause 2. On the synthetic data, the cause-specific Cox-PH model shows good discrimination (C-Index) for the second cause, but is worse than GBT (PEM) and DeepHit w.r.t. to the Brier Score.

cause 1 | cause 2 | |||||||

data | index | method | Q25 | Q50 | Q75 | Q25 | Q50 | Q75 |

Cox-PH | \(13.3\) | \(22.1\) | \(26.4\) | \(\phantom{00}-\) | \(\phantom{00}-\) | \(\phantom{00}-\) | ||

Brier Score | DeepHit | \(14.3\) | \(23.5\) | \(27.0\) | \(\phantom{00}-\) | \(\phantom{00}-\) | \(\phantom{00}-\) | |

GBT (PEM) | \(\textbf{12.8}\) | \(\textbf{21.3}\) | \(\textbf{25.9}\) | \(\phantom{00}-\) | \(\phantom{00}-\) | \(\phantom{00}-\) | ||

METABRIC | Cox-PH | \(63.7\) | \(65.1\) | \(64.7\) | \(\phantom{00}-\) | \(\phantom{00}-\) | \(\phantom{00}-\) | |

C-Index | DeepHit | \(68.6\) | \(63.3\) | \(54.9\) | \(\phantom{00}-\) | \(\phantom{00}-\) | \(\phantom{00}-\) | |

GBT (PEM) | \(\textbf{71.9}\) | \(\textbf{71.5}\) | \(\textbf{67.7}\) | \(\phantom{00}-\) | \(\phantom{00}-\) | \(\phantom{00}-\) | ||

Cox-PH (CS) | \(23.6\) | \(43.7\) | \(64.3\) | \(13.4\) | \(20.5\) | \(22.3\) | ||

Brier Score | DeepHit | \(\textbf{22.8}\) | \(\textbf{41.0}\) | \(\textbf{57.8}\) | \(14.9\) | \(27.0\) | \(41.5\) | |

MGUS 2 | GBT (PEM) | \(22.9\) | \(41.6\) | \(60.5\) | \(\textbf{13.0}\) | \(\textbf{20.1}\) | \(\textbf{22.1}\) | |

Cox-PH (CS) | \(66.7\) | \(\textbf{65.9}\) | \(\textbf{62.4}\) | \(68.8\) | \(69.4\) | \(70.1\) | ||

C-Index | DeepHit | \(59.6\) | \(57.0\) | \(52.3\) | \(65.5\) | \(67.2\) | \(68.6\) | |

GBT (PEM) | \(\textbf{68.4}\) | \(62.9\) | \(60.5\) | \(\textbf{72.6}\) | \(\textbf{70.9}\) | \(\textbf{70.8}\) | ||

Cox-PH | \(\phantom{0}9.4\) | \(13.1\) | \(25.1\) | \(35.5\) | \(44.3\) | \(50.6\) | ||

Brier Score | DeepHit | \(\phantom{0}9.5\) | \(16.0\) | \(28.9\) | \(33.0\) | \(38.8\) | \(\textbf{41.0}\) | |

GBT (PEM) | \(\phantom{0}\textbf{7.2}\) | \(\textbf{11.6}\) | \(\textbf{20.6}\) | \(\textbf{30.1}\) | \(\textbf{38.0}\) | \(43.6\) | ||

synthetic (TVE, CR) | Cox-PH | \(90.2\) | \(89.5\) | \(85.4\) | \(\textbf{86.5}\) | \(\textbf{83.9}\) | \(\textbf{81.6}\) | |

C-Index | DeepHit | \(92.3\) | \(90.8\) | \(84.6\) | \(82.0\) | \(80.1\) | \(79.8\) | |

GBT (PEM) | \(\textbf{93.9}\) | \(\textbf{92.2}\) | \(\textbf{87.5}\) | \(80.9\) | \(80.8\) | \(81.0\) |

We now briefly describe algorithmic details and discuss the complexity of the resulting algorithms when using the proposed framework.

The proposed framework is general in the sense that it transforms a survival task into a regression task. Nevertheless, different methods (and algorithms) have different strengths and weaknesses and different strategies can be applied to specify various alternative models within this framework. For example, in tree based methods, time-variation of feature effects could be controlled by allowing interactions of the time variable only with a subset of features, e.g. based on prior information, and similarly in order to control shared vs. transition specific effects in the multi-state setting. Tree-based methods are particularly intuitive when it comes to understanding the integration of TVE and extension to multi-state models via interaction terms into the model. This is illustrated in Figure 2. For example, in panel (A) of Figure 2, features and split points before the split w.r.t. time indicate feature effects common to all time-points. Once the data in panel (A) is split w.r.t. time \(t\), the predicted hazard will be different for intervals with \(\kappa_j < 3\) and \(\kappa_j \geq 3\) for observations with \(x_1 < .5\). Similarly, in a multi-state setting (panel (B) in Figure 2), splits above the split w.r.t. \(k\) indicate shared effects for all transitions, while splits below indicate different effects for transitions \(k < 2\) vs. \(k\geq 2\). Forcing a split w.r.t. to \(k\) at the root node would be equivalent to an estimation of cause specific hazards on each subset and no shared effects.

Neural networks are particularly flexible when it comes to the specification of different PEMs. For example, the network could be split in two subnetworks, one for the temporal component, one for features, which is equivalent to the specification of a proportional hazards model, while allowing for non-linearity and high-dimensional interactions in feature effects. Similarly, defining subnetworks of the time variable for each category of a categorical feature would imply a stratified proportional hazards model.

As described in the literature review, various approaches exist that account for special survival characteristics like TVF, CR or continuous time-scale prediction by altering the underlying method. While adapting the structure of the algorithm itself potentially increases the complexity of the method, our approach leaves the algorithm of choice unchanged as different time points and transitions are simply included as features. This allows to employ commonly used prediction methods without introducing further algorithmic complexity. We note, however, that our approach might be improved upon in terms of scaling with respect to the number of intervals \(J\) relative to the number of observations \(n\). In the worst case, the number of total data points is quadratic in \(n\) (or more precisely \(\mathcal{O}(n(n+1)/2)\)) when one interval cut-point is introduced for each observed event or censoring time. We therefore propose a refinement of the presented method that improves run-times without forfeiting performance. Instead of setting cut-points at all unique event times, we suggest to define cut points more sparsely, for example, based on a sub-sample of the original data.

To investigate this strategy we conduct a scaling experiment where the sample size was consecutively doubled starting from \(n=400\) up to \(n=3200\). For each sample size, ten replications of one experiment as described for the “synthetic TVE” setting in Section 3 were performed and the elapsed time (hours) as well as performance (IBS) for two different strategies of cut-point selection was measured. The first strategy (full) uses all event times (\(t_i\) where \(\delta_i = 1\)) as cut-points. The second strategy (sub-sample) is equivalent to the first strategy, but event times were chosen based on a sub-sample of \(n'=200\), selected randomly from the training data in each iteration. Results in Table 7 show that the “sub-sample” strategy leads to an approximately linear increase in computation time while the performance remains virtually unchanged. Potentially, a sparser choice of cut-points could also lead to a more robust and thus improved hazard estimation, as more events are available in each interval, but we did not conduct a formal investigation in that regard.

n | |||||

strategy | 400 | 800 | 1600 | 3200 | |

time (hours) | full | \(0.10\) | \(0.48\) | \(2.49\) | \(8.94\) |

sub-sample | \(\textbf{0.09}\) | \(\textbf{0.20}\) | \(\textbf{0.51}\) | \(\textbf{1.04}\) | |

IBS | full | \(8.10\) | \(6.50\) | \(6.40\) | \(\textbf{5.90}\) |

sub-sample | \(\textbf{8.00}\) | \(\textbf{6.40}\) | \(\textbf{6.20}\) | \(\textbf{5.90}\) |

We have presented a general machine learning framework for time-to-event analysis based on a data augmentation strategy that reduces a large variety of survival analysis tasks to the optimization of a Poisson likelihood. We demonstrated its versatility and state-of-the-art performance. The availability of Poisson regression for most machine learning frameworks provides additional practical advantages. For example, photon-ML [35] is a scalable machine learning library for Apache Spark [36] that has no native support for survival analysis, but implements generalized linear mixed models. Therefore, survival modeling with high cardinality random effects (frailty) is directly available using our framework. Similarly, lightGBM [37], a high-performance implementation of GBT, currently has no implementation of survival methods, but could be also used for high-dimensional survival tasks based on PEMs, including reliability analysis or churn analysis with intermediate states.

This work has been funded by the German Federal Ministry of Education and Research (BMBF) under Grant No. 01IS18036A. The authors of this work take full responsibilities for its content.

[1] Ishwaran, H., Kogalur, U.B., Blackstone, E.H., Lauer, M.S.: Random survival forests. The Annals of Applied Statistics **2**(3), 841–860 (2008).

[2] Ishwaran, H., Gerds, T.A., Kogalur, U.B., Moore, R.D., Gange, S.J., Lau, B.M.: Random survival forests for competing risks. Biostatistics **15**(4), 757–773 (Oct 2014), publisher: Oxford Academic.

[3] Wright, M.N., Ziegler, A.: ranger: AFastImplementation of RandomForests for HighDimensionalData in C++ and R. Journal of Statistical Software **77**(1), 1–17 (Mar 2017).

[4] Jaeger, B.C., Long, D.L., Long, D.M., Sims, M., Szychowski, J.M., Min, Y.I., Mcclure, L.A., Howard, G., Simon, N.: Oblique random survival forests. The Annals of Applied Statistics **13**(3), 1847–1883 (Sep 2019).

[5] Bou-Hamad, I., Larocque, D., Ben-Ameur, H.: A review of survival trees. Statistics Surveys **5**, 44–71 (2011).

[6] Wang, P., Li, Y., Reddy, C.K.: Machine Learning for SurvivalAnalysis: ASurvey. ACM Computing Surveys (CSUR) **51**(6), 110:1–110:36 (Feb 2019).

[7] Klein, J.P., Moeschberger, M.L.: Survival analysis: techniques for censored and truncated data. Springer Science & Business Media (2006).

[8] Faraggi, D., Simon, R.: A neural network model for survival data. Statistics in Medicine **14**(1), 73–82 (1995).

[9] Ranganath, R., Perotte, A., Elhadad, N., Blei, D.: Deep SurvivalAnalysis. arXiv:1608.02158 (Aug 2016).

[10] Alaa, A.M., van der Schaar, M.: Deep multi-task gaussian processes for survival analysis with competing risks. In: Proceedings of the 31st International Conference on Neural Information Processing Systems, pp. 2326–2334 (2017).

[11] Lee, C., Zame, W.R., Yoon, J., Schaar, M.v.d.: DeepHit: ADeepLearningApproach to SurvivalAnalysisWithCompetingRisks. In: Thirty-SecondAAAIConference on ArtificialIntelligence (Apr 2018).

[12] Lee, C., Yoon, J., Schaar, M.v.d.: Dynamic-DeepHit: ADeepLearningApproach for DynamicSurvivalAnalysisWithCompetingRisksBased on LongitudinalData. IEEE transactions on bio-medical engineering **67**(1), 122–133 (2020).

[13] Binder, H., Allignol, A., Schumacher, M., Beyersmann, J.: Boosting for high-dimensional time-to-event data with competing risks. Bioinformatics **25**(7), 890–896 (Apr 2009), publisher: Oxford Academic.

[14] Hothorn, T., Bühlmann, P.: Model-based boosting in high dimensions. Bioinformatics **22**(22), 2828–2829 (2006).

[15] Reulen, H., Kneib, T.: Boosting multi-state models. Lifetime Data Analysis pp. 1–22 (May 2015).

[16] Chen, T., Guestrin, C.: XGBoost: AScalableTreeBoostingSystem. Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining - KDD ’16 pp. 785–794 (2016), arXiv: 1603.02754.

[17] Lee, D.K.K., Chen, N., Ishwaran, H.: Boosted nonparametric hazards with time-dependent covariates. arXiv:1701.07926 [stat] (Nov 2019), arXiv: 1701.07926.

[18] Liestbl, K., Andersen, P.K., Andersen, U.: Survival analysis and neural nets. Statistics in Medicine **13**(12), 1189–1200 (1994).

[19] Biganzoli, E., Boracchi, P., Marubini, E.: A general framework for neural network models on censored survival data. Neural Networks **15**(2), 209–218 (Mar 2002).

[20] Huang, X., Chen, S., Soong, S.j.: Piecewise ExponentialSurvivalTrees with Time-DependentCovariates. Biometrics **54**(4), 1420–1433 (1998), publisher: [Wiley, International Biometric Society].

[21] Cai, T., Hyndman, R.J., Wand, M.P.: Mixed model-based hazard estimation. Journal of Computational and Graphical Statistics **11**(4), 784–798 (2002).

[22] Fornili, M., Ambrogi, F., Boracchi, P., Biganzoli, E.: Piecewise ExponentialArtificialNeuralNetworks (PEANN) for ModelingHazardFunction with RightCensoredData. In: Formenti, E., Tagliaferri, R., Wit, E. (eds.) Computational IntelligenceMethods for Bioinformatics and Biostatistics. pp. 125–136. Lecture Notes in ComputerScience, Springer International Publishing, Cham (2014).

[23] Sennhenn‐Reulen, H., Kneib, T.: Structured fusion lasso penalized multi-state models. Statistics in Medicine **35**(25), 4637–4659 (2016).

[24] Friedman, M.: Piecewise exponential models for survival data with covariates. The Annals of Statistics **10**(1), 101–113 (1982).

[25] Bender, A., Scheipl, F., Hartl, W., Day, A.G., Küchenhoff, H.: Penalized estimation of complex, non-linear exposure-lag-response associations. Biostatistics (2018).

[26] Bender, A., Groll, A., Scheipl, F.: A generalized additive model approach to time-to-event analysis. Statistical Modelling p. 1471082X17748083 (2018).

[27] Guo, G.: Event-HistoryAnalysis for Left-TruncatedData. Sociological Methodology **23**, 217–243 (1993).

[28] Iacobelli, S., Carstensen, B.: Multiple time scales in multi-state models. Statistics in Medicine **32**(30), 5315–5327 (Dec 2013).

[29] Hothorn, T., Hornik, K., Zeileis, A.: Unbiased recursive partitioning: a conditional inference framework. Journal of Computational and Graphical Statistics **15**(3), 651–674 (2006).

[30] Friedman, J.H., Hastie, T., Tibshirani, R.: Regularization Paths for GeneralizedLinearModels via CoordinateDescent. Journal of Statistical Software **33**(1), 1–22 (Feb 2010), number: 1.

[31] Gerds, T.A., Schumacher, M.: Consistent Estimation of the ExpectedBrierScore in GeneralSurvivalModels with Right-CensoredEventTimes. Biometrical Journal **48**(6), 1029–1040 (Dec 2006).

[32] Gerds, T.A., Kattan, M.W., Schumacher, M., Yu, C.: Estimating a time-dependent concordance index for survival prediction models with covariate dependent censoring. Statistics in Medicine **32**(13), 2173–2184 (Jun 2013).

[33] Kyle, R.A., Therneau, T.M., Rajkumar, S.V., Offord, J.R., Larson, D.R., Plevak, M.F., Melton, L.J.: A Long-TermStudy of Prognosis in MonoclonalGammopathy of UndeterminedSignificance. New England Journal of Medicine **346**(8), 564–569 (Feb 2002).

[34] Cox, D.R.: Regression Models and Life-Tables. Journal of the Royal Statistical Society. Series B (Methodological) (Vol. 34, No. 2.), 187–220 (1972).

[35] Zhang, X., Zhou, Y., Ma, Y., Chen, B.C., Zhang, L., Agarwal, D.: Glmix: Generalized linear mixed models for large-scale response prediction. In: Proceedings of the 22nd ACM SIGKDD International Conference on Knowledge Discovery and Data Mining. pp. 363–372 (2016).

[36] Zaharia, M., Xin, R.S., Wendell, P., Das, T., Armbrust, M., Dave, A., Meng, X., Rosen, J., Venkataraman, S., Franklin, M.J., Ghodsi, A., Gonzalez, J., Shenker, S., Stoica, I.: Apache spark: A unified engine for big data processing. Commun. ACM **59**(11), 56–65 (Oct 2016).

[37] Ke, G., Meng, Q., Finley, T., Wang, T., Chen, W., Ma, W., Ye, Q., Liu, T.Y.: LightGBM: AHighlyEfficientGradientBoostingDecisionTree. In: Guyon, I., Luxburg, U.V., Bengio, S., Wallach, H., Fergus, R., Vishwanathan, S., Garnett, R. (eds.) Advances in NeuralInformationProcessingSystems 30, pp. 3146–3154. Curran Associates, Inc. (2017).