aboutsummaryrefslogtreecommitdiff
path: root/boltzgen/kernel/generator.py
diff options
context:
space:
mode:
Diffstat (limited to 'boltzgen/kernel/generator.py')
-rw-r--r--boltzgen/kernel/generator.py12
1 files changed, 10 insertions, 2 deletions
diff --git a/boltzgen/kernel/generator.py b/boltzgen/kernel/generator.py
index 8da91ba..dd44a56 100644
--- a/boltzgen/kernel/generator.py
+++ b/boltzgen/kernel/generator.py
@@ -1,8 +1,14 @@
from mako.template import Template
+from mako.lookup import TemplateLookup
+
from pathlib import Path
from . import memory
+template_lookup = TemplateLookup(directories = [
+ Path(__file__).parent/"template"
+])
+
class Generator:
def __init__(self, model, target, precision, index, layout):
self.model = model
@@ -25,12 +31,13 @@ class Generator:
if not template_path.exists():
raise Exception("Target '%s' doesn't provide '%s'" % (self.target, template))
- return Template(filename = str(template_path)).render(
+ return Template(filename = str(template_path), lookup = template_lookup).render(
descriptor = self.descriptor,
model = self.model,
geometry = geometry,
index = self.index_impl(geometry),
layout = self.layout_impl(self.descriptor, self.index_impl, geometry),
+ streaming = 'AB',
float_type = self.float_type,
extras = extras
)
@@ -42,12 +49,13 @@ class Generator:
return "\n".join(map(lambda f: self.instantiate(f, geometry, extras), functions))
def custom(self, geometry, source, extras = []):
- return Template(text = source).render(
+ return Template(text = source, lookup = template_lookup).render(
descriptor = self.descriptor,
model = self.model,
geometry = geometry,
index = self.index_impl(geometry),
layout = self.layout_impl(self.descriptor, self.index_impl, geometry),
+ streaming = 'AB',
float_type = self.float_type,
extras = extras
)