Skip to content

Performance Measures¤

Model evaluation on test data or by CV, can be done using the evaluation functions available in PyDTS and the measures of performance presented in the Methods section.

import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
from pydts.examples_utils.generate_simulations_data import generate_quick_start_df
import warnings
pd.set_option("display.max_rows", 500)
warnings.filterwarnings('ignore')
%matplotlib inline

real_coef_dict = {
    "alpha": {
        1: lambda t: -1 - 0.3 * np.log(t),
        2: lambda t: -1.75 - 0.15 * np.log(t)
    },
    "beta": {
        1: -np.log([0.8, 3, 3, 2.5, 2]),
        2: -np.log([1, 3, 4, 3, 2])
    }
}

n_patients = 50000
n_cov = 5

patients_df = generate_quick_start_df(n_patients=n_patients, n_cov=n_cov, d_times=30, j_events=2, 
                                      pid_col='pid', seed=0, censoring_prob=0.8, 
                                      real_coef_dict=real_coef_dict)

train_df, test_df = train_test_split(patients_df, test_size=0.2)

patients_df.head()
pid Z1 Z2 Z3 Z4 Z5 J T C X
0 0 0.548814 0.715189 0.602763 0.544883 0.423655 0 31 10 10
1 1 0.645894 0.437587 0.891773 0.963663 0.383442 0 31 24 24
2 2 0.791725 0.528895 0.568045 0.925597 0.071036 0 17 11 11
3 3 0.087129 0.020218 0.832620 0.778157 0.870012 1 1 31 1
4 4 0.978618 0.799159 0.461479 0.780529 0.118274 0 15 14 14

For example, in the following code, the survival models are estimated based on the two-stage approach and the dataset train_df. Assume that the event of main interest is \(j=1\). Then, \(\pi_{i1}(t)\) are calculated and stored in pred_df, and finally \(\widehat{\mbox{AUC}}_1(t)\), \(t=1,\ldots,d\), are provided by

from pydts.fitters import TwoStagesFitter
from pydts.evaluation import *

fitter = TwoStagesFitter()
fitter.fit(df = train_df)
pred_df = fitter.predict_prob_event_j_all(test_df, event=1)
auc_1 = event_specific_auc_at_t_all(pred_df, event=1)
print(f'AUC(t) for event 1 is:')
auc_1
INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.
AUC(t) for event 1 is:

1     0.988007
2     0.989981
3     0.990528
4     0.992057
5     0.991414
6     0.992463
7     0.992465
8     0.993415
9     0.993273
10    0.992431
11    0.994390
12    0.995089
13    0.992101
14    0.993462
15    0.995626
16    0.994383
17    0.994551
18    0.992995
19    0.993614
20    0.993318
21    0.994786
22    0.992859
23    0.994097
24    0.997663
25    0.998358
26    0.997495
27    0.997700
28    0.998260
29    0.994920
30    0.988650
Name: 1, dtype: float64

Other measures such as \(\widehat{\mbox{AUC}}_1\), \(\widehat{\mbox{BS}}_1\), \(\widehat{\mbox{AUC}}\), and \(\widehat{\mbox{BS}}\) can be calculated by

pred_df = fitter.predict_prob_events(test_df)
ibs_1 = event_specific_integrated_brier_score(pred_df, event = 1)
iauc_1 = event_specific_integrated_auc(pred_df, event = 1)
bs = global_brier_score(pred_df)
auc = global_auc(pred_df)

Model evaluation based on K-fold CV and TwoStagesFitter can be done by

from pydts.cross_validation import TwoStagesCV

cross_validator = TwoStagesCV()
cross_validator.cross_validate(full_df = patients_df.drop(['C', 'T'], 
                               axis = 1), 
                               n_splits = 5, seed = 0,
                               metrics = ['BS', 'IBS', 'GBS', 
                                          'AUC', 'IAUC', 'GAUC'])
Fitting fold 1/5
INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.
Fitting fold 2/5
INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.
Fitting fold 3/5
INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.
Fitting fold 4/5
INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.
Fitting fold 5/5
INFO: Pandarallel will run on 8 workers.
INFO: Pandarallel will use standard multiprocessing data transfer (pipe) to transfer data between the main process and workers.

