Skip to content
Snippets Groups Projects
Commit d298fd7d authored by Stavros Mitsis's avatar Stavros Mitsis
Browse files

changed model slightly

parent f6b4414a
No related branches found
No related tags found
No related merge requests found
{"num_hidden_layers": 1, "hidden_layer_size": 48, "learning_rate": 0.01, "batch_size": 64, "threshold": 0.25}
\ No newline at end of file
{"num_hidden_layers": 1, "hidden_layer_size": 96, "learning_rate": 0.02, "batch_size": 512, "threshold": 0.4}
\ No newline at end of file
No preview for this file type
......@@ -290,27 +290,27 @@ class Trainer:
self.device = device
def train_and_validate(
self,
train_loader,
val_loader,
threshold=0.5,
patience=20,
max_epochs=500
self,
train_loader,
early_stop_loader,
threshold=0.5,
patience=20,
max_epochs=500,
):
"""
Train and validate the model with early stopping on best F3 score.
Train and validate the model with early stopping on the best mixed metric.
Args:
train_loader (DataLoader): DataLoader for training data.
val_loader (DataLoader): DataLoader for validation data.
threshold (float): Threshold for classification.
patience (int): Early stopping patience.
max_epochs (int): Maximum number of training epochs.
Args:
train_loader (DataLoader): DataLoader for training data.
early_stop_loader (DataLoader): DataLoader for early stopping data.
threshold (float): Threshold for classification.
patience (int): Early stopping patience.
max_epochs (int): Maximum number of training epochs.
Returns:
tuple: (best_model (nn.Module), best_f3 (float)).
Returns:
tuple: (best_model (nn.Module), best_metric (float)).
"""
best_val_f3 = 0.0
best_metric = 0.0
patience_counter = 0
best_model_state = None
......@@ -326,23 +326,51 @@ class Trainer:
loss.backward()
self.optimizer.step()
# Validation of the model at each epoch
val_f3 = self._evaluate_f3(val_loader, threshold)
if val_f3 > best_val_f3:
best_val_f3 = val_f3
# Evaluate on the early stopping set
f3 = self._evaluate_f3(early_stop_loader, threshold)
acc = self._evaluate_accuracy(early_stop_loader, threshold)
# Mixed metric for early stopping
metric = 0.8 * f3 + 0.2 * acc
if metric > best_metric:
best_metric = metric
patience_counter = 0
best_model_state = self.model.state_dict()
else:
patience_counter += 1 # Early stopping counter
patience_counter += 1
if patience_counter >= patience:
print(f"Early stopping at epoch {epoch+1}. Best F3 = {best_val_f3:.4f}")
print(f"Early stopping at epoch {epoch + 1}. Best Metric = {best_metric:.4f}")
break
# Load the best state
if best_model_state:
self.model.load_state_dict(best_model_state)
return self.model, best_val_f3
return self.model, best_metric
def _evaluate_accuracy(self, data_loader, threshold):
"""
Compute Accuracy on a given DataLoader.
Args:
data_loader (DataLoader): DataLoader for evaluation.
threshold (float): Classification threshold.
Returns:
float: Accuracy score.
"""
self.model.eval()
preds, labels = [], []
with torch.no_grad():
for x_batch, y_batch in data_loader:
x_batch = x_batch.to(self.device)
y_batch = y_batch.to(self.device)
out = self.model(x_batch).squeeze()
preds.extend((out > threshold).cpu().numpy())
labels.extend(y_batch.cpu().numpy())
return accuracy_score(labels, preds)
def _evaluate_f3(self, data_loader, threshold):
"""
......@@ -397,39 +425,56 @@ def cross_validate_model(
Returns:
float: Average F3 score across all folds.
"""
# 5 FOLD CROSS VALIDATION
skf = StratifiedKFold(n_splits=n_splits, shuffle=True, random_state=42)
f3_scores = []
for train_idx, val_idx in skf.split(data, data['aki']):
train_fold = data.iloc[train_idx].copy()
val_fold = data.iloc[val_idx].copy()
# Split into train (80%) and validation (20%)
full_train_data = data.iloc[train_idx].copy()
val_data = data.iloc[val_idx].copy()
# Further split the train data into 90% train and 10% train EA
train_data, train_ea_data = train_test_split(
full_train_data,
test_size=0.1,
stratify=full_train_data['aki'],
random_state=42
)
train_loader = DataLoader(AKIDataset(train_fold), batch_size=batch_size, shuffle=True)
val_loader = DataLoader(AKIDataset(val_fold), batch_size=batch_size, shuffle=False)
# Create data loaders
train_loader = DataLoader(AKIDataset(train_data), batch_size=batch_size, shuffle=True)
train_ea_loader = DataLoader(AKIDataset(train_ea_data), batch_size=batch_size, shuffle=False)
val_loader = DataLoader(AKIDataset(val_data), batch_size=batch_size, shuffle=False)
# Instantiate the model
model = AKINet(
input_size=train_fold.shape[1] - 1,
input_size=train_data.shape[1] - 1,
num_hidden_layers=num_hidden_layers,
hidden_layer_size=hidden_layer_size
).to(device)
criterion = nn.BCELoss()
optimizer = Adam(model.parameters(), lr=learning_rate)
trainer = Trainer(model, optimizer, criterion, device)
# For each fold and for each set of hyper parameters, compute the F3 score on the resulting trained model
_, best_f3 = trainer.train_and_validate(
trainer = Trainer(model, optimizer, criterion, device)
# Train the model with the new early stopping criterion
best_model, best_f3_acc = trainer.train_and_validate(
train_loader,
val_loader,
train_ea_loader,
threshold=threshold,
patience=20,
max_epochs=500
)
f3_scores.append(best_f3)
# Return the average F3 scores across all folds for a given hyper parameter set
# Evaluate on the validation set
val_f3 = trainer._evaluate_f3(val_loader, threshold)
f3_scores.append(val_f3)
# Return the average F3 score across all folds
return np.mean(f3_scores)
def train_model(
data: pd.DataFrame,
hyperparams: dict,
......
"""
main.py
Example usage of the refactored helper_functions for AKI prediction.
"""
import argparse
......@@ -27,7 +26,7 @@ def hyper_parameter_tuning():
processor = DataProcessor()
data = processor.preprocess(['training.csv'], save_constants=True)
balanced_data = processor.handle_class_imbalance(data)
best_trial = tune_hyperparameters(balanced_data, n_trials=100)
best_trial = tune_hyperparameters(balanced_data, n_trials=5)
return best_trial
......
{"age": {"mean": 37.24667853718669, "std": 21.668772847258783}, "latest_creatinine_value": {"mean": 165.47376797698945, "std": 92.37864985819138}, "median_previous": {"mean": 134.49263183125598, "std": 46.189196224629306}, "mean_previous": {"mean": 134.50308092162888, "std": 46.03514202518719}, "std_dev_previous": {"mean": 10.672707528120002, "std": 7.2736730536565775}, "abs_percentage_diff": {"mean": 0.2868546764758316, "std": 0.4044325731446869}}
\ No newline at end of file
{"age": {"mean": 37.26280350948786, "std": 21.842276857225034}, "latest_creatinine_value": {"mean": 166.31846085832825, "std": 94.44752591026013}, "median_previous": {"mean": 134.41235802217233, "std": 46.13848554494386}, "mean_previous": {"mean": 134.4250456736426, "std": 46.02537337310225}, "std_dev_previous": {"mean": 10.737323990705239, "std": 7.306492442412928}, "abs_percentage_diff": {"mean": 0.2897293206129817, "std": 0.40866939525398815}}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment