5 python code-generation sympy codegen python-3.x
我正在使用公共子表达式消除 (CSE) 例程和 ccode 打印机使用 sympy 生成 C 代码。
但是,我希望将幂表达式设置为 (x*x) 而不是 pow(x,2)。
无论如何要这样做吗?
例子:
import sympy as sp
a= sp.MatrixSymbol('a',3,3)
b=sp.Matrix(a)*sp.Matrix(a)
res = sp.cse(b)
lines = []
for tmp in res[0]:
lines.append(sp.ccode(tmp[1], tmp[0]))
for i,result in enumerate(res[1]):
lines.append(sp.ccode(result,"result_%i"%i))
Run Code Online (Sandbox Code Playgroud)
将输出:
x0[0] = a[0];
x0[1] = a[1];
x0[2] = a[2];
x0[3] = a[3];
x0[4] = a[4];
x0[5] = a[5];
x0[6] = a[6];
x0[7] = a[7];
x0[8] = a[8];
x1 = x0[0];
x2 = x0[1];
x3 = x0[3];
x4 = x2*x3;
x5 = x0[2];
x6 = x0[6];
x7 = x5*x6;
x8 = x0[4];
x9 = x0[7];
x10 = x0[5];
x11 = x0[8];
x12 = x10*x9;
result_0[0] = pow(x1, 2) + x4 + x7;
result_0[1] = x1*x2 + x2*x8 + x5*x9;
result_0[2] = x1*x5 + x10*x2 + x11*x5;
result_0[3] = x1*x3 + x10*x6 + x3*x8;
result_0[4] = x12 + x4 + pow(x8, 2);
result_0[5] = x10*x11 + x10*x8 + x3*x5;
result_0[6] = x1*x6 + x11*x6 + x3*x9;
result_0[7] = x11*x9 + x2*x6 + x8*x9;
result_0[8] = pow(x11, 2) + x12 + x7;
Run Code Online (Sandbox Code Playgroud)
此致
您可以对代码打印机进行子类化,并且仅更改您想要不同的一个功能。您需要研究原始的 sympy 代码以找到正确的函数名称和默认实现,这样您就可以确保不会犯错误。只要稍加小心,所需的括号就可以在需要的时间和地点自动生成。
这是一个最小的例子:
import sympy as sp
from sympy.printing.c import C99CodePrinter
from sympy.printing.precedence import precedence
from sympy.abc import x
class CustomCodePrinter(C99CodePrinter):
def _print_Pow(self, expr):
PREC = precedence(expr)
if expr.exp == 2:
return '({0} * {0})'.format(self.parenthesize(expr.base, PREC))
else:
return super()._print_Pow(expr)
default_printer = C99CodePrinter().doprint
custom_printer = CustomCodePrinter().doprint
expressions = [x, (2 + x) ** 2, x ** 3, x ** 15, sp.sqrt(5), sp.sqrt(x)**4, 1 / x, 1 / (x * x)]
print("Default: {}".format(default_printer(expressions)))
print("Custom: {}".format(custom_printer(expressions)))
Run Code Online (Sandbox Code Playgroud)
输出:
Default: [x, pow(x + 2, 2), pow(x, 3), pow(x, 15), sqrt(5), pow(x, 2), 1.0/x, pow(x, -2)]
Custom: [x, ((x + 2) * (x + 2)), pow(x, 3), pow(x, 15), sqrt(5), (x * x), 1.0/x, pow(x, -2)]
Run Code Online (Sandbox Code Playgroud)
PS:为了支持更广泛的指数,您可以使用例如
class CustomCodePrinter(C99CodePrinter):
def _print_Pow(self, expr):
PREC = precedence(expr)
if expr.exp in range(2, 7):
return '*'.join([self.parenthesize(expr.base, PREC)] * int(expr.exp))
elif expr.exp in range(-6, 0):
return '1.0/(' + ('*'.join([self.parenthesize(expr.base, PREC)] * int(-expr.exp))) + ')'
else:
return super()._print_Pow(expr)
Run Code Online (Sandbox Code Playgroud)
| 归档时间: |
|
| 查看次数: |
1235 次 |
| 最近记录: |