Skip to content
Snippets Groups Projects
Commit 3fd05f5c authored by Midha, Rohit's avatar Midha, Rohit
Browse files

Merge branch 'setting-up-in-memory-db' into 'master'

Implemented further database functionality

See merge request !4
parents 4dd93cd8 8842d913
No related branches found
No related tags found
1 merge request!4Implemented further database functionality
......@@ -4,6 +4,7 @@ MLLP_END_CHAR = b"\x1c\x0d"
# Path to load and store the trained Decision Tree model
DT_MODEL_PATH = "model/dt_model.joblib"
ON_DISK_DB_PATH = "database.db"
# Map for AKI Label
LABELS_MAP = {"n": 0, "y": 1}
......
No preview for this file type
import sqlite3
from constants import ON_DISK_DB_PATH
import os
class InMemoryDatabase():
def __init__(self):
self.connection = sqlite3.connect(':memory:')
self.initialise_table()
self.initialise_tables()
def initialise_table(self):
def initialise_tables(self):
"""
Initialises the database with the patient features table.
Initialise the database with the patient features table.
"""
query = """
create_patients = """
CREATE TABLE patients (
mrn TEXT PRIMARY KEY,
age INTEGER,
sex TEXT
);
"""
create_test_results = """
CREATE TABLE test_results (
mrn TEXT,
date DATETIME,
result DECIMAL,
PRIMARY KEY (mrn, date),
FOREIGN KEY (mrn) REFERENCES patients (mrn)
);
"""
create_patient_features = """
CREATE TABLE features (
mrn TEXT PRIMARY KEY,
age INTEGER,
sex TEXT,
......@@ -25,13 +43,15 @@ class InMemoryDatabase():
aki TEXT
);
"""
# features table
self.connection.execute(query)
# create the tables
self.connection.execute(create_patients)
self.connection.execute(create_test_results)
self.connection.execute(create_patient_features)
def insert_patient(self, mrn, age, sex, c1, rv1, rv1_r, rv2, rv2_r, change, D, aki=None):
def insert_patient_features(self, mrn, age, sex, c1, rv1, rv1_r, rv2, rv2_r, change, D, aki=None):
"""
Inserts the obtained features into the in-memory database.
Insert the obtained features into the in-memory database.
Args:
- mrn {str}: Medical Record Number of the patient
- age {int}: Age of the patient
......@@ -45,49 +65,178 @@ class InMemoryDatabase():
- D {float}: Difference between current and lowest previous result (48h)
- aki {str}: Whether the patient has been diagnosed with aki ('y'/'n')
"""
query = """
INSERT INTO patients
INSERT INTO features
(mrn, age, sex, C1, RV1, RV1_ratio, RV2, RV2_ratio, has_changed_48h, D, aki)
VALUES
(?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
"""
# execute the query
self.connection.execute(
query,
(mrn, age, sex, c1, rv1, rv1_r, rv2, rv2_r, change, D, aki)
)
self.connection.commit()
def insert_patient(self, mrn, age, sex):
"""
Insert the patient info from PAS into the in-memory database.
Args:
- mrn {str}: Medical Record Number of the patient
- age {int}: Age of the patient
- sex {str}: Sex of the patient ('m'/'f')
"""
query = """
INSERT INTO patients
(mrn, age, sex)
VALUES
(?, ?, ?)
"""
# execute the query
self.connection.execute(
query,
(mrn, age, sex)
)
self.connection.commit()
def get_patient(self, mrn):
def insert_test_result(self, mrn, date, result):
"""
Insert the patient info from PAS into the in-memory database.
Args:
- mrn {str}: Medical Record Number of the patient
- date {datetime}: creatinine result date
- result {float}: creatinine result
"""
query = """
INSERT INTO test_results
(mrn, date, result)
VALUES
(?, ?, ?)
"""
# execute the query
self.connection.execute(
query,
(mrn, date, result)
)
self.connection.commit()
def get_patient_features(self, mrn):
"""
Query the patient data for a given mrn.
Query the features table for a given mrn.
Args:
- mrn {str}: Medical Record Number
"""
cursor = self.connection.cursor()
cursor.execute('SELECT * FROM features WHERE mrn = ?', (mrn,))
return cursor.fetchone()
def get_patient(self, mrn):
"""
Query the patients table for a given mrn.
Args:
- mrn {str}: Medical Record Number
"""
cursor = self.connection.cursor()
cursor.execute('SELECT * FROM patients WHERE mrn = ?', (mrn,))
return cursor.fetchone()
def update_patient(self, mrn, **kwargs):
def get_test_results(self, mrn):
"""
Query the test results table for a given mrn.
Args:
- mrn {str}: Medical Record Number
"""
cursor = self.connection.cursor()
cursor.execute('SELECT * FROM test_results WHERE mrn = ?', (mrn,))
return cursor.fetchall()
def get_patient_history(self, mrn):
"""
Get patient info along with all their test results and their dates.
Args:
- mrn {str}: Medical Record Number
Returns:
- _ {list}: List of records
"""
query = """
SELECT
*
FROM
patients
JOIN
test_results
ON
patients.mrn = test_results.mrn
WHERE patients.mrn = ?
"""
cursor = self.connection.cursor()
cursor.execute(query, (mrn,))
return cursor.fetchall()
def discharge_patient(self, mrn, update_disk_db=True):
"""
Remove the patient record from patients table in-memory and on-disk. Test
results are kept in the test_results table for historic data.
Args:
- mrn {str}: Medical Record Number
"""
# delete from in-memory
self.connection.execute('DELETE FROM patients WHERE mrn = ?', (mrn,))
self.connection.commit()
# delete from on-disk
if update_disk_db:
disk_conn = sqlite3.connect(ON_DISK_DB_PATH)
disk_conn.execute('DELETE FROM patients WHERE mrn = ?', (mrn,))
disk_conn.commit()
disk_conn.close()
def update_patient_features(self, mrn, **kwargs):
"""
Update patient information based on the provided keyword arguments.
Args:
- mrn {str}: Medical Record Number of the patient to update
- **kwargs {dict}: Where key=column, value=new value
"""
# construct the SET part of the SQL query based on the given args
set_clause = ", ".join([f"{key} = ?" for key in kwargs])
query = f"UPDATE patients SET {set_clause} WHERE mrn = ?"
set_clause = ', '.join([f"{key} = ?" for key in kwargs])
query = f'UPDATE features SET {set_clause} WHERE mrn = ?'
# prepare the values for the placeholders in the SQL statement
values = list(kwargs.values()) + [mrn]
# execute the update query
# execute the query
self.connection.execute(query, values)
self.connection.commit()
def persist_db(self):
"""
Persist the in-memory database to disk.
Args:
- disk_db_path {str}: the path to the database
"""
# create an empty db file if it does not exist already
if not os.path.exists(ON_DISK_DB_PATH):
with open(ON_DISK_DB_PATH, 'w'):
pass
# connect to the disk db
disk_connection = sqlite3.connect(ON_DISK_DB_PATH, check_same_thread=False)
# backs up and closes the connection
with disk_connection:
self.connection.backup(disk_connection)
def load_db(self):
"""
Load the on-disk database into the in-memory database.
"""
pass
def close(self):
......
import unittest
from memory_db import InMemoryDatabase
from datetime import datetime
class TestInMemoryDatabase(unittest.TestCase):
def setUp(self):
......@@ -16,27 +17,67 @@ class TestInMemoryDatabase(unittest.TestCase):
self.db.close()
def test_insert_and_get_for_patient(self):
def test_insert_and_get_for_patient_features(self):
actual_record = ('31251122', 42, 'm', 142.22, 127.45, 1.12, 156.89, 0.91, False, 0, None)
# insert
self.db.insert_patient(*actual_record)
self.db.insert_patient_features(*actual_record)
# get
queried_record = self.db.get_patient('31251122')
queried_record = self.db.get_patient_features('31251122')
self.assertEqual(actual_record, queried_record)
def test_insert_and_update_for_patient(self):
def test_insert_and_update_for_patient_features(self):
actual_record = ['31251122', 42, 'm', 142.22, 127.45, 1.12, 156.89, 0.91, False, 0, None]
# insert
self.db.insert_patient(*actual_record)
self.db.insert_patient_features(*actual_record)
# update
self.db.update_patient('31251122', RV1=114.98, RV1_ratio=1.24)
self.db.update_patient_features('31251122', RV1=114.98, RV1_ratio=1.24)
actual_record[4] = 114.98
actual_record[5] = 1.24
# get patient after update
queried_record = self.db.get_patient('31251122')
queried_record = self.db.get_patient_features('31251122')
self.assertEqual(tuple(actual_record), queried_record)
def test_insert_and_get_for_patients(self):
actual_record = ('0012352', 29, 'f')
# insert
self.db.insert_patient(*actual_record)
# get
queried_record = self.db.get_patient('0012352')
self.assertEqual(actual_record, queried_record)
def test_insert_and_get_for_test_results(self):
actual_record = ('0012352', str(datetime.today()), 109.43)
# insert
self.db.insert_test_result(*actual_record)
# get
queried_record = self.db.get_test_results('0012352')[0]
self.assertEqual(actual_record, queried_record)
def test_get_patient_history(self):
patient = ['0012352', 29, 'f']
test_result = ['0012352', str(datetime.today()), 109.43]
# insert
self.db.insert_patient(*patient)
self.db.insert_test_result(*test_result)
# get
queried_record = self.db.get_patient_history(patient[0])[0]
self.assertEqual(tuple(patient + test_result), queried_record)
def test_discharge_patient(self):
patient = ['0012352', 29, 'f']
# insert
self.db.insert_patient(*patient)
# discharge
self.db.discharge_patient(patient[0], update_disk_db=False)
# get patient
queried_record = self.db.get_patient(patient[0])
self.assertIsNone(queried_record)
if __name__ == '__main__':
unittest.main()
\ No newline at end of file
import socket
import hl7
import datetime
import csv
from constants import MLLP_START_CHAR, MLLP_END_CHAR, REVERSE_LABELS_MAP
......@@ -57,3 +57,30 @@ def predict_with_dt(dt_model, data):
labels = [REVERSE_LABELS_MAP[item] for item in y_pred]
return labels
def populate_test_results_table(db, path):
"""
Reads in the training/testing data from a csv file and returns
a list of that data.
Args:
- db {InMemoryDatabase}: the database object
- path {str}: path to the data
"""
with open(path, newline='') as f:
rows = csv.reader(f)
for i, row in enumerate(rows):
# skip header
if i == 0:
continue
# remove empty strings
while row and row[-1] == '':
row.pop()
mrn = row[0]
# for each date, result pair insert into the table
for j in range(1, len(row), 2):
date = row[j]
result = float(row[j+1])
db.insert_test_result(mrn, date, result)
......@@ -3,7 +3,9 @@ from utils import (
process_mllp_message,
parse_hl7_message,
create_acknowledgement,
populate_test_results_table,
)
from memory_db import InMemoryDatabase
import hl7
......@@ -41,6 +43,17 @@ class TestUtilsClient(unittest.TestCase):
self.assertIn(b"ACK", ack_message)
self.assertIn(b"MSA|AA|", ack_message)
def test_populate_test_results_table(self):
db = InMemoryDatabase()
populate_test_results_table(db, 'history.csv')
# expected result
expected_result = ('822825', '2024-01-01 06:12:00', 68.58)
result = db.get_test_results(expected_result[0])[0]
# close the db
db.close()
self.assertEqual(result, expected_result)
if __name__ == "__main__":
unittest.main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment