检测上下文管理器嵌套

Qoo*_*ooS 3 python nested with-statement contextmanager

我最近一直在想是否有一种方法可以检测上下文管理器是否嵌套。

我创建了Timer和TimerGroup类:

class Timer:
    def __init__(self, name="Timer"):
        self.name = name
        self.start_time = clock()

    @staticmethod
    def seconds_to_str(t):
        return str(timedelta(seconds=t))

    def end(self):
        return clock() - self.start_time

    def print(self, t):
        print(("{0:<" + str(line_width - 18) + "} >> {1}").format(self.name, self.seconds_to_str(t)))

    def __enter__(self):
        return self

    def __exit__(self, exc_type, value, traceback):
        self.print(self.end())


class TimerGroup(Timer):
    def __enter__(self):
        print(('= ' + self.name + ' ').ljust(line_width, '='))
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        total_time = self.seconds_to_str(self.end())
        print(" Total: {0}".format(total_time).rjust(line_width, '='))
        print()
Run Code Online (Sandbox Code Playgroud)

此代码以可读格式打印时间:

with TimerGroup("Collecting child documents for %s context" % context_name):
    with Timer("Collecting context features"):
        # some code...
    with Timer("Collecting child documents"):
        # some code...


= Collecting child documents for Global context ============
Collecting context features                >> 0:00:00.001063
Collecting child documents                 >> 0:00:10.611130
====================================== Total: 0:00:10.612292
Run Code Online (Sandbox Code Playgroud)

但是,当我嵌套TimerGroups时,它搞砸了:

with TimerGroup("Choosing the best classifier for %s context" % context_name):
    with Timer("Splitting datasets"):
        # some code...
    for cname, cparams in classifiers.items():
        with TimerGroup("%s classifier" % cname):
            with Timer("Training"):
                # some code...
            with Timer("Calculating accuracy on testing set"):
                # some code


= Choosing the best classifier for Global context ==========
Splitting datasets                         >> 0:00:00.002054
= Naive Bayes classifier ===================================
Training                                   >> 0:00:34.184903
Calculating accuracy on testing set        >> 0:05:08.481904
====================================== Total: 0:05:42.666949

====================================== Total: 0:05:42.669078
Run Code Online (Sandbox Code Playgroud)

我要做的就是以某种方式缩进嵌套的Timers和TimerGroups。我应该将任何参数传递给它们的构造函数吗?还是可以从班级内部检测到?

Mar*_*ers 5

没有检测嵌套上下文管理器的特殊功能。您必须自己处理。您可以在自己的上下文管理器中执行此操作:

import threading


class TimerGroup(Timer):
    _active_group = threading.local()

    def __enter__(self):
        if getattr(TimerGroup._active_group, 'current', False):
            raise RuntimeError("Can't nest TimerGroup context managers")
        TimerGroup._active_group.current = self
        print(('= ' + self.name + ' ').ljust(line_width, '='))
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        TimerGroup._active_group.current = None
        total_time = self.seconds_to_str(self.end())
        print(" Total: {0}".format(total_time).rjust(line_width, '='))
        print()
Run Code Online (Sandbox Code Playgroud)

然后,您可以在TimerGroup._active_group其他位置使用该属性来获取当前活动的组。我使用了一个线程局部对象来确保可以在多个执行线程之间使用。

另外,您可以使一个堆栈计数器仅在嵌套__enter__调用中递增和递减,或者使一个堆栈列表并推入self该堆栈,并在您执行__exit__以下操作时再次弹出:

import threading


class TimerGroup(Timer):
    _active_group = threading.local()

    def __enter__(self):
        if not hasattr(TimerGroup._active_group, 'current'):
            TimerGroup._active_group.current = []
        stack = TimerGroup._active_group.current
        if stack:
            # nested context manager.
            # do something with stack[-1] or stack[0]
        TimerGroup._active_group.current.append(self)

        print(('= ' + self.name + ' ').ljust(line_width, '='))
        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        last = TimerGroup._active_group.current.pop()
        assert last == self, "Context managers being exited out of order"
        total_time = self.seconds_to_str(self.end())
        print(" Total: {0}".format(total_time).rjust(line_width, '='))
        print()
Run Code Online (Sandbox Code Playgroud)