Skip to content
Snippets Groups Projects
Commit 5a436266 authored by RohitMidha23's avatar RohitMidha23
Browse files

updated code to work with docker

parent 3c0394b7
No related branches found
No related tags found
1 merge request!10Integrates all the parts
FROM ubuntu:jammy
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -yq install python3
COPY simulator.py /simulator/
COPY simulator_test.py /simulator/
COPY dt_model.joblib /model/
WORKDIR /simulator
RUN ./simulator_test.py
COPY requirements.txt /app/
RUN pip3 install -r /app/requirements.txt
# copy model
COPY app/dt_model.joblib /app/
# copy scripts
COPY main.py /app/
COPY constants.py /app/
COPY utils.py /app/
COPY memory_db.py /app/
COPY feed_database.py /app/
RUN chmod +x /app/main.py
COPY messages.mllp /data/
EXPOSE 8440
EXPOSE 8441
CMD /simulator/simulator.py --messages=/data/messages.mllp
CMD /app/main.py --mllp=$MLLP_ADDRESS --pager=$PAGER_ADDRESS
This diff is collapsed.
File moved
This diff is collapsed.
......@@ -3,7 +3,7 @@ MLLP_START_CHAR = b"\x0b"
MLLP_END_CHAR = b"\x1c\x0d"
# Path to load and store the trained Decision Tree model
DT_MODEL_PATH = "model/dt_model.joblib"
DT_MODEL_PATH = "app/dt_model.joblib"
ON_DISK_DB_PATH = "database.db"
# Map for AKI Label
......
import socket
import argparse
from joblib import load
from utils import (
process_mllp_message,
parse_hl7_message,
create_acknowledgement,
parse_system_message,
strip_url,
)
from memory_db import InMemoryDatabase
from constants import DT_MODEL_PATH, REVERSE_LABELS_MAP, FEATURES_COLUMNS
from constants import DT_MODEL_PATH, FEATURES_COLUMNS
from utils import (
populate_test_results_table,
D_value_compute,
......@@ -21,29 +23,26 @@ import pandas as pd
import numpy as np
def start_server(host="0.0.0.0", port=8440, pager_port=8441):
def start_server(mllp_address, pager_address, debug=False):
"""
Starts the TCP server to listen for incoming MLLP messages on the specified port.
"""
latencies = [] # to measure latency
outputs = [] # to measure f3 score
if debug:
latencies = [] # to measure latency
outputs = [] # to measure f3 score
count = 0
mllp_host, mllp_port = strip_url(mllp_address)
# Initialise the in-memory database
db = InMemoryDatabase()
# print(db)
db = InMemoryDatabase() # this also loads the previous history
assert db != None, "In-memory Database is not initialised properly..."
# Populate the in-memory database with processed historical data
# populate_test_results_table(db, "history.csv")
# Load the model once for use through out
dt_model = load(DT_MODEL_PATH)
assert dt_model != None, "Model is not loaded properly..."
count = 0
count1 = 0
# Start the server
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.connect((host, port))
print(f"Connected to simulator on {host}:{port}")
sock.connect((mllp_host, int(mllp_port)))
print(f"Connected to simulator on {mllp_address}")
while True:
data = sock.recv(1024)
......@@ -54,9 +53,7 @@ def start_server(host="0.0.0.0", port=8440, pager_port=8441):
hl7_data = process_mllp_message(data)
if hl7_data:
message = parse_hl7_message(hl7_data)
# print("Parsed HL7 Message:")
# print(message)
# print(type(message))
category, mrn, data = parse_system_message(
message
) # category is type of system message and data consists of age sex if PAS admit or date of blood test and creatanine result
......@@ -69,9 +66,8 @@ def start_server(host="0.0.0.0", port=8440, pager_port=8441):
start_time = datetime.now()
patient_history = db.get_patient_history(str(mrn))
if len(patient_history) != 0:
print(f"patient {mrn} has history... ")
print("calculating features...")
count = count + 1
if debug:
count = count + 1
latest_creatine_result = data[1]
latest_creatine_date = data[0]
D = D_value_compute(
......@@ -95,18 +91,17 @@ def start_server(host="0.0.0.0", port=8440, pager_port=8441):
True,
D,
]
print("features crafted...")
input = pd.DataFrame([features], columns=FEATURES_COLUMNS)
aki = predict_with_dt(dt_model, input)
if aki[0] == "y":
outputs.append((mrn, latest_creatine_date))
print("Calling pager for mrn:", mrn)
send_pager_request(mrn)
if debug:
outputs.append((mrn, latest_creatine_date))
send_pager_request(mrn, pager_address)
end_time = datetime.now()
db.insert_test_result(mrn, data[0], data[1])
latency = end_time - start_time
latencies.append(latency)
print(latency)
if debug:
latency = end_time - start_time
latencies.append(latency)
# Create and send ACK message
ack_message = create_acknowledgement()
......@@ -114,33 +109,52 @@ def start_server(host="0.0.0.0", port=8440, pager_port=8441):
else:
print("No valid MLLP message received.")
# print("No data", count)
print("Patients with Historical Data", count)
if debug:
print("Patients with Historical Data", count)
# Calculate latency metrics
print(latencies)
mean_latency = np.mean(latencies)
median_latency = np.median(latencies)
min_latency = np.min(latencies)
max_latency = np.max(latencies)
percentile_99 = np.percentile(latencies, 99)
# Calculate latency metrics
mean_latency = np.mean(latencies)
median_latency = np.median(latencies)
min_latency = np.min(latencies)
max_latency = np.max(latencies)
percentile_99 = np.percentile(latencies, 99)
metrics = {
"Mean": mean_latency,
"Median": median_latency,
"Minimum": min_latency,
"Maximum": max_latency,
"99% Efficiency": percentile_99,
}
print(metrics)
metrics = {
"Mean": mean_latency,
"Median": median_latency,
"Minimum": min_latency,
"Maximum": max_latency,
"99% Efficiency": percentile_99,
}
print(metrics)
df = pd.DataFrame(outputs, columns=["mrn", "date"])
df["date"] = pd.to_datetime(df["date"]).dt.strftime("%Y-%m-%d %H:%M:%S")
df.to_csv("aki_predicted.csv", index=False)
df = pd.DataFrame(outputs, columns=["mrn", "date"])
df["date"] = pd.to_datetime(df["date"]).dt.strftime("%Y-%m-%d %H:%M:%S")
df.to_csv("aki_predicted.csv", index=False)
def main():
start_server()
parser = argparse.ArgumentParser()
parser.add_argument(
"--mllp",
default="0.0.0.0:8440",
type=str,
help="Port on which to get HL7 messages via MLLP",
)
parser.add_argument(
"--pager",
default="0.0.0.0:8441",
type=str,
help="Post on which to send pager requests via HTTP",
)
parser.add_argument(
"--debug",
default=False,
type=bool,
help="Whether to calculate F3 and Latency Score",
)
flags = parser.parse_args()
start_server(flags.mllp, flags.pager, flags.debug)
if __name__ == "__main__":
......
This diff is collapsed.
%% Cell type:code id: tags:
``` python
from memory_db import InMemoryDatabase
from utils import populate_tables
import csv
import constants
import datetime
```
%% Cell type:code id: tags:
``` python
db = InMemoryDatabase()
```
%% Cell type:code id: tags:
``` python
cursor = db.connection.cursor()
query = 'SELECT * FROM test_results;'
cursor.execute(query)
cursor.fetchall()
```
%% Output
[]
%% Cell type:code id: tags:
``` python
conn.close()
```
%% Cell type:code id: tags:
``` python
cursor = db.connection.cursor()
query = 'SELECT * FROM test_results;'
cursor.execute(query)
cursor.fetchall()
```
%% Output
[]
%% Cell type:code id: tags:
``` python
db.insert_patient('822825', 29, 'f')
db.insert_patient('16318', 42, 'm')
db.insert_patient('440673', 67, 'f')
```
%% Cell type:code id: tags:
``` python
db.get_patient_history('822825')
```
%% Output
[('822825', 29, 'f', '2024-01-01 06:12:00', 68.58),
('822825', 29, 'f', '2024-01-09 10:48:00', 70.58),
('822825', 29, 'f', '2024-01-09 14:20:00', 64.15),
('822825', 29, 'f', '2024-01-10 17:29:00', 48.39),
('822825', 29, 'f', '2024-01-17 06:27:00', 58.01),
('822825', 29, 'f', '2024-01-23 17:55:00', 85.93)]
%% Cell type:code id: tags:
``` python
```
%% Cell type:code id: tags:
``` python
db.get_patient_history('16318')
```
%% Output
[('16318', 42, 'm', '16318', '2024-01-01 09:47:00', 64.44),
('16318', 42, 'm', '16318', '2024-01-04 16:35:00', 1070423.72)]
%% Cell type:code id: tags:
``` python
db.get_patient_history('440673')
```
%% Output
[('440673', 67, 'f', '440673', '2024-01-01 10:47:00', 69.18),
('440673', 67, 'f', '440673', '2024-01-01 12:43:00', 83.32),
('440673', 67, 'f', '440673', '2024-01-05 13:04:00', 84.87),
('440673', 67, 'f', '440673', '2024-01-05 15:47:00', 76.07)]
%% Cell type:code id: tags:
``` python
db.discharge_patient('16318')
```
%% Cell type:code id: tags:
``` python
# db.persist_db()
```
%% Cell type:code id: tags:
``` python
db.close()
```
%% Cell type:code id: tags:
``` python
import pandas as pd
```
%% Cell type:code id: tags:
``` python
# Load the datasets
aki_df = pd.read_csv("aki.csv")
aki_pred_df = pd.read_csv("aki_predicted.csv")
# Assuming 'mrn' and 'date' are the column names in both CSV files
# Convert 'date' to datetime for accurate comparison
aki_df["date"] = pd.to_datetime(aki_df["date"])
aki_pred_df["date"] = pd.to_datetime(aki_pred_df["date"])
```
%% Cell type:code id: tags:
``` python
# Merge the datasets to find true positives
tp_df = pd.merge(aki_df, aki_pred_df, how="inner", on=["mrn", "date"])
```
%% Cell type:code id: tags:
``` python
# Count true positives, false positives, and false negatives
tp = len(tp_df)
fp = len(aki_pred_df) - tp
fn = len(aki_df) - tp
```
%% Cell type:code id: tags:
``` python
# Calculate precision, recall, and F3 score
precision = tp / (tp + fp) if tp + fp else 0
recall = tp / (tp + fn) if tp + fn else 0
beta_squared = 3**2
f3_score = (
(1 + beta_squared) * (precision * recall) / ((beta_squared * precision) + recall)
if (precision + recall)
else 0
)
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F3 Score: {f3_score}")
```
%% Output
Precision: 0.9819587628865979
Recall: 1.0
F3 Score: 0.9981660990306522
%% Cell type:code id: tags:
``` python
from urllib.parse import urlparse
def strip_url(url):
"""
Strips the URL and returns the host and port alone.
"""
url = url.split("://")[-1]
# Split the URL by "/" to separate the host and potentially the port
parts = url.split("/")
# Get the host part
host = parts[0].strip()
# Check if the port is specified
if ":" in host:
host, port = host.split(":")
port = int(port)
return host, port
```
%% Cell type:code id: tags:
``` python
strip_url("0.0.0.0:8440")
```
%% Output
('0.0.0.0', 8440)
%% Cell type:code id: tags:
``` python
```
......
import socket
import hl7
import pandas as pd
import numpy as np
import datetime
from sklearn.preprocessing import LabelEncoder
import joblib
import csv
from statistics import median
from constants import MLLP_START_CHAR, MLLP_END_CHAR, REVERSE_LABELS_MAP
import requests
from datetime import timedelta
def process_mllp_message(data):
......@@ -178,20 +175,20 @@ def D_value_compute(creat_latest_result, d1, lis):
:param row: The row of data from the dataframe.
:return: The computed D value.
"""
d1 = datetime.datetime.strptime(d1, '%Y%m%d%H%M%S')
d1 = datetime.datetime.strptime(d1, "%Y%m%d%H%M%S")
if type(lis[-1][3]) != int:
d2 = datetime.datetime.strptime(lis[-1][3], '%Y-%m-%d %H:%M:%S')
d2 = datetime.datetime.strptime(lis[-1][3], "%Y-%m-%d %H:%M:%S")
else:
d2 = datetime.datetime.strptime(str(lis[-1][3]), '%Y%m%d%H%M%S')
#Calculating the date within 48 hours
past_two_days = d1 - datetime.timedelta(days = 2)
d2 = datetime.datetime.strptime(str(lis[-1][3]), "%Y%m%d%H%M%S")
# Calculating the date within 48 hours
past_two_days = d1 - datetime.timedelta(days=2)
prev_lis_values = []
for i in range(len(lis)):
if type(lis[i][3]) != int:
d_ = datetime.datetime.strptime(lis[i][3], '%Y-%m-%d %H:%M:%S')
d_ = datetime.datetime.strptime(lis[i][3], "%Y-%m-%d %H:%M:%S")
else:
d_ = datetime.datetime.strptime(str(lis[i][3]), '%Y%m%d%H%M%S')
if d_<= past_two_days:
d_ = datetime.datetime.strptime(str(lis[i][3]), "%Y%m%d%H%M%S")
if d_ <= past_two_days:
prev_lis_values.append(lis[i][4])
if len(prev_lis_values) > 0:
# Finding the minimum value in the last two days
......@@ -212,15 +209,15 @@ def RV_compute(creat_latest_result, d1, lis):
:param row: The row of data from the dataframe.
:return: The computed RV value.
"""
#Calculating the difference of days between the two latest tests
d1 = datetime.datetime.strptime(d1, '%Y%m%d%H%M%S')
# Calculating the difference of days between the two latest tests
d1 = datetime.datetime.strptime(d1, "%Y%m%d%H%M%S")
if type(lis[-1][3]) != int:
d2 = datetime.datetime.strptime(lis[-1][3], '%Y-%m-%d %H:%M:%S')
d2 = datetime.datetime.strptime(lis[-1][3], "%Y-%m-%d %H:%M:%S")
else:
d2 = datetime.datetime.strptime(str(lis[-1][3]), '%Y%m%d%H%M%S')
diff = abs(((d2-d1).seconds)/86400 + (d2-d1).days)
#If difference in less than 7 days then use the minimum to compute the ratio
if diff<=7:
d2 = datetime.datetime.strptime(str(lis[-1][3]), "%Y%m%d%H%M%S")
diff = abs(((d2 - d1).seconds) / 86400 + (d2 - d1).days)
# If difference in less than 7 days then use the minimum to compute the ratio
if diff <= 7:
C1 = float(creat_latest_result)
minimum = float(min([float(lis[i][4]) for i in range(len(lis))]))
assert C1 / minimum is not None, "The RV value is None"
......@@ -235,8 +232,8 @@ def RV_compute(creat_latest_result, d1, lis):
elif diff <= 365:
C1 = float(creat_latest_result)
median_ = float(median([float(lis[i][4]) for i in range(len(lis))]))
assert C1/median_ is not None, "The RV value is None"
return C1, 0, 0, median_, C1/median_ #C1, RV1, RV1_ratio, RV2, RV2_ratio
assert C1 / median_ is not None, "The RV value is None"
return C1, 0, 0, median_, C1 / median_ # C1, RV1, RV1_ratio, RV2, RV2_ratio
else:
return 0
......@@ -248,31 +245,33 @@ def label_encode(sex):
:param column: The list of features to be encoded.
:return: List of encoded features.
"""
if sex == 'M' or sex == 'm':
if sex == "M" or sex == "m":
return 0
elif sex == 'F' or sex == 'f':
elif sex == "F" or sex == "f":
return 1
import requests
def send_pager_request(mrn):
def send_pager_request(mrn, pager_address):
# Define the URL for the pager request.
url = "http://0.0.0.0:8441/page"
pager_host, pager_port = strip_url(pager_address)
url = f"http://{pager_host}:{pager_port}/page"
headers = {"Content-Type": "text/plain"}
# Convert the MRN to a string and encode it to bytes, as the body of the POST request.
data = str(mrn).encode('utf-8')
data = str(mrn).encode("utf-8")
# Send the POST request with the MRN as the body.
response = requests.post(url, data=data)
response = requests.post(url, data=data, headers=headers)
# Check the response status code and print appropriate message.
if response.status_code == 200:
print(f"Request successful, server responded: {response.text}")
else:
print(f"Request failed, status code: {response.status_code}, message: {response.text}")
print(
f"Request failed, status code: {response.status_code}, message: {response.text}"
)
# Example usage
#send_pager_request(12345)
def load_model(file_path):
"""
......@@ -293,19 +292,20 @@ def load_model(file_path):
return None
def alert_response_team(host, port, mrn):
def strip_url(url):
"""
Sends a page to the pager server with the given MRN.
Strips the URL and returns the host and port alone.
"""
url = f"http://{host}:{port}/page"
headers = {"Content-Type": "text/plain"}
try:
response = requests.post(url, data=str(mrn), headers=headers)
if response.status_code == 200:
print(f"Successfully paged for MRN: {mrn}")
else:
print(
f"Failed to page for MRN: {mrn}. Status code: {response.status_code}, Response: {response.text}"
)
except requests.RequestException as e:
print(f"Request failed: {e}")
url = url.split("://")[-1]
# Split the URL by "/" to separate the host and potentially the port
parts = url.split("/")
# Get the host part
host = parts[0].strip()
# Check if the port is specified
if ":" in host:
host, port = host.split(":")
port = int(port)
return host, port
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment