将绘图保存到 numpy 数组

发布于 2024-12-10 21:34:32 字数 87 浏览 0 评论 0原文

在 Python 和 Matplotlib 中,可以轻松地将绘图显示为弹出窗口或将绘图保存为 PNG 文件。如何将绘图保存为 RGB 格式的 numpy 数组?

In Python and Matplotlib, it is easy to either display the plot as a popup window or save the plot as a PNG file. How can I instead save the plot to a numpy array in RGB format?

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(9

始终不够 2024-12-17 21:34:32

当您需要与保存的绘图进行像素到像素的比较时,这对于单元测试等来说是一个方便的技巧。

一种方法是使用 fig.canvas.tostring_rgb,然后使用 numpy.fromstring 以及适当的数据类型。还有其他方法,但这是我倾向于使用的一种。

例如

import matplotlib.pyplot as plt
import numpy as np

# Make a random plot...
fig = plt.figure()
fig.add_subplot(111)

# If we haven't already shown or saved the plot, then we need to
# draw the figure first...
fig.canvas.draw()

# Now we can save it to a numpy array.
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))

This is a handy trick for unit tests and the like, when you need to do a pixel-to-pixel comparison with a saved plot.

One way is to use fig.canvas.tostring_rgb and then numpy.fromstring with the approriate dtype. There are other ways as well, but this is the one I tend to use.

E.g.

import matplotlib.pyplot as plt
import numpy as np

# Make a random plot...
fig = plt.figure()
fig.add_subplot(111)

# If we haven't already shown or saved the plot, then we need to
# draw the figure first...
fig.canvas.draw()

# Now we can save it to a numpy array.
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
只是在用心讲痛 2024-12-17 21:34:32

@JUN_NETWORKS 的答案有一个更简单的选项。您可以使用其他格式,例如 rawrgba 并跳过 cv2,而不是将图形保存为 png解码步骤。

换句话说,实际的绘图到 numpy 的转换归结为:

io_buf = io.BytesIO()
fig.savefig(io_buf, format='raw', dpi=DPI)
io_buf.seek(0)
img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
                     newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
io_buf.close()

希望,这有帮助。

There is a bit simpler option for @JUN_NETWORKS's answer. Instead of saving the figure in png, one can use other format, like raw or rgba and skip the cv2 decoding step.

In other words the actual plot-to-numpy conversion boils down to:

io_buf = io.BytesIO()
fig.savefig(io_buf, format='raw', dpi=DPI)
io_buf.seek(0)
img_arr = np.reshape(np.frombuffer(io_buf.getvalue(), dtype=np.uint8),
                     newshape=(int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1))
io_buf.close()

Hope, this helps.

一生独一 2024-12-17 21:34:32

有人提出了这样的方法当然

np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')

,这段代码可以工作。但是,输出 numpy 数组图像的分辨率很低。

我的提案代码是这样的。

import io
import cv2
import numpy as np
import matplotlib.pyplot as plt

# plot sin wave
fig = plt.figure()
ax = fig.add_subplot(111)

x = np.linspace(-np.pi, np.pi)

ax.set_xlim(-np.pi, np.pi)
ax.set_xlabel("x")
ax.set_ylabel("y")

ax.plot(x, np.sin(x), label="sin")

ax.legend()
ax.set_title("sin(x)")


# define a function which returns an image as numpy array from figure
def get_img_from_fig(fig, dpi=180):
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=dpi)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    return img

# you can get a high-resolution image as numpy array!!
plot_img_np = get_img_from_fig(fig)

这段代码运行良好。
如果您在 dpi 参数上设置较大的数字,则可以获得 numpy 数组形式的高分辨率图像。

Some people propose a method which is like this

np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')

Ofcourse, this code work. But, output numpy array image is so low resolution.

My proposal code is this.

import io
import cv2
import numpy as np
import matplotlib.pyplot as plt

# plot sin wave
fig = plt.figure()
ax = fig.add_subplot(111)

x = np.linspace(-np.pi, np.pi)

ax.set_xlim(-np.pi, np.pi)
ax.set_xlabel("x")
ax.set_ylabel("y")

ax.plot(x, np.sin(x), label="sin")

ax.legend()
ax.set_title("sin(x)")


# define a function which returns an image as numpy array from figure
def get_img_from_fig(fig, dpi=180):
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=dpi)
    buf.seek(0)
    img_arr = np.frombuffer(buf.getvalue(), dtype=np.uint8)
    buf.close()
    img = cv2.imdecode(img_arr, 1)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    return img

# you can get a high-resolution image as numpy array!!
plot_img_np = get_img_from_fig(fig)

This code works well.
You can get a high-resolution image as a numpy array if you set a large number on the dpi argument.

凉城 2024-12-17 21:34:32

