- Simple Neural Network Model
- Training Model and Saving it(.pth),
- Loading model and using it for prediction.
We'll use a small dataset for demonstration, like the classic MNIST dataset, which consists of handwritten digits.
Step 1: Import Libraries and Define the Model
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
# Define a simple neural network
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(28*28, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1, 28*28)
x = torch.relu(self.fc1(x))
x = self.fc2(x)
return x
# Instantiate the model, define loss function and optimizer
model = SimpleNN()
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
Step 2: Load the Dataset and Train the Model
# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
trainset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=64, shuffle=True)
# Train the model
for epoch in range(1): # Train for 1 epoch for simplicity
for images, labels in trainloader:
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
print('Training complete!')
Step 3: Save the Model
# Save the model state dictionary
torch.save(model.state_dict(), 'simple_nn.pth')
print('Model saved!')
Step 4: Load the Model and Make Predictions
# Load the model state dictionary
loaded_model = SimpleNN()
loaded_model.load_state_dict(torch.load('simple_nn.pth'))
loaded_model.eval() # Set the model to evaluation mode
# Make a prediction on a single image
test_image, label = trainset[20] # Use the 20th image from the training set as an example
test_image = test_image.unsqueeze(0) # Add a batch dimension
# Display the image
plt.imshow(test_image.squeeze(), cmap='gray')
plt.title(f'Actual Label: {label}')
plt.axis('off')
plt.show()
output = loaded_model(test_image)
_, predicted = torch.max(output, 1)
print('Predicted label:', predicted.item())
Hope you found this post helpful and enjoyable.
Thank you!
Top comments (2)
Hey, great post! We really enjoyed it. You might be interested in knowing how to productionalise ML models with a simple line of code. If so, please have a look at flama for Python. Some time ago we published a post Introducing Flama for Robust ML APIs. We think you might really enjoy the post, and flama.
If you have any doubts, or you'd like to learn more about it and how it works in more detail, don't hesitate to give us a shout. And if you like it, please gift us a star ⭐ here.
I am happy to explore flama.