Skip to content

Estimating with TwoStagesFitterExact¤

Introduction¤

The CoxPHFitter from the Python lifelines package, which is used in the first stage of TwoStagesFitter, employs Efron’s approximation of the partial likelihood function when ties are present. While Efron's method is computationally efficient for large sample sizes, it may yield biased coefficient estimates when the sample size is small.

Therefore, for datasets with up to approximately 500 observations, it is recommended to use the exact method, i.e., TwoStagesFitterExact, as illustrated below. This method employs ConditionalLogit models from statsmodels to estimate the \(\beta_j\) coefficients using the exact likelihood. However, due to its computational complexity, it is suitable only for small sample sizes. Additional tools for model selection and screening available in PyDTS for use with TwoStagesFitter also have corresponding "Exact" versions for small sample sizes, which rely on TwoStagesFitterExact.

Data Preparation¤

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from pydts.examples_utils.generate_simulations_data import generate_quick_start_df
from pydts.examples_utils.plots import plot_example_pred_output
import warnings
pd.set_option("display.max_rows", 500)
warnings.filterwarnings('ignore')
%matplotlib inline
real_coef_dict = {
    "alpha": {
        1: lambda t: -1. + 0.4 * np.log(t),
        2: lambda t: -1. + 0.4 * np.log(t),
    },
    "beta": {
        1: -0.4*np.log([0.8, 3, 3, 2.5, 2]),
        2: -0.3*np.log([1, 3, 4, 3, 2]),
    }
}

n_patients = 300
n_cov = 5
patients_df = generate_quick_start_df(n_patients=n_patients, n_cov=n_cov, d_times=4, 
                                      j_events=2, pid_col='pid', seed=0,
                                      real_coef_dict=real_coef_dict, censoring_prob=0.1)

patients_df.head()
pid Z1 Z2 Z3 Z4 Z5 J T C X
0 0 0.548814 0.715189 0.602763 0.544883 0.423655 1 2 5 2
1 1 0.645894 0.437587 0.891773 0.963663 0.383442 2 2 5 2
2 2 0.791725 0.528895 0.568045 0.925597 0.071036 1 1 5 1
3 3 0.087129 0.020218 0.832620 0.778157 0.870012 1 4 5 4
4 4 0.978618 0.799159 0.461479 0.780529 0.118274 2 1 5 1

Estimation¤

In the following we apply the estimation method of Meir et al. (2022). Note that the data dataframe must not contain a column named 'C'.

from pydts.fitters import TwoStagesFitterExact
new_fitter = TwoStagesFitterExact()
new_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1))
new_fitter.print_summary()
1 2
coef std err z P>|z| [0.025 0.975] coef std err z P>|z| [0.025 0.975]
Z1 -0.2946 0.347 -0.848 0.397 -0.976 0.386 0.2337 0.321 0.728 0.467 -0.395 0.863
Z2 -0.8902 0.367 -2.427 0.015 -1.609 -0.171 -0.3483 0.341 -1.022 0.307 -1.016 0.319
Z3 -0.1380 0.348 -0.397 0.692 -0.820 0.544 -0.7829 0.333 -2.349 0.019 -1.436 -0.130
Z4 -0.4728 0.328 -1.442 0.149 -1.115 0.170 0.1123 0.310 0.362 0.718 -0.496 0.721
Z5 -0.3284 0.349 -0.941 0.347 -1.012 0.356 -0.0659 0.330 -0.200 0.842 -0.712 0.580


Model summary for event: 1

n_jt success alpha_jt
J X
1 1 48 True -0.632059
2 32 True -0.552136
3 27 True -0.083812
4 15 True 0.056266


Model summary for event: 2

n_jt success alpha_jt
J X
2 1 56 True -1.093067
2 43 True -0.817784
3 27 True -0.744723
4 16 True -0.545621

Standard Error of the Regression Coefficients¤

new_fitter.get_beta_SE()
1 2
coef std err z P>|z| [0.025 0.975] coef std err z P>|z| [0.025 0.975]
Z1 -0.2946 0.347 -0.848 0.397 -0.976 0.386 0.2337 0.321 0.728 0.467 -0.395 0.863
Z2 -0.8902 0.367 -2.427 0.015 -1.609 -0.171 -0.3483 0.341 -1.022 0.307 -1.016 0.319
Z3 -0.1380 0.348 -0.397 0.692 -0.820 0.544 -0.7829 0.333 -2.349 0.019 -1.436 -0.130
Z4 -0.4728 0.328 -1.442 0.149 -1.115 0.170 0.1123 0.310 0.362 0.718 -0.496 0.721
Z5 -0.3284 0.349 -0.941 0.347 -1.012 0.356 -0.0659 0.330 -0.200 0.842 -0.712 0.580

Regularization¤

The Exact version supports adding regularization when estimating the Beta coefficients. It is done by passing the fit_beta_kwargs argument to the fit() method. The added regularization term is of the form: $$ \mbox{Penalizer} \cdot \Bigg( \frac{1-\mbox{L1_wt}}{2}||\beta||_{2}^{2} + \mbox{L1_wt} ||\beta||_1 \Bigg) $$ In statsmodels, the penalization parameter is denoted as alpha. Thus, adding L1, L2, or Elastic Net regularization can be done as follows:

