nac3-spec/toy-impl/parse_expr.py

279 lines
9.9 KiB
Python

import ast
import copy
from helper import *
from type_def import *
from inference import *
# we assume having the following types:
# bool, int32 and float with associated operations
# not handled now: named expression, type guard
def parse_expr(ctx: Context,
sym_table: dict[str, Type],
expr: ast.expr):
if isinstance(expr, ast.Expression):
body = expr.body
else:
body = expr
if isinstance(body, ast.Constant):
return parse_constant(ctx, sym_table, body)
if isinstance(body, ast.UnaryOp):
return parse_unary_ops(ctx, sym_table, body)
if isinstance(body, ast.BinOp):
return parse_bin_ops(ctx, sym_table, body)
if isinstance(body, ast.Name):
return parse_name(ctx, sym_table, body)
if isinstance(body, ast.List):
return parse_list(ctx, sym_table, body)
if isinstance(body, ast.Tuple):
return parse_tuple(ctx, sym_table, body)
if isinstance(body, ast.Attribute):
return parse_attribute(ctx, sym_table, body)
if isinstance(body, ast.BoolOp):
return parse_bool_ops(ctx, sym_table, body)
if isinstance(body, ast.Compare):
return parse_compare(ctx, sym_table, body)
if isinstance(body, ast.Call):
return parse_call(ctx, sym_table, body)
if isinstance(body, ast.Subscript):
return parse_subscript(ctx, sym_table, body)
if isinstance(body, ast.IfExp):
return parse_if_expr(ctx, sym_table, body)
if isinstance(body, ast.ListComp):
return parse_list_comprehension(ctx, sym_table, body)
raise CustomError(f'{body} is not yet supported', body)
def get_unary_op(op):
if isinstance(op, ast.UAdd):
return '__pos__'
if isinstance(op, ast.USub):
return '__neg__'
if isinstance(op, ast.Invert):
return '__invert__'
raise Exception(f'Unknown {expr}')
def get_bin_ops(op):
if isinstance(op, ast.Div):
return '__truediv__'
if isinstance(op, ast.BitAnd):
return '__and__'
if isinstance(op, ast.BitOr):
return '__or__'
if isinstance(op, ast.BitXor):
return '__xor__'
return f'__{type(op).__name__.lower()}__'
def parse_constant(ctx: Context,
sym_table: dict[str, Type],
node):
v = node.value
if isinstance(v, bool):
return ctx.types['bool']
elif isinstance(v, int):
return ctx.types['int32']
elif isinstance(v, float):
return ctx.types['float']
else:
raise CustomError(f'unknown constant {v}', node)
def parse_name(ctx: Context,
sym_table: dict[str, Type],
node):
if node.id in sym_table:
return sym_table[node.id]
else:
raise CustomError(f'unbounded variable {node.id}', node)
def parse_list(ctx: Context,
sym_table: dict[str, Type],
node):
types = [parse_expr(ctx, sym_table, e) for e in node.elts]
if len(types) == 0:
return ListType(BotType())
for t in types[1:]:
if t != types[0]:
raise CustomError(f'inhomogeneous list is not allowed', node)
return ListType(types[0])
def parse_tuple(ctx: Context,
sym_table: dict[str, Type],
node):
types = [parse_expr(ctx, sym_table, e) for e in node.elts]
return TupleType(types)
def parse_attribute(ctx: Context,
sym_table: dict[str, Type],
node):
obj = parse_expr(ctx, sym_table, node.value)
if isinstance(obj, TypeVariable) and len(obj.constraints) > 0:
if node.attr not in obj.constraints[0].fields:
raise CustomError(f'unknown field {node.attr} in {obj}', node)
ty = obj.constraints[0].fields[node.attr]
for v in obj.constraints[1:]:
if node.attr not in v.fields:
raise CustomError(f'unknown field {node.attr} in {obj}', node)
if v.fields[node.attr] != ty:
raise CustomError(
f'unknown field {node.attr} in {obj} (type mismatch)', node)
return ty
if node.attr in obj.fields:
return obj.fields[node.attr]
raise CustomError(f'unknown field {node.attr} in {obj}', node)
def parse_bool_ops(ctx: Context,
sym_table: dict[str, Type],
node):
assert len(node.values) == 2
left = parse_expr(ctx, sym_table, node.values[0])
right = parse_expr(ctx, sym_table, node.values[1])
b = ctx.types['bool']
if left != b or right != b:
raise CustomError('operands of bool ops must be booleans', node)
return b
def parse_bin_ops(ctx: Context,
sym_table: dict[str, Type],
node):
left = parse_expr(ctx, sym_table, node.left)
right = parse_expr(ctx, sym_table, node.right)
op = get_bin_ops(node.op)
try:
return resolve_call(left, op, [right], {}, ctx)
except CustomError as e:
raise e.at(node)
def parse_unary_ops(ctx: Context,
sym_table: dict[str, Type],
node):
t = parse_expr(ctx, sym_table, node.operand)
if isinstance(node.op, ast.Not):
b = ctx.types['bool']
if t != b:
raise CustomError('operands of bool ops must be booleans', node)
return b
try:
return resolve_call(t, get_unary_op(node.op), [], {}, ctx)
except CustomError as e:
raise e.at(node)
def parse_compare(ctx: Context,
sym_table: dict[str, Type],
node):
items = [parse_expr(ctx, sym_table, v) for v in node.comparators]
items.insert(0, parse_expr(ctx, sym_table, node.left))
boolean = ctx.types['bool']
ops = [get_bin_ops(v) for v in node.ops]
for a, b, op in zip(items[:-1], items[1:], ops):
try:
result = resolve_call(a, op, [b], {}, ctx)
if result != boolean:
raise CustomError(
f'result of comparison must be bool instead of {result}')
except CustomError as e:
raise e.at(node)
return boolean
def parse_call(ctx: Context,
sym_table: dict[str, Type],
node):
if len(node.keywords) > 0:
raise CustomError('keyword arguments are not supported', node)
args = [parse_expr(ctx, sym_table, v) for v in node.args]
obj = None
f = None
if isinstance(node.func, ast.Attribute):
obj = parse_expr(ctx, sym_table, node.func.value)
f = node.func.attr
elif isinstance(node.func, ast.Name):
f = node.func.id
try:
return resolve_call(obj, f, args, {}, ctx)
except CustomError as e:
raise e.at(node)
def parse_subscript(ctx: Context,
sym_table: dict[str, Type],
node):
value = parse_expr(ctx, sym_table, node.value)
if not isinstance(value, ListType):
raise CustomError(f'cannot take index of {value}', node)
i32 = ctx.types['int32']
if isinstance(node.slice, ast.Slice):
if node.slice.lower is not None:
if parse_expr(ctx, sym_table, node.slice.lower) != i32:
raise CustomError(f'slice index must be int32', node.slice.lower)
if node.slice.upper is not None:
if parse_expr(ctx, sym_table, node.slice.upper) != i32:
raise CustomError(f'slice index must be int32', node.slice.upper)
if node.slice.step is not None:
if parse_expr(ctx, sym_table, node.slice.step) != i32:
raise CustomError(f'slice index must be int32', node.slice.step)
return value
else:
s = parse_expr(ctx, sym_table, node.slice)
if s == i32:
return value.params[0]
else:
raise CustomError(f'index of type {s} is not supported', node)
def parse_if_expr(ctx: Context,
sym_table: dict[str, Type],
node):
b = ctx.types['bool']
t = parse_expr(ctx, sym_table, node.test)
if t != b:
raise CustomError(f'type of conditional must be bool instead of {t}', node)
ty1 = parse_expr(ctx, sym_table, node.body)
ty2 = parse_expr(ctx, sym_table, node.orelse)
if ty1 != ty2:
raise CustomError(f'divergent type for if expression: {ty1} != {ty2}', node)
return ty1
def parse_simple_binding(name, ty):
if isinstance(name, ast.Name):
if name.id == '_':
return {}
return {name.id: ty}
elif isinstance(name, ast.Tuple):
if not isinstance(ty, TupleType):
raise CustomError(f'cannot pattern match over {ty}')
if len(name.elts) != len(ty.params):
raise CustomError(f'pattern matching length mismatch')
result = {}
for x, y in zip(name.elts, ty.params):
binding = parse_simple_binding(x, y)
expected = len(result) + len(binding)
result |= parse_simple_binding(x, y)
if len(result) != expected:
raise CustomError('variable name clash', x)
return result
else:
raise CustomError(f'binding to {name} is not supported')
def parse_list_comprehension(ctx: Context,
sym_table: dict[str, Type],
node):
if len(node.generators) != 1:
raise CustomError(
'list comprehension with more than 1 for loop is not supported', node)
if node.generators[0].is_async:
raise CustomError('async list comprehension is not supported', node)
ty = parse_expr(ctx, sym_table, node.generators[0].iter)
if not isinstance(ty, ListType):
raise CustomError(f'unable to iterate over {ty}', node)
try:
sym_table2 = sym_table | parse_simple_binding(node.generators[0].target, ty.params[0])
except CustomError as e:
raise e.at(node)
b = ctx.types['bool']
for c in node.generators[0].ifs:
if parse_expr(ctx, sym_table2, c) != b:
raise CustomError(f'condition should be of boolean type', c)
return ListType(parse_expr(ctx, sym_table2, node.elt))