Skip to content
Snippets Groups Projects
Commit 5ce9a782 authored by jh1724's avatar jh1724
Browse files

add try catch

parent ec615fb7
No related branches found
No related tags found
No related merge requests found
......@@ -25,7 +25,7 @@ def data_preprocessing(
- scaler (StandardScaler): A fitted StandardScaler for transforming the test data
Returns:
- Tuple[pd.DataFrame, StandardScaler]: The preprocessed data and the fitted scaler
- Tuple[np.ndarray, StandardScaler]: The preprocessed data and the fitted scaler
"""
# Convert 'sex' feature from string to integer (1 for 'm', 0 for 'f')
......@@ -64,10 +64,18 @@ def main():
flags = parser.parse_args()
# Load training and testing data
training_data = pd.read_csv(flags.training_data)
testing_data = pd.read_csv(flags.input)
try:
training_data = pd.read_csv(flags.training_data)
testing_data = pd.read_csv(flags.input)
except FileNotFoundError as e:
print(f"Error: The filepath was incorrect - {e.filename}")
sys.exit(1)
except pd.errors.ParserError as e:
print(f"Error: There was an issue parsing the CSV file - {e}")
sys.exit(1)
# extract samples and labels
assert 'aki' in training_data.columns, "Error: Training data must contain trained labels."
training_x = training_data.drop(columns=['aki'])
# Ignore 'aki' if not present in test data
testing_x = testing_data.drop(columns=['aki'], errors='ignore')
......@@ -81,6 +89,7 @@ def main():
# only use common columns to train and test
common_cols = training_x.columns.intersection(testing_x.columns)
assert len(common_cols) > 0, "Error: Data must have at least one column in common"
training_x = training_x[common_cols]
testing_x = testing_x[common_cols]
......@@ -93,7 +102,6 @@ def main():
testing_y = testing_y.to_numpy()
# use SMOTE to handle imbalanced data
smote = SMOTE(sampling_strategy='auto', random_state=21)
x_resampled, y_resampled = smote.fit_resample(training_x, training_y)
......@@ -115,10 +123,12 @@ def main():
print(f"Accuracy: {accuracy}")
# Write predictions to aki.csv
predictions_df = pd.DataFrame(predictions, columns=['aki'])
predictions_df['aki'] = predictions_df['aki'].map({0: 'n', 1: 'y'})
predictions_df.to_csv(flags.output, index=False)
try:
predictions_df = pd.DataFrame(predictions, columns=['aki'])
predictions_df['aki'] = predictions_df['aki'].map({0: 'n', 1: 'y'})
predictions_df.to_csv(flags.output, index=False)
except Exception as e:
print(f"Error when writing to ask.csv: {e}")
if __name__ == "__main__":
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment