Skip to content
Snippets Groups Projects
helper_functions.py 19.9 KiB
Newer Older
Stavros Mitsis's avatar
Stavros Mitsis committed
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import fbeta_score, accuracy_score
import pandas as pd
import numpy as np
import json
import optuna
Stavros Mitsis's avatar
Stavros Mitsis committed
import os
Stavros Mitsis's avatar
Stavros Mitsis committed

###############################################################################
# 1. Data Processing
###############################################################################
class DataProcessor:
    """
    Handles data loading, feature engineering, and normalization for AKI data.
    """
    def __init__(self):
        self.normalization_constants = {}

    def preprocess(self, filenames, save_constants=False, final_model=False):
        """
        Preprocess one or more CSV files (merge if multiple). Extract relevant features:
        - age
        - sex
        - (optionally) aki
        - latest_creatinine_value
        - median_previous, mean_previous, std_dev_previous (within 365 days)
        - abs_percentage_diff = |(latest - mean_previous)/mean_previous|

        Args:
            filenames (list of str): CSV files to load and combine.
            save_constants (bool): If True, compute and save normalization constants.
            final_model (bool): If True, we do NOT assume an 'aki' column (for unlabeled test).

        Returns:
            pd.DataFrame: Preprocessed DataFrame with standardized columns.
        """
        dfs = []
        for filename in filenames:
            df = pd.read_csv(filename)
            date_cols = [c for c in df.columns if "creatinine_date_" in c]
            res_cols  = [c for c in df.columns if "creatinine_result_" in c]

            # Convert date columns to datetime
            for c in date_cols:
                df[c] = pd.to_datetime(df[c], errors='coerce')

            # Build new rows with summary features for each row in the CSV
            new_rows = []
            for _, row in df.iterrows():
                age = row['age']
                sex = 1 if str(row['sex']).lower() == 'm' else 0
                if not final_model:
                    aki = 1 if str(row['aki']).lower() == 'y' else 0

                latest_date = None
                latest_value = None
                prev_values = []

                # Find the latest creatinine value, plus any within 365 days
                for date_col, result_col in zip(date_cols, res_cols):
                    if pd.notna(row[date_col]):
                        if (latest_date is None) or (row[date_col] > latest_date):
                            # If there's an old "latest_value" within 365 days, treat it as "previous"
                            if (latest_value is not None) and ((latest_date - row[date_col]).days <= 365):
                                prev_values.append(latest_value)
                            latest_date = row[date_col]
                            latest_value = row[result_col]
                        else:
                            # If within 365 days of the current latest date
                            if (latest_date - row[date_col]).days <= 365:
                                prev_values.append(row[result_col])

                if prev_values:
                    median_prev = float(pd.Series(prev_values).median())
                    mean_prev   = float(pd.Series(prev_values).mean())
                    std_prev    = float(pd.Series(prev_values).std(ddof=0))
                else:
                    median_prev = latest_value
                    mean_prev   = latest_value
                    std_prev    = 0.0

                if mean_prev != 0:
                    abs_pct_diff = abs((latest_value - mean_prev) / mean_prev)
                else:
                    abs_pct_diff = 0

                if not final_model:
                    new_rows.append([
                        age, sex, aki, latest_value, median_prev,
                        mean_prev, std_prev, abs_pct_diff
                    ])
                else:
                    new_rows.append([
                        age, sex, latest_value, median_prev,
                        mean_prev, std_prev, abs_pct_diff
                    ])

            if not final_model:
                dfs.append(pd.DataFrame(new_rows, columns=[
                    'age', 'sex', 'aki', 'latest_creatinine_value',
                    'median_previous', 'mean_previous', 'std_dev_previous',
                    'abs_percentage_diff'
                ]))
            else:
                dfs.append(pd.DataFrame(new_rows, columns=[
                    'age', 'sex', 'latest_creatinine_value',
                    'median_previous', 'mean_previous',
                    'std_dev_previous', 'abs_percentage_diff'
                ]))

        df_final = pd.concat(dfs, ignore_index=True)

        # Normalize numeric columns
        numeric_cols = [
            'age', 'latest_creatinine_value', 'median_previous',
            'mean_previous', 'std_dev_previous', 'abs_percentage_diff'
        ]
        if save_constants:
            for col in numeric_cols:
                mean_ = df_final[col].mean()
                std_ = df_final[col].std()
                self.normalization_constants[col] = {'mean': mean_, 'std': std_}
                df_final[col] = (df_final[col] - mean_) / std_
            with open('normalization_constants.json', 'w') as f:
                json.dump(self.normalization_constants, f)
        else:
