为什么jax+ Stax模型的GPU内存比所需的更多?

发布于 2025-02-07 02:41:53 字数 3505 浏览 1 评论 0原文

我正在尝试从GPU上的Kaggle内核运行JAX + Stax模型,但由于内存错误而失败。我已经将xla_python_client_preallocate设置为false以避免gpu内存的预先安装,并且还尝试了设置xla_python_client_allocator_allocator to platform> Platform,没有任何帮助。由于我不希望GPU上存储的所有数据,因此将默认设备设置为CPU。模型和批处理数据将手动发送到GPU。变量的大小(模型参数,数据...)可能不是一个问题,因为相同的代码在CPU上平稳运行,而无需OOM错误。我还对模型进行了记忆分析。为了仅获得GPU内存,有必要制作另一个版本的代码,其中GPU为默认设备,所有数据都存储在此处。如果我在CPU为默认的原始代码上运行了分析,则仅获取CPU数据的分析。对于模型完成训练也需要减少批量尺寸为10。分析仅显示存储数据和参数所需的内存(≈5.5GB),但是当我使用其他python函数检查GPU使用时,它更大(≈14.6GB,请注意:使用batch_size = 100运行时内存在第一个迷你批处理期间还达到14.6GB,但无法进一步进行)。

这是我使用的代码的简化版本:

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = 'false'
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = 'platform' # Tried this, didn't help

import jax
from jax.lib import xla_bridge
jax.config.update('jax_platform_name', 'cpu') # If not set default device = CPU then all the device arrays will be saved to GPU by default

# Set the processor to GPU if available
try: print('Available GPU Devices: ', jax.devices("gpu")); device = jax.devices("gpu")[0]; gpu_available = 1
except: device = jax.devices("cpu")[0]; gpu_available = 0

# Load data into jax device arrays of dimensions (2000, 200, 200, 3)...

InitializationFunction, ApplyFunction = stax.serial(
    Conv(out_chan = 64, filter_shape = (3, 3), strides = (1, 1), padding = 'SAME'), Relu,
    Conv(out_chan = 64, filter_shape = (3, 3), strides = (1, 1), padding = 'SAME'), Relu,
    Flatten, Dense(128), Relu, Dense(2),)

key = random.PRNGKey(2793)
output_shape, parameters = jax.device_put(InitializationFunction(rng = key, input_shape = (100, image_width, image_height, number_of_channels)), device)
optimizer = optax.adam(0.001)
optimizer_state = jax.device_put(optimizer.init(parameters), device)

def Loss(parameters, inputs, targets):
    predictions = ApplyFunction(parameters, inputs)
    loss = jnp.mean(optax.softmax_cross_entropy(predictions, targets))
    return loss

@jit
def Step(parameters, optimizer_state, inputs, targets):
    loss, gradients = value_and_grad(Loss)(parameters, inputs, targets)
    updates, optimizer_state = optimizer.update(gradients, optimizer_state, parameters)
    parameters = optax.apply_updates(parameters, updates)
    return parameters, optimizer_state, loss

epochs, batch_size = 2, 100
key, subkey = random.split(key)
keys_epochs = random.split(subkey, epochs)
    
