两个显然相同的数据类别不相等

发布于 2025-01-25 00:44:43 字数 7276 浏览 3 评论 0原文

我定义了以下数据类:

"""This module declares the SubtitleItem dataclass."""

import re

from dataclasses import dataclass
from time_utils import Timestamp

@dataclass
class SubtitleItem:
    """Class for storing all the information for
    a subtitle item."""
    index: int
    start_time: Timestamp
    end_time: Timestamp
    text: str

    @staticmethod
    def load_from_text_item(text_item: str) -> "SubtitleItem":
        """Create new subtitle item from their .srt file text.

        Example, if your .srt file contains the following subtitle item:

        ```
        3
        00:00:05,847 --> 00:00:06,916
        The robot.
        ```

        This function will return:

        ```
        SubtitleItem(
            index=3,
            start_time=Timestamp(seconds=5, milliseconds=847),
            end_time=Timestamp(seconds=6, milliseconds=916),
            text='The robot.')
        ```

        Args:
            text_item (str): The .srt text for a subtitle item.

        Returns:
            SubtitleItem: A corresponding SubtitleItem.
        """

        # Build regex
        index_re = r"\d+"
        timestamp = lambda prefix: rf"(?P<{prefix}_hours>\d\d):" + \
                                   rf"(?P<{prefix}_minutes>\d\d):" + \
                                   rf"(?P<{prefix}_seconds>\d\d)," + \
                                   rf"(?P<{prefix}_milliseconds>\d\d\d)"
        start_timestamp_re = timestamp("start")
        end_timestamp_re = timestamp("end")
        text_re = r".+"
        complete_re = f"^(?P<index>{index_re})\n"
        complete_re += f"{start_timestamp_re} --> {end_timestamp_re}\n"
        complete_re += f"(?P<text>{text_re})$"
        regex = re.compile(complete_re)

        # Match and extract groups
        match = regex.match(text_item)
        if match is None:
            raise ValueError(f"Index item invalid format:\n'{text_item}'")
        groups = match.groupdict()

        # Extract values
        index = int(groups['index'])

        group_items = filter(lambda kv: kv[0].startswith("start_"), groups.items())
        args = { k[len("start_"):]: int(v) for k, v in group_items }
        start = Timestamp(**args)
        group_items = filter(lambda kv: kv[0].startswith("end_"), groups.items())
        args = { k[len("end_"):]: int(v) for k, v in group_items }
        end = Timestamp(**args)

        text = groups['text']

        if start >= end:
            raise ValueError(
                f"Start timestamp must be later than end timestamp: start={start}, end={end}")
        return SubtitleItem(index, start, end, text)

    @staticmethod
    def _format_timestamp(t: Timestamp) -> str:
        """Format a timestamp in the .srt format.

        Args:
            t (Timestamp): The timestamp to convert.

        Returns:
            str: The textual representation for the .srt format.
        """
        return f"{t.get_hours()}:{t.get_minutes()}:{t.get_seconds()},{t.get_milliseconds()}"

    def __str__(self):
        res = f"{self.index}\n"
        res += f"{SubtitleItem._format_timestamp(self.start_time)}"
        res += " --> "
        res += f"{SubtitleItem._format_timestamp(self.end_time)}\n"
        res += self.text
        return res

...在以下测试中使用:

import unittest
from src.subtitle_item import SubtitleItem
from src.time_utils import Timestamp


class SubtitleItemTest(unittest.TestCase):
    def testLoadFromText(self):
        text = "21\n01:02:03,004 --> 05:06:07,008\nTest subtitle."
        res = SubtitleItem.load_from_text_item(text)
        exp = SubtitleItem(
            21, Timestamp(hours=1, minutes=2, seconds=3, milliseconds=4),
            Timestamp(hours=5, minutes=6, seconds=7, milliseconds=8),
            "Test subtitle."
        )
        self.assertEqual(res, exp)