Stavros Mitsis's avatar
Stavros Mitsis committed
            base_dir = os.path.dirname(os.path.abspath(__file__))
            constants_path = os.path.join(base_dir, 'normalization_constants.json')

            with open(constants_path, 'r') as f:
Stavros Mitsis's avatar
Stavros Mitsis committed
                self.normalization_constants = json.load(f)
            for col in numeric_cols:
                mean_ = self.normalization_constants[col]['mean']
                std_ = self.normalization_constants[col]['std']
                df_final[col] = (df_final[col] - mean_) / std_

        return df_final

    def handle_class_imbalance(self, df: pd.DataFrame):
        """
        Oversample minority class by random sampling with added noise.
        """
        aki_counts = df['aki'].value_counts()
        if len(aki_counts) < 2:
            # If there's only one class, nothing to balance
            return df

        imbalance_ratio = aki_counts.min() / aki_counts.max()
        if imbalance_ratio < 0.5:
            minority_class = df[df['aki'] == aki_counts.idxmin()]
            num_samples = aki_counts.max() - aki_counts.min()
            oversampled_minority = minority_class.sample(n=num_samples, replace=True, random_state=42)

            numeric_cols = [
                'age', 'latest_creatinine_value', 'median_previous',
                'mean_previous', 'std_dev_previous', 'abs_percentage_diff'
            ]
            noise = np.random.normal(0, 0.5, size=oversampled_minority[numeric_cols].shape)
            oversampled_minority[numeric_cols] += noise

            df = pd.concat([df, oversampled_minority], ignore_index=True)
        return df


###############################################################################
# 2. PyTorch Dataset and Model
###############################################################################
class AKIDataset(Dataset):
    """
    PyTorch Dataset for AKI data, returning (features, label).
    """
    def __init__(self, data: pd.DataFrame):
        """
        data should have an 'aki' column.
        If you want to do inference on unlabeled data,
        you can pass a DataFrame with a dummy 'aki' column = 0 or np.nan.
        """
        self.data = data
        self.features = data.drop(columns=['aki']).values
        self.labels = data['aki'].values

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.features[idx]
        y = self.labels[idx]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)


class AKINet(nn.Module):
    """
    A simple feed-forward neural network for AKI classification.
    """
    def __init__(self, input_size, num_hidden_layers, hidden_layer_size):
        super(AKINet, self).__init__()
        layers = []
        for i in range(num_hidden_layers):
            in_features = input_size if i == 0 else hidden_layer_size
            layers.append(nn.Linear(in_features, hidden_layer_size))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_layer_size, 1))
        layers.append(nn.Sigmoid())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


###############################################################################
# 3. Trainer Class
###############################################################################
class Trainer:
    """
    Encapsulates the training and validation loop, including early stopping
    based on best F3.
    """
    def __init__(self, model, optimizer, criterion, device):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device

    def train_and_validate(
        self,
        train_loader,
        val_loader,
        threshold=0.5,
        patience=20,
        max_epochs=500
    ):
        """
        Train and validate the model with early stopping on best F3 score.

        Returns:
            (nn.Module, float): The best model (state) and the best F3 achieved.
        """
        best_val_f3 = 0.0
        patience_counter = 0
        best_model_state = None

        for epoch in range(max_epochs):
            # Training
            self.model.train()
            for x_batch, y_batch in train_loader:
                x_batch = x_batch.to(self.device)
                y_batch = y_batch.to(self.device)
                self.optimizer.zero_grad()
                preds = self.model(x_batch).squeeze()
                loss = self.criterion(preds, y_batch)
                loss.backward()
                self.optimizer.step()

            # Validation
            val_f3 = self._evaluate_f3(val_loader, threshold)
            if val_f3 > best_val_f3:
                best_val_f3 = val_f3
                patience_counter = 0
                best_model_state = self.model.state_dict()
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}. Best F3 = {best_val_f3:.4f}")
                break

        # Load the best state
        if best_model_state:
            self.model.load_state_dict(best_model_state)
        return self.model, best_val_f3

    def _evaluate_f3(self, data_loader, threshold):
        """
        Compute F3 on a given DataLoader.
        """
        self.model.eval()
        preds, labels = [], []
        with torch.no_grad():
            for x_batch, y_batch in data_loader:
                x_batch = x_batch.to(self.device)
                y_batch = y_batch.to(self.device)
                out = self.model(x_batch).squeeze()
                preds.extend((out > threshold).cpu().numpy())
                labels.extend(y_batch.cpu().numpy())
        return fbeta_score(labels, preds, beta=3)


###############################################################################
# 4. Helper Functions for Training / Evaluation
###############################################################################

def cross_validate_model(
    data: pd.DataFrame,
    num_hidden_layers: int,
    hidden_layer_size: int,
    learning_rate: float,
    batch_size: int,
    threshold: float,
    device: torch.device,
    n_splits=5
):
    """
    Perform n-fold cross-validation, returning average F3 across folds.
    This is used by the objective function during hyperparam tuning.
    """
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    f3_scores = []

    for train_idx, val_idx in skf.split(data, data['aki']):
        train_fold = data.iloc[train_idx].copy()
        val_fold   = data.iloc[val_idx].copy()

        train_loader = DataLoader(AKIDataset(train_fold), batch_size=batch_size, shuffle=True)
        val_loader   = DataLoader(AKIDataset(val_fold), batch_size=batch_size, shuffle=False)

        model = AKINet(
            input_size=train_fold.shape[1] - 1,
            num_hidden_layers=num_hidden_layers,
            hidden_layer_size=hidden_layer_size
        ).to(device)

        criterion = nn.BCELoss()
        optimizer = Adam(model.parameters(), lr=learning_rate)
        trainer   = Trainer(model, optimizer, criterion, device)

        _, best_f3 = trainer.train_and_validate(
            train_loader,
            val_loader,
            threshold=threshold,
            patience=20,
            max_epochs=500
        )
        f3_scores.append(best_f3)

    return np.mean(f3_scores)


def train_model(
    data: pd.DataFrame,
    hyperparams: dict,
    device: torch.device,
    val_split=0.2
):
    """
    Train a final model (no hyperparameter search) using the given hyperparams.
    Splits `data` into (train, val) by `val_split` ratio for early stopping.

    Saves the entire model to 'best_model.pth' (so we don't need to re-instantiate later).
    Returns the trained model.
    """
    # Unpack hyperparams
    num_hidden_layers = hyperparams["num_hidden_layers"]
    hidden_layer_size = hyperparams["hidden_layer_size"]
    learning_rate     = hyperparams["learning_rate"]
    batch_size        = hyperparams["batch_size"]
    threshold         = hyperparams["threshold"]  # for early stopping measurement

    # Train-val split
    train_df, val_df = train_test_split(data, test_size=val_split, stratify=data['aki'], random_state=42)

    # Create loaders
    train_loader = DataLoader(AKIDataset(train_df), batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(AKIDataset(val_df),   batch_size=batch_size, shuffle=False)

    # Instantiate model
    model = AKINet(
        input_size=train_df.shape[1] - 1,
        num_hidden_layers=num_hidden_layers,
        hidden_layer_size=hidden_layer_size
    ).to(device)

    criterion = nn.BCELoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)
    trainer   = Trainer(model, optimizer, criterion, device)

    # Train with early stopping
    best_model, best_f3 = trainer.train_and_validate(
        train_loader,
        val_loader,
        threshold=threshold,
        patience=20,
        max_epochs=500
    )

    print(f"Final model after training. Best F3 on validation: {best_f3:.4f}")
    # Save entire model (no need for separate architecture instantiation later)
    torch.save(best_model, "best_model.pth")
    print("Saved entire model to best_model.pth")
    return best_model


def evaluate(model, data: pd.DataFrame, device: torch.device, threshold=0.5):
    """
    Evaluate a model on a dataset. If data contains 'aki' column, compute F3 & accuracy.
    Otherwise, just return predictions.

    Args:
        model (nn.Module): The trained model (already loaded onto device).
        data (pd.DataFrame): Data to evaluate or infer. Must contain 'aki' if you want metrics.
        device (torch.device): 'cpu' or 'cuda'.
        threshold (float): Classification threshold.

    Returns:
        dict with possible keys:
            - 'predictions': List/array of 0/1 predictions
            - 'f3': F3 score (if 'aki' in data)
            - 'accuracy': Accuracy score (if 'aki' in data)
    """
    # If 'aki' is missing, create a dummy column for Dataset
    has_labels = 'aki' in data.columns
    if not has_labels:
        data = data.copy()
        data['aki'] = 0  # dummy

    loader = DataLoader(AKIDataset(data), batch_size=64, shuffle=False)
    model.eval()
    model.to(device)

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x_batch, y_batch in loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            out = model(x_batch).squeeze()
            batch_preds = (out > threshold).float().cpu().numpy()
            all_preds.extend(batch_preds)
            all_labels.extend(y_batch.cpu().numpy())

    result = {
        "predictions": np.array(all_preds, dtype=int).tolist()
    }

    if has_labels:
        f3 = fbeta_score(all_labels, all_preds, beta=3)
        acc = accuracy_score(all_labels, all_preds)
        result["f3"] = f3
        result["accuracy"] = acc

    return result


