Maggy distributed training ResNet-50 on ImageNet (Petastorm)

Training ResNet-50 on ImageNet from a Petastorm dataset

In this notebook, we are going to train a ResNet-50 network on a subset of 10 labels of the original ImageNet dataset. In order to improve our I/O time compared to the standard ImageNet training, we are going to use the Petastorm version of the dataset created in ImageNet_to_petastorm.

from hops import hdfs
from torchvision import models
Starting Spark application
IDYARN Application IDKindStateSpark UIDriver log
186application_1617699042861_0013pysparkidleLinkLink
SparkSession available as 'spark'.

Defining the training function

Just as in the notebooks before, we define our training function. Instead of using the default PyTorch dataloader however, we use the MaggyPetastormDataLoader. This DataLoader aims to mimic its PyTorch counterpart as close as possible. One key difference is that transforms have to be passed to the DataLoader instead of the dataset. You can do so with the transform_spec.

def train_fn(module, hparams, train_set, test_set):
    
    import time    
    import torch
    from torchvision import transforms as T
    
    from maggy.core.patching import MaggyPetastormDataLoader

    
    model = module(**hparams)
    
    n_epochs = 10
    batch_size = 64
    lr_base = 0.1 * 2*batch_size/256
    
    def transform(image_net_row):
        transform = T.Compose([
            T.ToTensor(),
            T.RandomHorizontalFlip()
        ])
        return {"image": transform(image_net_row['image']), "label": image_net_row['label']}

    # Parameters as in https://arxiv.org/pdf/1706.02677.pdf
    optimizer = torch.optim.SGD(model.parameters(), lr=lr_base, momentum=0.9, weight_decay=0.0001, nesterov=True)
    loss_criterion = torch.nn.CrossEntropyLoss()
    
    train_loader = MaggyPetastormDataLoader(train_set, batch_size=batch_size, transform_spec=transform)
    test_loader = MaggyPetastormDataLoader(test_set, batch_size=batch_size, transform_spec=transform)
        
    def eval_model(model, test_loader):
        acc = 0
        model.eval()
        img_cnt = 0
        with torch.no_grad():
            for data in test_loader:
                img, label = data["image"].float(), data["label"].float()
                prediction = model(img)
                acc += torch.sum(torch.argmax(prediction, dim=1) == label).detach()
                img_cnt += len(label.detach())
        acc = acc/float(img_cnt)
        print("Test accuracy: {:.3f}\n".format(acc) + 20*"-")
        return acc

    model.train()
    t_0 = time.time()
    for epoch in range(n_epochs):
        print("-"*20 + "\nStarting new epoch\n")
        model.train()
        t_start = time.time()
        for idx, data in enumerate(train_loader):
            optimizer.zero_grad()
            img, label = data["image"].float(), data["label"].float()
            prediction = model(img)
            loss = loss_criterion(prediction, label.long())
            loss.backward()
            optimizer.step()
            if idx%10 == 0:
                print(f"Working on batch {idx}.")
        t_end = time.time()
        print("Epoch training took {:.0f}s.\n".format(t_end - t_start))
        acc = eval_model(model, test_loader)
    t_1 = time.time()
    minutes, seconds = divmod(t_1 - t_0, 60)
    hours, minutes = divmod(minutes, 60)
    print("-"*20 + "\nTotal training time: {:.0f}h {:.0f}m {:.0f}s.".format(hours, minutes, seconds))
    return float(acc)
train_ds = hdfs.project_path() + "DataSets/ImageNet/PetastormImageNette/train"
test_ds = hdfs.project_path() + "DataSets/ImageNet/PetastormImageNette/test"
print(hdfs.exists(train_ds), hdfs.exists(test_ds))
True True

Configuring the experiment

In this example we are using the PyTorch provided implementation of ResNet50. We therefore do not need to define our own module. In the hparams argument, we can pass any arguments for the network. Of course, this is more a convenience mechanism than a necessity. You could also just define them in the training function itself. However, passing them in the config has the advantage that you can automate distributed training after e.g. hyperparameter search.

from maggy import experiment
from maggy.experiment_config import TorchDistributedConfig

