Skip to content

Estimating with TwoStagesFitter¤

Estimation¤

In the following we apply the estimation method of Meir et al. (2022). Note that the data dataframe must not contain a column named 'C'.

from pydts.fitters import TwoStagesFitter
new_fitter = TwoStagesFitter()
new_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1))
new_fitter.print_summary()
INFO: Pandarallel will run on 4 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.

j1_params j1_SE j2_params j2_SE
covariate
Z1 0.187949 0.025068 0.040169 0.037807
Z2 -1.100792 0.025610 -1.100246 0.038696
Z3 -1.093466 0.025726 -1.410202 0.039280
Z4 -0.874521 0.025437 -1.097849 0.038642
Z5 -0.652655 0.025280 -0.654501 0.038179


Model summary for event: 1

n_jt success alpha_jt
J X
1 1 3374 True -0.987702
2 2328 True -1.220809
3 1805 True -1.358580
4 1524 True -1.409997
5 1214 True -1.530437
6 1114 True -1.511889
7 916 True -1.614043
8 830 True -1.618019
9 683 True -1.718359
10 626 True -1.714668
11 569 True -1.720344
12 516 True -1.728207
13 419 True -1.845399
14 410 True -1.776981
15 326 True -1.909345
16 320 True -1.841848
17 280 True -1.881339
18 240 True -1.950204
19 243 True -1.837087
20 204 True -1.914093
21 176 True -1.978425
22 167 True -1.935467
23 166 True -1.832599
24 118 True -2.068397
25 114 True -1.996911
26 109 True -1.925090
27 89 True -2.008449
28 70 True -2.120056
29 67 True -2.033129
30 47 True -2.231271


Model summary for event: 2

n_jt success alpha_jt
J X
2 1 1250 True -1.737087
2 839 True -1.981763
3 805 True -1.881945
4 644 True -1.991485
5 570 True -1.998569
6 483 True -2.055976
7 416 True -2.099660
8 409 True -2.019652
9 323 True -2.150486
10 306 True -2.112509
11 240 True -2.250577
12 246 True -2.142076
13 226 True -2.132065
14 198 True -2.168557
15 170 True -2.215715
16 162 True -2.178298
17 147 True -2.178342
18 115 True -2.346988
19 125 True -2.151499
20 118 True -2.113865
21 83 True -2.380588
22 89 True -2.190208
23 65 True -2.421944
24 59 True -2.401785
25 58 True -2.318061
26 53 True -2.291874
27 43 True -2.373117
28 38 True -2.368179
29 43 True -2.115566
30 37 True -2.113986
from pydts.examples_utils.plots import plot_second_model_coefs
plot_second_model_coefs(new_fitter.alpha_df, new_fitter.beta_models, new_fitter.times, n_cov=5)

Standard Error of the Regression Coefficients¤

new_fitter.get_beta_SE()
j1_params j1_SE j2_params j2_SE
covariate
Z1 0.187949 0.025068 0.040169 0.037807
Z2 -1.100792 0.025610 -1.100246 0.038696
Z3 -1.093466 0.025726 -1.410202 0.039280
Z4 -0.874521 0.025437 -1.097849 0.038642
Z5 -0.652655 0.025280 -0.654501 0.038179

Regularization¤

It is possible to add regularization when estimating the Beta coefficients. It is done by using the CoxPHFitter (Lifelines) penalizer and l1_ratio arguments, which can be passed using the fit_beta_kwargs argument to the fit() method. The added regularization term is of the form: $$ \mbox{Penalizer} \cdot \Bigg( \frac{1-\mbox{L1_ratio}}{2}||\beta||_{2}^{2} + \mbox{L1_ratio} ||\beta||_1 \Bigg) $$ Examples for adding L1, L2 and Elastic Net regularization are followed.

L1¤

L1_regularized_fitter = TwoStagesFitter()

fit_beta_kwargs = {
    'model_kwargs': {
        'penalizer': 0.003,
        'l1_ratio': 1
    }
}

L1_regularized_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)

L1_regularized_fitter.get_beta_SE()
INFO: Pandarallel will run on 4 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.

j1_params j1_SE j2_params j2_SE
covariate
Z1 0.000002 0.000102 5.690226e-08 0.000041
Z2 -0.774487 0.025401 -3.574822e-01 0.038251
Z3 -0.762942 0.025533 -6.516077e-01 0.038510
Z4 -0.552172 0.025318 -3.590965e-01 0.038235
Z5 -0.340120 0.025211 -1.435430e-06 0.000132

