[딥러닝/AI] TabNet으로 학습한 딥러닝 모델을 SHAP으로 설명력 구현하기 (풀코드구현)
정형 데이터에 최적화된 딥러닝 모델이라고 알려져 있는 TabNet 으로 학습한 TabNet Regression 모델을 생성하고, 이 모델을 SHAP(Shapely Value) 을 통해 모델 예측 결과를 최대한 잘 설명해줄 수 방법을 찾아보고, 결과물을 한 데이터프레임 안에 저장해보자
TabNet 논문과 SHAP 과 같은 explainable AI 에 관한 내용은 다른 포스트에서 좀더 다룰 예정
install package
- Colab 환경에서 작성되었음
! pip install pytorch-tabnet ## 설치된 경우에는 실행 안해도 됨
! pip install shap ## 설치된 경우에는 실행 안해도 됨
env setting
### drive mount
from google.colab import drive
drive.mount('/content/drive')
Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
### gpu mapping
import torch
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using {}".format(DEVICE))
Using cpu
### colab pro + 메모리 사양 확인
# MemTotal: 13298580 kB
# MemFree: 9764400 kB
# MemAvailable: 12273036 kB
! head -n 3 /proc/meminfo ### 위의 주석과 다르면.. 실험환경 메모리가 미세하게 다른 것임 ㅠㅜㅠ
MemTotal: 13297228 kB
MemFree: 8930012 kB
MemAvailable: 12369728 kB
Generate Data & Split Train/Test
### import model
from sklearn.datasets import load_iris, load_boston
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.tree import DecisionTreeClassifier, DecisionTreeRegressor
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, mean_absolute_error, accuracy_score, f1_score
from pytorch_tabnet.tab_model import TabNetClassifier, TabNetRegressor
from sklearn.metrics import r2_score
import shap
import os
import time
import warnings
warnings.filterwarnings(action='ignore') ### nsample수 20으로 할경우뜨는 warning 메시지 지우기 위해
Hyperparameter 설정
X_SIZE = 10000 ## 고정
TEST_SIZE = 20 ### 1000, 2000
COL_SIZE = 15 ### 30, 60
SEED = 2022
N_sample = 'auto' ## 'auto' or 2048
데이터 생성
- random 데이터를 뿌려줌
### make data
np.random.seed(SEED)
x_data = np.random.rand(X_SIZE, COL_SIZE)
df = pd.DataFrame(x_data)
df['target'] = np.random.randint(1000, 50000, size=(X_SIZE, 1))
df
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | 10 | 11 | 12 | 13 | 14 | target | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.009359 | 0.499058 | 0.113384 | 0.049974 | 0.685408 | 0.486988 | 0.897657 | 0.647452 | 0.896963 | 0.721135 | 0.831353 | 0.827568 | 0.833580 | 0.957044 | 0.368044 | 30906 |
1 | 0.494838 | 0.339509 | 0.619429 | 0.977530 | 0.096433 | 0.744206 | 0.292499 | 0.298675 | 0.752473 | 0.018664 | 0.523737 | 0.864436 | 0.388843 | 0.212192 | 0.475181 | 8187 |
2 | 0.564672 | 0.349429 | 0.975909 | 0.037820 | 0.794270 | 0.357883 | 0.747964 | 0.914509 | 0.372662 | 0.964883 | 0.081386 | 0.042451 | 0.296796 | 0.363704 | 0.490255 | 23806 |
3 | 0.668519 | 0.673415 | 0.572101 | 0.080592 | 0.898331 | 0.038389 | 0.782194 | 0.036656 | 0.267184 | 0.205224 | 0.258894 | 0.932615 | 0.008125 | 0.403473 | 0.894102 | 10204 |
4 | 0.204209 | 0.021776 | 0.697167 | 0.191023 | 0.546433 | 0.603225 | 0.988794 | 0.092446 | 0.064287 | 0.987952 | 0.452108 | 0.853911 | 0.401445 | 0.388206 | 0.884407 | 27620 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
9995 | 0.017416 | 0.824987 | 0.246643 | 0.141063 | 0.184951 | 0.384777 | 0.722438 | 0.279597 | 0.194048 | 0.816772 | 0.070302 | 0.708632 | 0.497547 | 0.113425 | 0.302923 | 20149 |
9996 | 0.434106 | 0.286644 | 0.964673 | 0.237779 | 0.093510 | 0.788614 | 0.645321 | 0.475191 | 0.551407 | 0.438434 | 0.801701 | 0.698005 | 0.065917 | 0.594159 | 0.664846 | 9288 |
9997 | 0.951556 | 0.573942 | 0.489135 | 0.139136 | 0.991655 | 0.563769 | 0.347741 | 0.782542 | 0.520789 | 0.944053 | 0.820197 | 0.364698 | 0.538379 | 0.761037 | 0.904788 | 34624 |
9998 | 0.800968 | 0.732879 | 0.651727 | 0.610226 | 0.644994 | 0.756211 | 0.247786 | 0.620484 | 0.464670 | 0.879303 | 0.108468 | 0.580453 | 0.742119 | 0.414510 | 0.988418 | 24557 |
9999 | 0.082800 | 0.654022 | 0.453132 | 0.713547 | 0.766718 | 0.452666 | 0.910464 | 0.052970 | 0.132754 | 0.090441 | 0.807935 | 0.648001 | 0.722958 | 0.820611 | 0.093902 | 9184 |
10000 rows × 16 columns
train test split - 위에서 설정한 테스트 개수만큼 split
X, y = df.drop('target', axis = 1).values, df['target'].values
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size = TEST_SIZE, random_state = 2022)
y_train, y_test = y_train.reshape(-1,1), y_test.reshape(-1,1)
Save & Load Model
model_path = f"/content/drive/MyDrive/tabnet_reg_10000/tabnet_reg_{X_SIZE}_{TEST_SIZE}" ### 우선 구글 드라이브에 바로 저장
regressor = TabNetRegressor(verbose=1, seed=2022)
if os.path.isfile(model_path+".zip"):
print(f"LOAD SAVED MODEL -- {model_path}")
regressor.load_model(model_path+".zip")
else:
print(f"TRANING MODEL & SAVE")
regressor = TabNetRegressor(verbose=1, seed=2022)
regressor.fit(X_train=X_train, y_train=y_train,
eval_metric=['rmsle'],
batch_size = 64,
max_epochs=3) ### default epoch # 100
regressor.save_model(model_path)
print(regressor)
TRANING MODEL & SAVE
epoch 0 | loss: 863465970.37419| 0:00:02s
epoch 1 | loss: 835548792.98065| 0:00:07s
epoch 2 | loss: 775246533.57419| 0:00:12s
Successfully saved model at /content/drive/MyDrive/tabnet_reg_10000/tabnet_reg_10000_20.zip
TabNetRegressor(n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[], cat_emb_dim=1, n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02, lambda_sparse=0.001, seed=2022, clip_value=1, verbose=1, optimizer_fn=<class 'torch.optim.adam.Adam'>, optimizer_params={'lr': 0.02}, scheduler_fn=None, scheduler_params={}, mask_type='sparsemax', input_dim=15, output_dim=1, device_name='auto', n_shared_decoder=1, n_indep_decoder=1)
model_path = f"/content/drive/MyDrive/tabnet_reg_10000/tabnet_reg_{X_SIZE}_{TEST_SIZE}"
regressor = TabNetRegressor(verbose=1, seed=2022)
loaded_tabnetregressor = TabNetRegressor()
loaded_tabnetregressor.load_model(model_path+".zip") # 저장한 모델 불러오기
loaded_tabnetregressor
TabNetRegressor(n_d=8, n_a=8, n_steps=3, gamma=1.3, cat_idxs=[], cat_dims=[], cat_emb_dim=1, n_independent=2, n_shared=2, epsilon=1e-15, momentum=0.02, lambda_sparse=0.001, seed=2022, clip_value=1, verbose=1, optimizer_fn=<class 'torch.optim.adam.Adam'>, optimizer_params={'lr': 0.02}, scheduler_fn=None, scheduler_params={}, mask_type='sparsemax', input_dim=15, output_dim=1, device_name='auto', n_shared_decoder=1, n_indep_decoder=1)
Calculate SHAP Value
KernelExplainer
사용 – Deep Learning model 에는 Permutation 또는 Kernel Explainer가 적합하다
### SHAP value 추출
print("현재 shap value 확인하는 테스트 데이터 사이즈 -", X_test.shape)
explainer = shap.KernelExplainer(loaded_tabnetregressor.predict, X_test) # 저장한 모델 넣기
start = time.time()
shap_values = explainer.shap_values(X_test, nsamples = N_sample)
cell_run_time = time.time() - start
print(cell_run_time)
현재 shap value 확인하는 테스트 데이터 사이즈 - (20, 15)
0%| | 0/20 [00:00<?, ?it/s]
21.929385900497437
### cross check
base_value = explainer.expected_value[0]
print("base value", base_value)
base value 2976.890240478515
Merge Predicted Value, Shap Value into Test Dataframe
shap_importance = pd.DataFrame(shap_values[0])
shap_importance.columns = [str(i)+'_SHAPVAL' for i in shap_importance.columns]
shap_importance.loc[:,'TOTAL_SHAPVAL'] = shap_importance.sum(axis=1)
final_report_df = pd.concat([pd.DataFrame(X_test), shap_importance], axis = 1)
y_pred = loaded_tabnetregressor.predict(X_test)
final_report_df['y_pred'] = y_pred
final_report_df['y_pred'] = final_report_df['y_pred'].astype('float64')
final_report_df['base_value'] = pd.Series([base_value] * X_test.shape[0])
final_report_df['TOTAL_SHAPVAL + base_value'] = final_report_df['base_value'] + final_report_df['TOTAL_SHAPVAL']
final_report_df
display(final_report_df)
0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 | 9 | ... | 9_SHAPVAL | 10_SHAPVAL | 11_SHAPVAL | 12_SHAPVAL | 13_SHAPVAL | 14_SHAPVAL | TOTAL_SHAPVAL | y_pred | base_value | TOTAL_SHAPVAL + base_value | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.665340 | 0.305386 | 0.160187 | 0.608090 | 0.876134 | 0.577168 | 0.409957 | 0.951125 | 0.078057 | 0.580620 | ... | 13.989868 | 58.825361 | 0.000000 | 24.958757 | 5.865831 | 207.362089 | 582.052631 | 3558.942871 | 2976.89024 | 3558.942871 |
1 | 0.896871 | 0.016248 | 0.043515 | 0.174851 | 0.286535 | 0.426831 | 0.938803 | 0.245879 | 0.652431 | 0.358082 | ... | -5.117060 | 13.446534 | 1.012135 | -7.112571 | 0.000000 | -299.820383 | -229.182477 | 2747.707520 | 2976.89024 | 2747.707764 |
2 | 0.195551 | 0.534170 | 0.054984 | 0.067725 | 0.899910 | 0.864291 | 0.928148 | 0.663764 | 0.529692 | 0.069272 | ... | -13.821463 | 16.158704 | 0.000000 | 0.000000 | 0.000000 | -797.671184 | -1612.761945 | 1364.128174 | 2976.89024 | 1364.128296 |
3 | 0.780031 | 0.958210 | 0.060079 | 0.127752 | 0.121757 | 0.998124 | 0.619833 | 0.771568 | 0.357500 | 0.250555 | ... | -8.957114 | 2.896907 | 0.000000 | -2.727257 | 1.221694 | -259.071928 | 139.967181 | 3116.857422 | 2976.89024 | 3116.857422 |
4 | 0.833555 | 0.818795 | 0.188312 | 0.866473 | 0.029219 | 0.869027 | 0.231443 | 0.502989 | 0.734818 | 0.029553 | ... | -46.340093 | 97.512771 | 0.000000 | -40.298334 | -10.574096 | 14.495060 | 128.472308 | 3105.362305 | 2976.89024 | 3105.362549 |
5 | 0.387548 | 0.608834 | 0.373447 | 0.163071 | 0.871580 | 0.123094 | 0.695192 | 0.221676 | 0.473412 | 0.179041 | ... | -5.709209 | -32.263871 | 0.000000 | -1.806092 | 2.042695 | 519.058980 | 438.634662 | 3415.524658 | 2976.89024 | 3415.524902 |
6 | 0.051641 | 0.885105 | 0.505377 | 0.459587 | 0.717950 | 0.021153 | 0.329931 | 0.348929 | 0.030949 | 0.122640 | ... | -16.405837 | 21.901906 | -2.739079 | 0.000000 | 0.000000 | -619.011598 | -1625.876813 | 1351.013306 | 2976.89024 | 1351.013428 |
7 | 0.622831 | 0.676264 | 0.349113 | 0.840550 | 0.952585 | 0.700288 | 0.910378 | 0.059485 | 0.070335 | 0.392437 | ... | 2.886479 | -40.003732 | 0.000000 | 7.700811 | 0.000000 | 174.672702 | 53.446185 | 3030.336914 | 2976.89024 | 3030.336426 |
8 | 0.540494 | 0.498669 | 0.432838 | 0.635259 | 0.044647 | 0.157382 | 0.870429 | 0.626106 | 0.463865 | 0.389576 | ... | 0.000000 | -22.330007 | 0.000000 | -3.186513 | 0.000000 | 200.520145 | 297.537982 | 3274.428223 | 2976.89024 | 3274.428223 |
9 | 0.292736 | 0.394259 | 0.501476 | 0.755071 | 0.963806 | 0.097967 | 0.544224 | 0.091328 | 0.650178 | 0.943432 | ... | 17.487106 | -37.783851 | 0.000000 | 0.000000 | -5.146068 | 578.151493 | 538.363422 | 3515.253174 | 2976.89024 | 3515.253662 |
10 | 0.807821 | 0.546079 | 0.826843 | 0.606432 | 0.196059 | 0.102094 | 0.280362 | 0.134659 | 0.952526 | 0.791225 | ... | 23.460712 | 23.505667 | 0.000000 | 20.740959 | 3.940929 | 169.248973 | 579.097552 | 3555.987793 | 2976.89024 | 3555.987793 |
11 | 0.553643 | 0.405755 | 0.070046 | 0.709272 | 0.852507 | 0.653668 | 0.500025 | 0.306107 | 0.014521 | 0.307509 | ... | -3.137991 | 8.399765 | 0.000000 | 0.000000 | 2.952945 | 222.107463 | 587.687885 | 3564.578369 | 2976.89024 | 3564.578125 |
12 | 0.269201 | 0.886692 | 0.236271 | 0.228237 | 0.790845 | 0.247772 | 0.858724 | 0.103983 | 0.214732 | 0.714835 | ... | 33.948094 | 0.000000 | 0.000000 | -4.854956 | -2.832984 | -642.829525 | -1401.765851 | 1575.124268 | 2976.89024 | 1575.124390 |
13 | 0.511182 | 0.153969 | 0.987458 | 0.664299 | 0.489893 | 0.038254 | 0.708350 | 0.582648 | 0.444579 | 0.579134 | ... | 8.152289 | 3.099521 | 0.000000 | 4.080772 | 6.835309 | 167.510164 | 599.708881 | 3576.599121 | 2976.89024 | 3576.599121 |
14 | 0.035127 | 0.145992 | 0.092245 | 0.939648 | 0.698673 | 0.515788 | 0.870113 | 0.728663 | 0.240244 | 0.733680 | ... | 32.049755 | -13.377773 | 0.000000 | -11.680317 | -6.447017 | -558.299584 | -1539.971051 | 1436.918945 | 2976.89024 | 1436.919189 |
15 | 0.458821 | 0.365300 | 0.913334 | 0.213902 | 0.705668 | 0.444712 | 0.655041 | 0.711478 | 0.109087 | 0.047111 | ... | 0.000000 | -47.550775 | 0.000000 | -8.974647 | -14.749852 | 212.535800 | 522.897113 | 3499.787109 | 2976.89024 | 3499.787354 |
16 | 0.586684 | 0.721589 | 0.554982 | 0.728450 | 0.987413 | 0.365390 | 0.149811 | 0.879073 | 0.132804 | 0.158433 | ... | -10.001263 | -16.914937 | 0.000000 | 4.579777 | 0.000000 | 243.402676 | 539.024310 | 3515.915283 | 2976.89024 | 3515.914551 |
17 | 0.447370 | 0.856104 | 0.490726 | 0.512460 | 0.901070 | 0.723599 | 0.946287 | 0.362711 | 0.341831 | 0.226329 | ... | -4.604177 | -30.644028 | 0.000000 | 20.202807 | 10.523801 | 161.331920 | 237.945941 | 3214.836182 | 2976.89024 | 3214.836182 |
18 | 0.602242 | 0.488993 | 0.776788 | 0.334477 | 0.417354 | 0.396566 | 0.848444 | 0.130109 | 0.666934 | 0.080961 | ... | -5.500653 | 5.991046 | 0.000000 | 0.000000 | 5.574099 | 253.524126 | 580.183002 | 3557.073242 | 2976.89024 | 3557.073242 |
19 | 0.459065 | 0.059892 | 0.831910 | 0.959609 | 0.553826 | 0.954812 | 0.142540 | 0.463129 | 0.410726 | 0.211993 | ... | -8.669551 | -23.851995 | 3.546563 | 0.000000 | 0.000000 | 52.351894 | 584.539447 | 3561.429932 | 2976.89024 | 3561.429688 |
20 rows × 34 columns
Optional) 시간 소요 그래프 만들기
Explainer 를 바꿔가면서 시간 소요 테스트를 해보고 싶을 때, log dataframe 을 만들어 추적하면 편하다
cell_run_time
21.929385900497437
log = pd.DataFrame([[TEST_SIZE, COL_SIZE, N_sample, cell_run_time]], columns=['test_size','col_size','nsamples', 'run_time'])
log
# log.to_excel("/content/drive/MyDrive/tabnet_reg_10000/log.xlsx", index=False) # 최초의 log 파일 생성시
# pd.read_excel("/content/drive/MyDrive/tabnet_reg_10000/log.xlsx")
test_size | col_size | nsamples | run_time | |
---|---|---|---|---|
0 | 20 | 15 | auto | 21.929386 |
import os
from openpyxl import load_workbook
cache_path = '/content/drive/MyDrive/tabnet_reg_10000/log.xlsx'
if not os.path.exists(cache_path): # 파일 없으면 최초 log 파일 만들기
initial_log = pd.DataFrame([[TEST_SIZE, COL_SIZE, N_sample, cell_run_time]],columns=['test_size','col_size','nsamples', 'run_time'])
initial_log.to_excel(cache_path, index=False)
else:
book = load_workbook(cache_path) # 기존 log 파일 불러오기
writer = pd.ExcelWriter(cache_path, engine='openpyxl')
writer.book = book
writer.sheets = {ws.title: ws for ws in book.worksheets}
for sheetname in writer.sheets:
log.to_excel(writer,sheet_name=sheetname, startrow=writer.sheets[sheetname].max_row, index = False,header= False) # append rows to existing excel cache file
writer.save()
pd.read_excel('/content/drive/MyDrive/tabnet_reg_10000/log.xlsx')
test_size | col_size | nsamples | run_time | |
---|---|---|---|---|
0 | 500 | 15 | 20 | 322.463880 |
1 | 200 | 15 | 20 | 62.606001 |
2 | 1000 | 15 | auto | 43426.773710 |
3 | 100 | 15 | auto | 1695.991297 |
4 | 300 | 15 | auto | 18396.789017 |
5 | 100 | 15 | auto | 492.172346 |
6 | 100 | 15 | auto | 498.190434 |
7 | 200 | 15 | auto | 1911.865682 |
8 | 400 | 15 | auto | 7451.173161 |
9 | 300 | 15 | auto | 3855.254270 |
10 | 500 | 15 | auto | 10974.687255 |
11 | 400 | 15 | auto | 6170.332176 |
12 | 100 | 15 | auto | 418.124938 |
13 | 200 | 15 | auto | 1512.140173 |
14 | 300 | 15 | auto | 3359.933324 |
15 | 400 | 15 | auto | 6090.160212 |
16 | 500 | 15 | auto | 9173.967759 |
17 | 1000 | 15 | auto | 36812.828763 |
18 | 20 | 15 | auto | 21.929386 |
댓글남기기