Skip to content
Snippets Groups Projects
Commit 5c3fb721 authored by OnurZa's avatar OnurZa
Browse files

Add comments to model.py.

parent c08d7e8e
Branches master
No related tags found
No related merge requests found
......@@ -311,7 +311,7 @@ n
n
n
n
n
y
n
n
n
......@@ -611,7 +611,7 @@ n
n
n
n
y
n
n
n
n
......@@ -863,7 +863,7 @@ n
n
n
n
y
n
n
n
n
......@@ -1742,7 +1742,7 @@ n
y
n
n
n
y
n
n
n
......@@ -2078,7 +2078,7 @@ y
n
n
y
y
n
y
y
n
......@@ -2134,13 +2134,13 @@ n
n
n
n
n
y
y
y
n
n
n
n
y
n
y
y
......@@ -2241,7 +2241,7 @@ n
n
y
n
y
n
n
n
n
......@@ -2259,7 +2259,7 @@ n
n
y
n
y
n
n
n
n
......@@ -2752,7 +2752,7 @@ n
n
n
n
y
n
n
n
n
......@@ -3511,7 +3511,7 @@ y
n
n
n
n
y
n
n
n
......@@ -3776,7 +3776,7 @@ n
y
y
y
n
y
y
n
y
......@@ -3817,7 +3817,7 @@ n
y
n
n
y
n
y
n
n
......@@ -3933,7 +3933,7 @@ n
y
y
n
y
n
n
y
n
......@@ -3967,7 +3967,7 @@ n
y
n
n
n
y
n
y
y
......@@ -4032,7 +4032,7 @@ n
n
n
n
y
n
n
n
y
......@@ -4358,7 +4358,7 @@ y
n
n
n
n
y
n
y
n
......@@ -4760,7 +4760,7 @@ n
n
y
y
y
n
y
n
y
......@@ -4935,7 +4935,7 @@ n
n
n
n
y
n
y
n
y
......@@ -5204,7 +5204,7 @@ n
n
n
n
y
n
n
n
y
......@@ -5255,7 +5255,7 @@ n
n
n
n
y
n
n
y
y
......@@ -5683,7 +5683,7 @@ n
n
n
n
y
n
n
n
n
......@@ -5923,7 +5923,7 @@ n
y
n
n
y
n
y
n
n
......@@ -6294,7 +6294,7 @@ y
n
n
n
n
y
n
n
n
......@@ -6515,7 +6515,7 @@ y
n
n
n
y
n
n
n
n
......
......@@ -3,37 +3,60 @@
import argparse
import csv
import pandas as pd
from sklearn.metrics import fbeta_score, classification_report
from sklearn.metrics import fbeta_score
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LogisticRegression
def preprocess_data(input_file):
data = pd.read_csv(input_file)
def preprocess_data(data):
"""
Preprocesses the input data.
This function performs the following steps:
1. Reads the input data from a CSV file.
2. Computes the mean, minimum, maximum, and variance of all columns containing "creatinine_result".
3. Drops all columns containing "creatinine".
4. Fills missing values in the "age" column with the median age.
5. Maps the "sex" column to binary values (1 for male, 0 for female).
6. Fills any remaining missing values in the "age", "sex", and computed creatinine columns with 0.
7. Separates the target variable "aki" (if present) from the features.
Args:
data (str): The file path to the CSV data file.
Returns:
tuple: A tuple containing:
- X (pd.DataFrame): The preprocessed feature data.
- y (pd.Series or None): The target variable if "aki" is present in the data, otherwise None.
"""
data = pd.read_csv(data)
# Extract the creatinine results, and all columns that contain the word "creatinine"
creatinine_results = [col for col in data.columns if "creatinine_result" in col]
all_creatinine_cols = [col for col in data.columns if "creatinine" in col]
# Compute the mean, min, max, and variance of the creatinine results
data["creatinine_mean"] = data[creatinine_results].mean(axis=1)
data["creatinine_min"] = data[creatinine_results].min(axis=1)
data["creatinine_max"] = data[creatinine_results].max(axis=1)
data["creatinine_variance"] = data[creatinine_results].var(axis=1)
data["creatinine_median"] = data[creatinine_results].median(axis=1)
# Once computed, drop the original creatinine columns to
# keep only the computed features
data = data.drop(columns=all_creatinine_cols)
# Fill missing values in the "age" column with the median age
data["age"] = data["age"].fillna(data["age"].median())
# Encode the 'sex' column as binary values
data["sex"] = data["sex"].map({"m": 1, "f": 0})
# Fill any remaining missing values with 0
data["creatinine_variance"] = data["creatinine_variance"].fillna(0)
data["creatinine_mean"] = data["creatinine_mean"].fillna(0)
data["creatinine_min"] = data["creatinine_min"].fillna(0)
data["creatinine_max"] = data["creatinine_max"].fillna(0)
data["creatinine_median"] = data["creatinine_median"].fillna(0)
data["age"] = data["age"].fillna(0)
data["sex"] = data["sex"].fillna(0)
# Separate the target variable if present. Return the
# features and target variable, if present
y = None
X = data
if "aki" in data.columns:
......@@ -43,31 +66,42 @@ def preprocess_data(input_file):
return X, y
def main():
"""
Trains a logistic regression model to predict the presence of AKI.
Logistic regression is a powerful but efficient linear model that is often used as a
baseline model for binary classification tasks, hance making it a suitable choice here.
"""
parser = argparse.ArgumentParser()
parser.add_argument("--input", default="test.csv")
parser.add_argument("--output", default="aki.csv")
parser.add_argument("--train", default="data/training.csv")
flags = parser.parse_args()
# Preprocess the training data
X_train, y_train = preprocess_data(flags.train)
X_test, y_test = preprocess_data(flags.input)
# Standardize the data
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
# Train a logistic regression model with balanced class weights.
model = LogisticRegression(class_weight="balanced", random_state=42)
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
preds = ["y" if p == 1 else "n" for p in y_pred]
# Make the predictions and write to a CSV file
preds = ["y" if p == 1 else "n" for p in y_pred]
with open(flags.output, "w", newline="") as f:
writer = csv.writer(f)
writer.writerow(["aki"])
for pred in preds:
writer.writerow([pred])
# Evalute the model by computing the F3 score
if y_test is not None:
f3 = fbeta_score(y_test, y_pred, beta=3)
print("F3 score: ", f3)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment