diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index ac37dcdd..6f03f0d2 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -33,7 +33,7 @@ impl<'ctx> CodeGenContext<'ctx> { .join(", ") } - fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize { + pub fn get_attr_index(&mut self, ty: Type, attr: &str) -> usize { let obj_id = match &*self.unifier.get_ty(ty) { TypeEnum::TObj { obj_id, .. } => *obj_id, // we cannot have other types, virtual type should be handled by function calls @@ -48,7 +48,7 @@ impl<'ctx> CodeGenContext<'ctx> { index } - fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> { + pub fn get_llvm_type(&mut self, ty: Type) -> BasicTypeEnum<'ctx> { use TypeEnum::*; // we assume the type cache should already contain primitive types, // and they should be passed by value instead of passing as pointer. @@ -275,7 +275,15 @@ impl<'ctx> CodeGenContext<'ctx> { } ExprKind::Name { id, .. } => { let ptr = self.var_assignment.get(id).unwrap(); - self.builder.build_load(*ptr, "load") + let primitives = &self.primitives; + // we should only dereference primitive types + if [primitives.int32, primitives.int64, primitives.float, primitives.bool] + .contains(&self.unifier.get_representative(expr.custom.unwrap())) + { + self.builder.build_load(*ptr, "load") + } else { + (*ptr).into() + } } ExprKind::List { elts, .. } => { // this shall be optimized later for constant primitive lists... @@ -472,7 +480,7 @@ impl<'ctx> CodeGenContext<'ctx> { ops.iter(), ) .fold(None, |prev, (lhs, rhs, op)| { - let ty = lhs.custom.unwrap(); + let ty = self.unifier.get_representative(lhs.custom.unwrap()); let current = if [self.primitives.int32, self.primitives.int64, self.primitives.bool] .contains(&ty) diff --git a/nac3core/src/codegen/helper.rs b/nac3core/src/codegen/helper.rs deleted file mode 100644 index 8b137891..00000000 --- a/nac3core/src/codegen/helper.rs +++ /dev/null @@ -1 +0,0 @@ - diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index 89a10d8c..95ee4bbf 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -1,2 +1,2 @@ mod expr; -mod helper; +mod stmt; diff --git a/nac3core/src/codegen/stmt.rs b/nac3core/src/codegen/stmt.rs new file mode 100644 index 00000000..1fa65060 --- /dev/null +++ b/nac3core/src/codegen/stmt.rs @@ -0,0 +1,88 @@ +use std::convert::TryInto; + +use crate::{top_level::CodeGenContext, typecheck::typedef::Type}; +use inkwell::{ + types::BasicTypeEnum, + values::{BasicValueEnum, PointerValue}, +}; +use rustpython_parser::ast::{Expr, ExprKind, Stmt, StmtKind}; + +impl<'ctx> CodeGenContext<'ctx> { + fn gen_var(&mut self, ty: Type) -> PointerValue<'ctx> { + let ty = self.get_llvm_type(ty); + let ty = if let BasicTypeEnum::PointerType(ty) = ty { + ty.get_element_type().try_into().unwrap() + } else { + ty + }; + self.builder.build_alloca(ty, "tmp") + } + + fn parse_pattern(&mut self, pattern: &Expr>) -> PointerValue<'ctx> { + // very similar to gen_expr, but we don't do an extra load at the end + // and we flatten nested tuples + match &pattern.node { + ExprKind::Name { id, .. } => { + self.var_assignment.get(id).cloned().unwrap_or_else(|| { + let ptr = self.gen_var(pattern.custom.unwrap()); + self.var_assignment.insert(id.clone(), ptr); + ptr + }) + } + ExprKind::Attribute { value, attr, .. } => { + let index = self.get_attr_index(value.custom.unwrap(), attr); + let val = self.gen_expr(value); + let ptr = if let BasicValueEnum::PointerValue(v) = val { + v + } else { + unreachable!(); + }; + unsafe { + ptr.const_in_bounds_gep(&[ + self.ctx.i32_type().const_zero(), + self.ctx.i32_type().const_int(index as u64, false), + ]) + } + } + ExprKind::Subscript { .. } => unimplemented!(), + _ => unreachable!(), + } + } + + fn gen_assignment(&mut self, target: &Expr>, value: BasicValueEnum<'ctx>) { + if let ExprKind::Tuple { elts, .. } = &target.node { + if let BasicValueEnum::PointerValue(ptr) = value { + for (i, elt) in elts.iter().enumerate() { + unsafe { + let t = ptr.const_in_bounds_gep(&[ + self.ctx.i32_type().const_zero(), + self.ctx.i32_type().const_int(i as u64, false), + ]); + let v = self.builder.build_load(t, "tmpload"); + self.gen_assignment(elt, v); + } + } + } else { + unreachable!() + } + } else { + let ptr = self.parse_pattern(target); + self.builder.build_store(ptr, value); + } + } + + pub fn gen_stmt(&mut self, stmt: &Stmt>) { + match &stmt.node { + StmtKind::Expr { value } => { + self.gen_expr(&value); + } + StmtKind::Assign { targets, value, .. } => { + let value = self.gen_expr(&value); + for target in targets.iter() { + self.gen_assignment(target, value); + } + } + _ => unimplemented!(), + } + } +}