aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--symbolic/generator.py6
-rw-r--r--symbolic/optimizations.py10
2 files changed, 14 insertions, 2 deletions
diff --git a/symbolic/generator.py b/symbolic/generator.py
index 161e2f4..7315ae0 100644
--- a/symbolic/generator.py
+++ b/symbolic/generator.py
@@ -1,6 +1,8 @@
from sympy import *
from sympy.codegen.ast import Assignment
+import symbolic.optimizations as optimizations
+
class LBM:
def __init__(self, descriptor):
self.descriptor = descriptor
@@ -18,7 +20,7 @@ class LBM:
Assignment(u_i, sum([ (c_j*self.f_curr[j])[i] for j, c_j in enumerate(self.descriptor.c) ]) / sum(self.f_curr)))
if optimize:
- return cse(exprs, optimizations='basic', symbols=numbered_symbols(prefix='m'))
+ return cse(exprs, optimizations=optimizations.custom, symbols=numbered_symbols(prefix='m'))
else:
return ([], exprs)
@@ -41,6 +43,6 @@ class LBM:
exprs = [ Assignment(self.f_next[i], self.f_curr[i] + 1/tau * (f_eq_i - self.f_curr[i])) for i, f_eq_i in enumerate(f_eq) ]
if optimize:
- return cse(exprs, optimizations='basic')
+ return cse(exprs, optimizations=optimizations.custom)
else:
return ([], exprs)
diff --git a/symbolic/optimizations.py b/symbolic/optimizations.py
new file mode 100644
index 0000000..93dad09
--- /dev/null
+++ b/symbolic/optimizations.py
@@ -0,0 +1,10 @@
+from sympy import *
+
+from sympy.codegen.rewriting import ReplaceOptim
+
+expand_square = ReplaceOptim(
+ lambda e: e.is_Pow and e.exp.is_integer and e.exp == 2,
+ lambda p: UnevaluatedExpr(Mul(p.base, p.base, evaluate = False))
+)
+
+custom = [ (expand_square, expand_square) ] + cse_main.basic_optimizations