此测试失败,但我不明白为什么。

我已经检查了调试器:expres具有完全相同的字段。 Timestamp类是另一个单独的数据级。我在调试器中手动检查了每个字段的平等,所有字段都是相同的:

>>> exp == res
False
>>> exp.index == res.index
True
>>> exp.start_time == res.start_time
True
>>> exp.end_time == res.end_time
True
>>> exp.text == res.text
True

此外,每个对象上的asdict() 返回相同的词典:

>>> dataclasses.asdict(exp) == dataclasses.asdict(res)
True

我是否有一些关于实现平等操作员的误解使用数据级?

谢谢。

编辑:我的time_utils模块,对不起,不提前

"""
This module declares the Delta and Timestamp classes.
"""

from dataclasses import dataclass

@dataclass(frozen=True)
class _TimeBase:
    hours:          int = 0
    minutes:        int = 0
    seconds:        int = 0
    milliseconds:   int = 0

    def __post_init__(self):
        BOUNDS_H  = range(0, 100)
        BOUNDS_M  = range(0, 60)
        BOUNDS_S  = range(0, 60)
        BOUNDS_MS = range(0, 1000)
        if self.hours not in BOUNDS_H:
            raise ValueError(
                f"{self.hours=} not in [{BOUNDS_H.start, BOUNDS_H.stop})")
        if self.minutes not in BOUNDS_M:
            raise ValueError(
                f"{self.minutes=} not in [{BOUNDS_M.start, BOUNDS_M.stop})")
        if self.seconds not in BOUNDS_S:
            raise ValueError(
                f"{self.seconds=} not in [{BOUNDS_S.start, BOUNDS_S.stop})")
        if self.milliseconds not in BOUNDS_MS:
            raise ValueError(
                f"{self.milliseconds=} not in [{BOUNDS_MS.start, BOUNDS_MS.stop})")

    def _to_ms(self):
        return self.milliseconds + 1000 * (self.seconds + 60 * (self.minutes + 60 * self.hours))


@dataclass(frozen=True)
class Delta(_TimeBase):
    """A time difference, with milliseconds accuracy.
    Must be less than 100h long."""
    sign: int = 1

    def __post_init__(self):
        if self.sign not in (1, -1):
            raise ValueError(
                f"{self.sign=} should either be 1 or -1")
        super().__post_init__()

    def __add__(self, other: "Delta") -> "Delta":
        self_ms = self.sign * self._to_ms()
        other_ms = other.sign * other._to_ms()
        ms_sum = self_ms + other_ms
        sign = -1 if ms_sum < 0 else 1
        ms_sum = abs(ms_sum)

        ms_n, s_rem = ms_sum % 1000, ms_sum // 1000
        s_n, m_rem = s_rem % 60, s_rem // 60
        m_n, h_n = m_rem % 60, m_rem // 60
        return Delta(hours=h_n, minutes=m_n, seconds=s_n, milliseconds=ms_n, sign=sign)

@dataclass(frozen=True)
class Timestamp(_TimeBase):
    """A timestamp with milliseconds accuracy. Must be
    less than 100h long."""

    def __add__(self, other: Delta) -> "Timestamp":
        ms_sum = self._to_ms() + other.sign * other._to_ms()
        ms_n, s_rem = ms_sum % 1000, ms_sum // 1000
        s_n, m_rem = s_rem % 60, s_rem // 60
        m_n, h_n = m_rem % 60, m_rem // 60
        return Timestamp(hours=h_n, minutes=m_n, seconds=s_n, milliseconds=ms_n)

    def __ge__(self, other: "Timestamp") -> bool:
        return self._to_ms() >= other._to_ms()

I've defined the following dataclass:

"""This module declares the SubtitleItem dataclass."""

import re

from dataclasses import dataclass
from time_utils import Timestamp

