从Python AST获取与给定名称的特定变量对应的所有节点

Ato*_*nal 3 python abstract-syntax-tree

考虑下面的代码:

1 | x = 20
2 | 
3 | def f():
4 |     x = 0
5 |     for x in range(10):
6 |         x += 10
7 |     return x
8 | f()
9 |
10| for x in range(10):
11|     pass
12| x += 1
13| print(x)
Run Code Online (Sandbox Code Playgroud)

的值x的代码的执行之后以上10。现在,我怎么能得到所有一流的节点Name,其ids为x并指x这是一个在线路1,10,12和13被使用?

换句话说,的x内部fxs 的其余部分不同。是否可以获取仅具有脚本和脚本的AST而不执行脚本的AST节点?

Mar*_*ers 5

在AST树上行走时,请跟踪上下文;与全球范围内启动,那么当你遇到FunctionDefClassDefLambda节点,记录这方面作为一个堆栈(退出相关节点时再弹出堆栈)。

然后,您可以仅查看Name全局上下文中的节点。您也可以跟踪global标识符(我会在每个堆栈级别使用一组)。

使用NodeVisitor子类

import ast

class GlobalUseCollector(ast.NodeVisitor):
    def __init__(self, name):
        self.name = name
        # track context name and set of names marked as `global`
        self.context = [('global', ())]

    def visit_FunctionDef(self, node):
        self.context.append(('function', set()))
        self.generic_visit(node)
        self.context.pop()

    # treat coroutines the same way
    visit_AsyncFunctionDef = visit_FunctionDef

    def visit_ClassDef(self, node):
        self.context.append(('class', ()))
        self.generic_visit(node)
        self.context.pop()

    def visit_Lambda(self, node):
        # lambdas are just functions, albeit with no statements, so no assignments
        self.context.append(('function', ()))
        self.generic_visit(node)
        self.context.pop()

    def visit_Global(self, node):
        assert self.context[-1][0] == 'function'
        self.context[-1][1].update(node.names)

    def visit_Name(self, node):
        ctx, g = self.context[-1]
        if node.id == self.name and (ctx == 'global' or node.id in g):
            print('{} used at line {}'.format(node.id, node.lineno))
Run Code Online (Sandbox Code Playgroud)

演示(为中的示例代码提供了AST树t):

>>> GlobalUseCollector('x').visit(t)
x used at line 1
x used at line 10
x used at line 12
x used at line 13
Run Code Online (Sandbox Code Playgroud)

global x在函数中使用:

>>> u = ast.parse('''\
... x = 20
...
... def g():
...     global x
...     x = 0
...     for x in range(10):
...         x += 10
...     return x
...
... g()
... for x in range(10):
...     pass
... x += 1
... print(x)
... ''')
>>> GlobalUseCollector('x').visit(u)
x used at line 1
x used at line 5
x used at line 6
x used at line 7
x used at line 8
x used at line 11
x used at line 13
x used at line 14
Run Code Online (Sandbox Code Playgroud)