정형 데이터에 최적화된 딥러닝 모델이라고 알려져 있는 TabNet 으로 학습한 TabNet Regression 모델을 생성하고, 이 모델을 SHAP(Shapely Value) 을 통해 모델 예측 결과를 최대한 잘 설명해줄 수 방법을 찾아보고, 결과물을 한 데이터프레임 안에 저장해보자

TabNet 논문과 SHAP 과 같은 explainable AI 에 관한 내용은 다른 포스트에서 좀더 다룰 예정

install package

  • Colab 환경에서 작성되었음
! pip install pytorch-tabnet ## 설치된 경우에는 실행 안해도 됨
! pip install shap ## 설치된 경우에는 실행 안해도 됨 

env setting

### drive mount

from google.colab import drive
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
### gpu mapping
import torch
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using {}".format(DEVICE))
Using cpu
### colab pro + 메모리 사양 확인 
# MemTotal:       13298580 kB
# MemFree:         9764400 kB
# MemAvailable:   12273036 kB

! head -n 3 /proc/meminfo  ### 위의 주석과 다르면.. 실험환경 메모리가 미세하게 다른 것임 ㅠㅜㅠ 
MemTotal:       13297228 kB
MemFree:         8930012 kB
MemAvailable:   12369728 kB

Generate Data & Split Train/Test

### import model 

from sklearn.datasets import load_iris, load_boston
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, accuracy_score, f1_score

from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor
from sklearn.metrics import r2_score
import shap
import os
import time

import warnings
warnings.filterwarnings(action='ignore') ### nsample수 20으로 할경우뜨는 warning 메시지 지우기 위해 

Hyperparameter 설정

X_SIZE = 10000 ## 고정 
TEST_SIZE = 20 ### 1000, 2000
COL_SIZE = 15 ### 30, 60
SEED = 2022
N_sample = 'auto'  ## 'auto' or 2048

데이터 생성

  • random 데이터를 뿌려줌
### make data 
x_data = np.random.rand(X_SIZE, COL_SIZE)
df = pd.DataFrame(x_data)
df['target'] = np.random.randint(1000, 50000, size=(X_SIZE, 1))
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 target
0 0.009359 0.499058 0.113384 0.049974 0.685408 0.486988 0.897657 0.647452 0.896963 0.721135 0.831353 0.827568 0.833580 0.957044 0.368044 30906
1 0.494838 0.339509 0.619429 0.977530 0.096433 0.744206 0.292499 0.298675 0.752473 0.018664 0.523737 0.864436 0.388843 0.212192 0.475181 8187
2 0.564672 0.349429 0.975909 0.037820 0.794270 0.357883 0.747964 0.914509 0.372662 0.964883 0.081386 0.042451 0.296796 0.363704 0.490255 23806
3 0.668519 0.673415 0.572101 0.080592 0.898331 0.038389 0.782194 0.036656 0.267184 0.205224 0.258894 0.932615 0.008125 0.403473 0.894102 10204
4 0.204209 0.021776 0.697167 0.191023 0.546433 0.603225 0.988794 0.092446 0.064287 0.987952 0.452108 0.853911 0.401445 0.388206 0.884407 27620
... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... ...
9995 0.017416 0.824987 0.246643 0.141063 0.184951 0.384777 0.722438 0.279597 0.194048 0.816772 0.070302 0.708632 0.497547 0.113425 0.302923 20149
9996 0.434106 0.286644 0.964673 0.237779 0.093510 0.788614 0.645321 0.475191 0.551407 0.438434 0.801701 0.698005 0.065917 0.594159 0.664846 9288
9997 0.951556 0.573942 0.489135 0.139136 0.991655 0.563769 0.347741 0.782542 0.520789 0.944053 0.820197 0.364698 0.538379 0.761037 0.904788 34624
9998 0.800968 0.732879 0.651727 0.610226 0.644994 0.756211 0.247786 0.620484 0.464670 0.879303 0.108468 0.580453 0.742119 0.414510 0.988418 24557
9999 0.082800 0.654022 0.453132 0.713547 0.766718 0.452666 0.910464 0.052970 0.132754 0.090441 0.807935 0.648001 0.722958 0.820611 0.093902 9184