@dataclass
class SubtitleItem:
    """Class for storing all the information for
    a subtitle item."""
    index: int
    start_time: Timestamp
    end_time: Timestamp
    text: str

    @staticmethod
    def load_from_text_item(text_item: str) -> "SubtitleItem":
        """Create new subtitle item from their .srt file text.

        Example, if your .srt file contains the following subtitle item:

        ```
        3
        00:00:05,847 --> 00:00:06,916
        The robot.
        ```

        This function will return:

        ```
        SubtitleItem(
            index=3,
            start_time=Timestamp(seconds=5, milliseconds=847),
            end_time=Timestamp(seconds=6, milliseconds=916),
            text='The robot.')
        ```

        Args:
            text_item (str): The .srt text for a subtitle item.

        Returns:
            SubtitleItem: A corresponding SubtitleItem.
        """

        # Build regex
        index_re = r"\d+"
        timestamp = lambda prefix: rf"(?P<{prefix}_hours>\d\d):" + \
                                   rf"(?P<{prefix}_minutes>\d\d):" + \
                                   rf"(?P<{prefix}_seconds>\d\d)," + \
                                   rf"(?P<{prefix}_milliseconds>\d\d\d)"
        start_timestamp_re = timestamp("start")
        end_timestamp_re = timestamp("end")
        text_re = r".+"
        complete_re = f"^(?P<index>{index_re})\n"
        complete_re += f"{start_timestamp_re} --> {end_timestamp_re}\n"
        complete_re += f"(?P<text>{text_re})
quot;
        regex = re.compile(complete_re)

        # Match and extract groups
        match = regex.match(text_item)
        if match is None:
            raise ValueError(f"Index item invalid format:\n'{text_item}'")
        groups = match.groupdict()

        # Extract values
        index = int(groups['index'])

        group_items = filter(lambda kv: kv[0].startswith("start_"), groups.items())
        args = { k[len("start_"):]: int(v) for k, v in group_items }
        start = Timestamp(**args)
        group_items = filter(lambda kv: kv[0].startswith("end_"), groups.items())
        args = { k[len("end_"):]: int(v) for k, v in group_items }
        end = Timestamp(**args)

        text = groups['text']

        if start >= end:
            raise ValueError(
                f"Start timestamp must be later than end timestamp: start={start}, end={end}")
        return SubtitleItem(index, start, end, text)

    @staticmethod
    def _format_timestamp(t: Timestamp) -> str:
        """Format a timestamp in the .srt format.

        Args:
            t (Timestamp): The timestamp to convert.

        Returns:
            str: The textual representation for the .srt format.
        """
        return f"{t.get_hours()}:{t.get_minutes()}:{t.get_seconds()},{t.get_milliseconds()}"

    def __str__(self):
        res = f"{self.index}\n"
        res += f"{SubtitleItem._format_timestamp(self.start_time)}"
        res += " --> "
        res += f"{SubtitleItem._format_timestamp(self.end_time)}\n"
        res += self.text
        return res

... which I use in the following test:

import unittest
from src.subtitle_item import SubtitleItem
from src.time_utils import Timestamp


class SubtitleItemTest(unittest.TestCase):
    def testLoadFromText(self):
        text = "21\n01:02:03,004 --> 05:06:07,008\nTest subtitle."
        res = SubtitleItem.load_from_text_item(text)
        exp = SubtitleItem(
            21, Timestamp(hours=1, minutes=2, seconds=3, milliseconds=4),
            Timestamp(hours=5, minutes=6, seconds=7, milliseconds=8),
            "Test subtitle."
        )
        self.assertEqual(res, exp)

This test fails, but I don't understand why.

I've checked with the debugger: exp and res have exactly the same fields. The Timestamp class is another separate dataclass. I've checked equality per field manually in the debugger, all fields are identical:

>>> exp == res
False
>>> exp.index == res.index
True
>>> exp.start_time == res.start_time
True
>>> exp.end_time == res.end_time
True
>>> exp.text == res.text
True

