compiler: handle direct calls to class methods.

Fixes #1005.
This commit is contained in:
whitequark 2018-05-25 02:02:18 +00:00
parent c9287cfc69
commit 12d1b9819c
3 changed files with 32 additions and 3 deletions

View File

@ -6,6 +6,6 @@ from .iodelay_estimator import IODelayEstimator
from .artiq_ir_generator import ARTIQIRGenerator
from .dead_code_eliminator import DeadCodeEliminator
from .local_demoter import LocalDemoter
from .llvm_ir_generator import LLVMIRGenerator
from .interleaver import Interleaver
from .typedtree_printer import TypedtreePrinter
from .llvm_ir_generator import LLVMIRGenerator

View File

@ -9,6 +9,7 @@ from pythonparser import ast, diagnostic
from llvmlite_artiq import ir as ll, binding as llvm
from ...language import core as language_core
from .. import types, builtins, ir
from ..embedding import SpecializedFunction
llvoid = ll.VoidType()
@ -1549,8 +1550,16 @@ class LLVMIRGenerator:
# RPC and C functions have no runtime representation.
return ll.Constant(llty, ll.Undefined)
elif types.is_function(typ):
return self.get_function_with_undef_env(typ.find(),
self.embedding_map.retrieve_function(value))
try:
func = self.embedding_map.retrieve_function(value)
except KeyError:
# If a class function was embedded directly (e.g. by a `C.f(...)` call),
# but it also appears in a class hierarchy, we might need to fall back
# to the non-specialized one, since direct invocations do not cause
# monomorphization.
assert isinstance(value, SpecializedFunction)
func = self.embedding_map.retrieve_function(value.host_function)
return self.get_function_with_undef_env(typ.find(), func)
elif types.is_method(typ):
llclosure = self._quote(value.__func__, types.get_method_function(typ),
lambda: path() + ['__func__'])

View File

@ -0,0 +1,20 @@
# RUN: %python -m artiq.compiler.testbench.embedding %s
from artiq.language.core import *
from artiq.language.types import *
class C:
@kernel
def f(self):
pass
class D(C):
@kernel
def f(self):
# super().f() # super() not bound
C.f(self) # KeyError in compile
di = D()
@kernel
def entrypoint():
di.f()