scipy griddata 产生样本之间的 nan 值

发布于 2025-01-12 20:59:19 字数 1186 浏览 0 评论 0原文

我正在尝试根据非结构化样本插入网格点。我的样本取自 0.01 和 10(x 轴)之间以及 1e-8 和 1(y 轴)之间的对数空间。当我运行此代码时:

from scipy.interpolate import griddata

data = pd.read_csv('data.csv')

param1, param2, errors = data['param1'].values, data['param2'].values, data['error'].values

x = np.linspace(param1.min(), param1.max(), 100, endpoint=True)
y = np.linspace(param2.min(), param2.max(), 100, endpoint=True)

X, Y = np.meshgrid(x, y)

Z = griddata((param1, param2), errors, (X, Y), method='linear')

fig, ax = plt.subplots(figsize=(10, 7))

cax = ax.contourf(X, Y, Z, 25, cmap='hot')
ax.scatter(param1, param2, s=1, color='black', alpha=0.4)
ax.set(xscale='log', yscale='log')

cbar = fig.colorbar(cax)
fig.tight_layout()

我得到这个结果。白色区域显示 NaN 值。 x 轴和 y 轴均采用对数刻度: 输入图片description here

即使白色区域有样本(散点证明),griddata 也会产生 NaN。 数据中没有 NaN/infs。 我是否遗漏了某些内容,或者这只是 Scipy 中的一个错误?

data.csv

I'm trying to interpolate grid points based on unstructured samples. My samples are taken from a log space between 0.01 and 10 (x axis) and between 1e-8 and 1 (y axis). When I run this code:

from scipy.interpolate import griddata

data = pd.read_csv('data.csv')

param1, param2, errors = data['param1'].values, data['param2'].values, data['error'].values

x = np.linspace(param1.min(), param1.max(), 100, endpoint=True)
y = np.linspace(param2.min(), param2.max(), 100, endpoint=True)

X, Y = np.meshgrid(x, y)

Z = griddata((param1, param2), errors, (X, Y), method='linear')

fig, ax = plt.subplots(figsize=(10, 7))

cax = ax.contourf(X, Y, Z, 25, cmap='hot')
ax.scatter(param1, param2, s=1, color='black', alpha=0.4)
ax.set(xscale='log', yscale='log')

cbar = fig.colorbar(cax)
fig.tight_layout()

I get this result.The white area shows NaN values. Both x and y axes are in log scale:
enter image description here

Even though there are samples in the white area (scatter points prove that), griddata produces NaNs. There are no NaNs/infs in the data. Am I missing something or it's just a bug in Scipy?

data.csv

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

甲如呢乙后呢 2025-01-19 20:59:19

这是由于 XY 插值网格的线性间距和轴的对数缩放造成的。这可以通过几何(“对数”)间隔插值网格来相当容易地解决。

还可以在对数空间中进行插值; IMO 这给出了更好看的结果,但它可能无效。

这是您的图形的更粗略采样的版本,显示了插值网格点如何“聚集”在对数标度图中的右上角。这里,顶行轴显示数据有限的位置,底行是“真实”图:

在此处输入图像描述

您可以看到线性间隔样本网格的最左侧点和最底部点是(只是!)外面套价值观;这尤其糟糕,因为由于对数缩放,下一个最近的点线在视觉上很远。

这是插值网格按几何间隔排列的结果,插值也在该空间中完成。

输入图片此处描述

您可以运行下面的代码来查看其他两个变体。

from itertools import product

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
import pandas as pd

CMAP = None
# crude, to make interpolation grid visible
NX = 11
NY = 11

