From 79c469301aba7f18e12ef0e831a009c1e57eb7a5 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Sat, 5 Mar 2022 03:45:09 +0800 Subject: [PATCH] basic unsigned integer support --- nac3artiq/src/lib.rs | 4 ++ nac3artiq/src/symbol_resolver.rs | 18 +++++ nac3core/src/codegen/concrete_type.rs | 10 +++ nac3core/src/codegen/expr.rs | 52 +++++++++----- nac3core/src/codegen/irrt/irrt.c | 2 + nac3core/src/codegen/irrt/mod.rs | 9 ++- nac3core/src/codegen/mod.rs | 4 ++ nac3core/src/symbol_resolver.rs | 14 +++- nac3core/src/toplevel/builtins.rs | 68 +++++++++++++++++-- nac3core/src/toplevel/composer.rs | 2 + nac3core/src/toplevel/helper.rs | 68 ++++++++++++++++++- nac3core/src/toplevel/type_annotation.rs | 4 ++ nac3core/src/typecheck/magic_methods.rs | 55 +++++++-------- nac3core/src/typecheck/type_inferencer/mod.rs | 56 ++++++++++++++- 14 files changed, 306 insertions(+), 60 deletions(-) diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index c4e8fc77..ffc0058b 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -59,6 +59,8 @@ pub struct PrimitivePythonId { int: u64, int32: u64, int64: u64, + uint32: u64, + uint64: u64, float: u64, bool: u64, list: u64, @@ -362,6 +364,8 @@ impl Nac3 { int: id_fn.call1((builtins_mod.getattr("int").unwrap(),)).unwrap().extract().unwrap(), int32: id_fn.call1((numpy_mod.getattr("int32").unwrap(),)).unwrap().extract().unwrap(), int64: id_fn.call1((numpy_mod.getattr("int64").unwrap(),)).unwrap().extract().unwrap(), + uint32: id_fn.call1((numpy_mod.getattr("uint32").unwrap(),)).unwrap().extract().unwrap(), + uint64: id_fn.call1((numpy_mod.getattr("uint64").unwrap(),)).unwrap().extract().unwrap(), bool: id_fn.call1((builtins_mod.getattr("bool").unwrap(),)).unwrap().extract().unwrap(), float: id_fn .call1((builtins_mod.getattr("float").unwrap(),)) diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index d59b0fac..152b5975 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -24,6 +24,8 @@ use crate::PrimitivePythonId; pub enum PrimitiveValue { I32(i32), I64(i64), + U32(u32), + U64(u64), F64(f64), Bool(bool), } @@ -115,6 +117,8 @@ impl StaticValue for PythonValue { return Ok(match val { PrimitiveValue::I32(val) => ctx.ctx.i32_type().const_int(*val as u64, false).into(), PrimitiveValue::I64(val) => ctx.ctx.i64_type().const_int(*val as u64, false).into(), + PrimitiveValue::U32(val) => ctx.ctx.i32_type().const_int(*val as u64, false).into(), + PrimitiveValue::U64(val) => ctx.ctx.i64_type().const_int(*val as u64, false).into(), PrimitiveValue::F64(val) => ctx.ctx.f64_type().const_float(*val).into(), PrimitiveValue::Bool(val) => { ctx.ctx.bool_type().const_int(*val as u64, false).into() @@ -238,6 +242,10 @@ impl InnerResolver { Ok(Ok((primitives.int32, true))) } else if ty_id == self.primitive_ids.int64 { Ok(Ok((primitives.int64, true))) + } else if ty_id == self.primitive_ids.uint32 { + Ok(Ok((primitives.uint32, true))) + } else if ty_id == self.primitive_ids.uint64 { + Ok(Ok((primitives.uint64, true))) } else if ty_id == self.primitive_ids.bool { Ok(Ok((primitives.bool, true))) } else if ty_id == self.primitive_ids.float { @@ -615,6 +623,16 @@ impl InnerResolver { format!("{} is not in the range of int64", obj)))?; self.id_to_primitive.write().insert(id, PrimitiveValue::I64(val)); Ok(Some(ctx.ctx.i64_type().const_int(val as u64, false).into())) + } else if ty_id == self.primitive_ids.uint32 { + let val: u32 = obj.extract().map_err(|_| super::CompileError::new_err( + format!("{} is not in the range of uint32", obj)))?; + self.id_to_primitive.write().insert(id, PrimitiveValue::U32(val)); + Ok(Some(ctx.ctx.i32_type().const_int(val as u64, false).into())) + } else if ty_id == self.primitive_ids.uint64 { + let val: u64 = obj.extract().map_err(|_| super::CompileError::new_err( + format!("{} is not in the range of uint64", obj)))?; + self.id_to_primitive.write().insert(id, PrimitiveValue::U64(val)); + Ok(Some(ctx.ctx.i64_type().const_int(val, false).into())) } else if ty_id == self.primitive_ids.bool { let val: bool = obj.extract().map_err(|_| super::CompileError::new_err( format!("{} is not in the range of bool", obj)))?; diff --git a/nac3core/src/codegen/concrete_type.rs b/nac3core/src/codegen/concrete_type.rs index bd7573ce..03a99974 100644 --- a/nac3core/src/codegen/concrete_type.rs +++ b/nac3core/src/codegen/concrete_type.rs @@ -28,6 +28,8 @@ pub struct ConcreteFuncArg { pub enum Primitive { Int32, Int64, + UInt32, + UInt64, Float, Bool, None, @@ -72,6 +74,8 @@ impl ConcreteTypeStore { ConcreteTypeEnum::TPrimitive(Primitive::Range), ConcreteTypeEnum::TPrimitive(Primitive::Str), ConcreteTypeEnum::TPrimitive(Primitive::Exception), + ConcreteTypeEnum::TPrimitive(Primitive::UInt32), + ConcreteTypeEnum::TPrimitive(Primitive::UInt64), ], } } @@ -130,6 +134,10 @@ impl ConcreteTypeStore { ConcreteType(6) } else if unifier.unioned(ty, primitives.exception) { ConcreteType(7) + } else if unifier.unioned(ty, primitives.uint32) { + ConcreteType(8) + } else if unifier.unioned(ty, primitives.uint64) { + ConcreteType(9) } else if let Some(cty) = cache.get(&ty) { if let Some(cty) = cty { *cty @@ -223,6 +231,8 @@ impl ConcreteTypeStore { let ty = match primitive { Primitive::Int32 => primitives.int32, Primitive::Int64 => primitives.int64, + Primitive::UInt32 => primitives.uint32, + Primitive::UInt64 => primitives.uint64, Primitive::Float => primitives.float, Primitive::Bool => primitives.bool, Primitive::None => primitives.none, diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 1b896186..3885804e 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -94,6 +94,8 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { match val { SymbolValue::I32(v) => self.ctx.i32_type().const_int(*v as u64, true).into(), SymbolValue::I64(v) => self.ctx.i64_type().const_int(*v as u64, true).into(), + SymbolValue::U32(v) => self.ctx.i32_type().const_int(*v as u64, false).into(), + SymbolValue::U64(v) => self.ctx.i64_type().const_int(*v as u64, false).into(), SymbolValue::Bool(v) => self.ctx.bool_type().const_int(*v as u64, true).into(), SymbolValue::Double(v) => self.ctx.f64_type().const_float(*v).into(), SymbolValue::Str(v) => { @@ -152,9 +154,13 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ty.const_int(if *v { 1 } else { 0 }, false).into() } Constant::Int(Some(val)) => { - let ty = if self.unifier.unioned(ty, self.primitives.int32) { + let ty = if self.unifier.unioned(ty, self.primitives.int32) + || self.unifier.unioned(ty, self.primitives.uint32) + { self.ctx.i32_type() - } else if self.unifier.unioned(ty, self.primitives.int64) { + } else if self.unifier.unioned(ty, self.primitives.int64) + || self.unifier.unioned(ty, self.primitives.uint64) + { self.ctx.i64_type() } else { unreachable!(); @@ -201,6 +207,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { op: &Operator, lhs: BasicValueEnum<'ctx>, rhs: BasicValueEnum<'ctx>, + signed: bool ) -> BasicValueEnum<'ctx> { let (lhs, rhs) = if let (BasicValueEnum::IntValue(lhs), BasicValueEnum::IntValue(rhs)) = (lhs, rhs) { @@ -208,26 +215,33 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } else { unreachable!() }; - match op { - Operator::Add => self.builder.build_int_add(lhs, rhs, "add").into(), - Operator::Sub => self.builder.build_int_sub(lhs, rhs, "sub").into(), - Operator::Mult => self.builder.build_int_mul(lhs, rhs, "mul").into(), - Operator::Div => { - let float = self.ctx.f64_type(); + let float = self.ctx.f64_type(); + match (op, signed) { + (Operator::Add, _) => self.builder.build_int_add(lhs, rhs, "add").into(), + (Operator::Sub, _) => self.builder.build_int_sub(lhs, rhs, "sub").into(), + (Operator::Mult, _) => self.builder.build_int_mul(lhs, rhs, "mul").into(), + (Operator::Div, true) => { let left = self.builder.build_signed_int_to_float(lhs, float, "i2f"); let right = self.builder.build_signed_int_to_float(rhs, float, "i2f"); self.builder.build_float_div(left, right, "fdiv").into() } - Operator::Mod => self.builder.build_int_signed_rem(lhs, rhs, "mod").into(), - Operator::BitOr => self.builder.build_or(lhs, rhs, "or").into(), - Operator::BitXor => self.builder.build_xor(lhs, rhs, "xor").into(), - Operator::BitAnd => self.builder.build_and(lhs, rhs, "and").into(), - Operator::LShift => self.builder.build_left_shift(lhs, rhs, "lshift").into(), - Operator::RShift => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(), - Operator::FloorDiv => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(), - Operator::Pow => integer_power(self, lhs, rhs).into(), + (Operator::Div, false) => { + let left = self.builder.build_unsigned_int_to_float(lhs, float, "i2f"); + let right = self.builder.build_unsigned_int_to_float(rhs, float, "i2f"); + self.builder.build_float_div(left, right, "fdiv").into() + } + (Operator::Mod, true) => self.builder.build_int_signed_rem(lhs, rhs, "mod").into(), + (Operator::Mod, false) => self.builder.build_int_unsigned_rem(lhs, rhs, "mod").into(), + (Operator::BitOr, _) => self.builder.build_or(lhs, rhs, "or").into(), + (Operator::BitXor, _) => self.builder.build_xor(lhs, rhs, "xor").into(), + (Operator::BitAnd, _) => self.builder.build_and(lhs, rhs, "and").into(), + (Operator::LShift, _) => self.builder.build_left_shift(lhs, rhs, "lshift").into(), + (Operator::RShift, _) => self.builder.build_right_shift(lhs, rhs, true, "rshift").into(), + (Operator::FloorDiv, true) => self.builder.build_int_signed_div(lhs, rhs, "floordiv").into(), + (Operator::FloorDiv, false) => self.builder.build_int_unsigned_div(lhs, rhs, "floordiv").into(), + (Operator::Pow, s) => integer_power(self, lhs, rhs, s).into(), // special implementation? - Operator::MatMult => unreachable!(), + (Operator::MatMult, _) => unreachable!(), } } @@ -807,7 +821,9 @@ pub fn gen_binop_expr<'ctx, 'a, G: CodeGenerator>( // which would be unchanged until further unification, which we would never do // when doing code generation for function instances Ok(if ty1 == ty2 && [ctx.primitives.int32, ctx.primitives.int64].contains(&ty1) { - ctx.gen_int_ops(op, left, right) + ctx.gen_int_ops(op, left, right, true) + } else if ty1 == ty2 && [ctx.primitives.uint32, ctx.primitives.uint64].contains(&ty1) { + ctx.gen_int_ops(op, left, right, false) } else if ty1 == ty2 && ctx.primitives.float == ty1 { ctx.gen_float_ops(op, left, right) } else if ty1 == ctx.primitives.float && ty2 == ctx.primitives.int32 { diff --git a/nac3core/src/codegen/irrt/irrt.c b/nac3core/src/codegen/irrt/irrt.c index fe316992..4c091437 100644 --- a/nac3core/src/codegen/irrt/irrt.c +++ b/nac3core/src/codegen/irrt/irrt.c @@ -26,6 +26,8 @@ typedef unsigned _ExtInt(64) uint64_t; DEF_INT_EXP(int32_t) DEF_INT_EXP(int64_t) +DEF_INT_EXP(uint32_t) +DEF_INT_EXP(uint64_t) int32_t __nac3_slice_index_bound(int32_t i, const int32_t len) { diff --git a/nac3core/src/codegen/irrt/mod.rs b/nac3core/src/codegen/irrt/mod.rs index d972878b..5b55f0ad 100644 --- a/nac3core/src/codegen/irrt/mod.rs +++ b/nac3core/src/codegen/irrt/mod.rs @@ -37,10 +37,13 @@ pub fn integer_power<'ctx, 'a>( ctx: &mut CodeGenContext<'ctx, 'a>, base: IntValue<'ctx>, exp: IntValue<'ctx>, + signed: bool, ) -> IntValue<'ctx> { - let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width()) { - (32, 32) => "__nac3_int_exp_int32_t", - (64, 64) => "__nac3_int_exp_int64_t", + let symbol = match (base.get_type().get_bit_width(), exp.get_type().get_bit_width(), signed) { + (32, 32, true) => "__nac3_int_exp_int32_t", + (64, 64, true) => "__nac3_int_exp_int64_t", + (32, 32, false) => "__nac3_int_exp_uint32_t", + (64, 64, false) => "__nac3_int_exp_uint64_t", _ => unreachable!(), }; let base_type = base.get_type(); diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index a76980fe..ee7d3a0d 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -360,6 +360,8 @@ pub fn gen_func<'ctx, G: CodeGenerator>( let primitives = PrimitiveStore { int32: unifier.get_representative(primitives.int32), int64: unifier.get_representative(primitives.int64), + uint32: unifier.get_representative(primitives.uint32), + uint64: unifier.get_representative(primitives.uint64), float: unifier.get_representative(primitives.float), bool: unifier.get_representative(primitives.bool), none: unifier.get_representative(primitives.none), @@ -371,6 +373,8 @@ pub fn gen_func<'ctx, G: CodeGenerator>( let mut type_cache: HashMap<_, _> = [ (primitives.int32, context.i32_type().into()), (primitives.int64, context.i64_type().into()), + (primitives.uint32, context.i32_type().into()), + (primitives.uint64, context.i64_type().into()), (primitives.float, context.f64_type().into()), (primitives.bool, context.bool_type().into()), (primitives.str, { diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 0a22cc06..186453f2 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -23,6 +23,8 @@ use parking_lot::RwLock; pub enum SymbolValue { I32(i32), I64(i64), + U32(u32), + U64(u64), Str(String), Double(f64), Bool(bool), @@ -34,6 +36,8 @@ impl Display for SymbolValue { match self { SymbolValue::I32(i) => write!(f, "{}", i), SymbolValue::I64(i) => write!(f, "int64({})", i), + SymbolValue::U32(i) => write!(f, "uint32({})", i), + SymbolValue::U64(i) => write!(f, "uint64({})", i), SymbolValue::Str(s) => write!(f, "\"{}\"", s), SymbolValue::Double(d) => write!(f, "{}", d), SymbolValue::Bool(b) => { @@ -141,7 +145,7 @@ pub trait SymbolResolver { } thread_local! { - static IDENTIFIER_ID: [StrRef; 10] = [ + static IDENTIFIER_ID: [StrRef; 12] = [ "int32".into(), "int64".into(), "float".into(), @@ -152,6 +156,8 @@ thread_local! { "tuple".into(), "str".into(), "Exception".into(), + "uint32".into(), + "uint64".into(), ]; } @@ -175,12 +181,18 @@ pub fn parse_type_annotation( let tuple_id = ids[7]; let str_id = ids[8]; let exn_id = ids[9]; + let uint32_id = ids[10]; + let uint64_id = ids[11]; let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| { if *id == int32_id { Ok(primitives.int32) } else if *id == int64_id { Ok(primitives.int64) + } else if *id == uint32_id { + Ok(primitives.uint32) + } else if *id == uint64_id { + Ok(primitives.uint64) } else if *id == float_id { Ok(primitives.float) } else if *id == bool_id { diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 0007286c..c29afcd7 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -12,12 +12,14 @@ type BuiltinInfo = (Vec<(Arc>, Option)>, &'static [&'s pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { let int32 = primitives.0.int32; let int64 = primitives.0.int64; + let uint32 = primitives.0.uint32; + let uint64 = primitives.0.uint64; let float = primitives.0.float; let boolean = primitives.0.bool; let range = primitives.0.range; let string = primitives.0.str; let num_ty = primitives.1.get_fresh_var_with_range( - &[int32, int64, float, boolean], + &[int32, int64, float, boolean, uint32, uint64], Some("N".into()), None, ); @@ -35,12 +37,12 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { ("__param2__".into(), int64, true), ]; let div_by_zero = primitives.1.add_ty(TypeEnum::TObj { - obj_id: DefinitionId(10), + obj_id: DefinitionId(12), fields: exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect(), params: Default::default(), }); let index_error = primitives.1.add_ty(TypeEnum::TObj { - obj_id: DefinitionId(11), + obj_id: DefinitionId(13), fields: exception_fields.iter().map(|(a, b, c)| (*a, (*b, *c))).collect(), params: Default::default(), }); @@ -125,6 +127,20 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { resolver: None, loc: None, })), + Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( + 8, + None, + "uint32".into(), + None, + None, + ))), + Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( + 9, + None, + "uint64".into(), + None, + None, + ))), Arc::new(RwLock::new(TopLevelDef::Function { name: "ZeroDivisionError.__init__".into(), simple_name: "__init__".into(), @@ -149,7 +165,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { })), Arc::new(RwLock::new(TopLevelDef::Class { name: "ZeroDivisionError".into(), - object_id: DefinitionId(10), + object_id: DefinitionId(12), type_vars: Default::default(), fields: exception_fields.clone(), methods: vec![("__init__".into(), div_by_zero_signature, DefinitionId(8))], @@ -163,7 +179,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { })), Arc::new(RwLock::new(TopLevelDef::Class { name: "IndexError".into(), - object_id: DefinitionId(11), + object_id: DefinitionId(13), type_vars: Default::default(), fields: exception_fields, methods: vec![("__init__".into(), index_error_signature, DefinitionId(9))], @@ -295,6 +311,46 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "uint32".into(), + simple_name: "uint32".into(), + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], + ret: uint32, + vars: var_map.clone(), + })), + var_id: Default::default(), + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + // TODO: + unimplemented!() + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "uint64".into(), + simple_name: "uint64".into(), + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: num_ty.0, default_value: None }], + ret: uint64, + vars: var_map.clone(), + })), + var_id: Default::default(), + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, fun, args, generator| { + // TODO: + unimplemented!() + }, + )))), + loc: None, + })), Arc::new(RwLock::new(TopLevelDef::Function { name: "float".into(), simple_name: "float".into(), @@ -797,6 +853,8 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "IndexError", "int32", "int64", + "uint32", + "uint64", "float", "round", "round64", diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 694d0a06..f6535723 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -62,6 +62,8 @@ impl TopLevelComposer { "tuple".into(), "int32".into(), "int64".into(), + "uint32".into(), + "uint64".into(), "float".into(), "bool".into(), "none".into(), diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 6c131e6f..58ba9fa6 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -97,7 +97,17 @@ impl TopLevelComposer { .collect::>(), params: HashMap::new(), }); - let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception }; + let uint32 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(8), + fields: HashMap::new(), + params: HashMap::new(), + }); + let uint64 = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(9), + fields: HashMap::new(), + params: HashMap::new(), + }); + let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception, uint32, uint64 }; crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); (primitives, unifier) } @@ -399,6 +409,20 @@ impl TopLevelComposer { Some("int64".to_string()) } } + SymbolValue::U32(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.uint32) { + None + } else { + Some("uint32".to_string()) + } + } + SymbolValue::U64(..) => { + if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.uint64) { + None + } else { + Some("uint64".to_string()) + } + } SymbolValue::Str(..) => { if matches!(ty, TypeAnnotation::Primitive(t) if *t == primitive.str) { None @@ -475,6 +499,48 @@ pub fn parse_parameter_default_value( Err(format!("only allow constant integer here at {}", default.location)) } } + ast::ExprKind::Call { func, args, .. } if { + match &func.node { + ast::ExprKind::Name { id, .. } => *id == "uint32".into(), + _ => false, + } + } => { + if args.len() == 1 { + match &args[0].node { + ast::ExprKind::Constant { value: Constant::Int(Some(v)), .. } => { + let v: Result = (*v).try_into(); + match v { + Ok(v) => Ok(SymbolValue::U32(v)), + _ => Err(format!("default param value out of range at {}", default.location)) + } + } + _ => Err(format!("only allow constant integer here at {}", default.location)) + } + } else { + Err(format!("only allow constant integer here at {}", default.location)) + } + } + ast::ExprKind::Call { func, args, .. } if { + match &func.node { + ast::ExprKind::Name { id, .. } => *id == "uint64".into(), + _ => false, + } + } => { + if args.len() == 1 { + match &args[0].node { + ast::ExprKind::Constant { value: Constant::Int(Some(v)), .. } => { + let v: Result = (*v).try_into(); + match v { + Ok(v) => Ok(SymbolValue::U64(v)), + _ => Err(format!("default param value out of range at {}", default.location)) + } + } + _ => Err(format!("only allow constant integer here at {}", default.location)) + } + } else { + Err(format!("only allow constant integer here at {}", default.location)) + } + } ast::ExprKind::Tuple { elts, .. } => Ok(SymbolValue::Tuple(elts .iter() .map(|x| parse_parameter_default_value(x, resolver)) diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 497218eb..9fc1f80c 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -64,6 +64,10 @@ pub fn parse_ast_to_type_annotation_kinds( Ok(TypeAnnotation::Primitive(primitives.int32)) } else if id == &"int64".into() { Ok(TypeAnnotation::Primitive(primitives.int64)) + } else if id == &"uint32".into() { + Ok(TypeAnnotation::Primitive(primitives.uint32)) + } else if id == &"uint64".into() { + Ok(TypeAnnotation::Primitive(primitives.uint64)) } else if id == &"float".into() { Ok(TypeAnnotation::Primitive(primitives.float)) } else if id == &"bool".into() { diff --git a/nac3core/src/typecheck/magic_methods.rs b/nac3core/src/typecheck/magic_methods.rs index 95dd8436..36d46a6d 100644 --- a/nac3core/src/typecheck/magic_methods.rs +++ b/nac3core/src/typecheck/magic_methods.rs @@ -286,35 +286,32 @@ pub fn impl_eq(unifier: &mut Unifier, store: &PrimitiveStore, ty: Type) { } pub fn set_primitives_magic_methods(store: &PrimitiveStore, unifier: &mut Unifier) { - let PrimitiveStore { int32: int32_t, int64: int64_t, float: float_t, bool: bool_t, .. } = - *store; - /* int32 ======== */ - impl_basic_arithmetic(unifier, store, int32_t, &[int32_t], int32_t); - impl_pow(unifier, store, int32_t, &[int32_t], int32_t); - impl_bitwise_arithmetic(unifier, store, int32_t); - impl_bitwise_shift(unifier, store, int32_t); - impl_div(unifier, store, int32_t, &[int32_t]); - impl_floordiv(unifier, store, int32_t, &[int32_t], int32_t); - impl_mod(unifier, store, int32_t, &[int32_t], int32_t); - impl_sign(unifier, store, int32_t); - impl_invert(unifier, store, int32_t); - impl_not(unifier, store, int32_t); - impl_comparison(unifier, store, int32_t, int32_t); - impl_eq(unifier, store, int32_t); - - /* int64 ======== */ - impl_basic_arithmetic(unifier, store, int64_t, &[int64_t], int64_t); - impl_pow(unifier, store, int64_t, &[int64_t], int64_t); - impl_bitwise_arithmetic(unifier, store, int64_t); - impl_bitwise_shift(unifier, store, int64_t); - impl_div(unifier, store, int64_t, &[int64_t]); - impl_floordiv(unifier, store, int64_t, &[int64_t], int64_t); - impl_mod(unifier, store, int64_t, &[int64_t], int64_t); - impl_sign(unifier, store, int64_t); - impl_invert(unifier, store, int64_t); - impl_not(unifier, store, int64_t); - impl_comparison(unifier, store, int64_t, int64_t); - impl_eq(unifier, store, int64_t); + let PrimitiveStore { + int32: int32_t, + int64: int64_t, + float: float_t, + bool: bool_t, + uint32: uint32_t, + uint64: uint64_t, + .. + } = *store; + /* int ======== */ + for t in [int32_t, int64_t, uint32_t, uint64_t] { + impl_basic_arithmetic(unifier, store, t, &[t], t); + impl_pow(unifier, store, t, &[t], t); + impl_bitwise_arithmetic(unifier, store, t); + impl_bitwise_shift(unifier, store, t); + impl_div(unifier, store, t, &[t]); + impl_floordiv(unifier, store, t, &[t], t); + impl_mod(unifier, store, t, &[t], t); + impl_invert(unifier, store, t); + impl_not(unifier, store, t); + impl_comparison(unifier, store, t, t); + impl_eq(unifier, store, t); + } + for t in [int32_t, int64_t] { + impl_sign(unifier, store, t); + } /* float ======== */ impl_basic_arithmetic(unifier, store, float_t, &[float_t], float_t); diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index 2637128c..d3bae33a 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -32,6 +32,8 @@ impl From for CodeLocation { pub struct PrimitiveStore { pub int32: Type, pub int64: Type, + pub uint32: Type, + pub uint64: Type, pub float: Type, pub bool: Type, pub none: Type, @@ -779,8 +781,56 @@ impl<'a> Inferencer<'a> { &args[0].node { let custom = Some(self.primitives.int64); - if val.is_none() { - return report_error("Integer out of bound", args[0].location); + match val { + Some(val) if { + let v: Result = (*val).try_into(); + v.is_ok() + } => {}, + _ => return report_error("Integer out of bound", args[0].location) + } + return Ok(Located { + location: args[0].location, + custom, + node: ExprKind::Constant { + value: ast::Constant::Int(*val), + kind: kind.clone(), + }, + }); + } + } + if id == "uint32".into() && args.len() == 1 { + if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = + &args[0].node + { + let custom = Some(self.primitives.uint32); + match val { + Some(val) if { + let v: Result = (*val).try_into(); + v.is_ok() + } => {}, + _ => return report_error("Integer out of bound", args[0].location) + } + return Ok(Located { + location: args[0].location, + custom, + node: ExprKind::Constant { + value: ast::Constant::Int(*val), + kind: kind.clone(), + }, + }); + } + } + if id == "uint64".into() && args.len() == 1 { + if let ExprKind::Constant { value: ast::Constant::Int(val), kind } = + &args[0].node + { + let custom = Some(self.primitives.uint64); + match val { + Some(val) if { + let v: Result = (*val).try_into(); + v.is_ok() + } => {}, + _ => return report_error("Integer out of bound", args[0].location) } return Ok(Located { location: args[0].location, @@ -876,7 +926,7 @@ impl<'a> Inferencer<'a> { match val { Some(val) => { let int32: Result = (*val).try_into(); - // int64 is handled separately in functions + // int64 and unsigned integers are handled separately in functions if int32.is_ok() { Ok(self.primitives.int32) } else {