diff --git a/fraud_detection/fraud_detection_model.py b/fraud_detection/fraud_detection_model.py index 7182d5030..edac1a2fa 100644 --- a/fraud_detection/fraud_detection_model.py +++ b/fraud_detection/fraud_detection_model.py @@ -3,14 +3,17 @@ from sklearn.ensemble import RandomForestClassifier from sklearn.model_selection import train_test_split + class FraudDetectionModel: def __init__(self, data): self.data = data - self.X = self.data.drop('fraudulent', axis=1) - self.y = self.data['fraudulent'] + self.X = self.data.drop("fraudulent", axis=1) + self.y = self.data["fraudulent"] def train_model(self): - X_train, X_test, y_train, y_test = train_test_split(self.X, self.y, test_size=0.2, random_state=42) + X_train, X_test, y_train, y_test = train_test_split( + self.X, self.y, test_size=0.2, random_state=42 + ) self.model = RandomForestClassifier(n_estimators=100, random_state=42) self.model.fit(X_train, y_train)