diff options
-rw-r--r-- | codegen_lbm.py | 32 |
1 files changed, 8 insertions, 24 deletions
diff --git a/codegen_lbm.py b/codegen_lbm.py index c1ccd00..3e1865c 100644 --- a/codegen_lbm.py +++ b/codegen_lbm.py @@ -50,30 +50,16 @@ collide_opt = cse(collide, optimizations='basic') kernel = """ __constant float tau = ${tau}; -unsigned int indexOfDirection(int i, int j) { - return (i+1) + 3*(1-j); -} - -unsigned int indexOfCell(int x, int y) -{ - return y * ${nX} + x; -} - -unsigned int idx(int x, int y, int i, int j) { - return indexOfDirection(i,j)*${nCells} + indexOfCell(x,y); -} - -__global float f_i(__global __read_only float* f, int x, int y, int i, int j) { - return f[idx(x,y,i,j)]; -} +<% +def direction_index(c_i): + return (c_i[0]+1) + 3*(1-c_i[1]) +%> __kernel void collide_and_stream(__global __write_only float* f_a, __global __read_only float* f_b, __global __read_only int* material) { - const unsigned int gid = indexOfCell(get_global_id(0), get_global_id(1)); - - const uint2 cell = (uint2)(get_global_id(0), get_global_id(1)); + const unsigned int gid = get_global_id(1)*${nX} + get_global_id(0); const int m = material[gid]; @@ -82,7 +68,7 @@ __kernel void collide_and_stream(__global __write_only float* f_a, } % for i, c_i in enumerate(c): - const float f_curr_${i} = f_i(f_b, cell.x-(${c_i[0]}), cell.y-(${c_i[1]}), ${c_i[0]}, ${c_i[1]}); + const float f_curr_${i} = f_b[${direction_index(c_i)*nCells}u + (get_global_id(1)-(${c_i[1]}))*${nX} + get_global_id(0)-(${c_i[0]})]; % endfor % for i, expr in enumerate(moments_helper): @@ -114,12 +100,10 @@ __kernel void collide_and_stream(__global __write_only float* f_a, __kernel void collect_moments(__global __read_only float* f, __global __write_only float* moments) { - const unsigned int gid = indexOfCell(get_global_id(0), get_global_id(1)); - - const uint2 cell = (uint2)(get_global_id(0), get_global_id(1)); + const unsigned int gid = get_global_id(1)*${nX} + get_global_id(0); % for i in range(0,len(c)): - const float f_curr_${i} = f[${i*nCells} + gid]; + const float f_curr_${i} = f[${i*nCells}u + gid]; % endfor % for i, expr in enumerate(moments_helper): |