Skip to content
Snippets Groups Projects
helper_functions.py 19.9 KiB
Newer Older
Stavros Mitsis's avatar
Stavros Mitsis committed
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import fbeta_score, accuracy_score
import pandas as pd
import numpy as np
import json
import optuna
Stavros Mitsis's avatar
Stavros Mitsis committed
import os
Stavros Mitsis's avatar
Stavros Mitsis committed

###############################################################################
# 1. Data Processing
###############################################################################
class DataProcessor:
    """
    Handles data loading, feature engineering, and normalization for AKI data.
    """
    def __init__(self):
        self.normalization_constants = {}

    def preprocess(self, filenames, save_constants=False, final_model=False):
        """
        Preprocess one or more CSV files (merge if multiple). Extract relevant features:
        - age
        - sex
        - (optionally) aki
        - latest_creatinine_value
        - median_previous, mean_previous, std_dev_previous (within 365 days)
        - abs_percentage_diff = |(latest - mean_previous)/mean_previous|

        Args:
            filenames (list of str): CSV files to load and combine.
            save_constants (bool): If True, compute and save normalization constants.
            final_model (bool): If True, we do NOT assume an 'aki' column (for unlabeled test).

        Returns:
            pd.DataFrame: Preprocessed DataFrame with standardized columns.
        """
        dfs = []
        for filename in filenames:
            df = pd.read_csv(filename)
            date_cols = [c for c in df.columns if "creatinine_date_" in c]
            res_cols  = [c for c in df.columns if "creatinine_result_" in c]

            # Convert date columns to datetime
            for c in date_cols:
                df[c] = pd.to_datetime(df[c], errors='coerce')

            # Build new rows with summary features for each row in the CSV
            new_rows = []
            for _, row in df.iterrows():
                age = row['age']
                sex = 1 if str(row['sex']).lower() == 'm' else 0
                if not final_model:
                    aki = 1 if str(row['aki']).lower() == 'y' else 0

                latest_date = None
                latest_value = None
                prev_values = []

                # Find the latest creatinine value, plus any within 365 days
                for date_col, result_col in zip(date_cols, res_cols):
                    if pd.notna(row[date_col]):
                        if (latest_date is None) or (row[date_col] > latest_date):
                            # If there's an old "latest_value" within 365 days, treat it as "previous"
                            if (latest_value is not None) and ((latest_date - row[date_col]).days <= 365):
                                prev_values.append(latest_value)
                            latest_date = row[date_col]
                            latest_value = row[result_col]
                        else:
                            # If within 365 days of the current latest date
                            if (latest_date - row[date_col]).days <= 365:
                                prev_values.append(row[result_col])

                if prev_values:
                    median_prev = float(pd.Series(prev_values).median())
                    mean_prev   = float(pd.Series(prev_values).mean())
                    std_prev    = float(pd.Series(prev_values).std(ddof=0))
                else:
                    median_prev = latest_value
                    mean_prev   = latest_value
                    std_prev    = 0.0

                if mean_prev != 0:
                    abs_pct_diff = abs((latest_value - mean_prev) / mean_prev)
                else:
                    abs_pct_diff = 0

                if not final_model:
                    new_rows.append([
                        age, sex, aki, latest_value, median_prev,
                        mean_prev, std_prev, abs_pct_diff
                    ])
                else:
                    new_rows.append([
                        age, sex, latest_value, median_prev,
                        mean_prev, std_prev, abs_pct_diff
                    ])

            if not final_model:
                dfs.append(pd.DataFrame(new_rows, columns=[
                    'age', 'sex', 'aki', 'latest_creatinine_value',
                    'median_previous', 'mean_previous', 'std_dev_previous',
                    'abs_percentage_diff'
                ]))
            else:
                dfs.append(pd.DataFrame(new_rows, columns=[
                    'age', 'sex', 'latest_creatinine_value',
                    'median_previous', 'mean_previous',
                    'std_dev_previous', 'abs_percentage_diff'
                ]))

        df_final = pd.concat(dfs, ignore_index=True)

        # Normalize numeric columns
        numeric_cols = [
            'age', 'latest_creatinine_value', 'median_previous',
            'mean_previous', 'std_dev_previous', 'abs_percentage_diff'
        ]
        if save_constants:
            for col in numeric_cols:
                mean_ = df_final[col].mean()
                std_ = df_final[col].std()
                self.normalization_constants[col] = {'mean': mean_, 'std': std_}
                df_final[col] = (df_final[col] - mean_) / std_
            with open('normalization_constants.json', 'w') as f:
                json.dump(self.normalization_constants, f)
        else:
