Fix type annotations with mixed tuples

The type checker/inferer visits every node in an AST tree, including
function return annotations. This means for a function definition like

    def f() -> TTuple([TInt32, TBool]):
      ...

We attempt to type check the list [TInt32, TBool], which generates the
unification constraint builtins.TBool ~ builtins.TInt. This causes an
internal error due to compiler weirdness.

We can avoid this by just nulling-out the return annotation in the
embedding stage. The return type isn't actually used anywhere (it's
extracted via the inspect module instead), so this is entirely safe.

Arguments aren't affected by this, as we already nulled out the
annotation (see visit_arg in embedding.py).

Signed-off-by: Jonathan Coates <jonathan.coates@oxionics.com>
This commit is contained in:
Jonathan Coates 2023-08-21 11:02:08 +01:00 committed by Sebastien Bourdeauducq
parent 3f27c76619
commit cc81464f53
2 changed files with 17 additions and 1 deletions

View File

@ -546,7 +546,7 @@ class StitchingASTTypedRewriter(ASTTypedRewriter):
node = asttyped.QuotedFunctionDefT( node = asttyped.QuotedFunctionDefT(
typing_env=extractor.typing_env, globals_in_scope=extractor.global_, typing_env=extractor.typing_env, globals_in_scope=extractor.global_,
signature_type=types.TVar(), return_type=types.TVar(), signature_type=types.TVar(), return_type=types.TVar(),
name=node.name, args=node.args, returns=node.returns, name=node.name, args=node.args, returns=None,
body=node.body, decorator_list=node.decorator_list, body=node.body, decorator_list=node.decorator_list,
keyword_loc=node.keyword_loc, name_loc=node.name_loc, keyword_loc=node.keyword_loc, name_loc=node.name_loc,
arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs, arrow_loc=node.arrow_loc, colon_loc=node.colon_loc, at_locs=node.at_locs,

View File

@ -0,0 +1,16 @@
# RUN: %python -m artiq.compiler.testbench.embedding %s
from artiq.language.core import *
from artiq.language.types import *
@kernel
def consume_tuple(x: TTuple([TInt32, TBool])):
print(x)
@kernel
def return_tuple() -> TTuple([TInt32, TBool]):
return (123, False)
@kernel
def entrypoint():
consume_tuple(return_tuple())