Skip to content
Snippets Groups Projects
helper_functions.py 24.2 KiB
Newer Older
Stavros Mitsis's avatar
Stavros Mitsis committed
import torch
import torch.nn as nn
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
from sklearn.model_selection import train_test_split, StratifiedKFold
from sklearn.metrics import fbeta_score, accuracy_score
import pandas as pd
import numpy as np
import json
import optuna
Stavros Mitsis's avatar
Stavros Mitsis committed
import os
Stavros Mitsis's avatar
Stavros Mitsis committed

###############################################################################
# 1. Data Processing
###############################################################################
class DataProcessor:
    """
    Handles data loading, feature engineering, and normalization for AKI data.
    """
    def __init__(self):
Stavros Mitsis's avatar
Stavros Mitsis committed
        """
               Initialize the normalization constants dictionary.
        """
Stavros Mitsis's avatar
Stavros Mitsis committed
        self.normalization_constants = {}

    def preprocess(self, filenames, save_constants=False, final_model=False):
        """
        Preprocess one or more CSV files (merge if multiple). Extract relevant features:
        - age
        - sex
        - latest_creatinine_value
Stavros Mitsis's avatar
Stavros Mitsis committed
        - Summary statistics (mean, median, std_dev) within 365 days
        - Absolute percentage difference from the mean
Stavros Mitsis's avatar
Stavros Mitsis committed

        Args:
Stavros Mitsis's avatar
Stavros Mitsis committed
            filenames (list of str): List of CSV file paths to process.
Stavros Mitsis's avatar
Stavros Mitsis committed
            save_constants (bool): If True, compute and save normalization constants.
Stavros Mitsis's avatar
Stavros Mitsis committed
            final_model (bool): If True, assumes no 'aki' column in the data.
Stavros Mitsis's avatar
Stavros Mitsis committed

        Returns:
            pd.DataFrame: Preprocessed DataFrame with standardized columns.
        """
Stavros Mitsis's avatar
Stavros Mitsis committed

Stavros Mitsis's avatar
Stavros Mitsis committed
        dfs = []
        for filename in filenames:
            df = pd.read_csv(filename)
            date_cols = [c for c in df.columns if "creatinine_date_" in c]
            res_cols  = [c for c in df.columns if "creatinine_result_" in c]

            # Convert date columns to datetime
            for c in date_cols:
                df[c] = pd.to_datetime(df[c], errors='coerce')

Stavros Mitsis's avatar
Stavros Mitsis committed
            # Binarize gender and AKI value
Stavros Mitsis's avatar
Stavros Mitsis committed
            new_rows = []
            for _, row in df.iterrows():
                age = row['age']
                sex = 1 if str(row['sex']).lower() == 'm' else 0
Stavros Mitsis's avatar
Stavros Mitsis committed
                # The final model argument controls if there is going to be anm AKI column in the
                # Data. In training the data frame has an AKI column but during inference in a  Clinical setting
                # this will be missing thus we let the preprocessor know what is the intended application
Stavros Mitsis's avatar
Stavros Mitsis committed
                if not final_model:
Stavros Mitsis's avatar
Stavros Mitsis committed

Stavros Mitsis's avatar
Stavros Mitsis committed
                    aki = 1 if str(row['aki']).lower() == 'y' else 0

                latest_date = None
                latest_value = None
                prev_values = []

Stavros Mitsis's avatar
Stavros Mitsis committed
                # Find the latest creatine test and any creatine test before the latest test that have
                #been conducted within 365 days of the latest test
Stavros Mitsis's avatar
Stavros Mitsis committed
                for date_col, result_col in zip(date_cols, res_cols):
                    if pd.notna(row[date_col]):
                        if (latest_date is None) or (row[date_col] > latest_date):
                            if (latest_value is not None) and ((latest_date - row[date_col]).days <= 365):
                                prev_values.append(latest_value)
                            latest_date = row[date_col]
                            latest_value = row[result_col]
                        else:
                            # If within 365 days of the current latest date
                            if (latest_date - row[date_col]).days <= 365:
                                prev_values.append(row[result_col])

Stavros Mitsis's avatar
Stavros Mitsis committed
                # Here we compute the mean,median and standard deviation of all the tests within 365 expect the latest
                #one
Stavros Mitsis's avatar
Stavros Mitsis committed
                if prev_values:
                    median_prev = float(pd.Series(prev_values).median())
                    mean_prev   = float(pd.Series(prev_values).mean())
                    std_prev    = float(pd.Series(prev_values).std(ddof=0))
Stavros Mitsis's avatar
Stavros Mitsis committed
                else: #If only one test exists and it's the latest then mean and median is the latest test
Stavros Mitsis's avatar
Stavros Mitsis committed
                    median_prev = latest_value
                    mean_prev   = latest_value
                    std_prev    = 0.0
Stavros Mitsis's avatar
Stavros Mitsis committed
                #Here we compute the absolute percentage difference between the latest test and the mean
Stavros Mitsis's avatar
Stavros Mitsis committed
                if mean_prev != 0:
                    abs_pct_diff = abs((latest_value - mean_prev) / mean_prev)
Stavros Mitsis's avatar
Stavros Mitsis committed
                else: # If only the latest test exists then the absolute % diff is zero.
Stavros Mitsis's avatar
Stavros Mitsis committed
                    abs_pct_diff = 0

                if not final_model:
                    new_rows.append([
                        age, sex, aki, latest_value, median_prev,
                        mean_prev, std_prev, abs_pct_diff
                    ])
                else:
                    new_rows.append([
                        age, sex, latest_value, median_prev,
                        mean_prev, std_prev, abs_pct_diff
                    ])

Stavros Mitsis's avatar
Stavros Mitsis committed
            #Create a new panda frame with the reformed columns as explained above
Stavros Mitsis's avatar
Stavros Mitsis committed
            if not final_model:
                dfs.append(pd.DataFrame(new_rows, columns=[
                    'age', 'sex', 'aki', 'latest_creatinine_value',
                    'median_previous', 'mean_previous', 'std_dev_previous',
                    'abs_percentage_diff'
                ]))
            else:
                dfs.append(pd.DataFrame(new_rows, columns=[
                    'age', 'sex', 'latest_creatinine_value',
                    'median_previous', 'mean_previous',
                    'std_dev_previous', 'abs_percentage_diff'
                ]))

        df_final = pd.concat(dfs, ignore_index=True)

        # Normalize numeric columns
        numeric_cols = [
            'age', 'latest_creatinine_value', 'median_previous',
            'mean_previous', 'std_dev_previous', 'abs_percentage_diff'
        ]
Stavros Mitsis's avatar
Stavros Mitsis committed

Stavros Mitsis's avatar
Stavros Mitsis committed
        # The if "save argument" is needed to specify where this is the training data or the test data. This is because
        # We need to normalize the data and save the normalization constants. During predictions, testing we load the
        # normalization constants that were saved during training such that we can normalize the test data .
Stavros Mitsis's avatar
Stavros Mitsis committed
        if save_constants:
            for col in numeric_cols:
                mean_ = df_final[col].mean()
                std_ = df_final[col].std()
                self.normalization_constants[col] = {'mean': mean_, 'std': std_}
                df_final[col] = (df_final[col] - mean_) / std_
            with open('normalization_constants.json', 'w') as f:
                json.dump(self.normalization_constants, f)
        else:
Stavros Mitsis's avatar
Stavros Mitsis committed
            base_dir = os.path.dirname(os.path.abspath(__file__))
            constants_path = os.path.join(base_dir, 'normalization_constants.json')

            with open(constants_path, 'r') as f:
Stavros Mitsis's avatar
Stavros Mitsis committed
                self.normalization_constants = json.load(f)
            for col in numeric_cols:
                mean_ = self.normalization_constants[col]['mean']
                std_ = self.normalization_constants[col]['std']
                df_final[col] = (df_final[col] - mean_) / std_

        return df_final

    def handle_class_imbalance(self, df: pd.DataFrame):
        """
Stavros Mitsis's avatar
Stavros Mitsis committed
        Oversample the minority class by adding random noise to samples.

        Args:
            df (pd.DataFrame): Input DataFrame containing 'aki' column.

        Returns:
            pd.DataFrame: Balanced DataFrame with oversampled minority class.
Stavros Mitsis's avatar
Stavros Mitsis committed
        """
        aki_counts = df['aki'].value_counts()
        if len(aki_counts) < 2:
            # If there's only one class, nothing to balance
            return df

Stavros Mitsis's avatar
Stavros Mitsis committed
        # Since the data have been normalized, then the training data can be oversampled for the minority class
        # and we can add some gaussian noise to make the model more robust.
Stavros Mitsis's avatar
Stavros Mitsis committed
        imbalance_ratio = aki_counts.min() / aki_counts.max()
        if imbalance_ratio < 0.5:
            minority_class = df[df['aki'] == aki_counts.idxmin()]
            num_samples = aki_counts.max() - aki_counts.min()
            oversampled_minority = minority_class.sample(n=num_samples, replace=True, random_state=42)

            numeric_cols = [
                'age', 'latest_creatinine_value', 'median_previous',
                'mean_previous', 'std_dev_previous', 'abs_percentage_diff'
            ]
            noise = np.random.normal(0, 0.5, size=oversampled_minority[numeric_cols].shape)
            oversampled_minority[numeric_cols] += noise

            df = pd.concat([df, oversampled_minority], ignore_index=True)
        return df


###############################################################################
# 2. PyTorch Dataset and Model
###############################################################################
class AKIDataset(Dataset):
    """
    PyTorch Dataset for AKI data, returning (features, label).
    """
    def __init__(self, data: pd.DataFrame):
        """
Stavros Mitsis's avatar
Stavros Mitsis committed
        Initialize dataset with features and labels.

        Args:
            data (pd.DataFrame): Input DataFrame with 'aki' column.
Stavros Mitsis's avatar
Stavros Mitsis committed
        """
        self.data = data
        self.features = data.drop(columns=['aki']).values
        self.labels = data['aki'].values

    def __len__(self):
Stavros Mitsis's avatar
Stavros Mitsis committed
        """
        Return the number of samples in the dataset.

        Returns:
            int: Number of samples.
        """
Stavros Mitsis's avatar
Stavros Mitsis committed
        return len(self.data)

    def __getitem__(self, idx):
Stavros Mitsis's avatar
Stavros Mitsis committed
        """
        Get a single sample by index.

        Args:
            idx (int): Index of the sample.

        Returns:
            tuple: (features, label) as PyTorch tensors.
        """
Stavros Mitsis's avatar
Stavros Mitsis committed
        x = self.features[idx]
        y = self.labels[idx]
        return torch.tensor(x, dtype=torch.float32), torch.tensor(y, dtype=torch.float32)


class AKINet(nn.Module):
    """
    A simple feed-forward neural network for AKI classification.
    """
    def __init__(self, input_size, num_hidden_layers, hidden_layer_size):
Stavros Mitsis's avatar
Stavros Mitsis committed
        """
        Initialize the neural network.

        Args:
            input_size (int): Number of input features.
            num_hidden_layers (int): Number of hidden layers.
            hidden_layer_size (int): Number of neurons per hidden layer.
        """
Stavros Mitsis's avatar
Stavros Mitsis committed
        super(AKINet, self).__init__()
        layers = []
        for i in range(num_hidden_layers):
            in_features = input_size if i == 0 else hidden_layer_size
            layers.append(nn.Linear(in_features, hidden_layer_size))
            layers.append(nn.ReLU())
        layers.append(nn.Linear(hidden_layer_size, 1))
        layers.append(nn.Sigmoid())
        self.net = nn.Sequential(*layers)

    def forward(self, x):
Stavros Mitsis's avatar
Stavros Mitsis committed
        """
        Forward pass of the network.

        Args:
            x (torch.Tensor): Input tensor.

        Returns:
            torch.Tensor: Output tensor.
        """
Stavros Mitsis's avatar
Stavros Mitsis committed
        return self.net(x)


###############################################################################
# 3. Trainer Class
###############################################################################
class Trainer:
    """
    Encapsulates the training and validation loop, including early stopping
    based on best F3.
    """
    def __init__(self, model, optimizer, criterion, device):
