This module contains implementations of both traditional survival analysis functions, as well as the loss functions associated with uncensored data, as defined in the original DRSA paper.

Survival Analysis Functions

Following the notation used in the the DRSA paper, we define the following:

  • Let $z$ be the true occurrence time for the event of interest.

  • Let $t$ be the time that a given data point was observed.

  • For each observation, there exist $L$ time slices, ie $0 < t_1 < t_2 < \dots < t_L$, at which we either observe the event (uncensored) or do not (censored).

  • Let $V_l = (t_{l-1}, t_l]$ be the set of all disjoint intervals with $l = 1, 2, \dots, L$.

Discrete Survival function

Though it's given its own name is survival analysis, the survival function is simply calculated as $1 - \text{CDF}(z)$. In the discrete, empirical case, the survival function is estimated as follows (this is equation (5) in the paper).

$$ S(t_l) = Pr(z > t_l) = \sum_{j > l}Pr(z\in V_j) $$

survival_rate[source]

survival_rate(h)

Given the predicted conditional hazard rate, this function estimates the survival rate.

input:

  • h:
    • type: torch.tensor,
    • predicted conditional hazard rate, at each observed time step.
    • note: h.shape == (batch size, 1, 1), as this is most amenable to use in training neural nets with pytorch.

output:

  • s:
    • type: torch.tensor
    • estimated survival rate at time t.
    • note: s.shape == (batch_size, 1)
# example
h1 = torch.tensor([[0.001],
                   [0.5],
                   [0.55],
                   [0.15],
                   [0.15],
                   [0.15],
                   [0.15],
                   [0.9]], requires_grad=True)
h2 = torch.tensor([[0.001],
                    [0.005],
                    [0.1],
                    [0.11],
                    [0.12],
                    [0.15],
                    [0.15],
                    [0.9]], requires_grad=True)
h = torch.stack([h1, h2], dim=0)
survival_rate(h)
tensor([[0.0117],
        [0.0506]], grad_fn=<ProdBackward1>)

Discrete Event Rate function

The event rate function is calculated as $\text{CDF}(z)$. In the discrete, empirical case, it is estimated as follows (this is equation (5) in the paper).

$$ W(t_l) = Pr(z \leq t_l) = \sum_{j\leq l}Pr(z\in V_j) $$

event_rate[source]

event_rate(h)

Given the predicted conditional hazard rate, this function estimates the event rate.

input:

  • h:
    • type: torch.tensor,
    • predicted conditional hazard rate, at each observed time step.
    • note: h.shape == (batch size, 1, 1), as this is most amenable to use in training neural nets with pytorch.

output:

  • w:
    • type: torch.tensor
    • estimated survival rate at time t.
    • note: w.shape == (batch_size, 1)
# example
event_rate(h)
tensor([[0.9883],
        [0.9494]], grad_fn=<RsubBackward1>)

Discrete Event Time Probability function

The event time probability function is calculated as $\text{PDF}(z)$. In the discrete, empirical case, it is estimated as follows (this is equation (6) in the paper).

$$p_l = Pr(z\in V_t) = W(t_l) - W(t_{l-1}) = S(t_{l-1}) - S(t_{l})$$

event_time[source]

event_time(h)

Given the predicted conditional hazard rate, this function estimates the probability that the event occurs at time t.

input:

  • h:
    • type: torch.tensor,
    • predicted conditional hazard rate, at each observed time step.
    • note: h.shape == (batch size, 1, 1), as this is most amenable to use in training neural nets with pytorch.

output:

  • p:
    • type: torch.tensor
    • estimated probability of event at time t.
    • note: p.shape == (batch_size, 1)
# example
event_time(h)
tensor([[0.1056],
        [0.4556]], grad_fn=<MulBackward0>)

Discrete Conditional Hazard Rate

The conditional hazard rate is the quantity which will be predicted at each time step by a recurrent survival analysis model. In the discrete, empirical case, it is estimated as follows (this is equation (7) in the paper).

