diff options
Diffstat (limited to 'boltzgas/simulation.py')
-rw-r--r-- | boltzgas/simulation.py | 43 |
1 files changed, 21 insertions, 22 deletions
diff --git a/boltzgas/simulation.py b/boltzgas/simulation.py index 6af9069..d818bab 100644 --- a/boltzgas/simulation.py +++ b/boltzgas/simulation.py @@ -25,7 +25,27 @@ def build_kernel(delta_t, n_particles, radius): radius = radius) class HardSphereSimulation: - def setup_cl(self): + def __init__(self, setup, opengl = False, t_scale = 1.0): + self.np_particle_position = setup.position.astype(np.float32) + self.np_particle_velocity = setup.velocity.astype(np.float32) + + self.n_particles = setup.n_particles + self.radius = setup.radius + self.char_u = setup.char_u + + self.opengl = opengl + self.t_scale = t_scale + + self.np_last_collide = np.ndarray((self.n_particles, 1), dtype=np.uint32) + self.np_last_collide[:,0] = self.n_particles + + self.np_particle_velocity_norms = np.ndarray((self.n_particles, 1), dtype=np.float32) + + self.kernel_src = build_kernel(self.t_scale*self.radius/self.char_u, self.n_particles, self.radius) + + self.tick = True + + def setup(self): self.platform = cl.get_platforms()[0] if self.opengl: self.context = cl.Context( @@ -56,27 +76,6 @@ class HardSphereSimulation: self.cl_last_collide = cl.Buffer(self.context, mf.COPY_HOST_PTR, hostbuf=self.np_last_collide) self.cl_particle_velocity_norms = cl.Buffer(self.context, mf.COPY_HOST_PTR, hostbuf=self.np_particle_velocity_norms) - def __init__(self, setup, opengl = False, t_scale = 1.0): - self.np_particle_position = setup.position.astype(np.float32) - self.np_particle_velocity = setup.velocity.astype(np.float32) - - self.n_particles = setup.n_particles - self.radius = setup.radius - self.char_u = setup.char_u - - self.opengl = opengl - self.t_scale = t_scale - - self.np_last_collide = np.ndarray((self.n_particles, 1), dtype=np.uint32) - self.np_last_collide[:,0] = self.n_particles - - self.np_particle_velocity_norms = np.ndarray((self.n_particles, 1), dtype=np.float32) - - self.kernel_src = build_kernel(self.t_scale*self.radius/self.char_u, self.n_particles, self.radius) - - self.setup_cl() - - self.tick = True def evolve(self): if self.opengl: |