Add support for IfExp.

This commit is contained in:
whitequark 2015-06-11 03:22:20 +03:00
parent 9953302cb6
commit ba9a7d087d
2 changed files with 35 additions and 7 deletions

View File

@ -165,17 +165,18 @@ class Inferencer(algorithm.Transformer):
{"typeb": printer.name(typeb)}, {"typeb": printer.name(typeb)},
locb) locb)
highlights = [locb] if locb else []
if e.typea.find() == typea.find() and e.typeb.find() == typeb.find(): if e.typea.find() == typea.find() and e.typeb.find() == typeb.find():
diag = diagnostic.Diagnostic("fatal", diag = diagnostic.Diagnostic("fatal",
"cannot unify {typea} with {typeb}", "cannot unify {typea} with {typeb}",
{"typea": printer.name(typea), "typeb": printer.name(typeb)}, {"typea": printer.name(typea), "typeb": printer.name(typeb)},
loca, [locb], notes=[note1, note2]) loca, highlights, notes=[note1, note2])
else: # give more detail else: # give more detail
diag = diagnostic.Diagnostic("fatal", diag = diagnostic.Diagnostic("fatal",
"cannot unify {typea} with {typeb}: {fraga} is incompatible with {fragb}", "cannot unify {typea} with {typeb}: {fraga} is incompatible with {fragb}",
{"typea": printer.name(typea), "typeb": printer.name(typeb), {"typea": printer.name(typea), "typeb": printer.name(typeb),
"fraga": printer.name(e.typea), "fragb": printer.name(e.typeb)}, "fraga": printer.name(e.typea), "fragb": printer.name(e.typeb)},
loca, [locb], notes=[note1, note2]) loca, highlights, notes=[note1, note2])
self.engine.process(diag) self.engine.process(diag)
def _find_name(self, name, loc): def _find_name(self, name, loc):
@ -281,6 +282,29 @@ class Inferencer(algorithm.Transformer):
node.loc, node.value.loc, kind="expects") node.loc, node.value.loc, kind="expects")
return node return node
def visit_IfExp(self, node):
node = self.generic_visit(node)
self._unify(node.body.type, node.orelse.type,
node.body.loc, node.orelse.loc)
return asttyped.IfExpT(type=node.body.type,
test=node.test, body=node.body, orelse=node.orelse,
if_loc=node.if_loc, else_loc=node.else_loc, loc=node.loc)
def visit_BoolOp(self, node):
node = self.generic_visit(node)
for value, op_loc in zip(node.values, node.op_locs):
def makenotes(printer, typea, typeb, loca, locb):
return [
diagnostic.Diagnostic("note",
"py2llvm requires boolean operations to have boolean operands", {},
op_loc)
]
self._unify(value.type, types.TBool(),
value.loc, None, makenotes)
return asttyped.BoolOpT(type=types.TBool(),
op=node.op, values=node.values,
op_locs=node.op_locs, loc=node.loc)
# Visitors that just unify types # Visitors that just unify types
# #
def visit_Assign(self, node): def visit_Assign(self, node):
@ -316,14 +340,12 @@ class Inferencer(algorithm.Transformer):
visit_Attribute = visit_unsupported visit_Attribute = visit_unsupported
visit_BinOp = visit_unsupported visit_BinOp = visit_unsupported
visit_BoolOp = visit_unsupported
visit_Call = visit_unsupported visit_Call = visit_unsupported
visit_Compare = visit_unsupported visit_Compare = visit_unsupported
visit_Dict = visit_unsupported visit_Dict = visit_unsupported
visit_DictComp = visit_unsupported visit_DictComp = visit_unsupported
visit_Ellipsis = visit_unsupported visit_Ellipsis = visit_unsupported
visit_GeneratorExp = visit_unsupported visit_GeneratorExp = visit_unsupported
visit_IfExp = visit_unsupported
visit_Lambda = visit_unsupported visit_Lambda = visit_unsupported
visit_ListComp = visit_unsupported visit_ListComp = visit_unsupported
visit_Set = visit_unsupported visit_Set = visit_unsupported
@ -343,18 +365,18 @@ class Printer(algorithm.Visitor):
return self.rewriter.rewrite() return self.rewriter.rewrite()
def visit_FunctionDefT(self, node): def visit_FunctionDefT(self, node):
super().generic_visit(node)
self.rewriter.insert_before(node.colon_loc, self.rewriter.insert_before(node.colon_loc,
"->{}".format(self.type_printer.name(node.return_type))) "->{}".format(self.type_printer.name(node.return_type)))
def generic_visit(self, node):
super().generic_visit(node) super().generic_visit(node)
def generic_visit(self, node):
if hasattr(node, "type"): if hasattr(node, "type"):
self.rewriter.insert_after(node.loc, self.rewriter.insert_after(node.loc,
":{}".format(self.type_printer.name(node.type))) ":{}".format(self.type_printer.name(node.type)))
super().generic_visit(node)
def main(): def main():
import sys, fileinput, os import sys, fileinput, os

View File

@ -32,3 +32,9 @@ i[0] = 1
j = [] j = []
j += [1.0] j += [1.0]
# CHECK-L: j:list(elt=float) # CHECK-L: j:list(elt=float)
1 if a else 2
# CHECK-L: 1:int(width='f) if a:int(width='a) else 2:int(width='f):int(width='f)
True and False
# CHECK-L: True:bool and False:bool:bool