Maggy PyTorch HParam Tuning Example

# Start spark session
print('Startup')
# Import maggy, define searchspace
from maggy import Searchspace

sp = Searchspace(l1_size=('Integer', [2,32]), l2_size=('Integer', [2,32]), batch_size=('integer', [2,16]))
# Hyperparameter tuning. Create oblivious training function.
from maggy import experiment

def training_function(l1_size, l2_size, batch_size, reporter):
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import math
        
    # define torch model
    class NeuralNetwork(nn.Module):
        
        def __init__(self, l1_size, l2_size):
            super().__init__()
            self.linear1 = nn.Linear(2,l1_size)
            self.linear2 = nn.Linear(l1_size,l2_size)
            self.output = nn.Linear(l2_size, 1)
            
        def forward(self, x):
            x = torch.relu(self.linear1(x))
            x = torch.relu(self.linear2(x))
            return self.output(x)
        
    # define training parameters
    net = NeuralNetwork(l1_size, l2_size)
    epochs = 100
    learning_rate = 1e-3
    optimizer = optim.Adam(net.parameters(), lr=learning_rate)
    
    # define random training data
    x = torch.reshape(torch.rand(1000,2), (-1,2))
    y = torch.reshape(x[:,0] * torch.exp(x[:,0]**2 - x[:,1]**2), (-1,1))

    dataset = torch.utils.data.TensorDataset(x, y)
    train_ds, test_ds = torch.utils.data.random_split(dataset, [800,200])
    trainloader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    testloader = torch.utils.data.DataLoader(test_ds, batch_size=200)

    for t in range(epochs):
        for batch_idx, (sample, target) in enumerate(trainloader):
            optimizer.zero_grad()
            y_pred = net(sample)
            loss = nn.functional.mse_loss(y_pred, target)
            loss.backward()
            optimizer.step()
            # Only necessary if early stopping and live metrics are to be employed, otherwise can be omitted.
            reporter.broadcast(metric=loss.item())
        if t%25 == 24:
            print("Iteration {}: MSE Loss: {:.2e}".format(t, loss.item()))

    for idx, (sample, target) in enumerate(testloader):
        with torch.no_grad():
            y_pred = net(sample)
            test_loss = nn.functional.mse_loss(y_pred, target)
            
    print("MSE Loss of the model: {:.2e}".format(test_loss.item()))

    return test_loss.item()
# Run the search with Maggy. 
result = experiment.lagom(train_fn=training_function, 
                           searchspace=sp, 
                           optimizer='randomsearch', 
                           direction='min',
                           num_trials=2, 
                           name='fctApproxTest', 
                           hb_interval=1, 
                           es_interval=1,
                           es_min=5
                          )