10000 rows × 16 columns

train test split - 위에서 설정한 테스트 개수만큼 split

X, y = df.drop('target', axis = 1).values, df['target'].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = TEST_SIZE, random_state = 2022)
y_train, y_test = y_train.reshape(-1,1), y_test.reshape(-1,1)

Save & Load Model

model_path = f"/content/drive/MyDrive/tabnet_reg_10000/tabnet_reg_{X_SIZE}_{TEST_SIZE}" ### 우선 구글 드라이브에 바로 저장
regressor = TabNetRegressor(verbose=1, seed=2022)

if os.path.isfile(model_path+".zip"):
    print(f"LOAD SAVED MODEL -- {model_path}")
    print(f"TRANING MODEL & SAVE")
    regressor = TabNetRegressor(verbose=1, seed=2022)
    regressor.fit(X_train=X_train, y_train=y_train, 
                batch_size = 64, 
                max_epochs=3) ### default epoch # 100 

epoch 0  | loss: 863465970.37419|  0:00:02s
epoch 1  | loss: 835548792.98065|  0:00:07s
epoch 2  | loss: 775246533.57419|  0:00:12s
Successfully saved model at /content/drive/MyDrive/tabnet_reg_10000/tabnet_reg_10000_20.zip
TabNetRegressor(n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[], cat_emb_dim=1, n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02, lambda_sparse=0.001, seed=2022, clip_value=1, verbose=1, optimizer_fn=<class 'torch.optim.adam.Adam'>, optimizer_params={'lr': 0.02}, scheduler_fn=None, scheduler_params={}, mask_type='sparsemax', input_dim=15, output_dim=1, device_name='auto', n_shared_decoder=1, n_indep_decoder=1)
model_path = f"/content/drive/MyDrive/tabnet_reg_10000/tabnet_reg_{X_SIZE}_{TEST_SIZE}" 
regressor = TabNetRegressor(verbose=1, seed=2022)
loaded_tabnetregressor = TabNetRegressor()
loaded_tabnetregressor.load_model(model_path+".zip") # 저장한 모델 불러오기
TabNetRegressor(n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[], cat_emb_dim=1, n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02, lambda_sparse=0.001, seed=2022, clip_value=1, verbose=1, optimizer_fn=<class 'torch.optim.adam.Adam'>, optimizer_params={'lr': 0.02}, scheduler_fn=None, scheduler_params={}, mask_type='sparsemax', input_dim=15, output_dim=1, device_name='auto', n_shared_decoder=1, n_indep_decoder=1)

Calculate SHAP Value

  • KernelExplainer 사용 – Deep Learning model 에는 Permutation 또는 Kernel Explainer가 적합하다
### SHAP value 추출  
print("현재 shap value 확인하는 테스트 데이터 사이즈 -", X_test.shape)
explainer = shap.KernelExplainer(loaded_tabnetregressor.predict, X_test)  # 저장한 모델 넣기

start = time.time() 
shap_values = explainer.shap_values(X_test, nsamples = N_sample)
cell_run_time = time.time() - start
현재 shap value 확인하는 테스트 데이터 사이즈 - (20, 15)

### cross check 
base_value = explainer.expected_value[0]
print("base value", base_value)
base value 2976.890240478515

Merge Predicted Value, Shap Value into Test Dataframe

shap_importance = pd.DataFrame(shap_values[0])
shap_importance.columns = [str(i)+'_SHAPVAL' for i in shap_importance.columns]
shap_importance.loc[:,'TOTAL_SHAPVAL'] = shap_importance.sum(axis=1)
final_report_df = pd.concat([pd.DataFrame(X_test), shap_importance], axis = 1)

y_pred = loaded_tabnetregressor.predict(X_test)
final_report_df['y_pred'] = y_pred
final_report_df['y_pred']  = final_report_df['y_pred'].astype('float64')
final_report_df['base_value']  =  pd.Series([base_value] * X_test.shape[0])
final_report_df['TOTAL_SHAPVAL + base_value'] =  final_report_df['base_value']  + final_report_df['TOTAL_SHAPVAL']