L2¤

L2_regularized_fitter = TwoStagesFitter()

fit_beta_kwargs = {
    'model_kwargs': {
        'penalizer': 0.003,
        'l1_ratio': 0
    }
}

L2_regularized_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)

L2_regularized_fitter.get_beta_SE()
INFO: Pandarallel will run on 4 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.

j1_params j1_SE j2_params j2_SE
covariate
Z1 0.172957 0.024069 0.032774 0.034626
Z2 -1.007262 0.024506 -0.903957 0.035205
Z3 -1.000509 0.024629 -1.162132 0.035589
Z4 -0.799488 0.024384 -0.903531 0.035177
Z5 -0.597079 0.024255 -0.537159 0.034911

Elastic Net¤

EN_regularized_fitter = TwoStagesFitter()

fit_beta_kwargs = {
    'model_kwargs': {
        'penalizer': 0.003,
        'l1_ratio': 0.5
    }
}

EN_regularized_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)

EN_regularized_fitter.get_beta_SE()
INFO: Pandarallel will run on 4 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.

j1_params j1_SE j2_params j2_SE
covariate
Z1 0.039322 0.024542 0.000001 0.000190
Z2 -0.895581 0.024938 -0.654614 0.036595
Z3 -0.886332 0.025065 -0.928867 0.036941
Z4 -0.680998 0.024832 -0.655263 0.036573
Z5 -0.473818 0.024711 -0.265356 0.036382

Separated Penalty Coefficients¤

The above methods can be applied with a separate penalty coefficient to each of the covariates by passing a vector (with same length as the number of covariates) to the penalizer keyword instead of a scalar. For example, applying L2 regularization only to covariates Z1, Z2 can be done as follows:

L2_regularized_fitter = TwoStagesFitter()

fit_beta_kwargs = {
    'model_kwargs': {
        'penalizer': np.array([0.04, 0.04, 0, 0, 0]),
        'l1_ratio': 0
    }
}

L2_regularized_fitter.fit(df=patients_df.drop(['C', 'T'], axis=1), fit_beta_kwargs=fit_beta_kwargs)

L2_regularized_fitter.get_beta_SE()
INFO: Pandarallel will run on 4 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.

j1_params j1_SE j2_params j2_SE
covariate
Z1 0.088314 0.017178 0.011120 0.020019
Z2 -0.515292 0.017378 -0.306194 0.020269
Z3 -1.069182 0.025695 -1.374391 0.039205
Z4 -0.853807 0.025419 -1.066715 0.038602
Z5 -0.641989 0.025272 -0.637811 0.038161

Prediction¤

Full prediction is given by the method predict_cumulative_incident_function()

The input is a pandas.DataFrame() containing for each observation the covariates columns which were used in the fit() method (Z1-Z5 in our example).

The following columns will be added:

  1. The overall survival at each time point t
  2. The hazard for each failure type \(j\) at each time point t
  3. The probability of event type \(j\) at time t
  4. The Cumulative Incident Function (CIF) of event type \(j\) at time t

In the following, we provide predictions for the individuals with ID values (pid) 0, 1 and 2. We transposed the output for easy view.

pred_df = new_fitter.predict_cumulative_incident_function(
    patients_df.drop(['J', 'T', 'C', 'X'], axis=1).head(3)).set_index('pid').T
