Train and register PyTorch models at scale with Azure Machine Learning service

This article shows you how to train and register a PyTorch model using Azure Machine Learning service. It's based on PyTorch's transfer learning tutorial that builds a deep neural network (DNN) classifier for images of chickens and turkeys.

PyTorch is an open-source computational framework commonly used to create deep neural networks (DNN). With Azure Machine Learning service, you can rapidly scale out open-source training jobs using elastic cloud compute resources. You can also track your training runs, version models, deploy models, and much more.

Whether you're developing a PyTorch model from the ground-up or you're bringing an existing model into the cloud, Azure Machine Learning service can help you build production-ready models.


Run this code on either of these environments:

Set up the experiment

This section sets up the training experiment by loading the required python packages, initializing a workspace, creating an experiment, and uploading the training data and training scripts.

Import packages

First, import the necessary Python libraries.

import os
import shutil

from azureml.core.workspace import Workspace
from azureml.core import Experiment

from azureml.core.compute import ComputeTarget, AmlCompute
from azureml.core.compute_target import ComputeTargetException
from azureml.train.dnn import PyTorch

Initialize a workspace

The Azure Machine Learning service workspace is the top-level resource for the service. It provides you with a centralized place to work with all the artifacts you create. In the Python SDK, you can access the workspace artifacts by creating a workspace object.

Create a workspace object from the config.json file created in the prerequisites section.

ws = Workspace.from_config()

Create an experiment

Create an experiment and a folder to hold your training scripts. In this example, create an experiment called "pytorch-birds".

project_folder = './pytorch-birds'
os.makedirs(project_folder, exist_ok=True)

experiment_name = 'pytorch-birds'
experiment = Experiment(ws, name=experiment_name)

Get the data

The dataset consists of about 120 training images each for turkeys and chickens, with 100 validation images for each class. We will download and extract the dataset as part of our training script The images are a subset of the Open Images v5 Dataset.

Prepare training scripts

In this tutorial, the training script,, is already provided. In practice, you can take any custom training script, as is, and run it with Azure Machine Learning service.

Upload the Pytorch training script,

shutil.copy('', project_folder)

However, if you would like to use Azure Machine Learning service tracking and metrics capabilities, you will have to add a small amount code inside your training script. Examples of metrics tracking can be found in

Create a compute target

Create a compute target for your PyTorch job to run on. In this example, create a GPU-enabled Azure Machine Learning compute cluster.

cluster_name = "gpucluster"

    compute_target = ComputeTarget(workspace=ws, name=cluster_name)
    print('Found existing compute target')
except ComputeTargetException:
    print('Creating a new compute target...')
    compute_config = AmlCompute.provisioning_configuration(vm_size='STANDARD_NC6', 

    compute_target = ComputeTarget.create(ws, cluster_name, compute_config)

    compute_target.wait_for_completion(show_output=True, min_node_count=None, timeout_in_minutes=20)

For more information on compute targets, see the what is a compute target article.

Create a PyTorch estimator

The PyTorch estimator provides a simple way of launching a PyTorch training job on a compute target.

The PyTorch estimator is implemented through the generic estimator class, which can be used to support any framework. For more information about training models using the generic estimator, see train models with Azure Machine Learning using estimator

If your training script needs additional pip or conda packages to run, you can have the packages installed on the resulting docker image by passing their names through the pip_packages and conda_packages arguments.

script_params = {
    '--num_epochs': 30,
    '--output_dir': './outputs'

estimator = PyTorch(source_directory=project_folder, 

Submit a run

The Run object provides the interface to the run history while the job is running and after it has completed.

run = experiment.submit(estimator)

As the Run is executed, it goes through the following stages:

  • Preparing: A docker image is created according to the PyTorch estimator. The image is uploaded to the workspace's container registry and cached for later runs. Logs are also streamed to the run history and can be viewed to monitor progress.

  • Scaling: The cluster attempts to scale up if the Batch AI cluster requires more nodes to execute the run than are currently available.

  • Running: All scripts in the script folder are uploaded to the compute target, data stores are mounted or copied, and the entry_script is executed. Outputs from stdout and the ./logs folder are streamed to the run history and can be used to monitor the run.

  • Post-Processing: The ./outputs folder of the run is copied over to the run history.

Register or download a model

Once you've trained the model, you can register it to your workspace. Model registration lets you store and version your models in your workspace to simplify model management and deployment.

model = run.register_model(model_name='pt-dnn', model_path='outputs/')

You can also download a local copy of the model by using the Run object. In the training script, a PyTorch save object persists the model to a local folder (local to the compute target). You can use the Run object to download a copy.

# Create a model folder in the current directory
os.makedirs('./model', exist_ok=True)

for f in run.get_file_names():
    if f.startswith('outputs/model'):
        output_file_path = os.path.join('./model', f.split('/')[-1])
        print('Downloading from {} to {} ...'.format(f, output_file_path))
        run.download_file(name=f, output_file_path=output_file_path)

Distributed training

The PyTorch estimator also supports distributed training across CPU and GPU clusters. You can easily run distributed PyTorch jobs and Azure Machine Learning service will manage the orchestration for you.


Horovod is an open-source, all reduce framework for distributed training developed by Uber. It offers an easy path to distributed GPU PyTorch jobs.

To use Horovod, specify an MpiConfiguration object for the distributed_training parameter in the PyTorch constructor. This parameter ensures that Horovod library is installed for you to use in your training script.

from azureml.train.dnn import PyTorch

estimator= PyTorch(source_directory=project_folder,

Horovod and its dependencies will be installed for you, so you can import it in your training script as follows:

import torch
import horovod

Export to ONNX

To optimize inference with the ONNX Runtime, convert your trained PyTorch model to the ONNX format. Inference, or model scoring, is the phase where the deployed model is used for prediction, most commonly on production data. See the tutorial for an example.

Next steps

In this article, you trained and registered a PyTorch model on Azure Machine Learning service. To learn how to deploy a model, continue on to our model deployment article.