From c85e4122063cbeb5f951c3558f55d16580165daa Mon Sep 17 00:00:00 2001 From: David Mak Date: Tue, 2 Jul 2024 16:14:55 +0800 Subject: [PATCH] core: Implement list::__mul__ --- nac3core/src/codegen/expr.rs | 86 ++++++++++++++++++- ...el__test__test_analyze__generic_class.snap | 2 +- ...t__test_analyze__inheritance_override.snap | 2 +- ...est__test_analyze__list_tuple_generic.snap | 4 +- ...__toplevel__test__test_analyze__self1.snap | 2 +- ...t__test_analyze__simple_class_compose.snap | 4 +- nac3core/src/typecheck/magic_methods.rs | 21 ++++- nac3core/src/typecheck/typedef/mod.rs | 39 +++------ 8 files changed, 124 insertions(+), 36 deletions(-) diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index e768d873..c8276d49 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -10,10 +10,11 @@ use crate::{ gen_in_range_check, get_llvm_abi_type, get_llvm_type, irrt::*, llvm_intrinsics::{ - call_expect, call_float_floor, call_float_pow, call_float_powi, call_memcpy_generic, + call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax, + call_memcpy_generic, }, need_sret, numpy, - stmt::{gen_if_else_expr_callback, gen_raise, gen_var}, + stmt::{gen_for_callback_incrementing, gen_if_else_expr_callback, gen_raise, gen_var}, CodeGenContext, CodeGenTask, CodeGenerator, }, symbol_resolver::{SymbolValue, ValueEnum}, @@ -1193,6 +1194,87 @@ pub fn gen_binop_expr_with_values<'ctx, G: CodeGenerator>( Some("f_pow_i"), ); Ok(Some(res.into())) + } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) + || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) + { + let llvm_usize = generator.get_size_type(ctx.ctx); + + if is_aug_assign || op != Operator::Mult { + todo!("Only __mul__ is implemented for lists") + } + + let (elem_ty, list_val, int_val) = + if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { + let elem_ty = + if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty1) { + *params.iter().next().unwrap().1 + } else { + unreachable!() + }; + + (elem_ty, left_val, right_val) + } else if ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::List.id()) { + let elem_ty = + if let TypeEnum::TObj { params, .. } = &*ctx.unifier.get_ty_immutable(ty2) { + *params.iter().next().unwrap().1 + } else { + unreachable!() + }; + + (elem_ty, right_val, left_val) + } else { + unreachable!() + }; + let list_val = ListValue::from_ptr_val(list_val.into_pointer_value(), llvm_usize, None); + let int_val = + ctx.builder.build_int_s_extend(int_val.into_int_value(), llvm_usize, "").unwrap(); + // [...] * (i where i < 0) => [] + let int_val = call_int_smax(ctx, int_val, llvm_usize.const_zero(), None); + + let elem_llvm_ty = ctx.get_llvm_type(generator, elem_ty); + + let new_list = allocate_list( + generator, + ctx, + Some(elem_llvm_ty), + ctx.builder.build_int_mul(list_val.load_size(ctx, None), int_val, "").unwrap(), + None, + ); + + gen_for_callback_incrementing( + generator, + ctx, + llvm_usize.const_zero(), + (int_val, false), + |generator, ctx, _, i| { + let offset = + ctx.builder.build_int_mul(i, list_val.load_size(ctx, None), "").unwrap(); + let ptr = + unsafe { new_list.data().ptr_offset_unchecked(ctx, generator, &offset, None) }; + + let memcpy_sz = ctx + .builder + .build_int_mul( + list_val.load_size(ctx, None), + elem_llvm_ty.size_of().unwrap(), + "", + ) + .unwrap(); + + call_memcpy_generic( + ctx, + ptr, + list_val.data().base_ptr(ctx, generator), + memcpy_sz, + ctx.ctx.bool_type().const_zero(), + ); + + Ok(()) + }, + llvm_usize.const_int(1, false), + )?; + + Ok(Some(new_list.as_base_value().into())) } else if ty1.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) || ty2.obj_id(&ctx.unifier).is_some_and(|id| id == PrimDef::NDArray.id()) { diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index ed94a833..c8ff7dba 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -5,7 +5,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"Generic_A[V]\", \"B\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(242)]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [TypeVarId(245)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 2c59687c..b67596d8 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -7,7 +7,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"B[typevar231]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar231\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"B[typevar234]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"typevar234\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"B[bool]\", \"A[float]\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 2aa2ead6..08f254f5 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -5,8 +5,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"A[T, V]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(244)]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(249)]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [TypeVarId(247)]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(252)]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[list[float], int32]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index a350d18a..ce3b02ed 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -3,7 +3,7 @@ source: nac3core/src/toplevel/test.rs expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"A[typevar230, typevar231]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar230\", \"typevar231\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"A[typevar233, typevar234]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[float, bool], b:B], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\")],\ntype_vars: [\"typevar233\", \"typevar234\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[float, bool], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[float, bool]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"A[int64, bool]\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[float, bool]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[list[B], int32]], tuple[A[virtual[A[B, int32]], bool], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index 5fe92b13..ae002764 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -6,12 +6,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(250)]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [TypeVarId(253)]\n}\n", "Class {\nname: \"B\",\nancestors: [\"B\", \"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"C\", \"A\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(258)]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [TypeVarId(261)]\n}\n", ] diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index f2b995e2..f2464c03 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -429,12 +429,27 @@ pub fn typeof_binop( lhs: Type, rhs: Type, ) -> Result, String> { + let is_left_list = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id()); + let is_right_list = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::List.id()); let is_left_ndarray = lhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); let is_right_ndarray = rhs.obj_id(unifier).is_some_and(|id| id == PrimDef::NDArray.id()); Ok(Some(match op { Operator::Add | Operator::Sub | Operator::Mult | Operator::Mod | Operator::FloorDiv => { - if is_left_ndarray || is_right_ndarray { + if is_left_list || is_right_list { + if op != Operator::Mult { + return Err(format!( + "Binary operator {} not supported for list", + binop_name(op) + )); + } + + if is_left_list { + lhs + } else { + rhs + } + } else if is_left_ndarray || is_right_ndarray { typeof_ndarray_broadcast(unifier, primitives, lhs, rhs)? } else if unifier.unioned(lhs, rhs) { lhs @@ -604,6 +619,7 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie bool: bool_t, uint32: uint32_t, uint64: uint64_t, + list: list_t, ndarray: ndarray_t, .. } = *store; @@ -648,6 +664,9 @@ pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifie impl_sign(unifier, store, bool_t, Some(int32_t)); impl_eq(unifier, store, bool_t, &[bool_t, ndarray_bool_t], None); + /* list ======== */ + impl_binop(unifier, store, list_t, &[int32_t, int64_t], Some(list_t), &[Operator::Mult]); + /* ndarray ===== */ let ndarray_usized_ndims_tvar = unifier.get_fresh_const_generic_var(size_t, Some("ndarray_ndims".into()), None); diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 0382830b..79ca48f2 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -894,27 +894,6 @@ impl Unifier { self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } - ( - TVar { fields: Some(fields), range, is_const_generic: false, .. }, - TObj { obj_id, params, .. }, - ) if *obj_id == PrimDef::List.id() => { - let ty = iter_type_vars(params).nth(0).unwrap().ty; - - for (k, v) in fields { - match *k { - RecordKey::Int(_) => { - self.unify_impl(v.ty, ty, false).map_err(|e| e.at(v.loc))?; - } - RecordKey::Str(_) => { - return Err(TypeError::new(TypeErrorKind::NoSuchField(*k, b), v.loc)) - } - } - } - let x = self.check_var_compatibility(b, range)?.unwrap_or(b); - self.unify_impl(x, b, false)?; - self.set_a_to_b(a, x); - } - ( TVar { id: id1, range: ty1, is_const_generic: true, .. }, TVar { id: id2, range: ty2, .. }, @@ -993,7 +972,7 @@ impl Unifier { } self.set_a_to_b(a, b); } - (TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => { + (TVar { fields: Some(map), range, .. }, TObj { obj_id, fields, params }) => { for (k, field) in map { match *k { RecordKey::Str(s) => { @@ -1012,10 +991,18 @@ impl Unifier { self.unify_impl(field.ty, ty, false).map_err(|v| v.at(field.loc))?; } RecordKey::Int(_) => { - return Err(TypeError::new( - TypeErrorKind::NoSuchField(*k, b), - field.loc, - )) + // Allow expressions such as list[0] + if *obj_id == PrimDef::List.id() { + let ty = iter_type_vars(params).nth(0).unwrap().ty; + + self.unify_impl(field.ty, ty, false) + .map_err(|e| e.at(field.loc))?; + } else { + return Err(TypeError::new( + TypeErrorKind::NoSuchField(*k, b), + field.loc, + )); + } } } }