Skip to content

Commit

Permalink
interactive py script
Browse files Browse the repository at this point in the history
  • Loading branch information
anthonychen000 authored Sep 1, 2024
1 parent f739b0d commit 5d3acd4
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions example_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from tqdm import tqdm
from sklearn.metrics import accuracy_score, f1_score

# Define a custom dataset
# Define custom dataset
class MNISTDataset(Dataset):
def __init__(self, csv_file):
self.data = pd.read_csv(csv_file)
Expand All @@ -17,15 +17,15 @@ def __len__(self):
return len(self.data)

def __getitem__(self, idx):
image = torch.tensor(self.images[idx]).unsqueeze(0)
image = torch.tensor(self.images[ idx]).unsqueeze(0)
label = torch.tensor(self.labels[idx], dtype=torch.long)
return image, label

# Load the data
train_data_path = 'mnist_train.csv' # Replace with your CSV file path
train_data_path = 'dir_1/dir_2/dir_3' # Replace with train CSV file path
train_dataset = MNISTDataset(train_data_path)

test_data_path = 'mnist_test.csv'
test_data_path = 'dir_1/dir_2/dir_3' # Replace with test CSV file path
test_dataset = MNISTDataset(test_data_path)

# DataLoader objects
Expand Down Expand Up @@ -94,4 +94,4 @@ def evaluate_model(model, test_loader):

# Train and evaluate the model
train_model(model, train_loader, criterion, optimizer, num_epochs=5)
evaluate_model(model, test_loader)
evaluate_model(model, test_loader)

0 comments on commit 5d3acd4

Please sign in to comment.