L1¤

L1_regularized_fitter = TwoStagesFitterExact()

fit_beta_kwargs = {
    'model_fit_kwargs': {
        1: {
                'alpha': 0.003,
                'L1_wt': 1
        },
        2: {
                'alpha': 0.005,
                'L1_wt': 1
        }
    }
}

L1_regularized_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)

L1_regularized_fitter.get_beta_SE()
1 2
Z1 -0.058633 0.000000
Z2 -0.653238 0.000000
Z3 0.000000 -0.455494
Z4 -0.290221 0.000000
Z5 -0.143967 0.000000

L2¤

L2_regularized_fitter = TwoStagesFitterExact()

fit_beta_kwargs = {
    'model_fit_kwargs': {
        1: {
                'alpha': 0.0,
                'L1_wt': 0
        },
        2: {
                'alpha': 0.002,
                'L1_wt': 0
        }
    }
}

L2_regularized_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)

L2_regularized_fitter.get_beta_SE()
1 2
Z1 -0.294621 0.203402
Z2 -0.890082 -0.305744
Z3 -0.137424 -0.685264
Z4 -0.473069 0.096045
Z5 -0.328632 -0.074304

Elastic Net¤

EN_regularized_fitter = TwoStagesFitterExact()

fit_beta_kwargs = {
    'model_kwargs': {
        'alpha': 0.003,
        'L1_wt': 0.5
    }
}

EN_regularized_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)

EN_regularized_fitter.get_beta_SE()
1 2
coef std err z P>|z| [0.025 0.975] coef std err z P>|z| [0.025 0.975]
Z1 -0.2946 0.347 -0.848 0.397 -0.976 0.386 0.2337 0.321 0.728 0.467 -0.395 0.863
Z2 -0.8902 0.367 -2.427 0.015 -1.609 -0.171 -0.3483 0.341 -1.022 0.307 -1.016 0.319
Z3 -0.1380 0.348 -0.397 0.692 -0.820 0.544 -0.7829 0.333 -2.349 0.019 -1.436 -0.130
Z4 -0.4728 0.328 -1.442 0.149 -1.115 0.170 0.1123 0.310 0.362 0.718 -0.496 0.721
Z5 -0.3284 0.349 -0.941 0.347 -1.012 0.356 -0.0659 0.330 -0.200 0.842 -0.712 0.580

Prediction¤

Full prediction is given by the method predict_cumulative_incident_function()

The input is a pandas.DataFrame() containing for each observation the covariates columns which were used in the fit() method (Z1-Z5 in our example).

The following columns will be added:

  1. The overall survival at each time point t
  2. The hazard for each failure type \(j\) at each time point t
  3. The probability of event type \(j\) at time t
  4. The Cumulative Incident Function (CIF) of event type \(j\) at time t

In the following, we provide predictions for the individuals with ID values (pid) 0, 1 and 2. We transposed the output for easy view.

pred_df = new_fitter.predict_cumulative_incident_function(
    patients_df.drop(['J', 'T', 'C', 'X'], axis=1).head(3)).set_index('pid').T
pred_df.index.name = ''
pred_df.columns = ['ID=0', 'ID=1', 'ID=2']
pred_df
ID=0 ID=1 ID=2
Z1 0.548814 0.645894 0.791725
Z2 0.715189 0.437587 0.528895
Z3 0.602763 0.891773 0.568045
Z4 0.544883 0.963663 0.925597
Z5 0.423655 0.383442 0.071036
overall_survival_t1 0.710298 0.718560 0.675201
overall_survival_t2 0.469059 0.481542 0.418215
overall_survival_t3 0.273296 0.285601 0.225275
overall_survival_t4 0.143109 0.152871 0.107048
hazard_j1_t1 0.128945 0.128253 0.132910
hazard_j1_t2 0.138191 0.137457 0.142394
hazard_j1_t3 0.203904 0.202903 0.209619
hazard_j1_t4 0.227586 0.226502 0.233770
hazard_j2_t1 0.160757 0.153188 0.191889
hazard_j2_t2 0.201439 0.192393 0.238213
hazard_j2_t3 0.213448 0.204001 0.251724
hazard_j2_t4 0.248774 0.238237 0.291040
prob_j1_at_t1 0.128945 0.128253 0.132910
prob_j1_at_t2 0.098157 0.098771 0.096145
prob_j1_at_t3 0.095643 0.097706 0.087666
prob_j1_at_t4 0.062198 0.064689 0.052662
prob_j2_at_t1 0.160757 0.153188 0.191889
prob_j2_at_t2 0.143082 0.138246 0.160842
prob_j2_at_t3 0.100120 0.098235 0.105275
prob_j2_at_t4 0.067989 0.068041 0.065564
cif_j1_at_t1 0.128945 0.128253 0.132910
cif_j1_at_t2 0.227102 0.227024 0.229055
cif_j1_at_t3 0.322745 0.324730 0.316721
cif_j1_at_t4 0.384944 0.389420 0.369383
cif_j2_at_t1 0.160757 0.153188 0.191889
cif_j2_at_t2 0.303839 0.291434 0.352730
cif_j2_at_t3 0.403958 0.389669 0.458005
cif_j2_at_t4 0.471948 0.457710 0.523569