Scr*_*own 5 python python-dataclasses
我定义了以下数据类:
"""This module declares the SubtitleItem dataclass."""
import re
from dataclasses import dataclass
from time_utils import Timestamp
@dataclass
class SubtitleItem:
"""Class for storing all the information for
a subtitle item."""
index: int
start_time: Timestamp
end_time: Timestamp
text: str
@staticmethod
def load_from_text_item(text_item: str) -> "SubtitleItem":
"""Create new subtitle item from their .srt file text.
Example, if your .srt file contains the following subtitle item:
```
3
00:00:05,847 --> 00:00:06,916
The robot.
```
This function will return:
```
SubtitleItem(
index=3,
start_time=Timestamp(seconds=5, milliseconds=847),
end_time=Timestamp(seconds=6, milliseconds=916),
text='The robot.')
```
Args:
text_item (str): The .srt text for a subtitle item.
Returns:
SubtitleItem: A corresponding SubtitleItem.
"""
# Build regex
index_re = r"\d+"
timestamp = lambda prefix: rf"(?P<{prefix}_hours>\d\d):" + \
rf"(?P<{prefix}_minutes>\d\d):" + \
rf"(?P<{prefix}_seconds>\d\d)," + \
rf"(?P<{prefix}_milliseconds>\d\d\d)"
start_timestamp_re = timestamp("start")
end_timestamp_re = timestamp("end")
text_re = r".+"
complete_re = f"^(?P<index>{index_re})\n"
complete_re += f"{start_timestamp_re} --> {end_timestamp_re}\n"
complete_re += f"(?P<text>{text_re})$"
regex = re.compile(complete_re)
# Match and extract groups
match = regex.match(text_item)
if match is None:
raise ValueError(f"Index item invalid format:\n'{text_item}'")
groups = match.groupdict()
# Extract values
index = int(groups['index'])
group_items = filter(lambda kv: kv[0].startswith("start_"), groups.items())
args = { k[len("start_"):]: int(v) for k, v in group_items }
start = Timestamp(**args)
group_items = filter(lambda kv: kv[0].startswith("end_"), groups.items())
args = { k[len("end_"):]: int(v) for k, v in group_items }
end = Timestamp(**args)
text = groups['text']
if start >= end:
raise ValueError(
f"Start timestamp must be later than end timestamp: start={start}, end={end}")
return SubtitleItem(index, start, end, text)
@staticmethod
def _format_timestamp(t: Timestamp) -> str:
"""Format a timestamp in the .srt format.
Args:
t (Timestamp): The timestamp to convert.
Returns:
str: The textual representation for the .srt format.
"""
return f"{t.get_hours()}:{t.get_minutes()}:{t.get_seconds()},{t.get_milliseconds()}"
def __str__(self):
res = f"{self.index}\n"
res += f"{SubtitleItem._format_timestamp(self.start_time)}"
res += " --> "
res += f"{SubtitleItem._format_timestamp(self.end_time)}\n"
res += self.text
return res
Run Code Online (Sandbox Code Playgroud)
...我在以下测试中使用它:
import unittest
from src.subtitle_item import SubtitleItem
from src.time_utils import Timestamp
class SubtitleItemTest(unittest.TestCase):
def testLoadFromText(self):
text = "21\n01:02:03,004 --> 05:06:07,008\nTest subtitle."
res = SubtitleItem.load_from_text_item(text)
exp = SubtitleItem(
21, Timestamp(hours=1, minutes=2, seconds=3, milliseconds=4),
Timestamp(hours=5, minutes=6, seconds=7, milliseconds=8),
"Test subtitle."
)
self.assertEqual(res, exp)
Run Code Online (Sandbox Code Playgroud)
这个测试失败了,但我不明白为什么。
我已经检查过调试器:exp并且res具有完全相同的字段。该类Timestamp是另一个单独的数据类。我已在调试器中手动检查每个字段的相等性,所有字段都是相同的:
>>> exp == res
False
>>> exp.index == res.index
True
>>> exp.start_time == res.start_time
True
>>> exp.end_time == res.end_time
True
>>> exp.text == res.text
True
Run Code Online (Sandbox Code Playgroud)
此外,asdict()每个对象都返回相同的字典:
>>> dataclasses.asdict(exp) == dataclasses.asdict(res)
True
Run Code Online (Sandbox Code Playgroud)
关于使用数据类实现相等运算符,我是否存在误解?
谢谢。
编辑:我的time_utils模块,很抱歉没有早点包含它
"""
This module declares the Delta and Timestamp classes.
"""
from dataclasses import dataclass
@dataclass(frozen=True)
class _TimeBase:
hours: int = 0
minutes: int = 0
seconds: int = 0
milliseconds: int = 0
def __post_init__(self):
BOUNDS_H = range(0, 100)
BOUNDS_M = range(0, 60)
BOUNDS_S = range(0, 60)
BOUNDS_MS = range(0, 1000)
if self.hours not in BOUNDS_H:
raise ValueError(
f"{self.hours=} not in [{BOUNDS_H.start, BOUNDS_H.stop})")
if self.minutes not in BOUNDS_M:
raise ValueError(
f"{self.minutes=} not in [{BOUNDS_M.start, BOUNDS_M.stop})")
if self.seconds not in BOUNDS_S:
raise ValueError(
f"{self.seconds=} not in [{BOUNDS_S.start, BOUNDS_S.stop})")
if self.milliseconds not in BOUNDS_MS:
raise ValueError(
f"{self.milliseconds=} not in [{BOUNDS_MS.start, BOUNDS_MS.stop})")
def _to_ms(self):
return self.milliseconds + 1000 * (self.seconds + 60 * (self.minutes + 60 * self.hours))
@dataclass(frozen=True)
class Delta(_TimeBase):
"""A time difference, with milliseconds accuracy.
Must be less than 100h long."""
sign: int = 1
def __post_init__(self):
if self.sign not in (1, -1):
raise ValueError(
f"{self.sign=} should either be 1 or -1")
super().__post_init__()
def __add__(self, other: "Delta") -> "Delta":
self_ms = self.sign * self._to_ms()
other_ms = other.sign * other._to_ms()
ms_sum = self_ms + other_ms
sign = -1 if ms_sum < 0 else 1
ms_sum = abs(ms_sum)
ms_n, s_rem = ms_sum % 1000, ms_sum // 1000
s_n, m_rem = s_rem % 60, s_rem // 60
m_n, h_n = m_rem % 60, m_rem // 60
return Delta(hours=h_n, minutes=m_n, seconds=s_n, milliseconds=ms_n, sign=sign)
@dataclass(frozen=True)
class Timestamp(_TimeBase):
"""A timestamp with milliseconds accuracy. Must be
less than 100h long."""
def __add__(self, other: Delta) -> "Timestamp":
ms_sum = self._to_ms() + other.sign * other._to_ms()
ms_n, s_rem = ms_sum % 1000, ms_sum // 1000
s_n, m_rem = s_rem % 60, s_rem // 60
m_n, h_n = m_rem % 60, m_rem // 60
return Timestamp(hours=h_n, minutes=m_n, seconds=s_n, milliseconds=ms_n)
def __ge__(self, other: "Timestamp") -> bool:
return self._to_ms() >= other._to_ms()
Run Code Online (Sandbox Code Playgroud)
class Timestamp:
def __init__( self, hours=0, minutes=0, seconds=0, milliseconds=0 ):
self.ms = ((hours*60+minutes)*60+seconds)*1000+milliseconds
def get_hours(self):
return self.ms // (60*60*1000)
def get_minutes(self):
return (self.ms // (60*1000)) % 60
def get_seconds(self):
return (self.ms // 1000) % 60
def get_milliseconds(self):
return self.ms % 1000
def __add__(self,other):
return Timestamp(milliseconds=self.ms + self.other)
def __eq__(self,other):
return self.ms == other.ms
def __lt__(self,other):
return self.ms < other.ms
def __le__(self,other):
return self.ms <= other.ms
... your code ...
text = "21\n01:02:03,004 --> 05:06:07,008\nTest subtitle."
res = SubtitleItem.load_from_text_item(text)
exp = SubtitleItem(
21, Timestamp(hours=1, minutes=2, seconds=3, milliseconds=4),
Timestamp(hours=5, minutes=6, seconds=7, milliseconds=8),
"Test subtitle."
)
print(res)
print(exp)
print(res==exp)
Run Code Online (Sandbox Code Playgroud)
生产:
21
1:2:3,4 --> 5:6:7,8
Test subtitle.
21
1:2:3,4 --> 5:6:7,8
Test subtitle.
True
Run Code Online (Sandbox Code Playgroud)
没有断言异常。