Preservare lo stato di avanzamento dell'addestramento utilizzando Autocheckpoint

Storicamente, quando una VM TPU richiede manutenzione, la procedura viene avviata immediatamente, senza lasciare agli utenti il tempo di eseguire azioni di conservazione dello stato di avanzamento, come il salvataggio di un checkpoint. Questo è mostrato nella Figura 1(a).

Diagramma che mostra l'impatto della manutenzione dell'host con e senza checkpoint automatici

Figura 1. Illustrazione della funzionalità Autocheckpoint: (a) senza Autocheckpoint, lo stato di avanzamento dell'addestramento dall'ultimo checkpoint viene perso quando si verifica un evento di manutenzione imminente. (b) Con Autocheckpoint, lo stato di avanzamento dell'addestramento dall'ultimo checkpoint può essere conservato quando si verifica un evento di manutenzione imminente.

Puoi utilizzare Autocheckpoint (Figura 1(b)) per preservare lo stato di avanzamento dell'addestramento configurando il codice in modo da salvare un checkpoint non pianificato quando si verifica un evento di manutenzione. Quando si verifica un evento di manutenzione, lo stato di avanzamento dall'ultimo checkpoint viene salvato automaticamente. La funzionalità funziona sia su singole slice che su Multislice.

La funzionalità Autocheckpoint funziona con i framework in grado di acquisire i segnali SIGTERM e successivamente salvare un checkpoint. I framework supportati includono:

Utilizzare Autocheckpoint

La funzionalità Autocheckpoint è disattivata per impostazione predefinita. Quando crei una TPU o richiedi una risorsa in coda, puoi abilitare Autocheckpoint aggiungendo il flag --autocheckpoint-enabled durante il provisioning della TPU. Con la funzionalità abilitata, Cloud TPU esegue i seguenti passaggi dopo aver ricevuto la notifica di un evento di manutenzione:

  1. Acquisisci il segnale SIGTERM inviato al processo utilizzando il dispositivo TPU
  2. Attendi l'uscita del processo o il trascorrere di 5 minuti, a seconda di quale si verifica per prima
  3. Esegui la manutenzione delle slice interessate

L'infrastruttura utilizzata da Autocheckpoint è indipendente dal framework di machine learning. Qualsiasi framework di machine learning può supportare Autocheckpoint se è in grado di acquisire il segnale SIGTERM e avviare un processo di checkpoint.

Nel codice dell'applicazione, devi abilitare le funzionalità Autocheckpoint fornite dal framework di machine learning. In Pax, ad esempio, ciò significa abilitare i flag della riga di comando durante l'avvio dell'addestramento. Per ulteriori informazioni, consulta la guida rapida di Autocheckpoint con Pax. Dietro le quinte, i framework salvano un checkpoint non pianificato quando viene ricevuto un segnale SIGTERM e la VM TPU interessata viene sottoposta a manutenzione quando la TPU non è più in uso.

Guida rapida: Autocheckpoint con MaxText

MaxText è una libreria LLM e un'implementazione di riferimento open source, ad alte prestazioni, scalabile in modo arbitrario e ben testata, scritta in Python/JAX puro e destinata a Cloud TPU. MaxText contiene tutta la configurazione necessaria per utilizzare la funzionalità Autocheckpoint.

Il file MaxText README descrive due modi per eseguire MaxText su larga scala:

Quando utilizzi multihost_runner.py, abilita Autocheckpoint impostando il flag autocheckpoint-enabled durante il provisioning della risorsa in coda.

Quando utilizzi multihost_job.py, abilita Autocheckpoint specificando il flag della riga di comando ENABLE_AUTOCHECKPOINT=true durante l'avvio del job.

Guida rapida: Autocheckpoint con Pax su una singola slice

