使用 jax.pmap
使用多个CPU内核的正确方法是什么?
以下示例为CPU Core后端上的SPMD创建一个环境变量,测试JAX识别设备并尝试尝试设备锁定。
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'
import jax as jx
import jax.numpy as jnp
jx.local_device_count()
# WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# 2
jx.devices("cpu")
# [CpuDevice(id=0), CpuDevice(id=1)]
def sfunc(x): while True: pass
jx.pmap(sfunc)(jnp.arange(2))
从jupyter内核执行并观察 htop
显示,只有一个核心锁定
当省略前两行和运行:
$ env XLA_FLAGS=--xla_force_host_platform_device_count=2 python test.py
用 sfunc
替换
def sfunc(x): return 2.0*x
和调用
jx.pmap(sfunc)(jnp.arange(2))
# ShardedDeviceArray([0., 2.], dtype=float32, weak_type=True)
确实返回 sharedDeviecArray
。
显然,我无法正确配置JAX/XLA使用两个内核。我缺少什么,我该怎么做才能诊断问题?
What is the correct method for using multiple CPU cores with jax.pmap
?
The following example creates an environment variable for SPMD on CPU core backends, tests that JAX recognises the devices, and attempts a device lock.
import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'
import jax as jx
import jax.numpy as jnp
jx.local_device_count()
# WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# 2
jx.devices("cpu")
# [CpuDevice(id=0), CpuDevice(id=1)]
def sfunc(x): while True: pass
jx.pmap(sfunc)(jnp.arange(2))
Executing from a jupyter kernel and observing htop
shows that only one core is locked
data:image/s3,"s3://crabby-images/998d6/998d6a9fd79b9fd4a74df490bbe1d0b70aa5b4f9" alt="execute from jupyter kernel"
I receive the same output from htop
when omitting the first two lines and running:
$ env XLA_FLAGS=--xla_force_host_platform_device_count=2 python test.py
Replacing sfunc
with
def sfunc(x): return 2.0*x
and calling
jx.pmap(sfunc)(jnp.arange(2))
# ShardedDeviceArray([0., 2.], dtype=float32, weak_type=True)
does return a SharedDeviecArray
.
Clearly I am not correctly configuring JAX/XLA to use two cores. What am I missing and what can I do to diagnose the problem?
发布评论
评论(1)
据我所知,您正在正确配置核心(请参阅eg 问题#2714 )。问题在于您的测试功能:
此功能在微量时段,而不是在运行时陷入无限循环。跟踪发生在单个CPU上的主机Python进程中(请参见在jax用于介绍在JAX转换中追踪的想法)。
如果您想在运行时观察CPU的使用情况,则必须使用完成跟踪并开始运行的功能。为此,您可以使用实际产生结果的任何长期运行功能。这是一个简单的例子:
As far as I can tell, you are configuring the cores correctly (see e.g. Issue #2714). The problem lies in your test function:
This function gets stuck in an infinite loop at trace-time, not at run-time. Tracing happens in your host Python process on a single CPU (see How to think in JAX for an introduction to the idea of tracing within JAX transformations).
If you want to observe CPU usage at runtime, you'll have to use a function that finishes tracing and begins running. For that you could use any long-running function that actually produces results. Here is a simple example: