import ast import copy from helper import * from type_def import * from inference import * # we assume having the following types: # bool, int32 and float with associated operations # not handled now: named expression, type guard def parse_expr(ctx: Context, sym_table: dict[str, Type], expr: ast.expr): if isinstance(expr, ast.Expression): body = expr.body else: body = expr if isinstance(body, ast.Constant): result = parse_constant(ctx, sym_table, body) elif isinstance(body, ast.UnaryOp): result = parse_unary_ops(ctx, sym_table, body) elif isinstance(body, ast.BinOp): result = parse_bin_ops(ctx, sym_table, body) elif isinstance(body, ast.Name): result = parse_name(ctx, sym_table, body) elif isinstance(body, ast.List): result = parse_list(ctx, sym_table, body) elif isinstance(body, ast.Tuple): result = parse_tuple(ctx, sym_table, body) elif isinstance(body, ast.Attribute): result = parse_attribute(ctx, sym_table, body) elif isinstance(body, ast.BoolOp): result = parse_bool_ops(ctx, sym_table, body) elif isinstance(body, ast.Compare): result = parse_compare(ctx, sym_table, body) elif isinstance(body, ast.Call): result = parse_call(ctx, sym_table, body) elif isinstance(body, ast.Subscript): result = parse_subscript(ctx, sym_table, body) elif isinstance(body, ast.IfExp): result = parse_if_expr(ctx, sym_table, body) elif isinstance(body, ast.ListComp): result = parse_list_comprehension(ctx, sym_table, body) else: raise CustomError(f'{body} is not yet supported', body) body.type = result return result def get_unary_op(op): if isinstance(op, ast.UAdd): return '__pos__' if isinstance(op, ast.USub): return '__neg__' if isinstance(op, ast.Invert): return '__invert__' raise Exception(f'Unknown {expr}') def get_bin_ops(op): if isinstance(op, ast.Div): return '__truediv__' if isinstance(op, ast.BitAnd): return '__and__' if isinstance(op, ast.BitOr): return '__or__' if isinstance(op, ast.BitXor): return '__xor__' return f'__{type(op).__name__.lower()}__' def parse_constant(ctx: Context, sym_table: dict[str, Type], node): v = node.value if isinstance(v, bool): return ctx.types['bool'] elif isinstance(v, int): return ctx.types['int32'] elif isinstance(v, float): return ctx.types['float'] else: raise CustomError(f'unknown constant {v}', node) def parse_name(ctx: Context, sym_table: dict[str, Type], node): if node.id in sym_table: return sym_table[node.id] else: raise CustomError(f'unbounded variable {node.id}', node) def parse_list(ctx: Context, sym_table: dict[str, Type], node): types = [parse_expr(ctx, sym_table, e) for e in node.elts] if len(types) == 0: return ListType(BotType()) for t in types[1:]: if t != types[0]: raise CustomError(f'inhomogeneous list is not allowed', node) return ListType(types[0]) def parse_tuple(ctx: Context, sym_table: dict[str, Type], node): types = [parse_expr(ctx, sym_table, e) for e in node.elts] return TupleType(types) def parse_attribute(ctx: Context, sym_table: dict[str, Type], node): obj = parse_expr(ctx, sym_table, node.value) if isinstance(obj, TypeVariable) and len(obj.constraints) > 0: if node.attr not in obj.constraints[0].fields: raise CustomError(f'unknown field {node.attr} in {obj}', node) ty = obj.constraints[0].fields[node.attr] for v in obj.constraints[1:]: if node.attr not in v.fields: raise CustomError(f'unknown field {node.attr} in {obj}', node) if v.fields[node.attr] != ty: raise CustomError( f'unknown field {node.attr} in {obj} (type mismatch)', node) return ty if node.attr in obj.fields: return obj.fields[node.attr] raise CustomError(f'unknown field {node.attr} in {obj}', node) def parse_bool_ops(ctx: Context, sym_table: dict[str, Type], node): assert len(node.values) == 2 left = parse_expr(ctx, sym_table, node.values[0]) right = parse_expr(ctx, sym_table, node.values[1]) b = ctx.types['bool'] if left != b or right != b: raise CustomError('operands of bool ops must be booleans', node) return b def parse_bin_ops(ctx: Context, sym_table: dict[str, Type], node): left = parse_expr(ctx, sym_table, node.left) right = parse_expr(ctx, sym_table, node.right) op = get_bin_ops(node.op) try: return resolve_call(left, op, [right], {}, ctx) except CustomError as e: raise e.at(node) def parse_unary_ops(ctx: Context, sym_table: dict[str, Type], node): t = parse_expr(ctx, sym_table, node.operand) if isinstance(node.op, ast.Not): b = ctx.types['bool'] if t != b: raise CustomError('operands of bool ops must be booleans', node) return b try: return resolve_call(t, get_unary_op(node.op), [], {}, ctx) except CustomError as e: raise e.at(node) def parse_compare(ctx: Context, sym_table: dict[str, Type], node): items = [parse_expr(ctx, sym_table, v) for v in node.comparators] items.insert(0, parse_expr(ctx, sym_table, node.left)) boolean = ctx.types['bool'] ops = [get_bin_ops(v) for v in node.ops] for a, b, op in zip(items[:-1], items[1:], ops): try: result = resolve_call(a, op, [b], {}, ctx) if result != boolean: raise CustomError( f'result of comparison must be bool instead of {result}') except CustomError as e: raise e.at(node) return boolean def parse_call(ctx: Context, sym_table: dict[str, Type], node): if len(node.keywords) > 0: raise CustomError('keyword arguments are not supported', node) args = [parse_expr(ctx, sym_table, v) for v in node.args] obj = None f = None if isinstance(node.func, ast.Attribute): obj = parse_expr(ctx, sym_table, node.func.value) f = node.func.attr elif isinstance(node.func, ast.Name): f = node.func.id try: return resolve_call(obj, f, args, {}, ctx) except CustomError as e: raise e.at(node) def parse_subscript(ctx: Context, sym_table: dict[str, Type], node): value = parse_expr(ctx, sym_table, node.value) if not isinstance(value, ListType): raise CustomError(f'cannot take index of {value}', node) i32 = ctx.types['int32'] if isinstance(node.slice, ast.Slice): if node.slice.lower is not None: if parse_expr(ctx, sym_table, node.slice.lower) != i32: raise CustomError(f'slice index must be int32', node.slice.lower) if node.slice.upper is not None: if parse_expr(ctx, sym_table, node.slice.upper) != i32: raise CustomError(f'slice index must be int32', node.slice.upper) if node.slice.step is not None: if parse_expr(ctx, sym_table, node.slice.step) != i32: raise CustomError(f'slice index must be int32', node.slice.step) return value else: s = parse_expr(ctx, sym_table, node.slice) if s == i32: return value.params[0] else: raise CustomError(f'index of type {s} is not supported', node) def parse_if_expr(ctx: Context, sym_table: dict[str, Type], node): b = ctx.types['bool'] t = parse_expr(ctx, sym_table, node.test) if t != b: raise CustomError(f'type of conditional must be bool instead of {t}', node) ty1 = parse_expr(ctx, sym_table, node.body) ty2 = parse_expr(ctx, sym_table, node.orelse) if ty1 != ty2: raise CustomError(f'divergent type for if expression: {ty1} != {ty2}', node) return ty1 def parse_simple_binding(name, ty): if isinstance(name, ast.Name): if name.id == '_': return {} name.type = ty return {name.id: ty} elif isinstance(name, ast.Tuple): if not isinstance(ty, TupleType): raise CustomError(f'cannot pattern match over {ty}') if len(name.elts) != len(ty.params): raise CustomError(f'pattern matching length mismatch') result = {} for x, y in zip(name.elts, ty.params): binding = parse_simple_binding(x, y) expected = len(result) + len(binding) result |= parse_simple_binding(x, y) if len(result) != expected: raise CustomError('variable name clash', x) return result else: raise CustomError(f'binding to {name} is not supported') def parse_list_comprehension(ctx: Context, sym_table: dict[str, Type], node): if len(node.generators) != 1: raise CustomError( 'list comprehension with more than 1 for loop is not supported', node) if node.generators[0].is_async: raise CustomError('async list comprehension is not supported', node) ty = parse_expr(ctx, sym_table, node.generators[0].iter) if isinstance(ty, TypeVariable) and \ len(ty.constraints) == 1: ty = ty.constraints[0] if not isinstance(ty, ListType): raise CustomError(f'unable to iterate over {ty}', node) try: sym_table2 = sym_table | parse_simple_binding(node.generators[0].target, ty.params[0]) except CustomError as e: raise e.at(node) b = ctx.types['bool'] for c in node.generators[0].ifs: if parse_expr(ctx, sym_table2, c) != b: raise CustomError(f'condition should be of boolean type', c) return ListType(parse_expr(ctx, sym_table2, node.elt))