Questa sezione fornisce un esempio di come configurare e utilizzare Autocheckpoint con Pax su una singola slice. Con la configurazione appropriata:

  • Viene salvato un checkpoint quando si verifica un evento di manutenzione.
  • Cloud TPU esegue la manutenzione delle VM TPU interessate dopo il salvataggio del checkpoint.
  • Al termine della manutenzione di Cloud TPU, puoi utilizzare la VM TPU come di consueto.
  1. Utilizza il flag autocheckpoint-enabled quando crei la VM TPU o richiedi una risorsa in coda.

    Ad esempio:

    1. Imposta le variabili di ambiente:

      export PROJECT_ID=your-project-id
      export TPU_NAME=your-tpu-name
      export ZONE=zone-you-want-to-use
      export ACCELERATOR_TYPE=your-accelerator-type
      export RUNTIME_VERSION=tpu-ubuntu2204-base

      Descrizioni delle variabili di ambiente

      • PROJECT_ID: l'ID progetto. Google Cloud Utilizza un progetto esistente o creane uno nuovo.
      • TPU_NAME: il nome della TPU.
      • ZONE: La zona in cui creare la VM TPU. Per ulteriori informazioni sulle zone supportate, consulta Regioni e zone TPU.
      • ACCELERATOR_TYPE: Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni sui tipi di acceleratore supportati per ogni versione TPU, consulta Versioni TPU.
      • RUNTIME_VERSION: la versione software di Cloud TPU.

    2. Imposta l'ID progetto e la zona nella configurazione attiva:

      gcloud config set project $PROJECT_ID
      gcloud config set compute/zone $ZONE
    3. Crea una TPU:

      gcloud alpha compute tpus tpu-vm create $TPU_NAME \
          --accelerator-type $ACCELERATOR_TYPE \
          --version $RUNTIME_VERSION \
          --autocheckpoint-enabled
  2. Connettiti alla TPU utilizzando SSH:

    gcloud compute tpus tpu-vm ssh $TPU_NAME
    
  3. Installa Pax su una singola slice

    La funzionalità Autocheckpoint funziona con le versioni di Pax 1.1.0 e successive. Nella VM TPU, installa jax[tpu] e l'ultima versione di paxml:

    pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
  4. Configura il LmCloudSpmd2B modello. Prima di eseguire lo script di addestramento, modifica ICI_MESH_SHAPE in [1, 8, 1]:

    @experiment_registry.register
    class LmCloudSpmd2B(LmCloudSpmd):
    
        """SPMD model with 2B params.
    
        Global batch size = 2 * 2 * 1 * 32 = 128
        """
        PERCORE_BATCH_SIZE = 8
    
        NUM_LAYERS = 18
        MODEL_DIMS = 3072
        HIDDEN_DIMS = MODEL_DIMS * 4
    
        CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING
        ICI_MESH_SHAPE = [1, 8, 1]
  5. Avvia l'addestramento con la configurazione appropriata.

    L'esempio seguente mostra come configurare il modello LmCloudSpmd2B per salvare i checkpoint attivati da Autocheckpoint in un bucket Cloud Storage. Sostituisci your-storage-bucket con il nome di un bucket esistente o creane uno nuovo.

    export JOB_LOG_DIR=gs://your-storage-bucket
    
    { python3 .local/lib/python3.10/site-packages/paxml/main.py \
        --jax_fully_async_checkpoint=1 \
        --exit_after_ondemand_checkpoint=1 \
        --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \
        --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt

    Tieni presente i due flag passati al comando:

    • jax_fully_async_checkpoint: con questo flag attivo, verrà utilizzato orbax.checkpoint.AsyncCheckpointer. La classe AsyncCheckpointer salva automaticamente un checkpoint quando lo script di addestramento riceve un segnale SIGTERM.
    • exit_after_ondemand_checkpoint: con questo flag attivo, il processo TPU viene chiuso dopo il salvataggio di Autocheckpoint, il che attiva l'esecuzione immediata della manutenzione. Se non utilizzi questo flag, l'addestramento continuerà dopo il salvataggio del checkpoint e Cloud TPU attenderà il timeout (5 minuti) prima di eseguire la manutenzione richiesta.

Autocheckpoint con Orbax

La funzionalità Autocheckpoint non è limitata a MaxText o Pax. Qualsiasi framework in grado di acquisire il segnale SIGTERM e avviare un processo di checkpoint funziona con l'infrastruttura fornita da Autocheckpoint. Orbax, uno spazio dei nomi che fornisce librerie di utilità comuni per gli utenti JAX, offre queste funzionalità.

Come spiegato nella documentazione di Orbax, queste funzionalità sono abilitate per impostazione predefinita per gli utenti di orbax.checkpoint.CheckpointManager. Il metodo save chiamato dopo ogni passaggio verifica automaticamente se è imminente un evento di manutenzione e, in caso affermativo, salva un checkpoint anche se il numero di passaggi non è un multiplo di save_interval_steps. La documentazione di GitHub illustra anche come chiudere l'addestramento dopo aver salvato un Autocheckpoint, con una modifica nel codice utente.