diff --git a/artiq/compiler/builtins.py b/artiq/compiler/builtins.py index 16e24381c..efac80e9a 100644 --- a/artiq/compiler/builtins.py +++ b/artiq/compiler/builtins.py @@ -205,4 +205,5 @@ def is_allocated(typ): accum or not (is_none(typ) or is_bool(typ) or is_int(typ) or is_float(typ) or is_range(typ) or types.is_c_function(typ) or types.is_rpc_function(typ) or + types.is_method(typ) or types.is_value(typ))) diff --git a/artiq/compiler/transforms/artiq_ir_generator.py b/artiq/compiler/transforms/artiq_ir_generator.py index c61449c35..30c821082 100644 --- a/artiq/compiler/transforms/artiq_ir_generator.py +++ b/artiq/compiler/transforms/artiq_ir_generator.py @@ -711,13 +711,22 @@ class ARTIQIRGenerator(algorithm.Visitor): finally: self.current_assign = old_assign - if node.attr not in node.type.find().attributes: + if node.attr not in obj.type.find().attributes: # A class attribute. Get the constructor (class object) and # extract the attribute from it. - constructor = obj.type.constructor - obj = self.append(ir.GetConstructor(self._env_for(constructor.name), - constructor.name, constructor, - name="constructor." + constructor.name)) + print(node) + print(obj) + constr_type = obj.type.constructor + constr = self.append(ir.GetConstructor(self._env_for(constr_type.name), + constr_type.name, constr_type, + name="constructor." + constr_type.name)) + + if types.is_function(constr.type.attributes[node.attr]): + # A method. Construct a method object instead. + func = self.append(ir.GetAttr(constr, node.attr)) + return self.append(ir.Alloc([func, obj], node.type)) + else: + obj = constr if self.current_assign is None: return self.append(ir.GetAttr(obj, node.attr, @@ -1413,36 +1422,49 @@ class ARTIQIRGenerator(algorithm.Visitor): elif types.is_builtin(typ): return self.visit_builtin_call(node) else: - func = self.visit(node.func) - args = [None] * (len(typ.args) + len(typ.optargs)) + if types.is_function(typ): + func = self.visit(node.func) + self_arg = None + fn_typ = typ + elif types.is_method(typ): + method = self.visit(node.func) + func = self.append(ir.GetAttr(method, "__func__")) + self_arg = self.append(ir.GetAttr(method, "__self__")) + fn_typ = types.get_method_function(typ) + + args = [None] * (len(fn_typ.args) + len(fn_typ.optargs)) for index, arg_node in enumerate(node.args): arg = self.visit(arg_node) - if index < len(typ.args): + if index < len(fn_typ.args): args[index] = arg else: args[index] = self.append(ir.Alloc([arg], ir.TOption(arg.type))) for keyword in node.keywords: arg = self.visit(keyword.value) - if keyword.arg in typ.args: - for index, arg_name in enumerate(typ.args): + if keyword.arg in fn_typ.args: + for index, arg_name in enumerate(fn_typ.args): if keyword.arg == arg_name: assert args[index] is None args[index] = arg break - elif keyword.arg in typ.optargs: - for index, optarg_name in enumerate(typ.optargs): + elif keyword.arg in fn_typ.optargs: + for index, optarg_name in enumerate(fn_typ.optargs): if keyword.arg == optarg_name: - assert args[len(typ.args) + index] is None - args[len(typ.args) + index] = \ + assert args[len(fn_typ.args) + index] is None + args[len(fn_typ.args) + index] = \ self.append(ir.Alloc([arg], ir.TOption(arg.type))) break - for index, optarg_name in enumerate(typ.optargs): - if args[len(typ.args) + index] is None: - args[len(typ.args) + index] = \ - self.append(ir.Alloc([], ir.TOption(typ.optargs[optarg_name]))) + for index, optarg_name in enumerate(fn_typ.optargs): + if args[len(fn_typ.args) + index] is None: + args[len(fn_typ.args) + index] = \ + self.append(ir.Alloc([], ir.TOption(fn_typ.optargs[optarg_name]))) + + if self_arg is not None: + assert args[0] is None + args[0] = self_arg assert None not in args diff --git a/artiq/compiler/transforms/inferencer.py b/artiq/compiler/transforms/inferencer.py index 2400723d0..8711aeedf 100644 --- a/artiq/compiler/transforms/inferencer.py +++ b/artiq/compiler/transforms/inferencer.py @@ -22,7 +22,7 @@ class Inferencer(algorithm.Visitor): self.in_loop = False self.has_return = False - def _unify(self, typea, typeb, loca, locb, makenotes=None): + def _unify(self, typea, typeb, loca, locb, makenotes=None, when=""): try: typea.unify(typeb) except types.UnificationError as e: @@ -45,16 +45,19 @@ class Inferencer(algorithm.Visitor): locb)) highlights = [locb] if locb else [] - if e.typea.find() == typea.find() and e.typeb.find() == typeb.find(): + if e.typea.find() == typea.find() and e.typeb.find() == typeb.find() or \ + e.typeb.find() == typea.find() and e.typea.find() == typeb.find(): diag = diagnostic.Diagnostic("error", - "cannot unify {typea} with {typeb}", - {"typea": printer.name(typea), "typeb": printer.name(typeb)}, + "cannot unify {typea} with {typeb}{when}", + {"typea": printer.name(typea), "typeb": printer.name(typeb), + "when": when}, loca, highlights, notes) else: # give more detail diag = diagnostic.Diagnostic("error", - "cannot unify {typea} with {typeb}: {fraga} is incompatible with {fragb}", + "cannot unify {typea} with {typeb}{when}: {fraga} is incompatible with {fragb}", {"typea": printer.name(typea), "typeb": printer.name(typeb), - "fraga": printer.name(e.typea), "fragb": printer.name(e.typeb)}, + "fraga": printer.name(e.typea), "fragb": printer.name(e.typeb), + "when": when}, loca, highlights, notes) self.engine.process(diag) @@ -88,13 +91,43 @@ class Inferencer(algorithm.Visitor): object_type = node.value.type.find() if not types.is_var(object_type): if node.attr in object_type.attributes: - # assumes no free type variables in .attributes + # Assumes no free type variables in .attributes. self._unify(node.type, object_type.attributes[node.attr], node.loc, None) elif types.is_instance(object_type) and \ node.attr in object_type.constructor.attributes: - # assumes no free type variables in .attributes - self._unify(node.type, object_type.constructor.attributes[node.attr], + # Assumes no free type variables in .attributes. + attr_type = object_type.constructor.attributes[node.attr].find() + if types.is_function(attr_type): + # Convert to a method. + if len(attr_type.args) < 1: + diag = diagnostic.Diagnostic("error", + "function '{attr}{type}' of class '{class}' cannot accept a self argument", + {"attr": node.attr, "type": types.TypePrinter().name(attr_type), + "class": object_type.name}, + node.loc) + self.engine.process(diag) + return + else: + def makenotes(printer, typea, typeb, loca, locb): + return [ + diagnostic.Diagnostic("note", + "expression of type {typea}", + {"typea": printer.name(typea)}, + loca), + diagnostic.Diagnostic("note", + "reference to a class function of type {typeb}", + {"typeb": printer.name(attr_type)}, + locb) + ] + + self._unify(object_type, list(attr_type.args.values())[0], + node.value.loc, node.loc, + makenotes=makenotes, + when=" while inferring the type for self argument") + + attr_type = types.TMethod(object_type, attr_type) + self._unify(node.type, attr_type, node.loc, None) else: diag = diagnostic.Diagnostic("error", @@ -695,7 +728,7 @@ class Inferencer(algorithm.Visitor): return elif types.is_builtin(typ): return self.visit_builtin_call(node) - elif not types.is_function(typ): + elif not (types.is_function(typ) or types.is_method(typ)): diag = diagnostic.Diagnostic("error", "cannot call this expression of type {type}", {"type": types.TypePrinter().name(typ)}, @@ -703,22 +736,34 @@ class Inferencer(algorithm.Visitor): self.engine.process(diag) return + if types.is_function(typ): + typ_arity = typ.arity() + typ_args = typ.args + typ_optargs = typ.optargs + typ_ret = typ.ret + else: + typ = types.get_method_function(typ) + typ_arity = typ.arity() - 1 + typ_args = OrderedDict(list(typ.args.items())[1:]) + typ_optargs = typ.optargs + typ_ret = typ.ret + passed_args = dict() - if len(node.args) > typ.arity(): + if len(node.args) > typ_arity: note = diagnostic.Diagnostic("note", "extraneous argument(s)", {}, - node.args[typ.arity()].loc.join(node.args[-1].loc)) + node.args[typ_arity].loc.join(node.args[-1].loc)) diag = diagnostic.Diagnostic("error", "this function of type {type} accepts at most {num} arguments", {"type": types.TypePrinter().name(node.func.type), - "num": typ.arity()}, + "num": typ_arity}, node.func.loc, [], [note]) self.engine.process(diag) return for actualarg, (formalname, formaltyp) in \ - zip(node.args, list(typ.args.items()) + list(typ.optargs.items())): + zip(node.args, list(typ_args.items()) + list(typ_optargs.items())): self._unify(actualarg.type, formaltyp, actualarg.loc, None) passed_args[formalname] = actualarg.loc @@ -732,15 +777,15 @@ class Inferencer(algorithm.Visitor): self.engine.process(diag) return - if keyword.arg in typ.args: - self._unify(keyword.value.type, typ.args[keyword.arg], + if keyword.arg in typ_args: + self._unify(keyword.value.type, typ_args[keyword.arg], keyword.value.loc, None) - elif keyword.arg in typ.optargs: - self._unify(keyword.value.type, typ.optargs[keyword.arg], + elif keyword.arg in typ_optargs: + self._unify(keyword.value.type, typ_optargs[keyword.arg], keyword.value.loc, None) passed_args[keyword.arg] = keyword.arg_loc - for formalname in typ.args: + for formalname in typ_args: if formalname not in passed_args: note = diagnostic.Diagnostic("note", "the called function is of type {type}", @@ -753,7 +798,7 @@ class Inferencer(algorithm.Visitor): self.engine.process(diag) return - self._unify(node.type, typ.ret, + self._unify(node.type, typ_ret, node.loc, None) def visit_LambdaT(self, node): diff --git a/artiq/compiler/transforms/llvm_ir_generator.py b/artiq/compiler/transforms/llvm_ir_generator.py index e068dcc05..981d76305 100644 --- a/artiq/compiler/transforms/llvm_ir_generator.py +++ b/artiq/compiler/transforms/llvm_ir_generator.py @@ -192,6 +192,10 @@ class LLVMIRGenerator: return llty else: return ll.LiteralStructType([envarg, llty.as_pointer()]) + elif types.is_method(typ): + llfuncty = self.llty_of_type(types.get_method_function(typ)) + llselfty = self.llty_of_type(types.get_method_self(typ)) + return ll.LiteralStructType([llfuncty, llselfty]) elif builtins.is_none(typ): if for_return: return llvoid diff --git a/artiq/compiler/types.py b/artiq/compiler/types.py index 152b71cc8..0226d8fe9 100644 --- a/artiq/compiler/types.py +++ b/artiq/compiler/types.py @@ -350,9 +350,20 @@ class TInstance(TMono): self.attributes = attributes def __repr__(self): - return "py2llvm.types.TInstance({}, {]})".format( + return "py2llvm.types.TInstance({}, {})".format( repr(self.name), repr(self.attributes)) +class TMethod(TMono): + """ + A type of a method. + """ + + def __init__(self, self_type, function_type): + super().__init__("method", {"self": self_type, "fn": function_type}) + self.attributes = OrderedDict([ + ("__func__", function_type), + ("__self__", self_type), + ]) class TValue(Type): """ @@ -452,6 +463,17 @@ def is_instance(typ, name=None): else: return isinstance(typ, TInstance) +def is_method(typ): + return isinstance(typ.find(), TMethod) + +def get_method_self(typ): + if is_method(typ): + return typ.find().params["self"] + +def get_method_function(typ): + if is_method(typ): + return typ.find().params["fn"] + def is_value(typ): return isinstance(typ.find(), TValue) diff --git a/lit-test/test/inferencer/class.py b/lit-test/test/inferencer/class.py index 4b14a2737..2efdf1561 100644 --- a/lit-test/test/inferencer/class.py +++ b/lit-test/test/inferencer/class.py @@ -5,10 +5,15 @@ class c: a = 1 def f(): pass + def m(self): + pass -# CHECK-L: c:NoneType, m: (self:c)->NoneType}> c # CHECK-L: .a:int(width='a) c.a # CHECK-L: .f:()->NoneType c.f + +# CHECK-L: .m:method(self=c, fn=(self:c)->NoneType) +c().m() diff --git a/lit-test/test/inferencer/error_method.py b/lit-test/test/inferencer/error_method.py new file mode 100644 index 000000000..bf9b1fbe8 --- /dev/null +++ b/lit-test/test/inferencer/error_method.py @@ -0,0 +1,16 @@ +# RUN: %python -m artiq.compiler.testbench.inferencer +diag %s >%t +# RUN: OutputCheck %s --file-to-check=%t + +class c: + def f(): + pass + + def g(self): + pass + +# CHECK-L: ${LINE:+1}: error: function 'f()->NoneType' of class 'c' cannot accept a self argument +c().f() + +c.g(1) +# CHECK-L: ${LINE:+1}: error: cannot unify c with int(width='a) while inferring the type for self argument +c().g() diff --git a/lit-test/test/integration/class.py b/lit-test/test/integration/class.py index 8528636b2..3d9048a1b 100644 --- a/lit-test/test/integration/class.py +++ b/lit-test/test/integration/class.py @@ -5,6 +5,9 @@ class c: a = 1 def f(): return 2 + def g(self): + return self.a + 5 assert c.a == 1 assert c.f() == 2 +assert c().g() == 6