如何定义一个数据类,使其每个属性都是其子类属性的列表?

lho*_*ert 6 python python-dataclasses

我有这个代码:

from dataclasses import dataclass
from typing import List

@dataclass
class Position:
    name: str
    lon: float
    lat: float

@dataclass
class Section:
    positions: List[Position]

pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2 , pos3])

print(sec.positions)
Run Code Online (Sandbox Code Playgroud)

如何在数据类中创建其他属性,Section以便它们成为其子类的属性列表Position

在我的示例中,我希望 section 对象也返回:

sec.name = ['a', 'b', 'c']   #[pos1.name,pos2.name,pos3.name]
sec.lon = [52, 46, 45]       #[pos1.lon,pos2.lon,pos3.lon]
sec.lat = [10, -10, -10]     #[pos1.lat,pos2.lat,pos3.lat]
Run Code Online (Sandbox Code Playgroud)

我试图将数据类定义为:

@dataclass
class Section:
    positions: List[Position]
    names :  List[Position.name]
Run Code Online (Sandbox Code Playgroud)

但它不起作用,因为名称不是位置属性。我可以在代码中定义稍后属性的对象(例如通过执行secs.name = [x.name for x in section.positions])。但是如果可以在数据类定义级别完成它会更好。

发布这个问题后,我找到了答案的开始(/sf/answers/4565581051/)。

但我想知道是否没有更通用/“自动”的方式来定义 Section 方法: .names(), .lons(), .lats(), ... ?因此,开发人员不必单独定义每个方法,而是根据 Positions 对象属性创建这些方法?

Mau*_*yer 6

__init__您可以在调用后创建一个新字段:

from dataclasses import dataclass, field, fields
from typing import List


@dataclass
class Position:
    name: str
    lon: float
    lat: float


@dataclass
class Section:
    positions: List[Position]
    _pos: dict = field(init=False, repr=False)

    def __post_init__(self):
        # create _pos after init is done, read only!
        Section._pos = property(Section._get_positions)

    def _get_positions(self):
        _pos = {}

        # iterate over all fields and add to _pos
        for field in [f.name for f in fields(self.positions[0])]:
            if field not in _pos:
                _pos[field] = []

            for p in self.positions:
                _pos[field].append(getattr(p, field))
        return _pos


pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2, pos3])

print(sec.positions)
print(sec._pos['name'])
print(sec._pos['lon'])
print(sec._pos['lat'])
Run Code Online (Sandbox Code Playgroud)

出去:

[Position(name='a', lon=52, lat=10), Position(name='b', lon=46, lat=-10), Position(name='c', lon=45, lat=-10)]
['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]
Run Code Online (Sandbox Code Playgroud)

编辑

如果您只需要更通用,您可以覆盖__getattr__

from dataclasses import dataclass, field, fields
from typing import List


@dataclass
class Position:
    name: str
    lon: float
    lat: float


@dataclass
class Section:
    positions: List[Position]

    def __getattr__(self, keyName):
        for f in fields(self.positions[0]):
            if f"{f.name}s" == keyName:
                return [getattr(x, f.name) for x in self.positions]
        # Error handling here: Return empty list, raise AttributeError, ...

pos1 = Position('a', 52, 10)
pos2 = Position('b', 46, -10)
pos3 = Position('c', 45, -10)

sec = Section([pos1, pos2, pos3])

print(sec.names)
print(sec.lons)
print(sec.lats)
Run Code Online (Sandbox Code Playgroud)

出去:

['a', 'b', 'c']
[52, 46, 45]
[10, -10, -10]
Run Code Online (Sandbox Code Playgroud)