Furthermore, asdict() on each object returns identical dictionaries:

>>> dataclasses.asdict(exp) == dataclasses.asdict(res)
True

Is there something I'm misunderstanding regarding the implementation of the equality operator with dataclasses?

Thanks.

EDIT: my time_utils module, sorry for not including that earlier

"""
This module declares the Delta and Timestamp classes.
"""

from dataclasses import dataclass

@dataclass(frozen=True)
class _TimeBase:
    hours:          int = 0
    minutes:        int = 0
    seconds:        int = 0
    milliseconds:   int = 0

    def __post_init__(self):
        BOUNDS_H  = range(0, 100)
        BOUNDS_M  = range(0, 60)
        BOUNDS_S  = range(0, 60)
        BOUNDS_MS = range(0, 1000)
        if self.hours not in BOUNDS_H:
            raise ValueError(
                f"{self.hours=} not in [{BOUNDS_H.start, BOUNDS_H.stop})")
        if self.minutes not in BOUNDS_M:
            raise ValueError(
                f"{self.minutes=} not in [{BOUNDS_M.start, BOUNDS_M.stop})")
        if self.seconds not in BOUNDS_S:
            raise ValueError(
                f"{self.seconds=} not in [{BOUNDS_S.start, BOUNDS_S.stop})")
        if self.milliseconds not in BOUNDS_MS:
            raise ValueError(
                f"{self.milliseconds=} not in [{BOUNDS_MS.start, BOUNDS_MS.stop})")

    def _to_ms(self):
        return self.milliseconds + 1000 * (self.seconds + 60 * (self.minutes + 60 * self.hours))


@dataclass(frozen=True)
class Delta(_TimeBase):
    """A time difference, with milliseconds accuracy.
    Must be less than 100h long."""
    sign: int = 1

    def __post_init__(self):
        if self.sign not in (1, -1):
            raise ValueError(
                f"{self.sign=} should either be 1 or -1")
        super().__post_init__()

    def __add__(self, other: "Delta") -> "Delta":
        self_ms = self.sign * self._to_ms()
        other_ms = other.sign * other._to_ms()
        ms_sum = self_ms + other_ms
        sign = -1 if ms_sum < 0 else 1
        ms_sum = abs(ms_sum)

        ms_n, s_rem = ms_sum % 1000, ms_sum // 1000
        s_n, m_rem = s_rem % 60, s_rem // 60
        m_n, h_n = m_rem % 60, m_rem // 60
        return Delta(hours=h_n, minutes=m_n, seconds=s_n, milliseconds=ms_n, sign=sign)