0 1 2 3 4 5 6 7 8 9 ... 9_SHAPVAL 10_SHAPVAL 11_SHAPVAL 12_SHAPVAL 13_SHAPVAL 14_SHAPVAL TOTAL_SHAPVAL y_pred base_value TOTAL_SHAPVAL + base_value
0 0.665340 0.305386 0.160187 0.608090 0.876134 0.577168 0.409957 0.951125 0.078057 0.580620 ... 13.989868 58.825361 0.000000 24.958757 5.865831 207.362089 582.052631 3558.942871 2976.89024 3558.942871
1 0.896871 0.016248 0.043515 0.174851 0.286535 0.426831 0.938803 0.245879 0.652431 0.358082 ... -5.117060 13.446534 1.012135 -7.112571 0.000000 -299.820383 -229.182477 2747.707520 2976.89024 2747.707764
2 0.195551 0.534170 0.054984 0.067725 0.899910 0.864291 0.928148 0.663764 0.529692 0.069272 ... -13.821463 16.158704 0.000000 0.000000 0.000000 -797.671184 -1612.761945 1364.128174 2976.89024 1364.128296
3 0.780031 0.958210 0.060079 0.127752 0.121757 0.998124 0.619833 0.771568 0.357500 0.250555 ... -8.957114 2.896907 0.000000 -2.727257 1.221694 -259.071928 139.967181 3116.857422 2976.89024 3116.857422
4 0.833555 0.818795 0.188312 0.866473 0.029219 0.869027 0.231443 0.502989 0.734818 0.029553 ... -46.340093 97.512771 0.000000 -40.298334 -10.574096 14.495060 128.472308 3105.362305 2976.89024 3105.362549
5 0.387548 0.608834 0.373447 0.163071 0.871580 0.123094 0.695192 0.221676 0.473412 0.179041 ... -5.709209 -32.263871 0.000000 -1.806092 2.042695 519.058980 438.634662 3415.524658 2976.89024 3415.524902
6 0.051641 0.885105 0.505377 0.459587 0.717950 0.021153 0.329931 0.348929 0.030949 0.122640 ... -16.405837 21.901906 -2.739079 0.000000 0.000000 -619.011598 -1625.876813 1351.013306 2976.89024 1351.013428
7 0.622831 0.676264 0.349113 0.840550 0.952585 0.700288 0.910378 0.059485 0.070335 0.392437 ... 2.886479 -40.003732 0.000000 7.700811 0.000000 174.672702 53.446185 3030.336914 2976.89024 3030.336426
8 0.540494 0.498669 0.432838 0.635259 0.044647 0.157382 0.870429 0.626106 0.463865 0.389576 ... 0.000000 -22.330007 0.000000 -3.186513 0.000000 200.520145 297.537982 3274.428223 2976.89024 3274.428223
9 0.292736 0.394259 0.501476 0.755071 0.963806 0.097967 0.544224 0.091328 0.650178 0.943432 ... 17.487106 -37.783851 0.000000 0.000000 -5.146068 578.151493 538.363422 3515.253174 2976.89024 3515.253662
10 0.807821 0.546079 0.826843 0.606432 0.196059 0.102094 0.280362 0.134659 0.952526 0.791225 ... 23.460712 23.505667 0.000000 20.740959 3.940929 169.248973 579.097552 3555.987793 2976.89024 3555.987793
11 0.553643 0.405755 0.070046 0.709272 0.852507 0.653668 0.500025 0.306107 0.014521 0.307509 ... -3.137991 8.399765 0.000000 0.000000 2.952945 222.107463 587.687885 3564.578369 2976.89024 3564.578125
12 0.269201 0.886692 0.236271 0.228237 0.790845 0.247772 0.858724 0.103983 0.214732 0.714835 ... 33.948094 0.000000 0.000000 -4.854956 -2.832984 -642.829525 -1401.765851 1575.124268 2976.89024 1575.124390
13 0.511182 0.153969 0.987458 0.664299 0.489893 0.038254 0.708350 0.582648 0.444579 0.579134 ... 8.152289 3.099521 0.000000 4.080772 6.835309 167.510164 599.708881 3576.599121 2976.89024 3576.599121
14 0.035127 0.145992 0.092245 0.939648 0.698673 0.515788 0.870113 0.728663 0.240244 0.733680 ... 32.049755 -13.377773 0.000000 -11.680317 -6.447017 -558.299584 -1539.971051 1436.918945 2976.89024 1436.919189
15 0.458821 0.365300 0.913334 0.213902 0.705668 0.444712 0.655041 0.711478 0.109087 0.047111 ... 0.000000 -47.550775 0.000000 -8.974647 -14.749852 212.535800 522.897113 3499.787109 2976.89024 3499.787354
16 0.586684 0.721589 0.554982 0.728450 0.987413 0.365390 0.149811 0.879073 0.132804 0.158433 ... -10.001263 -16.914937 0.000000 4.579777 0.000000 243.402676 539.024310 3515.915283 2976.89024 3515.914551
17 0.447370 0.856104 0.490726 0.512460 0.901070 0.723599 0.946287 0.362711 0.341831 0.226329 ... -4.604177 -30.644028 0.000000 20.202807 10.523801 161.331920 237.945941 3214.836182 2976.89024 3214.836182
18 0.602242 0.488993 0.776788 0.334477 0.417354 0.396566 0.848444 0.130109 0.666934 0.080961 ... -5.500653 5.991046 0.000000 0.000000 5.574099 253.524126 580.183002 3557.073242 2976.89024 3557.073242
19 0.459065 0.059892 0.831910 0.959609 0.553826 0.954812 0.142540 0.463129 0.410726 0.211993 ... -8.669551 -23.851995 3.546563 0.000000 0.000000 52.351894 584.539447 3561.429932 2976.89024 3561.429688

