Survival Analysis Functions
Following the notation used in the the DRSA paper, we define the following:
Let $z$ be the true occurrence time for the event of interest.
Let $t$ be the time that a given data point was observed.
For each observation, there exist $L$ time slices, ie $0 < t_1 < t_2 < \dots < t_L$, at which we either observe the event (uncensored) or do not (censored).
Let $V_l = (t_{l-1}, t_l]$ be the set of all disjoint intervals with $l = 1, 2, \dots, L$.
Discrete Survival function
Though it's given its own name is survival analysis, the survival function is simply calculated as $1 - \text{CDF}(z)$. In the discrete, empirical case, the survival function is estimated as follows (this is equation (5) in the paper).
$$ S(t_l) = Pr(z > t_l) = \sum_{j > l}Pr(z\in V_j) $$
# example
h1 = torch.tensor([[0.001],
[0.5],
[0.55],
[0.15],
[0.15],
[0.15],
[0.15],
[0.9]], requires_grad=True)
h2 = torch.tensor([[0.001],
[0.005],
[0.1],
[0.11],
[0.12],
[0.15],
[0.15],
[0.9]], requires_grad=True)
h = torch.stack([h1, h2], dim=0)
survival_rate(h)
# example
event_rate(h)
# example
event_time(h)
Discrete Conditional Hazard Rate
The conditional hazard rate is the quantity which will be predicted at each time step by a recurrent survival analysis model. In the discrete, empirical case, it is estimated as follows (this is equation (7) in the paper).
$$h_l = Pr(z\in V_l | z > t_{l-1}) = \frac{Pr(z\in V_l)}{Pr(z>t_{l-1})} = \frac{p_l}{S(t_{l-1})}$$
#example
log_survival_rate(h)
# example
log_event_rate(h)
# example
log_event_time(h)
Loss Functions
Now, we define the transform these generic survival analysis functions into loss functions that can be automatically differentiated by PyTorch, in order to train a Deep Recurrent Survival Analysis model.
We make a few notes below:
The functions below adhere to the common pattern used across all of
PyTorch
's loss functions, which is to take two arguments namedinput
andtarget
. We note, however, that due to the nature of this survival data, the target is inherent to the data structure and thus unnecessary.The original DRSA paper defines 3 loss functions, 2 of which are directed towards uncensored data, and 1 of which applies to censored data. This library's focus is on DRSA models using only uncensored data, so those are the only lossed we'll be defining.
# example
event_time_loss(h)
# example
event_rate_loss(h)