如何为省略默认值的“dataclass”定义“__str__”?

发布于 2025-01-11 09:33:17 字数 261 浏览 2 评论 0原文

给定一个 dataclass 实例,我希望 print() 或 str() 仅列出非默认字段值。当dataclass有很多字段并且只有少数字段发生变化时,这非常有用。

@dataclasses.dataclass
class X:
  a: int = 1
  b: bool = False
  c: float = 2.0

x = X(b=True)
print(x)  # Desired output: X(b=True)

Given a dataclass instance, I would like print() or str() to only list the non-default field values. This is useful when the dataclass has many fields and only a few are changed.

@dataclasses.dataclass
class X:
  a: int = 1
  b: bool = False
  c: float = 2.0

x = X(b=True)
print(x)  # Desired output: X(b=True)

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(2

红衣飘飘貌似仙 2025-01-18 09:33:17

解决方案是添加自定义 __str__() 函数:

@dataclasses.dataclass
class X:
  a: int = 1
  b: bool = False
  c: float = 2.0

  def __str__(self):
    """Returns a string containing only the non-default field values."""
    s = ', '.join(f'{field.name}={getattr(self, field.name)!r}'
                  for field in dataclasses.fields(self)
                  if getattr(self, field.name) != field.default)
    return f'{type(self).__name__}({s})'

x = X(b=True)
print(x)        # X(b=True)
print(str(x))   # X(b=True)
print(repr(x))  # X(a=1, b=True, c=2.0)
print(f'{x}, {x!s}, {x!r}')  # X(b=True), X(b=True), X(a=1, b=True, c=2.0)

这也可以使用装饰器来实现:

def terse_str(cls):  # Decorator for class.
  def __str__(self):
    """Returns a string containing only the non-default field values."""
    s = ', '.join(f'{field.name}={getattr(self, field.name)}'
                  for field in dataclasses.fields(self)
                  if getattr(self, field.name) != field.default)
    return f'{type(self).__name__}({s})'

  setattr(cls, '__str__', __str__)
  return cls

@dataclasses.dataclass
@terse_str
class X:
  a: int = 1
  b: bool = False
  c: float = 2.0

The solution is to add a custom __str__() function:

@dataclasses.dataclass
class X:
  a: int = 1
  b: bool = False
  c: float = 2.0

  def __str__(self):
    """Returns a string containing only the non-default field values."""
    s = ', '.join(f'{field.name}={getattr(self, field.name)!r}'
                  for field in dataclasses.fields(self)
                  if getattr(self, field.name) != field.default)
    return f'{type(self).__name__}({s})'

x = X(b=True)
print(x)        # X(b=True)
print(str(x))   # X(b=True)
print(repr(x))  # X(a=1, b=True, c=2.0)
print(f'{x}, {x!s}, {x!r}')  # X(b=True), X(b=True), X(a=1, b=True, c=2.0)

This can also be achieved using a decorator:

def terse_str(cls):  # Decorator for class.
  def __str__(self):
    """Returns a string containing only the non-default field values."""
    s = ', '.join(f'{field.name}={getattr(self, field.name)}'
                  for field in dataclasses.fields(self)
                  if getattr(self, field.name) != field.default)
    return f'{type(self).__name__}({s})'

  setattr(cls, '__str__', __str__)
  return cls

@dataclasses.dataclass
@terse_str
class X:
  a: int = 1
  b: bool = False
  c: float = 2.0
清醇 2025-01-18 09:33:17

我建议的一项改进是计算 dataclasses.fields 的结果,然后缓存结果中的默认值。这将有助于提高性能,因为当前 dataclasses 每次调用时都会评估 fields

这是一个使用元类方法的简单示例。

请注意,我还对其进行了稍微修改,以便它可以处理定义 default_factory 的可变类型字段。

from __future__ import annotations
import dataclasses


# adapted from `dataclasses` module
def _create_fn(name, args, body, *, globals=None, locals=None):
    if locals is None:
        locals = {}
    args = ','.join(args)
    body = '\n'.join(f'  {b}' for b in body)
    # Compute the text of the entire function.
    txt = f' def {name}({args}):\n{body}'
    local_vars = ', '.join(locals.keys())
    txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
    ns = {}
    exec(txt, globals, ns)
    return ns['__create_fn__'](**locals)


def terse_str(cls_name, bases, cls_dict):  # Metaclass for class

    def __str__(self):
        cls_fields: tuple[dataclasses.Field, ...] = dataclasses.fields(self)
        _locals = {}
        _body_lines = ['lines=[]']
        for f in cls_fields:
            name = f.name
            dflt_name = f'_dflt_{name}'
            dflt_factory = f.default_factory
            if dflt_factory is not dataclasses.MISSING:
                _locals[dflt_name] = dflt_factory()
            else:
                _locals[dflt_name] = f.default
            _body_lines.append(f'value=self.{name}')
            _body_lines.append(f'if value != _dflt_{name}:')
            _body_lines.append(f' lines.append(f"{name}={{value!r}}")')
        _body_lines.append(f'return f\'{cls_name}({{", ".join(lines)}})\'')
        # noinspection PyShadowingNames
        __str__ = _create_fn('__str__', ('self', ), _body_lines, locals=_locals)
        # set the __str__ with the cached `dataclass.fields`
        setattr(type(self), '__str__', __str__)
        # on initial run, compute and return __str__()
        return __str__(self)

    cls_dict['__str__'] = __str__
    return type(cls_name, bases, cls_dict)


@dataclasses.dataclass
class X(metaclass=terse_str):
    a: int = 1
    b: bool = False
    c: float = 2.0
    d: list[int] = dataclasses.field(default_factory=lambda: [1, 2, 3])


x1 = X(b=True)
x2 = X(b=False, c=3, d=[1, 2])

print(x1)    # X(b=True)
print(x2)    # X(c=3, d=[1, 2])

最后,这是一个快速而肮脏的测试,以确认缓存实际上有利于重复调用 str()print

import dataclasses
from timeit import timeit


def terse_str(cls):  # Decorator for class.
    def __str__(self):
        """Returns a string containing only the non-default field values."""
        s = ', '.join(f'{field.name}={getattr(self, field.name)}'
                      for field in dataclasses.fields(self)
                      if getattr(self, field.name) != field.default)
        return f'{type(self).__name__}({s})'

    setattr(cls, '__str__', __str__)
    return cls


# adapted from `dataclasses` module
def _create_fn(name, args, body, *, globals=None, locals=None):
    if locals is None:
        locals = {}
    args = ','.join(args)
    body = '\n'.join(f'  {b}' for b in body)
    # Compute the text of the entire function.
    txt = f' def {name}({args}):\n{body}'
    local_vars = ', '.join(locals.keys())
    txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
    ns = {}
    exec(txt, globals, ns)
    return ns['__create_fn__'](**locals)


def terse_str_meta(cls_name, bases, cls_dict):  # Metaclass for class

    def __str__(self):
        cls_fields: tuple[dataclasses.Field, ...] = dataclasses.fields(self)
        _locals = {}
        _body_lines = ['lines=[]']
        for f in cls_fields:
            name = f.name
            dflt_name = f'_dflt_{name}'
            dflt_factory = f.default_factory
            if dflt_factory is not dataclasses.MISSING:
                _locals[dflt_name] = dflt_factory()
            else:
                _locals[dflt_name] = f.default
            _body_lines.append(f'value=self.{name}')
            _body_lines.append(f'if value != _dflt_{name}:')
            _body_lines.append(f' lines.append(f"{name}={{value!r}}")')
        _body_lines.append(f'return f\'{cls_name}({{", ".join(lines)}})\'')
        # noinspection PyShadowingNames
        __str__ = _create_fn('__str__', ('self', ), _body_lines, locals=_locals)
        # set the __str__ with the cached `dataclass.fields`
        setattr(type(self), '__str__', __str__)
        # on initial run, compute and return __str__()
        return __str__(self)

    cls_dict['__str__'] = __str__
    return type(cls_name, bases, cls_dict)


@dataclasses.dataclass
@terse_str
class X:
    a: int = 1
    b: bool = False
    c: float = 2.0


@dataclasses.dataclass
class X_Cached(metaclass=terse_str_meta):
    a: int = 1
    b: bool = False
    c: float = 2.0


print(f"Simple:  {timeit('str(X(b=True))', globals=globals()):.3f}")
print(f"Cached:  {timeit('str(X_Cached(b=True))', globals=globals()):.3f}")

print()
print(X(b=True))
print(X_Cached(b=True))

结果:

Simple:  1.038
Cached:  0.289

One improvement I would suggest is to compute the result from dataclasses.fields and then cache the default values from the result. This will help performance because currently dataclasses evaluates the fields each time it is invoked.

Here's a simple example using a metaclass approach.

Note that I've also modified it slightly so it handles mutable-type fields that define a default_factory for instance.

from __future__ import annotations
import dataclasses


# adapted from `dataclasses` module
def _create_fn(name, args, body, *, globals=None, locals=None):
    if locals is None:
        locals = {}
    args = ','.join(args)
    body = '\n'.join(f'  {b}' for b in body)
    # Compute the text of the entire function.
    txt = f' def {name}({args}):\n{body}'
    local_vars = ', '.join(locals.keys())
    txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
    ns = {}
    exec(txt, globals, ns)
    return ns['__create_fn__'](**locals)


def terse_str(cls_name, bases, cls_dict):  # Metaclass for class

    def __str__(self):
        cls_fields: tuple[dataclasses.Field, ...] = dataclasses.fields(self)
        _locals = {}
        _body_lines = ['lines=[]']
        for f in cls_fields:
            name = f.name
            dflt_name = f'_dflt_{name}'
            dflt_factory = f.default_factory
            if dflt_factory is not dataclasses.MISSING:
                _locals[dflt_name] = dflt_factory()
            else:
                _locals[dflt_name] = f.default
            _body_lines.append(f'value=self.{name}')
            _body_lines.append(f'if value != _dflt_{name}:')
            _body_lines.append(f' lines.append(f"{name}={{value!r}}")')
        _body_lines.append(f'return f\'{cls_name}({{", ".join(lines)}})\'')
        # noinspection PyShadowingNames
        __str__ = _create_fn('__str__', ('self', ), _body_lines, locals=_locals)
        # set the __str__ with the cached `dataclass.fields`
        setattr(type(self), '__str__', __str__)
        # on initial run, compute and return __str__()
        return __str__(self)

    cls_dict['__str__'] = __str__
    return type(cls_name, bases, cls_dict)


@dataclasses.dataclass
class X(metaclass=terse_str):
    a: int = 1
    b: bool = False
    c: float = 2.0
    d: list[int] = dataclasses.field(default_factory=lambda: [1, 2, 3])


x1 = X(b=True)
x2 = X(b=False, c=3, d=[1, 2])

print(x1)    # X(b=True)
print(x2)    # X(c=3, d=[1, 2])

Finally, here's a quick and dirty test to confirm that caching is actually beneficial for repeated calls to str() or print:

import dataclasses
from timeit import timeit


def terse_str(cls):  # Decorator for class.
    def __str__(self):
        """Returns a string containing only the non-default field values."""
        s = ', '.join(f'{field.name}={getattr(self, field.name)}'
                      for field in dataclasses.fields(self)
                      if getattr(self, field.name) != field.default)
        return f'{type(self).__name__}({s})'

    setattr(cls, '__str__', __str__)
    return cls


# adapted from `dataclasses` module
def _create_fn(name, args, body, *, globals=None, locals=None):
    if locals is None:
        locals = {}
    args = ','.join(args)
    body = '\n'.join(f'  {b}' for b in body)
    # Compute the text of the entire function.
    txt = f' def {name}({args}):\n{body}'
    local_vars = ', '.join(locals.keys())
    txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
    ns = {}
    exec(txt, globals, ns)
    return ns['__create_fn__'](**locals)


def terse_str_meta(cls_name, bases, cls_dict):  # Metaclass for class

    def __str__(self):
        cls_fields: tuple[dataclasses.Field, ...] = dataclasses.fields(self)
        _locals = {}
        _body_lines = ['lines=[]']
        for f in cls_fields:
            name = f.name
            dflt_name = f'_dflt_{name}'
            dflt_factory = f.default_factory
            if dflt_factory is not dataclasses.MISSING:
                _locals[dflt_name] = dflt_factory()
            else:
                _locals[dflt_name] = f.default
            _body_lines.append(f'value=self.{name}')
            _body_lines.append(f'if value != _dflt_{name}:')
            _body_lines.append(f' lines.append(f"{name}={{value!r}}")')
        _body_lines.append(f'return f\'{cls_name}({{", ".join(lines)}})\'')
        # noinspection PyShadowingNames
        __str__ = _create_fn('__str__', ('self', ), _body_lines, locals=_locals)
        # set the __str__ with the cached `dataclass.fields`
        setattr(type(self), '__str__', __str__)
        # on initial run, compute and return __str__()
        return __str__(self)

    cls_dict['__str__'] = __str__
    return type(cls_name, bases, cls_dict)


@dataclasses.dataclass
@terse_str
class X:
    a: int = 1
    b: bool = False
    c: float = 2.0


@dataclasses.dataclass
class X_Cached(metaclass=terse_str_meta):
    a: int = 1
    b: bool = False
    c: float = 2.0


print(f"Simple:  {timeit('str(X(b=True))', globals=globals()):.3f}")
print(f"Cached:  {timeit('str(X_Cached(b=True))', globals=globals()):.3f}")

print()
print(X(b=True))
print(X_Cached(b=True))

Results:

Simple:  1.038
Cached:  0.289
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文