aboutsummaryrefslogtreecommitdiff
path: root/boltzgen/kernel/template/collect_moments.cl.mako
diff options
context:
space:
mode:
Diffstat (limited to 'boltzgen/kernel/template/collect_moments.cl.mako')
-rw-r--r--boltzgen/kernel/template/collect_moments.cl.mako26
1 files changed, 8 insertions, 18 deletions
diff --git a/boltzgen/kernel/template/collect_moments.cl.mako b/boltzgen/kernel/template/collect_moments.cl.mako
index 39317e3..8adf295 100644
--- a/boltzgen/kernel/template/collect_moments.cl.mako
+++ b/boltzgen/kernel/template/collect_moments.cl.mako
@@ -1,36 +1,26 @@
+<%namespace name="pattern" file="${'/pattern/%s.cl.mako' % context['streaming']}"/>
<%
import sympy
+moments_subexpr, moments_assignment = model.moments()
%>
-__kernel void collect_moments_gid(__global ${float_type}* f,
- __global ${float_type}* m,
- unsigned int gid)
-{
- __global ${float_type}* preshifted_f = f + ${layout.cell_preshift('gid')};
- __global ${float_type}* preshifted_m = m + gid*${descriptor.d+1};
-
-% for i in range(0,descriptor.q):
- const ${float_type} f_curr_${i} = preshifted_f[${layout.pop_offset(i)}];
-% endfor
-
-<%
- moments_subexpr, moments_assignment = model.moments()
-%>
-
+<%call expr="pattern.functor_ab('collect_moments', [('__global %s*' % float_type, 'm')])">
% for i, expr in enumerate(moments_subexpr):
const ${float_type} ${expr[0]} = ${sympy.ccode(expr[1])};
% endfor
+ __global ${float_type}* preshifted_m = m + gid*${descriptor.d+1};
+
% for i, expr in enumerate(moments_assignment):
preshifted_m[${i}] = ${sympy.ccode(expr.rhs)};
% endfor
-}
+</%call>
% if 'cell_list_dispatch' in extras:
__kernel void collect_moments_cells(__global ${float_type}* f,
- __global ${float_type}* moments,
+ __global ${float_type}* m,
__global unsigned int* cells)
{
- collect_moments_gid(f, moments, cells[get_global_id(0)]);
+ collect_moments(f, cells[get_global_id(0)], m);
}
% endif