1 2 3 4 5 6 7 8 9 10 ... 21 22 23 24 25 26 27 28 29 30
metric fold
BS 0 1 0.063049 0.053907 0.045755 0.043523 0.042300 0.046062 0.037669 0.039918 0.038884 0.039600 ... 0.036635 0.058504 0.046140 0.053152 0.061784 0.050324 0.065542 0.062469 0.090904 0.072492
2 0.024376 0.019168 0.019545 0.017727 0.019594 0.017552 0.021547 0.020059 0.016670 0.019916 ... 0.018437 0.034735 0.019682 0.032066 0.023314 0.031648 0.027103 0.019322 0.048674 0.048384
AUC 0 1 0.652021 0.643642 0.641287 0.643345 0.667766 0.643727 0.655765 0.667565 0.570066 0.653903 ... 0.514947 0.531114 0.499528 0.478013 0.578898 0.500812 0.414534 0.475877 0.582044 0.412960
2 0.672119 0.677260 0.688408 0.646792 0.686449 0.659598 0.629191 0.626242 0.693669 0.635776 ... 0.612318 0.495373 0.510674 0.517435 0.542218 0.504568 0.725689 0.297011 0.563298 0.587978
BS 1 1 0.063104 0.051763 0.047522 0.045185 0.039390 0.040410 0.040533 0.037199 0.035189 0.039389 ... 0.048439 0.039984 0.049740 0.037045 0.054407 0.063798 0.056508 0.063322 0.055487 0.046272
2 0.026567 0.020499 0.023158 0.019674 0.017881 0.018550 0.012963 0.019118 0.018793 0.020397 ... 0.016248 0.024147 0.021958 0.020464 0.024052 0.030794 0.022079 0.047698 0.055665 0.039775
AUC 1 1 0.637347 0.653555 0.670154 0.653391 0.674059 0.664124 0.626964 0.639071 0.622290 0.624590 ... 0.470615 0.619649 0.576780 0.408460 0.472122 0.534623 0.390513 0.458682 0.399079 0.547307
2 0.702004 0.665337 0.698247 0.670620 0.612123 0.662859 0.732118 0.652722 0.610839 0.666786 ... 0.564700 0.564492 0.474570 0.610849 0.528472 0.602485 0.547596 0.466056 0.532451 0.328246
BS 2 1 0.064886 0.049643 0.045128 0.044718 0.040647 0.041574 0.037698 0.035049 0.041622 0.038163 ... 0.048869 0.041845 0.053532 0.052515 0.027061 0.060276 0.048393 0.064393 0.074945 0.069480
2 0.024451 0.016186 0.021169 0.021732 0.021547 0.021889 0.017873 0.020496 0.015362 0.014820 ... 0.018748 0.021711 0.017476 0.020739 0.033913 0.033076 0.031291 0.034467 0.043016 0.055695
AUC 2 1 0.659157 0.654793 0.649652 0.659449 0.674832 0.660598 0.594342 0.621234 0.648151 0.634931 ... 0.611217 0.482856 0.548870 0.517922 0.367427 0.536657 0.581426 0.453719 0.453618 0.531055
2 0.657810 0.657075 0.706675 0.713073 0.645342 0.666765 0.653272 0.674614 0.565575 0.620555 ... 0.552351 0.616195 0.470557 0.589532 0.554307 0.558689 0.494114 0.490396 0.369201 0.509630
BS 3 1 0.059994 0.050895 0.045448 0.043562 0.037141 0.043832 0.038703 0.043434 0.035425 0.037136 ... 0.039925 0.050829 0.056631 0.044824 0.062753 0.075129 0.069417 0.059778 0.069127 0.041407
2 0.024225 0.021932 0.019773 0.015853 0.017388 0.017645 0.017031 0.018022 0.015532 0.019680 ... 0.022589 0.030680 0.029422 0.024571 0.031540 0.020359 0.031400 0.025750 0.042687 0.055311
AUC 3 1 0.659007 0.665953 0.682254 0.636794 0.677254 0.626931 0.650321 0.615777 0.639428 0.600239 ... 0.464096 0.516016 0.450649 0.644558 0.493753 0.510769 0.481684 0.635580 0.504782 0.445025
2 0.702751 0.677411 0.691677 0.699105 0.657049 0.655707 0.675905 0.646852 0.718233 0.665166 ... 0.594124 0.541976 0.492703 0.830657 0.479912 0.431244 0.579373 0.386118 0.505585 0.381362
BS 4 1 0.064885 0.053155 0.047227 0.045384 0.041575 0.035789 0.038320 0.040585 0.031461 0.032940 ... 0.038013 0.043923 0.066064 0.044236 0.061176 0.055079 0.067624 0.050423 0.075282 0.103932
2 0.023856 0.019616 0.022393 0.021449 0.020102 0.016353 0.019711 0.020459 0.021224 0.017895 ... 0.025120 0.015200 0.019491 0.019334 0.022800 0.033254 0.037406 0.033727 0.043146 0.062432
AUC 4 1 0.659357 0.636783 0.654152 0.637880 0.631202 0.647121 0.638941 0.625363 0.599964 0.641299 ... 0.524715 0.494028 0.502059 0.530022 0.579939 0.578329 0.511615 0.481707 0.389172 0.331990
2 0.672312 0.657230 0.687828 0.674431 0.653015 0.616802 0.636077 0.697158 0.613994 0.613293 ... 0.524942 0.592352 0.463863 0.605593 0.477705 0.490352 0.519561 0.739740 0.509843 0.459805

