使 pyplot 比 gnuplot 更快

发布于 2024-12-07 12:15:31 字数 861 浏览 0 评论 0原文

我最近决定尝试一下 matplotlib.pyplot,同时我已经使用 gnuplot 进行科学数据绘图多年。我首先简单地读取数据文件并绘制两列,就像 gnuplot 使用 plot 'datafile' u 1:2 所做的那样。 我舒适的要求是:

  • 跳过以 # 开头的行并跳过空行。
  • 允许实际数字之间和之前的任意数量的空格
  • 允许任意数量的列
  • 速度快

现在,以下代码是我解决该问题的方法。然而,与 gnuplot 相比,它确实没有那么快。这有点奇怪,因为我读到 py(plot/thon) 相对于 gnuplot 的一大优势是它的速度。

import numpy as np
import matplotlib.pyplot as plt
import sys

datafile = sys.argv[1]
data = []
for line in open(datafile,'r'):
    if line and line[0] != '#':
        cols = filter(lambda x: x!='',line.split(' '))
        for index,col in enumerate(cols):
            if len(data) <= index:
                data.append([])
            data[index].append(float(col))

plt.plot(data[0],data[1])
plt.show()

我该怎么做才能使数据读取更快?我快速浏览了 csv 模块,但它对于文件中的注释似乎不太灵活,并且仍然需要迭代文件中的所有行。

I recently decided to give matplotlib.pyplot a try, while having used gnuplot for scientific data plotting for years. I started out with simply reading a data file and plot two columns, like gnuplot would do with plot 'datafile' u 1:2.
The requirements for my comfort are:

  • Skip lines beginning with a # and skip empty lines.
  • Allow arbitrary numbers of spaces between and before the actual numbers
  • allow arbitrary numbers of columns
  • be fast

Now, the following code is my solution for the problem. However, compared to gnuplot, it really is not as fast. This is a bit odd, since I read that one big advantage of py(plot/thon) over gnuplot is it's speed.

import numpy as np
import matplotlib.pyplot as plt
import sys

datafile = sys.argv[1]
data = []
for line in open(datafile,'r'):
    if line and line[0] != '#':
        cols = filter(lambda x: x!='',line.split(' '))
        for index,col in enumerate(cols):
            if len(data) <= index:
                data.append([])
            data[index].append(float(col))

plt.plot(data[0],data[1])
plt.show()

What would I do to make the data reading faster? I had a quick look at the csv module, but it didn't seem to be very flexible with comments in files and one still needs to iterate over all lines in the file.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(2

本宫微胖 2024-12-14 12:15:31

由于您已经安装了 matplotlib,因此还必须安装 numpy。 numpy.genfromtxt 满足您的所有要求并且应该比在 Python 循环中自己解析文件快得多:

import numpy as np
import matplotlib.pyplot as plt

import textwrap
fname='/tmp/tmp.dat'
with open(fname,'w') as f:
    f.write(textwrap.dedent('''\
        id col1 col2 col3
        2010 1 2 3 4
        # Foo

        2011 5 6 7 8
        # Bar        
        # Baz
        2012 8 7 6 5
        '''))

data = np.genfromtxt(fname, 
                     comments='#',    # skip comment lines
                     dtype = None,    # guess dtype of each column
                     names=True)      # use first line as column names
print(data)
plt.plot(data['id'],data['col2'])
plt.show()

Since you have matplotlib installed, you must also have numpy installed. numpy.genfromtxt meets all your requirements and should be much faster than parsing the file yourself in a Python loop:

import numpy as np
import matplotlib.pyplot as plt

import textwrap
fname='/tmp/tmp.dat'
with open(fname,'w') as f:
    f.write(textwrap.dedent('''\
        id col1 col2 col3
        2010 1 2 3 4
        # Foo

        2011 5 6 7 8
        # Bar        
        # Baz
        2012 8 7 6 5
        '''))

data = np.genfromtxt(fname, 
                     comments='#',    # skip comment lines
                     dtype = None,    # guess dtype of each column
                     names=True)      # use first line as column names
print(data)
plt.plot(data['id'],data['col2'])
plt.show()
一曲琵琶半遮面シ 2024-12-14 12:15:31

您确实需要分析您的代码以找出瓶颈所在。

以下是一些微观优化:

import numpy as np
import matplotlib.pyplot as plt
import sys

datafile = sys.argv[1]
data = []
# use with to auto-close the file
for line in open(datafile,'r'):
    # line will never be False because it will always have at least a newline
    # maybe you mean line.rstrip()?
    # you can also try line.startswith('#') instead of line[0] != '#'
    if line and line[0] != '#':
        # not sure of the point of this
        # just line.split() will allow any number of spaces
        # if you do need it, use a list comprehension
        # cols = [col for col in line.split(' ') if col]
        # filter on a user-defined function is slow
        cols = filter(lambda x: x!='',line.split(' '))

        for index,col in enumerate(cols):
            # just made data a collections.defaultdict
            # initialized as data = defaultdict(list)
            # and you can skip this 'if' statement entirely
            if len(data) <= index:
                data.append([])
            data[index].append(float(col))

plt.plot(data[0],data[1])
plt.show()

您也许可以执行以下操作:

with open(datafile) as f:
    lines = (line.split() for line in f 
                 if line.rstrip() and not line.startswith('#'))
    data = zip(*[float(col) for col in line for line in lines])

这将为您提供一个由 tuple 组成的 list,而不是 int -keyed listdict,但在其他方面看起来相同。它可以作为一行完成,但我将其分开以使其更易于阅读。

You really need to profile your code to find out what the bottleneck is.

Here are some micro-optimizations:

import numpy as np
import matplotlib.pyplot as plt
import sys

datafile = sys.argv[1]
data = []
# use with to auto-close the file
for line in open(datafile,'r'):
    # line will never be False because it will always have at least a newline
    # maybe you mean line.rstrip()?
    # you can also try line.startswith('#') instead of line[0] != '#'
    if line and line[0] != '#':
        # not sure of the point of this
        # just line.split() will allow any number of spaces
        # if you do need it, use a list comprehension
        # cols = [col for col in line.split(' ') if col]
        # filter on a user-defined function is slow
        cols = filter(lambda x: x!='',line.split(' '))

        for index,col in enumerate(cols):
            # just made data a collections.defaultdict
            # initialized as data = defaultdict(list)
            # and you can skip this 'if' statement entirely
            if len(data) <= index:
                data.append([])
            data[index].append(float(col))

plt.plot(data[0],data[1])
plt.show()

You may be able to do something like:

with open(datafile) as f:
    lines = (line.split() for line in f 
                 if line.rstrip() and not line.startswith('#'))
    data = zip(*[float(col) for col in line for line in lines])

Which will give you a list of tuples instead of an int-keyed dict of lists, but otherwise appears identical. It can be done as a one-liner but I split it up to make it a little easier to read.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文