测试 python AST 相等性的优雅方法(不是引用或对象标识)

发布于 2024-09-11 00:37:54 字数 611 浏览 11 评论 0原文

不确定这里的术语,但这将是方案中 eq?equal? 之间的区别,或者 == 和 < code>strncmp 带有 C 字符串;在每种情况下,对于实际上具有相同内容的两个不同字符串,第一个将返回 false,第二个将返回 true。

我正在为 Python 的 AST 寻找后一种操作。

现在,我正在这样做:

import ast
def AST_eq(a, b):
    return ast.dump(a) == ast.dump(b)

这显然有效,但感觉就像一场即将发生的灾难。有人知道更好的方法吗?

编辑:不幸的是,当我去比较两个 AST 的 __dict__ 时,该比较默认使用各个元素的 __eq__ 方法。 AST 被实现为其他 AST 的树,并且它们的 __eq__ 显然会检查引用身份。因此,直接的 == 和 Thomas 链接中的解决方案都不起作用。 (除此之外,我也不想子类化每个 AST 节点类型来插入此自定义 __eq__。)

Not sure of the terminology here, but this would be difference between eq? and equal? in scheme, or the difference between == and strncmp with C strings; where in each case the first would return false for two different strings that actually have the same content and the second would return true.

I'm looking for the latter operation, for Python's ASTs.

Right now, I'm doing this:

import ast
def AST_eq(a, b):
    return ast.dump(a) == ast.dump(b)

which apparently works but feels like a disaster waiting to happen. Anyone know of a better way?

Edit: unfortunately, when I go to compare the two ASTs' __dict__'s, that comparison defaults to using the individual elements' __eq__ methods. ASTs are implemented as trees of other ASTs, and their __eq__ apparently checks for reference identity. So neither straight == nor the solution in Thomas's link work. (Besides which, I also don't want to subclass every AST node type to insert this custom __eq__.)

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(4

对岸观火 2024-09-18 00:37:54

我遇到了同样的问题。我尝试这样做:首先将 AST 简化为一些更简单的表示(字典树):

def simplify(node):
    if isinstance(node, ast.AST):
        res = vars(node).copy()
        for k in 'lineno', 'col_offset', 'ctx':
            res.pop(k, None)
        for k, v in res.iteritems():
            res[k] = simplify(v)
        res['__type__'] = type(node).__name__
        return res
    elif isinstance(node, list):
        return map(simplify, node)
    else:
        return node

然后您可以比较这些表示:

data = open("/usr/lib/python2.7/ast.py").read()
a1 = ast.parse(data)
a2 = ast.parse(data)
print simplify(a1) == simplify(a2)

会给您 True

EDIT

只是明白不需要创建字典,所以你可以这样做:

def compare_ast(node1, node2):
    if type(node1) is not type(node2):
        return False
    if isinstance(node1, ast.AST):
        for k, v in vars(node1).iteritems():
            if k in ('lineno', 'col_offset', 'ctx'):
                continue
            if not compare_ast(v, getattr(node2, k)):
                return False
        return True
    elif isinstance(node1, list):
        return all(itertools.starmap(compare_ast, itertools.izip(node1, node2)))
    else:
        return node1 == node2

I ran into the same problem. I tried to go this way: first dumb down AST to some easier representation (a tree of dicts):

def simplify(node):
    if isinstance(node, ast.AST):
        res = vars(node).copy()
        for k in 'lineno', 'col_offset', 'ctx':
            res.pop(k, None)
        for k, v in res.iteritems():
            res[k] = simplify(v)
        res['__type__'] = type(node).__name__
        return res
    elif isinstance(node, list):
        return map(simplify, node)
    else:
        return node

and then you can just compare these representations:

data = open("/usr/lib/python2.7/ast.py").read()
a1 = ast.parse(data)
a2 = ast.parse(data)
print simplify(a1) == simplify(a2)

will give you True

EDIT

Just understood that there's no need to create a dict, so you can do just:

def compare_ast(node1, node2):
    if type(node1) is not type(node2):
        return False
    if isinstance(node1, ast.AST):
        for k, v in vars(node1).iteritems():
            if k in ('lineno', 'col_offset', 'ctx'):
                continue
            if not compare_ast(v, getattr(node2, k)):
                return False
        return True
    elif isinstance(node1, list):
        return all(itertools.starmap(compare_ast, itertools.izip(node1, node2)))
    else:
        return node1 == node2
阿楠 2024-09-18 00:37:54

我修改了 @Yorik.sar 对 Python 3.9+ 的答案:

from itertools import zip_longest
from typing import Union


def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool:
    if type(node1) is not type(node2):
        return False

    if isinstance(node1, ast.AST):
        for k, v in vars(node1).items():
            if k in {"lineno", "end_lineno", "col_offset", "end_col_offset", "ctx"}:
                continue
            if not compare_ast(v, getattr(node2, k)):
                return False
        return True

    elif isinstance(node1, list) and isinstance(node2, list):
        return all(compare_ast(n1, n2) for n1, n2 in zip_longest(node1, node2))
    else:
        return node1 == node2

I modified @Yorik.sar's answer for Python 3.9+:

from itertools import zip_longest
from typing import Union


def compare_ast(node1: Union[ast.expr, list[ast.expr]], node2: Union[ast.expr, list[ast.expr]]) -> bool:
    if type(node1) is not type(node2):
        return False

    if isinstance(node1, ast.AST):
        for k, v in vars(node1).items():
            if k in {"lineno", "end_lineno", "col_offset", "end_col_offset", "ctx"}:
                continue
            if not compare_ast(v, getattr(node2, k)):
                return False
        return True

    elif isinstance(node1, list) and isinstance(node2, list):
        return all(compare_ast(n1, n2) for n1, n2 in zip_longest(node1, node2))
    else:
        return node1 == node2
隐诗 2024-09-18 00:37:54

以下代码适用于 Python 2 或 3,并且比使用 itertools 更快:

编辑:警告

显然,此代码可能会在某些(奇怪的)情况下挂起。结果,我不能推荐它。

def compare_ast(node1, node2):

    if type(node1) != type(node2):
        return False
    elif isinstance(node1, ast.AST):
        for kind, var in vars(node1).items():
            if kind not in ('lineno', 'col_offset', 'ctx'):
                var2 = vars(node2).get(kind)
                if not compare_ast(var, var2):
                    return False
        return True
    elif isinstance(node1, list):
        if len(node1) != len(node2):
            return False
        for i in range(len(node1)):
            if not compare_ast(node1[i], node2[i]):
                return False
        return True
    else:
        return node1 == node2

The following works with Python 2 or 3 and is faster than using itertools:

EDIT: WARNING:

Apparently this code can hang in some (weird) situations. As a result, I can not recommend it.

def compare_ast(node1, node2):

    if type(node1) != type(node2):
        return False
    elif isinstance(node1, ast.AST):
        for kind, var in vars(node1).items():
            if kind not in ('lineno', 'col_offset', 'ctx'):
                var2 = vars(node2).get(kind)
                if not compare_ast(var, var2):
                    return False
        return True
    elif isinstance(node1, list):
        if len(node1) != len(node2):
            return False
        for i in range(len(node1)):
            if not compare_ast(node1[i], node2[i]):
                return False
        return True
    else:
        return node1 == node2
舟遥客 2024-09-18 00:37:54

在 Python 中,使用 is 运算符(与 == 不同,不能重载)来比较对象标识。除非由白痴实现,否则 == 不会比较身份,而是比较相等(当然,如果可能并实现的话)。对于内置字符串类,情况肯定不是这样。

不过,您的实现可能存在另一个问题 - 由于转储会产生非常精确的信息(适合调试),因此两个 ass 可能会被视为 != ,例如具有不同名称的变量。这可能是也可能不是您想要的。

In Python, object identitiy is compared using the is operator (which, unlike ==, cannot be overloaded). Unless implemented by a moron, == will not compare identity, but rather equality (if possible and implemented, of course). And in case of the built-in string class, this is certainly not the case.

There may be another problem with your implementation, though - as dump produces very precise information (suitable for debugging), two asts with e.g. a variable named differently may be considered !=. This may or may not be what you want.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文