是时候对您的解决方案进行基准测试了。

import io
import matplotlib
matplotlib.use('agg')  # turn off interactive backend
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots()
ax.plot(range(10))


def plot1():
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    w, h = fig.canvas.get_width_height()
    im = data.reshape((int(h), int(w), -1))


def plot2():
    with io.BytesIO() as buff:
        fig.savefig(buff, format='png')
        buff.seek(0)
        im = plt.imread(buff)


def plot3():
    with io.BytesIO() as buff:
        fig.savefig(buff, format='raw')
        buff.seek(0)
        data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
    w, h = fig.canvas.get_width_height()
    im = data.reshape((int(h), int(w), -1))
>>> %timeit plot1()
34 ms ± 4.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot2()
50.2 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot3()
16.4 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

在这种情况下,IO 原始缓冲区是将 matplotlib 图转换为 numpy 数组最快的。

附加说明:

  • 如果您无权访问该图形,您始终可以从轴中提取它:

    fig = ax.figure

  • 如果您需要 channel x height x 中的数组宽度格式,做

    im = im.transpose((2, 0, 1))

Time to benchmark your solutions.

import io
import matplotlib
matplotlib.use('agg')  # turn off interactive backend
import matplotlib.pyplot as plt
import numpy as np

fig, ax = plt.subplots()
ax.plot(range(10))


def plot1():
    fig.canvas.draw()
    data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
    w, h = fig.canvas.get_width_height()
    im = data.reshape((int(h), int(w), -1))


def plot2():
    with io.BytesIO() as buff:
        fig.savefig(buff, format='png')
        buff.seek(0)
        im = plt.imread(buff)


def plot3():
    with io.BytesIO() as buff:
        fig.savefig(buff, format='raw')
        buff.seek(0)
        data = np.frombuffer(buff.getvalue(), dtype=np.uint8)
    w, h = fig.canvas.get_width_height()
    im = data.reshape((int(h), int(w), -1))
>>> %timeit plot1()
34 ms ± 4.16 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot2()
50.2 ms ± 234 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
>>> %timeit plot3()
16.4 ms ± 36 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)

Under this scenario, IO raw buffers are the fastest to convert a matplotlib figure to a numpy array.

Additional remarks:

  • if you don't have an access to the figure, you can always extract it from the axes:

    fig = ax.figure

  • if you need the array in the channel x height x width format, do

    im = im.transpose((2, 0, 1)).

一袭白衣梦中忆 2024-12-17 21:34:32

如果有人想要一个即插即用的解决方案,而不修改任何先前的代码(获取对 pyplot 图和所有内容的引用),下面的内容对我有用。只需在所有 pyplot 语句之后添加此内容,即在 pyplot.show() 之前

canvas = pyplot.gca().figure.canvas
canvas.draw()
data = numpy.frombuffer(canvas.tostring_rgb(), dtype=numpy.uint8)
image = data.reshape(canvas.get_width_height()[::-1] + (3,))

In case somebody wants a plug and play solution, without modifying any prior code (getting the reference to pyplot figure and all), the below worked for me. Just add this after all pyplot statements i.e. just before pyplot.show()

canvas = pyplot.gca().figure.canvas
canvas.draw()
data = numpy.frombuffer(canvas.tostring_rgb(), dtype=numpy.uint8)
image = data.reshape(canvas.get_width_height()[::-1] + (3,))
小情绪 2024-12-17 21:34:32

MoviePy 使得将图形转换为 numpy 数组非常简单。它有一个名为 mplfig_to_npimage() 的内置函数。你可以这样使用它:

from moviepy.video.io.bindings import mplfig_to_npimage
import matplotlib.pyplot as plt

fig = plt.figure()  # make a figure
numpy_fig = mplfig_to_npimage(fig)  # convert it to a numpy array

MoviePy makes converting a figure to a numpy array quite simple. It has a built-in function for this called mplfig_to_npimage(). You can use it like this:

from moviepy.video.io.bindings import mplfig_to_npimage
import matplotlib.pyplot as plt

fig = plt.figure()  # make a figure
numpy_fig = mplfig_to_npimage(fig)  # convert it to a numpy array
喵星人汪星人 2024-12-17 21:34:32

正如 Joe Kington 所指出的,一种方法是在画布上绘图,将画布转换为字节字符串,然后将其重新整形为正确的形状。

import matplotlib.pyplot as plt
import numpy as np
import math

plt.switch_backend('Agg')


def canvas2rgb_array(canvas):
    """Adapted from: https://stackoverflow.com/a/21940031/959926"""
    canvas.draw()
    buf = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    ncols, nrows = canvas.get_width_height()
    scale = round(math.sqrt(buf.size / 3 / nrows / ncols))
    return buf.reshape(scale * nrows, scale * ncols, 3)


