排查 JAX 问题 - TPU
本指南提供了 JAX 问题排查信息,可帮助您识别和解决在 Cloud TPU 上训练 JAX 模型时可能遇到的问题。
如需查看如何开始使用 Cloud TPU 的通用指南,请参阅 JAX 快速入门。
常见的 JAX 问题
如果您在开发训练模型或使用 JAX 进行训练时遇到问题,请参阅 JAX 常见问题解答。
如需了解使用 JAX 编写训练应用时可能遇到的更多常规编程错误,请参阅 JAX 错误。
分析 JAX 性能
您可以使用剖析 JAX 性能中所述的工具,了解 TPU 资源的使用方式。
排查内存问题
您可以使用 JAX 设备内存分析器监控内存的使用情况,但无法直接管理内存的使用方式。
JAX 设备内存分析器可用于:
- 找出给定时间 TPU 内存中有哪些数组和可执行文件,或者
- 跟踪内存泄漏。
您无法指定如何为特定操作分配 TPU 内存。如需详细了解 JAX 特有的 TPU 性能问题,请参阅将 TPU 与 JAX 搭配使用的性能说明。
排查 TPU 问题
以下部分介绍了如何解决在 TPU 上运行 JAX 程序时可能会遇到的一些常见问题。
如何验证 TPU 是否正在运行?
只要 JAX 未输出“找不到 GPU/TPU,回退到 CPU”,所有一切都会在 TPU 上运行。
您可以通过查看 jax.devices()
来验证 TPU 是否处于活跃状态,您应该会看到显示了多个 TPU 设备,也可以使用 assert jax.devices()[0].platform == 'tpu'
以程序化方式进行验证。
RuntimeError:无法初始化后端“tpu”:UNAVAILABLE:没有可用的 TPU 平台。
此运行时错误消息或在 TPU 虚拟机的 /tmp/tpu_logs/tpu_driver.WARNING
中发现“W1118 17:40:20.985243 23901 tpu_version_flag.cc:57] No hardware is found. Using default TPU version:xxxxxx
”可能表明您运行的是错误的 TPU 虚拟机版本。
验证您是否运行的是当前 JAX 运行时版本,然后重试。
排查 TPU 和 GKE 问题
为了帮助排查问题,请在 GKE 工作负载清单中启用详细日志记录,然后向 GKE 支持团队提供日志。
TPU_MIN_LOG_LEVEL=0 TF_CPP_MIN_LOG_LEVEL=0 TPU_STDERR_LOG_LEVEL=0
以下部分介绍了与 TPU 和 GKE 设置相关的错误消息及解决方法。
没有可用于服务“jobset-webhook-service”的端点
此错误表示作业集未正确安装。检查 jobset-controller-manager 部署 Kubernetes Pod 是否正在运行。如需了解详情,请参阅 JobSet 问题排查文档。
TPU 初始化失败:连接失败
确保 GKE 节点版本为 1.30.4-gke.1348000 或更高版本(不支持 GKE 1.31)。