diff --git a/pattern_library/loader_tags.py b/pattern_library/loader_tags.py index 2efb6cfb..4cc6a9b8 100644 --- a/pattern_library/loader_tags.py +++ b/pattern_library/loader_tags.py @@ -150,3 +150,50 @@ def do_include(parser, token): extra_context=namemap, isolated_context=isolated_context, ) + +def visit_extends(self, node, frame): + """This method overrides the jinja extends tag + Is called as part of the compiler CodeGenerator + and adds a line to use the template_new_context as + part of the runtime render to pull in the dpl context + Handles visiting extends + """ + from .monkey_utils import jinja_visit_Extends + + jinja_visit_Extends(self, node, frame) + # addition to update the context with dpl context + # calls the template_new_context method below when + # invoked at runtime + self.writeline( + "parent_template.new_context(context.get_all(), True," + f" {self.dump_local_context(frame)})" + ) + + +def template_new_context( + self, + vars=None, + shared=False, + locals=None, +): + """This method overrides the jinja include tag + Is called as part of Template.render by jinja2 and is updated + to pull in the dpl context + Create a new :class:`Context` for this template. The vars + provided will be passed to the template. Per default the globals + are added to the context. If shared is set to `True` the data + is passed as is to the context without adding the globals. + + `locals` can be a dict of local variables for internal usage. + """ + from jinja2.runtime import new_context + + if is_pattern_library_context(vars or {}) and ( + pattern_context := get_pattern_context(self.name) + ): + for k, v in pattern_context.items(): + vars.setdefault(k, v) + + return new_context( + self.environment, self.name, self.blocks, vars, shared, self.globals, locals + ) \ No newline at end of file diff --git a/pattern_library/management/commands/render_patterns.py b/pattern_library/management/commands/render_patterns.py index 482ae902..7f3caeed 100644 --- a/pattern_library/management/commands/render_patterns.py +++ b/pattern_library/management/commands/render_patterns.py @@ -7,9 +7,8 @@ from pattern_library import get_base_template_names, get_pattern_base_template_name from pattern_library.utils import ( get_pattern_context, - get_pattern_templates, - get_template_ancestors, render_pattern, + get_renderer, ) @@ -44,7 +43,8 @@ def handle(self, **options): self.wrap_fragments = options["wrap_fragments"] self.output_dir = options["output_dir"] - templates = get_pattern_templates() + renderer = get_renderer() + templates = renderer.get_pattern_templates() factory = RequestFactory() request = factory.get("/") @@ -106,7 +106,8 @@ def render_pattern(self, request, pattern_template_name): if not self.wrap_fragments: return rendered_pattern - pattern_template_ancestors = get_template_ancestors( + renderer = get_renderer() + pattern_template_ancestors = renderer.get_template_ancestors( pattern_template_name, context=get_pattern_context(pattern_template_name), ) diff --git a/pattern_library/monkey_utils.py b/pattern_library/monkey_utils.py index 956e638b..7dbc659f 100644 --- a/pattern_library/monkey_utils.py +++ b/pattern_library/monkey_utils.py @@ -96,3 +96,24 @@ def node_render(context): return original_node return tag_func + +# have to export the original jinja visit Extends +# in the case jinja tags are being overriden +jinja_visit_Extends = None + +def override_jinja_tags(): + """ + Overrides jinja extends and include tags for use in your pattern library. + Call it in your settings to override tags + """ + global jinja_visit_Extends + try: + from jinja2.compiler import CodeGenerator as JinjaCodeGenerator + from jinja2.environment import Template as JinjaTemplate + except ModuleNotFoundError: + ModuleNotFoundError("install jinja2 to override jinja tags") + + from .loader_tags import template_new_context, visit_extends + jinja_visit_Extends = JinjaCodeGenerator.visit_Extends + JinjaTemplate.new_context = template_new_context + JinjaCodeGenerator.visit_Extends = visit_extends \ No newline at end of file diff --git a/pattern_library/utils.py b/pattern_library/utils.py index b1a81b8d..7230b766 100644 --- a/pattern_library/utils.py +++ b/pattern_library/utils.py @@ -22,6 +22,9 @@ from pattern_library.exceptions import TemplateIsNotPattern + +from django.utils.html import escape + def path_to_section(): section_config = get_sections() sections = {} @@ -79,77 +82,6 @@ def get_template_dirs(): return template_dirs -def get_pattern_templates(): - templates = base_dict() - template_dirs = get_template_dirs() - - for lookup_dir in template_dirs: - for root, dirs, files in os.walk(lookup_dir, topdown=True): - # Ignore folders without files - if not files: - continue - - base_path = os.path.relpath(root, lookup_dir) - section, path = section_for(base_path) - - # It has no section, ignore it - if not section: - continue - - found_templates = [] - for current_file in files: - pattern_path = os.path.join(root, current_file) - pattern_path = os.path.relpath(pattern_path, lookup_dir) - - if is_pattern(pattern_path): - template = get_template(pattern_path) - pattern_config = get_pattern_config(pattern_path) - pattern_name = pattern_config.get("name") - pattern_filename = os.path.relpath( - template.origin.template_name, - base_path, - ) - if pattern_name: - template.pattern_name = pattern_name - else: - template.pattern_name = pattern_filename - - template.pattern_filename = pattern_filename - - found_templates.append(template) - - if found_templates: - lookup_dir_relpath = os.path.relpath(root, lookup_dir) - sub_folders = os.path.relpath(lookup_dir_relpath, path) - templates_to_store = templates - for folder in [section, *sub_folders.split(os.sep)]: - try: - templates_to_store = templates_to_store["template_groups"][ - folder - ] - except KeyError: - templates_to_store["template_groups"][folder] = base_dict() - templates_to_store = templates_to_store["template_groups"][ - folder - ] - - templates_to_store["templates_stored"].extend(found_templates) - - # Order the templates alphabetically - for templates_objs in templates["template_groups"].values(): - templates_objs["template_groups"] = order_dict( - templates_objs["template_groups"] - ) - - # Order the top level by the sections - section_order = [section for section, _ in get_sections()] - templates["template_groups"] = order_dict( - templates["template_groups"], key_sort=lambda key: section_order.index(key) - ) - - return templates - - def get_pattern_config_str(template_name): replace_pattern = "{}$".format(get_pattern_template_suffix()) context_path = re.sub(replace_pattern, "", template_name) @@ -227,27 +159,159 @@ def render_pattern(request, template_name, allow_non_patterns=False, config=None return render_to_string(template_name, request=request, context=context) -def get_template_ancestors(template_name, context=None, ancestors=None): - """ - Returns a list of template names, starting with provided name - and followed by the names of any templates that extends until - the most extended template is reached. - """ - if ancestors is None: - ancestors = [template_name] - - if context is None: - context = Context() +def get_renderer(): + return TemplateRenderer + + +class TemplateRenderer: + @classmethod + def get_pattern_templates(cls): + templates = base_dict() + template_dirs = get_template_dirs() + + for lookup_dir in template_dirs: + for root, dirs, files in os.walk(lookup_dir, topdown=True): + # Ignore folders without files + if not files: + continue + + base_path = os.path.relpath(root, lookup_dir) + section, path = section_for(base_path) + + # It has no section, ignore it + if not section: + continue + + found_templates = [] + for current_file in files: + pattern_path = os.path.join(root, current_file) + pattern_path = os.path.relpath(pattern_path, lookup_dir) + + if is_pattern(pattern_path): + template = get_template(pattern_path) + pattern_config = get_pattern_config(pattern_path) + pattern_name = pattern_config.get("name") + pattern_filename = os.path.relpath( + template.origin.template_name, + base_path, + ) + if pattern_name: + template.pattern_name = pattern_name + else: + template.pattern_name = pattern_filename + + template.pattern_filename = pattern_filename + + found_templates.append(template) + + if found_templates: + lookup_dir_relpath = os.path.relpath(root, lookup_dir) + sub_folders = os.path.relpath(lookup_dir_relpath, path) + templates_to_store = templates + for folder in [section, *sub_folders.split(os.sep)]: + try: + templates_to_store = templates_to_store["template_groups"][ + folder + ] + except KeyError: + templates_to_store["template_groups"][folder] = base_dict() + + templates_to_store = templates_to_store["template_groups"][ + folder + ] + + templates_to_store["templates_stored"].extend(found_templates) + + # Order the templates alphabetically + for templates_objs in templates["template_groups"].values(): + templates_objs["template_groups"] = order_dict( + templates_objs["template_groups"] + ) - pattern_template = get_template(template_name) + # Order the top level by the sections + section_order = [section for section, _ in get_sections()] + templates["template_groups"] = order_dict( + templates["template_groups"], key_sort=lambda key: section_order.index(key) + ) - for node in pattern_template.template.nodelist: - if isinstance(node, ExtendsNode): - parent_template_name = node.parent_name.resolve(context) + return templates + + @classmethod + def get_pattern_source(cls, template): + return cls._get_engine(template).get_pattern_source(template) + + @classmethod + def get_template_ancestors(cls, template_name, context=None): + template = get_template(template_name) + return cls._get_engine(template).get_template_ancestors(template_name, context=context) + + @classmethod + def _get_engine(cls, template): + if "jinja" in str(type(template)).lower(): + return JinjaTemplateRenderer + return DTLTemplateRenderer + +class DTLTemplateRenderer: + @staticmethod + def get_pattern_source(template): + return escape(template.template.source) + + @classmethod + def get_template_ancestors(cls, template_name, context=None, ancestors=None): + """ + Returns a list of template names, starting with provided name + and followed by the names of any templates that extends until + the most extended template is reached. + """ + if ancestors is None: + ancestors = [template_name] + + if context is None: + context = Context() + + pattern_template = get_template(template_name) + + for node in pattern_template.template.nodelist: + if isinstance(node, ExtendsNode): + parent_template_name = node.parent_name.resolve(context) + ancestors.append(parent_template_name) + cls.get_template_ancestors( + parent_template_name, context=context, ancestors=ancestors + ) + break + + return ancestors + + +class JinjaTemplateRenderer: + @staticmethod + def get_pattern_source(template): + with open(template.template.filename) as f: + source = escape(f.read()) + return source + + @classmethod + def get_template_ancestors(cls, template_name, context=None, ancestors=None): + """ + Returns a list of template names, starting with provided name + and followed by the names of any templates that extends until + the most extended template is reached. + """ + from jinja2.nodes import Extends + + if ancestors is None: + ancestors = [template_name] + + if context is None: + context = Context() + + pattern_template = get_template(template_name) + #todo - make sure envrionment has context passed in + environment = pattern_template.template.environment + nodelist = environment.parse(pattern_template.name) + parent_template_name = nodelist.find(Extends) + if parent_template_name: ancestors.append(parent_template_name) - get_template_ancestors( - parent_template_name, context=context, ancestors=ancestors - ) - break + cls.get_template_ancestors(parent_template_name, context=context, ancestors=ancestors) - return ancestors + return ancestors diff --git a/pattern_library/views.py b/pattern_library/views.py index 4e3ab01d..030c1d8a 100644 --- a/pattern_library/views.py +++ b/pattern_library/views.py @@ -15,11 +15,10 @@ get_pattern_config_str, get_pattern_context, get_pattern_markdown, - get_pattern_templates, get_sections, - get_template_ancestors, is_pattern, render_pattern, + get_renderer, ) @@ -52,7 +51,8 @@ def get_first_template(self, templates): def get(self, request, pattern_template_name=None): # Get all pattern templates - templates = get_pattern_templates() + renderer = get_renderer() + templates = renderer.get_pattern_templates() if pattern_template_name is None: # Just display the first pattern if a specific one isn't requested @@ -67,7 +67,7 @@ def get(self, request, pattern_template_name=None): context = self.get_context_data() context["pattern_templates"] = templates context["pattern_template_name"] = pattern_template_name - context["pattern_source"] = escape(template.template.source) + context["pattern_source"] = renderer.get_pattern_source(template) context["pattern_config"] = escape( get_pattern_config_str(pattern_template_name) ) @@ -83,7 +83,8 @@ class RenderPatternView(TemplateView): @method_decorator(xframe_options_sameorigin) def get(self, request, pattern_template_name=None): - pattern_template_ancestors = get_template_ancestors( + renderer = get_renderer() + pattern_template_ancestors = renderer.get_template_ancestors( pattern_template_name, context=get_pattern_context(self.kwargs["pattern_template_name"]), ) diff --git a/tests/tests/test_utils.py b/tests/tests/test_utils.py index 9730be5f..7c051631 100644 --- a/tests/tests/test_utils.py +++ b/tests/tests/test_utils.py @@ -5,15 +5,17 @@ from pattern_library.utils import ( get_pattern_config_str, - get_template_ancestors, get_template_dirs, + get_renderer, ) class TestGetTemplateAncestors(SimpleTestCase): + def setUp(self): + self.renderer = get_renderer() def test_page(self): self.assertEqual( - get_template_ancestors("patterns/pages/test_page/test_page.html"), + self.renderer.get_template_ancestors("patterns/pages/test_page/test_page.html"), [ "patterns/pages/test_page/test_page.html", "patterns/base_page.html", @@ -23,7 +25,7 @@ def test_page(self): def test_fragment(self): self.assertEqual( - get_template_ancestors("patterns/atoms/test_atom/test_atom.html"), + self.renderer.get_template_ancestors("patterns/atoms/test_atom/test_atom.html"), [ "patterns/atoms/test_atom/test_atom.html", ], @@ -31,7 +33,7 @@ def test_fragment(self): def test_parent_template_from_variable(self): self.assertEqual( - get_template_ancestors( + self.renderer.get_template_ancestors( "patterns/atoms/test_extends/extended.html", context={"parent_template_name": "patterns/base.html"}, ),