# Make a simple plot to test with
t = np.arange(0.0, 2.0, 0.01)
s = 1 + np.sin(2 * np.pi * t)
fig, ax = plt.subplots()
ax.plot(t, s)

# Extract the plot as an array
plt_array = canvas2rgb_array(fig.canvas)
print(plt_array.shape)

但是,由于 canvas.get_width_height() 返回显示坐标中的宽度和高度,因此有时会在此答案中解决缩放问题。

As Joe Kington has pointed out, one way is to draw on the canvas, convert the canvas to a byte string and then reshape it into the correct shape.

import matplotlib.pyplot as plt
import numpy as np
import math

plt.switch_backend('Agg')


def canvas2rgb_array(canvas):
    """Adapted from: https://stackoverflow.com/a/21940031/959926"""
    canvas.draw()
    buf = np.frombuffer(canvas.tostring_rgb(), dtype=np.uint8)
    ncols, nrows = canvas.get_width_height()
    scale = round(math.sqrt(buf.size / 3 / nrows / ncols))
    return buf.reshape(scale * nrows, scale * ncols, 3)


# Make a simple plot to test with
t = np.arange(0.0, 2.0, 0.01)
s = 1 + np.sin(2 * np.pi * t)
fig, ax = plt.subplots()
ax.plot(t, s)

# Extract the plot as an array
plt_array = canvas2rgb_array(fig.canvas)
print(plt_array.shape)

However as canvas.get_width_height() returns width and height in display coordinates, there are sometimes scaling issues that are resolved in this answer.

夜雨飘雪 2024-12-17 21:34:32

Jonan Gueorguiev 答案的清理版本:

with io.BytesIO() as io_buf:
  fig.savefig(io_buf, format='raw', dpi=dpi)
  image = np.frombuffer(io_buf.getvalue(), np.uint8).reshape(
      int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1)

Cleaned up version of the answer by Jonan Gueorguiev:

with io.BytesIO() as io_buf:
  fig.savefig(io_buf, format='raw', dpi=dpi)
  image = np.frombuffer(io_buf.getvalue(), np.uint8).reshape(
      int(fig.bbox.bounds[3]), int(fig.bbox.bounds[2]), -1)
﹂绝世的画 2024-12-17 21:34:32
import numpy as np 
import cv2
import time
import justpyplot as jplt

xs, ys = [], []
while(cv2.waitKey(1) != 27):
    xt = time.perf_counter() - t0
    yx = np.sin(xt)
    xs.append(xt)
    ys.append(yx)
    
    frame = np.full((500,470,3), (255,255,255), dtype=np.uint8)
    
    vals = np.array(ys)

    plotted_in_array = jplt.just_plot(frame, vals,title="sin() from Clock")
    
    cv2.imshow('np array plot', plotted_in_array)

所有 matplotlib 方法的问题是,即使您执行 plt.ioff() 或返回图形,matplotlib 仍然可以渲染和显示绘图,即使您成功,但它在不同平台上的行为不同(因为 matplotlib 将其委托给后端取决于操作系统) - 绘制 numpy 数组时性能会受到影响。
我测量了之前建议的所有 matplotlib 方法,结果需要几毫秒,最常见的是几十毫秒,有时甚至更多毫秒。

我找不到一个简单的库可以做到这一点,不得不自己编写这个东西。完全矢量化的 numpy(不是单个循环)中的 numpy 绘图,用于所有部分,例如散点、连接、轴、网格,包括点的大小和厚度,并且在微秒内完成

https://github.com/bedbad/justpyplot

import numpy as np 
import cv2
import time
import justpyplot as jplt

xs, ys = [], []
while(cv2.waitKey(1) != 27):
    xt = time.perf_counter() - t0
    yx = np.sin(xt)
    xs.append(xt)
    ys.append(yx)
    
    frame = np.full((500,470,3), (255,255,255), dtype=np.uint8)
    
    vals = np.array(ys)

    plotted_in_array = jplt.just_plot(frame, vals,title="sin() from Clock")
    
    cv2.imshow('np array plot', plotted_in_array)

The issues with all the matplotlib approaches is that matplotlib can still render and display plot even if you do plt.ioff() or return the figure and even if you do succeed while it behaves differently on a different platform(because matplotlib delegates it to backend depending on os) - you get a performance hit for getting plotted numpy array.
I measured all previosly suggested matplotlib approaches and it rakes in milliseconds, most often dozens, sometimes even more milliseconds.

I couldn't find a simple library that just does it, had to write the thing myself. A plot to numpy in fully vectorized numpy(not a single loop) for all the parts such as scatter, connected, axis, grid, including size of the points and thickness and it does it in microseconds

https://github.com/bedbad/justpyplot

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文