diff --git a/.gitignore b/.gitignore index 62ef8ff..a319a71 100644 --- a/.gitignore +++ b/.gitignore @@ -31,6 +31,7 @@ share/python-wheels/ .installed.cfg *.egg MANIFEST +.DS_Store # PyInstaller # Usually these files are written by a python script from a template diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..261eeb9 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..ca4abd7 --- /dev/null +++ b/Makefile @@ -0,0 +1,42 @@ +# Public variable to be set by the user in the Makefile +TARGET_DSS_VERSION=7.0 + +# Private variables to be set by the user in the environment +ifndef DKU_PLUGIN_DEVELOPER_ORG +$(error the DKU_PLUGIN_DEVELOPER_ORG environment variable is not set) +endif +ifndef DKU_PLUGIN_DEVELOPER_TOKEN +$(error the DKU_PLUGIN_DEVELOPER_TOKEN environment variable is not set) +endif +ifndef DKU_PLUGIN_DEVELOPER_REPO_URL +$(error the DKU_PLUGIN_DEVELOPER_REPO_URL environment variable is not set) +endif + +# evaluate additional variable +plugin_id=`cat plugin.json | python -c "import sys, json; print(str(json.load(sys.stdin)['id']).replace('/',''))"` +plugin_version=`cat plugin.json | python -c "import sys, json; print(str(json.load(sys.stdin)['version']).replace('/',''))"` +archive_file_name="dss-plugin-${plugin_id}-${plugin_version}.zip" +artifact_repo_target="${DKU_PLUGIN_DEVELOPER_REPO_URL}/${TARGET_DSS_VERSION}/${DKU_PLUGIN_DEVELOPER_ORG}/${plugin_id}/${plugin_version}/${archive_file_name}" +remote_url=`git config --get remote.origin.url` +last_commit_id=`git rev-parse HEAD` + + +plugin: + @echo "[START] Archiving plugin to dist/ folder..." + @cat plugin.json | json_pp > /dev/null + @rm -rf dist + @mkdir dist + @echo "{\"remote_url\":\"${remote_url}\",\"last_commit_id\":\"${last_commit_id}\"}" > release_info.json + @git archive -v -9 --format zip -o dist/${archive_file_name} HEAD + @zip -u dist/${archive_file_name} release_info.json + @rm release_info.json + @echo "[SUCCESS] Archiving plugin to dist/ folder: Done!" + +submit: plugin + @echo "[START] Publishing archive to artifact repository..." + @curl -H "Authorization: Bearer ${DKU_PLUGIN_DEVELOPER_TOKEN}>" -X PUT ${artifact_repo_target} -T dist/${archive_file_name} + @echo "[SUCCESS] Publishing archive to artifact repository: Done!" + + +dist-clean: + rm -rf dist diff --git a/README.md b/README.md index d7a7f3c..4f95941 100644 --- a/README.md +++ b/README.md @@ -1 +1,8 @@ -# dss-plugin-amazon-comprehend-medical +# Amazon Comprehend Medical Plugin + +This Dataiku DSS plugin provides several recipes to call the [Amazon Comprehend Medical APIs](https://aws.amazon.com/comprehend/medical/). + +Documentation: https://www.dataiku.com/product/plugins/amazon-comprehend-nlp-medical/ + +### Licence +This plugin is distributed under the Apache License version 2.0 diff --git a/code-env/python/desc.json b/code-env/python/desc.json new file mode 100644 index 0000000..fca0a34 --- /dev/null +++ b/code-env/python/desc.json @@ -0,0 +1,9 @@ +{ + "acceptedPythonInterpreters": [ + "PYTHON36", + "PYTHON35" + ], + "forceConda": false, + "installCorePackages": true, + "installJupyterSupport": false +} \ No newline at end of file diff --git a/code-env/python/spec/requirements.txt b/code-env/python/spec/requirements.txt new file mode 100644 index 0000000..ee7b2d6 --- /dev/null +++ b/code-env/python/spec/requirements.txt @@ -0,0 +1,5 @@ +boto3==1.13.13 +tqdm==4.46.0 +ratelimit==2.2.1 +retry==0.9.2 +more_itertools==8.2.0 \ No newline at end of file diff --git a/custom-recipes/amazon-comprehend-nlp-medical-entity-recognition/recipe.json b/custom-recipes/amazon-comprehend-nlp-medical-entity-recognition/recipe.json new file mode 100644 index 0000000..01d463b --- /dev/null +++ b/custom-recipes/amazon-comprehend-nlp-medical-entity-recognition/recipe.json @@ -0,0 +1,155 @@ +{ + "meta": { + "label": "Medical Entity Recognition", + "description": "Recognize Medical Entities (medical condition, treatment, etc.) in a medical text record", + "icon": "icon-amazon-comprehend icon-cloud", + "displayOrderRank": 2 + }, + "kind": "PYTHON", + "inputRoles": [ + { + "name": "input_dataset", + "label": "Input Dataset", + "description": "Dataset containing the text data to analyze", + "arity": "UNARY", + "required": true, + "acceptsDataset": true + } + ], + "outputRoles": [ + { + "name": "output_dataset", + "label": "Output dataset", + "description": "Dataset with enriched output", + "arity": "UNARY", + "required": true, + "acceptsDataset": true + } + ], + "params": [ + { + "name": "separator_input", + "label": "Input Parameters", + "type": "SEPARATOR" + }, + { + "name": "text_column", + "label": "Text column", + "type": "COLUMN", + "columnRole": "input_dataset", + "mandatory": true, + "allowedColumnTypes": [ + "string" + ] + }, + { + "name": "language", + "label": "Language", + "description": "Only supported language", + "type": "SELECT", + "mandatory": true, + "selectChoices": [ + { + "value": "en", + "label": "English" + } + ], + "defaultValue": "en" + }, + { + "name": "separator_configuration", + "label": "Configuration", + "type": "SEPARATOR" + }, + { + "name": "api_configuration_preset", + "label": "API configuration preset", + "type": "PRESET", + "parameterSetId": "api-configuration", + "mandatory": true + }, + { + "name": "separator_advanced", + "label": "Advanced", + "type": "SEPARATOR" + }, + { + "name": "expert", + "label": "Expert mode", + "type": "BOOLEAN", + "defaultValue": false + }, + { + "name": "entity_types", + "label": "Entity types", + "type": "MULTISELECT", + "visibilityCondition": "model.expert == true", + "description": "List of medical entity types to extract", + "mandatory": true, + "selectChoices": [ + { + "value": "ANATOMY", + "label": "Anatomy" + }, + { + "value": "MEDICAL_CONDITION", + "label": "Medical condition" + }, + { + "value": "MEDICATION", + "label": "Medication" + }, + { + "value": "PROTECTED_HEALTH_INFORMATION", + "label": "Protected health information" + }, + { + "value": "TEST_TREATMENT_PROCEDURE", + "label": "Test treatment procedure" + }, + { + "value": "TIME_EXPRESSION", + "label": "Time expression" + } + ], + "defaultValue": [ + "ANATOMY", + "MEDICAL_CONDITION", + "MEDICATION", + "TEST_TREATMENT_PROCEDURE", + "TIME_EXPRESSION" + ] + }, + { + "name": "minimum_score", + "label": "Minimum score", + "description": "Minimum confidence score (from 0 to 1) for the medical entity to be recognized", + "visibilityCondition": "model.expert == true", + "type": "DOUBLE", + "mandatory": true, + "defaultValue": 0, + "minD": 0, + "maxD": 1 + }, + { + "name": "error_handling", + "label": "Error handling", + "type": "SELECT", + "visibilityCondition": "model.expert == true", + "selectChoices": [ + { + "value": "FAIL", + "label": "Fail" + }, + { + "value": "LOG", + "label": "Log" + } + ], + "defaultValue": "LOG", + "mandatory": true, + "description": "Log API errors to the output or fail with an exception on any API error" + } + ], + "resourceKeys": [] +} \ No newline at end of file diff --git a/custom-recipes/amazon-comprehend-nlp-medical-entity-recognition/recipe.py b/custom-recipes/amazon-comprehend-nlp-medical-entity-recognition/recipe.py new file mode 100644 index 0000000..4a8f512 --- /dev/null +++ b/custom-recipes/amazon-comprehend-nlp-medical-entity-recognition/recipe.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +import json +from typing import Dict, AnyStr + +from retry import retry +from ratelimit import limits, RateLimitException + +import dataiku +from dataiku.customrecipe import get_recipe_config, get_input_names_for_role, get_output_names_for_role + +from plugin_io_utils import ErrorHandlingEnum, validate_column_input, set_column_description +from amazon_comprehend_medical_api_formatting import MedicalEntityTypeEnum, MedicalEntityAPIFormatter +from amazon_comprehend_medical_api_client import API_EXCEPTIONS, get_client +from api_parallelizer import api_parallelizer + + +# ============================================================================== +# SETUP +# ============================================================================== + +api_configuration_preset = get_recipe_config().get("api_configuration_preset") +api_quota_rate_limit = api_configuration_preset.get("api_quota_rate_limit") +api_quota_period = api_configuration_preset.get("api_quota_period") +parallel_workers = api_configuration_preset.get("parallel_workers") +text_column = get_recipe_config().get("text_column") +entity_types = [MedicalEntityTypeEnum[i] for i in get_recipe_config().get("entity_types", [])] +minimum_score = float(get_recipe_config().get("minimum_score", 0)) +if minimum_score < 0 or minimum_score > 1: + raise ValueError("Minimum confidence score must be between 0 and 1") +error_handling = ErrorHandlingEnum[get_recipe_config().get("error_handling")] + +input_dataset_name = get_input_names_for_role("input_dataset")[0] +input_dataset = dataiku.Dataset(input_dataset_name) +input_schema = input_dataset.read_schema() +input_columns_names = [col["name"] for col in input_schema] +validate_column_input(text_column, input_columns_names) + +output_dataset_name = get_output_names_for_role("output_dataset")[0] +output_dataset = dataiku.Dataset(output_dataset_name) + +input_df = input_dataset.get_dataframe() +client = get_client(api_configuration_preset) +column_prefix = "medical_entity_api" + + +# ============================================================================== +# RUN +# ============================================================================== + + +@retry((RateLimitException, OSError), delay=api_quota_period, tries=5) +@limits(calls=api_quota_rate_limit, period=api_quota_period) +def call_api_medical_entity_recognition(row: Dict, text_column: AnyStr) -> Dict: + text = row[text_column] + if not isinstance(text, str) or str(text).strip() == "": + return "" + responses = client.detect_entities_v2(Text=text) + return json.dumps(responses) + + +df = api_parallelizer( + input_df=input_df, + api_call_function=call_api_medical_entity_recognition, + api_exceptions=API_EXCEPTIONS, + text_column=text_column, + parallel_workers=parallel_workers, + error_handling=error_handling, + column_prefix=column_prefix, +) + +api_formatter = MedicalEntityAPIFormatter( + input_df=input_df, + entity_types=entity_types, + minimum_score=minimum_score, + column_prefix=column_prefix, + error_handling=error_handling, +) +output_df = api_formatter.format_df(df) + +output_dataset.write_with_schema(output_df) +set_column_description( + input_dataset=input_dataset, + output_dataset=output_dataset, + column_description_dict=api_formatter.column_description_dict, +) diff --git a/custom-recipes/amazon-comprehend-nlp-medical-protected-health-information/recipe.json b/custom-recipes/amazon-comprehend-nlp-medical-protected-health-information/recipe.json new file mode 100644 index 0000000..23b166d --- /dev/null +++ b/custom-recipes/amazon-comprehend-nlp-medical-protected-health-information/recipe.json @@ -0,0 +1,114 @@ +{ + "meta": { + "label": "Protected Health Information Extraction", + "description": "Extract Protected Health Information (PHI) in a medical text record", + "icon": "icon-amazon-comprehend icon-cloud", + "displayOrderRank": 1 + }, + "kind": "PYTHON", + "inputRoles": [ + { + "name": "input_dataset", + "label": "Input Dataset", + "description": "Dataset containing the text data to analyze", + "arity": "UNARY", + "required": true, + "acceptsDataset": true + } + ], + "outputRoles": [ + { + "name": "output_dataset", + "label": "Output dataset", + "description": "Dataset with enriched output", + "arity": "UNARY", + "required": true, + "acceptsDataset": true + } + ], + "params": [ + { + "name": "separator_input", + "label": "Input Parameters", + "type": "SEPARATOR" + }, + { + "name": "text_column", + "label": "Text column", + "type": "COLUMN", + "columnRole": "input_dataset", + "mandatory": true, + "allowedColumnTypes": [ + "string" + ] + }, + { + "name": "language", + "label": "Language", + "description": "Only supported language", + "type": "SELECT", + "mandatory": true, + "selectChoices": [ + { + "value": "en", + "label": "English" + } + ], + "defaultValue": "en" + }, + { + "name": "separator_configuration", + "label": "Configuration", + "type": "SEPARATOR" + }, + { + "name": "api_configuration_preset", + "label": "API configuration preset", + "type": "PRESET", + "parameterSetId": "api-configuration", + "mandatory": true + }, + { + "name": "separator_advanced", + "label": "Advanced", + "type": "SEPARATOR" + }, + { + "name": "expert", + "label": "Expert mode", + "type": "BOOLEAN", + "defaultValue": false + }, + { + "name": "minimum_score", + "label": "Minimum score", + "description": "Minimum confidence score (from 0 to 1) for the PHI to be extracted", + "visibilityCondition": "model.expert == true", + "type": "DOUBLE", + "mandatory": true, + "defaultValue": 0, + "minD": 0, + "maxD": 1 + }, + { + "name": "error_handling", + "label": "Error handling", + "type": "SELECT", + "visibilityCondition": "model.expert == true", + "selectChoices": [ + { + "value": "FAIL", + "label": "Fail" + }, + { + "value": "LOG", + "label": "Log" + } + ], + "defaultValue": "LOG", + "mandatory": true, + "description": "Log API errors to the output or fail with an exception on any API error" + } + ], + "resourceKeys": [] +} \ No newline at end of file diff --git a/custom-recipes/amazon-comprehend-nlp-medical-protected-health-information/recipe.py b/custom-recipes/amazon-comprehend-nlp-medical-protected-health-information/recipe.py new file mode 100644 index 0000000..74e8cfe --- /dev/null +++ b/custom-recipes/amazon-comprehend-nlp-medical-protected-health-information/recipe.py @@ -0,0 +1,80 @@ +# -*- coding: utf-8 -*- +import json +from typing import Dict, AnyStr + +from retry import retry +from ratelimit import limits, RateLimitException + +import dataiku +from dataiku.customrecipe import get_recipe_config, get_input_names_for_role, get_output_names_for_role + +from plugin_io_utils import ErrorHandlingEnum, validate_column_input, set_column_description +from amazon_comprehend_medical_api_formatting import MedicalPhiAPIFormatter +from amazon_comprehend_medical_api_client import API_EXCEPTIONS, get_client +from api_parallelizer import api_parallelizer + + +# ============================================================================== +# SETUP +# ============================================================================== + +api_configuration_preset = get_recipe_config().get("api_configuration_preset") +api_quota_rate_limit = api_configuration_preset.get("api_quota_rate_limit") +api_quota_period = api_configuration_preset.get("api_quota_period") +parallel_workers = api_configuration_preset.get("parallel_workers") +text_column = get_recipe_config().get("text_column") +minimum_score = float(get_recipe_config().get("minimum_score", 0)) +if minimum_score < 0 or minimum_score > 1: + raise ValueError("Minimum confidence score must be between 0 and 1") +error_handling = ErrorHandlingEnum[get_recipe_config().get("error_handling")] + +input_dataset_name = get_input_names_for_role("input_dataset")[0] +input_dataset = dataiku.Dataset(input_dataset_name) +input_schema = input_dataset.read_schema() +input_columns_names = [col["name"] for col in input_schema] +validate_column_input(text_column, input_columns_names) + +output_dataset_name = get_output_names_for_role("output_dataset")[0] +output_dataset = dataiku.Dataset(output_dataset_name) + +input_df = input_dataset.get_dataframe() +client = get_client(api_configuration_preset) +column_prefix = "medical_phi_api" + + +# ============================================================================== +# RUN +# ============================================================================== + + +@retry((RateLimitException, OSError), delay=api_quota_period, tries=5) +@limits(calls=api_quota_rate_limit, period=api_quota_period) +def call_api_medical_phi_extraction(row: Dict, text_column: AnyStr) -> Dict: + text = row[text_column] + if not isinstance(text, str) or str(text).strip() == "": + return "" + responses = client.detect_phi(Text=text) + return json.dumps(responses) + + +df = api_parallelizer( + input_df=input_df, + api_call_function=call_api_medical_phi_extraction, + api_exceptions=API_EXCEPTIONS, + text_column=text_column, + parallel_workers=parallel_workers, + error_handling=error_handling, + column_prefix=column_prefix, +) + +api_formatter = MedicalPhiAPIFormatter( + input_df=input_df, minimum_score=minimum_score, column_prefix=column_prefix, error_handling=error_handling, +) +output_df = api_formatter.format_df(df) + +output_dataset.write_with_schema(output_df) +set_column_description( + input_dataset=input_dataset, + output_dataset=output_dataset, + column_description_dict=api_formatter.column_description_dict, +) diff --git a/parameter-sets/api-configuration/parameter-set.json b/parameter-sets/api-configuration/parameter-set.json new file mode 100644 index 0000000..c98ac64 --- /dev/null +++ b/parameter-sets/api-configuration/parameter-set.json @@ -0,0 +1,78 @@ +{ + "meta": { + "label": "API configuration", + "description": "Define presets for users to call the API with specific credentials, quota and parallelization parameters", + "icon": "icon-amazon-comprehend icon-cloud" + }, + "defaultDefinableAtProjectLevel": true, + "defaultDefinableInline": false, + "params": [ + { + "name": "separator_authentication", + "label": "Authentication", + "description": "Please refer to the AWS documentation: https://boto3.amazonaws.com/v1/documentation/api/latest/guide/configuration.html", + "type": "SEPARATOR" + }, + { + "name": "aws_access_key", + "label": "AWS access key ID", + "description": "If empty, attempts to ascertain credentials from the environment.", + "type": "STRING", + "mandatory": false + }, + { + "name": "aws_secret_key", + "label": "AWS secret access key", + "description": "If empty, attempts to ascertain credentials from the environment.", + "type": "PASSWORD", + "mandatory": false + }, + { + "name": "aws_region", + "label": "AWS region", + "description": "If empty, attempts to ascertain region name from the environment.", + "type": "STRING", + "mandatory": false, + "defaultValue": "us-east-1" + }, + { + "name": "separator_api_quota", + "label": "API quota", + "type": "SEPARATOR", + "description": "Throttling to stay within the quota defined by AWS: https://docs.aws.amazon.com/comprehend/latest/dg/guidelines-and-limits-med.html" + }, + { + "name": "api_quota_period", + "label": "Period", + "description": "Reset period of the quota in seconds. Defined by AWS.", + "type": "INT", + "mandatory": true, + "defaultValue": 1, + "minI": 1 + }, + { + "name": "api_quota_rate_limit", + "label": "Rate limit", + "description": "Maximum number of requests per period for one DSS activity. Reduce for concurrent activities.", + "type": "INT", + "mandatory": true, + "defaultValue": 5, + "minI": 1 + }, + { + "name": "separator_performance", + "label": "Parallelization", + "type": "SEPARATOR" + }, + { + "name": "parallel_workers", + "label": "Concurrency", + "description": "Number of threads calling the API in parallel (maximum 100). Increase to speed-up computation within the quota defined above.", + "type": "INT", + "mandatory": true, + "defaultValue": 4, + "minI": 1, + "maxI": 100 + } + ] +} diff --git a/plugin.json b/plugin.json index bdd8ee9..e8cb3b3 100644 --- a/plugin.json +++ b/plugin.json @@ -8,7 +8,7 @@ "author": "Dataiku (Alex COMBESSIE)", "icon": "icon-amazon-comprehend icon-cloud", "licenseInfo": "Apache Software License", - "url": "https://www.dataiku.com/dss/plugins/info/amazon-comprehend-nlp.html", + "url": "https://www.dataiku.com/product/plugins/amazon-comprehend-nlp-medical/", "tags": [ "AWS", "Cloud", @@ -16,4 +16,4 @@ ], "supportLevel": "TIER2_SUPPORT" } -} +} \ No newline at end of file diff --git a/python-lib/amazon_comprehend_medical_api_client.py b/python-lib/amazon_comprehend_medical_api_client.py new file mode 100644 index 0000000..ac28718 --- /dev/null +++ b/python-lib/amazon_comprehend_medical_api_client.py @@ -0,0 +1,27 @@ +# -*- coding: utf-8 -*- +import logging + +import boto3 +from boto3.exceptions import Boto3Error +from botocore.exceptions import BotoCoreError, ClientError + +# ============================================================================== +# CONSTANT DEFINITION +# ============================================================================== + +API_EXCEPTIONS = (Boto3Error, BotoCoreError, ClientError) + +# ============================================================================== +# CLASS AND FUNCTION DEFINITION +# ============================================================================== + + +def get_client(api_configuration_preset): + client = boto3.client( + service_name="comprehendmedical", + aws_access_key_id=api_configuration_preset.get("aws_access_key"), + aws_secret_access_key=api_configuration_preset.get("aws_secret_key"), + region_name=api_configuration_preset.get("aws_region"), + ) + logging.info("Credentials loaded") + return client diff --git a/python-lib/amazon_comprehend_medical_api_formatting.py b/python-lib/amazon_comprehend_medical_api_formatting.py new file mode 100644 index 0000000..cb6afda --- /dev/null +++ b/python-lib/amazon_comprehend_medical_api_formatting.py @@ -0,0 +1,187 @@ +# -*- coding: utf-8 -*- +import logging +from typing import AnyStr, Dict, List +from enum import Enum + +import pandas as pd + +from plugin_io_utils import ( + API_COLUMN_NAMES_DESCRIPTION_DICT, + ErrorHandlingEnum, + build_unique_column_names, + generate_unique, + safe_json_loads, + move_api_columns_to_end, +) + + +# ============================================================================== +# CONSTANT DEFINITION +# ============================================================================== + + +class MedicalEntityTypeEnum(Enum): + ANATOMY = "Anatomy" + MEDICAL_CONDITION = "Medical condition" + MEDICATION = "Medication" + PROTECTED_HEALTH_INFORMATION = "Protected health information" + TEST_TREATMENT_PROCEDURE = "Test treatment procedure" + TIME_EXPRESSION = "Time expression" + + +class MedicalPHITypeEnum(Enum): + ADDRESS = "Address" + AGE = "Age" + DATE = "Date" + NAME = "Name" + PHONE_OR_FAX = "Phone or fax" + EMAIL = "Email" + ID = "ID" + + +# ============================================================================== +# CLASS AND FUNCTION DEFINITION +# ============================================================================== + + +class GenericAPIFormatter: + """ + Geric Formatter class for API responses: + - initialize with generic parameters + - compute generic column descriptions + - apply format_row to dataframe + """ + + def __init__( + self, + input_df: pd.DataFrame, + column_prefix: AnyStr = "api", + error_handling: ErrorHandlingEnum = ErrorHandlingEnum.LOG, + ): + self.input_df = input_df + self.column_prefix = column_prefix + self.error_handling = error_handling + self.api_column_names = build_unique_column_names(input_df, column_prefix) + self.column_description_dict = { + v: API_COLUMN_NAMES_DESCRIPTION_DICT[k] for k, v in self.api_column_names._asdict().items() + } + + def format_row(self, row: Dict) -> Dict: + return row + + def format_df(self, df: pd.DataFrame) -> pd.DataFrame: + logging.info("Formatting API results...") + df = df.apply(func=self.format_row, axis=1) + df = move_api_columns_to_end(df, self.api_column_names, self.error_handling) + logging.info("Formatting API results: Done.") + return df + + +class MedicalPhiAPIFormatter(GenericAPIFormatter): + """ + Formatter class for Protected Health Information API responses: + - make sure response is valid JSON + - expand results to multiple columns + - compute column descriptions + """ + + def __init__( + self, + input_df: pd.DataFrame, + minimum_score: float, + column_prefix: AnyStr = "medical_phi_api", + error_handling: ErrorHandlingEnum = ErrorHandlingEnum.LOG, + ): + super().__init__(input_df, column_prefix, error_handling) + self.minimum_score = float(minimum_score) + self._compute_column_description() + + def _compute_column_description(self): + for entity_enum in MedicalPHITypeEnum: + entity_type_column = generate_unique( + "entity_type_" + str(entity_enum.value).lower() + "_text", self.input_df.keys(), self.column_prefix, + ) + self.column_description_dict[entity_type_column] = "List of '{}' PHI entities extracted by the API".format( + str(entity_enum.value) + ) + + def format_row(self, row: Dict) -> Dict: + raw_response = row[self.api_column_names.response] + response = safe_json_loads(raw_response, self.error_handling) + entities = response.get("Entities", []) + discarded_entities = [ + e + for e in entities + if float(e.get("Score", 0)) < self.minimum_score + and e.get("Type", "") in [e.name for e in MedicalEntityTypeEnum] + ] + if len(discarded_entities) != 0: + logging.info("Discarding {} entities below the minimum score threshold".format(len(discarded_entities))) + for entity_enum in MedicalPHITypeEnum: + entity_type_column = generate_unique( + "entity_type_" + str(entity_enum.value).lower() + "_text", row.keys(), self.column_prefix, + ) + row[entity_type_column] = [ + e.get("Text", "") + for e in entities + if e.get("Type", "") == entity_enum.name and float(e.get("Score", 0)) >= self.minimum_score + ] + if len(row[entity_type_column]) == 0: + row[entity_type_column] = "" + return row + + +class MedicalEntityAPIFormatter(GenericAPIFormatter): + """ + Formatter class for Medical Entity Recognition API responses: + - make sure response is valid JSON + - expand results to multiple columns + - compute column descriptions + """ + + def __init__( + self, + input_df: pd.DataFrame, + entity_types: List, + minimum_score: float, + column_prefix: AnyStr = "medical_entity_api", + error_handling: ErrorHandlingEnum = ErrorHandlingEnum.LOG, + ): + super().__init__(input_df, column_prefix, error_handling) + self.entity_types = entity_types + self.minimum_score = float(minimum_score) + self._compute_column_description() + + def _compute_column_description(self): + for entity_enum in MedicalEntityTypeEnum: + entity_type_column = generate_unique( + "entity_type_" + str(entity_enum.value).lower() + "_text", self.input_df.keys(), self.column_prefix, + ) + self.column_description_dict[ + entity_type_column + ] = "List of '{}' medical entities extracted by the API".format(str(entity_enum.value)) + + def format_row(self, row: Dict) -> Dict: + raw_response = row[self.api_column_names.response] + response = safe_json_loads(raw_response, self.error_handling) + entities = response.get("Entities", []) + discarded_entities = [ + e + for e in entities + if float(e.get("Score", 0)) < self.minimum_score + and e.get("Category", "") in [e.name for e in MedicalEntityTypeEnum] + ] + if len(discarded_entities) != 0: + logging.info("Discarding {} entities below the minimum score threshold".format(len(discarded_entities))) + for entity_enum in MedicalEntityTypeEnum: + entity_type_column = generate_unique( + "entity_type_" + str(entity_enum.value).lower() + "_text", row.keys(), self.column_prefix, + ) + row[entity_type_column] = [ + e.get("Text", "") + for e in entities + if e.get("Category", "") == entity_enum.name and float(e.get("Score", 0)) >= self.minimum_score + ] + if len(row[entity_type_column]) == 0: + row[entity_type_column] = "" + return row diff --git a/python-lib/api_parallelizer.py b/python-lib/api_parallelizer.py new file mode 100644 index 0000000..81c15f4 --- /dev/null +++ b/python-lib/api_parallelizer.py @@ -0,0 +1,192 @@ +# -*- coding: utf-8 -*- +import logging +import inspect +import math + +from typing import Callable, AnyStr, List, Tuple, NamedTuple, Dict, Union +from concurrent.futures import ThreadPoolExecutor, as_completed + +import pandas as pd +from more_itertools import chunked, flatten +from tqdm.auto import tqdm as tqdm_auto + +from plugin_io_utils import ErrorHandlingEnum, build_unique_column_names + + +# ============================================================================== +# CONSTANT DEFINITION +# ============================================================================== + +DEFAULT_PARALLEL_WORKERS = 4 +DEFAULT_BATCH_SIZE = 10 +DEFAULT_API_SUPPORT_BATCH = False +DEFAULT_VERBOSE = False + + +# ============================================================================== +# CLASS AND FUNCTION DEFINITION +# ============================================================================== + + +def api_call_single_row( + api_call_function: Callable, + api_column_names: NamedTuple, + row: Dict, + api_exceptions: Union[Exception, Tuple[Exception]], + error_handling: ErrorHandlingEnum = ErrorHandlingEnum.LOG, + verbose: bool = DEFAULT_VERBOSE, + **api_call_function_kwargs +) -> Dict: + """ + Wraps a single-row API calling function to: + - ensure it has a 'row' parameter which is a dict + (for batches of rows, use the api_call_batch function below) + - return the row with a new 'response' key containing the function result + - handles errors from the function with two methods: + * (default) do not fail on API-related exceptions, just log it + and return the row with new error keys + * fail if there is an error and raise it + """ + if error_handling == ErrorHandlingEnum.FAIL: + response = api_call_function(row=row, **api_call_function_kwargs) + row[api_column_names.response] = response + else: + for k in api_column_names: + row[k] = "" + try: + response = api_call_function(row=row, **api_call_function_kwargs) + row[api_column_names.response] = response + except api_exceptions as e: + logging.warning(str(e)) + module = str(inspect.getmodule(e).__name__) + error_name = str(type(e).__qualname__) + row[api_column_names.error_message] = str(e) + row[api_column_names.error_type] = ".".join([module, error_name]) + row[api_column_names.error_raw] = str(e.args) + return row + + +def api_call_batch( + api_call_function: Callable, + api_column_names: NamedTuple, + batch: List[Dict], + batch_api_response_parser: Callable, + api_exceptions: Union[Exception, Tuple[Exception]], + error_handling: ErrorHandlingEnum = ErrorHandlingEnum.LOG, + verbose: bool = DEFAULT_VERBOSE, + **api_call_function_kwargs +) -> List[Dict]: + """ + Wraps a batch API calling function to: + - ensure it has a 'batch' parameter which is a list of dict + - return the batch with a new 'response' key in each dict + containing the function result + - handles errors from the function with two methods: + * (default) do not fail on API-related exceptions, just log it + and return the batch with new error keys in each dict (using batch_api_parser) + * fail if there is an error and raise it + """ + if error_handling == ErrorHandlingEnum.FAIL: + response = api_call_function(batch=batch, **api_call_function_kwargs) + batch = batch_api_response_parser(batch=batch, response=response, api_column_names=api_column_names) + errors = [row[api_column_names.error_message] for row in batch if row[api_column_names.error_message] != ""] + if len(errors) != 0: + raise Exception("API returned errors: " + str(errors)) + else: + try: + response = api_call_function(batch=batch, **api_call_function_kwargs) + batch = batch_api_response_parser(batch=batch, response=response, api_column_names=api_column_names) + except api_exceptions as e: + logging.warning(str(e)) + module = str(inspect.getmodule(e).__name__) + error_name = str(type(e).__qualname__) + for row in batch: + row[api_column_names.response] = "" + row[api_column_names.error_message] = str(e) + row[api_column_names.error_type] = ".".join([module, error_name]) + row[api_column_names.error_raw] = str(e.args) + return batch + + +def convert_api_results_to_df( + input_df: pd.DataFrame, + api_results: List[Dict], + api_column_names: NamedTuple, + error_handling: ErrorHandlingEnum = ErrorHandlingEnum.LOG, + verbose: bool = DEFAULT_VERBOSE, +) -> pd.DataFrame: + """ + Helper function to the "api_parallelizer" main function. + Combine API results (list of dict) with input dataframe, + and convert it to a dataframe. + """ + if error_handling == ErrorHandlingEnum.FAIL: + columns_to_exclude = [v for k, v in api_column_names._asdict().items() if "error" in k] + else: + columns_to_exclude = [] + if not verbose: + columns_to_exclude = [api_column_names.error_raw] + output_schema = {**{v: str for v in api_column_names}, **dict(input_df.dtypes)} + output_schema = {k: v for k, v in output_schema.items() if k not in columns_to_exclude} + record_list = [{col: result.get(col) for col in output_schema.keys()} for result in api_results] + api_column_list = [c for c in api_column_names if c not in columns_to_exclude] + output_column_list = list(input_df.columns) + api_column_list + output_df = pd.DataFrame.from_records(record_list).astype(output_schema).reindex(columns=output_column_list) + assert len(output_df.index) == len(input_df.index) + return output_df + + +def api_parallelizer( + input_df: pd.DataFrame, + api_call_function: Callable, + api_exceptions: Union[Exception, Tuple[Exception]], + column_prefix: AnyStr, + parallel_workers: int = DEFAULT_PARALLEL_WORKERS, + api_support_batch: bool = DEFAULT_API_SUPPORT_BATCH, + batch_size: int = DEFAULT_BATCH_SIZE, + error_handling: ErrorHandlingEnum = ErrorHandlingEnum.LOG, + verbose: bool = DEFAULT_VERBOSE, + **api_call_function_kwargs +) -> pd.DataFrame: + """ + Apply an API call function in parallel to a pandas.DataFrame. + The DataFrame is passed to the function as row dictionaries. + Parallelism works by: + - (default) sending multiple concurrent threads + - if the API supports it, sending batches of row + """ + df_iterator = (i[1].to_dict() for i in input_df.iterrows()) + len_iterator = len(input_df.index) + log_msg = "Calling remote API endpoint with {} rows".format(len_iterator) + if api_support_batch: + log_msg += ", chunked by {}".format(batch_size) + df_iterator = chunked(df_iterator, batch_size) + len_iterator = math.ceil(len_iterator / batch_size) + logging.info(log_msg) + api_column_names = build_unique_column_names(input_df.columns, column_prefix) + pool_kwargs = api_call_function_kwargs.copy() + more_kwargs = [ + "api_call_function", + "error_handling", + "api_exceptions", + "api_column_names", + ] + for k in more_kwargs: + pool_kwargs[k] = locals()[k] + for k in ["fn", "row", "batch"]: # Reserved pool keyword arguments + pool_kwargs.pop(k, None) + api_results = [] + with ThreadPoolExecutor(max_workers=parallel_workers) as pool: + if api_support_batch: + futures = [pool.submit(api_call_batch, batch=batch, **pool_kwargs) for batch in df_iterator] + else: + futures = [pool.submit(api_call_single_row, row=row, **pool_kwargs) for row in df_iterator] + for f in tqdm_auto(as_completed(futures), total=len_iterator): + api_results.append(f.result()) + if api_support_batch: + api_results = flatten(api_results) + output_df = convert_api_results_to_df(input_df, api_results, api_column_names, error_handling, verbose) + num_api_error = sum(output_df[api_column_names.response] == "") + num_api_success = len(input_df.index) - num_api_error + logging.info("Remote API call results: {} rows succeeded, {} rows failed.".format(num_api_success, num_api_error)) + return output_df diff --git a/python-lib/plugin_io_utils.py b/python-lib/plugin_io_utils.py new file mode 100644 index 0000000..4530e8d --- /dev/null +++ b/python-lib/plugin_io_utils.py @@ -0,0 +1,136 @@ +# -*- coding: utf-8 -*- +import logging +import json +import pandas as pd +import dataiku + +from enum import Enum +from typing import AnyStr, List, NamedTuple, Dict +from collections import OrderedDict, namedtuple + + +# ============================================================================== +# CONSTANT DEFINITION +# ============================================================================== + +API_COLUMN_NAMES_DESCRIPTION_DICT = OrderedDict( + [ + ("response", "Raw response from the API in JSON format"), + ("error_message", "Error message from the API"), + ("error_type", "Error type (module and class name)"), + ("error_raw", "Raw error from the API"), + ] +) +COLUMN_PREFIX = "api" + + +# ============================================================================== +# CLASS AND FUNCTION DEFINITION +# ============================================================================== + +ApiColumnNameTuple = namedtuple("ApiColumnNameTuple", API_COLUMN_NAMES_DESCRIPTION_DICT.keys()) + + +class ErrorHandlingEnum(Enum): + LOG = "Log" + FAIL = "Fail" + + +def generate_unique(name: AnyStr, existing_names: List, prefix: AnyStr = COLUMN_PREFIX) -> AnyStr: + """ + Generate a unique name among existing ones by suffixing a number. + Can also add an optional prefix. + """ + if prefix is not None: + new_name = prefix + "_" + name + else: + new_name = name + for j in range(1, 1001): + if new_name not in existing_names: + return new_name + new_name = name + "_{}".format(j) + raise Exception("Failed to generated a unique name") + + +def build_unique_column_names(existing_names: List[AnyStr], column_prefix: AnyStr = COLUMN_PREFIX) -> NamedTuple: + """ + Helper function to the "api_parallelizer" main function. + Initializes a named tuple of column names from ApiColumnNameTuple, + adding a prefix and a number suffix to make them unique. + """ + api_column_names = ApiColumnNameTuple( + *[generate_unique(k, existing_names, column_prefix) for k in ApiColumnNameTuple._fields] + ) + return api_column_names + + +def safe_json_loads( + str_to_check: AnyStr, error_handling: ErrorHandlingEnum = ErrorHandlingEnum.LOG, verbose: bool = False, +) -> Dict: + """ + Wrap json.loads with an additional parameter to handle errors: + - 'FAIL' to use json.loads, which throws an exception on invalid data + - 'LOG' to try json.loads and return an empty dict if data is invalid + """ + if error_handling == ErrorHandlingEnum.FAIL: + output = json.loads(str_to_check) + else: + try: + output = json.loads(str_to_check) + except (TypeError, ValueError): + if verbose: + logging.warning("Invalid JSON: '" + str(str_to_check) + "'") + output = {} + return output + + +def validate_column_input(column_name: AnyStr, column_list: List[AnyStr]) -> None: + """ + Validate that user input for column parameter is valid. + """ + if column_name is None or len(column_name) == 0: + raise ValueError("You must specify a valid column name.") + if column_name not in column_list: + raise ValueError("Column '{}' is not present in the input dataset.".format(column_name)) + + +def move_api_columns_to_end( + df: pd.DataFrame, api_column_names: NamedTuple, error_handling: ErrorHandlingEnum = ErrorHandlingEnum.LOG +) -> pd.DataFrame: + """ + Move non-human-readable API columns to the end of the dataframe + """ + api_column_names_dict = api_column_names._asdict() + if error_handling == ErrorHandlingEnum.FAIL: + api_column_names_dict.pop("error_message", None) + api_column_names_dict.pop("error_type", None) + if not any(["error_raw" in k for k in df.keys()]): + api_column_names_dict.pop("error_raw", None) + cols = [c for c in df.keys() if c not in api_column_names_dict.values()] + new_cols = cols + list(api_column_names_dict.values()) + df = df.reindex(columns=new_cols) + return df + + +def set_column_description( + input_dataset: dataiku.Dataset, output_dataset: dataiku.Dataset, column_description_dict: Dict, +) -> None: + """ + Set column descriptions of the output dataset based on a dictionary of column descriptions + and retains the column descriptions from the input dataset if the column name matches + """ + input_dataset_schema = input_dataset.read_schema() + output_dataset_schema = output_dataset.read_schema() + input_columns_names = [col["name"] for col in input_dataset_schema] + for output_col_info in output_dataset_schema: + output_col_name = output_col_info.get("name", "") + output_col_info["comment"] = column_description_dict.get(output_col_name) + if output_col_name in input_columns_names: + matched_comment = [ + input_col_info.get("comment", "") + for input_col_info in input_dataset_schema + if input_col_info.get("name") == output_col_name + ] + if len(matched_comment) != 0: + output_col_info["comment"] = matched_comment[0] + output_dataset.write_schema(output_dataset_schema)