Skip to content
Snippets Groups Projects
model.py 3.14 KiB
Newer Older
sm1524's avatar
sm1524 committed
#!/usr/bin/env python3
"""
This script is designed to make predictions for Acute Kidney Injury (AKI) on a hidden test dataset.
It uses a previously trained and saved model along with normalization constants and hyperparameters.

**Main Components:**
1. **Load Saved Parameters**:
   - Load normalization constants (`final_normalization_constants.json`) to preprocess the test data.
   - Load best hyperparameters (`best_hyperparameters.json`) to initialize the saved model structure.

2. **Preprocess Test Data**:
   - Normalize the test dataset using saved normalization constants.
   - Extract relevant features (mean, median, standard deviation, etc.) from creatinine test results.
   - Unlike the training data, no oversampling is performed.

3. **Load and Use the Model**:
   - Load the trained model saved as `docker_model.pkl`.
   - Use it to predict AKI for the given test dataset.

4. **Save Predictions**:
   - Save the predictions in a CSV file named `aki.csv` in the format required (`aki` column with "y" or "n" values).
"""

import argparse
import csv
import json
Stavros Mitsis's avatar
Stavros Mitsis committed
import torch
import pandas as pd
Stavros Mitsis's avatar
Stavros Mitsis committed
import os
Stavros Mitsis's avatar
Stavros Mitsis committed
from helper_functions import (
    DataProcessor,
    tune_hyperparameters,  # For hyperparameter search
    train_model,           # For final (non-tuning) training
    evaluate               # For evaluation or inference
)
sm1524's avatar
sm1524 committed


Stavros Mitsis's avatar
Stavros Mitsis committed
def main():
sm1524's avatar
sm1524 committed
    parser = argparse.ArgumentParser()
    parser.add_argument("--input", default="test.csv")
    parser.add_argument("--output", default="aki.csv")
Stavros Mitsis's avatar
Stavros Mitsis committed
    args = parser.parse_args()
Stavros Mitsis's avatar
Stavros Mitsis committed
    base_dir = os.path.dirname(os.path.abspath(__file__))
Stavros Mitsis's avatar
Stavros Mitsis committed
    # Preprocess the unlabeled test set (final_model=True => no 'aki' column).
    processor = DataProcessor()
    unlabeled_test_data = processor.preprocess([args.input], save_constants=False, final_model=True)

    # Load the final trained model (entire model, not just state_dict)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Stavros Mitsis's avatar
Stavros Mitsis committed
    model_path = os.path.join(base_dir, "best_model.pth")
Stavros Mitsis's avatar
Stavros Mitsis committed
    model = torch.load(model_path, map_location=device)
Stavros Mitsis's avatar
Stavros Mitsis committed
    # Also load the threshold from the best hyperparameters
Stavros Mitsis's avatar
Stavros Mitsis committed
    hyperparams_path = os.path.join(base_dir, "best_hyperparameters.json")
    with open(hyperparams_path, "r") as f:
Stavros Mitsis's avatar
Stavros Mitsis committed
        best_hparams = json.load(f)
    threshold = best_hparams['threshold']

    # Evaluate once on the entire unlabeled test set to get predictions
    # evaluate(...) returns a dict containing 'predictions' for each row
    results = evaluate(
        model,
        data=unlabeled_test_data,
        device=device,
        threshold=threshold
    )
    predictions = results["predictions"]  # list of 0/1
sm1524's avatar
sm1524 committed

Stavros Mitsis's avatar
Stavros Mitsis committed
    # Write predictions as y/n to CSV
    with open(args.input, 'r') as infile, open(args.output, 'w', newline="") as outfile:
        reader = csv.reader(infile)
        writer = csv.writer(outfile)
        writer.writerow(["aki"])  # Write header

        # Skip input headers
        next(reader, None)

        for pred in predictions:
            aki_label = 'y' if pred == 1 else 'n'
            writer.writerow([aki_label])
sm1524's avatar
sm1524 committed


if __name__ == "__main__":
    main()