import ast import copy from helper import * from type_def import * from inference import * # we assume having the following types: # bool, int32 with associated operations # not handled now: named expression, if expression, type guard def parse_expr(ctx: Context, sym_table: dict[str, Type], expr: ast.expr): if isinstance(expr, ast.Expression): body = expr.body else: body = expr if isinstance(body, ast.Constant): return parse_constant(ctx, sym_table, body) if isinstance(body, ast.UnaryOp): return parse_unary_op(ctx, sym_table, body) if isinstance(body, ast.BinOp): return parse_bin_ops(ctx, sym_table, body) if isinstance(body, ast.Name): return parse_name(ctx, sym_table, body) if isinstance(body, ast.List): return parse_list(ctx, sym_table, body) if isinstance(body, ast.Tuple): return parse_tuple(ctx, sym_table, body) if isinstance(body, ast.Attribute): return parse_attribute(ctx, sym_table, body) if isinstance(body, ast.BoolOp): return parse_bool_ops(ctx, sym_table, body) if isinstance(body, ast.Compare): return parse_compare(ctx, sym_table, body) if isinstance(body, ast.Call): return parse_call(ctx, sym_table, body) if isinstance(body, ast.Subscript): return parse_subscript(ctx, sym_table, body) if isinstance(body, ast.IfExp): return parse_if_expr(ctx, sym_table, body) if isinstance(body, ast.ListComp): return parse_list_comprehension(ctx, sym_table, body) raise CustomError(f'{body} is not yet supported') def get_unary_op(op): if isinstance(op, ast.UAdd): return '__pos__' if isinstance(op, ast.USub): return '__neg__' if isinstance(op, ast.Invert): return '__invert__' raise Exception(f'Unknown {expr}') def get_bin_ops(op): if isinstance(op, ast.Div): return '__truediv__' if isinstance(op, ast.BitAnd): return '__and__' if isinstance(op, ast.BitOr): return '__or__' if isinstance(op, ast.BitXor): return '__xor__' return f'__{type(op).__name__.lower()}__' def parse_constant(ctx: Context, sym_table: dict[str, Type], node): v = node.value if isinstance(v, bool): return ctx.types['bool'] elif isinstance(v, int): return ctx.types['int32'] else: raise CustomError(f'unknown constant {v}') def parse_name(ctx: Context, sym_table: dict[str, Type], node): if node.id in sym_table: return sym_table[node.id] else: raise CustomError(f'unbounded variable {node.id}') def parse_list(ctx: Context, sym_table: dict[str, Type], node): types = [parse_expr(ctx, sym_table, e) for e in node.elts] if len(types) == 0: return ListType(BotType()) for t in types[1:]: if t != types[0]: raise CustomError(f'inhomogeneous list is not allowed') return ListType(types[0]) def parse_tuple(ctx: Context, sym_table: dict[str, Type], node): types = [parse_expr(ctx, sym_table, e) for e in node.elts] return TupleType(types) def parse_attribute(ctx: Context, sym_table: dict[str, Type], node): obj = parse_expr(node.value) if node.attr in obj.fields: return obj.fields[node.attr] raise CustomError(f'unknown field {node.attr} in {obj}') def parse_bool_ops(ctx: Context, sym_table: dict[str, Type], node): assert len(node.values) == 2 left = parse_expr(ctx, sym_table, node.values[0]) right = parse_expr(ctx, sym_table, node.values[1]) b = ctx.types['bool'] if left != b or right != b: raise CustomError('operands of bool ops must be booleans') return b def parse_bin_ops(ctx: Context, sym_table: dict[str, Type], node): left = parse_expr(ctx, sym_table, node.left) right = parse_expr(ctx, sym_table, node.right) op = get_bin_ops(node.op) return resolve_call(left, op, [right], {}, ctx) def parse_unary_ops(ctx: Context, sym_table: dict[str, Type], node): t = parse_expr(node.operand) if isinstance(node.op, ast.Not): b = ctx.types['bool'] if t != b: raise CustomError('operands of bool ops must be booleans') return b return resolve_call(t, get_unary_op(node.op), [], {}, ctx) def parse_compare(ctx: Context, sym_table: dict[str, Type], node): items = [parse_expr(ctx, sym_table, v) for v in node.comparators] items.insert(0, parse_expr(ctx, sym_table, node.left)) boolean = ctx.types['bool'] ops = [get_bin_ops(v) for v in node.ops] for a, b, op in zip(items[:-1], items[1:], ops): result = resolve_call(a, op, [b], {}, ctx) if result != boolean: raise CustomError( f'result of comparison must be bool instead of {result}') return boolean def parse_call(ctx: Context, sym_table: dict[str, Type], node): if len(node.keywords) > 0: raise CustomError('keyword arguments are not supported') args = [parse_expr(ctx, sym_table, v) for v in node.args] obj = None f = None if isinstance(node.func, ast.Attribute): obj = parse_expr(node.func.value) f = node.func.attr elif isinstance(node.func, ast.Name): f = node.func.id return resolve_call(obj, f, args, {}, ctx) def parse_subscript(ctx: Context, sym_table: dict[str, Type], node): value = parse_expr(ctx, sym_table, node.value) if not isinstance(value, ListType): raise CustomError(f'cannot take index of {value}') i32 = ctx.types['int32'] if isinstance(node.slice, ast.Slice): if node.slice.lower is not None: if parse_expr(ctx, sym_table, node.slice.lower) != i32: raise CustomError(f'slice index must be int32') if node.slice.upper is not None: if parse_expr(ctx, sym_table, node.slice.upper) != i32: raise CustomError(f'slice index must be int32') if node.slice.step is not None: if parse_expr(ctx, sym_table, node.slice.step) != i32: raise CustomError(f'slice index must be int32') return value else: s = parse_expr(ctx, sym_table, node.slice) if s == i32: return value.params[0] else: raise CustomError(f'index of type {s} is not supported') def parse_if_expr(ctx: Context, sym_table: dict[str, Type], node): b = ctx.types['bool'] t = parse_expr(ctx, sym_table, node.test) if t != b: raise CustomError(f'type of conditional must be bool instead of {t}') ty1 = parse_expr(ctx, sym_table, node.body) ty2 = parse_expr(ctx, sym_table, node.orelse) if ty1 != ty2: raise CustomError(f'divergent type for if expression: {ty1} != {ty2}') return ty1 def parse_binding(name, ty): if isinstance(name, ast.Name): return {name.id: ty} elif isinstance(name, ast.Tuple): if not isinstance(ty, TupleType): raise CustomError(f'cannot pattern match over {ty}') if len(name.elts) != len(ty.params): raise CustomError(f'pattern matching length mismatch') result = {} for x, y in zip(name.elts, ty.params): result |= parse_name(x, y) return result else: raise CustomError(f'binding to {name} is not supported') def parse_list_comprehension(ctx: Context, sym_table: dict[str, Type], node): if len(node.generators) != 1: raise CustomError('list comprehension with more than 1 for loop is not supported') if node.generators[0].is_async: raise CustomError('async list comprehension is not supported') ty = parse_expr(ctx, sym_table, node.generators[0].iter) if not isinstance(ty, ListType): raise CustomError(f'unable to iterate over {ty}') sym_table2 = sym_table | parse_binding(node.generators[0].target, ty.params[0]) b = ctx.types['bool'] for c in node.generators[0].ifs: if parse_expr(ctx, sym_table2, c) != b: raise CustomError(f'condition should be of boolean type') return ListType(parse_expr(ctx, sym_table2, node.elt))