266 lines
9.3 KiB
Python
266 lines
9.3 KiB
Python
import ast
|
|
import copy
|
|
from helper import *
|
|
from type_def import *
|
|
from inference import *
|
|
|
|
# we assume having the following types:
|
|
# bool, int32 and float with associated operations
|
|
|
|
# not handled now: named expression, type guard
|
|
|
|
def parse_expr(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
expr: ast.expr):
|
|
if isinstance(expr, ast.Expression):
|
|
body = expr.body
|
|
else:
|
|
body = expr
|
|
if isinstance(body, ast.Constant):
|
|
return parse_constant(ctx, sym_table, body)
|
|
if isinstance(body, ast.UnaryOp):
|
|
return parse_unary_ops(ctx, sym_table, body)
|
|
if isinstance(body, ast.BinOp):
|
|
return parse_bin_ops(ctx, sym_table, body)
|
|
if isinstance(body, ast.Name):
|
|
return parse_name(ctx, sym_table, body)
|
|
if isinstance(body, ast.List):
|
|
return parse_list(ctx, sym_table, body)
|
|
if isinstance(body, ast.Tuple):
|
|
return parse_tuple(ctx, sym_table, body)
|
|
if isinstance(body, ast.Attribute):
|
|
return parse_attribute(ctx, sym_table, body)
|
|
if isinstance(body, ast.BoolOp):
|
|
return parse_bool_ops(ctx, sym_table, body)
|
|
if isinstance(body, ast.Compare):
|
|
return parse_compare(ctx, sym_table, body)
|
|
if isinstance(body, ast.Call):
|
|
return parse_call(ctx, sym_table, body)
|
|
if isinstance(body, ast.Subscript):
|
|
return parse_subscript(ctx, sym_table, body)
|
|
if isinstance(body, ast.IfExp):
|
|
return parse_if_expr(ctx, sym_table, body)
|
|
if isinstance(body, ast.ListComp):
|
|
return parse_list_comprehension(ctx, sym_table, body)
|
|
raise CustomError(f'{body} is not yet supported', body)
|
|
|
|
|
|
def get_unary_op(op):
|
|
if isinstance(op, ast.UAdd):
|
|
return '__pos__'
|
|
if isinstance(op, ast.USub):
|
|
return '__neg__'
|
|
if isinstance(op, ast.Invert):
|
|
return '__invert__'
|
|
raise Exception(f'Unknown {expr}')
|
|
|
|
def get_bin_ops(op):
|
|
if isinstance(op, ast.Div):
|
|
return '__truediv__'
|
|
if isinstance(op, ast.BitAnd):
|
|
return '__and__'
|
|
if isinstance(op, ast.BitOr):
|
|
return '__or__'
|
|
if isinstance(op, ast.BitXor):
|
|
return '__xor__'
|
|
return f'__{type(op).__name__.lower()}__'
|
|
|
|
def parse_constant(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
v = node.value
|
|
if isinstance(v, bool):
|
|
return ctx.types['bool']
|
|
elif isinstance(v, int):
|
|
return ctx.types['int32']
|
|
elif isinstance(v, float):
|
|
return ctx.types['float']
|
|
else:
|
|
raise CustomError(f'unknown constant {v}', node)
|
|
|
|
def parse_name(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
if node.id in sym_table:
|
|
return sym_table[node.id]
|
|
else:
|
|
raise CustomError(f'unbounded variable {node.id}', node)
|
|
|
|
def parse_list(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
types = [parse_expr(ctx, sym_table, e) for e in node.elts]
|
|
if len(types) == 0:
|
|
return ListType(BotType())
|
|
for t in types[1:]:
|
|
if t != types[0]:
|
|
raise CustomError(f'inhomogeneous list is not allowed', node)
|
|
return ListType(types[0])
|
|
|
|
def parse_tuple(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
types = [parse_expr(ctx, sym_table, e) for e in node.elts]
|
|
return TupleType(types)
|
|
|
|
def parse_attribute(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
obj = parse_expr(ctx, sym_table, node.value)
|
|
if node.attr in obj.fields:
|
|
return obj.fields[node.attr]
|
|
raise CustomError(f'unknown field {node.attr} in {obj}', node)
|
|
|
|
def parse_bool_ops(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
assert len(node.values) == 2
|
|
left = parse_expr(ctx, sym_table, node.values[0])
|
|
right = parse_expr(ctx, sym_table, node.values[1])
|
|
b = ctx.types['bool']
|
|
if left != b or right != b:
|
|
raise CustomError('operands of bool ops must be booleans', node)
|
|
return b
|
|
|
|
def parse_bin_ops(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
left = parse_expr(ctx, sym_table, node.left)
|
|
right = parse_expr(ctx, sym_table, node.right)
|
|
op = get_bin_ops(node.op)
|
|
try:
|
|
return resolve_call(left, op, [right], {}, ctx)
|
|
except CustomError as e:
|
|
raise e.at(node)
|
|
|
|
def parse_unary_ops(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
t = parse_expr(ctx, sym_table, node.operand)
|
|
if isinstance(node.op, ast.Not):
|
|
b = ctx.types['bool']
|
|
if t != b:
|
|
raise CustomError('operands of bool ops must be booleans', node)
|
|
return b
|
|
try:
|
|
return resolve_call(t, get_unary_op(node.op), [], {}, ctx)
|
|
except CustomError as e:
|
|
raise e.at(node)
|
|
|
|
def parse_compare(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
items = [parse_expr(ctx, sym_table, v) for v in node.comparators]
|
|
items.insert(0, parse_expr(ctx, sym_table, node.left))
|
|
boolean = ctx.types['bool']
|
|
ops = [get_bin_ops(v) for v in node.ops]
|
|
for a, b, op in zip(items[:-1], items[1:], ops):
|
|
try:
|
|
result = resolve_call(a, op, [b], {}, ctx)
|
|
if result != boolean:
|
|
raise CustomError(
|
|
f'result of comparison must be bool instead of {result}')
|
|
except CustomError as e:
|
|
raise e.at(node)
|
|
|
|
return boolean
|
|
|
|
def parse_call(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
if len(node.keywords) > 0:
|
|
raise CustomError('keyword arguments are not supported', node)
|
|
args = [parse_expr(ctx, sym_table, v) for v in node.args]
|
|
obj = None
|
|
f = None
|
|
if isinstance(node.func, ast.Attribute):
|
|
obj = parse_expr(ctx, sym_table, node.func.value)
|
|
f = node.func.attr
|
|
elif isinstance(node.func, ast.Name):
|
|
f = node.func.id
|
|
try:
|
|
return resolve_call(obj, f, args, {}, ctx)
|
|
except CustomError as e:
|
|
raise e.at(node)
|
|
|
|
def parse_subscript(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
value = parse_expr(ctx, sym_table, node.value)
|
|
if not isinstance(value, ListType):
|
|
raise CustomError(f'cannot take index of {value}', node)
|
|
i32 = ctx.types['int32']
|
|
if isinstance(node.slice, ast.Slice):
|
|
if node.slice.lower is not None:
|
|
if parse_expr(ctx, sym_table, node.slice.lower) != i32:
|
|
raise CustomError(f'slice index must be int32', node.slice.lower)
|
|
if node.slice.upper is not None:
|
|
if parse_expr(ctx, sym_table, node.slice.upper) != i32:
|
|
raise CustomError(f'slice index must be int32', node.slice.upper)
|
|
if node.slice.step is not None:
|
|
if parse_expr(ctx, sym_table, node.slice.step) != i32:
|
|
raise CustomError(f'slice index must be int32', node.slice.step)
|
|
return value
|
|
else:
|
|
s = parse_expr(ctx, sym_table, node.slice)
|
|
if s == i32:
|
|
return value.params[0]
|
|
else:
|
|
raise CustomError(f'index of type {s} is not supported', node)
|
|
|
|
def parse_if_expr(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
b = ctx.types['bool']
|
|
t = parse_expr(ctx, sym_table, node.test)
|
|
if t != b:
|
|
raise CustomError(f'type of conditional must be bool instead of {t}', node)
|
|
ty1 = parse_expr(ctx, sym_table, node.body)
|
|
ty2 = parse_expr(ctx, sym_table, node.orelse)
|
|
if ty1 != ty2:
|
|
raise CustomError(f'divergent type for if expression: {ty1} != {ty2}', node)
|
|
return ty1
|
|
|
|
def parse_simple_binding(name, ty):
|
|
if isinstance(name, ast.Name):
|
|
if name.id == '_':
|
|
return {}
|
|
return {name.id: ty}
|
|
elif isinstance(name, ast.Tuple):
|
|
if not isinstance(ty, TupleType):
|
|
raise CustomError(f'cannot pattern match over {ty}')
|
|
if len(name.elts) != len(ty.params):
|
|
raise CustomError(f'pattern matching length mismatch')
|
|
result = {}
|
|
for x, y in zip(name.elts, ty.params):
|
|
binding = parse_simple_binding(x, y)
|
|
expected = len(result) + len(binding)
|
|
result |= parse_simple_binding(x, y)
|
|
if len(result) != expected:
|
|
raise CustomError('variable name clash', x)
|
|
return result
|
|
else:
|
|
raise CustomError(f'binding to {name} is not supported')
|
|
|
|
def parse_list_comprehension(ctx: Context,
|
|
sym_table: dict[str, Type],
|
|
node):
|
|
if len(node.generators) != 1:
|
|
raise CustomError(
|
|
'list comprehension with more than 1 for loop is not supported', node)
|
|
if node.generators[0].is_async:
|
|
raise CustomError('async list comprehension is not supported', node)
|
|
ty = parse_expr(ctx, sym_table, node.generators[0].iter)
|
|
if not isinstance(ty, ListType):
|
|
raise CustomError(f'unable to iterate over {ty}', node)
|
|
try:
|
|
sym_table2 = sym_table | parse_simple_binding(node.generators[0].target, ty.params[0])
|
|
except CustomError as e:
|
|
raise e.at(node)
|
|
b = ctx.types['bool']
|
|
for c in node.generators[0].ifs:
|
|
if parse_expr(ctx, sym_table2, c) != b:
|
|
raise CustomError(f'condition should be of boolean type', c)
|
|
return ListType(parse_expr(ctx, sym_table2, node.elt))
|
|
|