@dataclass(frozen=True)
class Timestamp(_TimeBase):
    """A timestamp with milliseconds accuracy. Must be
    less than 100h long."""

    def __add__(self, other: Delta) -> "Timestamp":
        ms_sum = self._to_ms() + other.sign * other._to_ms()
        ms_n, s_rem = ms_sum % 1000, ms_sum // 1000
        s_n, m_rem = s_rem % 60, s_rem // 60
        m_n, h_n = m_rem % 60, m_rem // 60
        return Timestamp(hours=h_n, minutes=m_n, seconds=s_n, milliseconds=ms_n)

    def __ge__(self, other: "Timestamp") -> bool:
        return self._to_ms() >= other._to_ms()

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(2

爱*していゐ 2025-02-01 00:44:43
class Timestamp:
    def __init__( self, hours=0, minutes=0, seconds=0, milliseconds=0 ):
        self.ms = ((hours*60+minutes)*60+seconds)*1000+milliseconds
    def get_hours(self):
        return self.ms // (60*60*1000)
    def get_minutes(self):
        return (self.ms // (60*1000)) % 60
    def get_seconds(self):
        return (self.ms // 1000) % 60
    def get_milliseconds(self):
        return self.ms % 1000
    def __add__(self,other):
        return Timestamp(milliseconds=self.ms + self.other)
    def __eq__(self,other):
        return self.ms == other.ms
    def __lt__(self,other):
        return self.ms < other.ms
    def __le__(self,other):
        return self.ms <= other.ms

... your code ...

text = "21\n01:02:03,004 --> 05:06:07,008\nTest subtitle."
res = SubtitleItem.load_from_text_item(text)
exp = SubtitleItem(
    21, Timestamp(hours=1, minutes=2, seconds=3, milliseconds=4),
    Timestamp(hours=5, minutes=6, seconds=7, milliseconds=8),
    "Test subtitle."
)
print(res)
print(exp)
print(res==exp)

生产:

21
1:2:3,4 --> 5:6:7,8
Test subtitle.
21
1:2:3,4 --> 5:6:7,8
Test subtitle.
True

没有任何例外。

class Timestamp:
    def __init__( self, hours=0, minutes=0, seconds=0, milliseconds=0 ):
        self.ms = ((hours*60+minutes)*60+seconds)*1000+milliseconds
    def get_hours(self):
        return self.ms // (60*60*1000)
    def get_minutes(self):
        return (self.ms // (60*1000)) % 60
    def get_seconds(self):
        return (self.ms // 1000) % 60
    def get_milliseconds(self):
        return self.ms % 1000
    def __add__(self,other):
        return Timestamp(milliseconds=self.ms + self.other)
    def __eq__(self,other):
        return self.ms == other.ms
    def __lt__(self,other):
        return self.ms < other.ms
    def __le__(self,other):
        return self.ms <= other.ms

... your code ...

text = "21\n01:02:03,004 --> 05:06:07,008\nTest subtitle."
res = SubtitleItem.load_from_text_item(text)
exp = SubtitleItem(
    21, Timestamp(hours=1, minutes=2, seconds=3, milliseconds=4),
    Timestamp(hours=5, minutes=6, seconds=7, milliseconds=8),
    "Test subtitle."
)
print(res)
print(exp)
print(res==exp)

Produces:

21
1:2:3,4 --> 5:6:7,8
Test subtitle.
21
1:2:3,4 --> 5:6:7,8
Test subtitle.
True

with no assert exception.

友谊不毕业 2025-02-01 00:44:43

好吧,我想我发现这里出了什么问题。

首先,当我之前报告问题时,我犯了一个错误:在单元测试中,exp.start_time!= res.start_timeexp.end_time!= res.end_time。对此很抱歉。这将问题缩小到时间戳的比较。

我的来源在project/src/中,失败的测试在project/tests/中。要使测试可访问源模块,我必须将源目录添加到pythonpath

$ PYTHONPATH=src/ python -m unittest discover -s tests/ -v

在单元测试中,即使res.start_time and code> and end.start_time do 具有相同的字段,它们没有相同的类型:

>>> print(type(res.start_time), type(exp.start_time))
<class 'time_utils.Timestamp'> <class 'src.time_utils.Timestamp'>

我添加了一个带有最小可重现示例的新帖子,以及有关文件结构的更多详细信息:最低可再现的示例

Okay, I think I found what's going wrong here.

First, I made a mistake when I reported the issue before: in the unit test, exp.start_time != res.start_time and exp.end_time != res.end_time. Sorry about that. That narrows down the issue to comparison of timestamps.

My sources are in project/src/, the test that fails is in project/tests/. To make source modules accessible to the test, I had to add the source directory to PYTHONPATH:

$ PYTHONPATH=src/ python -m unittest discover -s tests/ -v

In the unit test, even though res.start_time and end.start_time do have the same fields, they do not have the same type:

>>> print(type(res.start_time), type(exp.start_time))
<class 'time_utils.Timestamp'> <class 'src.time_utils.Timestamp'>

I've added a new post with a minimally reproducible example, and more details about the file structure here: Minimally reproducible example.

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文