Skip to content
Snippets Groups Projects
Commit 424b9a8d authored by RohitMidha23's avatar RohitMidha23
Browse files

Added function to use DT model for prediction

parent c0040ea1
No related branches found
No related tags found
1 merge request!3Use DT Model for Prediction
......@@ -2,9 +2,10 @@ FROM ubuntu:jammy
RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get -yq install python3
COPY simulator.py /simulator/
COPY simulator_test.py /simulator/
COPY dt_model.joblib /model/
WORKDIR /simulator
RUN ./simulator_test.py
COPY messages.mllp /data/
EXPOSE 8440
EXPOSE 8441
CMD /simulator/simulator.py --messages=/data/messages.mllp
\ No newline at end of file
CMD /simulator/simulator.py --messages=/data/messages.mllp
# MLLP constants
MLLP_START_CHAR = b"\x0b"
MLLP_END_CHAR = b"\x1c\x0d"
# Path to load and store the trained Decision Tree model
DT_MODEL_PATH = "model/dt_model.joblib"
# Map for AKI Label
LABELS_MAP = {"n": 0, "y": 1}
# Reverse labels map for writing the final output
REVERSE_LABELS_MAP = {v: k for k, v in LABELS_MAP.items()}
File added
import socket
from joblib import load
from utils import process_mllp_message, parse_hl7_message, create_acknowledgement
from constants import DT_MODEL_PATH, REVERSE_LABELS_MAP
def start_server(host="0.0.0.0", port=8440):
"""
Starts the TCP server to listen for incoming MLLP messages on the specified port.
"""
# Load the model once for use through out
dt_model = load(DT_MODEL_PATH)
assert dt_model != None, "Model is not loaded properly..."
# Start the server
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
sock.connect((host, port))
print(f"Connected to simulator on {host}:{port}")
......
......@@ -2,7 +2,7 @@ import socket
import hl7
import datetime
from constants import MLLP_START_CHAR, MLLP_END_CHAR
from constants import MLLP_START_CHAR, MLLP_END_CHAR, REVERSE_LABELS_MAP
def process_mllp_message(data):
......@@ -24,7 +24,7 @@ def parse_hl7_message(hl7_data):
return message
def create_acknowledgement(hl7_msg):
def create_acknowledgement():
"""
Creates an HL7 ACK message for the received message.
"""
......@@ -33,3 +33,27 @@ def create_acknowledgement(hl7_msg):
framed_ack = MLLP_START_CHAR + ack_msg.encode() + MLLP_END_CHAR
return framed_ack
def predict_with_dt(dt_model, data):
"""
Following data needs to be passed:
[
"age",
"sex",
"C1",
"RV1",
"RV1_ratio",
"RV2",
"RV2_ratio",
"change_within_48hrs",
"D"
]
Predict with the DT Model on the data.
Returns the predicted labels.
"""
y_pred = dt_model.predict(data)
# Map the predictions to labels
labels = [REVERSE_LABELS_MAP[item] for item in y_pred]
return labels
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment