使用多核CPU的JAX PMAP

发布于 2025-01-31 04:02:49 字数 1079 浏览 4 评论 0 原文

使用 jax.pmap 使用多个CPU内核的正确方法是什么?

以下示例为CPU Core后端上的SPMD创建一个环境变量,测试JAX识别设备并尝试尝试设备锁定。

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'

import jax as jx
import jax.numpy as jnp

jx.local_device_count()
# WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# 2

jx.devices("cpu")
# [CpuDevice(id=0), CpuDevice(id=1)]

def sfunc(x): while True: pass

jx.pmap(sfunc)(jnp.arange(2))

从jupyter内核执行并观察 htop 显示,只有一个核心锁定

当省略前两行和运行:

$ env XLA_FLAGS=--xla_force_host_platform_device_count=2 python test.py

sfunc 替换

def sfunc(x): return 2.0*x

和调用

jx.pmap(sfunc)(jnp.arange(2))
# ShardedDeviceArray([0., 2.], dtype=float32, weak_type=True)

确实返回 sharedDeviecArray

显然,我无法正确配置JAX/XLA使用两个内核。我缺少什么,我该怎么做才能诊断问题?

What is the correct method for using multiple CPU cores with jax.pmap?

The following example creates an environment variable for SPMD on CPU core backends, tests that JAX recognises the devices, and attempts a device lock.

import os
os.environ["XLA_FLAGS"] = '--xla_force_host_platform_device_count=2'

import jax as jx
import jax.numpy as jnp

jx.local_device_count()
# WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
# 2

jx.devices("cpu")
# [CpuDevice(id=0), CpuDevice(id=1)]

def sfunc(x): while True: pass

jx.pmap(sfunc)(jnp.arange(2))

Executing from a jupyter kernel and observing htop shows that only one core is locked

execute from jupyter kernel

I receive the same output from htop when omitting the first two lines and running:

$ env XLA_FLAGS=--xla_force_host_platform_device_count=2 python test.py

Replacing sfunc with

def sfunc(x): return 2.0*x

and calling

jx.pmap(sfunc)(jnp.arange(2))
# ShardedDeviceArray([0., 2.], dtype=float32, weak_type=True)

does return a SharedDeviecArray.

Clearly I am not correctly configuring JAX/XLA to use two cores. What am I missing and what can I do to diagnose the problem?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

花开浅夏 2025-02-07 04:02:49

据我所知,您正在正确配置核心(请参阅eg 问题#2714 )。问题在于您的测试功能:

def sfunc(x): while True: pass

此功能在微量时段,而不是在运行时陷入无限循环。跟踪发生在单个CPU上的主机Python进程中(请参见在jax 用于介绍在JAX转换中追踪的想法)。

如果您想在运行时观察CPU的使用情况,则必须使用完成跟踪并开始运行的功能。为此,您可以使用实际产生结果的任何长期运行功能。这是一个简单的例子:

def sfunc(x):
  for i in range(100):
    x = (x @ x)
  return x

jx.pmap(sfunc)(jnp.zeros((2, 1000, 1000)))

As far as I can tell, you are configuring the cores correctly (see e.g. Issue #2714). The problem lies in your test function:

def sfunc(x): while True: pass

This function gets stuck in an infinite loop at trace-time, not at run-time. Tracing happens in your host Python process on a single CPU (see How to think in JAX for an introduction to the idea of tracing within JAX transformations).

If you want to observe CPU usage at runtime, you'll have to use a function that finishes tracing and begins running. For that you could use any long-running function that actually produces results. Here is a simple example:

def sfunc(x):
  for i in range(100):
    x = (x @ x)
  return x

jx.pmap(sfunc)(jnp.zeros((2, 1000, 1000)))
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文