PyTorch

El proyecto PyTorch es un paquete de Python que proporciona cálculo de tensores acelerado por GPU y funcionalidad de alto nivel para crear redes de aprendizaje profundo. Para obtener información sobre las licencias, consulte el documento de la licencia de PyTorch en GitHub.

Para supervisar y depurar los modelos de PyTorch, puede usar TensorBoard.

PyTorch está incluido en Databricks Runtime para Machine Learning. Si usa Databricks Runtime, consulte Instalación de PyTorch para obtener instrucciones sobre cómo instalar PyTorch.

Nota:

Esta no es una guía completa de PyTorch. Si desea obtener más información, consulte el sitio web de PyTorch.

Entrenamiento distribuido y nodo único

Para probar y migrar flujos de trabajo de una sola máquina, use un clúster de nodo único.

Para obtener opciones de entrenamiento distribuido para el aprendizaje profundo, consulte Aprendizaje distribuido.

Cuaderno de ejemplo

Cuaderno PyTorch

Obtener el cuaderno

Instalación de PyTorch

Databricks Runtime para ML

Databricks Runtime para Machine Learning incluye PyTorch, que permite crear el clúster y empezar a usar PyTorch. Para saber la versión de PyTorch instalada en la versión de Databricks Runtime ML que está usando, consulte las notas de la versión.

Entorno de tiempo de ejecución de Databricks

Databricks recomienda usar la versión de PyTorch incluida en Databricks Runtime para Machine Learning. Sin embargo, si debe usar Databricks Runtime estándar, PyTorch se puede instalar como una biblioteca PyPI de Databricks. En el ejemplo siguiente, se muestra cómo instalar PyTorch 1.5.0:

  • En clústeres de GPU, instale pytorch y torchvision especificando lo siguiente:

    • torch==1.5.0
    • torchvision==0.6.0
  • En clústeres de CPU, instale pytorch y torchvision usando los siguientes archivos wheel de Python:

    https://download.pytorch.org/whl/cpu/torch-1.5.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
    
    https://download.pytorch.org/whl/cpu/torchvision-0.6.0%2Bcpu-cp37-cp37m-linux_x86_64.whl
    

Errores y solución de problemas de PyTorch distribuido

En las secciones siguientes se describen los mensajes de error comunes y las instrucciones de solución de problemas para las clases: PyTorch DataParallel o PyTorch DistributedDataParallel. Es probable que la mayoría de estos errores se resuelva con TorchDistributor, que está disponible en Databricks Runtime ML 13.0 y versiones posteriores. Sin embargo, si TorchDistributor no es una solución viable, también se proporcionan soluciones recomendadas dentro de cada sección.

A continuación se muestra un ejemplo de cómo usar TorchDistributor:


from pyspark.ml.torch.distributor import TorchDistributor

def train_fn(learning_rate):
        # ...

num_processes=2
distributor = TorchDistributor(num_processes=num_processes, local_mode=True)

distributor.run(train_fn, 1e-3)

process 0 terminated with exit code 1

Este error se produce al usar cuadernos, independientemente del entorno: Databricks, máquina local, etc. Para evitar este error, use torch.multiprocessing.start_processes con start_method=fork en lugar de torch.multiprocessing.spawn.

Por ejemplo:

import torch

def train_fn(rank, learning_rate):
    # required setup, e.g. setup(rank)
        # ...

num_processes = 2
torch.multiprocessing.start_processes(train_fn, args=(1e-3,), nprocs=num_processes, start_method="fork")

The server socket has failed to bind to [::]:{PORT NUMBER} (errno: 98 - Address already in use).

Este error aparece cuando se reinicia el entrenamiento distribuido después de interrumpir la celda mientras se está realizando el entrenamiento.

Para resolverlo, reinicie el clúster. Si esto no soluciona el problema, podría haber un error en el código de la función de entrenamiento.

Puede encontrarse con problemas adicionales con CUDA, ya que start_method=”fork”no es compatible con CUDA. El uso de comandos .cuda en cualquier celda puede provocar errores. Para evitar estos errores, agregue la siguiente comprobación antes de llamar a torch.multiprocessing.start_method:

if torch.cuda.is_initialized():
    raise Exception("CUDA was initialized; distributed training will fail.") # or something similar