diff --git a/nac3artiq/demo/embedding_map.py b/nac3artiq/demo/embedding_map.py index 74b9aa1a..dc64d7df 100644 --- a/nac3artiq/demo/embedding_map.py +++ b/nac3artiq/demo/embedding_map.py @@ -16,7 +16,8 @@ class EmbeddingMap: "CacheError", "SPIError", "0:ZeroDivisionError", - "0:IndexError"]) + "0:IndexError", + "0:UnwrapNoneError"]) def preallocate_runtime_exception_names(self, names): for i, name in enumerate(names): diff --git a/nac3artiq/demo/min_artiq.py b/nac3artiq/demo/min_artiq.py index 5302cf19..1dd5786a 100644 --- a/nac3artiq/demo/min_artiq.py +++ b/nac3artiq/demo/min_artiq.py @@ -11,6 +11,7 @@ from embedding_map import EmbeddingMap __all__ = [ "Kernel", "KernelInvariant", "virtual", + "Option", "Some", "none", "UnwrapNoneError", "round64", "floor64", "ceil64", "extern", "kernel", "portable", "nac3", "rpc", "ms", "us", "ns", @@ -32,6 +33,39 @@ class KernelInvariant(Generic[T]): class virtual(Generic[T]): pass +class Option(Generic[T]): + _nac3_option: T + + def __init__(self, v: T): + self._nac3_option = v + + def is_none(self): + return self._nac3_option is None + + def is_some(self): + return not self.is_none() + + def unwrap(self): + if self.is_none(): + raise UnwrapNoneError() + return self._nac3_option + + def __repr__(self) -> str: + if self.is_none(): + return "none" + else: + return "Some({})".format(repr(self._nac3_option)) + + def __str__(self) -> str: + if self.is_none(): + return "none" + else: + return "Some({})".format(str(self._nac3_option)) + +def Some(v: T) -> Option[T]: + return Option(v) + +none = Option(None) def round64(x): return round(x) @@ -240,5 +274,10 @@ class KernelContextManager: def __exit__(self): pass +@nac3 +class UnwrapNoneError(Exception): + """raised when unwrapping a none value""" + artiq_builtin = True + parallel = KernelContextManager() sequential = KernelContextManager() diff --git a/nac3artiq/src/lib.rs b/nac3artiq/src/lib.rs index 13bc850d..16827588 100644 --- a/nac3artiq/src/lib.rs +++ b/nac3artiq/src/lib.rs @@ -71,6 +71,7 @@ pub struct PrimitivePythonId { exception: u64, generic_alias: (u64, u64), virtual_id: u64, + option: u64, } type TopLevelComponent = (Stmt, String, PyObject); @@ -373,7 +374,17 @@ impl Nac3 { get_attr_id(typing_mod, "_GenericAlias"), get_attr_id(types_mod, "GenericAlias"), ), - none: get_attr_id(builtins_mod, "None"), + none: id_fn + .call1((builtins_mod + .getattr("globals") + .unwrap() + .call0() + .unwrap() + .get_item("none") + .unwrap(),)) + .unwrap() + .extract() + .unwrap(), typevar: get_attr_id(typing_mod, "TypeVar"), int: get_attr_id(builtins_mod, "int"), int32: get_attr_id(numpy_mod, "int32"), @@ -385,6 +396,17 @@ impl Nac3 { list: get_attr_id(builtins_mod, "list"), tuple: get_attr_id(builtins_mod, "tuple"), exception: get_attr_id(builtins_mod, "Exception"), + option: id_fn + .call1((builtins_mod + .getattr("globals") + .unwrap() + .call0() + .unwrap() + .get_item("Option") + .unwrap(),)) + .unwrap() + .extract() + .unwrap(), }; let working_directory = tempfile::Builder::new().prefix("nac3-").tempdir().unwrap(); @@ -474,7 +496,8 @@ impl Nac3 { "KeyError", "NotImplementedError", "OverflowError", - "IOError" + "IOError", + "UnwrapNoneError", ]; add_exceptions(&mut composer, &mut builtins_def, &mut builtins_ty, &exception_names); diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index e26435c0..eb123d75 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -279,6 +279,10 @@ impl InnerResolver { } else if ty_id == self.primitive_ids.tuple { // do not handle type var param and concrete check here Ok(Ok((unifier.add_ty(TypeEnum::TTuple { ty: vec![] }), false))) + } else if ty_id == self.primitive_ids.option { + Ok(Ok((primitives.option, false))) + } else if ty_id == self.primitive_ids.none { + unreachable!("none cannot be typeid") } else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).cloned() { let def = defs[def_id.0].read(); if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def { @@ -569,6 +573,52 @@ impl InnerResolver { let types = types?; Ok(types.map(|types| unifier.add_ty(TypeEnum::TTuple { ty: types }))) } + // special handling for option type since its class member layout in python side + // is special and cannot be mapped directly to a nac3 type as below + (TypeEnum::TObj { obj_id, params, .. }, false) + if *obj_id == primitives.option.get_obj_id(unifier) => + { + let field_data = match obj.getattr("_nac3_option") { + Ok(d) => d, + // we use `none = Option(None)`, so the obj always have attr `_nac3_option` + Err(_) => unreachable!("cannot be None") + }; + // if is `none` + let zelf_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; + if zelf_id == self.primitive_ids.none { + if let TypeEnum::TObj { params, .. } = + unifier.get_ty_immutable(primitives.option).as_ref() + { + let var_map = params + .iter() + .map(|(id_var, ty)| { + if let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) { + assert_eq!(*id, *id_var); + (*id, unifier.get_fresh_var_with_range(range, *name, *loc).0) + } else { + unreachable!() + } + }) + .collect::>(); + return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap())) + } else { + unreachable!("must be tobj") + } + } + + let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? { + Ok(t) => t, + Err(e) => { + return Ok(Err(format!( + "error when getting type of the option object ({})", + e + ))) + } + }; + let new_var_map: HashMap<_, _> = params.iter().map(|(id, _)| (*id, ty)).collect(); + let res = unifier.subst(extracted_ty, &new_var_map).unwrap_or(extracted_ty); + Ok(Ok(res)) + } (TypeEnum::TObj { params, fields, .. }, false) => { let var_map = params .iter() @@ -795,6 +845,39 @@ impl InnerResolver { let global = ctx.module.add_global(ty, Some(AddressSpace::Generic), &id_str); global.set_initializer(&val); Ok(Some(global.as_pointer_value().into())) + } else if ty_id == self.primitive_ids.option { + if id == self.primitive_ids.none { + // for option type, just a null ptr, whose type needs to be casted in codegen + // according to the type info attached in the ast + Ok(Some(ctx.ctx.i8_type().ptr_type(AddressSpace::Generic).const_null().into())) + } else { + match self + .get_obj_value(py, obj.getattr("_nac3_option").unwrap(), ctx, generator) + .map_err(|e| { + super::CompileError::new_err(format!( + "Error getting value of Option object: {}", + e + )) + })? { + Some(v) => { + let global_str = format!("{}_option", id); + { + if self.global_value_ids.read().contains(&id) { + let global = ctx.module.get_global(&global_str).unwrap_or_else(|| { + ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str) + }); + return Ok(Some(global.as_pointer_value().into())); + } else { + self.global_value_ids.write().insert(id); + } + } + let global = ctx.module.add_global(v.get_type(), Some(AddressSpace::Generic), &global_str); + global.set_initializer(&v); + Ok(Some(global.as_pointer_value().into())) + }, + None => Ok(None), + } + } } else { let id_str = id.to_string(); diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 5afeba8c..7991b55c 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -138,6 +138,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { &mut self.unifier, self.top_level, &mut self.type_cache, + &self.primitives, ty, ) } @@ -934,6 +935,22 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( let ty = expr.custom.unwrap(); ctx.gen_const(generator, value, ty).into() } + ExprKind::Name { id, .. } if id == &"none".into() => { + match ( + ctx.unifier.get_ty(expr.custom.unwrap()).as_ref(), + ctx.unifier.get_ty(ctx.primitives.option).as_ref(), + ) { + ( + TypeEnum::TObj { obj_id, params, .. }, + TypeEnum::TObj { obj_id: opt_id, .. }, + ) if *obj_id == *opt_id => ctx + .get_llvm_type(generator, *params.iter().next().unwrap().1) + .ptr_type(AddressSpace::Generic) + .const_null() + .into(), + _ => unreachable!("must be option type"), + } + } ExprKind::Name { id, .. } => match ctx.var_assignment.get(id) { Some((ptr, None, _)) => ctx.builder.build_load(*ptr, "load").into(), Some((_, Some(static_value), _)) => ValueEnum::Static(static_value.clone()), @@ -941,16 +958,27 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( let resolver = ctx.resolver.clone(); let val = resolver.get_symbol_value(*id, ctx).unwrap(); // if is tuple, need to deref it to handle tuple as value - if let (TypeEnum::TTuple { .. }, BasicValueEnum::PointerValue(ptr)) = ( + // if is option, need to cast pointer to handle None + match ( &*ctx.unifier.get_ty(expr.custom.unwrap()), resolver .get_symbol_value(*id, ctx) .unwrap() .to_basic_value_enum(ctx, generator)?, ) { - ctx.builder.build_load(ptr, "tup_val").into() - } else { - val + (TypeEnum::TTuple { .. }, BasicValueEnum::PointerValue(ptr)) => { + ctx.builder.build_load(ptr, "tup_val").into() + } + (TypeEnum::TObj { obj_id, params, .. }, BasicValueEnum::PointerValue(ptr)) + if *obj_id == ctx.primitives.option.get_obj_id(&ctx.unifier) => { + let actual_ptr_ty = ctx.get_llvm_type( + generator, + *params.iter().next().unwrap().1, + ) + .ptr_type(AddressSpace::Generic); + ctx.builder.build_bitcast(ptr, actual_ptr_ty, "option_ptr_cast").into() + } + _ => val, } } }, @@ -1281,6 +1309,26 @@ pub fn gen_expr<'ctx, 'a, G: CodeGenerator>( unreachable!() } }; + // directly generate code for option.unwrap + // since it needs location information from ast + if attr == &"unwrap".into() + && id == ctx.primitives.option.get_obj_id(&ctx.unifier) + { + if let BasicValueEnum::PointerValue(ptr) = val.to_basic_value_enum(ctx, generator)? { + let not_null = ctx.builder.build_is_not_null(ptr, "unwrap_not_null"); + ctx.make_assert( + generator, + not_null, + "0:UnwrapNoneError", + "", + [None, None, None], + expr.location, + ); + return Ok(Some(ctx.builder.build_load(ptr, "unwrap_some").into())) + } else { + unreachable!("option must be ptr") + } + } return Ok(generator .gen_call( ctx, diff --git a/nac3core/src/codegen/mod.rs b/nac3core/src/codegen/mod.rs index fd96e5d9..bce050c2 100644 --- a/nac3core/src/codegen/mod.rs +++ b/nac3core/src/codegen/mod.rs @@ -259,6 +259,7 @@ fn get_llvm_type<'ctx>( unifier: &mut Unifier, top_level: &TopLevelContext, type_cache: &mut HashMap>, + primitives: &PrimitiveStore, ty: Type, ) -> BasicTypeEnum<'ctx> { use TypeEnum::*; @@ -268,9 +269,28 @@ fn get_llvm_type<'ctx>( let ty_enum = unifier.get_ty(ty); let result = match &*ty_enum { TObj { obj_id, fields, .. } => { - // check to avoid treating primitives as classes - if obj_id.0 <= 7 { - unreachable!(); + // check to avoid treating primitives other than Option as classes + if obj_id.0 <= 10 { + match (unifier.get_ty(ty).as_ref(), unifier.get_ty(primitives.option).as_ref()) + { + ( + TypeEnum::TObj { obj_id, params, .. }, + TypeEnum::TObj { obj_id: opt_id, .. }, + ) if *obj_id == *opt_id => { + return get_llvm_type( + ctx, + generator, + unifier, + top_level, + type_cache, + primitives, + *params.iter().next().unwrap().1, + ) + .ptr_type(AddressSpace::Generic) + .into(); + } + _ => unreachable!("must be option type"), + } } // a struct with fields in the order of declaration let top_level_defs = top_level.definitions.read(); @@ -289,6 +309,7 @@ fn get_llvm_type<'ctx>( unifier, top_level, type_cache, + primitives, fields[&f.0].0, ) }) @@ -304,14 +325,14 @@ fn get_llvm_type<'ctx>( // a struct with fields in the order present in the tuple let fields = ty .iter() - .map(|ty| get_llvm_type(ctx, generator, unifier, top_level, type_cache, *ty)) + .map(|ty| get_llvm_type(ctx, generator, unifier, top_level, type_cache, primitives, *ty)) .collect_vec(); ctx.struct_type(&fields, false).into() } TList { ty } => { // a struct with an integer and a pointer to an array let element_type = - get_llvm_type(ctx, generator, unifier, top_level, type_cache, *ty); + get_llvm_type(ctx, generator, unifier, top_level, type_cache, primitives, *ty); let fields = [ element_type.ptr_type(AddressSpace::Generic).into(), generator.get_size_type(ctx).into(), @@ -385,6 +406,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>( range: unifier.get_representative(primitives.range), str: unifier.get_representative(primitives.str), exception: unifier.get_representative(primitives.exception), + option: unifier.get_representative(primitives.option), }; let mut type_cache: HashMap<_, _> = [ @@ -417,6 +439,8 @@ pub fn gen_func<'ctx, G: CodeGenerator>( exception.set_body(&fields, false); exception.ptr_type(AddressSpace::Generic).into() }); + // NOTE: special handling of option cannot use this type cache since it contains type var, + // handled inside get_llvm_type instead let (args, ret) = if let ConcreteTypeEnum::TFunc { args, ret, .. } = task.store.get(task.signature) @@ -437,7 +461,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>( let ret_type = if unifier.unioned(ret, primitives.none) { None } else { - Some(get_llvm_type(context, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, ret)) + Some(get_llvm_type(context, generator, &mut unifier, top_level_ctx.as_ref(), &mut type_cache, &primitives, ret)) }; let has_sret = ret_type.map_or(false, |ty| need_sret(context, ty)); @@ -450,6 +474,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>( &mut unifier, top_level_ctx.as_ref(), &mut type_cache, + &primitives, arg.ty, ) .into() @@ -497,6 +522,7 @@ pub fn gen_func<'ctx, G: CodeGenerator>( &mut unifier, top_level_ctx.as_ref(), &mut type_cache, + &primitives, arg.ty, ), &arg.name.to_string(), diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 96f941d2..b8fc9f27 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -153,12 +153,11 @@ pub trait SymbolResolver { } thread_local! { - static IDENTIFIER_ID: [StrRef; 12] = [ + static IDENTIFIER_ID: [StrRef; 11] = [ "int32".into(), "int64".into(), "float".into(), "bool".into(), - "None".into(), "virtual".into(), "list".into(), "tuple".into(), @@ -183,14 +182,13 @@ pub fn parse_type_annotation( let int64_id = ids[1]; let float_id = ids[2]; let bool_id = ids[3]; - let none_id = ids[4]; - let virtual_id = ids[5]; - let list_id = ids[6]; - let tuple_id = ids[7]; - let str_id = ids[8]; - let exn_id = ids[9]; - let uint32_id = ids[10]; - let uint64_id = ids[11]; + let virtual_id = ids[4]; + let list_id = ids[5]; + let tuple_id = ids[6]; + let str_id = ids[7]; + let exn_id = ids[8]; + let uint32_id = ids[9]; + let uint64_id = ids[10]; let name_handling = |id: &StrRef, loc: Location, unifier: &mut Unifier| { if *id == int32_id { @@ -205,8 +203,6 @@ pub fn parse_type_annotation( Ok(primitives.float) } else if *id == bool_id { Ok(primitives.bool) - } else if *id == none_id { - Ok(primitives.none) } else if *id == str_id { Ok(primitives.str) } else if *id == exn_id { diff --git a/nac3core/src/toplevel/builtins.rs b/nac3core/src/toplevel/builtins.rs index 2d97d612..6b171286 100644 --- a/nac3core/src/toplevel/builtins.rs +++ b/nac3core/src/toplevel/builtins.rs @@ -105,6 +105,20 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { ("__param2__".into(), int64, true), ]; + // for Option, is_some and is_none share the same type: () -> bool, + // and they are methods under the same class `Option` + let (is_some_ty, unwrap_ty, (option_ty_var, option_ty_var_id)) = + if let TypeEnum::TObj { fields, params, .. } = + primitives.1.get_ty(primitives.0.option).as_ref() + { + ( + *fields.get(&"is_some".into()).unwrap(), + *fields.get(&"unwrap".into()).unwrap(), + (*params.iter().next().unwrap().1, *params.iter().next().unwrap().0), + ) + } else { + unreachable!() + }; let top_level_def_list = vec![ Arc::new(RwLock::new(TopLevelComposer::make_top_level_class_def( 0, @@ -180,6 +194,81 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { None, None, ))), + Arc::new(RwLock::new({ + TopLevelDef::Class { + name: "Option".into(), + object_id: DefinitionId(10), + type_vars: vec![option_ty_var], + fields: vec![], + methods: vec![ + ("is_some".into(), is_some_ty.0, DefinitionId(11)), + ("is_none".into(), is_some_ty.0, DefinitionId(12)), + ("unwrap".into(), unwrap_ty.0, DefinitionId(13)), + ], + ancestors: vec![TypeAnnotation::CustomClass { + id: DefinitionId(10), + params: Default::default(), + }], + constructor: None, + resolver: None, + loc: None, + } + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "Option.is_some".into(), + simple_name: "is_some".into(), + signature: is_some_ty.0, + var_id: vec![option_ty_var_id], + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, _, _, generator| { + let obj_val = obj.unwrap().1.clone().to_basic_value_enum(ctx, generator)?; + if let BasicValueEnum::PointerValue(ptr) = obj_val { + Ok(Some(ctx.builder.build_is_not_null(ptr, "is_some").into())) + } else { + unreachable!("option must be ptr") + } + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "Option.is_none".into(), + simple_name: "is_none".into(), + signature: is_some_ty.0, + var_id: vec![option_ty_var_id], + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, _, _, generator| { + let obj_val = obj.unwrap().1.clone().to_basic_value_enum(ctx, generator)?; + if let BasicValueEnum::PointerValue(ptr) = obj_val { + Ok(Some(ctx.builder.build_is_null(ptr, "is_none").into())) + } else { + unreachable!("option must be ptr") + } + }, + )))), + loc: None, + })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "Option.unwrap".into(), + simple_name: "unwrap".into(), + signature: unwrap_ty.0, + var_id: vec![option_ty_var_id], + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, obj, _, _, generator| { + unreachable!("handled in gen_expr") + }, + )))), + loc: None, + })), Arc::new(RwLock::new(TopLevelDef::Function { name: "int32".into(), simple_name: "int32".into(), @@ -1098,6 +1187,28 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { )))), loc: None, })), + Arc::new(RwLock::new(TopLevelDef::Function { + name: "Some".into(), + simple_name: "Some".into(), + signature: primitives.1.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![FuncArg { name: "n".into(), ty: option_ty_var, default_value: None }], + ret: primitives.0.option, + vars: HashMap::from([(option_ty_var_id, option_ty_var)]), + })), + var_id: vec![option_ty_var_id], + instance_to_symbol: Default::default(), + instance_to_stmt: Default::default(), + resolver: None, + codegen_callback: Some(Arc::new(GenCall::new(Box::new( + |ctx, _, _fun, args, generator| { + let arg_val = args[0].1.clone().to_basic_value_enum(ctx, generator)?; + let alloca = ctx.builder.build_alloca(arg_val.get_type(), "alloca_some"); + ctx.builder.build_store(alloca, arg_val); + Ok(Some(alloca.into())) + }, + )))), + loc: None, + })), ]; let ast_list: Vec>> = @@ -1123,6 +1234,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo { "min", "max", "abs", + "Some", ], ) } diff --git a/nac3core/src/toplevel/composer.rs b/nac3core/src/toplevel/composer.rs index 991926a6..bfed6690 100644 --- a/nac3core/src/toplevel/composer.rs +++ b/nac3core/src/toplevel/composer.rs @@ -74,6 +74,8 @@ impl TopLevelComposer { "self".into(), "Kernel".into(), "KernelInvariant".into(), + "Some".into(), + "Option".into(), ]); let defined_names: HashSet = Default::default(); let method_class: HashMap = Default::default(); @@ -92,7 +94,6 @@ impl TopLevelComposer { } else if let TopLevelDef::Class { name, constructor, object_id, type_vars, .. } = &*def { assert!(id == object_id.0); - assert!(type_vars.is_empty()); if let Some(constructor) = constructor { builtin_ty.insert(*name, *constructor); } @@ -1783,9 +1784,7 @@ impl TopLevelComposer { }) }; let mut identifiers = { - // NOTE: none and function args? let mut result: HashSet<_> = HashSet::new(); - result.insert("None".into()); if self_type.is_some() { result.insert("self".into()); } @@ -1808,9 +1807,7 @@ impl TopLevelComposer { }, unifier, variable_mapping: { - // NOTE: none and function args? let mut result: HashMap = HashMap::new(); - result.insert("None".into(), primitives_ty.none); if let Some(self_ty) = self_type { result.insert("self".into(), self_ty); } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index a76e7b3a..793bb927 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -107,7 +107,43 @@ impl TopLevelComposer { fields: HashMap::new(), params: HashMap::new(), }); - let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception, uint32, uint64 }; + + let option_type_var = unifier.get_fresh_var(Some("option_type_var".into()), None); + let is_some_type_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![], + ret: bool, + vars: HashMap::from([(option_type_var.1, option_type_var.0)]), + })); + let unwrap_fun_ty = unifier.add_ty(TypeEnum::TFunc(FunSignature { + args: vec![], + ret: option_type_var.0, + vars: HashMap::from([(option_type_var.1, option_type_var.0)]), + })); + let option = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(10), + fields: vec![ + ("is_some".into(), (is_some_type_fun_ty, true)), + ("is_none".into(), (is_some_type_fun_ty, true)), + ("unwrap".into(), (unwrap_fun_ty, true)), + ] + .into_iter() + .collect::>(), + params: HashMap::from([(option_type_var.1, option_type_var.0)]), + }); + + let primitives = PrimitiveStore { + int32, + int64, + float, + bool, + none, + range, + str, + exception, + uint32, + uint64, + option, + }; crate::typecheck::magic_methods::set_primitives_magic_methods(&primitives, &mut unifier); (primitives, unifier) } diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap index 5a9c4fe1..a5efb1a1 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__generic_class.snap @@ -7,7 +7,7 @@ expression: res_vec [ "Class {\nname: \"Generic_A\",\nancestors: [\"{class: Generic_A, params: [\\\"V\\\"]}\", \"{class: B, params: []}\"],\nfields: [\"aa\", \"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\"), (\"fun\", \"fn[[a:int32], V]\")],\ntype_vars: [\"V\"]\n}\n", "Function {\nname: \"Generic_A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [17]\n}\n", + "Function {\nname: \"Generic_A.fun\",\nsig: \"fn[[a:int32], V]\",\nvar_id: [18]\n}\n", "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\"],\nfields: [\"aa\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"foo\", \"fn[[b:T], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.foo\",\nsig: \"fn[[b:T], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap index 7879cf36..02f829c5 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__inheritance_override.snap @@ -9,7 +9,7 @@ expression: res_vec "Function {\nname: \"A.__init__\",\nsig: \"fn[[t:T], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Function {\nname: \"A.foo\",\nsig: \"fn[[c:C], none]\",\nvar_id: []\n}\n", - "Class {\nname: \"B\",\nancestors: [\"{class: B, params: [\\\"var6\\\"]}\", \"{class: A, params: [\\\"float\\\"]}\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"var6\"]\n}\n", + "Class {\nname: \"B\",\nancestors: [\"{class: B, params: [\\\"var7\\\"]}\", \"{class: A, params: [\\\"float\\\"]}\"],\nfields: [\"a\", \"b\", \"c\", \"d\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: [\"var7\"]\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"B.fun\",\nsig: \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"{class: C, params: []}\", \"{class: B, params: [\\\"bool\\\"]}\", \"{class: A, params: [\\\"float\\\"]}\"],\nfields: [\"a\", \"b\", \"c\", \"d\", \"e\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:int32, b:T], list[virtual[B[bool]]]]\"), (\"foo\", \"fn[[c:C], none]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap index 1a33c0f2..037a0e71 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__list_tuple_generic.snap @@ -7,8 +7,8 @@ expression: res_vec [ "Function {\nname: \"foo\",\nsig: \"fn[[a:list[int32], b:tuple[T, float]], A[B, bool]]\",\nvar_id: []\n}\n", "Class {\nname: \"A\",\nancestors: [\"{class: A, params: [\\\"T\\\", \\\"V\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[v:V], none]\"), (\"fun\", \"fn[[a:T], V]\")],\ntype_vars: [\"T\", \"V\"]\n}\n", - "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [19]\n}\n", - "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [24]\n}\n", + "Function {\nname: \"A.__init__\",\nsig: \"fn[[v:V], none]\",\nvar_id: [20]\n}\n", + "Function {\nname: \"A.fun\",\nsig: \"fn[[a:T], V]\",\nvar_id: [25]\n}\n", "Function {\nname: \"gfun\",\nsig: \"fn[[a:A[int32, list[float]]], none]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\"],\nfields: [],\nmethods: [(\"__init__\", \"fn[[], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap index 50a7a665..b8400356 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__self1.snap @@ -5,7 +5,7 @@ expression: res_vec --- [ - "Class {\nname: \"A\",\nancestors: [\"{class: A, params: [\\\"var5\\\", \\\"var6\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"var5\", \"var6\"]\n}\n", + "Class {\nname: \"A\",\nancestors: [\"{class: A, params: [\\\"var6\\\", \\\"var7\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[a:A[bool, float], b:B], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\")],\ntype_vars: [\"var6\", \"var7\"]\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[a:A[bool, float], b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[a:A[bool, float]], A[bool, int32]]\",\nvar_id: []\n}\n", "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\", \"{class: A, params: [\\\"int64\\\", \\\"bool\\\"]}\"],\nfields: [\"a\", \"b\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[a:A[bool, float]], A[bool, int32]]\"), (\"foo\", \"fn[[b:B], B]\"), (\"bar\", \"fn[[a:A[int32, list[B]]], tuple[A[bool, virtual[A[B, int32]]], B]]\")],\ntype_vars: []\n}\n", diff --git a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap index cf182407..4ed79c4d 100644 --- a/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap +++ b/nac3core/src/toplevel/snapshots/nac3core__toplevel__test__test_analyze__simple_class_compose.snap @@ -1,6 +1,6 @@ --- source: nac3core/src/toplevel/test.rs -assertion_line: 540 +assertion_line: 549 expression: res_vec --- @@ -8,12 +8,12 @@ expression: res_vec "Class {\nname: \"A\",\nancestors: [\"{class: A, params: []}\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"A.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"A.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [25]\n}\n", + "Function {\nname: \"A.foo\",\nsig: \"fn[[a:T, b:V], none]\",\nvar_id: [26]\n}\n", "Class {\nname: \"B\",\nancestors: [\"{class: B, params: []}\", \"{class: C, params: []}\", \"{class: A, params: []}\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"B.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Class {\nname: \"C\",\nancestors: [\"{class: C, params: []}\", \"{class: A, params: []}\"],\nfields: [\"a\"],\nmethods: [(\"__init__\", \"fn[[], none]\"), (\"fun\", \"fn[[b:B], none]\"), (\"foo\", \"fn[[a:T, b:V], none]\")],\ntype_vars: []\n}\n", "Function {\nname: \"C.__init__\",\nsig: \"fn[[], none]\",\nvar_id: []\n}\n", "Function {\nname: \"C.fun\",\nsig: \"fn[[b:B], none]\",\nvar_id: []\n}\n", "Function {\nname: \"foo\",\nsig: \"fn[[a:A], none]\",\nvar_id: []\n}\n", - "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [33]\n}\n", + "Function {\nname: \"ff\",\nsig: \"fn[[a:T], V]\",\nvar_id: [34]\n}\n", ] diff --git a/nac3core/src/toplevel/type_annotation.rs b/nac3core/src/toplevel/type_annotation.rs index 3f5405db..423a418d 100644 --- a/nac3core/src/toplevel/type_annotation.rs +++ b/nac3core/src/toplevel/type_annotation.rs @@ -72,8 +72,6 @@ pub fn parse_ast_to_type_annotation_kinds( Ok(TypeAnnotation::Primitive(primitives.float)) } else if id == &"bool".into() { Ok(TypeAnnotation::Primitive(primitives.bool)) - } else if id == &"None".into() { - Ok(TypeAnnotation::Primitive(primitives.none)) } else if id == &"str".into() { Ok(TypeAnnotation::Primitive(primitives.str)) } else if id == &"Exception".into() { @@ -223,6 +221,29 @@ pub fn parse_ast_to_type_annotation_kinds( Ok(TypeAnnotation::List(def_ann.into())) } + // option + ast::ExprKind::Subscript { value, slice, .. } + if { + matches!(&value.node, ast::ExprKind::Name { id, .. } if id == &"Option".into()) + } => + { + let def_ann = parse_ast_to_type_annotation_kinds( + resolver, + top_level_defs, + unifier, + primitives, + slice.as_ref(), + locked, + )?; + let id = + if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty(primitives.option).as_ref() { + *obj_id + } else { + unreachable!() + }; + Ok(TypeAnnotation::CustomClass { id, params: vec![def_ann] }) + } + // tuple ast::ExprKind::Subscript { value, slice, .. } if { diff --git a/nac3core/src/typecheck/function_check.rs b/nac3core/src/typecheck/function_check.rs index b23821ae..c2dc884d 100644 --- a/nac3core/src/typecheck/function_check.rs +++ b/nac3core/src/typecheck/function_check.rs @@ -20,6 +20,8 @@ impl<'a> Inferencer<'a> { defined_identifiers: &mut HashSet, ) -> Result<(), String> { match &pattern.node { + ast::ExprKind::Name { id, .. } if id == &"none".into() => + Err(format!("cannot assign to a `none` (at {})", pattern.location)), ExprKind::Name { id, .. } => { if !defined_identifiers.contains(id) { defined_identifiers.insert(*id); @@ -70,6 +72,9 @@ impl<'a> Inferencer<'a> { } match &expr.node { ExprKind::Name { id, .. } => { + if id == &"none".into() { + return Ok(()); + } self.should_have_value(expr)?; if !defined_identifiers.contains(id) { match self.function_data.resolver.get_symbol_type( diff --git a/nac3core/src/typecheck/type_inferencer/mod.rs b/nac3core/src/typecheck/type_inferencer/mod.rs index c012f2ed..1e99ae29 100644 --- a/nac3core/src/typecheck/type_inferencer/mod.rs +++ b/nac3core/src/typecheck/type_inferencer/mod.rs @@ -40,6 +40,7 @@ pub struct PrimitiveStore { pub range: Type, pub str: Type, pub exception: Type, + pub option: Type, } pub struct FunctionData { @@ -448,25 +449,47 @@ impl<'a> fold::Fold<()> for Inferencer<'a> { Some(self.infer_constant(value, &expr.location)?) } ast::ExprKind::Name { id, .. } => { - if !self.defined_identifiers.contains(id) { - match self.function_data.resolver.get_symbol_type( - self.unifier, - &self.top_level.definitions.read(), - self.primitives, - *id, - ) { - Ok(_) => { - self.defined_identifiers.insert(*id); - } - Err(e) => { - return report_error( - &format!("type error at identifier `{}` ({})", id, e), - expr.location, - ); + // the name `none` is special since it may have different types + if id == &"none".into() { + if let TypeEnum::TObj { params, .. } = + self.unifier.get_ty_immutable(self.primitives.option).as_ref() + { + let var_map = params + .iter() + .map(|(id_var, ty)| { + if let TypeEnum::TVar { id, range, name, loc, .. } = &*self.unifier.get_ty(*ty) { + assert_eq!(*id, *id_var); + (*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).0) + } else { + unreachable!() + } + }) + .collect::>(); + Some(self.unifier.subst(self.primitives.option, &var_map).unwrap()) + } else { + unreachable!("must be tobj") + } + } else { + if !self.defined_identifiers.contains(id) { + match self.function_data.resolver.get_symbol_type( + self.unifier, + &self.top_level.definitions.read(), + self.primitives, + *id, + ) { + Ok(_) => { + self.defined_identifiers.insert(*id); + } + Err(e) => { + return report_error( + &format!("type error at identifier `{}` ({})", id, e), + expr.location, + ); + } } } + Some(self.infer_identifier(*id)?) } - Some(self.infer_identifier(*id)?) } ast::ExprKind::List { elts, .. } => Some(self.infer_list(elts)?), ast::ExprKind::Tuple { elts, .. } => Some(self.infer_tuple(elts)?), @@ -932,6 +955,8 @@ impl<'a> Inferencer<'a> { Ok(self.unifier.add_ty(TypeEnum::TTuple { ty: ty? })) } ast::Constant::Str(_) => Ok(self.primitives.str), + ast::Constant::None + => report_error("CPython `None` not supported (nac3 uses `none` instead)", *loc), _ => report_error("not supported", *loc), } } diff --git a/nac3core/src/typecheck/type_inferencer/test.rs b/nac3core/src/typecheck/type_inferencer/test.rs index 814cc38c..2c5a6ded 100644 --- a/nac3core/src/typecheck/type_inferencer/test.rs +++ b/nac3core/src/typecheck/type_inferencer/test.rs @@ -129,7 +129,24 @@ impl TestEnvironment { fields: HashMap::new(), params: HashMap::new(), }); - let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception, uint32, uint64 }; + let option = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(10), + fields: HashMap::new(), + params: HashMap::new(), + }); + let primitives = PrimitiveStore { + int32, + int64, + float, + bool, + none, + range, + str, + exception, + uint32, + uint64, + option, + }; set_primitives_magic_methods(&primitives, &mut unifier); let id_to_name = [ @@ -237,6 +254,11 @@ impl TestEnvironment { fields: HashMap::new(), params: HashMap::new(), }); + let option = unifier.add_ty(TypeEnum::TObj { + obj_id: DefinitionId(10), + fields: HashMap::new(), + params: HashMap::new(), + }); identifier_mapping.insert("None".into(), none); for (i, name) in ["int32", "int64", "float", "bool", "none", "range", "str", "Exception"] .iter() @@ -259,7 +281,19 @@ impl TestEnvironment { } let defs = 7; - let primitives = PrimitiveStore { int32, int64, float, bool, none, range, str, exception, uint32, uint64 }; + let primitives = PrimitiveStore { + int32, + int64, + float, + bool, + none, + range, + str, + exception, + uint32, + uint64, + option, + }; let (v0, id) = unifier.get_dummy_var(); diff --git a/nac3core/src/typecheck/typedef/mod.rs b/nac3core/src/typecheck/typedef/mod.rs index 67b89248..9df1cd1c 100644 --- a/nac3core/src/typecheck/typedef/mod.rs +++ b/nac3core/src/typecheck/typedef/mod.rs @@ -54,6 +54,18 @@ pub enum RecordKey { Int(i32), } +impl Type { + // a wrapper function for cleaner code so that we don't need to + // write this long pattern matching just to get the field `obj_id` + pub fn get_obj_id(self, unifier: &Unifier) -> DefinitionId { + if let TypeEnum::TObj { obj_id, .. } = unifier.get_ty_immutable(self).as_ref() { + *obj_id + } else { + unreachable!("expect a object type") + } + } +} + impl From<&RecordKey> for StrRef { fn from(r: &RecordKey) -> Self { match r { diff --git a/nac3standalone/demo/interpret_demo.py b/nac3standalone/demo/interpret_demo.py index d892833c..71bb4264 100755 --- a/nac3standalone/demo/interpret_demo.py +++ b/nac3standalone/demo/interpret_demo.py @@ -8,6 +8,38 @@ import pathlib from numpy import int32, int64, uint32, uint64 from typing import TypeVar, Generic +T = TypeVar('T') +class Option(Generic[T]): + _nac3_option: T + + def __init__(self, v: T): + self._nac3_option = v + + def is_none(self): + return self._nac3_option is None + + def is_some(self): + return not self.is_none() + + def unwrap(self): + return self._nac3_option + + def __repr__(self) -> str: + if self.is_none(): + return "none" + else: + return "Some({})".format(repr(self._nac3_option)) + + def __str__(self) -> str: + if self.is_none(): + return "none" + else: + return "Some({})".format(str(self._nac3_option)) + +def Some(v: T) -> Option[T]: + return Option(v) + +none = Option(None) def patch(module): def output_asciiart(x): @@ -39,6 +71,9 @@ def patch(module): module.TypeVar = TypeVar module.Generic = Generic module.extern = extern + module.Option = Option + module.Some = Some + module.none = none def file_import(filename, prefix="file_import_"): diff --git a/nac3standalone/demo/src/option.py b/nac3standalone/demo/src/option.py new file mode 100644 index 00000000..8cc0d5aa --- /dev/null +++ b/nac3standalone/demo/src/option.py @@ -0,0 +1,40 @@ +@extern +def output_int32(x: int32): + ... + +class A: + d: Option[int32] + e: Option[Option[int32]] + def __init__(self, a: Option[int32], b: Option[Option[int32]]): + self.d = a + self.e = b + +def run() -> int32: + a = Some(3) + if a.is_some(): + d = a.unwrap() + output_int32(a.unwrap()) + a = none + if a.is_none(): + output_int32(d + 2) + else: + a = Some(5) + c = Some(6) + output_int32(a.unwrap() + c.unwrap()) + + f = Some(4.3) + output_int32(int32(f.unwrap())) + + obj = A(Some(6), none) + output_int32(obj.d.unwrap()) + + obj2 = Some(A(Some(7), none)) + output_int32(obj2.unwrap().d.unwrap()) + + obj3 = Some(A(Some(8), Some(none))) + if obj3.unwrap().e.unwrap().is_none(): + obj3.unwrap().e = Some(Some(9)) + output_int32(obj3.unwrap().d.unwrap()) + output_int32(obj3.unwrap().e.unwrap().unwrap()) + + return 0