Stavros Mitsis's avatar
Stavros Mitsis committed
        """
        Initialize the Trainer with model, optimizer, criterion, and device.

        Args:
            model (nn.Module): PyTorch model to train.
            optimizer (torch.optim.Optimizer): Optimizer for the model.
            criterion (nn.Module): Loss function.
            device (torch.device): Device to use ('cpu' or 'cuda').
        """
Stavros Mitsis's avatar
Stavros Mitsis committed
        self.model = model
        self.optimizer = optimizer
        self.criterion = criterion
        self.device = device

    def train_and_validate(
Stavros Mitsis's avatar
Stavros Mitsis committed
            self,
            train_loader,
            early_stop_loader,
            threshold=0.5,
            patience=20,
            max_epochs=500,
Stavros Mitsis's avatar
Stavros Mitsis committed
    ):
        """
Stavros Mitsis's avatar
Stavros Mitsis committed
        Train and validate the model with early stopping on the best mixed metric.
Stavros Mitsis's avatar
Stavros Mitsis committed

Stavros Mitsis's avatar
Stavros Mitsis committed
        Args:
            train_loader (DataLoader): DataLoader for training data.
            early_stop_loader (DataLoader): DataLoader for early stopping data.
            threshold (float): Threshold for classification.
            patience (int): Early stopping patience.
            max_epochs (int): Maximum number of training epochs.
Stavros Mitsis's avatar
Stavros Mitsis committed

Stavros Mitsis's avatar
Stavros Mitsis committed
        Returns:
            tuple: (best_model (nn.Module), best_metric (float)).
Stavros Mitsis's avatar
Stavros Mitsis committed
        """
Stavros Mitsis's avatar
Stavros Mitsis committed
        best_metric = 0.0
Stavros Mitsis's avatar
Stavros Mitsis committed
        patience_counter = 0
        best_model_state = None

        for epoch in range(max_epochs):
Stavros Mitsis's avatar
Stavros Mitsis committed
            # Training loop
Stavros Mitsis's avatar
Stavros Mitsis committed
            self.model.train()
            for x_batch, y_batch in train_loader:
                x_batch = x_batch.to(self.device)
                y_batch = y_batch.to(self.device)
                self.optimizer.zero_grad()
                preds = self.model(x_batch).squeeze()
                loss = self.criterion(preds, y_batch)
                loss.backward()
                self.optimizer.step()

Stavros Mitsis's avatar
Stavros Mitsis committed
            # Evaluate on the early stopping set
            f3 = self._evaluate_f3(early_stop_loader, threshold)
            acc = self._evaluate_accuracy(early_stop_loader, threshold)

            # Mixed metric for early stopping
            metric = 0.8 * f3 + 0.2 * acc


            if metric > best_metric:
                best_metric = metric
Stavros Mitsis's avatar
Stavros Mitsis committed
                patience_counter = 0
                best_model_state = self.model.state_dict()
            else:
Stavros Mitsis's avatar
Stavros Mitsis committed
                patience_counter += 1
Stavros Mitsis's avatar
Stavros Mitsis committed

            if patience_counter >= patience:
Stavros Mitsis's avatar
Stavros Mitsis committed
                print(f"Early stopping at epoch {epoch + 1}. Best Metric = {best_metric:.4f}")
Stavros Mitsis's avatar
Stavros Mitsis committed
                break

        # Load the best state
        if best_model_state:
            self.model.load_state_dict(best_model_state)
Stavros Mitsis's avatar
Stavros Mitsis committed
        return self.model, best_metric

    def _evaluate_accuracy(self, data_loader, threshold):
        """
        Compute Accuracy on a given DataLoader.

        Args:
            data_loader (DataLoader): DataLoader for evaluation.
            threshold (float): Classification threshold.

        Returns:
            float: Accuracy score.
        """
        self.model.eval()
        preds, labels = [], []
        with torch.no_grad():
            for x_batch, y_batch in data_loader:
                x_batch = x_batch.to(self.device)
                y_batch = y_batch.to(self.device)
                out = self.model(x_batch).squeeze()
                preds.extend((out > threshold).cpu().numpy())
                labels.extend(y_batch.cpu().numpy())
        return accuracy_score(labels, preds)
