Spaces:
Sleeping
Sleeping
Ajay Karthick Senthil Kumar
commited on
Commit
·
5bd622e
1
Parent(s):
eb57aa2
update
Browse files- README.md +5 -5
- __pycache__/config.cpython-39.pyc +0 -0
- __pycache__/metrics.cpython-39.pyc +0 -0
- __pycache__/utils.cpython-39.pyc +0 -0
- app.py +175 -0
- config.py +135 -0
- data/tokenizer.pickle +3 -0
- metrics.py +20 -0
- model/model_1.h5 +3 -0
- requirements.txt +6 -0
- utils.py +123 -0
README.md
CHANGED
@@ -1,10 +1,10 @@
|
|
1 |
---
|
2 |
title: NER Medical Text
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
1 |
---
|
2 |
title: NER Medical Text
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: indigo
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.24.1
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
__pycache__/config.cpython-39.pyc
ADDED
Binary file (2.68 kB). View file
|
|
__pycache__/metrics.cpython-39.pyc
ADDED
Binary file (964 Bytes). View file
|
|
__pycache__/utils.cpython-39.pyc
ADDED
Binary file (3.67 kB). View file
|
|
app.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import tensorflow as tf
|
3 |
+
import os
|
4 |
+
|
5 |
+
# Import your utility functions
|
6 |
+
from utils import (
|
7 |
+
predict_multi_line_text,
|
8 |
+
tokenizer,
|
9 |
+
)
|
10 |
+
|
11 |
+
from config import index_to_label, acronyms_to_entities, MAX_LENGTH
|
12 |
+
from metrics import precision, recall, f1_score
|
13 |
+
|
14 |
+
# Register the custom metric functions
|
15 |
+
tf.keras.utils.get_custom_objects()[precision.__name__] = precision
|
16 |
+
tf.keras.utils.get_custom_objects()[recall.__name__] = recall
|
17 |
+
tf.keras.utils.get_custom_objects()[f1_score.__name__] = f1_score
|
18 |
+
|
19 |
+
# Load your trained model
|
20 |
+
model_dir = './model' # Adjust the path as needed
|
21 |
+
model_1 = tf.keras.models.load_model(os.path.join(model_dir, 'model_1.h5'))
|
22 |
+
|
23 |
+
# Define label colors for different entity types suitable for dark background
|
24 |
+
LABEL_COLORS = {
|
25 |
+
'Activity': '#FF7F50', # Coral
|
26 |
+
'Administration': '#6495ED', # Cornflower Blue
|
27 |
+
'Age': '#FFB6C1', # Light Pink
|
28 |
+
'Area': '#7FFF00', # Chartreuse
|
29 |
+
'Biological_attribute': '#FFD700', # Gold
|
30 |
+
'Biological_structure': '#00FA9A', # Medium Spring Green
|
31 |
+
'Clinical_event': '#BA55D3', # Medium Orchid
|
32 |
+
'Color': '#00CED1', # Dark Turquoise
|
33 |
+
'Coreference': '#FFA07A', # Light Salmon
|
34 |
+
'Date': '#ADFF2F', # Green Yellow
|
35 |
+
'Detailed_description': '#DA70D6', # Orchid
|
36 |
+
'Diagnostic_procedure': '#87CEFA', # Light Sky Blue
|
37 |
+
'Disease_disorder': '#FF4500', # Orange Red
|
38 |
+
'Distance': '#32CD32', # Lime Green
|
39 |
+
'Dosage': '#8A2BE2', # Blue Violet
|
40 |
+
'Duration': '#F08080', # Light Coral
|
41 |
+
'Family_history': '#20B2AA', # Light Sea Green
|
42 |
+
'Frequency': '#FF6347', # Tomato
|
43 |
+
'Height': '#4682B4', # Steel Blue
|
44 |
+
'History': '#EE82EE', # Violet
|
45 |
+
'Lab_value': '#FFDAB9', # Peach Puff
|
46 |
+
'Mass': '#7B68EE', # Medium Slate Blue
|
47 |
+
'Medication': '#00FF7F', # Spring Green
|
48 |
+
'Nonbiological_location': '#FF69B4', # Hot Pink
|
49 |
+
'Occupation': '#BDB76B', # Dark Khaki
|
50 |
+
'Other_entity': '#D3D3D3', # Light Grey
|
51 |
+
'Other_event': '#FF1493', # Deep Pink
|
52 |
+
'Outcome': '#00BFFF', # Deep Sky Blue
|
53 |
+
'Personal_background': '#00FFFF', # Aqua
|
54 |
+
'Qualitative_concept': '#FFA500', # Orange
|
55 |
+
'Quantitative_concept': '#FFA500', # Orange (same as above)
|
56 |
+
'Severity': '#1E90FF', # Dodger Blue
|
57 |
+
'Sex': '#FF00FF', # Magenta
|
58 |
+
'Shape': '#40E0D0', # Turquoise
|
59 |
+
'Sign_symptom': '#FFFF00', # Yellow
|
60 |
+
'Subject': '#F0E68C', # Khaki
|
61 |
+
'Texture': '#98FB98', # Pale Green
|
62 |
+
'Therapeutic_procedure': '#8B008B', # Dark Magenta
|
63 |
+
'Time': '#DC143C', # Crimson
|
64 |
+
'Volume': '#5F9EA0', # Cadet Blue
|
65 |
+
'Weight': '#FA8072', # Salmon
|
66 |
+
}
|
67 |
+
|
68 |
+
# Define the prediction function
|
69 |
+
def predict_ner(text):
|
70 |
+
try:
|
71 |
+
# Predict entities
|
72 |
+
entities = predict_multi_line_text(
|
73 |
+
text,
|
74 |
+
model_1,
|
75 |
+
index_to_label,
|
76 |
+
acronyms_to_entities,
|
77 |
+
MAX_LENGTH
|
78 |
+
)
|
79 |
+
|
80 |
+
# Sort entities by their start position
|
81 |
+
entities = sorted(entities, key=lambda x: x[0])
|
82 |
+
|
83 |
+
# Build HTML string with highlighted entities
|
84 |
+
html_output = ""
|
85 |
+
last_idx = 0
|
86 |
+
|
87 |
+
for start, end, label in entities:
|
88 |
+
# Append text before the entity
|
89 |
+
if last_idx < start:
|
90 |
+
html_output += text[last_idx:start]
|
91 |
+
|
92 |
+
# Get the color for the label, default to light grey if not specified
|
93 |
+
color = LABEL_COLORS.get(label, '#D3D3D3') # Light grey
|
94 |
+
|
95 |
+
# Wrap the entity with a span tag including style
|
96 |
+
entity_text = text[start:end]
|
97 |
+
# Include the label next to the entity
|
98 |
+
html_output += f'''<span style="background-color: {color}; font-weight: bold; padding: 2px; border-radius: 4px; margin: 1px;">{entity_text} <span style="font-size: smaller; font-weight: normal;">[{label}]</span></span>'''
|
99 |
+
|
100 |
+
last_idx = end
|
101 |
+
|
102 |
+
# Append any remaining text
|
103 |
+
if last_idx < len(text):
|
104 |
+
html_output += text[last_idx:]
|
105 |
+
|
106 |
+
return html_output
|
107 |
+
|
108 |
+
except Exception as e:
|
109 |
+
return f"<p style='color:red;'>Error: {str(e)}</p>"
|
110 |
+
|
111 |
+
# Set up the Streamlit app with dark theme
|
112 |
+
st.set_page_config(page_title="Medical NER", page_icon="🩺", layout="wide")
|
113 |
+
|
114 |
+
# Apply custom CSS for dark background and text colors
|
115 |
+
st.markdown(
|
116 |
+
"""
|
117 |
+
<style>
|
118 |
+
/* Main app background */
|
119 |
+
.stApp {
|
120 |
+
background-color: #2E2E2E;
|
121 |
+
color: #FFFFFF;
|
122 |
+
}
|
123 |
+
/* Text input area */
|
124 |
+
.stTextArea textarea {
|
125 |
+
background-color: #1E1E1E;
|
126 |
+
color: #FFFFFF;
|
127 |
+
}
|
128 |
+
/* Adjust the Analyze button */
|
129 |
+
div.stButton > button:first-child {
|
130 |
+
background-color: #1E90FF;
|
131 |
+
color: #FFFFFF;
|
132 |
+
}
|
133 |
+
/* Scrollbar styling */
|
134 |
+
::-webkit-scrollbar {
|
135 |
+
width: 10px;
|
136 |
+
}
|
137 |
+
::-webkit-scrollbar-track {
|
138 |
+
background: #1E1E1E;
|
139 |
+
}
|
140 |
+
::-webkit-scrollbar-thumb {
|
141 |
+
background: #888;
|
142 |
+
}
|
143 |
+
::-webkit-scrollbar-thumb:hover {
|
144 |
+
background: #555;
|
145 |
+
}
|
146 |
+
/* Style for the highlighted entities */
|
147 |
+
.highlighted-entity {
|
148 |
+
padding: 2px;
|
149 |
+
border-radius: 4px;
|
150 |
+
margin: 1px;
|
151 |
+
font-weight: bold;
|
152 |
+
display: inline-block;
|
153 |
+
}
|
154 |
+
</style>
|
155 |
+
""",
|
156 |
+
unsafe_allow_html=True
|
157 |
+
)
|
158 |
+
|
159 |
+
st.title("🩺 Medical Named Entity Recognition")
|
160 |
+
st.markdown("""
|
161 |
+
Enter medical text below to identify and highlight entities such as diseases, medications, and anatomical terms.
|
162 |
+
""")
|
163 |
+
|
164 |
+
# Input text area
|
165 |
+
text_input = st.text_area("Enter medical text here:", height=200)
|
166 |
+
|
167 |
+
# Analyze button
|
168 |
+
if st.button("Analyze"):
|
169 |
+
if text_input.strip():
|
170 |
+
with st.spinner("Analyzing..."):
|
171 |
+
result = predict_ner(text_input)
|
172 |
+
# Display the result with HTML rendering
|
173 |
+
st.markdown(f"<div style='font-size: 18px;'>{result}</div>", unsafe_allow_html=True)
|
174 |
+
else:
|
175 |
+
st.warning("Please enter some text to analyze.")
|
config.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
entity_to_acronyms = {
|
2 |
+
'Activity': 'ACT',
|
3 |
+
'Administration': 'ADM',
|
4 |
+
'Age': 'AGE',
|
5 |
+
'Area': 'ARA',
|
6 |
+
'Biological_attribute': 'BAT',
|
7 |
+
'Biological_structure': 'BST',
|
8 |
+
'Clinical_event': 'CLE',
|
9 |
+
'Color': 'COL',
|
10 |
+
'Coreference': 'COR',
|
11 |
+
'Date': 'DAT',
|
12 |
+
'Detailed_description': 'DET',
|
13 |
+
'Diagnostic_procedure': 'DIA',
|
14 |
+
'Disease_disorder': 'DIS',
|
15 |
+
'Distance': 'DIS',
|
16 |
+
'Dosage': 'DOS',
|
17 |
+
'Duration': 'DUR',
|
18 |
+
'Family_history': 'FAM',
|
19 |
+
'Frequency': 'FRE',
|
20 |
+
'Height': 'HEI',
|
21 |
+
'History': 'HIS',
|
22 |
+
'Lab_value': 'LAB',
|
23 |
+
'Mass': 'MAS',
|
24 |
+
'Medication': 'MED',
|
25 |
+
'Nonbiological_location': 'NBL',
|
26 |
+
'Occupation': 'OCC',
|
27 |
+
'Other_entity': 'OTH',
|
28 |
+
'Other_event': 'OTE',
|
29 |
+
'Outcome': 'OUT',
|
30 |
+
'Personal_background': 'PER',
|
31 |
+
'Qualitative_concept': 'QUC',
|
32 |
+
'Quantitative_concept': 'QUC',
|
33 |
+
'Severity': 'SEV',
|
34 |
+
'Sex': 'SEX',
|
35 |
+
'Shape': 'SHA',
|
36 |
+
'Sign_symptom': 'SIG',
|
37 |
+
'Subject': 'SUB',
|
38 |
+
'Texture': 'TEX',
|
39 |
+
'Therapeutic_procedure': 'THP',
|
40 |
+
'Time': 'TIM',
|
41 |
+
'Volume': 'VOL',
|
42 |
+
'Weight': 'WEI'
|
43 |
+
}
|
44 |
+
|
45 |
+
index_to_label = {1: 'B-ACT',
|
46 |
+
2: 'B-ADM',
|
47 |
+
3: 'B-AGE',
|
48 |
+
4: 'B-ARA',
|
49 |
+
5: 'B-BAT',
|
50 |
+
6: 'B-BST',
|
51 |
+
7: 'B-CLE',
|
52 |
+
8: 'B-COL',
|
53 |
+
9: 'B-COR',
|
54 |
+
10: 'B-DAT',
|
55 |
+
11: 'B-DET',
|
56 |
+
12: 'B-DIA',
|
57 |
+
13: 'B-DIS',
|
58 |
+
14: 'B-DOS',
|
59 |
+
15: 'B-DUR',
|
60 |
+
16: 'B-FAM',
|
61 |
+
17: 'B-FRE',
|
62 |
+
18: 'B-HEI',
|
63 |
+
19: 'B-HIS',
|
64 |
+
20: 'B-LAB',
|
65 |
+
21: 'B-MAS',
|
66 |
+
22: 'B-MED',
|
67 |
+
23: 'B-NBL',
|
68 |
+
24: 'B-OCC',
|
69 |
+
25: 'B-OTE',
|
70 |
+
26: 'B-OTH',
|
71 |
+
27: 'B-OUT',
|
72 |
+
28: 'B-PER',
|
73 |
+
29: 'B-QUC',
|
74 |
+
30: 'B-SEV',
|
75 |
+
31: 'B-SEX',
|
76 |
+
32: 'B-SHA',
|
77 |
+
33: 'B-SIG',
|
78 |
+
34: 'B-SUB',
|
79 |
+
35: 'B-TEX',
|
80 |
+
36: 'B-THP',
|
81 |
+
37: 'B-TIM',
|
82 |
+
38: 'B-VOL',
|
83 |
+
39: 'B-WEI',
|
84 |
+
40: 'I-ACT',
|
85 |
+
41: 'I-ADM',
|
86 |
+
42: 'I-AGE',
|
87 |
+
43: 'I-ARA',
|
88 |
+
44: 'I-BAT',
|
89 |
+
45: 'I-BST',
|
90 |
+
46: 'I-CLE',
|
91 |
+
47: 'I-COL',
|
92 |
+
48: 'I-COR',
|
93 |
+
49: 'I-DAT',
|
94 |
+
50: 'I-DET',
|
95 |
+
51: 'I-DIA',
|
96 |
+
52: 'I-DIS',
|
97 |
+
53: 'I-DOS',
|
98 |
+
54: 'I-DUR',
|
99 |
+
55: 'I-FAM',
|
100 |
+
56: 'I-FRE',
|
101 |
+
57: 'I-HEI',
|
102 |
+
58: 'I-HIS',
|
103 |
+
59: 'I-LAB',
|
104 |
+
60: 'I-MAS',
|
105 |
+
61: 'I-MED',
|
106 |
+
62: 'I-NBL',
|
107 |
+
63: 'I-OCC',
|
108 |
+
64: 'I-OTE',
|
109 |
+
65: 'I-OTH',
|
110 |
+
66: 'I-OUT',
|
111 |
+
67: 'I-PER',
|
112 |
+
68: 'I-QUC',
|
113 |
+
69: 'I-SEV',
|
114 |
+
70: 'I-SHA',
|
115 |
+
71: 'I-SIG',
|
116 |
+
72: 'I-SUB',
|
117 |
+
73: 'I-TEX',
|
118 |
+
74: 'I-THP',
|
119 |
+
75: 'I-TIM',
|
120 |
+
76: 'I-VOL',
|
121 |
+
77: 'I-WEI',
|
122 |
+
78: 'O',
|
123 |
+
0: '<PAD>'}
|
124 |
+
|
125 |
+
MAX_LENGTH = 100
|
126 |
+
|
127 |
+
acronyms_to_entities = {v: k for k, v in entity_to_acronyms.items()}
|
128 |
+
|
129 |
+
|
130 |
+
models = {
|
131 |
+
"model_1": {
|
132 |
+
"path": "model/model_1.h5",
|
133 |
+
"title": "Bidirectional LSTM Model with single LSTM layer"
|
134 |
+
},
|
135 |
+
}
|
data/tokenizer.pickle
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4f6c875d9180a973c5f297a0e332404c3b79816ace97a91ae39885d20440258a
|
3 |
+
size 277589
|
metrics.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from keras import backend as K
|
2 |
+
|
3 |
+
def precision(y_true, y_pred):
|
4 |
+
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
|
5 |
+
predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
|
6 |
+
_precision = true_positives / (predicted_positives + K.epsilon())
|
7 |
+
return _precision
|
8 |
+
|
9 |
+
def recall(y_true, y_pred):
|
10 |
+
"""Compute recall metric"""
|
11 |
+
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
|
12 |
+
possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
|
13 |
+
return true_positives / (possible_positives + K.epsilon())
|
14 |
+
|
15 |
+
def f1_score(y_true, y_pred):
|
16 |
+
"""Compute f1-score metric"""
|
17 |
+
_precision = precision(y_true, y_pred)
|
18 |
+
_recall = recall(y_true, y_pred)
|
19 |
+
f1_score = 2 * ((_precision * _recall) / (_precision + _recall + K.epsilon()))
|
20 |
+
return f1_score
|
model/model_1.h5
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:2f98117d817b0603a5d8daaaeaca952d6fd6e427bb45167b3602bde1ccbf6823
|
3 |
+
size 19859592
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tensorflow==2.13.0
|
2 |
+
numpy>=1.21.0,<1.24.0
|
3 |
+
pandas>=1.3.0,<1.5.0
|
4 |
+
streamlit
|
5 |
+
nltk>=3.6.0
|
6 |
+
pickle-mixin
|
utils.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
import numpy as np
|
5 |
+
from tensorflow.keras.preprocessing.sequence import pad_sequences
|
6 |
+
import nltk
|
7 |
+
nltk.download('punkt')
|
8 |
+
nltk.download('stopwords')
|
9 |
+
from nltk.corpus import stopwords
|
10 |
+
|
11 |
+
STOP_WORDS = stopwords.words('english')
|
12 |
+
|
13 |
+
# Load the tokenizer from file
|
14 |
+
with open('./data/tokenizer.pickle', 'rb') as handle:
|
15 |
+
tokenizer = pickle.load(handle)
|
16 |
+
|
17 |
+
def clean_word(word):
|
18 |
+
"""
|
19 |
+
Cleans a word by removing non-alphanumeric characters and extra whitespaces,
|
20 |
+
converting it to lowercase, and checking if it is a stopword.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
- word (str): the word to clean
|
24 |
+
|
25 |
+
Returns:
|
26 |
+
- str: the cleaned word, or an empty string if it is a stopword
|
27 |
+
"""
|
28 |
+
# remove non-alphanumeric characters and extra whitespaces
|
29 |
+
word = re.sub(r'[^\w\s]', '', word)
|
30 |
+
word = re.sub(r'\s+', ' ', word)
|
31 |
+
|
32 |
+
# convert to lowercase
|
33 |
+
word = word.lower()
|
34 |
+
|
35 |
+
if word not in STOP_WORDS:
|
36 |
+
return word
|
37 |
+
|
38 |
+
return ''
|
39 |
+
|
40 |
+
def tokenize_text(text):
|
41 |
+
"""
|
42 |
+
Tokenizes a text into a list of cleaned words.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
- text (str): the text to tokenize
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
- tokens (list of str): the list of cleaned words
|
49 |
+
- start_end_ranges (list of tuples): the start and end character positions for each token
|
50 |
+
"""
|
51 |
+
regex_match = r'[^\s\u200a\-\u2010-\u2015\u2212\uff0d]+' # Regex to match words
|
52 |
+
tokens = []
|
53 |
+
start_end_ranges = []
|
54 |
+
# Tokenize the sentences in the text
|
55 |
+
sentences = nltk.sent_tokenize(text)
|
56 |
+
|
57 |
+
start = 0
|
58 |
+
for sentence in sentences:
|
59 |
+
|
60 |
+
sentence_tokens = re.findall(regex_match, sentence)
|
61 |
+
curr_sent_tokens = []
|
62 |
+
curr_sent_ranges = []
|
63 |
+
|
64 |
+
for word in sentence_tokens:
|
65 |
+
word = clean_word(word)
|
66 |
+
if word.strip():
|
67 |
+
start = text.lower().find(word, start)
|
68 |
+
end = start + len(word)
|
69 |
+
curr_sent_ranges.append((start, end))
|
70 |
+
curr_sent_tokens.append(word)
|
71 |
+
start = end
|
72 |
+
if len(curr_sent_tokens) > 0:
|
73 |
+
tokens.append(curr_sent_tokens)
|
74 |
+
start_end_ranges.append(curr_sent_ranges)
|
75 |
+
|
76 |
+
return tokens, start_end_ranges
|
77 |
+
|
78 |
+
def predict_multi_line_text(text, model, index_to_label, acronyms_to_entities, MAX_LENGTH):
|
79 |
+
"""
|
80 |
+
Predicts named entities for multi-line input text.
|
81 |
+
|
82 |
+
Args:
|
83 |
+
- text (str): The input text
|
84 |
+
- model: The trained NER model
|
85 |
+
- index_to_label: Dictionary mapping index to label
|
86 |
+
- acronyms_to_entities: Dictionary mapping acronyms to entity names
|
87 |
+
- MAX_LENGTH: Maximum input length for the model
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
- entities: A list of named entities in the format (start, end, label)
|
91 |
+
"""
|
92 |
+
|
93 |
+
sequences = []
|
94 |
+
sent_tokens, sent_start_end = tokenize_text(text)
|
95 |
+
|
96 |
+
for i in range(len(sent_tokens)):
|
97 |
+
sequence = tokenizer.texts_to_sequences([' '.join(token for token in sent_tokens[i])])
|
98 |
+
sequences.extend(sequence)
|
99 |
+
|
100 |
+
padded_sequence = pad_sequences(sequences, maxlen=MAX_LENGTH, padding='post')
|
101 |
+
|
102 |
+
# Make the prediction
|
103 |
+
prediction = model.predict(np.array(padded_sequence))
|
104 |
+
|
105 |
+
# Decode the prediction
|
106 |
+
predicted_labels = np.argmax(prediction, axis=-1)
|
107 |
+
|
108 |
+
predicted_labels = [
|
109 |
+
[index_to_label[i] for i in sent_predicted_labels]
|
110 |
+
for sent_predicted_labels in predicted_labels
|
111 |
+
]
|
112 |
+
|
113 |
+
entities = []
|
114 |
+
for tokens, sent_pred_labels, start_end_ranges in zip(sent_tokens, predicted_labels, sent_start_end):
|
115 |
+
for i, (token, label, start_end_range) in enumerate(zip(tokens, sent_pred_labels, start_end_ranges)):
|
116 |
+
start = start_end_range[0]
|
117 |
+
end = start_end_range[1]
|
118 |
+
if label not in ['O', '<PAD>']:
|
119 |
+
entity_type = acronyms_to_entities[label[2:]]
|
120 |
+
entity = (start, end, entity_type)
|
121 |
+
entities.append(entity)
|
122 |
+
|
123 |
+
return entities
|