aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAdrian Kummerlaender2019-06-11 20:19:25 +0200
committerAdrian Kummerlaender2019-06-11 20:19:25 +0200
commit3660fbbd71c8579b60c1e062f9c1d288253c0d04 (patch)
tree1dcc0d78aa1afa000fec783a353f3a65bcc56620
parent75d008822f143ec44cf1abbc6280e44c4bfb9146 (diff)
downloadsymlbm_playground-3660fbbd71c8579b60c1e062f9c1d288253c0d04.tar
symlbm_playground-3660fbbd71c8579b60c1e062f9c1d288253c0d04.tar.gz
symlbm_playground-3660fbbd71c8579b60c1e062f9c1d288253c0d04.tar.bz2
symlbm_playground-3660fbbd71c8579b60c1e062f9c1d288253c0d04.tar.lz
symlbm_playground-3660fbbd71c8579b60c1e062f9c1d288253c0d04.tar.xz
symlbm_playground-3660fbbd71c8579b60c1e062f9c1d288253c0d04.tar.zst
symlbm_playground-3660fbbd71c8579b60c1e062f9c1d288253c0d04.zip
Start to use codegen for actual kernel generation
-rw-r--r--codegen_lbm.py165
-rw-r--r--shell.nix1
2 files changed, 100 insertions, 66 deletions
diff --git a/codegen_lbm.py b/codegen_lbm.py
index d637d49..a115ef1 100644
--- a/codegen_lbm.py
+++ b/codegen_lbm.py
@@ -1,8 +1,6 @@
import pyopencl as cl
mf = cl.mem_flags
-from string import Template
-
import numpy
import time
@@ -10,6 +8,44 @@ import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('AGG')
+from sympy import *
+from sympy.codegen.ast import Assignment
+
+from mako.template import Template
+
+q = 9
+d = 2
+
+c = [ Matrix(x) for x in [(-1, 1), ( 0, 1), ( 1, 1), (-1, 0), ( 0, 0), ( 1, 0), (-1,-1), ( 0, -1), ( 1, -1)] ]
+w = [ Rational(*x) for x in [(1,36), (1,9), (1,36), (1,9), (4,9), (1,9), (1,36), (1,9), (1,36)] ]
+
+c_s = sqrt(Rational(1,3))
+
+rho, tau = symbols('rho tau')
+
+f_next = symarray('f_next', q)
+f_curr = symarray('f_curr', q)
+
+u = Matrix(symarray('u', d))
+
+moments = [ Assignment(rho, sum(f_curr)) ]
+
+for i, u_i in enumerate(u):
+ moments.append(Assignment(u_i, sum([ (c_j*f_curr[j])[i] for j, c_j in enumerate(c) ]) / sum(f_curr)))
+
+moments_opt = cse(moments, optimizations='basic', symbols=numbered_symbols(prefix='m'))
+
+f_eq = []
+
+for i, c_i in enumerate(c):
+ f_eq_i = w[i] * rho * ( 1
+ + c_i.dot(u) / c_s**2
+ + c_i.dot(u)**2 / (2*c_s**4)
+ - u.dot(u) / (2*c_s**2) )
+ f_eq.append(f_eq_i)
+
+collide = [ Assignment(f_next[i], f_curr[i] + 1/tau * ( f_eq_i - f_curr[i] )) for i, f_eq_i in enumerate(f_eq) ]
+collide_opt = cse(collide, optimizations='basic')
kernel = """
unsigned int indexOfDirection(int i, int j) {
@@ -18,11 +54,11 @@ unsigned int indexOfDirection(int i, int j) {
unsigned int indexOfCell(int x, int y)
{
- return y * $nX + x;
+ return y * ${nX} + x;
}
unsigned int idx(int x, int y, int i, int j) {
- return indexOfDirection(i,j)*$nCells + indexOfCell(x,y);
+ return indexOfDirection(i,j)*${nCells} + indexOfCell(x,y);
}
__global float f_i(__global __read_only float* f, int x, int y, int i, int j) {
@@ -53,48 +89,38 @@ __kernel void collide_and_stream(__global __write_only float* f_a,
const float f_curr_7 = f_i(f_b, cell.x , cell.y+1, 0,-1);
const float f_curr_8 = f_i(f_b, cell.x-1, cell.y+1, 1,-1);
- const float ux0 = f_curr_3 + f_curr_6;
- const float ux1 = f_curr_1 + f_curr_2;
- const float ux2 = 1.0/(f_curr_0 + f_curr_4 + f_curr_5 + f_curr_7 + f_curr_8 + ux0 + ux1);
- const float ux3 = f_curr_0 - f_curr_8;
+ const float tau = ${tau};
- float u_x = -ux2*(-f_curr_2 - f_curr_5 + ux0 + ux3);
- float u_y = ux2*(-f_curr_6 - f_curr_7 + ux1 + ux3);
+% for i, expr in enumerate(moments_helper):
+ const float ${expr[0]} = ${ccode(expr[1])};
+% endfor
+
+% for i, expr in enumerate(moments_assignment):
+ float ${ccode(expr)}
+% endfor
if ( m == 2 ) {
- u_x = 0.0;
- u_y = 0.0;
+ u_0 = 0.0;
+ u_1 = 0.0;
}
- const float x0 = f_curr_0 + f_curr_1 + f_curr_2 + f_curr_3 + f_curr_4 + f_curr_5 + f_curr_6 + f_curr_7 + f_curr_8;
- const float x1 = 6*u_y;
- const float x2 = 6*u_x;
- const float x3 = pow(u_y, 2);
- const float x4 = 3*x3;
- const float x5 = pow(u_x, 2);
- const float x6 = 3*x5;
- const float x7 = x6 - 2;
- const float x8 = x4 + x7;
- const float x9 = x2 + x8;
- const float x10 = 1.0/$tau;
- const float x11 = (1.0/72.0)*x10;
- const float x12 = 6*x3;
- const float x13 = x1 - x6 + 2;
- const float x14 = (1.0/18.0)*x10;
- const float x15 = -x4;
- const float x16 = 9*pow(u_x + u_y, 2);
- const float x17 = -x2;
- const float x18 = x15 + 6*x5 + 2;
-
- f_a[0*$nCells + gid] = f_curr_0 - x11*(72*f_curr_0 + x0*(-x1 + x9 - 9*pow(-u_x + u_y, 2)));
- f_a[1*$nCells + gid] = f_curr_1 - x14*(18*f_curr_1 - x0*(x12 + x13));
- f_a[2*$nCells + gid] = f_curr_2 - x11*(72*f_curr_2 - x0*(x13 + x15 + x16 + x2));
- f_a[3*$nCells + gid] = f_curr_3 - x14*(18*f_curr_3 - x0*(x17 + x18));
- f_a[4*$nCells + gid] = f_curr_4 - 1.0/9.0*x10*(9*f_curr_4 + 2*x0*x8);
- f_a[5*$nCells + gid] = f_curr_5 - x14*(18*f_curr_5 - x0*(x18 + x2));
- f_a[6*$nCells + gid] = f_curr_6 - x11*(72*f_curr_6 + x0*(x1 - x16 + x9));
- f_a[7*$nCells + gid] = f_curr_7 - x14*(18*f_curr_7 + x0*(x1 - x12 + x7));
- f_a[8*$nCells + gid] = f_curr_8 - x11*(72*f_curr_8 + x0*(x1 + x17 + x8 - 9*pow(u_x - u_y, 2)));
+% for i, expr in enumerate(collide_helper):
+ const float ${expr[0]} = ${ccode(expr[1])};
+% endfor
+
+% for i, expr in enumerate(collide_assignment):
+ const float ${ccode(expr)}
+% endfor
+
+ f_a[0*${nCells} + gid] = f_next_0;
+ f_a[1*${nCells} + gid] = f_next_1;
+ f_a[2*${nCells} + gid] = f_next_2;
+ f_a[3*${nCells} + gid] = f_next_3;
+ f_a[4*${nCells} + gid] = f_next_4;
+ f_a[5*${nCells} + gid] = f_next_5;
+ f_a[6*${nCells} + gid] = f_next_6;
+ f_a[7*${nCells} + gid] = f_next_7;
+ f_a[8*${nCells} + gid] = f_next_8;
}
__kernel void collect_moments(__global __read_only float* f,
@@ -104,25 +130,27 @@ __kernel void collect_moments(__global __read_only float* f,
const uint2 cell = (uint2)(get_global_id(0), get_global_id(1));
- const float f_curr_0 = f[0*$nCells + gid];
- const float f_curr_1 = f[1*$nCells + gid];
- const float f_curr_2 = f[2*$nCells + gid];
- const float f_curr_3 = f[3*$nCells + gid];
- const float f_curr_4 = f[4*$nCells + gid];
- const float f_curr_5 = f[5*$nCells + gid];
- const float f_curr_6 = f[6*$nCells + gid];
- const float f_curr_7 = f[7*$nCells + gid];
- const float f_curr_8 = f[8*$nCells + gid];
-
- const float ux0 = f_curr_3 + f_curr_6;
- const float ux1 = f_curr_1 + f_curr_2;
- const float ux2 = 1.0/(f_curr_0 + f_curr_4 + f_curr_5 + f_curr_7 + f_curr_8 + ux0 + ux1);
- const float ux3 = f_curr_0 - f_curr_8;
-
- moments[0*$nCells + gid] = f_curr_0 + ux1 + ux0 + f_curr_4 + f_curr_5 + f_curr_7 + f_curr_8;
- moments[1*$nCells + gid] = -ux2*(-f_curr_2 - f_curr_5 + ux0 + ux3);
- moments[2*$nCells + gid] = ux2*(-f_curr_6 - f_curr_7 + ux1 + ux3);
-
+ const float f_curr_0 = f[0*${nCells} + gid];
+ const float f_curr_1 = f[1*${nCells} + gid];
+ const float f_curr_2 = f[2*${nCells} + gid];
+ const float f_curr_3 = f[3*${nCells} + gid];
+ const float f_curr_4 = f[4*${nCells} + gid];
+ const float f_curr_5 = f[5*${nCells} + gid];
+ const float f_curr_6 = f[6*${nCells} + gid];
+ const float f_curr_7 = f[7*${nCells} + gid];
+ const float f_curr_8 = f[8*${nCells} + gid];
+
+% for i, expr in enumerate(moments_helper):
+ const float ${expr[0]} = ${ccode(expr[1])};
+% endfor
+
+% for i, expr in enumerate(moments_assignment):
+ const float ${ccode(expr)}
+% endfor
+
+ moments[0*${nCells} + gid] = rho;
+ moments[1*${nCells} + gid] = u_0;
+ moments[2*${nCells} + gid] = u_1;
}"""
@@ -193,12 +221,17 @@ class D2Q9_BGK_Lattice:
self.np_pop_b[:,self.idx(x,y)] = 1./24.
def build_kernel(self):
- self.program = cl.Program(self.context, Template(kernel).substitute({
- 'nX' : self.nX,
- 'nY' : self.nY,
- 'nCells': self.nCells,
- 'tau': '0.8f'
- })).build() #'-cl-single-precision-constant -cl-fast-relaxed-math')
+ self.program = cl.Program(self.context, Template(kernel).render(
+ nX = self.nX,
+ nY = self.nY,
+ nCells = self.nCells,
+ tau = '0.8f',
+ moments_helper = moments_opt[0],
+ moments_assignment = moments_opt[1],
+ collide_helper = collide_opt[0],
+ collide_assignment = collide_opt[1],
+ ccode = ccode
+ )).build()
def collect_moments(self):
if self.tick:
diff --git a/shell.nix b/shell.nix
index a5021ca..638b66c 100644
--- a/shell.nix
+++ b/shell.nix
@@ -27,6 +27,7 @@ pkgs.stdenvNoCC.mkDerivation rec {
pyopencl
pyopengl
matplotlib
+ Mako
]);
in [