diff --git a/constants.py b/constants.py index 2bfb8fb4e8e39b27bd5c39707b3566e22225601f..2716befa577d6c61e030cdc2cb247c9474462e73 100644 --- a/constants.py +++ b/constants.py @@ -4,6 +4,7 @@ MLLP_END_CHAR = b"\x1c\x0d" # Path to load and store the trained Decision Tree model DT_MODEL_PATH = "model/dt_model.joblib" +ON_DISK_DB_PATH = "database.db" # Map for AKI Label LABELS_MAP = {"n": 0, "y": 1} diff --git a/database.db b/database.db index d082c74aaf44e4cd2bebdf8ad2960b3ae32dbe6d..b180a82c2c64a41746c80c1f447b50794fe23084 100644 Binary files a/database.db and b/database.db differ diff --git a/memory_db.py b/memory_db.py index 0098ccaa56f2468133759952fcdbab8438fd62e0..575c77d99841878a00d18caf5e7e7c747157085e 100644 --- a/memory_db.py +++ b/memory_db.py @@ -1,17 +1,35 @@ import sqlite3 +from constants import ON_DISK_DB_PATH +import os class InMemoryDatabase(): def __init__(self): self.connection = sqlite3.connect(':memory:') - self.initialise_table() + self.initialise_tables() - def initialise_table(self): + def initialise_tables(self): """ - Initialises the database with the patient features table. + Initialise the database with the patient features table. """ - query = """ + create_patients = """ CREATE TABLE patients ( + mrn TEXT PRIMARY KEY, + age INTEGER, + sex TEXT + ); + """ + create_test_results = """ + CREATE TABLE test_results ( + mrn TEXT, + date DATETIME, + result DECIMAL, + PRIMARY KEY (mrn, date), + FOREIGN KEY (mrn) REFERENCES patients (mrn) + ); + """ + create_patient_features = """ + CREATE TABLE features ( mrn TEXT PRIMARY KEY, age INTEGER, sex TEXT, @@ -25,13 +43,15 @@ class InMemoryDatabase(): aki TEXT ); """ - # features table - self.connection.execute(query) + # create the tables + self.connection.execute(create_patients) + self.connection.execute(create_test_results) + self.connection.execute(create_patient_features) - def insert_patient(self, mrn, age, sex, c1, rv1, rv1_r, rv2, rv2_r, change, D, aki=None): + def insert_patient_features(self, mrn, age, sex, c1, rv1, rv1_r, rv2, rv2_r, change, D, aki=None): """ - Inserts the obtained features into the in-memory database. + Insert the obtained features into the in-memory database. Args: - mrn {str}: Medical Record Number of the patient - age {int}: Age of the patient @@ -45,49 +65,178 @@ class InMemoryDatabase(): - D {float}: Difference between current and lowest previous result (48h) - aki {str}: Whether the patient has been diagnosed with aki ('y'/'n') """ - query = """ - INSERT INTO patients + INSERT INTO features (mrn, age, sex, C1, RV1, RV1_ratio, RV2, RV2_ratio, has_changed_48h, D, aki) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?) """ - + # execute the query self.connection.execute( query, (mrn, age, sex, c1, rv1, rv1_r, rv2, rv2_r, change, D, aki) ) self.connection.commit() + + def insert_patient(self, mrn, age, sex): + """ + Insert the patient info from PAS into the in-memory database. + Args: + - mrn {str}: Medical Record Number of the patient + - age {int}: Age of the patient + - sex {str}: Sex of the patient ('m'/'f') + """ + query = """ + INSERT INTO patients + (mrn, age, sex) + VALUES + (?, ?, ?) + """ + # execute the query + self.connection.execute( + query, + (mrn, age, sex) + ) + self.connection.commit() + - def get_patient(self, mrn): + def insert_test_result(self, mrn, date, result): + """ + Insert the patient info from PAS into the in-memory database. + Args: + - mrn {str}: Medical Record Number of the patient + - date {datetime}: creatinine result date + - result {float}: creatinine result + """ + query = """ + INSERT INTO test_results + (mrn, date, result) + VALUES + (?, ?, ?) + """ + # execute the query + self.connection.execute( + query, + (mrn, date, result) + ) + self.connection.commit() + + + def get_patient_features(self, mrn): """ - Query the patient data for a given mrn. + Query the features table for a given mrn. + Args: + - mrn {str}: Medical Record Number """ + cursor = self.connection.cursor() + cursor.execute('SELECT * FROM features WHERE mrn = ?', (mrn,)) + return cursor.fetchone() + + def get_patient(self, mrn): + """ + Query the patients table for a given mrn. + Args: + - mrn {str}: Medical Record Number + """ cursor = self.connection.cursor() cursor.execute('SELECT * FROM patients WHERE mrn = ?', (mrn,)) return cursor.fetchone() - def update_patient(self, mrn, **kwargs): + def get_test_results(self, mrn): + """ + Query the test results table for a given mrn. + Args: + - mrn {str}: Medical Record Number + """ + cursor = self.connection.cursor() + cursor.execute('SELECT * FROM test_results WHERE mrn = ?', (mrn,)) + return cursor.fetchall() + + + def get_patient_history(self, mrn): + """ + Get patient info along with all their test results and their dates. + Args: + - mrn {str}: Medical Record Number + Returns: + - _ {list}: List of records + """ + query = """ + SELECT + * + FROM + patients + JOIN + test_results + ON + patients.mrn = test_results.mrn + WHERE patients.mrn = ? + """ + cursor = self.connection.cursor() + cursor.execute(query, (mrn,)) + return cursor.fetchall() + + + def discharge_patient(self, mrn, update_disk_db=True): + """ + Remove the patient record from patients table in-memory and on-disk. Test + results are kept in the test_results table for historic data. + Args: + - mrn {str}: Medical Record Number + """ + # delete from in-memory + self.connection.execute('DELETE FROM patients WHERE mrn = ?', (mrn,)) + self.connection.commit() + # delete from on-disk + if update_disk_db: + disk_conn = sqlite3.connect(ON_DISK_DB_PATH) + disk_conn.execute('DELETE FROM patients WHERE mrn = ?', (mrn,)) + disk_conn.commit() + disk_conn.close() + + + def update_patient_features(self, mrn, **kwargs): """ Update patient information based on the provided keyword arguments. Args: - mrn {str}: Medical Record Number of the patient to update - **kwargs {dict}: Where key=column, value=new value """ - # construct the SET part of the SQL query based on the given args - set_clause = ", ".join([f"{key} = ?" for key in kwargs]) - query = f"UPDATE patients SET {set_clause} WHERE mrn = ?" - + set_clause = ', '.join([f"{key} = ?" for key in kwargs]) + query = f'UPDATE features SET {set_clause} WHERE mrn = ?' # prepare the values for the placeholders in the SQL statement values = list(kwargs.values()) + [mrn] - - # execute the update query + # execute the query self.connection.execute(query, values) self.connection.commit() + + + def persist_db(self): + """ + Persist the in-memory database to disk. + Args: + - disk_db_path {str}: the path to the database + """ + # create an empty db file if it does not exist already + if not os.path.exists(ON_DISK_DB_PATH): + with open(ON_DISK_DB_PATH, 'w'): + pass + # connect to the disk db + disk_connection = sqlite3.connect(ON_DISK_DB_PATH, check_same_thread=False) + # backs up and closes the connection + with disk_connection: + self.connection.backup(disk_connection) + + + def load_db(self): + """ + Load the on-disk database into the in-memory database. + """ + pass def close(self): diff --git a/test_memory_db.py b/test_memory_db.py index 86161063656486af2d6ba40cbcdcad213e4f4864..60312a217351cdf4d0192d7a652182e8ce707850 100644 --- a/test_memory_db.py +++ b/test_memory_db.py @@ -1,5 +1,6 @@ import unittest from memory_db import InMemoryDatabase +from datetime import datetime class TestInMemoryDatabase(unittest.TestCase): def setUp(self): @@ -16,27 +17,67 @@ class TestInMemoryDatabase(unittest.TestCase): self.db.close() - def test_insert_and_get_for_patient(self): + def test_insert_and_get_for_patient_features(self): actual_record = ('31251122', 42, 'm', 142.22, 127.45, 1.12, 156.89, 0.91, False, 0, None) # insert - self.db.insert_patient(*actual_record) + self.db.insert_patient_features(*actual_record) # get - queried_record = self.db.get_patient('31251122') + queried_record = self.db.get_patient_features('31251122') self.assertEqual(actual_record, queried_record) - def test_insert_and_update_for_patient(self): + def test_insert_and_update_for_patient_features(self): actual_record = ['31251122', 42, 'm', 142.22, 127.45, 1.12, 156.89, 0.91, False, 0, None] # insert - self.db.insert_patient(*actual_record) + self.db.insert_patient_features(*actual_record) # update - self.db.update_patient('31251122', RV1=114.98, RV1_ratio=1.24) + self.db.update_patient_features('31251122', RV1=114.98, RV1_ratio=1.24) actual_record[4] = 114.98 actual_record[5] = 1.24 # get patient after update - queried_record = self.db.get_patient('31251122') + queried_record = self.db.get_patient_features('31251122') self.assertEqual(tuple(actual_record), queried_record) + def test_insert_and_get_for_patients(self): + actual_record = ('0012352', 29, 'f') + # insert + self.db.insert_patient(*actual_record) + # get + queried_record = self.db.get_patient('0012352') + self.assertEqual(actual_record, queried_record) + + + def test_insert_and_get_for_test_results(self): + actual_record = ('0012352', str(datetime.today()), 109.43) + # insert + self.db.insert_test_result(*actual_record) + # get + queried_record = self.db.get_test_results('0012352')[0] + self.assertEqual(actual_record, queried_record) + + + def test_get_patient_history(self): + patient = ['0012352', 29, 'f'] + test_result = ['0012352', str(datetime.today()), 109.43] + # insert + self.db.insert_patient(*patient) + self.db.insert_test_result(*test_result) + # get + queried_record = self.db.get_patient_history(patient[0])[0] + self.assertEqual(tuple(patient + test_result), queried_record) + + + def test_discharge_patient(self): + patient = ['0012352', 29, 'f'] + # insert + self.db.insert_patient(*patient) + # discharge + self.db.discharge_patient(patient[0], update_disk_db=False) + # get patient + queried_record = self.db.get_patient(patient[0]) + self.assertIsNone(queried_record) + + if __name__ == '__main__': unittest.main() \ No newline at end of file diff --git a/utils.py b/utils.py index 69dccf52d9c6f399547b2928201f42be18274346..8dd74c2e8db0838601c1f47403469aa630bae082 100644 --- a/utils.py +++ b/utils.py @@ -1,7 +1,7 @@ import socket import hl7 import datetime - +import csv from constants import MLLP_START_CHAR, MLLP_END_CHAR, REVERSE_LABELS_MAP @@ -57,3 +57,30 @@ def predict_with_dt(dt_model, data): labels = [REVERSE_LABELS_MAP[item] for item in y_pred] return labels + + +def populate_test_results_table(db, path): + """ + Reads in the training/testing data from a csv file and returns + a list of that data. + Args: + - db {InMemoryDatabase}: the database object + - path {str}: path to the data + """ + with open(path, newline='') as f: + rows = csv.reader(f) + for i, row in enumerate(rows): + # skip header + if i == 0: + continue + + # remove empty strings + while row and row[-1] == '': + row.pop() + + mrn = row[0] + # for each date, result pair insert into the table + for j in range(1, len(row), 2): + date = row[j] + result = float(row[j+1]) + db.insert_test_result(mrn, date, result) diff --git a/utils_test.py b/utils_test.py index baa11ce8de6c59cf7162b2b3cd32f0975e15bd4d..c84c747a983e3ef3a2cd402e521b96b037f296f9 100644 --- a/utils_test.py +++ b/utils_test.py @@ -3,7 +3,9 @@ from utils import ( process_mllp_message, parse_hl7_message, create_acknowledgement, + populate_test_results_table, ) +from memory_db import InMemoryDatabase import hl7 @@ -41,6 +43,17 @@ class TestUtilsClient(unittest.TestCase): self.assertIn(b"ACK", ack_message) self.assertIn(b"MSA|AA|", ack_message) + + def test_populate_test_results_table(self): + db = InMemoryDatabase() + populate_test_results_table(db, 'history.csv') + # expected result + expected_result = ('822825', '2024-01-01 06:12:00', 68.58) + result = db.get_test_results(expected_result[0])[0] + # close the db + db.close() + self.assertEqual(result, expected_result) + if __name__ == "__main__": unittest.main()