An adaptation of the Introduction to PyTorch* Lightning tutorial using Intel® Gaudi® AI processors.
In this tutorial, we go over the basics of Lightning by preparing models to train on the MNIST Handwritten Digits dataset.
Setup
This tutorial requires some packages besides pytorch-lightning.
Warning Running pip as the 'root' user can result in broken permissions and conflicting behavior with the system package manager. It is recommended to use a virtual environment instead: https://docs.python.org/3/tutorial/venv.html
Simplest Example
Here’s the simplest most minimal example with just a training loop (no validation, no testing).
By using the Trainer you automatically get:
- TensorBoard* logging
- Model checkpointing
- Training and validation loop
- Early-stopping
To enable PyTorch Lightning to use the HPU accelerator, simply provide accelerator="hpu" parameter to the Trainer class.
A More Complete MNIST Lightning Module Example
Let's dive in a bit deeper and write a more complete LightningModule for MNIST.
This time, we bake in all the dataset specific pieces directly in the LightningModule. This way, we can avoid writing extra code at the beginning of our script every time we want to run it.
Note what the following built-in functions are doing:
prepare_data() 💾
setup(stage) ⚙️
x_dataloader() ♻️
- train_dataloader(), val_dataloader(), and test_dataloader() all return PyTorch DataLoader instances that are created by wrapping their respective datasets that we prepared in setup()
class LitMNIST(LightningModule):
def __init__(self, data_dir=PATH_DATASETS, hidden_size=64, learning_rate=2e-4):
super().__init__()
self.data_dir = data_dir
self.hidden_size = hidden_size
self.learning_rate = learning_rate
self.num_classes = 10
self.dims = (1, 28, 28)
channels, width, height = self.dims
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)),
]
)
self.model = nn.Sequential(
nn.Flatten(),
nn.Linear(channels * width * height, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, hidden_size),
nn.ReLU(),
nn.Dropout(0.1),
nn.Linear(hidden_size, self.num_classes),
)
self.accuracy = Accuracy(task="multiclass", num_classes=self.num_classes)
def forward(self, x):
x = self.model(x)
return F.log_softmax(x, dim=1)
def training_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
return loss
def validation_step(self, batch, batch_idx):
x, y = batch
logits = self(x)
loss = F.nll_loss(logits, y)
preds = torch.argmax(logits, dim=1)
self.accuracy(preds, y)
self.log("val_loss", loss, prog_bar=True)
self.log("val_acc", self.accuracy, prog_bar=True)
return loss
def test_step(self, batch, batch_idx):
return self.validation_step(batch, batch_idx)
def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
return optimizer
def prepare_data(self):
MNIST(self.data_dir, train=True, download=True)
MNIST(self.data_dir, train=False, download=True)
def setup(self, stage=None):
if stage == "fit" or stage is None:
mnist_full = MNIST(self.data_dir, train=True, transform=self.transform)
self.mnist_train, self.mnist_val = random_split(mnist_full, [55000, 5000])
if stage == "test" or stage is None:
self.mnist_test = MNIST(self.data_dir, train=False, transform=self.transform)
def train_dataloader(self):
return DataLoader(self.mnist_train, batch_size=BATCH_SIZE)
def val_dataloader(self):
return DataLoader(self.mnist_val, batch_size=BATCH_SIZE)
def test_dataloader(self):
return DataLoader(self.mnist_test, batch_size=BATCH_SIZE)
Train the Model on Intel Gaudi Accelerators
Remember to enable PyTorch Lightning to use the HPU accelerator. Simply provide accelerator="hpu" parameter to the Trainer class.
Input op list would be overridden in opt_level O2 hmp:verbose_mode False hmp:opt_level O2 GPU available: False, used: False TPU available: False, using: 0 TPU cores IPU available: False, using: 0 IPUs HPU available: True, using: 1 HPUs | Name | Type | Params ---------------------------------------- 0 | model | Sequential | 55.1 K 1 | accuracy | Accuracy | 0 ---------------------------------------- 55.1 K Trainable params 0 Non-trainable params 55.1 K Total params 0.110 Total estimated model params size (MB) Sanity Checking DataLoader 0: 0%| | 0/2 [00:00, ?it/s] /usr/local/lib/python3.8/dist-packages/pytorch_lightning/trainer/connectors/data_connector.py:236: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 96 which is the number of cpus on this machine) in the `DataLoader` init to improve performance. rank_zero_warn( Epoch 0: 85%|███████████████████████████████████████████████████████████████████████████████████████████ | 200/235 [00:11<00:01, 18.03it/s, loss=0.544, v_num=1] Validation: 0it [00:00, ?it/s] Validation: 0%| | 0/20 [00:00, ?it/s] Epoch 0: 94%|
████████████████████████████████████████████████████████████████████████████████████████████████████▏ | 220/235 [00:12<00:00, 18.00it/s, loss=0.544, v_num=1] Epoch 0: 100%|
█████████████████████████████████████████████████████████████████████████████| 235/235 [00:13<00:00, 18.00it/s, loss=0.51, v_num=1, val_loss=0.420, val_acc=0.886] Epoch 1: 85%|████████████████████████████████████████████████████████████████▋ | 200/235 [00:08<00:01, 23.58it/s, loss=0.381, v_num=1, val_loss=0.420, val_acc=0.886] Validation: 0it [00:00, ?it/s] Validation: 0%| | 0/20 [00:00, ?it/s] Epoch 1: 94%|
███████████████████████████████████████████████████████████████████████▏ | 220/235 [00:09<00:00, 23.22it/s, loss=0.381, v_num=1, val_loss=0.420, val_acc=0.886] Epoch 1: 100%|
████████████████████████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 23.18it/s, loss=0.362, v_num=1, val_loss=0.311, val_acc=0.907] Epoch 2: 85%|
████████████████████████████████████████████████████████████████▋ | 200/235 [00:08<00:01, 23.73it/s, loss=0.319, v_num=1, val_loss=0.311, val_acc=0.907] Validation: 0it [00:00, ?it/s] Validation: 0%| | 0/20 [00:00, ?it/s] Epoch 2: 94%|
███████████████████████████████████████████████████████████████████████▏ | 220/235 [00:09<00:00, 23.39it/s, loss=0.319, v_num=1, val_loss=0.311, val_acc=0.907] Epoch 2: 100%|
████████████████████████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 23.28it/s, loss=0.303, v_num=1, val_loss=0.267, val_acc=0.919] Epoch 2: 100%|
████████████████████████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 23.27it/s, loss=0.303, v_num=1, val_loss=0.267, val_acc=0.919] `Trainer.fit` stopped: `max_epochs=3` reached. Epoch 2: 100%|████████████████████████████████████████████████████████████████████████████| 235/235 [00:10<00:00, 23.25it/s, loss=0.303, v_num=1, val_loss=0.267, val_acc=0.919]
Testing
To test a model, call trainer.test(model).
Or, if you have just trained a model, you can just call trainer.test() and Lightning will automatically test using the best-saved checkpoint (conditioned on val_loss).
Copyright (c) 2022 Habana Labs, Ltd. an Intel Company.
All rights reserved.
License
Licensed under a CC BY SA 4.0 license.
A derivative of Introduction To PyTorch .Lightning by the PyTorch Lightning team.