Skip to content

Commit 192feeb

Browse files
committed
Use SHAP to explain the detected anomalies
1 parent 0a7db5b commit 192feeb

File tree

2 files changed

+209
-97
lines changed

2 files changed

+209
-97
lines changed

domains/anomaly-detection/explore/AnomalyDetectionIsolationForestExploration.ipynb

Lines changed: 208 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,8 @@
5959
"from sklearn.decomposition import PCA\n",
6060
"from sklearn.ensemble import IsolationForest, RandomForestClassifier\n",
6161
"\n",
62+
"import shap # Explainable AI tool\n",
63+
"\n",
6264
"import matplotlib.pyplot as plot"
6365
]
6466
},
@@ -410,10 +412,10 @@
410412
" anomaly_label_column: str = 'anomalyLabel',\n",
411413
" anomaly_score_column: str = 'anomalyScore',\n",
412414
") -> pd.DataFrame:\n",
413-
" isolation_forest = IsolationForest(n_estimators=200, contamination=0.05, random_state=42)\n",
415+
" isolation_forest = IsolationForest(n_estimators=200, contamination='auto', random_state=42)\n",
414416
" anomaly_score = isolation_forest.fit_predict(prepared_features)\n",
415417
"\n",
416-
" original_features[anomaly_label_column] = anomaly_score * -1 # 1 = anomaly, 0 = no anomaly\n",
418+
" original_features[anomaly_label_column] = (anomaly_score == -1).astype(int) # 1 = anomaly, 0 = normal\n",
417419
" original_features[anomaly_score_column] = isolation_forest.decision_function(prepared_features) * -1 # higher = more anomalous\n",
418420
" return original_features"
419421
]
@@ -440,7 +442,7 @@
440442
" anomaly_label_column: str = \"anomalyLabel\",\n",
441443
" anomaly_score_column: str = \"anomalyScore\"\n",
442444
") -> pd.DataFrame:\n",
443-
" anomalies = anomaly_detected_features[anomaly_detected_features[anomaly_label_column] == -1]\n",
445+
" anomalies = anomaly_detected_features[anomaly_detected_features[anomaly_label_column] == 1]\n",
444446
" return anomalies.sort_values(by=anomaly_score_column, ascending=False).reset_index(drop=True).head(10)"
445447
]
446448
},
@@ -456,96 +458,18 @@
456458
},
457459
{
458460
"cell_type": "markdown",
459-
"id": "efa822ca",
461+
"id": "a3936d79",
460462
"metadata": {},
461463
"source": [
462-
"### 1.4 Plot the 20 most influential features\n",
463-
"\n",
464-
"Use Random Forest as a proxy to estimate the importance of each feature contributing to the anomaly score."
465-
]
466-
},
467-
{
468-
"cell_type": "code",
469-
"execution_count": null,
470-
"id": "24427977",
471-
"metadata": {},
472-
"outputs": [],
473-
"source": [
474-
"def get_feature_importances(\n",
475-
" anomaly_detected_features: pd.DataFrame, \n",
476-
" prepared_features: numpy_typing.NDArray,\n",
477-
" anomaly_label_column: str = \"anomalyLabel\",\n",
478-
") -> numpy_typing.NDArray:\n",
479-
" \"\"\"\n",
480-
" Use Random Forest as a proxy model to find out which are the most important features for the anomaly detection model (Isolation Forest).\n",
481-
" This helps to see if embedding components dominate (top 10 filled with them), and then tune accordingly.\n",
482-
" \"\"\"\n",
483-
" # Use IsolationForest labels as a \"pseudo ground truth\"\n",
484-
" y_pseudo = (anomaly_detected_features[anomaly_label_column] == -1).astype(int)\n",
485-
"\n",
486-
" # Fit classifier to match the IF model\n",
487-
" proxy_random_forest = RandomForestClassifier(n_estimators=100, random_state=42)\n",
488-
" proxy_random_forest.fit(prepared_features, y_pseudo)\n",
489-
"\n",
490-
" return proxy_random_forest.feature_importances_"
491-
]
492-
},
493-
{
494-
"cell_type": "code",
495-
"execution_count": null,
496-
"id": "97b21d49",
497-
"metadata": {},
498-
"outputs": [],
499-
"source": [
500-
"java_package_anomaly_detection_importances = get_feature_importances(java_package_anomaly_detection_features, java_package_anomaly_detection_features_prepared)\n",
501-
"java_package_anomaly_detection_importances_series = pd.Series(java_package_anomaly_detection_importances, index=java_package_anomaly_detection_feature_names).sort_values(ascending=False)\n",
502-
"#display(java_type_anomaly_detection_importances_series.head(10))"
503-
]
504-
},
505-
{
506-
"cell_type": "code",
507-
"execution_count": null,
508-
"id": "14d0b03e",
509-
"metadata": {},
510-
"outputs": [],
511-
"source": [
512-
"def plot_feature_importances(feature_importances_series: pd.Series, title_prefix: str) -> None:\n",
513-
" feature_importances_series.head(20).plot(\n",
514-
" kind='barh',\n",
515-
" figsize=(10, 6),\n",
516-
" color='skyblue',\n",
517-
" title=f\"{title_prefix}: Top 20 Feature Importances (Random Forest Proxy)\",\n",
518-
" xlabel=\"Importance\"\n",
519-
" )\n",
520-
" plot.gca().invert_yaxis() # Most important feature at the top\n",
521-
" plot.tight_layout()\n",
522-
" plot.show()"
523-
]
524-
},
525-
{
526-
"cell_type": "code",
527-
"execution_count": null,
528-
"id": "974a2bae",
529-
"metadata": {},
530-
"outputs": [],
531-
"source": [
532-
"plot_feature_importances(java_package_anomaly_detection_importances_series, title_prefix='Java Packages')"
533-
]
534-
},
535-
{
536-
"cell_type": "markdown",
537-
"id": "c9dd6246",
538-
"metadata": {},
539-
"source": [
540-
"### 1.5. Plot anomalies\n",
464+
"### 1.4. Plot anomalies\n",
541465
"\n",
542466
"Plots clustered nodes and highlights anomalies."
543467
]
544468
},
545469
{
546470
"cell_type": "code",
547471
"execution_count": null,
548-
"id": "ab1e76ab",
472+
"id": "c5604735",
549473
"metadata": {},
550474
"outputs": [],
551475
"source": [
@@ -640,13 +564,177 @@
640564
{
641565
"cell_type": "code",
642566
"execution_count": null,
643-
"id": "aea29887",
567+
"id": "61ec7904",
644568
"metadata": {},
645569
"outputs": [],
646570
"source": [
647571
"plot_anomalies(java_package_anomaly_detection_features, title_prefix=\"Java Package Anomalies\")"
648572
]
649573
},
574+
{
575+
"cell_type": "markdown",
576+
"id": "efa822ca",
577+
"metadata": {},
578+
"source": [
579+
"### 1.5 Print the 20 most influential features\n",
580+
"\n",
581+
"Use Random Forest as a proxy to estimate the importance of each feature contributing to the anomaly score."
582+
]
583+
},
584+
{
585+
"cell_type": "code",
586+
"execution_count": null,
587+
"id": "24427977",
588+
"metadata": {},
589+
"outputs": [],
590+
"source": [
591+
"def get_proxy_random_forest(\n",
592+
" anomaly_detected_features: pd.DataFrame, \n",
593+
" prepared_features: numpy_typing.NDArray,\n",
594+
" anomaly_label_column: str = \"anomalyLabel\",\n",
595+
") -> RandomForestClassifier:\n",
596+
" \"\"\"\n",
597+
" Use Random Forest as a proxy model to find out which are the most important features for the anomaly detection model (Isolation Forest).\n",
598+
" This helps to see if embedding components dominate (top 10 filled with them), and then tune accordingly.\n",
599+
" \"\"\"\n",
600+
" # Use IsolationForest labels as a \"pseudo ground truth\"\n",
601+
" y_pseudo = anomaly_detected_features[anomaly_label_column]\n",
602+
"\n",
603+
" # Fit classifier to match the IF model\n",
604+
" proxy_random_forest = RandomForestClassifier(n_estimators=100, random_state=42)\n",
605+
" proxy_random_forest.fit(prepared_features, y_pseudo)\n",
606+
"\n",
607+
" return proxy_random_forest"
608+
]
609+
},
610+
{
611+
"cell_type": "code",
612+
"execution_count": null,
613+
"id": "97b21d49",
614+
"metadata": {},
615+
"outputs": [],
616+
"source": [
617+
"java_package_proxy_random_forest = get_proxy_random_forest(java_package_anomaly_detection_features, java_package_anomaly_detection_features_prepared)\n",
618+
"java_package_anomaly_detection_importances = java_package_proxy_random_forest.feature_importances_\n",
619+
"java_package_anomaly_detection_importances_series = pd.Series(java_package_anomaly_detection_importances, index=java_package_anomaly_detection_feature_names).sort_values(ascending=False)\n",
620+
"print(java_package_anomaly_detection_importances_series.head(10))"
621+
]
622+
},
623+
{
624+
"cell_type": "raw",
625+
"id": "14d0b03e",
626+
"metadata": {
627+
"vscode": {
628+
"languageId": "raw"
629+
}
630+
},
631+
"source": [
632+
"# TODO Remove if not used anymore because of a better plot using SHAP\n",
633+
"def plot_feature_importances(feature_importances_series: pd.Series, title_prefix: str) -> None:\n",
634+
" feature_importances_series.head(20).plot(\n",
635+
" kind='barh',\n",
636+
" figsize=(10, 6),\n",
637+
" color='skyblue',\n",
638+
" title=f\"{title_prefix}: Top 20 Feature Importances (Random Forest Proxy)\",\n",
639+
" xlabel=\"Importance\"\n",
640+
" )\n",
641+
" plot.gca().invert_yaxis() # Most important feature at the top\n",
642+
" plot.tight_layout()\n",
643+
" plot.show()\n",
644+
"\n",
645+
"plot_feature_importances(java_package_anomaly_detection_importances_series, title_prefix='Java Packages')"
646+
]
647+
},
648+
{
649+
"cell_type": "markdown",
650+
"id": "db03216e",
651+
"metadata": {},
652+
"source": [
653+
"### 1.6 Use SHAP to explain the Isolation Forest Model"
654+
]
655+
},
656+
{
657+
"cell_type": "code",
658+
"execution_count": null,
659+
"id": "e8c5905d",
660+
"metadata": {},
661+
"outputs": [],
662+
"source": [
663+
"def explain_anomalies_with_shap(\n",
664+
" random_forest_model: RandomForestClassifier,\n",
665+
" anomaly_detected_features: pd.DataFrame,\n",
666+
" prepared_features: numpy_typing.NDArray,\n",
667+
" feature_names: list[str],\n",
668+
" title_prefix: str = \"\",\n",
669+
" anomaly_label_column: str = \"anomalyLabel\",\n",
670+
") -> None:\n",
671+
" \"\"\"\n",
672+
" Explain anomalies using SHAP values.\n",
673+
" \"\"\"\n",
674+
"\n",
675+
" # Use TreeExplainer for Random Forest\n",
676+
" explainer = shap.TreeExplainer(random_forest_model)\n",
677+
" \n",
678+
" shap_values = explainer.shap_values(prepared_features)\n",
679+
" print(f\"Input shape: {anomaly_detected_features.shape}\")\n",
680+
" print(f\"SHAP shape: {np.shape(shap_values)}\")\n",
681+
"\n",
682+
" anomaly_rows = anomaly_detected_features[anomaly_label_column] == 1 # Filter anomalies\n",
683+
" shap.summary_plot(\n",
684+
" shap_values[anomaly_rows, :, 1], # Class 1 = anomaly\n",
685+
" prepared_features[anomaly_rows],\n",
686+
" feature_names=feature_names,\n",
687+
" plot_type=\"bar\",\n",
688+
" title=f\"{title_prefix} Anomalies explained using SHAP\",\n",
689+
" max_display=20,\n",
690+
" plot_size=(12, 6) # (width, height) in inches\n",
691+
" )\n",
692+
"\n",
693+
" # Create DataFrame of SHAP values for class 1 (anomaly)\n",
694+
" shap_df = pd.DataFrame(\n",
695+
" shap_values[:, :, 1], # select SHAP values for class 1\n",
696+
" columns=feature_names\n",
697+
" )\n",
698+
"\n",
699+
" # Add anomaly label to shap_df\n",
700+
" shap_df[\"anomalyLabel\"] = anomaly_detected_features[\"anomalyLabel\"].values\n",
701+
"\n",
702+
" # Filter to only anomalies using the boolean mask\n",
703+
" anomaly_shap_df = shap_df[anomaly_rows].drop(columns=[\"anomalyLabel\"])\n",
704+
"\n",
705+
" # Get top 3 features per anomaly (by absolute SHAP value)\n",
706+
" top3_per_anomaly = anomaly_shap_df.apply(\n",
707+
" lambda row: list(\n",
708+
" row.abs().sort_values(ascending=False).head(3).index\n",
709+
" ),\n",
710+
" axis=1\n",
711+
" )\n",
712+
"\n",
713+
" # Add top 3 influential features to every anomaly row\n",
714+
" anomaly_detected_features[\"anomalyLabelInfluentialFeatures\"] = None\n",
715+
" anomaly_detected_features.loc[\n",
716+
" anomaly_rows, \"anomalyLabelInfluentialFeatures\"\n",
717+
" ] = top3_per_anomaly.values\n",
718+
"\n",
719+
" display(anomaly_detected_features[anomaly_detected_features[\"anomalyLabel\"] == 1].sort_values(by='anomalyScore', ascending=False).head(10))\n"
720+
]
721+
},
722+
{
723+
"cell_type": "code",
724+
"execution_count": null,
725+
"id": "7d671e71",
726+
"metadata": {},
727+
"outputs": [],
728+
"source": [
729+
"explain_anomalies_with_shap(\n",
730+
" random_forest_model=java_package_proxy_random_forest,\n",
731+
" anomaly_detected_features=java_package_anomaly_detection_features, \n",
732+
" prepared_features=java_package_anomaly_detection_features_prepared,\n",
733+
" feature_names=java_package_anomaly_detection_feature_names,\n",
734+
" title_prefix=\"Java Package\"\n",
735+
")"
736+
]
737+
},
650738
{
651739
"cell_type": "markdown",
652740
"id": "5682bb64",
@@ -767,48 +855,71 @@
767855
"display(get_top_10_anomalies(java_type_anomaly_detection_features))"
768856
]
769857
},
858+
{
859+
"cell_type": "markdown",
860+
"id": "68a00628",
861+
"metadata": {},
862+
"source": [
863+
"### 2.4. Plot anomalies\n",
864+
"\n",
865+
"Plots clustered nodes and highlights anomalies."
866+
]
867+
},
868+
{
869+
"cell_type": "code",
870+
"execution_count": null,
871+
"id": "4ecc9fb4",
872+
"metadata": {},
873+
"outputs": [],
874+
"source": [
875+
"plot_anomalies(java_type_anomaly_detection_features, title_prefix=\"Java Type Anomalies\")"
876+
]
877+
},
770878
{
771879
"cell_type": "markdown",
772880
"id": "4e565f84",
773881
"metadata": {},
774882
"source": [
775-
"### 2.4 Plot the 20 most influential features\n",
883+
"### 2.5 Print the 20 most influential features\n",
776884
"\n",
777885
"Use Random Forest as a proxy to estimate the importance of each feature contributing to the anomaly score."
778886
]
779887
},
780888
{
781889
"cell_type": "code",
782890
"execution_count": null,
783-
"id": "1b97f299",
891+
"id": "86945e66",
784892
"metadata": {},
785893
"outputs": [],
786894
"source": [
787-
"java_type_anomaly_detection_importances = get_feature_importances(java_type_anomaly_detection_features, java_type_anomaly_detection_features_prepared)\n",
895+
"java_type_proxy_random_forest = get_proxy_random_forest(java_type_anomaly_detection_features, java_type_anomaly_detection_features_prepared)\n",
896+
"java_type_anomaly_detection_importances = java_type_proxy_random_forest.feature_importances_\n",
788897
"java_type_anomaly_detection_importances_series = pd.Series(java_type_anomaly_detection_importances, index=java_type_anomaly_detection_feature_names).sort_values(ascending=False)\n",
789-
"#display(java_type_anomaly_detection_importances_series.head(10))\n",
790-
"\n",
791-
"plot_feature_importances(java_type_anomaly_detection_importances_series, title_prefix='Java Types')"
898+
"print(java_type_anomaly_detection_importances_series.head(10))"
792899
]
793900
},
794901
{
795902
"cell_type": "markdown",
796-
"id": "68a00628",
903+
"id": "b12a0379",
797904
"metadata": {},
798905
"source": [
799-
"### 2.5. Plot anomalies\n",
800-
"\n",
801-
"Plots clustered nodes and highlights anomalies."
906+
"### 2.6 Use SHAP to explain the Isolation Forest Model"
802907
]
803908
},
804909
{
805910
"cell_type": "code",
806911
"execution_count": null,
807-
"id": "4ecc9fb4",
912+
"id": "2d4b35c6",
808913
"metadata": {},
809914
"outputs": [],
810915
"source": [
811-
"plot_anomalies(java_type_anomaly_detection_features, title_prefix=\"Java Type Anomalies\")"
916+
"explain_anomalies_with_shap(\n",
917+
" random_forest_model=java_type_proxy_random_forest,\n",
918+
" anomaly_detected_features=java_type_anomaly_detection_features, \n",
919+
" prepared_features=java_type_anomaly_detection_features_prepared,\n",
920+
" feature_names=java_type_anomaly_detection_feature_names,\n",
921+
" title_prefix=\"Java Type\"\n",
922+
")"
812923
]
813924
}
814925
],

0 commit comments

Comments
 (0)