diff --git a/artiq/compiler/embedding.py b/artiq/compiler/embedding.py index 6d46317d1..763accb20 100644 --- a/artiq/compiler/embedding.py +++ b/artiq/compiler/embedding.py @@ -5,10 +5,11 @@ the references to the host objects and translates the functions annotated as ``@kernel`` when they are referenced. """ -import os, re, linecache, inspect, textwrap +import sys, os, re, linecache, inspect, textwrap from collections import OrderedDict, defaultdict from pythonparser import ast, algorithm, source, diagnostic, parse_buffer +from pythonparser import lexer as source_lexer, parser as source_parser from ..language import core as language_core from . import types, builtins, asttyped, prelude @@ -424,7 +425,7 @@ class Stitcher: # Extract function source. embedded_function = function.artiq_embedded.function - source_code = textwrap.dedent(inspect.getsource(embedded_function)) + source_code = inspect.getsource(embedded_function) filename = embedded_function.__code__.co_filename module_name = embedded_function.__globals__['__name__'] first_line = embedded_function.__code__.co_firstlineno @@ -436,10 +437,20 @@ class Stitcher: cell_names = embedded_function.__code__.co_freevars host_environment.update({var: cells[index] for index, var in enumerate(cell_names)}) + # Find out how indented we are. + initial_whitespace = re.search(r"^\s*", source_code).group(0) + initial_indent = len(initial_whitespace.expandtabs()) + # Parse. source_buffer = source.Buffer(source_code, filename, first_line) - parsetree, comments = parse_buffer(source_buffer, engine=self.engine) - function_node = parsetree.body[0] + lexer = source_lexer.Lexer(source_buffer, version=sys.version_info[0:2], + diagnostic_engine=self.engine) + lexer.indent = [(initial_indent, + source.Range(source_buffer, 0, len(initial_whitespace)), + initial_whitespace)] + parser = source_parser.Parser(lexer, version=sys.version_info[0:2], + diagnostic_engine=self.engine) + function_node = parser.file_input().body[0] # Mangle the name, since we put everything into a single module. function_node.name = "{}.{}".format(module_name, function.__qualname__)