pred_df.index.name = ''
pred_df.columns = ['ID=0', 'ID=1', 'ID=2']
plot_example_pred_output(pred_df)
pred_df
ID=0 ID=1 ID=2
Z1 0.548814 0.645894 0.791725
Z2 0.715189 0.437587 0.528895
Z3 0.602763 0.891773 0.568045
Z4 0.544883 0.963663 0.925597
Z5 0.423655 0.383442 0.071036
overall_survival_t1 0.941845 0.959704 0.932244
overall_survival_t2 0.898252 0.928975 0.881883
overall_survival_t3 0.859546 0.901548 0.837705
overall_survival_t4 0.824888 0.876595 0.798379
overall_survival_t5 0.794348 0.854430 0.764029
overall_survival_t6 0.765050 0.832843 0.731231
overall_survival_t7 0.739087 0.813536 0.702371
overall_survival_t8 0.713461 0.794332 0.674101
overall_survival_t9 0.691252 0.777486 0.649717
overall_survival_t10 0.669427 0.760785 0.625897
overall_survival_t11 0.649221 0.745080 0.603897
overall_survival_t12 0.629094 0.729356 0.582149
overall_survival_t13 0.610989 0.715148 0.562735
overall_survival_t14 0.592808 0.700678 0.543291
overall_survival_t15 0.576893 0.687938 0.526379
overall_survival_t16 0.560503 0.674682 0.509031
overall_survival_t17 0.544986 0.662036 0.492696
overall_survival_t18 0.531334 0.650765 0.478355
overall_survival_t19 0.516058 0.638087 0.462409
overall_survival_t20 0.501756 0.626164 0.447587
overall_survival_t21 0.489552 0.615813 0.434935
overall_survival_t22 0.476519 0.604735 0.421524
overall_survival_t23 0.463848 0.593720 0.408451
overall_survival_t24 0.453297 0.584569 0.397667
overall_survival_t25 0.442188 0.574864 0.386360
overall_survival_t26 0.430726 0.564748 0.374731
overall_survival_t27 0.420435 0.555592 0.364334
overall_survival_t28 0.411084 0.547251 0.354946
overall_survival_t29 0.400506 0.537811 0.344410
overall_survival_t30 0.391319 0.529622 0.335339
hazard_j1_t1 0.043775 0.031795 0.052250
hazard_j1_t10 0.021649 0.015626 0.025957
hazard_j1_t11 0.021529 0.015539 0.025814
hazard_j1_t12 0.021364 0.015419 0.025617
hazard_j1_t13 0.019047 0.013737 0.022849
hazard_j1_t14 0.020368 0.014696 0.024427
hazard_j1_t15 0.017888 0.012897 0.021464
hazard_j1_t16 0.019113 0.013785 0.022928
hazard_j1_t17 0.018387 0.013259 0.022060
hazard_j1_t18 0.017184 0.012387 0.020622
hazard_j1_t19 0.019202 0.013850 0.023035
hazard_j1_t2 0.034991 0.025352 0.041840
hazard_j1_t20 0.017805 0.012837 0.021364
hazard_j1_t21 0.016714 0.012047 0.020060
hazard_j1_t22 0.017435 0.012569 0.020922
hazard_j1_t23 0.019287 0.013912 0.023136
hazard_j1_t24 0.015298 0.011022 0.018365
hazard_j1_t25 0.016413 0.011829 0.019700
hazard_j1_t26 0.017613 0.012698 0.021135
hazard_j1_t27 0.016227 0.011695 0.019478
hazard_j1_t28 0.014539 0.010472 0.017457
hazard_j1_t29 0.015838 0.011413 0.019012
hazard_j1_t3 0.030625 0.022161 0.036652
hazard_j1_t30 0.013028 0.009381 0.015648
hazard_j1_t4 0.029135 0.021074 0.034880
hazard_j1_t5 0.025915 0.018728 0.031045
hazard_j1_t6 0.026387 0.019071 0.031608
hazard_j1_t7 0.023886 0.017251 0.028626
hazard_j1_t8 0.023794 0.017184 0.028516
hazard_j1_t9 0.021571 0.015569 0.025864
hazard_j2_t1 0.014380 0.008500 0.015506
hazard_j2_t10 0.009924 0.005855 0.010704
hazard_j2_t11 0.008655 0.005104 0.009337
hazard_j2_t12 0.009637 0.005686 0.010396
hazard_j2_t13 0.009733 0.005743 0.010499
hazard_j2_t14 0.009388 0.005538 0.010127
hazard_j2_t15 0.008959 0.005284 0.009665
hazard_j2_t16 0.009298 0.005485 0.010030
hazard_j2_t17 0.009297 0.005484 0.010029
hazard_j2_t18 0.007866 0.004637 0.008486
hazard_j2_t19 0.009548 0.005633 0.010299
hazard_j2_t2 0.011294 0.006668 0.012181
hazard_j2_t20 0.009910 0.005848 0.010690
hazard_j2_t21 0.007608 0.004485 0.008208
hazard_j2_t22 0.009189 0.005420 0.009912
hazard_j2_t23 0.007302 0.004304 0.007878
hazard_j2_t24 0.007450 0.004391 0.008037
hazard_j2_t25 0.008095 0.004773 0.008733
hazard_j2_t26 0.008308 0.004899 0.008963
hazard_j2_t27 0.007665 0.004518 0.008269
hazard_j2_t28 0.007702 0.004540 0.008310
hazard_j2_t29 0.009894 0.005838 0.010672
hazard_j2_t3 0.012465 0.007363 0.013443
hazard_j2_t30 0.009909 0.005847 0.010689
hazard_j2_t4 0.011186 0.006604 0.012065
hazard_j2_t5 0.011108 0.006557 0.011981
hazard_j2_t6 0.010495 0.006194 0.011320
hazard_j2_t7 0.010051 0.005931 0.010841
hazard_j2_t8 0.010879 0.006422 0.011734
hazard_j2_t9 0.009558 0.005638 0.010310
prob_j1_at_t1 0.043775 0.031795 0.052250
prob_j1_at_t2 0.032956 0.024330 0.039005
prob_j1_at_t3 0.027509 0.020587 0.032323
prob_j1_at_t4 0.025043 0.018999 0.029219
prob_j1_at_t5 0.021377 0.016416 0.024785
prob_j1_at_t6 0.020961 0.016295 0.024149
prob_j1_at_t7 0.018274 0.014368 0.020932
prob_j1_at_t8 0.017586 0.013980 0.020029
prob_j1_at_t9 0.015390 0.012367 0.017435
prob_j1_at_t10 0.014965 0.012149 0.016865
prob_j1_at_t11 0.014412 0.011822 0.016157
prob_j1_at_t12 0.013870 0.011488 0.015470
prob_j1_at_t13 0.011982 0.010019 0.013301
prob_j1_at_t14 0.012445 0.010510 0.013746
prob_j1_at_t15 0.010604 0.009037 0.011661
prob_j1_at_t16 0.011026 0.009483 0.012069
prob_j1_at_t17 0.010306 0.008945 0.011229
prob_j1_at_t18 0.009365 0.008201 0.010160
prob_j1_at_t19 0.010203 0.009013 0.011019
prob_j1_at_t20 0.009188 0.008191 0.009879
prob_j1_at_t21 0.008386 0.007543 0.008978
prob_j1_at_t22 0.008535 0.007740 0.009100
prob_j1_at_t23 0.009191 0.008413 0.009752
prob_j1_at_t24 0.007096 0.006544 0.007501
prob_j1_at_t25 0.007440 0.006915 0.007834
prob_j1_at_t26 0.007788 0.007300 0.008166
prob_j1_at_t27 0.006990 0.006604 0.007299
prob_j1_at_t28 0.006113 0.005818 0.006360
prob_j1_at_t29 0.006511 0.006246 0.006748
prob_j1_at_t30 0.005218 0.005045 0.005389
prob_j2_at_t1 0.014380 0.008500 0.015506
prob_j2_at_t2 0.010637 0.006399 0.011356
prob_j2_at_t3 0.011197 0.006840 0.011855
prob_j2_at_t4 0.009615 0.005954 0.010107
prob_j2_at_t5 0.009163 0.005748 0.009565
prob_j2_at_t6 0.008337 0.005292 0.008649
prob_j2_at_t7 0.007689 0.004939 0.007928
prob_j2_at_t8 0.008040 0.005224 0.008241
prob_j2_at_t9 0.006819 0.004479 0.006950
prob_j2_at_t10 0.006860 0.004552 0.006955
prob_j2_at_t11 0.005794 0.003883 0.005844
prob_j2_at_t12 0.006257 0.004236 0.006278
prob_j2_at_t13 0.006123 0.004188 0.006112
prob_j2_at_t14 0.005736 0.003961 0.005699
prob_j2_at_t15 0.005311 0.003703 0.005251
prob_j2_at_t16 0.005364 0.003773 0.005279
prob_j2_at_t17 0.005211 0.003700 0.005105
prob_j2_at_t18 0.004287 0.003070 0.004181
prob_j2_at_t19 0.005073 0.003666 0.004927
prob_j2_at_t20 0.005114 0.003731 0.004943
prob_j2_at_t21 0.003817 0.002808 0.003674
prob_j2_at_t22 0.004498 0.003338 0.004311
prob_j2_at_t23 0.003480 0.002603 0.003321
prob_j2_at_t24 0.003455 0.002607 0.003283
prob_j2_at_t25 0.003669 0.002790 0.003473
prob_j2_at_t26 0.003674 0.002816 0.003463
prob_j2_at_t27 0.003301 0.002552 0.003099
prob_j2_at_t28 0.003238 0.002523 0.003027
prob_j2_at_t29 0.004067 0.003195 0.003788
prob_j2_at_t30 0.003969 0.003144 0.003681
cif_j1_at_t1 0.043775 0.031795 0.052250
cif_j1_at_t2 0.076731 0.056126 0.091255
cif_j1_at_t3 0.104240 0.076713 0.123578
cif_j1_at_t4 0.129283 0.095712 0.152797
cif_j1_at_t5 0.150660 0.112129 0.177582
cif_j1_at_t6 0.171621 0.128424 0.201731
cif_j1_at_t7 0.189895 0.142791 0.222664
cif_j1_at_t8 0.207480 0.156771 0.242692
cif_j1_at_t9 0.222870 0.169138 0.260127
cif_j1_at_t10 0.237835 0.181287 0.276992
cif_j1_at_t11 0.252248 0.193109 0.293148
cif_j1_at_t12 0.266118 0.204597 0.308618
cif_j1_at_t13 0.278100 0.214616 0.321919
cif_j1_at_t14 0.290544 0.225126 0.335665
cif_j1_at_t15 0.301148 0.234163 0.347326
cif_j1_at_t16 0.312174 0.243646 0.359395
cif_j1_at_t17 0.322480 0.252591 0.370624
cif_j1_at_t18 0.331845 0.260792 0.380785
cif_j1_at_t19 0.342048 0.269805 0.391804
cif_j1_at_t20 0.351236 0.277996 0.401683
cif_j1_at_t21 0.359623 0.285540 0.410661
cif_j1_at_t22 0.368158 0.293280 0.419761
cif_j1_at_t23 0.377348 0.301693 0.429513
cif_j1_at_t24 0.384444 0.308236 0.437014
cif_j1_at_t25 0.391884 0.315151 0.444848
cif_j1_at_t26 0.399673 0.322451 0.453014
cif_j1_at_t27 0.406662 0.329055 0.460313
cif_j1_at_t28 0.412775 0.334874 0.466673
cif_j1_at_t29 0.419285 0.341119 0.473422
cif_j1_at_t30 0.424503 0.346164 0.478811
cif_j2_at_t1 0.014380 0.008500 0.015506
cif_j2_at_t2 0.025018 0.014900 0.026862
cif_j2_at_t3 0.036214 0.021739 0.038717
cif_j2_at_t4 0.045829 0.027693 0.048824
cif_j2_at_t5 0.054992 0.033441 0.058389
cif_j2_at_t6 0.063329 0.038733 0.067038
cif_j2_at_t7 0.071018 0.043673 0.074965
cif_j2_at_t8 0.079059 0.048897 0.083207
cif_j2_at_t9 0.085878 0.053376 0.090156
cif_j2_at_t10 0.092737 0.057928 0.097111
cif_j2_at_t11 0.098531 0.061811 0.102955
cif_j2_at_t12 0.104788 0.066048 0.109233
cif_j2_at_t13 0.110912 0.070236 0.115345
cif_j2_at_t14 0.116647 0.074197 0.121044
cif_j2_at_t15 0.121959 0.077899 0.126295
cif_j2_at_t16 0.127322 0.081672 0.131574
cif_j2_at_t17 0.132534 0.085372 0.136679
cif_j2_at_t18 0.136820 0.088442 0.140860
cif_j2_at_t19 0.141894 0.092108 0.145787
cif_j2_at_t20 0.147008 0.095839 0.150730
cif_j2_at_t21 0.150825 0.098647 0.154404
cif_j2_at_t22 0.155324 0.101985 0.158715
cif_j2_at_t23 0.158803 0.104588 0.162036
cif_j2_at_t24 0.162259 0.107195 0.165319
cif_j2_at_t25 0.165928 0.109985 0.168792
cif_j2_at_t26 0.169602 0.112801 0.172254
cif_j2_at_t27 0.172903 0.115352 0.175353
cif_j2_at_t28 0.176141 0.117875 0.178380
cif_j2_at_t29 0.180208 0.121070 0.182168
cif_j2_at_t30 0.184177 0.124214 0.185850