aboutsummaryrefslogtreecommitdiff
path: root/codegen_lbm.py
diff options
context:
space:
mode:
Diffstat (limited to 'codegen_lbm.py')
-rw-r--r--codegen_lbm.py32
1 files changed, 8 insertions, 24 deletions
diff --git a/codegen_lbm.py b/codegen_lbm.py
index c1ccd00..3e1865c 100644
--- a/codegen_lbm.py
+++ b/codegen_lbm.py
@@ -50,30 +50,16 @@ collide_opt = cse(collide, optimizations='basic')
kernel = """
__constant float tau = ${tau};
-unsigned int indexOfDirection(int i, int j) {
- return (i+1) + 3*(1-j);
-}
-
-unsigned int indexOfCell(int x, int y)
-{
- return y * ${nX} + x;
-}
-
-unsigned int idx(int x, int y, int i, int j) {
- return indexOfDirection(i,j)*${nCells} + indexOfCell(x,y);
-}
-
-__global float f_i(__global __read_only float* f, int x, int y, int i, int j) {
- return f[idx(x,y,i,j)];
-}
+<%
+def direction_index(c_i):
+ return (c_i[0]+1) + 3*(1-c_i[1])
+%>
__kernel void collide_and_stream(__global __write_only float* f_a,
__global __read_only float* f_b,
__global __read_only int* material)
{
- const unsigned int gid = indexOfCell(get_global_id(0), get_global_id(1));
-
- const uint2 cell = (uint2)(get_global_id(0), get_global_id(1));
+ const unsigned int gid = get_global_id(1)*${nX} + get_global_id(0);
const int m = material[gid];
@@ -82,7 +68,7 @@ __kernel void collide_and_stream(__global __write_only float* f_a,
}
% for i, c_i in enumerate(c):
- const float f_curr_${i} = f_i(f_b, cell.x-(${c_i[0]}), cell.y-(${c_i[1]}), ${c_i[0]}, ${c_i[1]});
+ const float f_curr_${i} = f_b[${direction_index(c_i)*nCells}u + (get_global_id(1)-(${c_i[1]}))*${nX} + get_global_id(0)-(${c_i[0]})];
% endfor
% for i, expr in enumerate(moments_helper):
@@ -114,12 +100,10 @@ __kernel void collide_and_stream(__global __write_only float* f_a,
__kernel void collect_moments(__global __read_only float* f,
__global __write_only float* moments)
{
- const unsigned int gid = indexOfCell(get_global_id(0), get_global_id(1));
-
- const uint2 cell = (uint2)(get_global_id(0), get_global_id(1));
+ const unsigned int gid = get_global_id(1)*${nX} + get_global_id(0);
% for i in range(0,len(c)):
- const float f_curr_${i} = f[${i*nCells} + gid];
+ const float f_curr_${i} = f[${i*nCells}u + gid];
% endfor
% for i, expr in enumerate(moments_helper):