如何在 Django ORM 中预取或子查询带有条件的深层嵌套对象

fab*_*hel 5 python django django-orm

有这些模型和关系:

Hours --FK--> Task --FK--> Project <--FK-- Period

class Hour(models.Model):
  date = models.DateField(...)
  task = models.ForeignKey(Task, ...)

class Task(models.Model):
  project = models.ForeignKey(Project, ...)
  
class Project(models.Model):
  pass

class Period(models.Model):
  project = models.ForeignKey(Project,...)
  start = models.DateField(...)
  end = models.DateField(...)


Summary :
 Hour has one task
 Task has one project
 Period has one project
 Hour has a date
 Period has a start date and a end date
Run Code Online (Sandbox Code Playgroud)

对于给定日期和给定项目,可能有一个时期或没有时期

我想以相同的方式填充对象period中的字段(使用查询集)Hourprefetch_related

我想要这样的东西:

hours = Hour.objects.prefetch_period().all()
hours.first().period # Period(...)
Run Code Online (Sandbox Code Playgroud)

使用像这样的自定义查询集方法:

class HourQuerySet(models.query.QuerySet):
  def prefetch_related(self):
    return ???
Run Code Online (Sandbox Code Playgroud)

目前我只能使用annotateand成功执行此操作Subquery,但我只能检索 period_id 而不是预取的 period :

def inject_period(self):
    period_qs = (
        Period.objects.filter(
            project__tasks=OuterRef("task"), start__lte=OuterRef("date"), end__gte=OuterRef("date")
        )
        .values("id")[:1]
    )
    return self.annotate(period_id=Subquery(period_qs))
Run Code Online (Sandbox Code Playgroud)

fab*_*hel 0

我找到了以下解决方案,但它有点老套,我不确定是否结束使用它。

我重写了 django 的内部_fetch_all方法,该方法QuerySet在查询集被触发时被调用。然后我进行自定义预取并设置实例的属性。

这可能需要一些进一步的优化。

class HourQuerySet(models.query.QuerySet):
  # With annotate and Subquery, search and define period_id
  def prefetch_period(self):
    period_qs = (
        Period.objects.filter(
            project__tasks=OuterRef("task"), start__lte=OuterRef("date"), end__gte=OuterRef("date")
        )
        .values("id")[:1]
    )
    return self.annotate(period_id=Subquery(period_qs))

  # Override _fetch_all method to manually prefetch and inject period in returned instances (for which who have a period_id defined)
  def _fetch_all(self):
    super()._fetch_all()
    if not self._result_cache or type(self._result_cache[0]) is dict:
        return
    period_ids = [r.period_id for r in self._result_cache]
    if not period_ids:
        return
    periods = {p.id: p for p in Period.objects.filter(id__in=period_ids)}
    for wh in self._result_cache:
        setattr(wh, "period", periods.get(wh.period_id))
Run Code Online (Sandbox Code Playgroud)