Stavros Mitsis's avatar
Stavros Mitsis committed
            base_dir = os.path.dirname(os.path.abspath(__file__))
            constants_path = os.path.join(base_dir, 'normalization_constants.json')

            with open(constants_path, 'r') as f:
Stavros Mitsis's avatar
Stavros Mitsis committed
135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547
                self.normalization_constants = json.load(f)
            for col in numeric_cols:
                mean_ = self.normalization_constants[col]['mean']
                std_ = self.normalization_constants[col]['std']
                df_final[col] = (df_final[col] - mean_) / std_

        return df_final

    def handle_class_imbalance(self, df: pd.DataFrame):
        """
        Oversample minority class by random sampling with added noise.
        """
        aki_counts = df['aki'].value_counts()
        if len(aki_counts) < 2:
            # If there's only one class, nothing to balance
            return df

        imbalance_ratio = aki_counts.min() / aki_counts.max()
        if imbalance_ratio < 0.5:
            minority_class = df[df['aki'] == aki_counts.idxmin()]
            num_samples = aki_counts.max() - aki_counts.min()
            oversampled_minority = minority_class.sample(n=num_samples, replace=True, random_state=42)

            numeric_cols = [
                'age', 'latest_creatinine_value', 'median_previous',
                'mean_previous', 'std_dev_previous', 'abs_percentage_diff'
            ]
            noise = np.random.normal(0, 0.5, size=oversampled_minority[numeric_cols].shape)
            oversampled_minority[numeric_cols] += noise

            df = pd.concat([df, oversampled_minority], ignore_index=True)
        return df


###############################################################################
# 2. PyTorch Dataset and Model
###############################################################################
class AKIDataset(Dataset):
    """
    PyTorch Dataset for AKI data, returning (features, label).
    """
    def __init__(self, data: pd.DataFrame):
        """
        data should have an 'aki' column.
        If you want to do inference on unlabeled data,
        you can pass a DataFrame with a dummy 'aki' column = 0 or np.nan.
        """
        self.data = data
        self.features = data.drop(columns=['aki']).values
        self.labels = data['aki'].values

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        x = self.features[idx]
        y = self.labels[idx]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)


class AKINet(nn.Module):
    """
    A simple feed-forward neural network for AKI classification.
    """
    def __init__(self, input_size, num_hidden_layers, hidden_layer_size):
        super(AKINet, self).__init__()
        layers = []
        for i in range(num_hidden_layers):
            in_features = input_size if i == 0 else hidden_layer_size
            layers.append(nn.Linear(in_features, hidden_layer_size))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_layer_size, 1))
        layers.append(nn.Sigmoid())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
        return self.net(x)


###############################################################################
# 3. Trainer Class
###############################################################################
class Trainer:
    """
    Encapsulates the training and validation loop, including early stopping
    based on best F3.
    """
    def __init__(self, model, optimizer, criterion, device):
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device

    def train_and_validate(
        self,
        train_loader,
        val_loader,
        threshold=0.5,
        patience=20,
        max_epochs=500
    ):
        """
        Train and validate the model with early stopping on best F3 score.

        Returns:
            (nn.Module, float): The best model (state) and the best F3 achieved.
        """
        best_val_f3 = 0.0
        patience_counter = 0
        best_model_state = None

        for epoch in range(max_epochs):
            # Training
            self.model.train()
            for x_batch, y_batch in train_loader:
                x_batch = x_batch.to(self.device)
                y_batch = y_batch.to(self.device)
                self.optimizer.zero_grad()
                preds = self.model(x_batch).squeeze()
                loss = self.criterion(preds, y_batch)
                loss.backward()
                self.optimizer.step()

            # Validation
            val_f3 = self._evaluate_f3(val_loader, threshold)
            if val_f3 > best_val_f3:
                best_val_f3 = val_f3
                patience_counter = 0
                best_model_state = self.model.state_dict()
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print(f"Early stopping at epoch {epoch+1}. Best F3 = {best_val_f3:.4f}")
                break

        # Load the best state
        if best_model_state:
            self.model.load_state_dict(best_model_state)
        return self.model, best_val_f3

    def _evaluate_f3(self, data_loader, threshold):
        """
        Compute F3 on a given DataLoader.
        """
        self.model.eval()
        preds, labels = [], []
        with torch.no_grad():
            for x_batch, y_batch in data_loader:
                x_batch = x_batch.to(self.device)
                y_batch = y_batch.to(self.device)
                out = self.model(x_batch).squeeze()
                preds.extend((out > threshold).cpu().numpy())
                labels.extend(y_batch.cpu().numpy())
        return fbeta_score(labels, preds, beta=3)


