expression type check
This commit is contained in:
parent
60c1e99285
commit
97fdef2488
@ -98,10 +98,10 @@ def resolve_call(obj,
|
|||||||
raise CustomError('{f} is not a method of {obj}')
|
raise CustomError('{f} is not a method of {obj}')
|
||||||
f_args, f_result = TupleType(f[0][1:]), f[1]
|
f_args, f_result = TupleType(f[0][1:]), f[1]
|
||||||
else:
|
else:
|
||||||
raise CustomError(f"No such method {fn} in {c}")
|
raise CustomError(f"No such method {fn} in {obj}")
|
||||||
elif isinstance(obj, VirtualClassType):
|
elif isinstance(obj, VirtualClassType):
|
||||||
# may need to emit special annotation that this is a virtual method
|
# TODO: may need to emit special annotation that this is a virtual
|
||||||
# call?
|
# method call?
|
||||||
if fn in obj.base.methods:
|
if fn in obj.base.methods:
|
||||||
f = obj.base.methods[fn]
|
f = obj.base.methods[fn]
|
||||||
if len(f[0]) == 0 or not isinstance(f[0][0], SelfType):
|
if len(f[0]) == 0 or not isinstance(f[0][0], SelfType):
|
||||||
|
180
toy-impl/parse_expr.py
Normal file
180
toy-impl/parse_expr.py
Normal file
@ -0,0 +1,180 @@
|
|||||||
|
import ast
|
||||||
|
from helper import *
|
||||||
|
from type_def import *
|
||||||
|
from inference import *
|
||||||
|
|
||||||
|
# we assume having the following types:
|
||||||
|
# bool, int32 with associated operations
|
||||||
|
|
||||||
|
# not handled now: slice, comprehensions, named expression, if expression, type
|
||||||
|
# guard
|
||||||
|
|
||||||
|
def parse_expr(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
expr: ast.expr):
|
||||||
|
if isinstance(expr, ast.Expression):
|
||||||
|
body = expr.body
|
||||||
|
else:
|
||||||
|
body = expr
|
||||||
|
if isinstance(body, ast.Constant):
|
||||||
|
return parse_constant(ctx, sym_table, body)
|
||||||
|
if isinstance(body, ast.UnaryOp):
|
||||||
|
return parse_unary_op(ctx, sym_table, body)
|
||||||
|
if isinstance(body, ast.BinOp):
|
||||||
|
return parse_bin_ops(ctx, sym_table, body)
|
||||||
|
if isinstance(body, ast.Name):
|
||||||
|
return parse_name(ctx, sym_table, body)
|
||||||
|
if isinstance(body, ast.List):
|
||||||
|
return parse_list(ctx, sym_table, body)
|
||||||
|
if isinstance(body, ast.Tuple):
|
||||||
|
return parse_tuple(ctx, sym_table, body)
|
||||||
|
if isinstance(body, ast.Attribute):
|
||||||
|
return parse_attribute(ctx, sym_table, body)
|
||||||
|
if isinstance(body, ast.BoolOp):
|
||||||
|
return parse_bool_ops(ctx, sym_table, body)
|
||||||
|
if isinstance(body, ast.Compare):
|
||||||
|
return parse_compare(ctx, sym_table, body)
|
||||||
|
if isinstance(body, ast.Call):
|
||||||
|
return parse_call(ctx, sym_table, body)
|
||||||
|
if isinstance(body, ast.Subscript):
|
||||||
|
return parse_subscript(ctx, sym_table, body)
|
||||||
|
raise CustomError(f'{body} is not yet supported')
|
||||||
|
|
||||||
|
|
||||||
|
def get_unary_op(op):
|
||||||
|
if isinstance(op, ast.UAdd):
|
||||||
|
return '__pos__'
|
||||||
|
if isinstance(op, ast.USub):
|
||||||
|
return '__neg__'
|
||||||
|
if isinstance(op, ast.Invert):
|
||||||
|
return '__invert__'
|
||||||
|
raise Exception(f'Unknown {expr}')
|
||||||
|
|
||||||
|
def get_bin_ops(op):
|
||||||
|
if isinstance(op, ast.Div):
|
||||||
|
return '__truediv__'
|
||||||
|
if isinstance(op, ast.BitAnd):
|
||||||
|
return '__and__'
|
||||||
|
if isinstance(op, ast.BitOr):
|
||||||
|
return '__or__'
|
||||||
|
if isinstance(op, ast.BitXor):
|
||||||
|
return '__xor__'
|
||||||
|
return f'__{type(op).__name__.lower()}__'
|
||||||
|
|
||||||
|
def parse_constant(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
node):
|
||||||
|
v = node.value
|
||||||
|
if isinstance(v, int):
|
||||||
|
return ctx.types['int32']
|
||||||
|
elif isinstance(v, bool):
|
||||||
|
return ctx.types['bool']
|
||||||
|
else:
|
||||||
|
raise CustomError(f'unknown constant {v}')
|
||||||
|
|
||||||
|
def parse_name(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
node):
|
||||||
|
if node.id in sym_table:
|
||||||
|
return sym_table[node.id]
|
||||||
|
else:
|
||||||
|
raise CustomError(f'unbounded variable {node.id}')
|
||||||
|
|
||||||
|
def parse_list(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
node):
|
||||||
|
types = [parse_expr(ctx, sym_table, e) for e in node.elts]
|
||||||
|
if len(types) == 0:
|
||||||
|
return ListType(BotType())
|
||||||
|
for t in types[1:]:
|
||||||
|
if t != types[0]:
|
||||||
|
raise CustomError(f'inhomogeneous list is not allowed')
|
||||||
|
return ListType(types[0])
|
||||||
|
|
||||||
|
def parse_tuple(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
node):
|
||||||
|
types = [parse_expr(ctx, sym_table, e) for e in node.elts]
|
||||||
|
return TupleType(types)
|
||||||
|
|
||||||
|
def parse_attribute(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
node):
|
||||||
|
obj = parse_expr(node.value)
|
||||||
|
if node.attr in obj.fields:
|
||||||
|
return obj.fields[node.attr]
|
||||||
|
raise CustomError(f'unknown field {node.attr} in {obj}')
|
||||||
|
|
||||||
|
def parse_bool_ops(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
node):
|
||||||
|
assert len(node.values) == 2
|
||||||
|
left = parse_expr(ctx, sym_table, node.values[0])
|
||||||
|
right = parse_expr(ctx, sym_table, node.values[1])
|
||||||
|
b = ctx.types['bool']
|
||||||
|
if left != b or right != b:
|
||||||
|
raise CustomError('operands of bool ops must be booleans')
|
||||||
|
return b
|
||||||
|
|
||||||
|
def parse_bin_ops(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
node):
|
||||||
|
left = parse_expr(ctx, sym_table, node.left)
|
||||||
|
right = parse_expr(ctx, sym_table, node.right)
|
||||||
|
op = get_bin_ops(node.op)
|
||||||
|
return resolve_call(left, op, [right], {}, ctx)
|
||||||
|
|
||||||
|
def parse_unary_ops(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
node):
|
||||||
|
t = parse_expr(node.operand)
|
||||||
|
if isinstance(node.op, ast.Not):
|
||||||
|
b = ctx.types['bool']
|
||||||
|
if t != b:
|
||||||
|
raise CustomError('operands of bool ops must be booleans')
|
||||||
|
return b
|
||||||
|
return resolve_call(t, get_unary_op(node.op), [], {}, ctx)
|
||||||
|
|
||||||
|
def parse_compare(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
node):
|
||||||
|
items = [parse_expr(ctx, sym_table, v) for v in node.comparators]
|
||||||
|
items.insert(0, parse_expr(ctx, sym_table, node.left))
|
||||||
|
boolean = ctx.types['bool']
|
||||||
|
ops = [get_bin_ops(v) for v in node.ops]
|
||||||
|
for a, b, op in zip(items[:-1], items[1:], ops):
|
||||||
|
result = resolve_call(a, op, [b], {}, ctx)
|
||||||
|
if result != boolean:
|
||||||
|
raise CustomError(
|
||||||
|
f'result of comparison must be bool instead of {result}')
|
||||||
|
return boolean
|
||||||
|
|
||||||
|
def parse_call(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
node):
|
||||||
|
if len(node.keywords) > 0:
|
||||||
|
raise CustomError('keyword arguments are not supported')
|
||||||
|
args = [parse_expr(ctx, sym_table, v) for v in node.args]
|
||||||
|
obj = None
|
||||||
|
f = None
|
||||||
|
if isinstance(node.func, ast.Attribute):
|
||||||
|
obj = parse_expr(node.func.value)
|
||||||
|
f = node.func.attr
|
||||||
|
elif isinstance(node.func, ast.Name):
|
||||||
|
f = node.func.id
|
||||||
|
return resolve_call(obj, f, args, {}, ctx)
|
||||||
|
|
||||||
|
def parse_subscript(ctx: Context,
|
||||||
|
sym_table: dict[str, Type],
|
||||||
|
node):
|
||||||
|
value = parse_expr(ctx, sym_table, node.value)
|
||||||
|
if not isinstance(value, ListType):
|
||||||
|
raise CustomError(f'cannot take index of {value}')
|
||||||
|
s = parse_expr(ctx, sym_table, node.slice)
|
||||||
|
i32 = ctx.types['int32']
|
||||||
|
if s == i32:
|
||||||
|
return value.params[0]
|
||||||
|
else:
|
||||||
|
# will support slice
|
||||||
|
raise CustomError(f'index of type {s} is not supported')
|
||||||
|
|
73
toy-impl/test_expr.py
Normal file
73
toy-impl/test_expr.py
Normal file
@ -0,0 +1,73 @@
|
|||||||
|
import ast
|
||||||
|
from type_def import *
|
||||||
|
from inference import *
|
||||||
|
from helper import *
|
||||||
|
from parse_expr import *
|
||||||
|
|
||||||
|
types = {
|
||||||
|
'int32': PrimitiveType('int32'),
|
||||||
|
'int64': PrimitiveType('int64'),
|
||||||
|
'str': PrimitiveType('str'),
|
||||||
|
'bool': PrimitiveType('bool')
|
||||||
|
}
|
||||||
|
|
||||||
|
i32 = types['int32']
|
||||||
|
i64 = types['int64']
|
||||||
|
s = types['str']
|
||||||
|
b = types['bool']
|
||||||
|
|
||||||
|
variables = {
|
||||||
|
'X': TypeVariable('X', []),
|
||||||
|
'Y': TypeVariable('Y', []),
|
||||||
|
'I': TypeVariable('I', [i32, i64]),
|
||||||
|
'A': TypeVariable('A', [i32, i64, s]),
|
||||||
|
}
|
||||||
|
|
||||||
|
X = variables['X']
|
||||||
|
Y = variables['Y']
|
||||||
|
I = variables['I']
|
||||||
|
A = variables['A']
|
||||||
|
|
||||||
|
i32.methods['__init__'] = ([SelfType(), I], None, set())
|
||||||
|
i32.methods['__add__'] = ([SelfType(), i32], i32, set())
|
||||||
|
i32.methods['__sub__'] = ([SelfType(), i32], i32, set())
|
||||||
|
i32.methods['__lt__'] = ([SelfType(), i32], b, set())
|
||||||
|
i32.methods['__gt__'] = ([SelfType(), i32], b, set())
|
||||||
|
i32.methods['__eq__'] = ([SelfType(), i32], b, set())
|
||||||
|
i32.methods['__ne__'] = ([SelfType(), i32], b, set())
|
||||||
|
i32.methods['__le__'] = ([SelfType(), i32], b, set())
|
||||||
|
i32.methods['__ge__'] = ([SelfType(), i32], b, set())
|
||||||
|
|
||||||
|
i64.methods['__init__'] = ([SelfType(), I], None, set())
|
||||||
|
i64.methods['__add__'] = ([SelfType(), i64], i64, set())
|
||||||
|
i64.methods['__sub__'] = ([SelfType(), i64], i64, set())
|
||||||
|
i64.methods['__lt__'] = ([SelfType(), i64], b, set())
|
||||||
|
i64.methods['__gt__'] = ([SelfType(), i64], b, set())
|
||||||
|
i64.methods['__eq__'] = ([SelfType(), i64], b, set())
|
||||||
|
i64.methods['__ne__'] = ([SelfType(), i64], b, set())
|
||||||
|
i64.methods['__le__'] = ([SelfType(), i64], b, set())
|
||||||
|
i64.methods['__ge__'] = ([SelfType(), i64], b, set())
|
||||||
|
|
||||||
|
ctx = Context(variables, types)
|
||||||
|
|
||||||
|
def test_expr(expr, sym_table= {}):
|
||||||
|
print(f'Testing {expr} w.r.t. {stringify_subst(sym_table)}')
|
||||||
|
try:
|
||||||
|
tree = ast.parse(expr, mode='eval')
|
||||||
|
result = parse_expr(ctx, sym_table, tree)
|
||||||
|
print(result)
|
||||||
|
except CustomError as err:
|
||||||
|
print(f'error: {err.msg}')
|
||||||
|
|
||||||
|
test_expr('1 + 1')
|
||||||
|
test_expr('1 - 1')
|
||||||
|
test_expr('int64(1)')
|
||||||
|
test_expr('int64(1) - 1')
|
||||||
|
test_expr('a - a', {'a': I})
|
||||||
|
test_expr('a - a', {'a': A})
|
||||||
|
test_expr('[1, 2, 3][2]')
|
||||||
|
test_expr('[[1], [2], [3]][2]')
|
||||||
|
test_expr('[[1], [2], [3]][a]', {'a': i32})
|
||||||
|
test_expr('a == a == a', {'a': I})
|
||||||
|
test_expr('a == a and 1 == 2', {'a': I})
|
||||||
|
|
Loading…
Reference in New Issue
Block a user