From f01d833d4899dc45dbe710c12e395dd7cb8586df Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 22 Nov 2023 16:45:58 +0800 Subject: [PATCH 1/9] standalone: Add missing parenthesis --- nac3standalone/demo/demo.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/nac3standalone/demo/demo.c b/nac3standalone/demo/demo.c index c5d2094..f272c35 100644 --- a/nac3standalone/demo/demo.c +++ b/nac3standalone/demo/demo.c @@ -94,13 +94,13 @@ uint64_t dbg_stack_address(__attribute__((unused)) struct cslice *slice) { } uint32_t __nac3_personality(uint32_t state, uint32_t exception_object, uint32_t context) { - printf("__nac3_personality(state: %u, exception_object: %u, context: %u\n", state, exception_object, context); + printf("__nac3_personality(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context); exit(101); __builtin_unreachable(); } uint32_t __nac3_raise(uint32_t state, uint32_t exception_object, uint32_t context) { - printf("__nac3_raise(state: %u, exception_object: %u, context: %u\n", state, exception_object, context); + printf("__nac3_raise(state: %u, exception_object: %u, context: %u)\n", state, exception_object, context); exit(101); __builtin_unreachable(); } -- 2.44.1 From 1c3a823670bf408ca59745f937f05fd91ea14f78 Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 22 Nov 2023 13:35:56 +0800 Subject: [PATCH 2/9] core: Do not discard value names for IRRT --- nac3core/build.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/nac3core/build.rs b/nac3core/build.rs index 281e283..6edcf0b 100644 --- a/nac3core/build.rs +++ b/nac3core/build.rs @@ -17,6 +17,7 @@ fn main() { const FLAG: &[&str] = &[ "--target=wasm32", FILE, + "-fno-discard-value-names", "-O3", "-emit-llvm", "-S", -- 2.44.1 From bd792904f90e8abb2cd56593798b255ea720f4b5 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 15 Dec 2023 14:02:30 +0800 Subject: [PATCH 3/9] core: Add size_t to primitive store Used for ndims in ndarray. --- nac3artiq/src/lib.rs | 15 ++++++++++++++- nac3core/src/codegen/mod.rs | 1 + nac3core/src/toplevel/composer.rs | 12 +++++------- nac3core/src/toplevel/helper.rs | 3 ++- nac3core/src/typecheck/type_inferencer/mod.rs | 12 ++++++++++++ nac3standalone/src/main.rs | 7 ++++--- 6 files changed, 38 insertions(+), 12 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 3e14684..9a0ad86 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -63,6 +63,17 @@ enum Isa { CortexA9, } +impl Isa { + /// Returns the number of bits in `size_t` for the [`Isa`]. + fn get_size_type(&self) -> u32 { + if self == &Isa::Host { + 64u32 + } else { + 32u32 + } + } +} + #[derive(Clone)] pub struct PrimitivePythonId { int: u64, @@ -277,9 +288,11 @@ impl Nac3 { py: Python, link_fn: &dyn Fn(&Module) -> PyResult, ) -> PyResult { + let size_t = self.isa.get_size_type(); let (mut composer, mut builtins_def, mut builtins_ty) = TopLevelComposer::new( self.builtins.clone(), ComposerConfig { kernel_ann: Some("Kernel"), kernel_invariant_ann: "KernelInvariant" }, + size_t, ); let builtins = PyModule::import(py, "builtins")?; @@ -792,7 +805,7 @@ impl Nac3 { Isa::RiscV32IMA => &timeline::NOW_PINNING_TIME_FNS, Isa::CortexA9 | Isa::Host => &timeline::EXTERN_TIME_FNS, }; - let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0; + let primitive: PrimitiveStore = TopLevelComposer::make_primitives(isa.get_size_type()).0; let builtins = vec![ ( "now_mu".into(), diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 51496e6..41bd2a8 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -614,6 +614,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte str: unifier.get_representative(primitives.str), exception: unifier.get_representative(primitives.exception), option: unifier.get_representative(primitives.option), + ..primitives }; let mut type_cache: HashMap<_, _> = [ diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 8e612a6..48f7c1a 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -37,12 +37,8 @@ pub struct TopLevelComposer { // number of built-in function and classes in the definition list, later skip pub builtin_num: usize, pub core_config: ComposerConfig, -} - -impl Default for TopLevelComposer { - fn default() -> Self { - Self::new(vec![], ComposerConfig::default()).0 - } + /// The size of a native word on the target platform. + pub size_t: u32, } impl TopLevelComposer { @@ -52,8 +48,9 @@ impl TopLevelComposer { pub fn new( builtins: Vec<(StrRef, FunSignature, Arc)>, core_config: ComposerConfig, + size_t: u32, ) -> (Self, HashMap, HashMap) { - let mut primitives = Self::make_primitives(); + let mut primitives = Self::make_primitives(size_t); let mut definition_ast_list = builtins::get_builtins(&mut primitives); let primitives_ty = primitives.0; let mut unifier = primitives.1; @@ -146,6 +143,7 @@ impl TopLevelComposer { defined_names, method_class, core_config, + size_t, }, builtin_id, builtin_ty, diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 0f4d375..8e440c1 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -44,7 +44,7 @@ impl TopLevelDef { impl TopLevelComposer { #[must_use] - pub fn make_primitives() -> (PrimitiveStore, Unifier) { + pub fn make_primitives(size_t: u32) -> (PrimitiveStore, Unifier) { let mut unifier = Unifier::new(); let int32 = unifier.add_ty(TypeEnum::TObj { obj_id: DefinitionId(0), @@ -144,6 +144,7 @@ impl TopLevelComposer { str, exception, option, + size_t, }; unifier.put_primitive_store(&primitives); crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 0b9b5a0..c436076 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -41,6 +41,18 @@ pub struct PrimitiveStore { pub str: Type, pub exception: Type, pub option: Type, + pub size_t: u32, +} + +impl PrimitiveStore { + /// Returns a [Type] representing `size_t`. + pub fn usize(&self) -> Type { + match self.size_t { + 32 => self.uint32, + 64 => self.uint64, + _ => unreachable!(), + } + } } pub struct FunctionData { diff --git a/nac3standalone/src/main.rs b/nac3standalone/src/main.rs index eb746a8..8973cf3 100644 --- a/nac3standalone/src/main.rs +++ b/nac3standalone/src/main.rs @@ -286,6 +286,7 @@ fn main() { // The default behavior for -O where n>3 defaults to O3 for both Clang and GCC _ => OptimizationLevel::Aggressive, }; + const SIZE_T: u32 = 64; let program = match fs::read_to_string(file_name.clone()) { Ok(program) => program, @@ -295,9 +296,9 @@ fn main() { } }; - let primitive: PrimitiveStore = TopLevelComposer::make_primitives().0; + let primitive: PrimitiveStore = TopLevelComposer::make_primitives(SIZE_T).0; let (mut composer, builtins_def, builtins_ty) = - TopLevelComposer::new(vec![], ComposerConfig::default()); + TopLevelComposer::new(vec![], ComposerConfig::default(), SIZE_T); let internal_resolver: Arc = ResolverInternal { id_to_type: builtins_ty.into(), @@ -400,7 +401,7 @@ fn main() { membuffer.lock().push(buffer); }))); let threads = (0..threads) - .map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), 64))) + .map(|i| Box::new(DefaultCodeGenerator::new(format!("module{i}"), SIZE_T))) .collect(); let (registry, handles) = WorkerRegistry::create_workers(threads, top_level, &llvm_options, &f); registry.add_task(task); -- 2.44.1 From e435b2575602c9042efd3d91c5f2fe1d7685760c Mon Sep 17 00:00:00 2001 From: David Mak Date: Wed, 20 Dec 2023 18:30:44 +0800 Subject: [PATCH 4/9] core: Allow implicit promotions of integral literals It should not matter, since it is the value of the literal that matters with respect to the const generic variable. --- nac3core/src/typecheck/typedef/mod.rs | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index d275d8a..04a905d 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -789,7 +789,23 @@ impl Unifier { (TLiteral { values: val1, .. }, TLiteral { values: val2, .. }) => { for (v1, v2) in zip(val1, val2) { if v1 != v2 { - return self.incompatible_types(a, b) + let symbol_value_to_int = |value: &SymbolValue| -> Option { + match value { + SymbolValue::I32(v) => Some(*v as i128), + SymbolValue::I64(v) => Some(*v as i128), + SymbolValue::U32(v) => Some(*v as i128), + SymbolValue::U64(v) => Some(*v as i128), + _ => None, + } + }; + + // Try performing integer promotion on literals + let v1i = symbol_value_to_int(v1); + let v2i = symbol_value_to_int(v2); + + if v1i != v2i { + return self.incompatible_types(a, b) + } } } -- 2.44.1 From 03870f222d96988f83a46d4314b0d5a51e9679e8 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 15 Dec 2023 16:57:23 +0800 Subject: [PATCH 5/9] core: Extract special method handling in type inferencer To prepare for more special handling with methods. --- nac3core/src/typecheck/type_inferencer/mod.rs | 232 ++++++++++-------- 1 file changed, 126 insertions(+), 106 deletions(-) diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index c436076..7ff3e0b 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -10,7 +10,13 @@ use itertools::izip; use nac3parser::ast::{ self, fold::{self, Fold}, - Arguments, Comprehension, ExprContext, ExprKind, Located, Location, StrRef, + Arguments, + Comprehension, + ExprContext, + ExprKind, + Located, + Location, + StrRef }; #[cfg(test)] @@ -773,6 +779,120 @@ impl<'a> Inferencer<'a> { }) } + /// Tries to fold a special call. Returns [`Some`] if the call expression `func` is a special call, otherwise + /// returns [`None`]. + fn try_fold_special_call( + &mut self, + location: Location, + func: &ast::Expr<()>, + args: &mut Vec>, + keywords: &Vec>, + ) -> Result>>, HashSet> { + let Located { location: func_location, node: ExprKind::Name { id, ctx }, .. } = func else { + return Ok(None) + }; + + // handle special functions that cannot be typed in the usual way... + if id == &"virtual".into() { + if args.is_empty() || args.len() > 2 || !keywords.is_empty() { + return report_error( + "`virtual` can only accept 1/2 positional arguments", + *func_location, + ) + } + let arg0 = self.fold_expr(args.remove(0))?; + let ty = if let Some(arg) = args.pop() { + let top_level_defs = self.top_level.definitions.read(); + self.function_data.resolver.parse_type_annotation( + top_level_defs.as_slice(), + self.unifier, + self.primitives, + &arg, + )? + } else { + self.unifier.get_dummy_var().0 + }; + self.virtual_checks.push((arg0.custom.unwrap(), ty, *func_location)); + let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); + return Ok(Some(Located { + location, + custom, + node: ExprKind::Call { + func: Box::new(Located { + custom: None, + location: func.location, + node: ExprKind::Name { id: *id, ctx: ctx.clone() }, + }), + args: vec![arg0], + keywords: vec![], + }, + })) + } + // int64 is special because its argument can be a constant larger than int32 + if id == &"int64".into() && args.len() == 1 { + if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = + &args[0].node + { + let custom = Some(self.primitives.int64); + let v: Result = (*val).try_into(); + return if v.is_ok() { + Ok(Some(Located { + location: args[0].location, + custom, + node: ExprKind::Constant { + value: ast::Constant::Int(*val), + kind: kind.clone(), + }, + })) + } else { + report_error("Integer out of bound", args[0].location) + } + } + } + if id == &"uint32".into() && args.len() == 1 { + if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = + &args[0].node + { + let custom = Some(self.primitives.uint32); + let v: Result = (*val).try_into(); + return if v.is_ok() { + Ok(Some(Located { + location: args[0].location, + custom, + node: ExprKind::Constant { + value: ast::Constant::Int(*val), + kind: kind.clone(), + }, + })) + } else { + report_error("Integer out of bound", args[0].location) + } + } + } + if id == &"uint64".into() && args.len() == 1 { + if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = + &args[0].node + { + let custom = Some(self.primitives.uint64); + let v: Result = (*val).try_into(); + return if v.is_ok() { + Ok(Some(Located { + location: args[0].location, + custom, + node: ExprKind::Constant { + value: ast::Constant::Int(*val), + kind: kind.clone(), + }, + })) + } else { + report_error("Integer out of bound", args[0].location) + } + } + } + + Ok(None) + } + fn fold_call( &mut self, location: Location, @@ -780,111 +900,11 @@ impl<'a> Inferencer<'a> { mut args: Vec>, keywords: Vec>, ) -> Result>, HashSet> { - let func = - if let Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } = - func - { - // handle special functions that cannot be typed in the usual way... - if id == "virtual".into() { - if args.is_empty() || args.len() > 2 || !keywords.is_empty() { - return report_error( - "`virtual` can only accept 1/2 positional arguments", - func_location, - ); - } - let arg0 = self.fold_expr(args.remove(0))?; - let ty = if let Some(arg) = args.pop() { - let top_level_defs = self.top_level.definitions.read(); - self.function_data.resolver.parse_type_annotation( - top_level_defs.as_slice(), - self.unifier, - self.primitives, - &arg, - )? - } else { - self.unifier.get_dummy_var().0 - }; - self.virtual_checks.push((arg0.custom.unwrap(), ty, func_location)); - let custom = Some(self.unifier.add_ty(TypeEnum::TVirtual { ty })); - return Ok(Located { - location, - custom, - node: ExprKind::Call { - func: Box::new(Located { - custom: None, - location: func.location, - node: ExprKind::Name { id, ctx }, - }), - args: vec![arg0], - keywords: vec![], - }, - }); - } - // int64 is special because its argument can be a constant larger than int32 - if id == "int64".into() && args.len() == 1 { - if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = - &args[0].node - { - let custom = Some(self.primitives.int64); - let v: Result = (*val).try_into(); - return if v.is_ok() { - Ok(Located { - location: args[0].location, - custom, - node: ExprKind::Constant { - value: ast::Constant::Int(*val), - kind: kind.clone(), - }, - }) - } else { - report_error("Integer out of bound", args[0].location) - } - } - } - if id == "uint32".into() && args.len() == 1 { - if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = - &args[0].node - { - let custom = Some(self.primitives.uint32); - let v: Result = (*val).try_into(); - return if v.is_ok() { - Ok(Located { - location: args[0].location, - custom, - node: ExprKind::Constant { - value: ast::Constant::Int(*val), - kind: kind.clone(), - }, - }) - } else { - report_error("Integer out of bound", args[0].location) - } - } - } - if id == "uint64".into() && args.len() == 1 { - if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = - &args[0].node - { - let custom = Some(self.primitives.uint64); - let v: Result = (*val).try_into(); - return if v.is_ok() { - Ok(Located { - location: args[0].location, - custom, - node: ExprKind::Constant { - value: ast::Constant::Int(*val), - kind: kind.clone(), - }, - }) - } else { - report_error("Integer out of bound", args[0].location) - } - } - } - Located { location: func_location, custom, node: ExprKind::Name { id, ctx } } - } else { - func - }; + let func = if let Some(spec_call_func) = self.try_fold_special_call(location, &func, &mut args, &keywords)? { + return Ok(spec_call_func) + } else { + func + }; let func = Box::new(self.fold_expr(func)?); let args = args.into_iter().map(|v| self.fold_expr(v)).collect::, _>>()?; let keywords = keywords -- 2.44.1 From c39547209479055745437ddc60de7172c6842338 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 6 Nov 2023 18:03:52 +0800 Subject: [PATCH 6/9] core: Initial infrastructure for ndarray --- nac3artiq/src/codegen.rs | 11 ++ nac3artiq/src/lib.rs | 2 + nac3artiq/src/symbol_resolver.rs | 44 ++++++- nac3core/src/codegen/concrete_type.rs | 14 ++ nac3core/src/codegen/expr.rs | 3 + nac3core/src/codegen/mod.rs | 18 +++ nac3core/src/codegen/stmt.rs | 122 +++++++++--------- nac3core/src/symbol_resolver.rs | 29 ++++- nac3core/src/toplevel/builtins.rs | 59 ++++++--- nac3core/src/toplevel/type_annotation.rs | 23 +++- nac3core/src/typecheck/type_inferencer/mod.rs | 27 +++- nac3core/src/typecheck/typedef/mod.rs | 57 +++++++- nac3core/src/typecheck/typedef/test.rs | 1 + 13 files changed, 314 insertions(+), 96 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 142fe51..98f51bd 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -400,6 +400,9 @@ fn gen_rpc_tag( buffer.push(b'l'); gen_rpc_tag(ctx, *ty, buffer)?; } + TNDArray { .. } => { + todo!() + } _ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))), } } @@ -673,6 +676,14 @@ pub fn attributes_writeback( values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap())); } }, + TypeEnum::TNDArray { ty: elem_ty, .. } => { + if gen_rpc_tag(ctx, *elem_ty, &mut scratch_buffer).is_ok() { + let pydict = PyDict::new(py); + pydict.set_item("obj", val)?; + host_attributes.append(pydict)?; + values.push((ty, inner_resolver.get_obj_value(py, val, ctx, generator, ty)?.unwrap())); + } + }, _ => {} } } diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 9a0ad86..7a1183d 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -85,6 +85,7 @@ pub struct PrimitivePythonId { float64: u64, bool: u64, list: u64, + ndarray: u64, tuple: u64, typevar: u64, const_generic_marker: u64, @@ -879,6 +880,7 @@ impl Nac3 { float: get_attr_id(builtins_mod, "float"), float64: get_attr_id(numpy_mod, "float64"), list: get_attr_id(builtins_mod, "list"), + ndarray: get_attr_id(numpy_mod, "NDArray"), tuple: get_attr_id(builtins_mod, "tuple"), exception: get_attr_id(builtins_mod, "Exception"), option: get_id(artiq_builtins.get_item("Option").ok().flatten().unwrap()), diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index 631e2de..aeb99fa 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -302,6 +302,12 @@ impl InnerResolver { let var = unifier.get_dummy_var().0; let list = unifier.add_ty(TypeEnum::TList { ty: var }); Ok(Ok((list, false))) + } else if ty_id == self.primitive_ids.ndarray { + // do not handle type var param and concrete check here + let var = unifier.get_dummy_var().0; + let ndims = unifier.get_fresh_const_generic_var(primitives.usize(), None, None).0; + let ndarray = unifier.add_ty(TypeEnum::TNDArray { ty: var, ndims }); + Ok(Ok((ndarray, false))) } else if ty_id == self.primitive_ids.tuple { // do not handle type var param and concrete check here Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) @@ -446,6 +452,16 @@ impl InnerResolver { ))); } } + TypeEnum::TNDArray { .. } => { + if args.len() != 2 { + return Ok(Err(format!( + "type list needs exactly 2 type parameters, found {}", + args.len() + ))); + } + + todo!() + } TypeEnum::TTuple { .. } => { let args = match args .iter() @@ -607,7 +623,7 @@ impl InnerResolver { Err(e) => return Ok(Err(e)), }; match (&*unifier.get_ty(extracted_ty), inst_check) { - // do the instantiation for these three types + // do the instantiation for these four types (TypeEnum::TList { ty }, false) => { let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; if len == 0 { @@ -632,6 +648,30 @@ impl InnerResolver { } } } + (TypeEnum::TNDArray { ty, ndims }, false) => { + let len: usize = self.helper.len_fn.call1(py, (obj,))?.extract(py)?; + if len == 0 { + assert!(matches!( + &*unifier.get_ty(*ty), + TypeEnum::TVar { fields: None, range, .. } + if range.is_empty() + )); + Ok(Ok(extracted_ty)) + } else { + let actual_ty = + self.get_list_elem_type(py, obj, len, unifier, defs, primitives)?; + match actual_ty { + Ok(t) => match unifier.unify(*ty, t) { + Ok(_) => Ok(Ok(unifier.add_ty(TypeEnum::TNDArray { ty: *ty, ndims: *ndims }))), + Err(e) => Ok(Err(format!( + "type error ({}) for the ndarray", + e.to_display(unifier).to_string() + ))), + }, + Err(e) => Ok(Err(e)), + } + } + } (TypeEnum::TTuple { .. }, false) => { let elements: &PyTuple = obj.downcast()?; let types: Result, _>, _> = elements @@ -898,6 +938,8 @@ impl InnerResolver { global.set_initializer(&val); Ok(Some(global.as_pointer_value().into())) + } else if ty_id == self.primitive_ids.ndarray { + todo!() } else if ty_id == self.primitive_ids.tuple { let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty); let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else { diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index 7745160..a440276 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -47,6 +47,10 @@ pub enum ConcreteTypeEnum { TList { ty: ConcreteType, }, + TNDArray { + ty: ConcreteType, + ndims: ConcreteType, + }, TObj { obj_id: DefinitionId, fields: HashMap, @@ -167,6 +171,10 @@ impl ConcreteTypeStore { TypeEnum::TList { ty } => ConcreteTypeEnum::TList { ty: self.from_unifier_type(unifier, primitives, *ty, cache), }, + TypeEnum::TNDArray { ty, ndims } => ConcreteTypeEnum::TNDArray { + ty: self.from_unifier_type(unifier, primitives, *ty, cache), + ndims: self.from_unifier_type(unifier, primitives, *ndims, cache), + }, TypeEnum::TObj { obj_id, fields, params } => ConcreteTypeEnum::TObj { obj_id: *obj_id, fields: fields @@ -260,6 +268,12 @@ impl ConcreteTypeStore { ConcreteTypeEnum::TList { ty } => { TypeEnum::TList { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } } + ConcreteTypeEnum::TNDArray { ty, ndims } => { + TypeEnum::TNDArray { + ty: self.to_unifier_type(unifier, primitives, *ty, cache), + ndims: self.to_unifier_type(unifier, primitives, *ndims, cache), + } + } ConcreteTypeEnum::TVirtual { ty } => { TypeEnum::TVirtual { ty: self.to_unifier_type(unifier, primitives, *ty, cache) } } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3a894a4..79ebe4a 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -1846,6 +1846,9 @@ pub fn gen_expr<'ctx, G: CodeGenerator>( ctx.build_gep_and_load(arr_ptr, &[index], None).into() } } + TypeEnum::TNDArray { .. } => { + return Err(String::from("subscript operator for ndarray not implemented")) + } TypeEnum::TTuple { .. } => { let index: u32 = if let ExprKind::Constant { value: Constant::Int(v), .. } = &slice.node { diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 41bd2a8..21943d4 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -507,6 +507,24 @@ fn get_llvm_type<'ctx>( ]; ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into() } + TNDArray { ty, .. } => { + let llvm_usize = generator.get_size_type(ctx); + let element_type = get_llvm_type( + ctx, module, generator, unifier, top_level, type_cache, primitives, *ty, + ); + + // struct NDArray { num_dims: size_t, dims: size_t*, data: T* } + // + // * num_dims: Number of dimensions in the array + // * dims: Pointer to an array containing the size of each dimension + // * data: Pointer to an array containing the array data + let fields = [ + llvm_usize.into(), + llvm_usize.ptr_type(AddressSpace::default()).into(), + element_type.ptr_type(AddressSpace::default()).into(), + ]; + ctx.struct_type(&fields, false).ptr_type(AddressSpace::default()).into() + } TVirtual { .. } => unimplemented!(), _ => unreachable!("{}", ty_enum.get_type_name()), }; diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index f15c2b9..e53ea32 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -99,63 +99,69 @@ pub fn gen_store_target<'ctx, G: CodeGenerator>( } } ExprKind::Subscript { value, slice, .. } => { - assert!(matches!( - ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref(), - TypeEnum::TList { .. }, - )); - let i32_type = ctx.ctx.i32_type(); - let zero = i32_type.const_zero(); - let v = if let Some(v) = generator.gen_expr(ctx, value)? { - v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value() - } else { - return Ok(None) - }; - let len = ctx - .build_gep_and_load(v, &[zero, i32_type.const_int(1, false)], Some("len")) - .into_int_value(); - let raw_index = if let Some(v) = generator.gen_expr(ctx, slice)? { - v.to_basic_value_enum(ctx, generator, slice.custom.unwrap())?.into_int_value() - } else { - return Ok(None) - }; - let raw_index = ctx.builder.build_int_s_extend( - raw_index, - generator.get_size_type(ctx.ctx), - "sext", - ); - // handle negative index - let is_negative = ctx.builder.build_int_compare( - IntPredicate::SLT, - raw_index, - generator.get_size_type(ctx.ctx).const_zero(), - "is_neg", - ); - let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted"); - let index = ctx - .builder - .build_select(is_negative, adjusted, raw_index, "index") - .into_int_value(); - // unsigned less than is enough, because negative index after adjustment is - // bigger than the length (for unsigned cmp) - let bound_check = ctx.builder.build_int_compare( - IntPredicate::ULT, - index, - len, - "inbound", - ); - ctx.make_assert( - generator, - bound_check, - "0:IndexError", - "index {0} out of bounds 0:{1}", - [Some(raw_index), Some(len), None], - slice.location, - ); - unsafe { - let arr_ptr = ctx - .build_gep_and_load(v, &[i32_type.const_zero(), i32_type.const_zero()], Some("arr.addr")) - .into_pointer_value(); - ctx.builder.build_gep(arr_ptr, &[index], name.unwrap_or("")) + match ctx.unifier.get_ty_immutable(value.custom.unwrap()).as_ref() { + TypeEnum::TList { .. } => { + let i32_type = ctx.ctx.i32_type(); + let zero = i32_type.const_zero(); + let v = generator + .gen_expr(ctx, value)? + .unwrap() + .to_basic_value_enum(ctx, generator, value.custom.unwrap())? + .into_pointer_value(); + let len = ctx + .build_gep_and_load(v, &[zero, i32_type.const_int(1, false)], Some("len")) + .into_int_value(); + let raw_index = generator + .gen_expr(ctx, slice)? + .unwrap() + .to_basic_value_enum(ctx, generator, slice.custom.unwrap())? + .into_int_value(); + let raw_index = ctx.builder.build_int_s_extend( + raw_index, + generator.get_size_type(ctx.ctx), + "sext", + ); + // handle negative index + let is_negative = ctx.builder.build_int_compare( + IntPredicate::SLT, + raw_index, + generator.get_size_type(ctx.ctx).const_zero(), + "is_neg", + ); + let adjusted = ctx.builder.build_int_add(raw_index, len, "adjusted"); + let index = ctx + .builder + .build_select(is_negative, adjusted, raw_index, "index") + .into_int_value(); + // unsigned less than is enough, because negative index after adjustment is + // bigger than the length (for unsigned cmp) + let bound_check = ctx.builder.build_int_compare( + IntPredicate::ULT, + index, + len, + "inbound", + ); + ctx.make_assert( + generator, + bound_check, + "0:IndexError", + "index {0} out of bounds 0:{1}", + [Some(raw_index), Some(len), None], + slice.location, + ); + unsafe { + let arr_ptr = ctx + .build_gep_and_load(v, &[i32_type.const_zero(), i32_type.const_zero()], Some("arr.addr")) + .into_pointer_value(); + ctx.builder.build_gep(arr_ptr, &[index], name.unwrap_or("")) + } + } + + TypeEnum::TNDArray { .. } => { + todo!() + } + + _ => unreachable!(), } } _ => unreachable!(), @@ -203,7 +209,7 @@ pub fn gen_assign<'ctx, G: CodeGenerator>( let value = value .to_basic_value_enum(ctx, generator, target.custom.unwrap())? .into_pointer_value(); - let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(target.custom.unwrap()) else { + let (TypeEnum::TList { ty } | TypeEnum::TNDArray { ty, .. }) = &*ctx.unifier.get_ty(target.custom.unwrap()) else { unreachable!() }; diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 0932bea..3ff5e0e 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -354,13 +354,14 @@ pub trait SymbolResolver { } thread_local! { - static IDENTIFIER_ID: [StrRef; 11] = [ + static IDENTIFIER_ID: [StrRef; 12] = [ "int32".into(), "int64".into(), "float".into(), "bool".into(), "virtual".into(), "list".into(), + "ndarray".into(), "tuple".into(), "str".into(), "Exception".into(), @@ -385,11 +386,12 @@ pub fn parse_type_annotation( let bool_id = ids[3]; let virtual_id = ids[4]; let list_id = ids[5]; - let tuple_id = ids[6]; - let str_id = ids[7]; - let exn_id = ids[8]; - let uint32_id = ids[9]; - let uint64_id = ids[10]; + let ndarray_id = ids[6]; + let tuple_id = ids[7]; + let str_id = ids[8]; + let exn_id = ids[9]; + let uint32_id = ids[10]; + let uint64_id = ids[11]; let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| { if *id == int32_id { @@ -460,6 +462,21 @@ pub fn parse_type_annotation( } else if *id == list_id { let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, slice)?; Ok(unifier.add_ty(TypeEnum::TList { ty })) + } else if *id == ndarray_id { + let Tuple { elts, .. } = &slice.node else { + return Err(HashSet::from([ + String::from("Expected 2 type arguments for ndarray"), + ])) + }; + if elts.len() < 2 { + return Err(HashSet::from([ + String::from("Expected 2 type arguments for ndarray"), + ])) + } + + let ty = parse_type_annotation(resolver, top_level_defs, unifier, primitives, &elts[0])?; + let ndims = parse_type_annotation(resolver, top_level_defs, unifier, primitives, &elts[1])?; + Ok(unifier.add_ty(TypeEnum::TNDArray { ty, ndims })) } else if *id == tuple_id { if let Tuple { elts, .. } = &slice.node { let ty = elts diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 0b8ccde..00da280 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -470,6 +470,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), + { + let tvar = primitives.1.get_fresh_var(Some("T".into()), None); + let ndims = primitives.1.get_fresh_const_generic_var(primitives.0.uint64, Some("N".into()), None); + + Arc::new(RwLock::new(TopLevelDef::Class { + name: "ndarray".into(), + object_id: DefinitionId(14), + type_vars: vec![tvar.0, ndims.0], + fields: Vec::default(), + methods: Vec::default(), + ancestors: Vec::default(), + constructor: None, + resolver: None, + loc: None, + })) + }, Arc::new(RwLock::new(TopLevelDef::Function { name: "int32".into(), simple_name: "int32".into(), @@ -1265,10 +1281,13 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { }), ), Arc::new(RwLock::new({ - let list_var = primitives.1.get_fresh_var(Some("L".into()), None); - let list = primitives.1.add_ty(TypeEnum::TList { ty: list_var.0 }); + let tvar = primitives.1.get_fresh_var(Some("L".into()), None); + let list = primitives.1.add_ty(TypeEnum::TList { ty: tvar.0 }); + let ndims = primitives.1.get_fresh_const_generic_var(primitives.0.uint64, Some("N".into()), None); + let ndarray = primitives.1.add_ty(TypeEnum::TNDArray { ty: tvar.0, ndims: ndims.0 }); + let arg_ty = primitives.1.get_fresh_var_with_range( - &[list, primitives.0.range], + &[list, ndarray, primitives.0.range], Some("I".into()), None, ); @@ -1278,7 +1297,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { args: vec![FuncArg { name: "ls".into(), ty: arg_ty.0, default_value: None }], ret: int32, - vars: vec![(list_var.1, list_var.0), (arg_ty.1, arg_ty.0)] + vars: vec![(tvar.1, tvar.0), (arg_ty.1, arg_ty.0)] .into_iter() .collect(), })), @@ -1296,19 +1315,25 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let (start, end, step) = destructure_range(ctx, arg); Some(calculate_len_for_slice_range(generator, ctx, start, end, step).into()) } else { - let int32 = ctx.ctx.i32_type(); - let zero = int32.const_zero(); - let len = ctx - .build_gep_and_load( - arg.into_pointer_value(), - &[zero, int32.const_int(1, false)], - None, - ) - .into_int_value(); - if len.get_type().get_bit_width() == 32 { - Some(len.into()) - } else { - Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into()) + match &*ctx.unifier.get_ty_immutable(arg_ty) { + TypeEnum::TList { .. } => { + let int32 = ctx.ctx.i32_type(); + let zero = int32.const_zero(); + let len = ctx + .build_gep_and_load( + arg.into_pointer_value(), + &[zero, int32.const_int(1, false)], + None, + ) + .into_int_value(); + if len.get_type().get_bit_width() == 32 { + Some(len.into()) + } else { + Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into()) + } + } + TypeEnum::TNDArray { .. } => todo!(), + _ => unreachable!(), } }) }, diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 4482bc5..94897ff 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -491,11 +491,24 @@ pub fn get_type_from_type_annotation_kinds( (*name, (subst_ty, *mutability)) })); let need_subst = !subst.is_empty(); - let ty = unifier.add_ty(TypeEnum::TObj { - obj_id: *obj_id, - fields: tobj_fields, - params: subst, - }); + let ty = if obj_id == &DefinitionId(14) { + assert_eq!(subst.len(), 2); + let tv_tys = subst.iter() + .sorted_by_key(|(k, _)| *k) + .map(|(_, v)| v) + .collect_vec(); + + unifier.add_ty(TypeEnum::TNDArray { + ty: *tv_tys[0], + ndims: *tv_tys[1], + }) + } else { + unifier.add_ty(TypeEnum::TObj { + obj_id: *obj_id, + fields: tobj_fields, + params: subst, + }) + }; if need_subst { if let Some(wl) = subst_list.as_mut() { wl.push(ty); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 7ff3e0b..f9d3396 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -223,8 +223,12 @@ impl<'a> Fold<()> for Inferencer<'a> { if self.unifier.unioned(iter.custom.unwrap(), self.primitives.range) { self.unify(self.primitives.int32, target.custom.unwrap(), &target.location)?; } else { - let list = self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }); - self.unify(list, iter.custom.unwrap(), &iter.location)?; + let list_like_ty = match &*self.unifier.get_ty(iter.custom.unwrap()) { + TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty: target.custom.unwrap() }), + TypeEnum::TNDArray { .. } => todo!(), + _ => unreachable!(), + }; + self.unify(list_like_ty, iter.custom.unwrap(), &iter.location)?; } let body = body.into_iter().map(|b| self.fold_stmt(b)).collect::, _>>()?; @@ -1137,9 +1141,13 @@ impl<'a> Inferencer<'a> { for v in [lower.as_ref(), upper.as_ref(), step.as_ref()].iter().flatten() { self.constrain(v.custom.unwrap(), self.primitives.int32, &v.location)?; } - let list = self.unifier.add_ty(TypeEnum::TList { ty }); - self.constrain(value.custom.unwrap(), list, &value.location)?; - Ok(list) + let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { + TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), + TypeEnum::TNDArray { ndims, .. } => self.unifier.add_ty(TypeEnum::TNDArray { ty, ndims: *ndims }), + _ => unreachable!() + }; + self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?; + Ok(list_like_ty) } ExprKind::Constant { value: ast::Constant::Int(val), .. } => { // the index is a constant, so value can be a sequence. @@ -1159,10 +1167,15 @@ impl<'a> Inferencer<'a> { { return report_error("Tuple index must be a constant (KernelInvariant is also not supported)", slice.location) } + // the index is not a constant, so value can only be a list self.constrain(slice.custom.unwrap(), self.primitives.int32, &slice.location)?; - let list = self.unifier.add_ty(TypeEnum::TList { ty }); - self.constrain(value.custom.unwrap(), list, &value.location)?; + let list_like_ty = match &*self.unifier.get_ty(value.custom.unwrap()) { + TypeEnum::TList { .. } => self.unifier.add_ty(TypeEnum::TList { ty }), + TypeEnum::TNDArray { .. } => todo!(), + _ => unreachable!(), + }; + self.constrain(value.custom.unwrap(), list_like_ty, &value.location)?; Ok(ty) } } diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 04a905d..e0a72e8 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -159,6 +159,11 @@ pub enum TypeEnum { ty: Type, }, + TNDArray { + ty: Type, + ndims: Type, + }, + /// An object type. TObj { /// The [DefintionId] of this object type. @@ -193,6 +198,7 @@ impl TypeEnum { TypeEnum::TLiteral { .. } => "TConstant", TypeEnum::TTuple { .. } => "TTuple", TypeEnum::TList { .. } => "TList", + TypeEnum::TNDArray { .. } => "TNDArray", TypeEnum::TObj { .. } => "TObj", TypeEnum::TVirtual { .. } => "TVirtual", TypeEnum::TCall { .. } => "TCall", @@ -418,6 +424,9 @@ impl Unifier { TypeEnum::TList { ty } => self .get_instantiations(*ty) .map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TList { ty })).collect_vec()), + TypeEnum::TNDArray { ty, ndims } => self + .get_instantiations(*ty) + .map(|ty| ty.iter().map(|&ty| self.add_ty(TypeEnum::TNDArray { ty, ndims: *ndims })).collect_vec()), TypeEnum::TVirtual { ty } => self.get_instantiations(*ty).map(|ty| { ty.iter().map(|&ty| self.add_ty(TypeEnum::TVirtual { ty })).collect_vec() }), @@ -470,6 +479,7 @@ impl Unifier { TVar { .. } => allowed_typevars.iter().any(|b| self.unification_table.unioned(a, *b)), TCall { .. } => false, TList { ty } | TVirtual { ty } => self.is_concrete(*ty, allowed_typevars), + TNDArray { ty, .. } => self.is_concrete(*ty, allowed_typevars), TTuple { ty } => ty.iter().all(|ty| self.is_concrete(*ty, allowed_typevars)), TObj { params: vars, .. } => { vars.values().all(|ty| self.is_concrete(*ty, allowed_typevars)) @@ -717,7 +727,8 @@ impl Unifier { self.unify_impl(x, b, false)?; self.set_a_to_b(a, x); } - (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) => { + (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TList { ty }) | + (TVar { fields: Some(fields), range, is_const_generic: false, .. }, TNDArray { ty, .. }) => { for (k, v) in fields { match *k { RecordKey::Int(_) => { @@ -829,6 +840,15 @@ impl Unifier { } self.set_a_to_b(a, b); } + (TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => { + if self.unify_impl(*ty1, *ty2, false).is_err() { + return self.incompatible_types(a, b) + } + if self.unify_impl(*ndims1, *ndims2, false).is_err() { + return self.incompatible_types(a, b) + } + self.set_a_to_b(a, b); + } (TVar { fields: Some(map), range, .. }, TObj { fields, .. }) => { for (k, field) in map { match *k { @@ -1076,6 +1096,13 @@ impl Unifier { TypeEnum::TList { ty } => { format!("list[{}]", self.internal_stringify(*ty, obj_to_name, var_to_name, notes)) } + TypeEnum::TNDArray { ty, ndims } => { + format!( + "ndarray[{}, {}]", + self.internal_stringify(*ty, obj_to_name, var_to_name, notes), + self.internal_stringify(*ndims, obj_to_name, var_to_name, notes), + ) + } TypeEnum::TVirtual { ty } => { format!( "virtual[{}]", @@ -1195,7 +1222,7 @@ impl Unifier { // variables, i.e. things like TRecord, TCall should not occur, and we // should be safe to not implement the substitution for those variants. match &*ty { - TypeEnum::TRigidVar { .. } => None, + TypeEnum::TRigidVar { .. } | TypeEnum::TLiteral { .. } => None, TypeEnum::TVar { id, .. } => mapping.get(id).copied(), TypeEnum::TTuple { ty } => { let mut new_ty = Cow::from(ty); @@ -1213,6 +1240,19 @@ impl Unifier { TypeEnum::TList { ty } => { self.subst_impl(*ty, mapping, cache).map(|t| self.add_ty(TypeEnum::TList { ty: t })) } + TypeEnum::TNDArray { ty, ndims } => { + let new_ty = self.subst_impl(*ty, mapping, cache); + let new_ndims = self.subst_impl(*ndims, mapping, cache); + + if new_ty.is_some() || new_ndims.is_some() { + Some(self.add_ty(TypeEnum::TNDArray { + ty: new_ty.unwrap_or(*ty), + ndims: new_ndims.unwrap_or(*ndims) + })) + } else { + None + } + } TypeEnum::TVirtual { ty } => self .subst_impl(*ty, mapping, cache) .map(|t| self.add_ty(TypeEnum::TVirtual { ty: t })), @@ -1383,6 +1423,19 @@ impl Unifier { (TList { ty: ty1 }, TList { ty: ty2 }) => { Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TList { ty }))) } + (TNDArray { ty: ty1, ndims: ndims1 }, TNDArray { ty: ty2, ndims: ndims2 }) => { + let ty = self.get_intersection(*ty1, *ty2)?; + let ndims = self.get_intersection(*ndims1, *ndims2)?; + + Ok(if ty.is_some() || ndims.is_some() { + Some(self.add_ty(TNDArray { + ty: ty.unwrap_or(*ty1), + ndims: ndims.unwrap_or(*ndims1), + })) + } else { + None + }) + } (TVirtual { ty: ty1 }, TVirtual { ty: ty2 }) => { Ok(self.get_intersection(*ty1, *ty2)?.map(|ty| self.add_ty(TVirtual { ty }))) } diff --git a/nac3core/src/typecheck/typedef/test.rs b/nac3core/src/typecheck/typedef/test.rs index 3069b57..eb44c22 100644 --- a/nac3core/src/typecheck/typedef/test.rs +++ b/nac3core/src/typecheck/typedef/test.rs @@ -33,6 +33,7 @@ impl Unifier { && ty1.iter().zip(ty2.iter()).all(|(t1, t2)| self.eq(*t1, *t2)) } (TypeEnum::TList { ty: ty1 }, TypeEnum::TList { ty: ty2 }) + | (TypeEnum::TNDArray { ty: ty1 }, TypeEnum::TNDArray { ty: ty2 }) | (TypeEnum::TVirtual { ty: ty1 }, TypeEnum::TVirtual { ty: ty2 }) => { self.eq(*ty1, *ty2) } -- 2.44.1 From afa7d9b100321b60e1e58b9f81bb6616870179ca Mon Sep 17 00:00:00 2001 From: David Mak Date: Thu, 21 Dec 2023 14:43:19 +0800 Subject: [PATCH 7/9] core: Implement helper for creation of generic ndarray --- nac3core/src/typecheck/typedef/mod.rs | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index e0a72e8..ae84b0a 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -205,6 +205,27 @@ impl TypeEnum { TypeEnum::TFunc { .. } => "TFunc", } } + + /// Returns a [TypeEnum] representing a generic `ndarray` type. + /// + /// * `dtype` - The datatype of the `ndarray`, or `None` if the datatype is generic. + /// * `ndims` - The number of dimensions of the `ndarray`, or `None` if the number of dimensions is generic. + #[must_use] + pub fn ndarray( + unifier: &mut Unifier, + dtype: Option, + ndims: Option, + primitives: &PrimitiveStore + ) -> TypeEnum { + let dtype = dtype.unwrap_or_else(|| unifier.get_fresh_var(Some("T".into()), None).0); + let ndims = ndims + .unwrap_or_else(|| unifier.get_fresh_const_generic_var(primitives.usize(), Some("N".into()), None).0); + + TypeEnum::TNDArray { + ty: dtype, + ndims, + } + } } pub type SharedUnifier = Arc, u32, Vec)>>; -- 2.44.1 From 27fcf8926ef17d70b2cd470d18d8c8fc5ec21353 Mon Sep 17 00:00:00 2001 From: David Mak Date: Fri, 17 Nov 2023 17:30:27 +0800 Subject: [PATCH 8/9] core: Implement ndarray constructor and numpy.empty --- nac3core/src/codegen/irrt/irrt.c | 44 ++++ nac3core/src/codegen/irrt/mod.rs | 176 ++++++++++++++++ nac3core/src/codegen/stmt.rs | 76 ++++++- nac3core/src/toplevel/builtins.rs | 34 ++- nac3core/src/toplevel/mod.rs | 1 + nac3core/src/toplevel/numpy.rs | 198 ++++++++++++++++++ nac3core/src/typecheck/type_inferencer/mod.rs | 49 ++++- nac3standalone/demo/interpret_demo.py | 23 +- nac3standalone/demo/src/ndarray.py | 22 ++ 9 files changed, 619 insertions(+), 4 deletions(-) create mode 100644 nac3core/src/toplevel/numpy.rs create mode 100644 nac3standalone/demo/src/ndarray.py diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index d68b344..80e48aa 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -196,4 +196,48 @@ double __nac3_j0(double x) { } return j0(x); +} + +uint32_t __nac3_ndarray_calc_size( + const int32_t *list_data, + uint32_t list_len +) { + uint32_t num_elems = 1; + for (uint32_t i = 0; i < list_len; ++i) { + int32_t val = list_data[i]; + __builtin_assume(val >= 0); + num_elems *= (uint32_t) list_data[i]; + } + return num_elems; +} + +uint64_t __nac3_ndarray_calc_size64( + const int32_t *list_data, + uint64_t list_len +) { + uint64_t num_elems = 1; + for (uint64_t i = 0; i < list_len; ++i) { + int32_t val = list_data[i]; + __builtin_assume(val >= 0); + num_elems *= (uint64_t) list_data[i]; + } + return num_elems; +} + +void __nac3_ndarray_init_dims( + uint32_t *ndarray_dims, + const int32_t *shape_data, + uint32_t shape_len +) { + __builtin_memcpy(ndarray_dims, shape_data, shape_len * sizeof(int32_t)); +} + +void __nac3_ndarray_init_dims64( + uint64_t *ndarray_dims, + const int32_t *shape_data, + uint64_t shape_len +) { + for (uint64_t i = 0; i < shape_len; ++i) { + ndarray_dims[i] = (uint64_t) shape_data[i]; + } } \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index 0cd0ebd..d6906c6 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -12,6 +12,9 @@ use inkwell::{ }; use nac3parser::ast::Expr; +#[cfg(debug_assertions)] +use inkwell::types::AnyTypeEnum; + #[must_use] pub fn load_irrt(ctx: &Context) -> Module { let bitcode_buf = MemoryBuffer::create_from_memory_range( @@ -546,3 +549,176 @@ pub fn call_j0<'ctx>( .unwrap_left() .into_float_value() } + +/// Checks whether the pointer `value` refers to a `list` in LLVM. +fn assert_is_list(value: PointerValue) -> PointerValue { + #[cfg(debug_assertions)] + { + let llvm_shape_ty = value.get_type().get_element_type(); + let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else { + panic!("Expected struct type for `list` type, but got {llvm_shape_ty}") + }; + assert_eq!(llvm_shape_ty.count_fields(), 2); + assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..)))); + assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..)))); + } + + value +} + +/// Checks whether the pointer `value` refers to an `NDArray` in LLVM. +fn assert_is_ndarray(value: PointerValue) -> PointerValue { + #[cfg(debug_assertions)] + { + let llvm_ndarray_ty = value.get_type().get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}") + }; + + assert_eq!(llvm_ndarray_ty.count_fields(), 3); + assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..)))); + let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else { + unreachable!() + }; + let BasicTypeEnum::PointerType(dims) = ndarray_dims else { + panic!("Expected pointer type for `list.1`, but got {ndarray_dims}") + }; + assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..))); + assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..)))); + } + + value +} + +/// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the +/// calculated total size. +/// +/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM +/// representation of a `list`. +pub fn call_ndarray_calc_size<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + shape: PointerValue<'ctx>, +) -> IntValue<'ctx> { + assert_is_list(shape); + + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + + let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() { + 32 => "__nac3_ndarray_calc_size", + 64 => "__nac3_ndarray_calc_size64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_calc_size_fn_t = llvm_usize.fn_type( + &[ + llvm_pi32.into(), + llvm_usize.into(), + ], + false, + ); + let ndarray_calc_size_fn = ctx.module.get_function(ndarray_calc_size_fn_name) + .unwrap_or_else(|| { + ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) + }); + + let ( + shape_data, + shape_len, + ) = unsafe { + ( + ctx.builder.build_in_bounds_gep( + shape, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + "" + ), + ctx.builder.build_in_bounds_gep( + shape, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + "" + ), + ) + }; + + ctx.builder + .build_call( + ndarray_calc_size_fn, + &[ + ctx.builder.build_load(shape_data, "").into(), + ctx.builder.build_load(shape_len, "").into(), + ], + "", + ) + .try_as_basic_value() + .unwrap_left() + .into_int_value() +} + +/// Generates a call to `__nac3_ndarray_init_dims`. +/// +/// * `ndarray` - LLVM pointer to the NDArray. This value must be the LLVM representation of an +/// `NDArray`. +/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM +/// representation of a `list`. +pub fn call_ndarray_init_dims<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarray: PointerValue<'ctx>, + shape: PointerValue<'ctx>, +) { + assert_is_ndarray(ndarray); + assert_is_list(shape); + + let llvm_void = ctx.ctx.void_type(); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_init_dims_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() { + 32 => "__nac3_ndarray_init_dims", + 64 => "__nac3_ndarray_init_dims64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_init_dims_fn = ctx.module.get_function(ndarray_init_dims_fn_name).unwrap_or_else(|| { + let fn_type = llvm_void.fn_type( + &[ + llvm_pusize.into(), + llvm_pi32.into(), + llvm_usize.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_init_dims_fn_name, fn_type, None) + }); + + let ndarray_dims = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + None, + ); + let shape_data = ctx.build_gep_and_load( + shape, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None + ); + let ndarray_num_dims = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_int_value(); + + ctx.builder.build_call( + ndarray_init_dims_fn, + &[ + ndarray_dims.into(), + shape_data.into(), + ndarray_num_dims.into(), + ], + "", + ); +} \ No newline at end of file diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index e53ea32..111fb6c 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -16,7 +16,7 @@ use inkwell::{ attributes::{Attribute, AttributeLoc}, basic_block::BasicBlock, types::BasicTypeEnum, - values::{BasicValue, BasicValueEnum, FunctionValue, PointerValue}, + values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, IntPredicate, }; use nac3parser::ast::{ @@ -405,6 +405,80 @@ pub fn gen_for( Ok(()) } +/// Generates a C-style `for` construct using lambdas, similar to the following C code: +/// +/// ```c +/// for (x... = init(); cond(x...); update(x...)) { +/// body(x...); +/// } +/// ``` +/// +/// * `init` - A lambda containing IR statements declaring and initializing loop variables. The +/// return value is a [Clone] value which will be passed to the other lambdas. +/// * `cond` - A lambda containing IR statements checking whether the loop should continue +/// executing. The result value must be an `i1` indicating if the loop should continue. +/// * `body` - A lambda containing IR statements within the loop body. +/// * `update` - A lambda containing IR statements updating loop variables. +pub fn gen_for_callback<'ctx, 'a, I, InitFn, CondFn, BodyFn, UpdateFn>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + init: InitFn, + cond: CondFn, + body: BodyFn, + update: UpdateFn, +) -> Result<(), String> + where + I: Clone, + InitFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>) -> Result, + CondFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result, String>, + BodyFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, + UpdateFn: FnOnce(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, I) -> Result<(), String>, +{ + let current = ctx.builder.get_insert_block().and_then(|bb| bb.get_parent()).unwrap(); + let init_bb = ctx.ctx.append_basic_block(current, "for.init"); + // The BB containing the loop condition check + let cond_bb = ctx.ctx.append_basic_block(current, "for.cond"); + let body_bb = ctx.ctx.append_basic_block(current, "for.body"); + // The BB containing the increment expression + let update_bb = ctx.ctx.append_basic_block(current, "for.update"); + let cont_bb = ctx.ctx.append_basic_block(current, "for.end"); + + // store loop bb information and restore it later + let loop_bb = ctx.loop_target.replace((update_bb, cont_bb)); + + ctx.builder.build_unconditional_branch(init_bb); + + let loop_var = { + ctx.builder.position_at_end(init_bb); + let result = init(generator, ctx)?; + ctx.builder.build_unconditional_branch(cond_bb); + + result + }; + + ctx.builder.position_at_end(cond_bb); + let cond = cond(generator, ctx, loop_var.clone())?; + assert_eq!(cond.get_type().get_bit_width(), ctx.ctx.bool_type().get_bit_width()); + ctx.builder.build_conditional_branch( + cond, + body_bb, + cont_bb + ); + + ctx.builder.position_at_end(body_bb); + body(generator, ctx, loop_var.clone())?; + ctx.builder.build_unconditional_branch(update_bb); + + ctx.builder.position_at_end(update_bb); + update(generator, ctx, loop_var)?; + ctx.builder.build_unconditional_branch(cond_bb); + + ctx.builder.position_at_end(cont_bb); + ctx.loop_target = loop_bb; + + Ok(()) +} + /// See [`CodeGenerator::gen_while`]. pub fn gen_while( generator: &mut G, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 00da280..cb3f650 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -13,11 +13,12 @@ use crate::{ stmt::exn_constructor, }, symbol_resolver::SymbolValue, + toplevel::numpy::gen_ndarray_empty, }; use inkwell::{ attributes::{Attribute, AttributeLoc}, types::{BasicType, BasicMetadataTypeEnum}, - values::BasicMetadataValueEnum, + values::{BasicValue, BasicMetadataValueEnum}, FloatPredicate, IntPredicate }; @@ -278,6 +279,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let boolean = primitives.0.bool; let range = primitives.0.range; let string = primitives.0.str; + let ndarray_float = { + let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0); + primitives.1.add_ty(ndarray_ty_enum) + }; + let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 }); let num_ty = primitives.1.get_fresh_var_with_range( &[int32, int64, float, boolean, uint32, uint64], Some("N".into()), @@ -837,6 +843,32 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), + create_fn_by_codegen( + primitives, + &var_map, + "np_ndarray", + ndarray_float, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(list_int32, "shape")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_empty(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ), + create_fn_by_codegen( + primitives, + &var_map, + "np_empty", + ndarray_float, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(list_int32, "shape")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_empty(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ), create_fn_by_codegen( primitives, &var_map, diff --git a/nac3core/src/toplevel/mod.rs b/nac3core/src/toplevel/mod.rs index c62c5c0..c204819 100644 --- a/nac3core/src/toplevel/mod.rs +++ b/nac3core/src/toplevel/mod.rs @@ -25,6 +25,7 @@ pub struct DefinitionId(pub usize); pub mod builtins; pub mod composer; pub mod helper; +pub mod numpy; pub mod type_annotation; use composer::*; use type_annotation::*; diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs new file mode 100644 index 0000000..9b91e82 --- /dev/null +++ b/nac3core/src/toplevel/numpy.rs @@ -0,0 +1,198 @@ +use inkwell::{ + IntPredicate, + types::BasicType, + values::PointerValue, +}; +use nac3parser::ast::StrRef; +use crate::{ + codegen::{ + CodeGenContext, + CodeGenerator, + irrt::{call_ndarray_calc_size, call_ndarray_init_dims}, + stmt::gen_for_callback + }, + symbol_resolver::ValueEnum, + toplevel::DefinitionId, + typecheck::typedef::{FunSignature, Type, TypeEnum}, +}; + +/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. +/// +/// * `elem_ty` - The element type of the NDArray. +/// * `var_name` - The variable name of the NDArray. +/// * `shape` - The `shape` parameter used to construct the NDArray. +fn call_ndarray_impl<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + var_name: Option<&str>, + shape: PointerValue<'ctx>, +) -> Result, String> { + let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); + let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); + + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); + let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); + let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum(); + assert!(llvm_ndarray_data_t.is_sized()); + + // Assert that all dimensions are non-negative + gen_for_callback( + generator, + ctx, + |_, ctx| { + let i = ctx.builder.build_alloca(llvm_usize, ""); + ctx.builder.build_store(i, llvm_usize.const_zero()); + + Ok(i) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + let shape_len = ctx.build_gep_and_load( + shape, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + None, + ).into_int_value(); + + Ok(ctx.builder.build_int_compare(IntPredicate::ULE, i, shape_len, "")) + }, + |generator, ctx, i_addr| { + let shape_elems = ctx.build_gep_and_load( + shape, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None + ).into_pointer_value(); + + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + let shape_dim = ctx.build_gep_and_load( + shape_elems, + &[i], + None + ).into_int_value(); + + let shape_dim_gez = ctx.builder.build_int_compare( + IntPredicate::SGE, + shape_dim, + llvm_i32.const_zero(), + "" + ); + + ctx.make_assert( + generator, + shape_dim_gez, + "0:ValueError", + "negative dimensions not supported", + [None, None, None], + ctx.current_loc, + ); + + Ok(()) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), ""); + ctx.builder.build_store(i_addr, i); + + Ok(()) + }, + )?; + + let ndarray = ctx.builder.build_alloca( + llvm_ndarray_t, + var_name.unwrap_or_default() + ); + + let num_dims = ctx.build_gep_and_load( + shape, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + None + ).into_int_value(); + + let ndarray_num_dims = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + "", + ) + }; + ctx.builder.build_store(ndarray_num_dims, num_dims); + + let ndarray_dims = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + "", + ) + }; + + let ndarray_num_dims = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_int_value(); + + ctx.builder.build_store( + ndarray_dims, + ctx.builder.build_array_alloca( + llvm_usize, + ndarray_num_dims, + "", + ), + ); + + call_ndarray_init_dims(generator, ctx, ndarray, shape); + + let ndarray_num_elems = call_ndarray_calc_size(generator, ctx, shape); + + let ndarray_data = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], + "", + ) + }; + ctx.builder.build_store( + ndarray_data, + ctx.builder.build_array_alloca( + llvm_ndarray_data_t, + ndarray_num_elems, + "", + ), + ); + + Ok(ndarray) +} + +/// Generates LLVM IR for `ndarray.empty`. +pub fn gen_ndarray_empty<'ctx, 'a>( + context: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let shape_ty = fun.0.args[0].ty; + let shape_arg_name = args[0].0; + let shape_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, shape_ty)?; + + call_ndarray_impl( + generator, + context, + context.primitives.float, + shape_arg_name.map(|name| name.to_string()).as_deref(), + shape_arg.into_pointer_value(), + ) +} \ No newline at end of file diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index f9d3396..cfe000f 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -5,7 +5,7 @@ use std::{cell::RefCell, sync::Arc}; use super::typedef::{Call, FunSignature, FuncArg, RecordField, Type, TypeEnum, Unifier}; use super::{magic_methods::*, typedef::CallId}; -use crate::{symbol_resolver::SymbolResolver, toplevel::TopLevelContext}; +use crate::{symbol_resolver::{SymbolResolver, SymbolValue}, toplevel::TopLevelContext}; use itertools::izip; use nac3parser::ast::{ self, @@ -894,6 +894,53 @@ impl<'a> Inferencer<'a> { } } + // 1-argument ndarray n-dimensional creation functions + if [ + "np_ndarray".into(), + "np_empty".into(), + ].contains(id) && args.len() == 1 { + let ExprKind::List { elts, .. } = &args[0].node else { + return report_error("Expected List literal for first argument of np_ndarray", args[0].location) + }; + + let ndims = elts.len() as u64; + + let arg0 = self.fold_expr(args.remove(0))?; + let ndims = self.unifier.get_fresh_literal( + vec![SymbolValue::U64(ndims)], + None, + ); + let ret = self.unifier.add_ty(TypeEnum::TNDArray { + ty: self.primitives.float, + ndims + }); + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "shape".into(), + ty: arg0.custom.unwrap(), + default_value: None, + }, + ], + ret, + vars: HashMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: ctx.clone() }, + }), + args: vec![arg0], + keywords: vec![], + }, + })) + } + Ok(None) } diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index 830e86d..abdeda9 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -5,11 +5,12 @@ import importlib.util import importlib.machinery import math import numpy as np +import numpy.typing as npt import pathlib from numpy import int32, int64, uint32, uint64 from scipy import special -from typing import TypeVar, Generic, Literal +from typing import TypeVar, Generic, Literal, Union T = TypeVar('T') class Option(Generic[T]): @@ -50,6 +51,13 @@ class _ConstGenericMarker: def ConstGeneric(name, constraint): return TypeVar(name, _ConstGenericMarker, constraint) +N = TypeVar("N", bound=np.uint64) +class _NDArrayDummy(Generic[T, N]): + pass + +# https://stackoverflow.com/questions/67803260/how-to-create-a-type-alias-with-a-throw-away-generic +NDArray = Union[npt.NDArray[T], _NDArrayDummy[T, N]] + def round_away_zero(x): if x >= 0.0: return math.floor(x + 0.5) @@ -124,6 +132,16 @@ def patch(module): module.ceil64 = math.ceil module.np_ceil = np.ceil + # NumPy ndarray functions + module.ndarray = NDArray + module.np_ndarray = np.ndarray + module.np_empty = np.empty + module.np_zeros = np.zeros + module.np_ones = np.ones + module.np_full = np.full + module.np_eye = np.eye + module.np_identity = np.identity + # NumPy Math functions module.np_isnan = np.isnan module.np_isinf = np.isinf @@ -166,6 +184,9 @@ def patch(module): module.sp_spec_j0 = special.j0 module.sp_spec_j1 = special.j1 + # NumPy NDArray Functions + module.np_ndarray = np.ndarray + module.np_empty = np.empty def file_import(filename, prefix="file_import_"): filename = pathlib.Path(filename) diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py new file mode 100644 index 0000000..1237d06 --- /dev/null +++ b/nac3standalone/demo/src/ndarray.py @@ -0,0 +1,22 @@ +def consume_ndarray_1(n: ndarray[float, Literal[1]]): + pass + +def consume_ndarray_i32_1(n: ndarray[int32, Literal[1]]): + pass + +def consume_ndarray_2(n: ndarray[float, Literal[2]]): + pass + +def test_ndarray_ctor(): + n = np_ndarray([1]) + consume_ndarray_1(n) + +def test_ndarray_empty(): + n = np_empty([1]) + consume_ndarray_1(n) + +def run() -> int32: + test_ndarray_ctor() + test_ndarray_empty() + + return 0 -- 2.44.1 From 140f8f8a08fbfeea4fa5ced45fad177e7d4289f3 Mon Sep 17 00:00:00 2001 From: David Mak Date: Mon, 27 Nov 2023 13:25:53 +0800 Subject: [PATCH 9/9] core: Implement most ndarray-creation functions --- nac3core/src/codegen/generator.rs | 12 + nac3core/src/codegen/irrt/irrt.c | 40 +- nac3core/src/codegen/irrt/mod.rs | 146 ++-- nac3core/src/codegen/mod.rs | 45 +- nac3core/src/codegen/stmt.rs | 33 +- nac3core/src/toplevel/builtins.rs | 129 +++- nac3core/src/toplevel/numpy.rs | 729 +++++++++++++++++- nac3core/src/typecheck/type_inferencer/mod.rs | 63 +- nac3standalone/demo/interpret_demo.py | 5 + nac3standalone/demo/src/ndarray.py | 33 + 10 files changed, 1130 insertions(+), 105 deletions(-) diff --git a/nac3core/src/codegen/generator.rs b/nac3core/src/codegen/generator.rs index 7a86c0b..c5c6aed 100644 --- a/nac3core/src/codegen/generator.rs +++ b/nac3core/src/codegen/generator.rs @@ -92,6 +92,18 @@ pub trait CodeGenerator { gen_var(ctx, ty, name) } + /// Allocate memory for a variable and return a pointer pointing to it. + /// The default implementation places the allocations at the start of the function. + fn gen_array_var_alloc<'ctx, 'a>( + &mut self, + ctx: &mut CodeGenContext<'ctx, 'a>, + ty: BasicTypeEnum<'ctx>, + size: IntValue<'ctx>, + name: Option<&str>, + ) -> Result, String> { + gen_array_var(ctx, ty, size, name) + } + /// Return a pointer pointing to the target of the expression. fn gen_store_target<'ctx>( &mut self, diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index 80e48aa..8b28bc1 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -199,27 +199,27 @@ double __nac3_j0(double x) { } uint32_t __nac3_ndarray_calc_size( - const int32_t *list_data, + const uint64_t *list_data, uint32_t list_len ) { uint32_t num_elems = 1; for (uint32_t i = 0; i < list_len; ++i) { - int32_t val = list_data[i]; + uint64_t val = list_data[i]; __builtin_assume(val >= 0); - num_elems *= (uint32_t) list_data[i]; + num_elems *= list_data[i]; } return num_elems; } uint64_t __nac3_ndarray_calc_size64( - const int32_t *list_data, + const uint64_t *list_data, uint64_t list_len ) { uint64_t num_elems = 1; for (uint64_t i = 0; i < list_len; ++i) { - int32_t val = list_data[i]; + uint64_t val = list_data[i]; __builtin_assume(val >= 0); - num_elems *= (uint64_t) list_data[i]; + num_elems *= list_data[i]; } return num_elems; } @@ -240,4 +240,32 @@ void __nac3_ndarray_init_dims64( for (uint64_t i = 0; i < shape_len; ++i) { ndarray_dims[i] = (uint64_t) shape_data[i]; } +} + +void __nac3_ndarray_calc_nd_indices( + uint32_t index, + const uint32_t* dims, + uint32_t num_dims, + uint32_t* idxs +) { + uint32_t stride = 1; + for (uint32_t dim = 0; dim < num_dims; dim++) { + uint32_t i = num_dims - dim - 1; + idxs[i] = (index / stride) % dims[i]; + stride *= dims[i]; + } +} + +void __nac3_ndarray_calc_nd_indices64( + uint64_t index, + const uint64_t* dims, + uint64_t num_dims, + uint64_t* idxs +) { + uint64_t stride = 1; + for (uint64_t dim = 0; dim < num_dims; dim++) { + uint64_t i = num_dims - dim - 1; + idxs[i] = (index / stride) % dims[i]; + stride *= dims[i]; + } } \ No newline at end of file diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index d6906c6..e2add43 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -1,6 +1,6 @@ use crate::typecheck::typedef::Type; -use super::{CodeGenContext, CodeGenerator}; +use super::{assert_is_list, assert_is_ndarray, CodeGenContext, CodeGenerator}; use inkwell::{ attributes::{Attribute, AttributeLoc}, context::Context, @@ -12,9 +12,6 @@ use inkwell::{ }; use nac3parser::ast::Expr; -#[cfg(debug_assertions)] -use inkwell::types::AnyTypeEnum; - #[must_use] pub fn load_irrt(ctx: &Context) -> Module { let bitcode_buf = MemoryBuffer::create_from_memory_range( @@ -550,62 +547,21 @@ pub fn call_j0<'ctx>( .into_float_value() } -/// Checks whether the pointer `value` refers to a `list` in LLVM. -fn assert_is_list(value: PointerValue) -> PointerValue { - #[cfg(debug_assertions)] - { - let llvm_shape_ty = value.get_type().get_element_type(); - let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else { - panic!("Expected struct type for `list` type, but got {llvm_shape_ty}") - }; - assert_eq!(llvm_shape_ty.count_fields(), 2); - assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..)))); - assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..)))); - } - - value -} - -/// Checks whether the pointer `value` refers to an `NDArray` in LLVM. -fn assert_is_ndarray(value: PointerValue) -> PointerValue { - #[cfg(debug_assertions)] - { - let llvm_ndarray_ty = value.get_type().get_element_type(); - let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { - panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}") - }; - - assert_eq!(llvm_ndarray_ty.count_fields(), 3); - assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..)))); - let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else { - unreachable!() - }; - let BasicTypeEnum::PointerType(dims) = ndarray_dims else { - panic!("Expected pointer type for `list.1`, but got {ndarray_dims}") - }; - assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..))); - assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..)))); - } - - value -} - /// Generates a call to `__nac3_ndarray_calc_size`. Returns an [IntValue] representing the /// calculated total size. /// -/// * `shape` - LLVM pointer to the `shape` of the NDArray. This value must be the LLVM -/// representation of a `list`. +/// * `num_dims` - An [IntValue] containing the number of dimensions. +/// * `dims` - A [PointerValue] to an array containing the size of each dimensions. pub fn call_ndarray_calc_size<'ctx, 'a>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, - shape: PointerValue<'ctx>, + num_dims: IntValue<'ctx>, + dims: PointerValue<'ctx>, ) -> IntValue<'ctx> { - assert_is_list(shape); - - let llvm_i32 = ctx.ctx.i32_type(); + let llvm_i64 = ctx.ctx.i64_type(); let llvm_usize = generator.get_size_type(ctx.ctx); - let llvm_pi32 = llvm_i32.ptr_type(AddressSpace::default()); + let llvm_pi64 = llvm_i64.ptr_type(AddressSpace::default()); let ndarray_calc_size_fn_name = match generator.get_size_type(ctx.ctx).get_bit_width() { 32 => "__nac3_ndarray_calc_size", @@ -614,7 +570,7 @@ pub fn call_ndarray_calc_size<'ctx, 'a>( }; let ndarray_calc_size_fn_t = llvm_usize.fn_type( &[ - llvm_pi32.into(), + llvm_pi64.into(), llvm_usize.into(), ], false, @@ -624,30 +580,12 @@ pub fn call_ndarray_calc_size<'ctx, 'a>( ctx.module.add_function(ndarray_calc_size_fn_name, ndarray_calc_size_fn_t, None) }); - let ( - shape_data, - shape_len, - ) = unsafe { - ( - ctx.builder.build_in_bounds_gep( - shape, - &[llvm_i32.const_zero(), llvm_i32.const_zero()], - "" - ), - ctx.builder.build_in_bounds_gep( - shape, - &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], - "" - ), - ) - }; - ctx.builder .build_call( ndarray_calc_size_fn, &[ - ctx.builder.build_load(shape_data, "").into(), - ctx.builder.build_load(shape_len, "").into(), + dims.into(), + num_dims.into(), ], "", ) @@ -721,4 +659,68 @@ pub fn call_ndarray_init_dims<'ctx, 'a>( ], "", ); +} + +pub fn call_ndarray_calc_nd_indices<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + index: IntValue<'ctx>, + ndarray: PointerValue<'ctx>, +) -> Result, String> { + assert_is_ndarray(ndarray); + + let llvm_void = ctx.ctx.void_type(); + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pusize = llvm_usize.ptr_type(AddressSpace::default()); + + let ndarray_calc_nd_indices_dn_name = match generator.get_size_type(ctx.ctx).get_bit_width() { + 32 => "__nac3_ndarray_calc_nd_indices", + 64 => "__nac3_ndarray_calc_nd_indices64", + bw => unreachable!("Unsupported size type bit width: {}", bw) + }; + let ndarray_calc_nd_indices_fn = ctx.module.get_function(ndarray_calc_nd_indices_dn_name).unwrap_or_else(|| { + let fn_type = llvm_void.fn_type( + &[ + llvm_usize.into(), + llvm_pusize.into(), + llvm_usize.into(), + llvm_pusize.into(), + ], + false, + ); + + ctx.module.add_function(ndarray_calc_nd_indices_dn_name, fn_type, None) + }); + + let ndarray_num_dims = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_int_value(); + let ndarray_dims = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + None, + ).into_pointer_value(); + + let indices = ctx.builder.build_array_alloca( + llvm_usize, + ndarray_num_dims, + "", + ); + + ctx.builder.build_call( + ndarray_calc_nd_indices_fn, + &[ + index.into(), + ndarray_dims.into(), + ndarray_num_dims.into(), + indices.into(), + ], + "", + ); + + Ok(indices) } \ No newline at end of file diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 21943d4..b1836a0 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -34,6 +34,9 @@ use std::sync::{ }; use std::thread; +#[cfg(debug_assertions)] +use inkwell::types::AnyTypeEnum; + pub mod concrete_type; pub mod expr; mod generator; @@ -236,7 +239,7 @@ pub struct WorkerRegistry { static_value_store: Arc>, /// LLVM-related options for code generation. - llvm_options: CodeGenLLVMOptions, + pub llvm_options: CodeGenLLVMOptions, } impl WorkerRegistry { @@ -995,3 +998,43 @@ fn gen_in_range_check<'ctx>( ctx.builder.build_int_compare(IntPredicate::SLT, lo, hi, "cmp") } + +/// Checks whether the pointer `value` refers to a `list` in LLVM. +fn assert_is_list(value: PointerValue) -> PointerValue { + #[cfg(debug_assertions)] + { + let llvm_shape_ty = value.get_type().get_element_type(); + let AnyTypeEnum::StructType(llvm_shape_ty) = llvm_shape_ty else { + panic!("Expected struct type for `list` type, but got {llvm_shape_ty}") + }; + assert_eq!(llvm_shape_ty.count_fields(), 2); + assert!(matches!(llvm_shape_ty.get_field_type_at_index(0), Some(BasicTypeEnum::PointerType(..)))); + assert!(matches!(llvm_shape_ty.get_field_type_at_index(1), Some(BasicTypeEnum::IntType(..)))); + } + + value +} + +/// Checks whether the pointer `value` refers to an `NDArray` in LLVM. +fn assert_is_ndarray(value: PointerValue) -> PointerValue { + #[cfg(debug_assertions)] + { + let llvm_ndarray_ty = value.get_type().get_element_type(); + let AnyTypeEnum::StructType(llvm_ndarray_ty) = llvm_ndarray_ty else { + panic!("Expected struct type for `NDArray` type, but got {llvm_ndarray_ty}") + }; + + assert_eq!(llvm_ndarray_ty.count_fields(), 3); + assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(0), Some(BasicTypeEnum::IntType(..)))); + let Some(ndarray_dims) = llvm_ndarray_ty.get_field_type_at_index(1) else { + unreachable!() + }; + let BasicTypeEnum::PointerType(dims) = ndarray_dims else { + panic!("Expected pointer type for `list.1`, but got {ndarray_dims}") + }; + assert!(matches!(dims.get_element_type(), AnyTypeEnum::IntType(..))); + assert!(matches!(llvm_ndarray_ty.get_field_type_at_index(2), Some(BasicTypeEnum::PointerType(..)))); + } + + value +} diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs index 111fb6c..1cf57b2 100644 --- a/nac3core/src/codegen/stmt.rs +++ b/nac3core/src/codegen/stmt.rs @@ -15,7 +15,7 @@ use crate::{ use inkwell::{ attributes::{Attribute, AttributeLoc}, basic_block::BasicBlock, - types::BasicTypeEnum, + types::{BasicType, BasicTypeEnum}, values::{BasicValue, BasicValueEnum, FunctionValue, IntValue, PointerValue}, IntPredicate, }; @@ -54,6 +54,37 @@ pub fn gen_var<'ctx>( Ok(ptr) } +/// See [CodeGenerator::gen_array_var_alloc]. +pub fn gen_array_var<'ctx, 'a, T: BasicType<'ctx>>( + ctx: &mut CodeGenContext<'ctx, 'a>, + ty: T, + size: IntValue<'ctx>, + name: Option<&str>, +) -> Result, String> { + // Restore debug location + let di_loc = ctx.debug_info.0.create_debug_location( + ctx.ctx, + ctx.current_loc.row as u32, + ctx.current_loc.column as u32, + ctx.debug_info.2, + None, + ); + + // put the alloca in init block + let current = ctx.builder.get_insert_block().unwrap(); + + // position before the last branching instruction... + ctx.builder.position_before(&ctx.init_bb.get_last_instruction().unwrap()); + ctx.builder.set_current_debug_location(di_loc); + + let ptr = ctx.builder.build_array_alloca(ty, size, name.unwrap_or("")); + + ctx.builder.position_at_end(current); + ctx.builder.set_current_debug_location(di_loc); + + Ok(ptr) +} + /// See [`CodeGenerator::gen_store_target`]. pub fn gen_store_target<'ctx, G: CodeGenerator>( generator: &mut G, diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index cb3f650..d2eb458 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -13,7 +13,13 @@ use crate::{ stmt::exn_constructor, }, symbol_resolver::SymbolValue, - toplevel::numpy::gen_ndarray_empty, + toplevel::numpy::{ + gen_ndarray_empty, + gen_ndarray_eye, + gen_ndarray_full, + gen_ndarray_ones, + gen_ndarray_zeros, + }, }; use inkwell::{ attributes::{Attribute, AttributeLoc}, @@ -22,6 +28,7 @@ use inkwell::{ FloatPredicate, IntPredicate }; +use crate::toplevel::numpy::gen_ndarray_identity; type BuiltinInfo = Vec<(Arc>, Option)>; @@ -279,10 +286,30 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let boolean = primitives.0.bool; let range = primitives.0.range; let string = primitives.0.str; + let ndarray = { + let ndarray_ty = TypeEnum::ndarray(&mut primitives.1, None, None, &primitives.0); + primitives.1.add_ty(ndarray_ty) + }; let ndarray_float = { let ndarray_ty_enum = TypeEnum::ndarray(&mut primitives.1, Some(float), None, &primitives.0); primitives.1.add_ty(ndarray_ty_enum) }; + let ndarray_float_2d = { + let value = match primitives.0.size_t { + 64 => SymbolValue::U64(2u64), + 32 => SymbolValue::U32(2u32), + _ => unreachable!(), + }; + let ndims = primitives.1.add_ty(TypeEnum::TLiteral { + values: vec![value], + loc: None, + }); + + primitives.1.add_ty(TypeEnum::TNDArray { + ty: float, + ndims, + }) + }; let list_int32 = primitives.1.add_ty(TypeEnum::TList { ty: int32 }); let num_ty = primitives.1.get_fresh_var_with_range( &[int32, int64, float, boolean, uint32, uint64], @@ -869,6 +896,89 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { .map(|val| Some(val.as_basic_value_enum())) }), ), + create_fn_by_codegen( + primitives, + &var_map, + "np_zeros", + ndarray_float, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(list_int32, "shape")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_zeros(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ), + create_fn_by_codegen( + primitives, + &var_map, + "np_ones", + ndarray_float, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(list_int32, "shape")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_ones(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ), + { + let tv = primitives.1.get_fresh_var(Some("T".into()), None).0; + + create_fn_by_codegen( + primitives, + &var_map, + "np_full", + ndarray, + // We are using List[int32] here, as I don't know a way to specify an n-tuple bound on a + // type variable + &[(list_int32, "shape"), (tv, "fill_value")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_full(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ) + }, + Arc::new(RwLock::new(TopLevelDef::Function { + name: "np_eye".into(), + simple_name: "np_eye".into(), + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { name: "N".into(), ty: int32, default_value: None }, + // TODO(Derppening): Default values current do not work? + FuncArg { + name: "M".into(), + ty: int32, + default_value: Some(SymbolValue::OptionNone) + }, + FuncArg { name: "k".into(), ty: int32, default_value: Some(SymbolValue::I32(0)) }, + ], + ret: ndarray_float_2d, + vars: var_map.clone(), + })), + var_id: Default::default(), + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, fun, args, generator| { + gen_ndarray_eye(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }, + )))), + loc: None, + })), + create_fn_by_codegen( + primitives, + &var_map, + "np_identity", + ndarray_float_2d, + &[(int32, "n")], + Box::new(|ctx, obj, fun, args, generator| { + gen_ndarray_identity(ctx, obj, fun, args, generator) + .map(|val| Some(val.as_basic_value_enum())) + }), + ), create_fn_by_codegen( primitives, &var_map, @@ -1364,7 +1474,22 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { Some(ctx.builder.build_int_truncate(len, int32, "len2i32").into()) } } - TypeEnum::TNDArray { .. } => todo!(), + TypeEnum::TNDArray { .. } => { + let llvm_i32 = ctx.ctx.i32_type(); + let i32_zero = llvm_i32.const_zero(); + + let len = ctx.build_gep_and_load( + arg.into_pointer_value(), + &[i32_zero, i32_zero], + None, + ).into_int_value(); + + if len.get_type().get_bit_width() != 32 { + Some(ctx.builder.build_int_truncate(len, llvm_i32, "len").into()) + } else { + Some(len.into()) + } + } _ => unreachable!(), } }) diff --git a/nac3core/src/toplevel/numpy.rs b/nac3core/src/toplevel/numpy.rs index 9b91e82..13bb8a5 100644 --- a/nac3core/src/toplevel/numpy.rs +++ b/nac3core/src/toplevel/numpy.rs @@ -1,14 +1,15 @@ -use inkwell::{ - IntPredicate, - types::BasicType, - values::PointerValue, -}; +use inkwell::{AddressSpace, IntPredicate, types::BasicType, values::{BasicValueEnum, PointerValue}}; +use inkwell::values::{ArrayValue, IntValue}; use nac3parser::ast::StrRef; use crate::{ codegen::{ CodeGenContext, CodeGenerator, - irrt::{call_ndarray_calc_size, call_ndarray_init_dims}, + irrt::{ + call_ndarray_calc_nd_indices, + call_ndarray_calc_size, + call_ndarray_init_dims, + }, stmt::gen_for_callback }, symbol_resolver::ValueEnum, @@ -16,16 +17,201 @@ use crate::{ typecheck::typedef::{FunSignature, Type, TypeEnum}, }; -/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. +/// Creates an `NDArray` instance from a constant shape. /// -/// * `elem_ty` - The element type of the NDArray. -/// * `var_name` - The variable name of the NDArray. -/// * `shape` - The `shape` parameter used to construct the NDArray. -fn call_ndarray_impl<'ctx, 'a>( +/// * `elem_ty` - The element type of the `NDArray`. +/// * `shape` - The shape of the `NDArray`, represented as an LLVM [ArrayValue]. +fn create_ndarray_const_shape<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + shape: ArrayValue<'ctx> +) -> Result, String> { + let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); + let ndarray_ty = ctx.unifier.add_ty(ndarray_ty_enum); + + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let llvm_pndarray_t = ctx.get_llvm_type(generator, ndarray_ty).into_pointer_type(); + let llvm_ndarray_t = llvm_pndarray_t.get_element_type().into_struct_type(); + let llvm_ndarray_data_t = ctx.get_llvm_type(generator, elem_ty).as_basic_type_enum(); + assert!(llvm_ndarray_data_t.is_sized()); + + for i in 0..shape.get_type().len() { + let shape_dim = ctx.builder.build_extract_value( + shape, + i, + "", + ).unwrap(); + + let shape_dim_gez = ctx.builder.build_int_compare( + IntPredicate::SGE, + shape_dim.into_int_value(), + llvm_usize.const_zero(), + "" + ); + + ctx.make_assert( + generator, + shape_dim_gez, + "0:ValueError", + "negative dimensions not supported", + [None, None, None], + ctx.current_loc, + ); + } + + let ndarray = generator.gen_var_alloc( + ctx, + llvm_ndarray_t.into(), + None, + )?; + + let num_dims = llvm_usize.const_int(shape.get_type().len() as u64, false); + + let ndarray_num_dims = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + "", + ) + }; + ctx.builder.build_store(ndarray_num_dims, num_dims); + + let ndarray_dims = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + "", + ) + }; + + let ndarray_num_dims = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + None, + ).into_int_value(); + + ctx.builder.build_store( + ndarray_dims, + ctx.builder.build_array_alloca( + llvm_usize, + ndarray_num_dims, + "", + ), + ); + + for i in 0..shape.get_type().len() { + let ndarray_dim = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + None, + ).into_pointer_value(); + let ndarray_dim = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray_dim, + &[llvm_i32.const_int(i as u64, true)], + "", + ) + }; + let shape_dim = ctx.builder.build_extract_value(shape, i, "") + .map(|val| val.into_int_value()) + .unwrap(); + + ctx.builder.build_store(ndarray_dim, shape_dim); + } + + let (ndarray_num_dims, ndarray_dims) = unsafe { + ( + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + "" + ), + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + "" + ), + ) + }; + let ndarray_num_elems = call_ndarray_calc_size( + generator, + ctx, + ctx.builder.build_load(ndarray_num_dims, "").into_int_value(), + ctx.builder.build_load(ndarray_dims, "").into_pointer_value(), + ); + + let ndarray_data = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], + "", + ) + }; + ctx.builder.build_store( + ndarray_data, + ctx.builder.build_array_alloca( + llvm_ndarray_data_t, + ndarray_num_elems, + "" + ), + ); + + Ok(ndarray) +} + +fn ndarray_zero_value<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { + ctx.ctx.i32_type().const_zero().into() + } else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { + ctx.ctx.i64_type().const_zero().into() + } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { + ctx.ctx.f64_type().const_zero().into() + } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { + ctx.ctx.bool_type().const_zero().into() + } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { + ctx.gen_string(generator, "").into() + } else { + unreachable!() + } +} + +fn ndarray_one_value<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, +) -> BasicValueEnum<'ctx> { + if [ctx.primitives.int32, ctx.primitives.uint32].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { + let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int32); + ctx.ctx.i32_type().const_int(1, is_signed).into() + } else if [ctx.primitives.int64, ctx.primitives.uint64].iter().any(|ty| ctx.unifier.unioned(elem_ty, *ty)) { + let is_signed = ctx.unifier.unioned(elem_ty, ctx.primitives.int64); + ctx.ctx.i64_type().const_int(1, is_signed).into() + } else if ctx.unifier.unioned(elem_ty, ctx.primitives.float) { + ctx.ctx.f64_type().const_float(1.0).into() + } else if ctx.unifier.unioned(elem_ty, ctx.primitives.bool) { + ctx.ctx.bool_type().const_int(1, false).into() + } else if ctx.unifier.unioned(elem_ty, ctx.primitives.str) { + ctx.gen_string(generator, "1").into() + } else { + unreachable!() + } +} + +/// LLVM-typed implementation for generating the implementation for constructing an `NDArray`. +/// +/// * `elem_ty` - The element type of the NDArray. +/// * `shape` - The `shape` parameter used to construct the NDArray. +fn call_ndarray_empty_impl<'ctx, 'a>( generator: &mut dyn CodeGenerator, ctx: &mut CodeGenContext<'ctx, 'a>, elem_ty: Type, - var_name: Option<&str>, shape: PointerValue<'ctx>, ) -> Result, String> { let ndarray_ty_enum = TypeEnum::ndarray(&mut ctx.unifier, Some(elem_ty), None, &ctx.primitives); @@ -43,8 +229,8 @@ fn call_ndarray_impl<'ctx, 'a>( gen_for_callback( generator, ctx, - |_, ctx| { - let i = ctx.builder.build_alloca(llvm_usize, ""); + |generator, ctx| { + let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; ctx.builder.build_store(i, llvm_usize.const_zero()); Ok(i) @@ -106,10 +292,11 @@ fn call_ndarray_impl<'ctx, 'a>( }, )?; - let ndarray = ctx.builder.build_alloca( - llvm_ndarray_t, - var_name.unwrap_or_default() - ); + let ndarray = generator.gen_var_alloc( + ctx, + llvm_ndarray_t.into(), + None, + )?; let num_dims = ctx.build_gep_and_load( shape, @@ -151,7 +338,26 @@ fn call_ndarray_impl<'ctx, 'a>( call_ndarray_init_dims(generator, ctx, ndarray, shape); - let ndarray_num_elems = call_ndarray_calc_size(generator, ctx, shape); + let (ndarray_num_dims, ndarray_dims) = unsafe { + ( + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + "" + ), + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + "" + ), + ) + }; + let ndarray_num_elems = call_ndarray_calc_size( + generator, + ctx, + ctx.builder.build_load(ndarray_num_dims, "").into_int_value(), + ctx.builder.build_load(ndarray_dims, "").into_pointer_value(), + ); let ndarray_data = unsafe { ctx.builder.build_in_bounds_gep( @@ -172,6 +378,342 @@ fn call_ndarray_impl<'ctx, 'a>( Ok(ndarray) } +/// Generates LLVM IR for populating the entire `NDArray` using a lambda with its flattened index as +/// its input. +/// +/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements +/// with the given value (as opposed to all elements within the array). +fn ndarray_fill_flattened<'ctx, 'a, ValueFn>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarray: PointerValue<'ctx>, + value_fn: ValueFn, +) -> Result<(), String> + where + ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, IntValue<'ctx>) -> Result, String>, +{ + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + + let (num_dims, dims) = unsafe { + ( + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_zero()], + "" + ), + ctx.builder.build_in_bounds_gep( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(1, true)], + "" + ), + ) + }; + + let ndarray_num_elems = call_ndarray_calc_size( + generator, + ctx, + ctx.builder.build_load(num_dims, "").into_int_value(), + ctx.builder.build_load(dims, "").into_pointer_value(), + ); + + gen_for_callback( + generator, + ctx, + |generator, ctx| { + let i = generator.gen_var_alloc(ctx, llvm_usize.into(), None)?; + ctx.builder.build_store(i, llvm_usize.const_zero()); + + Ok(i) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + + Ok(ctx.builder.build_int_compare(IntPredicate::ULT, i, ndarray_num_elems, "")) + }, + |generator, ctx, i_addr| { + let ndarray_data = ctx.build_gep_and_load( + ndarray, + &[llvm_i32.const_zero(), llvm_i32.const_int(2, true)], + None + ).into_pointer_value(); + + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + let elem = unsafe { + ctx.builder.build_in_bounds_gep( + ndarray_data, + &[i], + "" + ) + }; + + let value = value_fn(generator, ctx, i)?; + ctx.builder.build_store(elem, value); + + Ok(()) + }, + |_, ctx, i_addr| { + let i = ctx.builder + .build_load(i_addr, "") + .into_int_value(); + let i = ctx.builder.build_int_add(i, llvm_usize.const_int(1, true), ""); + ctx.builder.build_store(i_addr, i); + + Ok(()) + }, + ) +} + +/// Generates LLVM IR for populating the entire `NDArray` using a lambda with the dimension-indices +/// as its input +/// +/// Note that this differs from `ndarray.fill`, which instead replaces all first-dimension elements +/// with the given value (as opposed to all elements within the array). +fn ndarray_fill_indexed<'ctx, 'a, ValueFn>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + ndarray: PointerValue<'ctx>, + value_fn: ValueFn, +) -> Result<(), String> + where + ValueFn: Fn(&mut dyn CodeGenerator, &mut CodeGenContext<'ctx, 'a>, PointerValue<'ctx>) -> Result, String>, +{ + ndarray_fill_flattened( + generator, + ctx, + ndarray, + |generator, ctx, idx| { + let indices = call_ndarray_calc_nd_indices( + generator, + ctx, + idx, + ndarray, + )?; + + value_fn(generator, ctx, indices) + } + ) +} + +/// LLVM-typed implementation for generating the implementation for `ndarray.zeros`. +/// +/// * `elem_ty` - The element type of the NDArray. +/// * `shape` - The `shape` parameter used to construct the NDArray. +fn call_ndarray_zeros_impl<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + shape: PointerValue<'ctx>, +) -> Result, String> { + let supported_types = [ + ctx.primitives.int32, + ctx.primitives.int64, + ctx.primitives.uint32, + ctx.primitives.uint64, + ctx.primitives.float, + ctx.primitives.bool, + ctx.primitives.str, + ]; + assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); + + let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; + ndarray_fill_flattened( + generator, + ctx, + ndarray, + |generator, ctx, _| { + let value = ndarray_zero_value(generator, ctx, elem_ty); + + Ok(value) + } + )?; + + Ok(ndarray) +} + +/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. +/// +/// * `elem_ty` - The element type of the NDArray. +/// * `shape` - The `shape` parameter used to construct the NDArray. +fn call_ndarray_ones_impl<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + shape: PointerValue<'ctx>, +) -> Result, String> { + let supported_types = [ + ctx.primitives.int32, + ctx.primitives.int64, + ctx.primitives.uint32, + ctx.primitives.uint64, + ctx.primitives.float, + ctx.primitives.bool, + ctx.primitives.str, + ]; + assert!(supported_types.iter().any(|supported_ty| ctx.unifier.unioned(*supported_ty, elem_ty))); + + let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; + ndarray_fill_flattened( + generator, + ctx, + ndarray, + |generator, ctx, _| { + let value = ndarray_one_value(generator, ctx, elem_ty); + + Ok(value) + } + )?; + + Ok(ndarray) +} + +/// LLVM-typed implementation for generating the implementation for `ndarray.ones`. +/// +/// * `elem_ty` - The element type of the NDArray. +/// * `shape` - The `shape` parameter used to construct the NDArray. +fn call_ndarray_full_impl<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + shape: PointerValue<'ctx>, + fill_value: BasicValueEnum<'ctx>, +) -> Result, String> { + let ndarray = call_ndarray_empty_impl(generator, ctx, elem_ty, shape)?; + ndarray_fill_flattened( + generator, + ctx, + ndarray, + |generator, ctx, _| { + let value = if fill_value.is_pointer_value() { + let llvm_void = ctx.ctx.void_type(); + let llvm_i1 = ctx.ctx.bool_type(); + let llvm_i8 = ctx.ctx.i8_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_pi8 = llvm_i8.ptr_type(AddressSpace::default()); + + let copy = generator.gen_var_alloc(ctx, fill_value.get_type(), None)?; + + let memcpy_fn_name = format!( + "llvm.memcpy.p0i8.p0i8.i{}", + generator.get_size_type(ctx.ctx).get_bit_width(), + ); + let memcpy_fn = ctx.module.get_function(memcpy_fn_name.as_str()).unwrap_or_else(|| { + let fn_type = llvm_void.fn_type( + &[ + llvm_pi8.into(), + llvm_pi8.into(), + llvm_usize.into(), + llvm_i1.into(), + ], + false, + ); + + ctx.module.add_function(memcpy_fn_name.as_str(), fn_type, None) + }); + + ctx.builder.build_call( + memcpy_fn, + &[ + copy.into(), + fill_value.into(), + fill_value.get_type().size_of().unwrap().into(), + llvm_i1.const_zero().into(), + ], + "", + ); + + copy.into() + } else if fill_value.is_int_value() || fill_value.is_float_value() { + fill_value.into() + } else { + unreachable!() + }; + + Ok(value) + } + )?; + + Ok(ndarray) +} + +/// LLVM-typed implementation for generating the implementation for `ndarray.eye`. +/// +/// * `elem_ty` - The element type of the NDArray. +fn call_ndarray_eye_impl<'ctx, 'a>( + generator: &mut dyn CodeGenerator, + ctx: &mut CodeGenContext<'ctx, 'a>, + elem_ty: Type, + nrows: IntValue<'ctx>, + ncols: IntValue<'ctx>, + offset: IntValue<'ctx>, +) -> Result, String> { + let llvm_i32 = ctx.ctx.i32_type(); + let llvm_usize = generator.get_size_type(ctx.ctx); + let llvm_usize_2 = llvm_usize.array_type(2); + + let shape_addr = generator.gen_var_alloc(ctx, llvm_usize_2.into(), None)?; + + let shape = ctx.builder.build_load(shape_addr, "") + .into_array_value(); + + let nrows = ctx.builder.build_int_z_extend_or_bit_cast(nrows, llvm_usize, ""); + let shape = ctx.builder + .build_insert_value(shape, nrows, 0, "") + .map(|val| val.into_array_value()) + .unwrap(); + + let ncols = ctx.builder.build_int_z_extend_or_bit_cast(ncols, llvm_usize, ""); + let shape = ctx.builder + .build_insert_value(shape, ncols, 1, "") + .map(|val| val.into_array_value()) + .unwrap(); + + let ndarray = create_ndarray_const_shape(generator, ctx, elem_ty, shape)?; + + ndarray_fill_indexed( + generator, + ctx, + ndarray, + |generator, ctx, indices| { + let row = ctx.build_gep_and_load( + indices, + &[llvm_i32.const_zero()], + None, + ).into_int_value(); + let col = ctx.build_gep_and_load( + indices, + &[llvm_i32.const_int(1, true)], + None, + ).into_int_value(); + + let col_with_offset = ctx.builder.build_int_add( + col, + ctx.builder.build_int_z_extend_or_bit_cast(offset, llvm_usize, ""), + "" + ); + let is_on_diag = ctx.builder.build_int_compare( + IntPredicate::EQ, + row, + col_with_offset, + "" + ); + + let zero = ndarray_zero_value(generator, ctx, elem_ty); + let one = ndarray_one_value(generator, ctx, elem_ty); + + let value = ctx.builder.build_select(is_on_diag, one, zero, ""); + + Ok(value) + }, + )?; + + Ok(ndarray) +} + /// Generates LLVM IR for `ndarray.empty`. pub fn gen_ndarray_empty<'ctx, 'a>( context: &mut CodeGenContext<'ctx, 'a>, @@ -184,15 +726,158 @@ pub fn gen_ndarray_empty<'ctx, 'a>( assert_eq!(args.len(), 1); let shape_ty = fun.0.args[0].ty; - let shape_arg_name = args[0].0; let shape_arg = args[0].1.clone() .to_basic_value_enum(context, generator, shape_ty)?; - call_ndarray_impl( + call_ndarray_empty_impl( generator, context, context.primitives.float, - shape_arg_name.map(|name| name.to_string()).as_deref(), shape_arg.into_pointer_value(), ) +} + +/// Generates LLVM IR for `ndarray.zeros`. +pub fn gen_ndarray_zeros<'ctx, 'a>( + context: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let shape_ty = fun.0.args[0].ty; + let shape_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, shape_ty)?; + + call_ndarray_zeros_impl( + generator, + context, + context.primitives.float, + shape_arg.into_pointer_value(), + ) +} + +/// Generates LLVM IR for `ndarray.ones`. +pub fn gen_ndarray_ones<'ctx, 'a>( + context: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let shape_ty = fun.0.args[0].ty; + let shape_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, shape_ty)?; + + call_ndarray_ones_impl( + generator, + context, + context.primitives.float, + shape_arg.into_pointer_value(), + ) +} + +/// Generates LLVM IR for `ndarray.full`. +pub fn gen_ndarray_full<'ctx, 'a>( + context: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 2); + + let shape_ty = fun.0.args[0].ty; + let shape_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, shape_ty)?; + let fill_value_ty = fun.0.args[1].ty; + let fill_value_arg = args[1].1.clone() + .to_basic_value_enum(context, generator, fill_value_ty)?; + + call_ndarray_full_impl( + generator, + context, + fill_value_ty, + shape_arg.into_pointer_value(), + fill_value_arg, + ) +} + +/// Generates LLVM IR for `ndarray.eye`. +pub fn gen_ndarray_eye<'ctx, 'a>( + context: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert!(matches!(args.len(), 1..=3)); + + let nrows_ty = fun.0.args[0].ty; + let nrows_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, nrows_ty)?; + + let ncols_ty = fun.0.args[1].ty; + let ncols_arg = args.iter() + .find(|arg| arg.0.map(|name| name == fun.0.args[1].name).unwrap_or(false)) + .map(|arg| arg.1.clone().to_basic_value_enum(context, generator, ncols_ty)) + .unwrap_or_else(|| { + args[0].1.clone().to_basic_value_enum(context, generator, nrows_ty) + })?; + + let offset_ty = fun.0.args[2].ty; + let offset_arg = args.iter() + .find(|arg| arg.0.map(|name| name == fun.0.args[2].name).unwrap_or(false)) + .map(|arg| arg.1.clone().to_basic_value_enum(context, generator, offset_ty)) + .unwrap_or_else(|| { + Ok(context.gen_symbol_val( + generator, + fun.0.args[2].default_value.as_ref().unwrap(), + offset_ty + )) + })?; + + call_ndarray_eye_impl( + generator, + context, + context.primitives.float, + nrows_arg.into_int_value(), + ncols_arg.into_int_value(), + offset_arg.into_int_value(), + ) +} + +/// Generates LLVM IR for `ndarray.identity`. +pub fn gen_ndarray_identity<'ctx, 'a>( + context: &mut CodeGenContext<'ctx, 'a>, + obj: Option<(Type, ValueEnum<'ctx>)>, + fun: (&FunSignature, DefinitionId), + args: Vec<(Option, ValueEnum<'ctx>)>, + generator: &mut dyn CodeGenerator, +) -> Result, String> { + assert!(obj.is_none()); + assert_eq!(args.len(), 1); + + let llvm_usize = generator.get_size_type(context.ctx); + + let n_ty = fun.0.args[0].ty; + let n_arg = args[0].1.clone() + .to_basic_value_enum(context, generator, n_ty)?; + + call_ndarray_eye_impl( + generator, + context, + context.primitives.float, + n_arg.into_int_value(), + n_arg.into_int_value(), + llvm_usize.const_zero(), + ) } \ No newline at end of file diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index cfe000f..1462ca6 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -898,9 +898,14 @@ impl<'a> Inferencer<'a> { if [ "np_ndarray".into(), "np_empty".into(), + "np_zeros".into(), + "np_ones".into(), ].contains(id) && args.len() == 1 { let ExprKind::List { elts, .. } = &args[0].node else { - return report_error("Expected List literal for first argument of np_ndarray", args[0].location) + return report_error( + format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(), + args[0].location + ) }; let ndims = elts.len() as u64; @@ -941,6 +946,62 @@ impl<'a> Inferencer<'a> { })) } + // 2-argument ndarray n-dimensional creation functions + if id == &"np_full".into() && args.len() == 2 { + let ExprKind::List { elts, .. } = &args[0].node else { + return report_error( + format!("Expected List literal for first argument of {id}, got {}", args[0].node.name()).as_str(), + args[0].location + ) + }; + + let ndims = elts.len() as u64; + + let arg0 = self.fold_expr(args.remove(0))?; + let arg1 = self.fold_expr(args.remove(0))?; + + let ty = arg1.custom.unwrap(); + let ndims = self.unifier.get_fresh_literal( + vec![SymbolValue::U64(ndims)], + None, + ); + + let ret = self.unifier.add_ty(TypeEnum::TNDArray { + ty, + ndims + }); + let custom = self.unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![ + FuncArg { + name: "shape".into(), + ty: arg0.custom.unwrap(), + default_value: None, + }, + FuncArg { + name: "fill_value".into(), + ty: arg1.custom.unwrap(), + default_value: None, + }, + ], + ret, + vars: HashMap::new(), + })); + + return Ok(Some(Located { + location, + custom: Some(ret), + node: ExprKind::Call { + func: Box::new(Located { + custom: Some(custom), + location: func.location, + node: ExprKind::Name { id: *id, ctx: ctx.clone() }, + }), + args: vec![arg0, arg1], + keywords: vec![], + }, + })) + } + Ok(None) } diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index abdeda9..03deff4 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -187,6 +187,11 @@ def patch(module): # NumPy NDArray Functions module.np_ndarray = np.ndarray module.np_empty = np.empty + module.np_zeros = np.zeros + module.np_ones = np.ones + module.np_full = np.full + module.np_eye = np.eye + module.np_identity = np.identity def file_import(filename, prefix="file_import_"): filename = pathlib.Path(filename) diff --git a/nac3standalone/demo/src/ndarray.py b/nac3standalone/demo/src/ndarray.py index 1237d06..1ab153b 100644 --- a/nac3standalone/demo/src/ndarray.py +++ b/nac3standalone/demo/src/ndarray.py @@ -7,6 +7,12 @@ def consume_ndarray_i32_1(n: ndarray[int32, Literal[1]]): def consume_ndarray_2(n: ndarray[float, Literal[2]]): pass +def consume_ndarray_i32_1(n: ndarray[int32, 1]): + pass + +def consume_ndarray_2(n: ndarray[float, 2]): + pass + def test_ndarray_ctor(): n = np_ndarray([1]) consume_ndarray_1(n) @@ -15,8 +21,35 @@ def test_ndarray_empty(): n = np_empty([1]) consume_ndarray_1(n) +def test_ndarray_zeros(): + n = np_zeros([1]) + consume_ndarray_1(n) + +def test_ndarray_ones(): + n = np_ones([1]) + consume_ndarray_1(n) + +def test_ndarray_full(): + n_float = np_full([1], 2.0) + consume_ndarray_1(n_float) + n_i32 = np_full([1], 2) + consume_ndarray_i32_1(n_i32) + +def test_ndarray_eye(): + n = np_eye(2) + consume_ndarray_2(n) + +def test_ndarray_identity(): + n = np_identity(2) + consume_ndarray_2(n) + def run() -> int32: test_ndarray_ctor() test_ndarray_empty() + test_ndarray_zeros() + test_ndarray_ones() + test_ndarray_full() + test_ndarray_eye() + test_ndarray_identity() return 0 -- 2.44.1