Ajay Karthick Senthil Kumar commited on
Commit
5bd622e
·
1 Parent(s): eb57aa2
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: NER Medical Text
3
- emoji: 🐨
4
- colorFrom: blue
5
- colorTo: blue
6
- sdk: streamlit
7
- sdk_version: 1.39.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
  title: NER Medical Text
3
+ emoji: 🐢
4
+ colorFrom: red
5
+ colorTo: indigo
6
+ sdk: gradio
7
+ sdk_version: 3.24.1
8
  app_file: app.py
9
  pinned: false
10
  ---
__pycache__/config.cpython-39.pyc ADDED
Binary file (2.68 kB). View file
 
__pycache__/metrics.cpython-39.pyc ADDED
Binary file (964 Bytes). View file
 
__pycache__/utils.cpython-39.pyc ADDED
Binary file (3.67 kB). View file
 
app.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import tensorflow as tf
3
+ import os
4
+
5
+ # Import your utility functions
6
+ from utils import (
7
+ predict_multi_line_text,
8
+ tokenizer,
9
+ )
10
+
11
+ from config import index_to_label, acronyms_to_entities, MAX_LENGTH
12
+ from metrics import precision, recall, f1_score
13
+
14
+ # Register the custom metric functions
15
+ tf.keras.utils.get_custom_objects()[precision.__name__] = precision
16
+ tf.keras.utils.get_custom_objects()[recall.__name__] = recall
17
+ tf.keras.utils.get_custom_objects()[f1_score.__name__] = f1_score
18
+
19
+ # Load your trained model
20
+ model_dir = './model' # Adjust the path as needed
21
+ model_1 = tf.keras.models.load_model(os.path.join(model_dir, 'model_1.h5'))
22
+
23
+ # Define label colors for different entity types suitable for dark background
24
+ LABEL_COLORS = {
25
+ 'Activity': '#FF7F50', # Coral
26
+ 'Administration': '#6495ED', # Cornflower Blue
27
+ 'Age': '#FFB6C1', # Light Pink
28
+ 'Area': '#7FFF00', # Chartreuse
29
+ 'Biological_attribute': '#FFD700', # Gold
30
+ 'Biological_structure': '#00FA9A', # Medium Spring Green
31
+ 'Clinical_event': '#BA55D3', # Medium Orchid
32
+ 'Color': '#00CED1', # Dark Turquoise
33
+ 'Coreference': '#FFA07A', # Light Salmon
34
+ 'Date': '#ADFF2F', # Green Yellow
35
+ 'Detailed_description': '#DA70D6', # Orchid
36
+ 'Diagnostic_procedure': '#87CEFA', # Light Sky Blue
37
+ 'Disease_disorder': '#FF4500', # Orange Red
38
+ 'Distance': '#32CD32', # Lime Green
39
+ 'Dosage': '#8A2BE2', # Blue Violet
40
+ 'Duration': '#F08080', # Light Coral
41
+ 'Family_history': '#20B2AA', # Light Sea Green
42
+ 'Frequency': '#FF6347', # Tomato
43
+ 'Height': '#4682B4', # Steel Blue
44
+ 'History': '#EE82EE', # Violet
45
+ 'Lab_value': '#FFDAB9', # Peach Puff
46
+ 'Mass': '#7B68EE', # Medium Slate Blue
47
+ 'Medication': '#00FF7F', # Spring Green
48
+ 'Nonbiological_location': '#FF69B4', # Hot Pink
49
+ 'Occupation': '#BDB76B', # Dark Khaki
50
+ 'Other_entity': '#D3D3D3', # Light Grey
51
+ 'Other_event': '#FF1493', # Deep Pink
52
+ 'Outcome': '#00BFFF', # Deep Sky Blue
53
+ 'Personal_background': '#00FFFF', # Aqua
54
+ 'Qualitative_concept': '#FFA500', # Orange
55
+ 'Quantitative_concept': '#FFA500', # Orange (same as above)
56
+ 'Severity': '#1E90FF', # Dodger Blue
57
+ 'Sex': '#FF00FF', # Magenta
58
+ 'Shape': '#40E0D0', # Turquoise
59
+ 'Sign_symptom': '#FFFF00', # Yellow
60
+ 'Subject': '#F0E68C', # Khaki
61
+ 'Texture': '#98FB98', # Pale Green
62
+ 'Therapeutic_procedure': '#8B008B', # Dark Magenta
63
+ 'Time': '#DC143C', # Crimson
64
+ 'Volume': '#5F9EA0', # Cadet Blue
65
+ 'Weight': '#FA8072', # Salmon
66
+ }
67
+
68
+ # Define the prediction function
69
+ def predict_ner(text):
70
+ try:
71
+ # Predict entities
72
+ entities = predict_multi_line_text(
73
+ text,
74
+ model_1,
75
+ index_to_label,
76
+ acronyms_to_entities,
77
+ MAX_LENGTH
78
+ )
79
+
80
+ # Sort entities by their start position
81
+ entities = sorted(entities, key=lambda x: x[0])
82
+
83
+ # Build HTML string with highlighted entities
84
+ html_output = ""
85
+ last_idx = 0
86
+
87
+ for start, end, label in entities:
88
+ # Append text before the entity
89
+ if last_idx < start:
90
+ html_output += text[last_idx:start]
91
+
92
+ # Get the color for the label, default to light grey if not specified
93
+ color = LABEL_COLORS.get(label, '#D3D3D3') # Light grey
94
+
95
+ # Wrap the entity with a span tag including style
96
+ entity_text = text[start:end]
97
+ # Include the label next to the entity
98
+ html_output += f'''<span style="background-color: {color}; font-weight: bold; padding: 2px; border-radius: 4px; margin: 1px;">{entity_text} <span style="font-size: smaller; font-weight: normal;">[{label}]</span></span>'''
99
+
100
+ last_idx = end
101
+
102
+ # Append any remaining text
103
+ if last_idx < len(text):
104
+ html_output += text[last_idx:]
105
+
106
+ return html_output
107
+
108
+ except Exception as e:
109
+ return f"<p style='color:red;'>Error: {str(e)}</p>"
110
+
111
+ # Set up the Streamlit app with dark theme
112
+ st.set_page_config(page_title="Medical NER", page_icon="🩺", layout="wide")
113
+
114
+ # Apply custom CSS for dark background and text colors
115
+ st.markdown(
116
+ """
117
+ <style>
118
+ /* Main app background */
119
+ .stApp {
120
+ background-color: #2E2E2E;
121
+ color: #FFFFFF;
122
+ }
123
+ /* Text input area */
124
+ .stTextArea textarea {
125
+ background-color: #1E1E1E;
126
+ color: #FFFFFF;
127
+ }
128
+ /* Adjust the Analyze button */
129
+ div.stButton > button:first-child {
130
+ background-color: #1E90FF;
131
+ color: #FFFFFF;
132
+ }
133
+ /* Scrollbar styling */
134
+ ::-webkit-scrollbar {
135
+ width: 10px;
136
+ }
137
+ ::-webkit-scrollbar-track {
138
+ background: #1E1E1E;
139
+ }
140
+ ::-webkit-scrollbar-thumb {
141
+ background: #888;
142
+ }
143
+ ::-webkit-scrollbar-thumb:hover {
144
+ background: #555;
145
+ }
146
+ /* Style for the highlighted entities */
147
+ .highlighted-entity {
148
+ padding: 2px;
149
+ border-radius: 4px;
150
+ margin: 1px;
151
+ font-weight: bold;
152
+ display: inline-block;
153
+ }
154
+ </style>
155
+ """,
156
+ unsafe_allow_html=True
157
+ )
158
+
159
+ st.title("🩺 Medical Named Entity Recognition")
160
+ st.markdown("""
161
+ Enter medical text below to identify and highlight entities such as diseases, medications, and anatomical terms.
162
+ """)
163
+
164
+ # Input text area
165
+ text_input = st.text_area("Enter medical text here:", height=200)
166
+
167
+ # Analyze button
168
+ if st.button("Analyze"):
169
+ if text_input.strip():
170
+ with st.spinner("Analyzing..."):
171
+ result = predict_ner(text_input)
172
+ # Display the result with HTML rendering
173
+ st.markdown(f"<div style='font-size: 18px;'>{result}</div>", unsafe_allow_html=True)
174
+ else:
175
+ st.warning("Please enter some text to analyze.")
config.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ entity_to_acronyms = {
2
+ 'Activity': 'ACT',
3
+ 'Administration': 'ADM',
4
+ 'Age': 'AGE',
5
+ 'Area': 'ARA',
6
+ 'Biological_attribute': 'BAT',
7
+ 'Biological_structure': 'BST',
8
+ 'Clinical_event': 'CLE',
9
+ 'Color': 'COL',
10
+ 'Coreference': 'COR',
11
+ 'Date': 'DAT',
12
+ 'Detailed_description': 'DET',
13
+ 'Diagnostic_procedure': 'DIA',
14
+ 'Disease_disorder': 'DIS',
15
+ 'Distance': 'DIS',
16
+ 'Dosage': 'DOS',
17
+ 'Duration': 'DUR',
18
+ 'Family_history': 'FAM',
19
+ 'Frequency': 'FRE',
20
+ 'Height': 'HEI',
21
+ 'History': 'HIS',
22
+ 'Lab_value': 'LAB',
23
+ 'Mass': 'MAS',
24
+ 'Medication': 'MED',
25
+ 'Nonbiological_location': 'NBL',
26
+ 'Occupation': 'OCC',
27
+ 'Other_entity': 'OTH',
28
+ 'Other_event': 'OTE',
29
+ 'Outcome': 'OUT',
30
+ 'Personal_background': 'PER',
31
+ 'Qualitative_concept': 'QUC',
32
+ 'Quantitative_concept': 'QUC',
33
+ 'Severity': 'SEV',
34
+ 'Sex': 'SEX',
35
+ 'Shape': 'SHA',
36
+ 'Sign_symptom': 'SIG',
37
+ 'Subject': 'SUB',
38
+ 'Texture': 'TEX',
39
+ 'Therapeutic_procedure': 'THP',
40
+ 'Time': 'TIM',
41
+ 'Volume': 'VOL',
42
+ 'Weight': 'WEI'
43
+ }
44
+
45
+ index_to_label = {1: 'B-ACT',
46
+ 2: 'B-ADM',
47
+ 3: 'B-AGE',
48
+ 4: 'B-ARA',
49
+ 5: 'B-BAT',
50
+ 6: 'B-BST',
51
+ 7: 'B-CLE',
52
+ 8: 'B-COL',
53
+ 9: 'B-COR',
54
+ 10: 'B-DAT',
55
+ 11: 'B-DET',
56
+ 12: 'B-DIA',
57
+ 13: 'B-DIS',
58
+ 14: 'B-DOS',
59
+ 15: 'B-DUR',
60
+ 16: 'B-FAM',
61
+ 17: 'B-FRE',
62
+ 18: 'B-HEI',
63
+ 19: 'B-HIS',
64
+ 20: 'B-LAB',
65
+ 21: 'B-MAS',
66
+ 22: 'B-MED',
67
+ 23: 'B-NBL',
68
+ 24: 'B-OCC',
69
+ 25: 'B-OTE',
70
+ 26: 'B-OTH',
71
+ 27: 'B-OUT',
72
+ 28: 'B-PER',
73
+ 29: 'B-QUC',
74
+ 30: 'B-SEV',
75
+ 31: 'B-SEX',
76
+ 32: 'B-SHA',
77
+ 33: 'B-SIG',
78
+ 34: 'B-SUB',
79
+ 35: 'B-TEX',
80
+ 36: 'B-THP',
81
+ 37: 'B-TIM',
82
+ 38: 'B-VOL',
83
+ 39: 'B-WEI',
84
+ 40: 'I-ACT',
85
+ 41: 'I-ADM',
86
+ 42: 'I-AGE',
87
+ 43: 'I-ARA',
88
+ 44: 'I-BAT',
89
+ 45: 'I-BST',
90
+ 46: 'I-CLE',
91
+ 47: 'I-COL',
92
+ 48: 'I-COR',
93
+ 49: 'I-DAT',
94
+ 50: 'I-DET',
95
+ 51: 'I-DIA',
96
+ 52: 'I-DIS',
97
+ 53: 'I-DOS',
98
+ 54: 'I-DUR',
99
+ 55: 'I-FAM',
100
+ 56: 'I-FRE',
101
+ 57: 'I-HEI',
102
+ 58: 'I-HIS',
103
+ 59: 'I-LAB',
104
+ 60: 'I-MAS',
105
+ 61: 'I-MED',
106
+ 62: 'I-NBL',
107
+ 63: 'I-OCC',
108
+ 64: 'I-OTE',
109
+ 65: 'I-OTH',
110
+ 66: 'I-OUT',
111
+ 67: 'I-PER',
112
+ 68: 'I-QUC',
113
+ 69: 'I-SEV',
114
+ 70: 'I-SHA',
115
+ 71: 'I-SIG',
116
+ 72: 'I-SUB',
117
+ 73: 'I-TEX',
118
+ 74: 'I-THP',
119
+ 75: 'I-TIM',
120
+ 76: 'I-VOL',
121
+ 77: 'I-WEI',
122
+ 78: 'O',
123
+ 0: '<PAD>'}
124
+
125
+ MAX_LENGTH = 100
126
+
127
+ acronyms_to_entities = {v: k for k, v in entity_to_acronyms.items()}
128
+
129
+
130
+ models = {
131
+ "model_1": {
132
+ "path": "model/model_1.h5",
133
+ "title": "Bidirectional LSTM Model with single LSTM layer"
134
+ },
135
+ }
data/tokenizer.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f6c875d9180a973c5f297a0e332404c3b79816ace97a91ae39885d20440258a
3
+ size 277589
metrics.py ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from keras import backend as K
2
+
3
+ def precision(y_true, y_pred):
4
+ true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
5
+ predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
6
+ _precision = true_positives / (predicted_positives + K.epsilon())
7
+ return _precision
8
+
9
+ def recall(y_true, y_pred):
10
+ """Compute recall metric"""
11
+ true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
12
+ possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
13
+ return true_positives / (possible_positives + K.epsilon())
14
+
15
+ def f1_score(y_true, y_pred):
16
+ """Compute f1-score metric"""
17
+ _precision = precision(y_true, y_pred)
18
+ _recall = recall(y_true, y_pred)
19
+ f1_score = 2 * ((_precision * _recall) / (_precision + _recall + K.epsilon()))
20
+ return f1_score
model/model_1.h5 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f98117d817b0603a5d8daaaeaca952d6fd6e427bb45167b3602bde1ccbf6823
3
+ size 19859592
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ tensorflow==2.13.0
2
+ numpy>=1.21.0,<1.24.0
3
+ pandas>=1.3.0,<1.5.0
4
+ streamlit
5
+ nltk>=3.6.0
6
+ pickle-mixin
utils.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import os
3
+ import pickle
4
+ import numpy as np
5
+ from tensorflow.keras.preprocessing.sequence import pad_sequences
6
+ import nltk
7
+ nltk.download('punkt')
8
+ nltk.download('stopwords')
9
+ from nltk.corpus import stopwords
10
+
11
+ STOP_WORDS = stopwords.words('english')
12
+
13
+ # Load the tokenizer from file
14
+ with open('./data/tokenizer.pickle', 'rb') as handle:
15
+ tokenizer = pickle.load(handle)
16
+
17
+ def clean_word(word):
18
+ """
19
+ Cleans a word by removing non-alphanumeric characters and extra whitespaces,
20
+ converting it to lowercase, and checking if it is a stopword.
21
+
22
+ Args:
23
+ - word (str): the word to clean
24
+
25
+ Returns:
26
+ - str: the cleaned word, or an empty string if it is a stopword
27
+ """
28
+ # remove non-alphanumeric characters and extra whitespaces
29
+ word = re.sub(r'[^\w\s]', '', word)
30
+ word = re.sub(r'\s+', ' ', word)
31
+
32
+ # convert to lowercase
33
+ word = word.lower()
34
+
35
+ if word not in STOP_WORDS:
36
+ return word
37
+
38
+ return ''
39
+
40
+ def tokenize_text(text):
41
+ """
42
+ Tokenizes a text into a list of cleaned words.
43
+
44
+ Args:
45
+ - text (str): the text to tokenize
46
+
47
+ Returns:
48
+ - tokens (list of str): the list of cleaned words
49
+ - start_end_ranges (list of tuples): the start and end character positions for each token
50
+ """
51
+ regex_match = r'[^\s\u200a\-\u2010-\u2015\u2212\uff0d]+' # Regex to match words
52
+ tokens = []
53
+ start_end_ranges = []
54
+ # Tokenize the sentences in the text
55
+ sentences = nltk.sent_tokenize(text)
56
+
57
+ start = 0
58
+ for sentence in sentences:
59
+
60
+ sentence_tokens = re.findall(regex_match, sentence)
61
+ curr_sent_tokens = []
62
+ curr_sent_ranges = []
63
+
64
+ for word in sentence_tokens:
65
+ word = clean_word(word)
66
+ if word.strip():
67
+ start = text.lower().find(word, start)
68
+ end = start + len(word)
69
+ curr_sent_ranges.append((start, end))
70
+ curr_sent_tokens.append(word)
71
+ start = end
72
+ if len(curr_sent_tokens) > 0:
73
+ tokens.append(curr_sent_tokens)
74
+ start_end_ranges.append(curr_sent_ranges)
75
+
76
+ return tokens, start_end_ranges
77
+
78
+ def predict_multi_line_text(text, model, index_to_label, acronyms_to_entities, MAX_LENGTH):
79
+ """
80
+ Predicts named entities for multi-line input text.
81
+
82
+ Args:
83
+ - text (str): The input text
84
+ - model: The trained NER model
85
+ - index_to_label: Dictionary mapping index to label
86
+ - acronyms_to_entities: Dictionary mapping acronyms to entity names
87
+ - MAX_LENGTH: Maximum input length for the model
88
+
89
+ Returns:
90
+ - entities: A list of named entities in the format (start, end, label)
91
+ """
92
+
93
+ sequences = []
94
+ sent_tokens, sent_start_end = tokenize_text(text)
95
+
96
+ for i in range(len(sent_tokens)):
97
+ sequence = tokenizer.texts_to_sequences([' '.join(token for token in sent_tokens[i])])
98
+ sequences.extend(sequence)
99
+
100
+ padded_sequence = pad_sequences(sequences, maxlen=MAX_LENGTH, padding='post')
101
+
102
+ # Make the prediction
103
+ prediction = model.predict(np.array(padded_sequence))
104
+
105
+ # Decode the prediction
106
+ predicted_labels = np.argmax(prediction, axis=-1)
107
+
108
+ predicted_labels = [
109
+ [index_to_label[i] for i in sent_predicted_labels]
110
+ for sent_predicted_labels in predicted_labels
111
+ ]
112
+
113
+ entities = []
114
+ for tokens, sent_pred_labels, start_end_ranges in zip(sent_tokens, predicted_labels, sent_start_end):
115
+ for i, (token, label, start_end_range) in enumerate(zip(tokens, sent_pred_labels, start_end_ranges)):
116
+ start = start_end_range[0]
117
+ end = start_end_range[1]
118
+ if label not in ['O', '<PAD>']:
119
+ entity_type = acronyms_to_entities[label[2:]]
120
+ entity = (start, end, entity_type)
121
+ entities.append(entity)
122
+
123
+ return entities