config = TorchDistributedConfig(name='ImageNet_petastorm', module=models.resnet50, hparams={"pretrained": False}, train_set=train_ds, test_set=test_ds, backend="torch")
result = experiment.lagom(train_fn, config)
HBox(children=(FloatProgress(value=0.0, description='Maggy experiment', max=1.0, style=ProgressStyle(descripti…


0: Awaiting worker reservations.
1: Awaiting worker reservations.
0: All executors registered: True
0: Reservations complete, configuring PyTorch.
0: Torch config is {'MASTER_ADDR': '10.0.0.4', 'MASTER_PORT': '46351', 'WORLD_SIZE': '2', 'RANK': '0', 'LOCAL_RANK': '0', 'NCCL_BLOCKING_WAIT': '1', 'NCCL_DEBUG': 'INFO'}
0: Starting distributed training.
1: All executors registered: True
1: Reservations complete, configuring PyTorch.
1: Torch config is {'MASTER_ADDR': '10.0.0.4', 'MASTER_PORT': '46351', 'WORLD_SIZE': '2', 'RANK': '1', 'LOCAL_RANK': '0', 'NCCL_BLOCKING_WAIT': '1', 'NCCL_DEBUG': 'INFO'}
1: Starting distributed training.
0: Petastorm dataset detected in folder hdfs://rpc.namenode.service.consul:8020/Projects/PyTorch_spark_minimal/DataSets/ImageNet/PetastormImageNette/train
0: Petastorm dataset detected in folder hdfs://rpc.namenode.service.consul:8020/Projects/PyTorch_spark_minimal/DataSets/ImageNet/PetastormImageNette/test
1: Petastorm dataset detected in folder hdfs://rpc.namenode.service.consul:8020/Projects/PyTorch_spark_minimal/DataSets/ImageNet/PetastormImageNette/train
1: Petastorm dataset detected in folder hdfs://rpc.namenode.service.consul:8020/Projects/PyTorch_spark_minimal/DataSets/ImageNet/PetastormImageNette/test
0: --------------------
Starting new epoch

1: --------------------
Starting new epoch

0: Working on batch 0.
1: Working on batch 0.
0: Working on batch 10.
1: Working on batch 10.
0: Working on batch 20.
1: Working on batch 20.
0: Working on batch 30.
1: Working on batch 30.
0: Working on batch 40.
1: Working on batch 40.
0: Working on batch 50.
1: Working on batch 50.
0: Working on batch 60.
1: Working on batch 60.
0: Working on batch 70.
1: Working on batch 70.
0: Epoch training took 149s.

1: Epoch training took 149s.

0: Test accuracy: 0.134
--------------------
0: --------------------
Starting new epoch

1: Test accuracy: 0.093
--------------------
1: --------------------
Starting new epoch

1: Working on batch 0.
0: Working on batch 0.
1: Working on batch 10.
0: Working on batch 10.
1: Working on batch 20.
0: Working on batch 20.
1: Working on batch 30.
0: Working on batch 30.
1: Working on batch 40.
0: Working on batch 40.
0: Working on batch 50.
1: Working on batch 50.
1: Working on batch 60.
0: Working on batch 60.
1: Working on batch 70.
0: Working on batch 70.
1: Epoch training took 149s.

0: Epoch training took 150s.

0: Test accuracy: 0.171
--------------------
0: --------------------
Starting new epoch

1: Test accuracy: 0.126
--------------------
1: --------------------
Starting new epoch

0: Working on batch 0.
1: Working on batch 0.
0: Working on batch 10.
1: Working on batch 10.
1: Working on batch 20.
0: Working on batch 20.
0: Working on batch 30.
1: Working on batch 30.
1: Working on batch 40.
0: Working on batch 40.
0: Working on batch 50.
1: Working on batch 50.
0: Working on batch 60.
1: Working on batch 60.
1: Working on batch 70.
0: Working on batch 70.
1: Epoch training took 150s.

0: Epoch training took 150s.

0: Test accuracy: 0.235
--------------------
0: --------------------
Starting new epoch

1: Test accuracy: 0.192
--------------------
1: --------------------
Starting new epoch

1: Working on batch 0.
0: Working on batch 0.
0: Working on batch 10.
1: Working on batch 10.
1: Working on batch 20.
0: Working on batch 20.
0: Working on batch 30.
1: Working on batch 30.
0: Working on batch 40.
1: Working on batch 40.
0: Working on batch 50.
1: Working on batch 50.
0: Working on batch 60.
1: Working on batch 60.
1: Working on batch 70.
0: Working on batch 70.
0: Epoch training took 150s.

1: Epoch training took 149s.

1: Test accuracy: 0.277
--------------------
1: --------------------
Starting new epoch

0: Test accuracy: 0.188
--------------------
0: --------------------
Starting new epoch

1: Working on batch 0.
0: Working on batch 0.
1: Working on batch 10.
0: Working on batch 10.
1: Working on batch 20.
0: Working on batch 20.
1: Working on batch 30.
0: Working on batch 30.
1: Working on batch 40.
0: Working on batch 40.
1: Working on batch 50.
0: Working on batch 50.
1: Working on batch 60.
0: Working on batch 60.
1: Working on batch 70.
0: Working on batch 70.
1: Epoch training took 150s.

0: Epoch training took 151s.

0: Test accuracy: 0.305
--------------------
0: --------------------
Starting new epoch

1: Test accuracy: 0.311
--------------------
1: --------------------
Starting new epoch

0: Working on batch 0.
1: Working on batch 0.
1: Working on batch 10.
0: Working on batch 10.
0: Working on batch 20.
1: Working on batch 20.
1: Working on batch 30.
0: Working on batch 30.
1: Working on batch 40.
0: Working on batch 40.
0: Working on batch 50.
1: Working on batch 50.
1: Working on batch 60.
0: Working on batch 60.
0: Working on batch 70.
1: Working on batch 70.
0: Epoch training took 151s.

1: Epoch training took 150s.

0: Test accuracy: 0.377
--------------------
0: --------------------
Starting new epoch

1: Test accuracy: 0.335
--------------------
1: --------------------
Starting new epoch

0: Working on batch 0.
1: Working on batch 0.
0: Working on batch 10.
1: Working on batch 10.
0: Working on batch 20.
1: Working on batch 20.
0: Working on batch 30.
1: Working on batch 30.
0: Working on batch 40.
1: Working on batch 40.