Stavros Mitsis's avatar
Stavros Mitsis committed

    def _evaluate_f3(self, data_loader, threshold):
        """
        Compute F3 on a given DataLoader.
Stavros Mitsis's avatar
Stavros Mitsis committed

        Args:
            data_loader (DataLoader): DataLoader for evaluation.
            threshold (float): Classification threshold.

        Returns:
            float: F3 score.
Stavros Mitsis's avatar
Stavros Mitsis committed
        """
        self.model.eval()
        preds, labels = [], []
        with torch.no_grad():
            for x_batch, y_batch in data_loader:
                x_batch = x_batch.to(self.device)
                y_batch = y_batch.to(self.device)
                out = self.model(x_batch).squeeze()
                preds.extend((out > threshold).cpu().numpy())
                labels.extend(y_batch.cpu().numpy())
        return fbeta_score(labels, preds, beta=3)


###############################################################################
# 4. Helper Functions for Training / Evaluation
###############################################################################

def cross_validate_model(
    data: pd.DataFrame,
    num_hidden_layers: int,
    hidden_layer_size: int,
    learning_rate: float,
    batch_size: int,
    threshold: float,
    device: torch.device,
    n_splits=5
):
    """
    Perform n-fold cross-validation, returning average F3 across folds.
Stavros Mitsis's avatar
Stavros Mitsis committed

    Args:
        data (pd.DataFrame): Input data with features and 'aki' column.
        num_hidden_layers (int): Number of hidden layers.
        hidden_layer_size (int): Number of neurons per hidden layer.
        learning_rate (float): Learning rate for the optimizer.
        batch_size (int): Batch size for DataLoader.
        threshold (float): Classification threshold.
        device (torch.device): Device to use ('cpu' or 'cuda').
        n_splits (int): Number of folds for cross-validation.

    Returns:
        float: Average F3 score across all folds.
Stavros Mitsis's avatar
Stavros Mitsis committed
    """
    skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
    f3_scores = []

    for train_idx, val_idx in skf.split(data, data['aki']):
Stavros Mitsis's avatar
Stavros Mitsis committed
        # Split into train (80%) and validation (20%)
        full_train_data = data.iloc[train_idx].copy()
        val_data = data.iloc[val_idx].copy()

        # Further split the train data into 90% train and 10% train EA
        train_data, train_ea_data = train_test_split(
            full_train_data,
            test_size=0.1,
            stratify=full_train_data['aki'],
            random_state=42
        )
Stavros Mitsis's avatar
Stavros Mitsis committed

Stavros Mitsis's avatar
Stavros Mitsis committed
        # Create data loaders
        train_loader = DataLoader(AKIDataset(train_data), batch_size=batch_size, shuffle=True)
        train_ea_loader = DataLoader(AKIDataset(train_ea_data), batch_size=batch_size, shuffle=False)
        val_loader = DataLoader(AKIDataset(val_data), batch_size=batch_size, shuffle=False)
Stavros Mitsis's avatar
Stavros Mitsis committed

Stavros Mitsis's avatar
Stavros Mitsis committed
        # Instantiate the model
Stavros Mitsis's avatar
Stavros Mitsis committed
        model = AKINet(
Stavros Mitsis's avatar
Stavros Mitsis committed
            input_size=train_data.shape[1] - 1,
Stavros Mitsis's avatar
Stavros Mitsis committed
            num_hidden_layers=num_hidden_layers,
            hidden_layer_size=hidden_layer_size
        ).to(device)

        criterion = nn.BCELoss()
        optimizer = Adam(model.parameters(), lr=learning_rate)
Stavros Mitsis's avatar
Stavros Mitsis committed
        trainer = Trainer(model, optimizer, criterion, device)

        # Train the model with the new early stopping criterion
        best_model, best_f3_acc = trainer.train_and_validate(
Stavros Mitsis's avatar
Stavros Mitsis committed
            train_loader,
Stavros Mitsis's avatar
Stavros Mitsis committed
            train_ea_loader,
Stavros Mitsis's avatar
Stavros Mitsis committed
            threshold=threshold,
            patience=20,
            max_epochs=500
        )
