|
59 | 59 | "from sklearn.decomposition import PCA\n",
|
60 | 60 | "from sklearn.ensemble import IsolationForest, RandomForestClassifier\n",
|
61 | 61 | "\n",
|
| 62 | + "import shap # Explainable AI tool\n", |
| 63 | + "\n", |
62 | 64 | "import matplotlib.pyplot as plot"
|
63 | 65 | ]
|
64 | 66 | },
|
|
410 | 412 | " anomaly_label_column: str = 'anomalyLabel',\n",
|
411 | 413 | " anomaly_score_column: str = 'anomalyScore',\n",
|
412 | 414 | ") -> pd.DataFrame:\n",
|
413 |
| - " isolation_forest = IsolationForest(n_estimators=200, contamination=0.05, random_state=42)\n", |
| 415 | + " isolation_forest = IsolationForest(n_estimators=200, contamination='auto', random_state=42)\n", |
414 | 416 | " anomaly_score = isolation_forest.fit_predict(prepared_features)\n",
|
415 | 417 | "\n",
|
416 |
| - " original_features[anomaly_label_column] = anomaly_score * -1 # 1 = anomaly, 0 = no anomaly\n", |
| 418 | + " original_features[anomaly_label_column] = (anomaly_score == -1).astype(int) # 1 = anomaly, 0 = normal\n", |
417 | 419 | " original_features[anomaly_score_column] = isolation_forest.decision_function(prepared_features) * -1 # higher = more anomalous\n",
|
418 | 420 | " return original_features"
|
419 | 421 | ]
|
|
440 | 442 | " anomaly_label_column: str = \"anomalyLabel\",\n",
|
441 | 443 | " anomaly_score_column: str = \"anomalyScore\"\n",
|
442 | 444 | ") -> pd.DataFrame:\n",
|
443 |
| - " anomalies = anomaly_detected_features[anomaly_detected_features[anomaly_label_column] == -1]\n", |
| 445 | + " anomalies = anomaly_detected_features[anomaly_detected_features[anomaly_label_column] == 1]\n", |
444 | 446 | " return anomalies.sort_values(by=anomaly_score_column, ascending=False).reset_index(drop=True).head(10)"
|
445 | 447 | ]
|
446 | 448 | },
|
|
456 | 458 | },
|
457 | 459 | {
|
458 | 460 | "cell_type": "markdown",
|
459 |
| - "id": "efa822ca", |
| 461 | + "id": "a3936d79", |
460 | 462 | "metadata": {},
|
461 | 463 | "source": [
|
462 |
| - "### 1.4 Plot the 20 most influential features\n", |
463 |
| - "\n", |
464 |
| - "Use Random Forest as a proxy to estimate the importance of each feature contributing to the anomaly score." |
465 |
| - ] |
466 |
| - }, |
467 |
| - { |
468 |
| - "cell_type": "code", |
469 |
| - "execution_count": null, |
470 |
| - "id": "24427977", |
471 |
| - "metadata": {}, |
472 |
| - "outputs": [], |
473 |
| - "source": [ |
474 |
| - "def get_feature_importances(\n", |
475 |
| - " anomaly_detected_features: pd.DataFrame, \n", |
476 |
| - " prepared_features: numpy_typing.NDArray,\n", |
477 |
| - " anomaly_label_column: str = \"anomalyLabel\",\n", |
478 |
| - ") -> numpy_typing.NDArray:\n", |
479 |
| - " \"\"\"\n", |
480 |
| - " Use Random Forest as a proxy model to find out which are the most important features for the anomaly detection model (Isolation Forest).\n", |
481 |
| - " This helps to see if embedding components dominate (top 10 filled with them), and then tune accordingly.\n", |
482 |
| - " \"\"\"\n", |
483 |
| - " # Use IsolationForest labels as a \"pseudo ground truth\"\n", |
484 |
| - " y_pseudo = (anomaly_detected_features[anomaly_label_column] == -1).astype(int)\n", |
485 |
| - "\n", |
486 |
| - " # Fit classifier to match the IF model\n", |
487 |
| - " proxy_random_forest = RandomForestClassifier(n_estimators=100, random_state=42)\n", |
488 |
| - " proxy_random_forest.fit(prepared_features, y_pseudo)\n", |
489 |
| - "\n", |
490 |
| - " return proxy_random_forest.feature_importances_" |
491 |
| - ] |
492 |
| - }, |
493 |
| - { |
494 |
| - "cell_type": "code", |
495 |
| - "execution_count": null, |
496 |
| - "id": "97b21d49", |
497 |
| - "metadata": {}, |
498 |
| - "outputs": [], |
499 |
| - "source": [ |
500 |
| - "java_package_anomaly_detection_importances = get_feature_importances(java_package_anomaly_detection_features, java_package_anomaly_detection_features_prepared)\n", |
501 |
| - "java_package_anomaly_detection_importances_series = pd.Series(java_package_anomaly_detection_importances, index=java_package_anomaly_detection_feature_names).sort_values(ascending=False)\n", |
502 |
| - "#display(java_type_anomaly_detection_importances_series.head(10))" |
503 |
| - ] |
504 |
| - }, |
505 |
| - { |
506 |
| - "cell_type": "code", |
507 |
| - "execution_count": null, |
508 |
| - "id": "14d0b03e", |
509 |
| - "metadata": {}, |
510 |
| - "outputs": [], |
511 |
| - "source": [ |
512 |
| - "def plot_feature_importances(feature_importances_series: pd.Series, title_prefix: str) -> None:\n", |
513 |
| - " feature_importances_series.head(20).plot(\n", |
514 |
| - " kind='barh',\n", |
515 |
| - " figsize=(10, 6),\n", |
516 |
| - " color='skyblue',\n", |
517 |
| - " title=f\"{title_prefix}: Top 20 Feature Importances (Random Forest Proxy)\",\n", |
518 |
| - " xlabel=\"Importance\"\n", |
519 |
| - " )\n", |
520 |
| - " plot.gca().invert_yaxis() # Most important feature at the top\n", |
521 |
| - " plot.tight_layout()\n", |
522 |
| - " plot.show()" |
523 |
| - ] |
524 |
| - }, |
525 |
| - { |
526 |
| - "cell_type": "code", |
527 |
| - "execution_count": null, |
528 |
| - "id": "974a2bae", |
529 |
| - "metadata": {}, |
530 |
| - "outputs": [], |
531 |
| - "source": [ |
532 |
| - "plot_feature_importances(java_package_anomaly_detection_importances_series, title_prefix='Java Packages')" |
533 |
| - ] |
534 |
| - }, |
535 |
| - { |
536 |
| - "cell_type": "markdown", |
537 |
| - "id": "c9dd6246", |
538 |
| - "metadata": {}, |
539 |
| - "source": [ |
540 |
| - "### 1.5. Plot anomalies\n", |
| 464 | + "### 1.4. Plot anomalies\n", |
541 | 465 | "\n",
|
542 | 466 | "Plots clustered nodes and highlights anomalies."
|
543 | 467 | ]
|
544 | 468 | },
|
545 | 469 | {
|
546 | 470 | "cell_type": "code",
|
547 | 471 | "execution_count": null,
|
548 |
| - "id": "ab1e76ab", |
| 472 | + "id": "c5604735", |
549 | 473 | "metadata": {},
|
550 | 474 | "outputs": [],
|
551 | 475 | "source": [
|
|
640 | 564 | {
|
641 | 565 | "cell_type": "code",
|
642 | 566 | "execution_count": null,
|
643 |
| - "id": "aea29887", |
| 567 | + "id": "61ec7904", |
644 | 568 | "metadata": {},
|
645 | 569 | "outputs": [],
|
646 | 570 | "source": [
|
647 | 571 | "plot_anomalies(java_package_anomaly_detection_features, title_prefix=\"Java Package Anomalies\")"
|
648 | 572 | ]
|
649 | 573 | },
|
| 574 | + { |
| 575 | + "cell_type": "markdown", |
| 576 | + "id": "efa822ca", |
| 577 | + "metadata": {}, |
| 578 | + "source": [ |
| 579 | + "### 1.5 Print the 20 most influential features\n", |
| 580 | + "\n", |
| 581 | + "Use Random Forest as a proxy to estimate the importance of each feature contributing to the anomaly score." |
| 582 | + ] |
| 583 | + }, |
| 584 | + { |
| 585 | + "cell_type": "code", |
| 586 | + "execution_count": null, |
| 587 | + "id": "24427977", |
| 588 | + "metadata": {}, |
| 589 | + "outputs": [], |
| 590 | + "source": [ |
| 591 | + "def get_proxy_random_forest(\n", |
| 592 | + " anomaly_detected_features: pd.DataFrame, \n", |
| 593 | + " prepared_features: numpy_typing.NDArray,\n", |
| 594 | + " anomaly_label_column: str = \"anomalyLabel\",\n", |
| 595 | + ") -> RandomForestClassifier:\n", |
| 596 | + " \"\"\"\n", |
| 597 | + " Use Random Forest as a proxy model to find out which are the most important features for the anomaly detection model (Isolation Forest).\n", |
| 598 | + " This helps to see if embedding components dominate (top 10 filled with them), and then tune accordingly.\n", |
| 599 | + " \"\"\"\n", |
| 600 | + " # Use IsolationForest labels as a \"pseudo ground truth\"\n", |
| 601 | + " y_pseudo = anomaly_detected_features[anomaly_label_column]\n", |
| 602 | + "\n", |
| 603 | + " # Fit classifier to match the IF model\n", |
| 604 | + " proxy_random_forest = RandomForestClassifier(n_estimators=100, random_state=42)\n", |
| 605 | + " proxy_random_forest.fit(prepared_features, y_pseudo)\n", |
| 606 | + "\n", |
| 607 | + " return proxy_random_forest" |
| 608 | + ] |
| 609 | + }, |
| 610 | + { |
| 611 | + "cell_type": "code", |
| 612 | + "execution_count": null, |
| 613 | + "id": "97b21d49", |
| 614 | + "metadata": {}, |
| 615 | + "outputs": [], |
| 616 | + "source": [ |
| 617 | + "java_package_proxy_random_forest = get_proxy_random_forest(java_package_anomaly_detection_features, java_package_anomaly_detection_features_prepared)\n", |
| 618 | + "java_package_anomaly_detection_importances = java_package_proxy_random_forest.feature_importances_\n", |
| 619 | + "java_package_anomaly_detection_importances_series = pd.Series(java_package_anomaly_detection_importances, index=java_package_anomaly_detection_feature_names).sort_values(ascending=False)\n", |
| 620 | + "print(java_package_anomaly_detection_importances_series.head(10))" |
| 621 | + ] |
| 622 | + }, |
| 623 | + { |
| 624 | + "cell_type": "raw", |
| 625 | + "id": "14d0b03e", |
| 626 | + "metadata": { |
| 627 | + "vscode": { |
| 628 | + "languageId": "raw" |
| 629 | + } |
| 630 | + }, |
| 631 | + "source": [ |
| 632 | + "# TODO Remove if not used anymore because of a better plot using SHAP\n", |
| 633 | + "def plot_feature_importances(feature_importances_series: pd.Series, title_prefix: str) -> None:\n", |
| 634 | + " feature_importances_series.head(20).plot(\n", |
| 635 | + " kind='barh',\n", |
| 636 | + " figsize=(10, 6),\n", |
| 637 | + " color='skyblue',\n", |
| 638 | + " title=f\"{title_prefix}: Top 20 Feature Importances (Random Forest Proxy)\",\n", |
| 639 | + " xlabel=\"Importance\"\n", |
| 640 | + " )\n", |
| 641 | + " plot.gca().invert_yaxis() # Most important feature at the top\n", |
| 642 | + " plot.tight_layout()\n", |
| 643 | + " plot.show()\n", |
| 644 | + "\n", |
| 645 | + "plot_feature_importances(java_package_anomaly_detection_importances_series, title_prefix='Java Packages')" |
| 646 | + ] |
| 647 | + }, |
| 648 | + { |
| 649 | + "cell_type": "markdown", |
| 650 | + "id": "db03216e", |
| 651 | + "metadata": {}, |
| 652 | + "source": [ |
| 653 | + "### 1.6 Use SHAP to explain the Isolation Forest Model" |
| 654 | + ] |
| 655 | + }, |
| 656 | + { |
| 657 | + "cell_type": "code", |
| 658 | + "execution_count": null, |
| 659 | + "id": "e8c5905d", |
| 660 | + "metadata": {}, |
| 661 | + "outputs": [], |
| 662 | + "source": [ |
| 663 | + "def explain_anomalies_with_shap(\n", |
| 664 | + " random_forest_model: RandomForestClassifier,\n", |
| 665 | + " anomaly_detected_features: pd.DataFrame,\n", |
| 666 | + " prepared_features: numpy_typing.NDArray,\n", |
| 667 | + " feature_names: list[str],\n", |
| 668 | + " title_prefix: str = \"\",\n", |
| 669 | + " anomaly_label_column: str = \"anomalyLabel\",\n", |
| 670 | + ") -> None:\n", |
| 671 | + " \"\"\"\n", |
| 672 | + " Explain anomalies using SHAP values.\n", |
| 673 | + " \"\"\"\n", |
| 674 | + "\n", |
| 675 | + " # Use TreeExplainer for Random Forest\n", |
| 676 | + " explainer = shap.TreeExplainer(random_forest_model)\n", |
| 677 | + " \n", |
| 678 | + " shap_values = explainer.shap_values(prepared_features)\n", |
| 679 | + " print(f\"Input shape: {anomaly_detected_features.shape}\")\n", |
| 680 | + " print(f\"SHAP shape: {np.shape(shap_values)}\")\n", |
| 681 | + "\n", |
| 682 | + " anomaly_rows = anomaly_detected_features[anomaly_label_column] == 1 # Filter anomalies\n", |
| 683 | + " shap.summary_plot(\n", |
| 684 | + " shap_values[anomaly_rows, :, 1], # Class 1 = anomaly\n", |
| 685 | + " prepared_features[anomaly_rows],\n", |
| 686 | + " feature_names=feature_names,\n", |
| 687 | + " plot_type=\"bar\",\n", |
| 688 | + " title=f\"{title_prefix} Anomalies explained using SHAP\",\n", |
| 689 | + " max_display=20,\n", |
| 690 | + " plot_size=(12, 6) # (width, height) in inches\n", |
| 691 | + " )\n", |
| 692 | + "\n", |
| 693 | + " # Create DataFrame of SHAP values for class 1 (anomaly)\n", |
| 694 | + " shap_df = pd.DataFrame(\n", |
| 695 | + " shap_values[:, :, 1], # select SHAP values for class 1\n", |
| 696 | + " columns=feature_names\n", |
| 697 | + " )\n", |
| 698 | + "\n", |
| 699 | + " # Add anomaly label to shap_df\n", |
| 700 | + " shap_df[\"anomalyLabel\"] = anomaly_detected_features[\"anomalyLabel\"].values\n", |
| 701 | + "\n", |
| 702 | + " # Filter to only anomalies using the boolean mask\n", |
| 703 | + " anomaly_shap_df = shap_df[anomaly_rows].drop(columns=[\"anomalyLabel\"])\n", |
| 704 | + "\n", |
| 705 | + " # Get top 3 features per anomaly (by absolute SHAP value)\n", |
| 706 | + " top3_per_anomaly = anomaly_shap_df.apply(\n", |
| 707 | + " lambda row: list(\n", |
| 708 | + " row.abs().sort_values(ascending=False).head(3).index\n", |
| 709 | + " ),\n", |
| 710 | + " axis=1\n", |
| 711 | + " )\n", |
| 712 | + "\n", |
| 713 | + " # Add top 3 influential features to every anomaly row\n", |
| 714 | + " anomaly_detected_features[\"anomalyLabelInfluentialFeatures\"] = None\n", |
| 715 | + " anomaly_detected_features.loc[\n", |
| 716 | + " anomaly_rows, \"anomalyLabelInfluentialFeatures\"\n", |
| 717 | + " ] = top3_per_anomaly.values\n", |
| 718 | + "\n", |
| 719 | + " display(anomaly_detected_features[anomaly_detected_features[\"anomalyLabel\"] == 1].sort_values(by='anomalyScore', ascending=False).head(10))\n" |
| 720 | + ] |
| 721 | + }, |
| 722 | + { |
| 723 | + "cell_type": "code", |
| 724 | + "execution_count": null, |
| 725 | + "id": "7d671e71", |
| 726 | + "metadata": {}, |
| 727 | + "outputs": [], |
| 728 | + "source": [ |
| 729 | + "explain_anomalies_with_shap(\n", |
| 730 | + " random_forest_model=java_package_proxy_random_forest,\n", |
| 731 | + " anomaly_detected_features=java_package_anomaly_detection_features, \n", |
| 732 | + " prepared_features=java_package_anomaly_detection_features_prepared,\n", |
| 733 | + " feature_names=java_package_anomaly_detection_feature_names,\n", |
| 734 | + " title_prefix=\"Java Package\"\n", |
| 735 | + ")" |
| 736 | + ] |
| 737 | + }, |
650 | 738 | {
|
651 | 739 | "cell_type": "markdown",
|
652 | 740 | "id": "5682bb64",
|
|
767 | 855 | "display(get_top_10_anomalies(java_type_anomaly_detection_features))"
|
768 | 856 | ]
|
769 | 857 | },
|
| 858 | + { |
| 859 | + "cell_type": "markdown", |
| 860 | + "id": "68a00628", |
| 861 | + "metadata": {}, |
| 862 | + "source": [ |
| 863 | + "### 2.4. Plot anomalies\n", |
| 864 | + "\n", |
| 865 | + "Plots clustered nodes and highlights anomalies." |
| 866 | + ] |
| 867 | + }, |
| 868 | + { |
| 869 | + "cell_type": "code", |
| 870 | + "execution_count": null, |
| 871 | + "id": "4ecc9fb4", |
| 872 | + "metadata": {}, |
| 873 | + "outputs": [], |
| 874 | + "source": [ |
| 875 | + "plot_anomalies(java_type_anomaly_detection_features, title_prefix=\"Java Type Anomalies\")" |
| 876 | + ] |
| 877 | + }, |
770 | 878 | {
|
771 | 879 | "cell_type": "markdown",
|
772 | 880 | "id": "4e565f84",
|
773 | 881 | "metadata": {},
|
774 | 882 | "source": [
|
775 |
| - "### 2.4 Plot the 20 most influential features\n", |
| 883 | + "### 2.5 Print the 20 most influential features\n", |
776 | 884 | "\n",
|
777 | 885 | "Use Random Forest as a proxy to estimate the importance of each feature contributing to the anomaly score."
|
778 | 886 | ]
|
779 | 887 | },
|
780 | 888 | {
|
781 | 889 | "cell_type": "code",
|
782 | 890 | "execution_count": null,
|
783 |
| - "id": "1b97f299", |
| 891 | + "id": "86945e66", |
784 | 892 | "metadata": {},
|
785 | 893 | "outputs": [],
|
786 | 894 | "source": [
|
787 |
| - "java_type_anomaly_detection_importances = get_feature_importances(java_type_anomaly_detection_features, java_type_anomaly_detection_features_prepared)\n", |
| 895 | + "java_type_proxy_random_forest = get_proxy_random_forest(java_type_anomaly_detection_features, java_type_anomaly_detection_features_prepared)\n", |
| 896 | + "java_type_anomaly_detection_importances = java_type_proxy_random_forest.feature_importances_\n", |
788 | 897 | "java_type_anomaly_detection_importances_series = pd.Series(java_type_anomaly_detection_importances, index=java_type_anomaly_detection_feature_names).sort_values(ascending=False)\n",
|
789 |
| - "#display(java_type_anomaly_detection_importances_series.head(10))\n", |
790 |
| - "\n", |
791 |
| - "plot_feature_importances(java_type_anomaly_detection_importances_series, title_prefix='Java Types')" |
| 898 | + "print(java_type_anomaly_detection_importances_series.head(10))" |
792 | 899 | ]
|
793 | 900 | },
|
794 | 901 | {
|
795 | 902 | "cell_type": "markdown",
|
796 |
| - "id": "68a00628", |
| 903 | + "id": "b12a0379", |
797 | 904 | "metadata": {},
|
798 | 905 | "source": [
|
799 |
| - "### 2.5. Plot anomalies\n", |
800 |
| - "\n", |
801 |
| - "Plots clustered nodes and highlights anomalies." |
| 906 | + "### 2.6 Use SHAP to explain the Isolation Forest Model" |
802 | 907 | ]
|
803 | 908 | },
|
804 | 909 | {
|
805 | 910 | "cell_type": "code",
|
806 | 911 | "execution_count": null,
|
807 |
| - "id": "4ecc9fb4", |
| 912 | + "id": "2d4b35c6", |
808 | 913 | "metadata": {},
|
809 | 914 | "outputs": [],
|
810 | 915 | "source": [
|
811 |
| - "plot_anomalies(java_type_anomaly_detection_features, title_prefix=\"Java Type Anomalies\")" |
| 916 | + "explain_anomalies_with_shap(\n", |
| 917 | + " random_forest_model=java_type_proxy_random_forest,\n", |
| 918 | + " anomaly_detected_features=java_type_anomaly_detection_features, \n", |
| 919 | + " prepared_features=java_type_anomaly_detection_features_prepared,\n", |
| 920 | + " feature_names=java_type_anomaly_detection_feature_names,\n", |
| 921 | + " title_prefix=\"Java Type\"\n", |
| 922 | + ")" |
812 | 923 | ]
|
813 | 924 | }
|
814 | 925 | ],
|
|
0 commit comments