HorovodEstimator is an Apache Spark MLlib-style estimator API that leverages the Horovod framework developed by Uber. It facilitates distributed, multi-GPU training of deep neural networks on Spark DataFrames, simplifying the integration of ETL in Spark with model training in TensorFlow. Specifically, HorovodEstimator simplifies launching distributed training with Horovod by:
- Distributing training code & data to each machine on your cluster
- Enabling passwordless SSH between the driver and workers, and launching training via MPI
- Writing custom data-ingest & model-export logic
- Simultaneously running model training & evaluation
HorovodEstimator requires Databricks Runtime ML.
You can run HorovodEstimator on clusters of two or more CPU or GPU-enabled machines; we recommend running on GPU instances if possible.
HorovodEstimator expects all GPUs on the current cluster to be available; thus we do not recommend using the API on shared clusters.
If using GPUs, we recommend not opening any other TensorFlow sessions on the same cluster as the one you’re using with HorovodEstimator. If you open a TensorFlow session, the Python REPL running your notebook will use a GPU, preventing HorovodEstimator from running. In this case you may need to detach/reattach your notebook, and rerun your HorovodEstimator code without running any TensorFlow code beforehand.
Distributed training with HorovodEstimator
HorovodEstimator is a Spark MLlib Estimator and can be used with the Spark MLlib Pipelines API, although estimator persistence is not yet supported.
Fitting a HorovodEstimator returns an MLlib Transformer (a TFTransformer) that
can be used for distributed inference on a DataFrame. It also stores model checkpoints (can be used to resume training), event files (contain metrics logged during training), and a
tf.SavedModel (can be used to apply the model for inference outside Spark) into the specified model directory.
HorovodEstimator makes no fault-tolerance guarantees. If an error occurs during training, HorovodEstimator does not attempt to recover, although you can rerun
fit() to resume training from the latest checkpoint.
The example notebook below demonstrates how to use HorovodEstimator to train a deep neural network on the MNIST dataset, a large database of handwritten digits, shown below.
Training a model to predict a digit is commonly used as the “Hello World” of machine learning.