20 rows × 34 columns

Optional) 시간 소요 그래프 만들기

Explainer 를 바꿔가면서 시간 소요 테스트를 해보고 싶을 때, log dataframe 을 만들어 추적하면 편하다

log = pd.DataFrame([[TEST_SIZE, COL_SIZE, N_sample, cell_run_time]], columns=['test_size','col_size','nsamples', 'run_time'])
# log.to_excel("/content/drive/MyDrive/tabnet_reg_10000/log.xlsx", index=False) # 최초의 log 파일 생성시
# pd.read_excel("/content/drive/MyDrive/tabnet_reg_10000/log.xlsx")
test_size col_size nsamples run_time
0 20 15 auto 21.929386
import os
from openpyxl import load_workbook

cache_path = '/content/drive/MyDrive/tabnet_reg_10000/log.xlsx'

if not os.path.exists(cache_path):  # 파일 없으면 최초 log 파일 만들기
  initial_log = pd.DataFrame([[TEST_SIZE, COL_SIZE, N_sample, cell_run_time]],columns=['test_size','col_size','nsamples', 'run_time'])
  initial_log.to_excel(cache_path, index=False)

  book = load_workbook(cache_path)    # 기존 log 파일 불러오기
  writer = pd.ExcelWriter(cache_path, engine='openpyxl')
  writer.book = book
  writer.sheets = {ws.title: ws for ws in book.worksheets}

  for sheetname in writer.sheets:
      log.to_excel(writer,sheet_name=sheetname, startrow=writer.sheets[sheetname].max_row, index = False,header= False) # append rows to existing excel cache file

test_size col_size nsamples run_time
0 500 15 20 322.463880
1 200 15 20 62.606001
2 1000 15 auto 43426.773710
3 100 15 auto 1695.991297
4 300 15 auto 18396.789017
5 100 15 auto 492.172346
6 100 15 auto 498.190434
7 200 15 auto 1911.865682
8 400 15 auto 7451.173161
9 300 15 auto 3855.254270
10 500 15 auto 10974.687255
11 400 15 auto 6170.332176
12 100 15 auto 418.124938
13 200 15 auto 1512.140173
14 300 15 auto 3359.933324
15 400 15 auto 6090.160212
16 500 15 auto 9173.967759
17 1000 15 auto 36812.828763
18 20 15 auto 21.929386