###############################################################################
# 4. Helper Functions for Training / Evaluation
###############################################################################

def cross_validate_model(
    data: pd.DataFrame,
    num_hidden_layers: int,
    hidden_layer_size: int,
    learning_rate: float,
    batch_size: int,
    threshold: float,
    device: torch.device,
    n_splits=5
):
    """
    Perform n-fold cross-validation, returning average F3 across folds.
    This is used by the objective function during hyperparam tuning.
    """
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    f3_scores = []

    for train_idx, val_idx in skf.split(data, data['aki']):
        train_fold = data.iloc[train_idx].copy()
        val_fold   = data.iloc[val_idx].copy()

        train_loader = DataLoader(AKIDataset(train_fold), batch_size=batch_size, shuffle=True)
        val_loader   = DataLoader(AKIDataset(val_fold), batch_size=batch_size, shuffle=False)

        model = AKINet(
            input_size=train_fold.shape[1] - 1,
            num_hidden_layers=num_hidden_layers,
            hidden_layer_size=hidden_layer_size
        ).to(device)

        criterion = nn.BCELoss()
        optimizer = Adam(model.parameters(), lr=learning_rate)
        trainer   = Trainer(model, optimizer, criterion, device)

        _, best_f3 = trainer.train_and_validate(
            train_loader,
            val_loader,
            threshold=threshold,
            patience=20,
            max_epochs=500
        )
        f3_scores.append(best_f3)

    return np.mean(f3_scores)


def train_model(
    data: pd.DataFrame,
    hyperparams: dict,
    device: torch.device,
    val_split=0.2
):
    """
    Train a final model (no hyperparameter search) using the given hyperparams.
    Splits `data` into (train, val) by `val_split` ratio for early stopping.

    Saves the entire model to 'best_model.pth' (so we don't need to re-instantiate later).
    Returns the trained model.
    """
    # Unpack hyperparams
    num_hidden_layers = hyperparams["num_hidden_layers"]
    hidden_layer_size = hyperparams["hidden_layer_size"]
    learning_rate     = hyperparams["learning_rate"]
    batch_size        = hyperparams["batch_size"]
    threshold         = hyperparams["threshold"]  # for early stopping measurement

    # Train-val split
    train_df, val_df = train_test_split(data, test_size=val_split, stratify=data['aki'], random_state=42)

    # Create loaders
    train_loader = DataLoader(AKIDataset(train_df), batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(AKIDataset(val_df),   batch_size=batch_size, shuffle=False)

    # Instantiate model
    model = AKINet(
        input_size=train_df.shape[1] - 1,
        num_hidden_layers=num_hidden_layers,
        hidden_layer_size=hidden_layer_size
    ).to(device)

    criterion = nn.BCELoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)
    trainer   = Trainer(model, optimizer, criterion, device)

    # Train with early stopping
    best_model, best_f3 = trainer.train_and_validate(
        train_loader,
        val_loader,
        threshold=threshold,
        patience=20,
        max_epochs=500
    )

    print(f"Final model after training. Best F3 on validation: {best_f3:.4f}")
    # Save entire model (no need for separate architecture instantiation later)
    torch.save(best_model, "best_model.pth")
    print("Saved entire model to best_model.pth")
    return best_model


def evaluate(model, data: pd.DataFrame, device: torch.device, threshold=0.5):
    """
    Evaluate a model on a dataset. If data contains 'aki' column, compute F3 & accuracy.
    Otherwise, just return predictions.

    Args:
        model (nn.Module): The trained model (already loaded onto device).
        data (pd.DataFrame): Data to evaluate or infer. Must contain 'aki' if you want metrics.
        device (torch.device): 'cpu' or 'cuda'.
        threshold (float): Classification threshold.

    Returns:
        dict with possible keys:
            - 'predictions': List/array of 0/1 predictions
            - 'f3': F3 score (if 'aki' in data)
            - 'accuracy': Accuracy score (if 'aki' in data)
    """
    # If 'aki' is missing, create a dummy column for Dataset
    has_labels = 'aki' in data.columns
    if not has_labels:
        data = data.copy()
        data['aki'] = 0  # dummy

    loader = DataLoader(AKIDataset(data), batch_size=64, shuffle=False)
    model.eval()
    model.to(device)

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x_batch, y_batch in loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            out = model(x_batch).squeeze()
            batch_preds = (out > threshold).float().cpu().numpy()
            all_preds.extend(batch_preds)
            all_labels.extend(y_batch.cpu().numpy())

    result = {
        "predictions": np.array(all_preds, dtype=int).tolist()
    }

    if has_labels:
        f3 = fbeta_score(all_labels, all_preds, beta=3)
        acc = accuracy_score(all_labels, all_preds)
        result["f3"] = f3
        result["accuracy"] = acc

    return result


