Entity Extraction From FDA Label Data Using Amazon Comprehend Medical, Step Functions and AWS Lambda
In this article, we explore how we can quickly use cognitive services provided by AWS to augment our workflows with AI with very little effort. We will use the Amazon Comprehend Medical service to automate entity extraction from drug label data for downstream visualization. The Food and Drug Administration (FDA) provides a wealth of large data sets via the OpenFDA (https://open.fda.gov) website and APIs. We will use one of these data sets in particular, the Drug Labels (https://open.fda.gov/apis/drug/label/) for this workflow.
Amazon Comprehend Medical (https://aws.amazon.com/comprehend/medical/) is a natural language processing (NLP) service that uses machine learning to extract key medical data from text. It provides a simple API to extract information such as medical conditions, dosage, treatments and procedure from medical text. The service is fully managed by AWS and you pay based on your API usage.
Medical Entity Extraction Pipeline Technology Stack
Since the goal of this article is to demonstrate the usage of Comprehend Medical, we will assume that the drug label data has already been downloaded and made queryable using Amazon Athena. We will build a pipeline that queries the ‘indications and usage’ field of the dataset and runs each record through the pipeline to extract key medical terms and store the data in S3. The pipeline uses serverless technologies on the AWS platform in order to avoid the need to provision EC2 instances, maintain and patch the instances, handling scaling, etc. The pipeline uses Step Functions to orchestrate the various steps in the workflow, Comprehend Medical for the NLP engine and S3 for storing the outputs.
Following are the primary AWS services used in this pipeline:
- Comprehend Medical — NLP Service
- Amazon Athena — Serverless query service to query data in S3
- Step Functions — Workflow orchestration
- Serverless Framework — Framework to build serverless applications and provision resources
- Lambda — Serverless compute
- Simple Storage Service — Store train, test and validation data
- EventBridge — Serverless event bus
Architecture
The following diagram shows a high-level architecture of the pipeline to perform key entity extraction from medical data:
Detailed Steps
Workflow Orchestration Using Serverless Framework: The serverless application framework (https://www.serverless.com/open-source/) provides a convenient way to develop, test and deploy serverless applications. The serverless framework lets the developers specify the various aspects of their application such as Lambda functions, Step Functions, Resources, IAM roles and policies, etc. in a single, easy to configure serverless.yaml file. It also provides a CLI that can deploy the application to the cloud using a single command. An example serverless.yaml file for this use case is provided below for reference.
service: comprehend-service-testframeworkVersion: ">=1.49.0"provider:
name: aws
runtime: python3.6
stage: dev
region: ${opt:region, 'us-east-1'}
memorySize: 3008
timeout: 900
reservedConcurrency: 1
deploymentBucket:
name: <deployment-bucket-name>
deploymentPrefix: deploy-${self:service}
environment:
bucket_name: <project-bucket-name>iamRoleStatements:- Effect: "Allow"
Action: "states:StartExecution"
Resource: "arn:aws:states:#{AWS::Region}:#{AWS::AccountId}:stateMachine:comprehendStateMachine"- Effect: "Allow"
Action: s3:*
Resource:
- "arn:aws:s3:::<project-bucket-name>/*"
- "arn:aws:s3:::<project-bucket-name>" - Effect: "Allow"
Action:
- comprehendmedical:DetectEntities
- comprehendmedical:DetectPHI
Resource:
- "*"- Effect: "Allow"
Action:
- athena:*
Resource:
- "*"- Effect: "Allow"
Action:
- glue:GetTable
- glue:GetPartitions
- glue:GetDatabase
Resource:
- "*"- Effect: Allow
Action:
- logs:CreateLogGroup
- logs:CreateLogStream
- logs:PutLogEvents
Resource: "arn:aws:logs:#{AWS::Region}:#{AWS::AccountId}:log-group:/aws/lambda/*:*:*"
plugins:
- serverless-step-functions
- serverless-pseudo-parameterspackage:
exclude:
- __pycache__/**
- node_modules/**
- tests/**
- env/**
- package.json
- package-lock.json
- .vscodefunctions:proxy:
handler: proxy.handler
environment:
statemachine_arn: "arn:aws:states:#{AWS::Region}:#{AWS::AccountId}:stateMachine:comprehendStateMachine"queryAthena:
handler: query_athena.handlerconfigureCount:
handler: configure_count.handler
environment:
comprehend_chunksize: 500
layers:
- arn:aws:lambda:us-east-1:<account_number>:layer:AWSLambda-Python36-SciPy1x:2 # numpy layer
- arn:aws:lambda:us-east-1:<account_number>:layer:pandas-xlrd-layer-Python36-Pandas23x:3 # pandas layeriterator:
handler: iterator.handlercallComprehend:
handler: call_comprehend.handler
environment:
comprehend_output_bucket_name: djm-landing
layers:
- arn:aws:lambda:us-east-1:<account_number>:layer:boto3layer:3 # boto3 layer
- arn:aws:lambda:us-east-1:<account_number>:layer:AWSLambda-Python36-SciPy1x:2 # numpy layer
- arn:aws:lambda:us-east-1:<account_number>:layer:pandas-xlrd-layer-Python36-Pandas23x:3 # pandas layerstepFunctions:
stateMachines:
comprehendStateMachine:
name: comprehendStateMachine
definition:
Comment: "Comprehend Tutorial state machine"
StartAt: QueryAthena
States:
QueryAthena:
Comment: "Queries Athena for indications"
Type: Task
Resource: "arn:aws:lambda:#{AWS::Region}:#{AWS::AccountId}:function:${self:service}-${opt:stage}-queryAthena"
ResultPath: "$.data_node.query_results"
Next: ConfigureCount
ConfigureCount:
Type: Task
Resource: "arn:aws:lambda:#{AWS::Region}:#{AWS::AccountId}:function:${self:service}-${opt:stage}-configureCount"
ResultPath: "$.iterator"
Next: Iterator
Iterator:
Type: Task
Resource: "arn:aws:lambda:#{AWS::Region}:#{AWS::AccountId}:function:${self:service}-${opt:stage}-iterator"
ResultPath: "$.iterator"
Next: IsRecordCountReached
IsRecordCountReached:
Type: Choice
Choices:
- Variable: "$.iterator.continue_iterating"
BooleanEquals: True
Next: CallComprehend
- Variable: "$.iterator.continue_iterating"
BooleanEquals: False
Next: IterationComplete
CallComprehend:
Comment: "Calls comprehend to extract entities"
Type: Task
Resource: "arn:aws:lambda:#{AWS::Region}:#{AWS::AccountId}:function:${self:service}-${opt:stage}-callComprehend"
ResultPath: "$.iterator"
Next: Iterator
IterationComplete:
Type: Pass
Next: Done
Done:
Type: Pass
End: trueresources:
Outputs:
ComprehendStateMachine:
Description: ARN of ComprehendStateMachine
Value:
Ref: ComprehendStateMachine
Trigger Pipeline Execution: The medical entity extraction step function is triggered by an EventBridge Trigger at a certain cadence (daily/weekly/monthly). The EventBridge can be modified to trigger based on events such as availability of new data instead of a scheduled execution.
Query Athena: The assumption in this step is that the FDA drug labels data has been already downloaded, processed and made queryable by Athena. The following Lambda function is part of the workflow to query the ‘indications and usage’ field between the years 2017 and 2019. The result of the query is stored in a S3 bucket and the url of the data is passed on to the next step in the step function. It is always good practice to pass the url of data files between steps in a step function and not pass the data contents because there is a limitation on the number of characters that be passed between steps. The step function fails if the payload size goes beyond the character limit.
import json
import boto3
import time
import urllib
import os
import csv
import logginglogger = logging.getLogger()
log_handler = logger.handlers[0]
log_handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s:%(message)s", "%Y-%m-%d %H:%M:%S"))
logger.setLevel(logging.INFO)s3 = boto3.resource('s3')
athena = boto3.client('athena')# setup configurationS3_OUTPUT = 's3://<project-bucket-name>/athenaoutput' # replace bucket name
DATABASE = 'fda'
RETRY_COUNT = 10# query all indications from last 2 yearsquery_all_indications = 'SELECT id, indications_and_usage FROM "fda"."fda_curated" where year >= \'2017\' and year <= \'2019\''
def handler(event, context):
if event:
logger.info(">> querying athena database")
data_node = {}
data_node['bucket_name'] = os.environ['bucket_name'] # run query on athena
response = athena.start_query_execution(
QueryString = query_all_indications,
QueryExecutionContext = {
'Database': DATABASE
}, ResultConfiguration = {
'OutputLocation': S3_OUTPUT,
'EncryptionConfiguration': {
'EncryptionOption': 'SSE_S3'
}
}
)query_execution_id = response['QueryExecutionId']logger.info(query_execution_id)# poll the query execution status to check if the query has completed executionfor i in range(1, 1 + RETRY_COUNT):query_status = athena.get_query_execution(QueryExecutionId = query_execution_id)query_execution_status = query_status['QueryExecution']['Status']['State']if query_execution_status == 'SUCCEEDED':
logger.info("STATUS:" + query_execution_status)
breakif query_execution_status == 'FAILED':
logger.info(query_status)
raise Exception("STATUS:" + query_execution_status)else:
logger.info("STATUS:" + query_execution_status)
time.sleep(5)else:athena.stop_query_execution(QueryExecutionId = query_execution_id)raise Exception('TIME OVER')result = athena.get_query_results(QueryExecutionId = query_execution_id)logger.info(result)# athena store the result of the query in the athena output bucket we had specified earlier.# send the url of the athena result file to the next lambda function in the step functionresults_file = "athenaoutput/" + query_execution_id + '.csv'query_results = {}query_results['results_file'] = results_filequery_results['query_execution_id'] = query_execution_idreturn query_results
Call Comprehend Medical: The step function uses an iteration pattern to iterate through all the records in the query results data. This is due to the 15 minute limit on a Lambda execution. If the process to call Comprehend Medical for all records exceeds 15 mins, the workflow will fail. Therefore, we use the iterator pattern to process the data in smaller chunks that can be expected to complete before the 15 minutes limit. Following code snippet shows the process to invoke Comprehend Medical and do medical entity extraction.
import json
import boto3
import time
import urllib
import os
import csv
import pandas as pd
import numpy as npimport logging
logger = logging.getLogger()
log_handler = logger.handlers[0]
log_handler.setFormatter(logging.Formatter("[%(asctime)s] %(levelname)s:%(name)s:%(message)s", "%Y-%m-%d %H:%M:%S"))
logger.setLevel(logging.INFO)s3 = boto3.resource('s3')
client = boto3.client(service_name='comprehendmedical', region_name='us-east-1')def handler(event, context):
if event:
logger.info(">> calling comprehend medical to extract entities")# get the url of the athena results file from the event
bucket_name = os.environ['bucket_name']
comprehend_output_bucket_name = os.environ['comprehend_output_bucket_name']
file_keyname = event['data_node']['query_results']['results_file']
print(">>> bucket name " + bucket_name)
print(">>> comprehend output bucket name " + comprehend_output_bucket_name)
print(">>> comprehend output file name " + file_keyname)
results_file = event['data_node']['query_results']['results_file']
# query_execution_id = event['data_node']['query_results']['query_execution_id']# read the iteration parameters from the event
startIndex = event['iterator']['comprehendedItems']
chunksize = event['iterator']['comprehend_chunksize']
totalItemsToComprehend = event['iterator']['totalItemsToComprehend']
iteration_num = event['iterator']['iteration_num']if ((startIndex + chunksize) <= totalItemsToComprehend):
finalIndex = startIndex + chunksize
else:
finalIndex = totalItemsToComprehend# read the csv file into a dataframe and pull out a subset (startIndex - finalIndex)
data = pd.read_csv(read_file_from_s3(bucket_name, file_keyname))
data_subset = data[startIndex:finalIndex]
data_subset = data_subset.fillna('')# iterate through the pandas datframe subset and call comprehend medical
data_to_persist = {}
datalist = []
for row in data_subset.itertuples():
dataitem = {}
dataitem['id'] = row.id
print(row.id)
text_list = []
# if indication_and_usage filed is empty -> replace the corresponding fields in the dataitem with NA.
# also, we do not run the entity extraction on these empty texts
temptext = row.indications_and_usage
temptext = temptext[0:20000] # use the first 20000 characters. DetectEntities operation has a size limit of 20000
print(row.indications_and_usage)
if temptext:
result = client.detect_entities(Text = temptext)
entities = result['Entities']
for entity in entities:
text_list.append(entity['Text'])
dataitem['extracted_text'] = text_list
else:
dataitem['indications_and_usage'] = 'NA'
dataitem['extracted_entities'] = 'NA'
datalist.append(dataitem)
data_to_persist['datalist'] = datalist print(data_to_persist['datalist'])
# write out the enriched data to local lambda storage as a json
with open('/tmp/comprehended.json', 'w') as outfile:
json.dump(data_to_persist['datalist'], outfile)# setup the S3 url to write the comprehended data
comprehend_output = 'fda-product-indications/comprehendoutput/comprehended-' + str(iteration_num) + '.json'# write to s3
s3.meta.client.upload_file('/tmp/comprehended.json', Bucket = comprehend_output_bucket_name, Key = comprehend_output, ExtraArgs={'ServerSideEncryption':'AES256'})# do iteration housekeeping
prev_comprehended_items = int(event['iterator']['comprehendedItems'])
event['iterator']['comprehendedItems'] = len(data_subset) + prev_comprehended_items
event['iterator']['iteration_num'] = iteration_num + 1# send the url of the written output file to the next lambda in the chain
return event['iterator']def read_file_from_s3(bucket_name, key):
print(">> reading S3 object...")
response = s3.Object(bucket_name, key).get()
return response['Body']
Once the key entities are extracted and stored in S3. we can build dashboards using QuickSight to visualize the data and gain more insights. This workflow is a common use case for NLP processing on text data and shows how we could build a pipeline relatively easily to embed NLP capabilities within our use cases without the need to train models.