for epoch in range(epochs):
    random_indices_order = random.permutation(keys_epochs[epoch], jnp.arange(len(train_set['images'])))

    for batch_number in range(len(train_set['images']) // batch_size):
        start = batch_number * batch_size
        end = (batch_number + 1) * batch_size
        batch_inputs = jax.device_put(jnp.take(train_set['images'], random_indices_order[start:end], 0), device)
        batch_targets = jax.device_put(OneHot(jnp.take(train_set['class_numbers'], random_indices_order[start:end], 0), jnp.max(train_set['class_numbers']) + 1), device)
        parameters, optimizer_state, loss = Step(parameters, optimizer_state, inputs = batch_inputs, targets = batch_targets)          

我的问题是:

  1. 为什么使用变量大小所需的GPU存储器多于使用JAX设备内存分析所需的更多?用于记忆的多余方法是什么,如何跟踪它以及如何预防它?
  2. 执行JAX设备内存分析时,如何同时捕获CPU和GPU内存?它仅在CPU为默认设备时才捕获CPU,尽管GPU也可用,并且也在使用中。

这是GPU设置为默认设备并存储整个数据集时GPU设备内存分析的结果(2X(2000,200 200,200,3)≈1.79GB)。批量大小减少到10。 gpu jax设备存储器批处理大小10

I'm trying to run a JAX + STAX model from Kaggle kernels on GPU but it fails due to Out Of Memory Error. I've set the XLA_PYTHON_CLIENT_PREALLOCATE to false to avoid preallocation of GPU memory and also tried setting XLA_PYTHON_CLIENT_ALLOCATOR to platform, nothing helped. The default device is set to CPU from the beginning as I do not want all the data stored on GPU. Model and batch data are sent to GPU manually. The size of the variables (model parameters, data...) souldn't be a problem as the same code runs smoothly on CPU, without OOM errors. I've also made memory profiling of the model. In order to get only GPU memory it was necessary to make another version of the code where GPU is the default device and all the data is stored there. If I ran the profiling on the original code where CPU is default I only get the profiling for CPU data. Batch size reduction to 10 was also necessary for the model to complete training. The profiling shows only the memory needed for storing the data and parameters (≈ 5.5GB), but when I check the GPU usage with other Python functions it is much larger (≈ 14.6GB, Note: when run with batch_size = 100 the memory also hits 14.6GB during the first mini batch but cannot go further).

Here is the simplified version of the code I used:

import os
os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = 'false'
# os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = 'platform' # Tried this, didn't help

import jax
from jax.lib import xla_bridge
jax.config.update('jax_platform_name', 'cpu') # If not set default device = CPU then all the device arrays will be saved to GPU by default

# Set the processor to GPU if available
try: print('Available GPU Devices: ', jax.devices("gpu")); device = jax.devices("gpu")[0]; gpu_available = 1
except: device = jax.devices("cpu")[0]; gpu_available = 0

# Load data into jax device arrays of dimensions (2000, 200, 200, 3)...

InitializationFunction, ApplyFunction = stax.serial(
    Conv(out_chan = 64, filter_shape = (3, 3), strides = (1, 1), padding = 'SAME'), Relu,
    Conv(out_chan = 64, filter_shape = (3, 3), strides = (1, 1), padding = 'SAME'), Relu,
    Flatten, Dense(128), Relu, Dense(2),)

key = random.PRNGKey(2793)
output_shape, parameters = jax.device_put(InitializationFunction(rng = key, input_shape = (100, image_width, image_height, number_of_channels)), device)
optimizer = optax.adam(0.001)
optimizer_state = jax.device_put(optimizer.init(parameters), device)

def Loss(parameters, inputs, targets):
    predictions = ApplyFunction(parameters, inputs)
    loss = jnp.mean(optax.softmax_cross_entropy(predictions, targets))
    return loss

@jit
def Step(parameters, optimizer_state, inputs, targets):
    loss, gradients = value_and_grad(Loss)(parameters, inputs, targets)
    updates, optimizer_state = optimizer.update(gradients, optimizer_state, parameters)
    parameters = optax.apply_updates(parameters, updates)
    return parameters, optimizer_state, loss

epochs, batch_size = 2, 100
key, subkey = random.split(key)
keys_epochs = random.split(subkey, epochs)
    
for epoch in range(epochs):
    random_indices_order = random.permutation(keys_epochs[epoch], jnp.arange(len(train_set['images'])))

    for batch_number in range(len(train_set['images']) // batch_size):
        start = batch_number * batch_size
        end = (batch_number + 1) * batch_size
        batch_inputs = jax.device_put(jnp.take(train_set['images'], random_indices_order[start:end], 0), device)
        batch_targets = jax.device_put(OneHot(jnp.take(train_set['class_numbers'], random_indices_order[start:end], 0), jnp.max(train_set['class_numbers']) + 1), device)
        parameters, optimizer_state, loss = Step(parameters, optimizer_state, inputs = batch_inputs, targets = batch_targets)          

My questions are:

  1. Why is more GPU memory used than needed for the size of the variables and more than captured with jax device memory profiling? What is the excess of the memory used for, how to track it and how to prevent it?
  2. How to capture both CPU and GPU memory when doing jax device memory profiling? It only captures CPU when CPU is default device, although GPU is available and in use too.

Here is the result of device memory profiling for GPU when GPU is set to default device and stores the entire dataset (2x(2000, 200, 200, 3) ≈ 1.79GB). Batch size is reduced to 10.
GPU Jax Device Memory profiling for batch size 10

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文