如何在python3中使用AST递归地简化数学表达式?

Gui*_*sse 4 recursion abstract-syntax-tree python-3.x

我有这个数学表达式:

\n
tree = ast.parse(\'1 + 2 + 3 + x\')\n
Run Code Online (Sandbox Code Playgroud)\n

对应于这个抽象语法树:

\n
Module(body=[Expr(value=BinOp(left=BinOp(left=BinOp(left=Num(n=1), op=Add(), right=Num(n=2)), op=Add(), right=Num(n=3)), op=Add(), right=Name(id=\'x\', ctx=Load())))])\n
Run Code Online (Sandbox Code Playgroud)\n

我想简化它 - 也就是说,得到这个:

\n
Module(body=[Expr(value=BinOp(left=Num(n=6), op=Add(), right=Name(id=\'x\', ctx=Load())))])\n
Run Code Online (Sandbox Code Playgroud)\n

根据文档,我应该使用 NodeTransformer 类。文档中的建议如下:

\n
\n

请记住,如果您正在操作的节点\xe2\x80\x99 有子节点,您必须自己转换子节点或首先调用该节点的generic_visit() 方法。

\n
\n

我尝试实现我自己的变压器:

\n
class Evaluator(ast.NodeTransformer):\n    def visit_BinOp(self, node):\n        print(\'Evaluating \', ast.dump(node))\n        for child in ast.iter_child_nodes(node):\n            self.visit(child)\n\n        if type(node.left) == ast.Num and type(node.right) == ast.Num:\n            print(ast.literal_eval(node))\n            return ast.copy_location(ast.Subscript(value=ast.literal_eval(node)), node)\n        else:\n            return node\n
Run Code Online (Sandbox Code Playgroud)\n

在这种特定情况下,它应该做的是将 1+2 简化为 3,然后将 3 +3 简化为 6。\n它确实简化了我想要简化的二进制运算,但它不会更新原始语法树。我尝试了不同的方法,但我仍然不明白如何递归地简化所有二进制操作(以深度优先的方式)。有人能指出我正确的方向吗?

\n

谢谢。

\n

a_g*_*est 5

这些方法有三种可能的返回值visit_*

  1. None这意味着该节点将被删除,
  2. node(节点本身)这意味着不会应用任何更改,
  3. 一个新节点将取代旧节点。

因此,当您想用 a 替换 时,BinOpNum需要返回一个新Num节点。表达式的求值不能通过 来完成,ast.literal_eval因为该函数仅求值文字(而不是任意表达式)。相反,您可以使用eval例如。

因此,您可以使用以下节点转换器类:

import ast

class Evaluator(ast.NodeTransformer):
    ops = {
        ast.Add: '+',
        ast.Sub: '-',
        ast.Mult: '*',
        ast.Div: '/',
        # define more here
    }

    def visit_BinOp(self, node):
        self.generic_visit(node)
        if isinstance(node.left, ast.Num) and isinstance(node.right, ast.Num):
            # On Python <= 3.6 you can use ast.literal_eval.
            # value = ast.literal_eval(node)
            value = eval(f'{node.left.n} {self.ops[type(node.op)]} {node.right.n}')
            return ast.Num(n=value)
        return node

tree = ast.parse('1 + 2 + 3 + x')
tree = ast.fix_missing_locations(Evaluator().visit(tree))
print(ast.dump(tree))
Run Code Online (Sandbox Code Playgroud)