20 rows × 30 columns

Results of the AUC(t), BS(t) from the cross-validation procedure to each of the folds and each of the risks:

cross_validator.results
1 2 3 4 5 6 7 8 9 10 ... 21 22 23 24 25 26 27 28 29 30
metric fold
BS 0 1 0.063049 0.053907 0.045755 0.043523 0.042300 0.046062 0.037669 0.039918 0.038884 0.039600 ... 0.036635 0.058504 0.046140 0.053152 0.061784 0.050324 0.065542 0.062469 0.090904 0.072492
2 0.024376 0.019168 0.019545 0.017727 0.019594 0.017552 0.021547 0.020059 0.016670 0.019916 ... 0.018437 0.034735 0.019682 0.032066 0.023314 0.031648 0.027103 0.019322 0.048674 0.048384
AUC 0 1 0.652021 0.643642 0.641287 0.643345 0.667766 0.643727 0.655765 0.667565 0.570066 0.653903 ... 0.514947 0.531114 0.499528 0.478013 0.578898 0.500812 0.414534 0.475877 0.582044 0.412960
2 0.672119 0.677260 0.688408 0.646792 0.686449 0.659598 0.629191 0.626242 0.693669 0.635776 ... 0.612318 0.495373 0.510674 0.517435 0.542218 0.504568 0.725689 0.297011 0.563298 0.587978
BS 1 1 0.063104 0.051763 0.047522 0.045185 0.039390 0.040410 0.040533 0.037199 0.035189 0.039389 ... 0.048439 0.039984 0.049740 0.037045 0.054407 0.063798 0.056508 0.063322 0.055487 0.046272
2 0.026567 0.020499 0.023158 0.019674 0.017881 0.018550 0.012963 0.019118 0.018793 0.020397 ... 0.016248 0.024147 0.021958 0.020464 0.024052 0.030794 0.022079 0.047698 0.055665 0.039775
AUC 1 1 0.637347 0.653555 0.670154 0.653391 0.674059 0.664124 0.626964 0.639071 0.622290 0.624590 ... 0.470615 0.619649 0.576780 0.408460 0.472122 0.534623 0.390513 0.458682 0.399079 0.547307
2 0.702004 0.665337 0.698247 0.670620 0.612123 0.662859 0.732118 0.652722 0.610839 0.666786 ... 0.564700 0.564492 0.474570 0.610849 0.528472 0.602485 0.547596 0.466056 0.532451 0.328246
BS 2 1 0.064886 0.049643 0.045128 0.044718 0.040647 0.041574 0.037698 0.035049 0.041622 0.038163 ... 0.048869 0.041845 0.053532 0.052515 0.027061 0.060276 0.048393 0.064393 0.074945 0.069480
2 0.024451 0.016186 0.021169 0.021732 0.021547 0.021889 0.017873 0.020496 0.015362 0.014820 ... 0.018748 0.021711 0.017476 0.020739 0.033913 0.033076 0.031291 0.034467 0.043016 0.055695
AUC 2 1 0.659157 0.654793 0.649652 0.659449 0.674832 0.660598 0.594342 0.621234 0.648151 0.634931 ... 0.611217 0.482856 0.548870 0.517922 0.367427 0.536657 0.581426 0.453719 0.453618 0.531055
2 0.657810 0.657075 0.706675 0.713073 0.645342 0.666765 0.653272 0.674614 0.565575 0.620555 ... 0.552351 0.616195 0.470557 0.589532 0.554307 0.558689 0.494114 0.490396 0.369201 0.509630
BS 3 1 0.059994 0.050895 0.045448 0.043562 0.037141 0.043832 0.038703 0.043434 0.035425 0.037136 ... 0.039925 0.050829 0.056631 0.044824 0.062753 0.075129 0.069417 0.059778 0.069127 0.041407
2 0.024225 0.021932 0.019773 0.015853 0.017388 0.017645 0.017031 0.018022 0.015532 0.019680 ... 0.022589 0.030680 0.029422 0.024571 0.031540 0.020359 0.031400 0.025750 0.042687 0.055311
AUC 3 1 0.659007 0.665953 0.682254 0.636794 0.677254 0.626931 0.650321 0.615777 0.639428 0.600239 ... 0.464096 0.516016 0.450649 0.644558 0.493753 0.510769 0.481684 0.635580 0.504782 0.445025
2 0.702751 0.677411 0.691677 0.699105 0.657049 0.655707 0.675905 0.646852 0.718233 0.665166 ... 0.594124 0.541976 0.492703 0.830657 0.479912 0.431244 0.579373 0.386118 0.505585 0.381362
BS 4 1 0.064885 0.053155 0.047227 0.045384 0.041575 0.035789 0.038320 0.040585 0.031461 0.032940 ... 0.038013 0.043923 0.066064 0.044236 0.061176 0.055079 0.067624 0.050423 0.075282 0.103932
2 0.023856 0.019616 0.022393 0.021449 0.020102 0.016353 0.019711 0.020459 0.021224 0.017895 ... 0.025120 0.015200 0.019491 0.019334 0.022800 0.033254 0.037406 0.033727 0.043146 0.062432
AUC 4 1 0.659357 0.636783 0.654152 0.637880 0.631202 0.647121 0.638941 0.625363 0.599964 0.641299 ... 0.524715 0.494028 0.502059 0.530022 0.579939 0.578329 0.511615 0.481707 0.389172 0.331990
2 0.672312 0.657230 0.687828 0.674431 0.653015 0.616802 0.636077 0.697158 0.613994 0.613293 ... 0.524942 0.592352 0.463863 0.605593 0.477705 0.490352 0.519561 0.739740 0.509843 0.459805

20 rows × 30 columns

with the integrated AUC and BS to each of folds and each of the risks:

pd.DataFrame.from_records(cross_validator.integrated_auc)
0 1 2 3 4
1 0.629493 0.627465 0.636286 0.634377 0.622095
2 0.651277 0.647620 0.646869 0.655923 0.642630
pd.DataFrame.from_records(cross_validator.integrated_bs)
0 1 2 3 4
1 0.048382 0.046735 0.046411 0.046305 0.047747
2 0.020957 0.021362 0.021003 0.020547 0.021381

and lastly, the global AUC and global BS to each of the folds:

print(cross_validator.global_auc)
{0: 0.635987860592661, 1: 0.633667449693418, 2: 0.6395581436224567, 3: 0.6408781976004154, 4: 0.6284087721165604}

print(cross_validator.global_bs)
{0: 0.0402049045296208, 1: 0.03892613679526756, 2: 0.03855544852975435, 3: 0.0385327046054388, 4: 0.039640037102306486}

References¤

[1] Meir, Tomer*, Gutman, Rom*, and Gorfine, Malka, "PyDTS: A Python Package for Discrete-Time Survival Analysis with Competing Risks" (2022)