aboutsummaryrefslogtreecommitdiff
path: root/boltzgas/simulation.py
diff options
context:
space:
mode:
Diffstat (limited to 'boltzgas/simulation.py')
-rw-r--r--boltzgas/simulation.py43
1 files changed, 21 insertions, 22 deletions
diff --git a/boltzgas/simulation.py b/boltzgas/simulation.py
index 6af9069..d818bab 100644
--- a/boltzgas/simulation.py
+++ b/boltzgas/simulation.py
@@ -25,7 +25,27 @@ def build_kernel(delta_t, n_particles, radius):
radius = radius)
class HardSphereSimulation:
- def setup_cl(self):
+ def __init__(self, setup, opengl = False, t_scale = 1.0):
+ self.np_particle_position = setup.position.astype(np.float32)
+ self.np_particle_velocity = setup.velocity.astype(np.float32)
+
+ self.n_particles = setup.n_particles
+ self.radius = setup.radius
+ self.char_u = setup.char_u
+
+ self.opengl = opengl
+ self.t_scale = t_scale
+
+ self.np_last_collide = np.ndarray((self.n_particles, 1), dtype=np.uint32)
+ self.np_last_collide[:,0] = self.n_particles
+
+ self.np_particle_velocity_norms = np.ndarray((self.n_particles, 1), dtype=np.float32)
+
+ self.kernel_src = build_kernel(self.t_scale*self.radius/self.char_u, self.n_particles, self.radius)
+
+ self.tick = True
+
+ def setup(self):
self.platform = cl.get_platforms()[0]
if self.opengl:
self.context = cl.Context(
@@ -56,27 +76,6 @@ class HardSphereSimulation:
self.cl_last_collide = cl.Buffer(self.context, mf.COPY_HOST_PTR, hostbuf=self.np_last_collide)
self.cl_particle_velocity_norms = cl.Buffer(self.context, mf.COPY_HOST_PTR, hostbuf=self.np_particle_velocity_norms)
- def __init__(self, setup, opengl = False, t_scale = 1.0):
- self.np_particle_position = setup.position.astype(np.float32)
- self.np_particle_velocity = setup.velocity.astype(np.float32)
-
- self.n_particles = setup.n_particles
- self.radius = setup.radius
- self.char_u = setup.char_u
-
- self.opengl = opengl
- self.t_scale = t_scale
-
- self.np_last_collide = np.ndarray((self.n_particles, 1), dtype=np.uint32)
- self.np_last_collide[:,0] = self.n_particles
-
- self.np_particle_velocity_norms = np.ndarray((self.n_particles, 1), dtype=np.float32)
-
- self.kernel_src = build_kernel(self.t_scale*self.radius/self.char_u, self.n_particles, self.radius)
-
- self.setup_cl()
-
- self.tick = True
def evolve(self):
if self.opengl: