From 97fdef24885b590da906c719b100241011e958ca Mon Sep 17 00:00:00 2001 From: pca006132 Date: Fri, 18 Dec 2020 16:40:32 +0800 Subject: [PATCH] expression type check --- toy-impl/inference.py | 6 +- toy-impl/parse_expr.py | 180 +++++++++++++++++++++++++++++++++++++++++ toy-impl/test_expr.py | 73 +++++++++++++++++ 3 files changed, 256 insertions(+), 3 deletions(-) create mode 100644 toy-impl/parse_expr.py create mode 100644 toy-impl/test_expr.py diff --git a/toy-impl/inference.py b/toy-impl/inference.py index c2aa954..69bef9b 100644 --- a/toy-impl/inference.py +++ b/toy-impl/inference.py @@ -98,10 +98,10 @@ def resolve_call(obj, raise CustomError('{f} is not a method of {obj}') f_args, f_result = TupleType(f[0][1:]), f[1] else: - raise CustomError(f"No such method {fn} in {c}") + raise CustomError(f"No such method {fn} in {obj}") elif isinstance(obj, VirtualClassType): - # may need to emit special annotation that this is a virtual method - # call? + # TODO: may need to emit special annotation that this is a virtual + # method call? if fn in obj.base.methods: f = obj.base.methods[fn] if len(f[0]) == 0 or not isinstance(f[0][0], SelfType): diff --git a/toy-impl/parse_expr.py b/toy-impl/parse_expr.py new file mode 100644 index 0000000..3e26a92 --- /dev/null +++ b/toy-impl/parse_expr.py @@ -0,0 +1,180 @@ +import ast +from helper import * +from type_def import * +from inference import * + +# we assume having the following types: +# bool, int32 with associated operations + +# not handled now: slice, comprehensions, named expression, if expression, type +# guard + +def parse_expr(ctx: Context, + sym_table: dict[str, Type], + expr: ast.expr): + if isinstance(expr, ast.Expression): + body = expr.body + else: + body = expr + if isinstance(body, ast.Constant): + return parse_constant(ctx, sym_table, body) + if isinstance(body, ast.UnaryOp): + return parse_unary_op(ctx, sym_table, body) + if isinstance(body, ast.BinOp): + return parse_bin_ops(ctx, sym_table, body) + if isinstance(body, ast.Name): + return parse_name(ctx, sym_table, body) + if isinstance(body, ast.List): + return parse_list(ctx, sym_table, body) + if isinstance(body, ast.Tuple): + return parse_tuple(ctx, sym_table, body) + if isinstance(body, ast.Attribute): + return parse_attribute(ctx, sym_table, body) + if isinstance(body, ast.BoolOp): + return parse_bool_ops(ctx, sym_table, body) + if isinstance(body, ast.Compare): + return parse_compare(ctx, sym_table, body) + if isinstance(body, ast.Call): + return parse_call(ctx, sym_table, body) + if isinstance(body, ast.Subscript): + return parse_subscript(ctx, sym_table, body) + raise CustomError(f'{body} is not yet supported') + + +def get_unary_op(op): + if isinstance(op, ast.UAdd): + return '__pos__' + if isinstance(op, ast.USub): + return '__neg__' + if isinstance(op, ast.Invert): + return '__invert__' + raise Exception(f'Unknown {expr}') + +def get_bin_ops(op): + if isinstance(op, ast.Div): + return '__truediv__' + if isinstance(op, ast.BitAnd): + return '__and__' + if isinstance(op, ast.BitOr): + return '__or__' + if isinstance(op, ast.BitXor): + return '__xor__' + return f'__{type(op).__name__.lower()}__' + +def parse_constant(ctx: Context, + sym_table: dict[str, Type], + node): + v = node.value + if isinstance(v, int): + return ctx.types['int32'] + elif isinstance(v, bool): + return ctx.types['bool'] + else: + raise CustomError(f'unknown constant {v}') + +def parse_name(ctx: Context, + sym_table: dict[str, Type], + node): + if node.id in sym_table: + return sym_table[node.id] + else: + raise CustomError(f'unbounded variable {node.id}') + +def parse_list(ctx: Context, + sym_table: dict[str, Type], + node): + types = [parse_expr(ctx, sym_table, e) for e in node.elts] + if len(types) == 0: + return ListType(BotType()) + for t in types[1:]: + if t != types[0]: + raise CustomError(f'inhomogeneous list is not allowed') + return ListType(types[0]) + +def parse_tuple(ctx: Context, + sym_table: dict[str, Type], + node): + types = [parse_expr(ctx, sym_table, e) for e in node.elts] + return TupleType(types) + +def parse_attribute(ctx: Context, + sym_table: dict[str, Type], + node): + obj = parse_expr(node.value) + if node.attr in obj.fields: + return obj.fields[node.attr] + raise CustomError(f'unknown field {node.attr} in {obj}') + +def parse_bool_ops(ctx: Context, + sym_table: dict[str, Type], + node): + assert len(node.values) == 2 + left = parse_expr(ctx, sym_table, node.values[0]) + right = parse_expr(ctx, sym_table, node.values[1]) + b = ctx.types['bool'] + if left != b or right != b: + raise CustomError('operands of bool ops must be booleans') + return b + +def parse_bin_ops(ctx: Context, + sym_table: dict[str, Type], + node): + left = parse_expr(ctx, sym_table, node.left) + right = parse_expr(ctx, sym_table, node.right) + op = get_bin_ops(node.op) + return resolve_call(left, op, [right], {}, ctx) + +def parse_unary_ops(ctx: Context, + sym_table: dict[str, Type], + node): + t = parse_expr(node.operand) + if isinstance(node.op, ast.Not): + b = ctx.types['bool'] + if t != b: + raise CustomError('operands of bool ops must be booleans') + return b + return resolve_call(t, get_unary_op(node.op), [], {}, ctx) + +def parse_compare(ctx: Context, + sym_table: dict[str, Type], + node): + items = [parse_expr(ctx, sym_table, v) for v in node.comparators] + items.insert(0, parse_expr(ctx, sym_table, node.left)) + boolean = ctx.types['bool'] + ops = [get_bin_ops(v) for v in node.ops] + for a, b, op in zip(items[:-1], items[1:], ops): + result = resolve_call(a, op, [b], {}, ctx) + if result != boolean: + raise CustomError( + f'result of comparison must be bool instead of {result}') + return boolean + +def parse_call(ctx: Context, + sym_table: dict[str, Type], + node): + if len(node.keywords) > 0: + raise CustomError('keyword arguments are not supported') + args = [parse_expr(ctx, sym_table, v) for v in node.args] + obj = None + f = None + if isinstance(node.func, ast.Attribute): + obj = parse_expr(node.func.value) + f = node.func.attr + elif isinstance(node.func, ast.Name): + f = node.func.id + return resolve_call(obj, f, args, {}, ctx) + +def parse_subscript(ctx: Context, + sym_table: dict[str, Type], + node): + value = parse_expr(ctx, sym_table, node.value) + if not isinstance(value, ListType): + raise CustomError(f'cannot take index of {value}') + s = parse_expr(ctx, sym_table, node.slice) + i32 = ctx.types['int32'] + if s == i32: + return value.params[0] + else: + # will support slice + raise CustomError(f'index of type {s} is not supported') + diff --git a/toy-impl/test_expr.py b/toy-impl/test_expr.py new file mode 100644 index 0000000..bd97f3f --- /dev/null +++ b/toy-impl/test_expr.py @@ -0,0 +1,73 @@ +import ast +from type_def import * +from inference import * +from helper import * +from parse_expr import * + +types = { + 'int32': PrimitiveType('int32'), + 'int64': PrimitiveType('int64'), + 'str': PrimitiveType('str'), + 'bool': PrimitiveType('bool') +} + +i32 = types['int32'] +i64 = types['int64'] +s = types['str'] +b = types['bool'] + +variables = { + 'X': TypeVariable('X', []), + 'Y': TypeVariable('Y', []), + 'I': TypeVariable('I', [i32, i64]), + 'A': TypeVariable('A', [i32, i64, s]), +} + +X = variables['X'] +Y = variables['Y'] +I = variables['I'] +A = variables['A'] + +i32.methods['__init__'] = ([SelfType(), I], None, set()) +i32.methods['__add__'] = ([SelfType(), i32], i32, set()) +i32.methods['__sub__'] = ([SelfType(), i32], i32, set()) +i32.methods['__lt__'] = ([SelfType(), i32], b, set()) +i32.methods['__gt__'] = ([SelfType(), i32], b, set()) +i32.methods['__eq__'] = ([SelfType(), i32], b, set()) +i32.methods['__ne__'] = ([SelfType(), i32], b, set()) +i32.methods['__le__'] = ([SelfType(), i32], b, set()) +i32.methods['__ge__'] = ([SelfType(), i32], b, set()) + +i64.methods['__init__'] = ([SelfType(), I], None, set()) +i64.methods['__add__'] = ([SelfType(), i64], i64, set()) +i64.methods['__sub__'] = ([SelfType(), i64], i64, set()) +i64.methods['__lt__'] = ([SelfType(), i64], b, set()) +i64.methods['__gt__'] = ([SelfType(), i64], b, set()) +i64.methods['__eq__'] = ([SelfType(), i64], b, set()) +i64.methods['__ne__'] = ([SelfType(), i64], b, set()) +i64.methods['__le__'] = ([SelfType(), i64], b, set()) +i64.methods['__ge__'] = ([SelfType(), i64], b, set()) + +ctx = Context(variables, types) + +def test_expr(expr, sym_table= {}): + print(f'Testing {expr} w.r.t. {stringify_subst(sym_table)}') + try: + tree = ast.parse(expr, mode='eval') + result = parse_expr(ctx, sym_table, tree) + print(result) + except CustomError as err: + print(f'error: {err.msg}') + +test_expr('1 + 1') +test_expr('1 - 1') +test_expr('int64(1)') +test_expr('int64(1) - 1') +test_expr('a - a', {'a': I}) +test_expr('a - a', {'a': A}) +test_expr('[1, 2, 3][2]') +test_expr('[[1], [2], [3]][2]') +test_expr('[[1], [2], [3]][a]', {'a': i32}) +test_expr('a == a == a', {'a': I}) +test_expr('a == a and 1 == 2', {'a': I}) +