###############################################################################
# 5. Hyperparameter Tuning (Optuna)
###############################################################################

def objective(trial, data, device):
    """
    A minimal objective function for Optuna. Performs cross-validation on `data`
    with the hyperparameters sampled by `trial`.
    """
    # Suggest hyperparameters
    num_hidden_layers = trial.suggest_int("num_hidden_layers", 1, 5)
    hidden_layer_size = trial.suggest_categorical(
        "hidden_layer_size", [32, 64, 128, 256, 512, 1024]
    )
    learning_rate = trial.suggest_categorical(
        "learning_rate", [0.001, 0.005, 0.01, 0.02, 0.05]
    )
    batch_size = trial.suggest_categorical(
        "batch_size", [32, 64, 128, 256, 512, 1024]
    )
    threshold = trial.suggest_categorical(
        "threshold", [0.25, 0.375, 0.5, 0.625, 0.75]
    )

    # Perform cross-validation
    avg_f3 = cross_validate_model(
        data=data,
        num_hidden_layers=num_hidden_layers,
        hidden_layer_size=hidden_layer_size,
        learning_rate=learning_rate,
        batch_size=batch_size,
        threshold=threshold,
        device=device,
        n_splits=5
    )

    return avg_f3


def tune_hyperparameters(data: pd.DataFrame, n_trials=10):
    """
    Run Optuna hyperparameter tuning on `data`, maximizing F3 via cross-validation.

    Saves best hyperparameters to best_hyperparameters.json.
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    study = optuna.create_study(direction="maximize")
    study.optimize(lambda trial: objective(trial, data, device), n_trials=n_trials)

    # Print best results
    print("\nBest trial:")
    print(f"Hyperparameters: {study.best_trial.params}")
    print(f"Avg F3 Score (CV): {study.best_trial.value:.4f}")

    # Save best hyperparameters
    with open("best_hyperparameters.json", "w") as f:
        json.dump(study.best_trial.params, f)
    print("Saved best hyperparameters to best_hyperparameters.json")

    return study.best_trial


###############################################################################
# 6. Example Usage Notes
###############################################################################
"""
How to use:

1) Hyperparameter Tuning:
   data_processor = DataProcessor()
   train_data = data_processor.preprocess(["training.csv"], save_constants=True)
   train_data = data_processor.handle_class_imbalance(train_data)
   # tune:
   best_trial = tune_hyperparameters(train_data, n_trials=10)

2) Train Final Model (using best hyperparameters):
   # load best hyperparams:
   with open("best_hyperparameters.json", "r") as f:
       best_hparams = json.load(f)

   final_model = train_model(train_data, best_hparams, device=torch.device("cpu"))
   # This saves "best_model.pth" to disk, containing the entire model.

3) Evaluate on a separate test set:
   test_data = data_processor.preprocess(["test.csv"], save_constants=False)
   # Make sure test_data also has an 'aki' column if you want metrics.
   results = evaluate(final_model, test_data, device=torch.device("cpu"), threshold=best_hparams["threshold"])
   print(results)  # => {'predictions': [...], 'f3': ..., 'accuracy': ...}

4) Final Inference (unlabeled):
   # If 'aki' column is missing, evaluate(...) won't compute metrics, but returns predictions only.
   test_data_unlabeled = data_processor.preprocess(["test.csv"], save_constants=False, final_model=True)
   model = torch.load("best_model.pth")  # load once
   inference_results = evaluate(model, test_data_unlabeled, device=torch.device("cpu"), threshold=best_hparams["threshold"])
   print(inference_results)  # => {'predictions': [...]}

5) Single-Row Prediction:
   # Just pass a single-row DataFrame (with a dummy 'aki' if necessary).
   # Or slice the test_data_unlabeled at row i, then evaluate.
"""