###############################################################################
# 5. Hyperparameter Tuning (Optuna)
###############################################################################

def objective(trial, data, device):
    """
    A minimal objective function for Optuna. Performs cross-validation on `data`
    with the hyperparameters sampled by `trial`.
    """
    # Suggest hyperparameters
    num_hidden_layers = trial.suggest_int("num_hidden_layers", 1, 5)
    hidden_layer_size = trial.suggest_categorical(
Stavros Mitsis's avatar
Stavros Mitsis committed
        "hidden_layer_size", [32, 48,64, 96,128,184, 256, 512, 1024]
Stavros Mitsis's avatar
Stavros Mitsis committed
    )
    learning_rate = trial.suggest_categorical(
        "learning_rate", [0.001, 0.005, 0.01, 0.02, 0.05]
    )
    batch_size = trial.suggest_categorical(
        "batch_size", [32, 64, 128, 256, 512, 1024]
    )
    threshold = trial.suggest_categorical(
        "threshold", [0.25, 0.375, 0.5, 0.625, 0.75]
    )

    # Perform cross-validation
    avg_f3 = cross_validate_model(
        data=data,
        num_hidden_layers=num_hidden_layers,
        hidden_layer_size=hidden_layer_size,
        learning_rate=learning_rate,
        batch_size=batch_size,
        threshold=threshold,
        device=device,
        n_splits=5
    )

    return avg_f3


Stavros Mitsis's avatar
Stavros Mitsis committed
def tune_hyperparameters(data: pd.DataFrame, n_trials=30):
Stavros Mitsis's avatar
Stavros Mitsis committed
    """
    Run Optuna hyperparameter tuning on `data`, maximizing F3 via cross-validation.

    Saves best hyperparameters to best_hyperparameters.json.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    study = optuna.create_study(direction="maximize")
    study.optimize(lambda trial: objective(trial, data, device), n_trials=n_trials)

    # Print best results
    print("\nBest trial:")
    print(f"Hyperparameters: {study.best_trial.params}")
    print(f"Avg F3 Score (CV): {study.best_trial.value:.4f}")

    # Save best hyperparameters
    with open("best_hyperparameters.json", "w") as f:
        json.dump(study.best_trial.params, f)
    print("Saved best hyperparameters to best_hyperparameters.json")

    return study.best_trial


###############################################################################
# 6. Example Usage Notes
###############################################################################
"""
How to use:

1) Hyperparameter Tuning:
   data_processor = DataProcessor()
   train_data = data_processor.preprocess(["training.csv"], save_constants=True)
   train_data = data_processor.handle_class_imbalance(train_data)
   # tune:
   best_trial = tune_hyperparameters(train_data, n_trials=10)

2) Train Final Model (using best hyperparameters):
   # load best hyperparams:
   with open("best_hyperparameters.json", "r") as f:
       best_hparams = json.load(f)

   final_model = train_model(train_data, best_hparams, device=torch.device("cpu"))
   # This saves "best_model.pth" to disk, containing the entire model.

3) Evaluate on a separate test set:
   test_data = data_processor.preprocess(["test.csv"], save_constants=False)
   # Make sure test_data also has an 'aki' column if you want metrics.
   results = evaluate(final_model, test_data, device=torch.device("cpu"), threshold=best_hparams["threshold"])
   print(results)  # => {'predictions': [...], 'f3': ..., 'accuracy': ...}

4) Final Inference (unlabeled):
   # If 'aki' column is missing, evaluate(...) won't compute metrics, but returns predictions only.
   test_data_unlabeled = data_processor.preprocess(["test.csv"], save_constants=False, final_model=True)
   model = torch.load("best_model.pth")  # load once
   inference_results = evaluate(model, test_data_unlabeled, device=torch.device("cpu"), threshold=best_hparams["threshold"])
   print(inference_results)  # => {'predictions': [...]}

5) Single-Row Prediction:
   # Just pass a single-row DataFrame (with a dummy 'aki' if necessary).
   # Or slice the test_data_unlabeled at row i, then evaluate.
"""