def plot_general(log_grid=False, log_interp=False):
    data = pd.read_csv('data.csv')

    param1, param2, errors = data['param1'].values, data['param2'].values, data['error'].values

    if log_grid:
        x = np.geomspace(param1.min(), param1.max(), NX)
        y = np.geomspace(param2.min(), param2.max(), NY)
    else:
        x = np.linspace(param1.min(), param1.max(), NX)
        y = np.linspace(param2.min(), param2.max(), NY)

    X, Y = np.meshgrid(x, y)

    if log_interp:
        Z = griddata((np.log10(param1), np.log10(param2)), errors, (np.log10(X), np.log10(Y)), method='linear')
    else:
        Z = griddata((param1, param2), errors, (X, Y), method='linear')

    fZ = np.isfinite(Z)

    fig, ax = plt.subplots(2, 2)

    ax[0,0].contourf(X, Y, fZ, levels=[0.5,1.5])
    ax[0,0].scatter(param1, param2, s=1, color='black')
    ax[0,0].plot(X.flat, Y.flat, '.', color='blue')

    ax[0,1].contourf(X, Y, fZ, levels=[0.5,1.5])
    ax[0,1].scatter(param1, param2, s=1, color='black')
    ax[0,1].plot(X.flat, Y.flat, '.', color='blue')
    ax[0,1].set(xscale='log', yscale='log')

    ax[1,0].contourf(X, Y, Z, levels=25, cmap=CMAP)
    ax[1,0].scatter(param1, param2, s=1, color='black')
    ax[1,0].plot(X.flat, Y.flat, '.', color='blue')
    ax[1,1].contourf(X, Y, Z, levels=25, cmap=CMAP)
    ax[1,1].scatter(param1, param2, s=1, color='black')
    ax[1,1].set(xscale='log', yscale='log')
    ax[1,1].plot(X.flat, Y.flat, '.', color='blue')

    fig.suptitle(f'{log_grid=}, {log_interp=}')
    fig.tight_layout()
    return fig

plt.close('all')

for log_grid, log_interp in product([False, True],
                                    [False, True]):
    fig = plot_general(log_grid, log_interp)
    #if you want to save results:
    #fig.savefig(f'log_grid{log_grid}-log_interp{log_interp}.png')

This is due to the linear spacing of your X-Y interpolation grid, and logarithmic scaling of axes. This is fairly easily fixed by geometrically ("logarithmically") spacing the interpolation grid.

One can also interpolate in log-space; IMO this gives a better looking result, but it may not be valid.

Here's a more-coarsely-sampled version of your figure, showing how the interpolation grid points are "clumped up" to the top right in the log-scaled plot. Here the top row of axes is shows where the data is finite, the bottom row is the "real" plot:

enter image description here

You can see the extreme left and extreme bottom points of a linearly-spaced sample grid are (just!) outside set of values; this is especially bad because the next closest lines of points are visually far away due to the logarithmic scaling.

Here's a result with the interpolation grid geometrically spaced, and interpolation also done in that space.

enter image description here

You can run the code below to view the other two variants.

from itertools import product

import numpy as np
import matplotlib.pyplot as plt
from scipy.interpolate import griddata
import pandas as pd

CMAP = None
# crude, to make interpolation grid visible
NX = 11
NY = 11

def plot_general(log_grid=False, log_interp=False):
    data = pd.read_csv('data.csv')

    param1, param2, errors = data['param1'].values, data['param2'].values, data['error'].values

    if log_grid:
        x = np.geomspace(param1.min(), param1.max(), NX)
        y = np.geomspace(param2.min(), param2.max(), NY)
    else:
        x = np.linspace(param1.min(), param1.max(), NX)
        y = np.linspace(param2.min(), param2.max(), NY)

    X, Y = np.meshgrid(x, y)

    if log_interp:
        Z = griddata((np.log10(param1), np.log10(param2)), errors, (np.log10(X), np.log10(Y)), method='linear')
    else:
        Z = griddata((param1, param2), errors, (X, Y), method='linear')

    fZ = np.isfinite(Z)

    fig, ax = plt.subplots(2, 2)

    ax[0,0].contourf(X, Y, fZ, levels=[0.5,1.5])
    ax[0,0].scatter(param1, param2, s=1, color='black')
    ax[0,0].plot(X.flat, Y.flat, '.', color='blue')

    ax[0,1].contourf(X, Y, fZ, levels=[0.5,1.5])
    ax[0,1].scatter(param1, param2, s=1, color='black')
    ax[0,1].plot(X.flat, Y.flat, '.', color='blue')
    ax[0,1].set(xscale='log', yscale='log')

    ax[1,0].contourf(X, Y, Z, levels=25, cmap=CMAP)
    ax[1,0].scatter(param1, param2, s=1, color='black')
    ax[1,0].plot(X.flat, Y.flat, '.', color='blue')
    ax[1,1].contourf(X, Y, Z, levels=25, cmap=CMAP)
    ax[1,1].scatter(param1, param2, s=1, color='black')
    ax[1,1].set(xscale='log', yscale='log')
    ax[1,1].plot(X.flat, Y.flat, '.', color='blue')

    fig.suptitle(f'{log_grid=}, {log_interp=}')
    fig.tight_layout()
    return fig

plt.close('all')

for log_grid, log_interp in product([False, True],
                                    [False, True]):
    fig = plot_general(log_grid, log_interp)
    #if you want to save results:
    #fig.savefig(f'log_grid{log_grid}-log_interp{log_interp}.png')
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文