Custom Defined Operators Behavior Support #268

Merged
sb10q merged 2 commits from operators into master 2022-04-18 15:31:56 +08:00
6 changed files with 369 additions and 21 deletions

View File

@ -10,7 +10,10 @@ use crate::{
}, },
symbol_resolver::{SymbolValue, ValueEnum}, symbol_resolver::{SymbolValue, ValueEnum},
toplevel::{DefinitionId, TopLevelDef}, toplevel::{DefinitionId, TopLevelDef},
typecheck::typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier}, typecheck::{
typedef::{FunSignature, FuncArg, Type, TypeEnum, Unifier},
magic_methods::{binop_name, binop_assign_name},
},
}; };
use inkwell::{ use inkwell::{
AddressSpace, AddressSpace,
@ -927,21 +930,29 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>(
left: &Expr<Option<Type>>, left: &Expr<Option<Type>>,
op: &Operator, op: &Operator,
right: &Expr<Option<Type>>, right: &Expr<Option<Type>>,
) -> Result<ValueEnum<'ctx>, String> { loc: Location,
is_aug_assign: bool,
) -> Result<Option<ValueEnum<'ctx>>, String> {
let ty1 = ctx.unifier.get_representative(left.custom.unwrap()); let ty1 = ctx.unifier.get_representative(left.custom.unwrap());
let ty2 = ctx.unifier.get_representative(right.custom.unwrap()); let ty2 = ctx.unifier.get_representative(right.custom.unwrap());
let left = generator.gen_expr(ctx, left)?.unwrap().to_basic_value_enum(ctx, generator, left.custom.unwrap())?; let left_val = generator
let right = generator.gen_expr(ctx, right)?.unwrap().to_basic_value_enum(ctx, generator, right.custom.unwrap())?; .gen_expr(ctx, left)?
.unwrap()
.to_basic_value_enum(ctx, generator, left.custom.unwrap())?;
let right_val = generator
.gen_expr(ctx, right)?
.unwrap()
.to_basic_value_enum(ctx, generator, right.custom.unwrap())?;
// we can directly compare the types, because we've got their representatives // we can directly compare the types, because we've got their representatives
// which would be unchanged until further unification, which we would never do // which would be unchanged until further unification, which we would never do
// when doing code generation for function instances // when doing code generation for function instances
Ok(if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) {
ctx.gen_int_ops(generator, op, left, right, true) Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, true).into()))
} else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) {
ctx.gen_int_ops(generator, op, left, right, false) Ok(Some(ctx.gen_int_ops(generator, op, left_val, right_val, false).into()))
} else if ty1 == ty2 && ctx.primitives.float == ty1 { } else if ty1 == ty2 && ctx.primitives.float == ty1 {
ctx.gen_float_ops(op, left, right) Ok(Some(ctx.gen_float_ops(op, left_val, right_val).into()))
} else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 {
// Pow is the only operator that would pass typecheck between float and int // Pow is the only operator that would pass typecheck between float and int
assert!(*op == Operator::Pow); assert!(*op == Operator::Pow);
@ -951,14 +962,68 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>(
let ty = f64_t.fn_type(&[f64_t.into(), i32_t.into()], false); let ty = f64_t.fn_type(&[f64_t.into(), i32_t.into()], false);
ctx.module.add_function("llvm.powi.f64.i32", ty, None) ctx.module.add_function("llvm.powi.f64.i32", ty, None)
}); });
ctx.builder let res = ctx.builder
.build_call(pow_intr, &[left.into(), right.into()], "f_pow_i") .build_call(pow_intr, &[left_val.into(), right_val.into()], "f_pow_i")
.try_as_basic_value() .try_as_basic_value()
.unwrap_left() .unwrap_left();
Ok(Some(res.into()))
} else { } else {
unimplemented!() let (op_name, id) = if let TypeEnum::TObj { fields, obj_id, .. } =
ctx.unifier.get_ty_immutable(left.custom.unwrap()).as_ref()
{
let (binop_name, binop_assign_name) = (
binop_name(op).into(),
binop_assign_name(op).into()
);
// if is aug_assign, try aug_assign operator first
if is_aug_assign && fields.contains_key(&binop_assign_name) {
(binop_assign_name, *obj_id)
} else {
(binop_name, *obj_id)
}
} else {
unreachable!("must be tobj")
};
let signature = match ctx.calls.get(&loc.into()) {
Some(call) => ctx.unifier.get_call_signature(*call).unwrap(),
None => {
if let TypeEnum::TObj { fields, .. } =
ctx.unifier.get_ty_immutable(left.custom.unwrap()).as_ref()
{
let fn_ty = fields.get(&op_name).unwrap().0;
if let TypeEnum::TFunc(sig) = ctx.unifier.get_ty_immutable(fn_ty).as_ref() {
sig.clone()
} else {
unreachable!("must be func sig")
}
} else {
unreachable!("must be tobj")
}
},
};
let fun_id = {
let defs = ctx.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read();
if let TopLevelDef::Class { methods, .. } = &*obj_def {
let mut fun_id = None;
for (name, _, id) in methods.iter() {
if name == &op_name {
Review

Iterator::find

Iterator::find
fun_id = Some(*id);
}
}
fun_id.unwrap()
} else {
unreachable!()
}
};
generator
.gen_call(
ctx,
Some((left.custom.unwrap(), left_val.into())),
(&signature, fun_id),
vec![(None, right_val.into())],
).map(|f| f.map(|f| f.into()))
} }
.into())
} }
pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
@ -1125,7 +1190,9 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>(
phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]); phi.add_incoming(&[(&a, a_bb), (&b, b_bb)]);
phi.as_basic_value().into() phi.as_basic_value().into()
} }
ExprKind::BinOp { op, left, right } => gen_binop_expr(generator, ctx, left, op, right)?, ExprKind::BinOp { op, left, right } => {
return gen_binop_expr(generator, ctx, left, op, right, expr.location, false);
}
ExprKind::UnaryOp { op, operand } => { ExprKind::UnaryOp { op, operand } => {
let ty = ctx.unifier.get_representative(operand.custom.unwrap()); let ty = ctx.unifier.get_representative(operand.custom.unwrap());
let val = let val =

View File

@ -1020,8 +1020,8 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator>(
StmtKind::For { .. } => generator.gen_for(ctx, stmt)?, StmtKind::For { .. } => generator.gen_for(ctx, stmt)?,
StmtKind::With { .. } => generator.gen_with(ctx, stmt)?, StmtKind::With { .. } => generator.gen_with(ctx, stmt)?,
StmtKind::AugAssign { target, op, value, .. } => { StmtKind::AugAssign { target, op, value, .. } => {
let value = gen_binop_expr(generator, ctx, target, op, value)?; let value = gen_binop_expr(generator, ctx, target, op, value, stmt.location, true)?;
generator.gen_assign(ctx, target, value)?; generator.gen_assign(ctx, target, value.unwrap())?;
} }
StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?, StmtKind::Try { .. } => gen_try(generator, ctx, stmt)?,
StmtKind::Raise { exc, .. } => { StmtKind::Raise { exc, .. } => {

View File

@ -83,7 +83,7 @@ where
pub fn impl_binop( pub fn impl_binop(
unifier: &mut Unifier, unifier: &mut Unifier,
store: &PrimitiveStore, _store: &PrimitiveStore,
ty: Type, ty: Type,
other_ty: &[Type], other_ty: &[Type],
ret_ty: Type, ret_ty: Type,
@ -120,7 +120,7 @@ pub fn impl_binop(
fields.insert(binop_assign_name(op).into(), { fields.insert(binop_assign_name(op).into(), {
( (
unifier.add_ty(TypeEnum::TFunc(FunSignature { unifier.add_ty(TypeEnum::TFunc(FunSignature {
ret: store.none, ret: ret_ty,
vars: function_vars.clone(), vars: function_vars.clone(),
args: vec![FuncArg { args: vec![FuncArg {
ty: other_ty, ty: other_ty,

View File

@ -423,7 +423,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
(None, None) => {} (None, None) => {}
}, },
ast::StmtKind::AugAssign { target, op, value, .. } => { ast::StmtKind::AugAssign { target, op, value, .. } => {
let res_ty = self.infer_bin_ops(stmt.location, target, op, value)?; let res_ty = self.infer_bin_ops(stmt.location, target, op, value, true)?;
self.unify(res_ty, target.custom.unwrap(), &stmt.location)?; self.unify(res_ty, target.custom.unwrap(), &stmt.location)?;
} }
ast::StmtKind::Assert { test, msg, .. } => { ast::StmtKind::Assert { test, msg, .. } => {
@ -505,7 +505,7 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
} }
ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?), ast::ExprKind::BoolOp { values, .. } => Some(self.infer_bool_ops(values)?),
ast::ExprKind::BinOp { left, op, right } => { ast::ExprKind::BinOp { left, op, right } => {
Some(self.infer_bin_ops(expr.location, left, op, right)?) Some(self.infer_bin_ops(expr.location, left, op, right, false)?)
} }
ast::ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?), ast::ExprKind::UnaryOp { op, operand } => Some(self.infer_unary_ops(op, operand)?),
ast::ExprKind::Compare { left, ops, comparators } => { ast::ExprKind::Compare { left, ops, comparators } => {
@ -1028,8 +1028,24 @@ impl<'a> Inferencer<'a> {
left: &ast::Expr<Option<Type>>, left: &ast::Expr<Option<Type>>,
op: &ast::Operator, op: &ast::Operator,
right: &ast::Expr<Option<Type>>, right: &ast::Expr<Option<Type>>,
is_aug_assign: bool,
) -> InferenceResult { ) -> InferenceResult {
let method = binop_name(op).into(); let method = if let TypeEnum::TObj { fields, .. } =
self.unifier.get_ty_immutable(left.custom.unwrap()).as_ref()
{
let (binop_name, binop_assign_name) = (
binop_name(op).into(),
binop_assign_name(op).into()
);
// if is aug_assign, try aug_assign operator first
if is_aug_assign && fields.contains_key(&binop_assign_name) {
binop_assign_name
} else {
binop_name
}
} else {
binop_name(op).into()
};
self.build_method_call( self.build_method_call(
location, location,
method, method,

View File

@ -0,0 +1,257 @@
from __future__ import annotations
@extern
def output_int32(x: int32):
...
@extern
def output_uint32(x: uint32):
...
@extern
def output_int64(x: int64):
...
@extern
def output_uint64(x: uint64):
...
@extern
def output_float64(x: float):
...
def run() -> int32:
test_int32()
test_uint32()
test_int64()
test_uint64()
test_A()
test_B()
return 0
def test_int32():
a = 17
b = 3
output_int32(a + b)
output_int32(a - b)
output_int32(a * b)
output_int32(a // b)
output_int32(a % b)
output_int32(a | b)
output_int32(a ^ b)
output_int32(a & b)
output_int32(a << b)
output_int32(a >> b)
output_float64(a / b)
a += b
output_int32(a)
a -= b
output_int32(a)
a *= b
output_int32(a)
a //= b
output_int32(a)
a %= b
output_int32(a)
a |= b
output_int32(a)
a ^= b
output_int32(a)
a &= b
output_int32(a)
a <<= b
output_int32(a)
a >>= b
output_int32(a)
# fail because (a / b) is float
# a /= b
def test_uint32():
a = uint32(17)
b = uint32(3)
output_uint32(a + b)
output_uint32(a - b)
output_uint32(a * b)
output_uint32(a // b)
output_uint32(a % b)
output_uint32(a | b)
output_uint32(a ^ b)
output_uint32(a & b)
output_uint32(a << b)
output_uint32(a >> b)
output_float64(a / b)
a += b
output_uint32(a)
a -= b
output_uint32(a)
a *= b
output_uint32(a)
a //= b
output_uint32(a)
a %= b
output_uint32(a)
a |= b
output_uint32(a)
a ^= b
output_uint32(a)
a &= b
output_uint32(a)
a <<= b
output_uint32(a)
a >>= b
output_uint32(a)
def test_int64():
a = int64(17)
b = int64(3)
output_int64(a + b)
output_int64(a - b)
output_int64(a * b)
output_int64(a // b)
output_int64(a % b)
output_int64(a | b)
output_int64(a ^ b)
output_int64(a & b)
output_int64(a << b)
output_int64(a >> b)
output_float64(a / b)
a += b
output_int64(a)
a -= b
output_int64(a)
a *= b
output_int64(a)
a //= b
output_int64(a)
a %= b
output_int64(a)
a |= b
output_int64(a)
a ^= b
output_int64(a)
a &= b
output_int64(a)
a <<= b
output_int64(a)
a >>= b
output_int64(a)
def test_uint64():
a = uint64(17)
b = uint64(3)
output_uint64(a + b)
output_uint64(a - b)
output_uint64(a * b)
output_uint64(a // b)
output_uint64(a % b)
output_uint64(a | b)
output_uint64(a ^ b)
output_uint64(a & b)
output_uint64(a << b)
output_uint64(a >> b)
output_float64(a / b)
a += b
output_uint64(a)
a -= b
output_uint64(a)
a *= b
output_uint64(a)
a //= b
output_uint64(a)
a %= b
output_uint64(a)
a |= b
output_uint64(a)
a ^= b
output_uint64(a)
a &= b
output_uint64(a)
a <<= b
output_uint64(a)
a >>= b
output_uint64(a)
class A:
a: int32
def __init__(self, a: int32):
self.a = a
def __add__(self, other: A) -> A:
output_int32(self.a + other.a)
return A(self.a + other.a)
def __sub__(self, other: A) -> A:
output_int32(self.a - other.a)
return A(self.a - other.a)
def test_A():
a = A(17)
b = A(3)
c = a + b
# fail due to alloca in __add__ function
# output_int32(c.a)
a += b
# fail due to alloca in __add__ function
# output_int32(a.a)
a = A(17)
b = A(3)
d = a - b
# fail due to alloca in __add__ function
# output_int32(c.a)
a -= b
# fail due to alloca in __add__ function
# output_int32(a.a)
a = A(17)
b = A(3)
a.__add__(b)
a.__sub__(b)
class B:
a: int32
def __init__(self, a: int32):
self.a = a
def __add__(self, other: B) -> B:
output_int32(self.a + other.a)
return B(self.a + other.a)
def __sub__(self, other: B) -> B:
output_int32(self.a - other.a)
return B(self.a - other.a)
def __iadd__(self, other: B) -> B:
output_int32(self.a + other.a + 24)
return B(self.a + other.a + 24)
def __isub__(self, other: B) -> B:
output_int32(self.a - other.a - 24)
return B(self.a - other.a - 24)
def test_B():
a = B(17)
b = B(3)
c = a + b
# fail due to alloca in __add__ function
# output_int32(c.a)
a += b
# fail due to alloca in __add__ function
# output_int32(a.a)
a = B(17)
b = B(3)
d = a - b
# fail due to alloca in __add__ function
# output_int32(c.a)
a -= b
# fail due to alloca in __add__ function
# output_int32(a.a)
a = B(17)
b = B(3)
a.__add__(b)
a.__sub__(b)

View File

@ -205,6 +205,14 @@ fn main() {
continue; continue;
} }
// still needs to skip this `from __future__ import annotations` because this seems to be
// magic in python and there seems no way to patch it from another module..
if matches!(
&stmt.node,
StmtKind::ImportFrom { module, names, .. }
if module == &Some("__future__".into()) && names[0].name == "annotations".into()
Review

this also allows ... import annotations, some_other_thing

this also allows ``... import annotations, some_other_thing``
) { continue; }
let (name, def_id, ty) = let (name, def_id, ty) =
composer.register_top_level(stmt, Some(resolver.clone()), "__main__".into()).unwrap(); composer.register_top_level(stmt, Some(resolver.clone()), "__main__".into()).unwrap();