$$h_l = Pr(z\in V_l | z > t_{l-1}) = \frac{Pr(z\in V_l)}{Pr(z>t_{l-1})} = \frac{p_l}{S(t_{l-1})}$$

Log Survival Analysis Functions

We additionally define the log of each of the traditional survival analysis functions, which prove useful for computational stability, being that we need to multiply many float point decimal values together.

Log Survival Function

log_survival_rate[source]

log_survival_rate(h)

Given the predicted conditional hazard rate, this function estimates the log survival rate.

input:

  • h:
    • type: torch.tensor,
    • predicted conditional hazard rate, at each observed time step.
    • note: h.shape == (batch size, 1, 1), as this is most amenable to use in training neural nets with pytorch.

output:

  • s:
    • type: torch.tensor
    • estimated log survival rate at time t.
    • note: s.shape == (batch_size, 1)
#example
log_survival_rate(h)
tensor([[-4.4453],
        [-2.9834]], grad_fn=<SumBackward1>)

Log Event Rate Function

log_event_rate[source]

log_event_rate(h)

Given the predicted conditional hazard rate, this function estimates the log event rate.

input:

  • h:
    • type: torch.tensor,
    • predicted conditional hazard rate, at each observed time step.
    • note: h.shape == (batch size, 1, 1), as this is most amenable to use in training neural nets with pytorch.

output:

  • w:
    • type: torch.tensor
    • estimated log survival rate at time t.
    • note: w.shape == (batch_size, 1)
# example
log_event_rate(h)
tensor([[-0.0118],
        [-0.0519]], grad_fn=<LogBackward>)

Log Event Time Function

log_event_time[source]

log_event_time(h)

Given the predicted conditional hazard rate, this function estimates the log probability that the event occurs at time t.

input:

  • h:
    • type: torch.tensor,
    • predicted conditional hazard rate, at each observed time step.
    • note: h.shape == (batch size, 1, 1), as this is most amenable to use in training neural nets with pytorch.

output:

  • p:
    • type: torch.tensor
    • estimated log probability of event at time t.
    • note: p.shape == (batch_size, 1)
# example
log_event_time(h)
tensor([[-2.2481],
        [-0.7861]], grad_fn=<AddBackward0>)

Loss Functions

Now, we define the transform these generic survival analysis functions into loss functions that can be automatically differentiated by PyTorch, in order to train a Deep Recurrent Survival Analysis model.

We make a few notes below:

  1. The functions below adhere to the common pattern used across all of PyTorch's loss functions, which is to take two arguments named input and target. We note, however, that due to the nature of this survival data, the target is inherent to the data structure and thus unnecessary.

  2. The original DRSA paper defines 3 loss functions, 2 of which are directed towards uncensored data, and 1 of which applies to censored data. This library's focus is on DRSA models using only uncensored data, so those are the only lossed we'll be defining.

Event Time Loss

event_time_loss[source]

event_time_loss(input, target=None)

Loss function applied to uncensored data in order to optimize the PDF of the true event time, z

input:

  • input:
    • type: torch.tensor,
    • predicted conditional hazard rate, at each observed time step.
    • note: h.shape == (batch size, 1, 1)
  • target:
    • unused, only present to mimic pytorch loss functions

output:

  • evt_loss:
    • type: torch.tensor
    • Loss associated with how wrong each predicted probability was at each time step
# example
event_time_loss(h)
tensor(1.5171, grad_fn=<NegBackward>)

Event Rate Loss

event_rate_loss[source]

event_rate_loss(input, target=None)

Loss function applied to uncensored data in order to optimize the CDF of the true event time, z

input:

  • input:
    • type: torch.tensor,
    • predicted conditional hazard rate, at each observed time step.
    • note: h.shape == (batch size, 1, 1)
  • target:
    • unused, only present to mimic pytorch loss functions

output:

  • evr_loss:
    • type: torch.tensor
    • Loss associated with how cumulative predicted probabilities differ from the ground truth labels.
# example
event_rate_loss(h)
tensor(0.0319, grad_fn=<NegBackward>)