From b4afecd0f9565bac8cf2e82651b9e2a0b4ba9a87 Mon Sep 17 00:00:00 2001 From: Adrian Kummerlaender Date: Sun, 16 Jun 2019 14:28:53 +0200 Subject: Replace some explicit dimension branching --- simulation.py | 35 ++++++++++++----------------------- 1 file changed, 12 insertions(+), 23 deletions(-) (limited to 'simulation.py') diff --git a/simulation.py b/simulation.py index ba3945c..672003a 100644 --- a/simulation.py +++ b/simulation.py @@ -15,23 +15,16 @@ class Geometry: self.volume = size_x * size_y * size_z def inner_cells(self): - if self.size_z == 1: - for y in range(1,self.size_y-1): - for x in range(1,self.size_x-1): - yield x, y - else: - for z in range(1,self.size_z-1): - for y in range(1,self.size_y-1): - for x in range(1,self.size_x-1): - yield x, y, z + for idx in numpy.ndindex(self.inner_size()): + yield tuple(map(lambda i: i + 1, idx)) - def span(self): + def size(self): if self.size_z == 1: return (self.size_x, self.size_y) else: return (self.size_x, self.size_y, self.size_z) - def inner_span(self): + def inner_size(self): if self.size_z == 1: return (self.size_x-2, self.size_y-2) else: @@ -75,18 +68,14 @@ class Lattice: }.get((descriptor.d, descriptor.q), None) self.program.equilibrilize( - self.queue, self.geometry.span(), self.layout, self.cl_pop_a, self.cl_pop_b).wait() + self.queue, self.geometry.size(), self.layout, self.cl_pop_a, self.cl_pop_b).wait() - def idx(self, x, y, z = 0): + def gid(self, x, y, z = 0): return z * (self.geometry.size_x*self.geometry.size_y) + y * self.geometry.size_x + x; def setup_geometry(self, material_at): - if self.descriptor.d == 2: - for x, y in self.geometry.inner_cells(): - self.np_material[self.idx(x,y)] = material_at(self.geometry, x, y) - elif self.descriptor.d == 3: - for x, y, z in self.geometry.inner_cells(): - self.np_material[self.idx(x,y,z)] = material_at(self.geometry, x, y, z) + for idx in self.geometry.inner_cells(): + self.np_material[self.gid(*idx)] = material_at(self.geometry, *idx) cl.enqueue_copy(self.queue, self.cl_material, self.np_material).wait(); @@ -117,11 +106,11 @@ class Lattice: if self.tick: self.tick = False self.program.collide_and_stream( - self.queue, self.geometry.span(), self.layout, self.cl_pop_a, self.cl_pop_b, self.cl_material) + self.queue, self.geometry.size(), self.layout, self.cl_pop_a, self.cl_pop_b, self.cl_material) else: self.tick = True self.program.collide_and_stream( - self.queue, self.geometry.span(), self.layout, self.cl_pop_b, self.cl_pop_a, self.cl_material) + self.queue, self.geometry.size(), self.layout, self.cl_pop_b, self.cl_pop_a, self.cl_material) def sync(self): self.queue.finish() @@ -130,9 +119,9 @@ class Lattice: moments = numpy.ndarray(shape=(self.descriptor.d+1, self.geometry.volume), dtype=numpy.float32) if self.tick: self.program.collect_moments( - self.queue, self.geometry.span(), self.layout, self.cl_pop_b, self.cl_moments) + self.queue, self.geometry.size(), self.layout, self.cl_pop_b, self.cl_moments) else: self.program.collect_moments( - self.queue, self.geometry.span(), self.layout, self.cl_pop_a, self.cl_moments) + self.queue, self.geometry.size(), self.layout, self.cl_pop_a, self.cl_moments) cl.enqueue_copy(self.queue, moments, self.cl_moments).wait(); return moments -- cgit v1.2.3