Preservare lo stato di avanzamento dell'addestramento utilizzando Autocheckpoint
Storicamente, quando una VM TPU richiede manutenzione, la procedura viene avviata immediatamente, senza lasciare agli utenti il tempo di eseguire azioni di conservazione dello stato di avanzamento, come il salvataggio di un checkpoint. Questo è mostrato nella Figura 1(a).
Figura 1. Illustrazione della funzionalità Autocheckpoint: (a) senza Autocheckpoint, lo stato di avanzamento dell'addestramento dall'ultimo checkpoint viene perso quando si verifica un evento di manutenzione imminente. (b) Con Autocheckpoint, lo stato di avanzamento dell'addestramento dall'ultimo checkpoint può essere conservato quando si verifica un evento di manutenzione imminente.
Puoi utilizzare Autocheckpoint (Figura 1(b)) per preservare lo stato di avanzamento dell'addestramento configurando il codice in modo da salvare un checkpoint non pianificato quando si verifica un evento di manutenzione. Quando si verifica un evento di manutenzione, lo stato di avanzamento dall'ultimo checkpoint viene salvato automaticamente. La funzionalità funziona sia su singole slice che su Multislice.
La funzionalità Autocheckpoint funziona con i framework in grado di acquisire i segnali SIGTERM e successivamente salvare un checkpoint. I framework supportati includono:
Utilizzare Autocheckpoint
La funzionalità Autocheckpoint è disattivata per impostazione predefinita. Quando crei una
TPU o richiedi una risorsa in coda,
puoi abilitare Autocheckpoint aggiungendo il flag --autocheckpoint-enabled durante il provisioning
della TPU.
Con la funzionalità abilitata, Cloud TPU esegue i seguenti passaggi dopo aver ricevuto la notifica di un evento di manutenzione:
- Acquisisci il segnale SIGTERM inviato al processo utilizzando il dispositivo TPU
- Attendi l'uscita del processo o il trascorrere di 5 minuti, a seconda di quale si verifica per prima
- Esegui la manutenzione delle slice interessate
L'infrastruttura utilizzata da Autocheckpoint è indipendente dal framework di machine learning. Qualsiasi framework di machine learning può supportare Autocheckpoint se è in grado di acquisire il segnale SIGTERM e avviare un processo di checkpoint.
Nel codice dell'applicazione, devi abilitare le funzionalità Autocheckpoint fornite dal framework di machine learning. In Pax, ad esempio, ciò significa abilitare i flag della riga di comando durante l'avvio dell'addestramento. Per ulteriori informazioni, consulta la guida rapida di Autocheckpoint con Pax. Dietro le quinte, i framework salvano un checkpoint non pianificato quando viene ricevuto un segnale SIGTERM e la VM TPU interessata viene sottoposta a manutenzione quando la TPU non è più in uso.
Guida rapida: Autocheckpoint con MaxText
MaxText è una libreria LLM e un'implementazione di riferimento open source, ad alte prestazioni, scalabile in modo arbitrario e ben testata, scritta in Python/JAX puro e destinata a Cloud TPU. MaxText contiene tutta la configurazione necessaria per utilizzare la funzionalità Autocheckpoint.
Il file MaxText README
descrive due modi per eseguire MaxText su larga scala:
- Utilizzando
multihost_runner.py, consigliato per la sperimentazione - Utilizzando
multihost_job.py, consigliato per la produzione
Quando utilizzi multihost_runner.py, abilita Autocheckpoint impostando il flag autocheckpoint-enabled durante il provisioning della risorsa in coda.
Quando utilizzi multihost_job.py, abilita Autocheckpoint specificando il flag della riga di comando ENABLE_AUTOCHECKPOINT=true durante l'avvio del job.
Guida rapida: Autocheckpoint con Pax su una singola slice
Questa sezione fornisce un esempio di come configurare e utilizzare Autocheckpoint con Pax su una singola slice. Con la configurazione appropriata:
- Viene salvato un checkpoint quando si verifica un evento di manutenzione.
- Cloud TPU esegue la manutenzione delle VM TPU interessate dopo il salvataggio del checkpoint.
- Al termine della manutenzione di Cloud TPU, puoi utilizzare la VM TPU come di consueto.
Utilizza il flag
autocheckpoint-enabledquando crei la VM TPU o richiedi una risorsa in coda.Ad esempio:
Imposta le variabili di ambiente:
export PROJECT_ID=your-project-id export TPU_NAME=your-tpu-name export ZONE=zone-you-want-to-use export ACCELERATOR_TYPE=your-accelerator-type export RUNTIME_VERSION=tpu-ubuntu2204-base
Descrizioni delle variabili di ambiente
PROJECT_ID: l'ID progetto. Google Cloud Utilizza un progetto esistente o creane uno nuovo.TPU_NAME: il nome della TPU.ZONE: La zona in cui creare la VM TPU. Per ulteriori informazioni sulle zone supportate, consulta Regioni e zone TPU.ACCELERATOR_TYPE: Il tipo di acceleratore specifica la versione e le dimensioni della Cloud TPU che vuoi creare. Per ulteriori informazioni sui tipi di acceleratore supportati per ogni versione TPU, consulta Versioni TPU.RUNTIME_VERSION: la versione software di Cloud TPU.
Imposta l'ID progetto e la zona nella configurazione attiva:
gcloud config set project $PROJECT_ID gcloud config set compute/zone $ZONE
Crea una TPU:
gcloud alpha compute tpus tpu-vm create $TPU_NAME \ --accelerator-type $ACCELERATOR_TYPE \ --version $RUNTIME_VERSION \ --autocheckpoint-enabled
Connettiti alla TPU utilizzando SSH:
gcloud compute tpus tpu-vm ssh $TPU_NAMEInstalla Pax su una singola slice
La funzionalità Autocheckpoint funziona con le versioni di Pax 1.1.0 e successive. Nella VM TPU, installa
jax[tpu]e l'ultima versione dipaxml:pip install paxml && pip install jax[tpu] -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
Configura il
LmCloudSpmd2Bmodello. Prima di eseguire lo script di addestramento, modificaICI_MESH_SHAPEin[1, 8, 1]:@experiment_registry.register class LmCloudSpmd2B(LmCloudSpmd): """SPMD model with 2B params. Global batch size = 2 * 2 * 1 * 32 = 128 """ PERCORE_BATCH_SIZE = 8 NUM_LAYERS = 18 MODEL_DIMS = 3072 HIDDEN_DIMS = MODEL_DIMS * 4 CHECKPOINT_POLICY = layers.AutodiffCheckpointType.SAVE_NOTHING ICI_MESH_SHAPE = [1, 8, 1]
Avvia l'addestramento con la configurazione appropriata.
L'esempio seguente mostra come configurare il modello
LmCloudSpmd2Bper salvare i checkpoint attivati da Autocheckpoint in un bucket Cloud Storage. Sostituisci your-storage-bucket con il nome di un bucket esistente o creane uno nuovo.export JOB_LOG_DIR=gs://your-storage-bucket { python3 .local/lib/python3.10/site-packages/paxml/main.py \ --jax_fully_async_checkpoint=1 \ --exit_after_ondemand_checkpoint=1 \ --exp=tasks.lm.params.lm_cloud.LmCloudSpmd2B \ --job_log_dir=$JOB_LOG_DIR; } 2>&1 | tee pax_logs.txt
Tieni presente i due flag passati al comando:
jax_fully_async_checkpoint: con questo flag attivo, verrà utilizzatoorbax.checkpoint.AsyncCheckpointer. La classeAsyncCheckpointersalva automaticamente un checkpoint quando lo script di addestramento riceve un segnale SIGTERM.exit_after_ondemand_checkpoint: con questo flag attivo, il processo TPU viene chiuso dopo il salvataggio di Autocheckpoint, il che attiva l'esecuzione immediata della manutenzione. Se non utilizzi questo flag, l'addestramento continuerà dopo il salvataggio del checkpoint e Cloud TPU attenderà il timeout (5 minuti) prima di eseguire la manutenzione richiesta.
Autocheckpoint con Orbax
La funzionalità Autocheckpoint non è limitata a MaxText o Pax. Qualsiasi framework in grado di acquisire il segnale SIGTERM e avviare un processo di checkpoint funziona con l'infrastruttura fornita da Autocheckpoint. Orbax, uno spazio dei nomi che fornisce librerie di utilità comuni per gli utenti JAX, offre queste funzionalità.
Come spiegato nella documentazione di Orbax,
queste funzionalità sono abilitate per impostazione predefinita per gli utenti
di orbax.checkpoint.CheckpointManager. Il metodo save chiamato
dopo ogni passaggio verifica automaticamente se è imminente un evento di manutenzione
e, in caso affermativo, salva un checkpoint anche se il numero di passaggi
non è un multiplo di save_interval_steps.
La documentazione di GitHub
illustra anche come chiudere l'addestramento dopo aver salvato un
Autocheckpoint, con una modifica nel codice utente.