Survival Analysis Functions
Following the notation used in the the DRSA paper, we define the following:
Let $z$ be the true occurrence time for the event of interest.
Let $t$ be the time that a given data point was observed.
For each observation, there exist $L$ time slices, ie $0 < t_1 < t_2 < \dots < t_L$, at which we either observe the event (uncensored) or do not (censored).
Let $V_l = (t_{l-1}, t_l]$ be the set of all disjoint intervals with $l = 1, 2, \dots, L$.
Discrete Survival function
Though it's given its own name is survival analysis, the survival function is simply calculated as $1 - \text{CDF}(z)$. In the discrete, empirical case, the survival function is estimated as follows (this is equation (5) in the paper).
$$ S(t_l) = Pr(z > t_l) = \sum_{j > l}Pr(z\in V_j) $$
# example
h1 = torch.tensor([[0.001],
[0.5],
[0.55],
[0.15],
[0.15],
[0.15],
[0.15],
[0.9]], requires_grad=True)
h2 = torch.tensor([[0.001],
[0.005],
[0.1],
[0.11],
[0.12],
[0.15],
[0.15],
[0.9]], requires_grad=True)
h = torch.stack([h1, h2], dim=0)
survival_rate(h)
# example
event_rate(h)
# example
event_time(h)
Discrete Conditional Hazard Rate
The conditional hazard rate is the quantity which will be predicted at each time step by a recurrent survival analysis model. In the discrete, empirical case, it is estimated as follows (this is equation (7) in the paper).
$$h_l = Pr(z\in V_l | z > t_{l-1}) = \frac{Pr(z\in V_l)}{Pr(z>t_{l-1})} = \frac{p_l}{S(t_{l-1})}$$
#example
log_survival_rate(h)
# example
log_event_rate(h)
# example
log_event_time(h)
Loss Functions
Now, we define the transform these generic survival analysis functions into loss functions that can be automatically differentiated by PyTorch, in order to train a Deep Recurrent Survival Analysis model.
We make a few notes below:
The functions below adhere to the common pattern used across all of
PyTorch's loss functions, which is to take two arguments namedinputandtarget. We note, however, that due to the nature of this survival data, the target is inherent to the data structure and thus unnecessary.The original DRSA paper defines 3 loss functions, 2 of which are directed towards uncensored data, and 1 of which applies to censored data. This library's focus is on DRSA models using only uncensored data, so those are the only lossed we'll be defining.
# example
event_time_loss(h)
# example
event_rate_loss(h)