diff options
-rw-r--r-- | symbolic/generator.py | 6 | ||||
-rw-r--r-- | symbolic/optimizations.py | 10 |
2 files changed, 14 insertions, 2 deletions
diff --git a/symbolic/generator.py b/symbolic/generator.py index 161e2f4..7315ae0 100644 --- a/symbolic/generator.py +++ b/symbolic/generator.py @@ -1,6 +1,8 @@ from sympy import * from sympy.codegen.ast import Assignment +import symbolic.optimizations as optimizations + class LBM: def __init__(self, descriptor): self.descriptor = descriptor @@ -18,7 +20,7 @@ class LBM: Assignment(u_i, sum([ (c_j*self.f_curr[j])[i] for j, c_j in enumerate(self.descriptor.c) ]) / sum(self.f_curr))) if optimize: - return cse(exprs, optimizations='basic', symbols=numbered_symbols(prefix='m')) + return cse(exprs, optimizations=optimizations.custom, symbols=numbered_symbols(prefix='m')) else: return ([], exprs) @@ -41,6 +43,6 @@ class LBM: exprs = [ Assignment(self.f_next[i], self.f_curr[i] + 1/tau * (f_eq_i - self.f_curr[i])) for i, f_eq_i in enumerate(f_eq) ] if optimize: - return cse(exprs, optimizations='basic') + return cse(exprs, optimizations=optimizations.custom) else: return ([], exprs) diff --git a/symbolic/optimizations.py b/symbolic/optimizations.py new file mode 100644 index 0000000..93dad09 --- /dev/null +++ b/symbolic/optimizations.py @@ -0,0 +1,10 @@ +from sympy import * + +from sympy.codegen.rewriting import ReplaceOptim + +expand_square = ReplaceOptim( + lambda e: e.is_Pow and e.exp.is_integer and e.exp == 2, + lambda p: UnevaluatedExpr(Mul(p.base, p.base, evaluate = False)) +) + +custom = [ (expand_square, expand_square) ] + cse_main.basic_optimizations |