Stavros Mitsis's avatar
Stavros Mitsis committed

        # Evaluate on the validation set
        val_f3 = trainer._evaluate_f3(val_loader, threshold)
        f3_scores.append(val_f3)

    # Return the average F3 score across all folds
Stavros Mitsis's avatar
Stavros Mitsis committed
    return np.mean(f3_scores)


Stavros Mitsis's avatar
Stavros Mitsis committed
def train_model(
    data: pd.DataFrame,
    hyperparams: dict,
    device: torch.device,
Stavros Mitsis's avatar
Stavros Mitsis committed
    val_split=0.1
Stavros Mitsis's avatar
Stavros Mitsis committed
):
    """
Stavros Mitsis's avatar
Stavros Mitsis committed
    Train a final model using the given hyperparameters. Splits data into
Stavros Mitsis's avatar
Stavros Mitsis committed
    train and validation sets, validation set is used for early stopping. Train split is 90% , validation split is 10%.
Stavros Mitsis's avatar
Stavros Mitsis committed

Stavros Mitsis's avatar
Stavros Mitsis committed
    Args:
        data (pd.DataFrame): Input data with features and 'aki' column.
        hyperparams (dict): Hyperparameters for model training.
        device (torch.device): Device to use ('cpu' or 'cuda').
        val_split (float): Proportion of data for validation.

    Returns:
        nn.Module: Trained PyTorch model.
Stavros Mitsis's avatar
Stavros Mitsis committed
    """
    # Unpack hyperparams
    num_hidden_layers = hyperparams["num_hidden_layers"]
    hidden_layer_size = hyperparams["hidden_layer_size"]
    learning_rate     = hyperparams["learning_rate"]
    batch_size        = hyperparams["batch_size"]
    threshold         = hyperparams["threshold"]  # for early stopping measurement

    # Train-val split
    train_df, val_df = train_test_split(data, test_size=val_split, stratify=data['aki'], random_state=42)

    # Create loaders
    train_loader = DataLoader(AKIDataset(train_df), batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(AKIDataset(val_df),   batch_size=batch_size, shuffle=False)

    # Instantiate model
    model = AKINet(
        input_size=train_df.shape[1] - 1,
        num_hidden_layers=num_hidden_layers,
        hidden_layer_size=hidden_layer_size
    ).to(device)

    criterion = nn.BCELoss()
    optimizer = Adam(model.parameters(), lr=learning_rate)
    trainer   = Trainer(model, optimizer, criterion, device)

    # Train with early stopping
    best_model, best_f3 = trainer.train_and_validate(
        train_loader,
        val_loader,
        threshold=threshold,
        patience=20,
        max_epochs=500
    )

    print(f"Final model after training. Best F3 on validation: {best_f3:.4f}")
    # Save entire model (no need for separate architecture instantiation later)
    torch.save(best_model, "best_model.pth")
    print("Saved entire model to best_model.pth")
    return best_model


def evaluate(model, data: pd.DataFrame, device: torch.device, threshold=0.5):
    """
Stavros Mitsis's avatar
Stavros Mitsis committed
    Evaluate a trained model on a dataset. Compute metrics if 'aki' column exists.
Stavros Mitsis's avatar
Stavros Mitsis committed
    The difference between the hidden method above of __evaluate F3 score, is that this
    function requires a model as input. As such this can be used with the resutling final model in mind
    The reason for this is because during hyper parameter tuning we dont save any models and this allows the
    code excecution of __evaluate F3 score to be faster rather than using this function.
Stavros Mitsis's avatar
Stavros Mitsis committed
    Args:
Stavros Mitsis's avatar
Stavros Mitsis committed
        model (nn.Module): Trained PyTorch model.
        data (pd.DataFrame): Input data for evaluation or inference.
        device (torch.device): Device to use ('cpu' or 'cuda').
Stavros Mitsis's avatar
Stavros Mitsis committed
        threshold (float): Classification threshold.

    Returns:
        dict with possible keys:
            - 'predictions': List/array of 0/1 predictions
            - 'f3': F3 score (if 'aki' in data)
            - 'accuracy': Accuracy score (if 'aki' in data)
    """
    # If 'aki' is missing, create a dummy column for Dataset
    has_labels = 'aki' in data.columns
    if not has_labels:
        data = data.copy()
        data['aki'] = 0  # dummy

    loader = DataLoader(AKIDataset(data), batch_size=64, shuffle=False)
    model.eval()
    model.to(device)

    all_preds = []
    all_labels = []

    with torch.no_grad():
        for x_batch, y_batch in loader:
            x_batch = x_batch.to(device)
            y_batch = y_batch.to(device)
            out = model(x_batch).squeeze()
            batch_preds = (out > threshold).float().cpu().numpy()
            all_preds.extend(batch_preds)
            all_labels.extend(y_batch.cpu().numpy())

    result = {
        "predictions": np.array(all_preds, dtype=int).tolist()
    }

    if has_labels:
        f3 = fbeta_score(all_labels, all_preds, beta=3)
        acc = accuracy_score(all_labels, all_preds)
        result["f3"] = f3
        result["accuracy"] = acc

    return result


###############################################################################
# 5. Hyperparameter Tuning (Optuna)
###############################################################################

def objective(trial, data, device):
    """
Stavros Mitsis's avatar
Stavros Mitsis committed
    Objective function for Optuna hyperparameter optimization.

    Args:
        trial (optuna.Trial): Optuna trial object.
        data (pd.DataFrame): Input data for cross-validation.
        device (torch.device): Device to use ('cpu' or 'cuda').

    Returns:
        float: Average F3 score from cross-validation.
Stavros Mitsis's avatar
Stavros Mitsis committed
    """
    # Suggest hyperparameters
    num_hidden_layers = trial.suggest_int("num_hidden_layers", 1, 5)
    hidden_layer_size = trial.suggest_categorical(
Stavros Mitsis's avatar
Stavros Mitsis committed
        "hidden_layer_size", [32, 48,64, 96,128,184, 256, 512, 1024]
Stavros Mitsis's avatar
Stavros Mitsis committed
    )
    learning_rate = trial.suggest_categorical(
        "learning_rate", [0.001, 0.005, 0.01, 0.02, 0.05]
    )
    batch_size = trial.suggest_categorical(
        "batch_size", [32, 64, 128, 256, 512, 1024]
    )
    threshold = trial.suggest_categorical(
Stavros Mitsis's avatar
Stavros Mitsis committed
        "threshold", [0.25,0.35, 0.375,0.4,0.45, 0.5,0.55, 0.625, 0.75]
Stavros Mitsis's avatar
Stavros Mitsis committed
    )

    # Perform cross-validation
    avg_f3 = cross_validate_model(
        data=data,
        num_hidden_layers=num_hidden_layers,
        hidden_layer_size=hidden_layer_size,
        learning_rate=learning_rate,
        batch_size=batch_size,
        threshold=threshold,
        device=device,
        n_splits=5
    )

    return avg_f3


Stavros Mitsis's avatar
Stavros Mitsis committed
def tune_hyperparameters(data: pd.DataFrame, n_trials=30):
Stavros Mitsis's avatar
Stavros Mitsis committed
    """
Stavros Mitsis's avatar
Stavros Mitsis committed
    Perform hyperparameter tuning using Optuna.
Stavros Mitsis's avatar
Stavros Mitsis committed

Stavros Mitsis's avatar
Stavros Mitsis committed
    Args:
        data (pd.DataFrame): Input data for tuning.
        n_trials (int): Number of trials for Optuna optimization.

    Returns:
        optuna.trial.FrozenTrial: Best trial.
Stavros Mitsis's avatar
Stavros Mitsis committed
    """
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    study = optuna.create_study(direction="maximize")
    study.optimize(lambda trial: objective(trial, data, device), n_trials=n_trials)

    # Print best results
    print("\nBest trial:")
    print(f"Hyperparameters: {study.best_trial.params}")
    print(f"Avg F3 Score (CV): {study.best_trial.value:.4f}")

    # Save best hyperparameters
    with open("best_hyperparameters.json", "w") as f:
        json.dump(study.best_trial.params, f)
    print("Saved best hyperparameters to best_hyperparameters.json")

    return study.best_trial