From 6ab73a223c1d567fce325d3b7227d12e5f4fcbe6 Mon Sep 17 00:00:00 2001 From: ychenfo Date: Wed, 30 Mar 2022 03:14:21 +0800 Subject: [PATCH] nac3core/artiq: support default param of option type --- nac3artiq/src/codegen.rs | 5 +++- nac3artiq/src/symbol_resolver.rs | 24 ++++++++++------- nac3core/src/codegen/expr.rs | 36 +++++++++++++++++++++++-- nac3core/src/symbol_resolver.rs | 4 +++ nac3core/src/toplevel/helper.rs | 45 ++++++++++++++++++++++++++++++-- 5 files changed, 100 insertions(+), 14 deletions(-) diff --git a/nac3artiq/src/codegen.rs b/nac3artiq/src/codegen.rs index 3a31171a..51380fad 100644 --- a/nac3artiq/src/codegen.rs +++ b/nac3artiq/src/codegen.rs @@ -361,7 +361,10 @@ fn rpc_codegen_callback_fn<'ctx, 'a>( } // default value handling for k in keys.into_iter() { - mapping.insert(k.name, ctx.gen_symbol_val(generator, &k.default_value.unwrap()).into()); + mapping.insert( + k.name, + ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into() + ); } // reorder the parameters let mut real_params = fun diff --git a/nac3artiq/src/symbol_resolver.rs b/nac3artiq/src/symbol_resolver.rs index a6293779..51635fb9 100644 --- a/nac3artiq/src/symbol_resolver.rs +++ b/nac3artiq/src/symbol_resolver.rs @@ -922,6 +922,7 @@ impl InnerResolver { py: Python, obj: &PyAny, ) -> PyResult> { + let id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?; let ty_id: u64 = self.helper.id_fn.call1(py, (self.helper.type_fn.call1(py, (obj,))?,))?.extract(py)?; Ok(if ty_id == self.primitive_ids.int || ty_id == self.primitive_ids.int32 { @@ -940,13 +941,17 @@ impl InnerResolver { let elements: &PyTuple = obj.cast_as()?; let elements: Result, String>, _> = elements.iter().map(|elem| self.get_default_param_obj_value(py, elem)).collect(); - let elements = match elements? { - Ok(el) => el, - Err(err) => return Ok(Err(err)), - }; - Ok(SymbolValue::Tuple(elements)) + elements?.map(SymbolValue::Tuple) + } else if ty_id == self.primitive_ids.option { + if id == self.primitive_ids.none { + Ok(SymbolValue::OptionNone) + } else { + self + .get_default_param_obj_value(py, obj.getattr("_nac3_option").unwrap())? + .map(|v| SymbolValue::OptionSome(Box::new(v))) + } } else { - Err("only primitives values and tuple can be default parameter value".into()) + Err("only primitives values, option and tuple can be default parameter value".into()) }) } } @@ -962,8 +967,9 @@ impl SymbolResolver for Resolver { for (key, val) in members.iter() { let key: &str = key.extract()?; if key == id.to_string() { - sym_value = - Some(self.0.get_default_param_obj_value(py, val).unwrap().unwrap()); + if let Ok(Ok(v)) = self.0.get_default_param_obj_value(py, val) { + sym_value = Some(v) + } break; } } @@ -971,7 +977,7 @@ impl SymbolResolver for Resolver { }) .unwrap() } - _ => unimplemented!("other type of expr not supported at {}", expr.location), + _ => unreachable!("only for resolving names"), } } diff --git a/nac3core/src/codegen/expr.rs b/nac3core/src/codegen/expr.rs index 3a1ca2ae..fa640cb8 100644 --- a/nac3core/src/codegen/expr.rs +++ b/nac3core/src/codegen/expr.rs @@ -91,6 +91,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { &mut self, generator: &mut dyn CodeGenerator, val: &SymbolValue, + ty: Type, ) -> BasicValueEnum<'ctx> { match val { SymbolValue::I32(v) => self.ctx.i32_type().const_int(*v as u64, true).into(), @@ -107,7 +108,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { ty.const_named_struct(&[str_ptr, size.into()]).into() } SymbolValue::Tuple(ls) => { - let vals = ls.iter().map(|v| self.gen_symbol_val(generator, v)).collect_vec(); + let vals = ls.iter().map(|v| self.gen_symbol_val(generator, v, ty)).collect_vec(); let fields = vals.iter().map(|v| v.get_type()).collect_vec(); let ty = self.ctx.struct_type(&fields, false); let ptr = self.builder.build_alloca(ty, "tuple"); @@ -124,6 +125,37 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> { } self.builder.build_load(ptr, "tup_val") } + SymbolValue::OptionSome(v) => { + let ty = match self.unifier.get_ty_immutable(ty).as_ref() { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == self.primitives.option.get_obj_id(&self.unifier) => + { + *params.iter().next().unwrap().1 + } + _ => unreachable!("must be option type"), + }; + let val = self.gen_symbol_val(generator, v, ty); + let ptr = self.builder.build_alloca(val.get_type(), "default_opt_some"); + self.builder.build_store(ptr, val); + ptr.into() + } + SymbolValue::OptionNone => { + let ty = match self.unifier.get_ty_immutable(ty).as_ref() { + TypeEnum::TObj { obj_id, params, .. } + if *obj_id == self.primitives.option.get_obj_id(&self.unifier) => + { + *params.iter().next().unwrap().1 + } + _ => unreachable!("must be option type"), + }; + let actual_ptr_type = + self.get_llvm_type(generator, ty).ptr_type(AddressSpace::Generic); + self.builder.build_bitcast( + self.ctx.i8_type().ptr_type(AddressSpace::Generic).const_null(), + actual_ptr_type, + "default_opt_none", + ) + } } } @@ -605,7 +637,7 @@ pub fn gen_call<'ctx, 'a, G: CodeGenerator>( } mapping.insert( k.name, - ctx.gen_symbol_val(generator, &k.default_value.unwrap()).into(), + ctx.gen_symbol_val(generator, &k.default_value.unwrap(), k.ty).into(), ); } // reorder the parameters diff --git a/nac3core/src/symbol_resolver.rs b/nac3core/src/symbol_resolver.rs index 6f92ee47..18d5f699 100644 --- a/nac3core/src/symbol_resolver.rs +++ b/nac3core/src/symbol_resolver.rs @@ -29,6 +29,8 @@ pub enum SymbolValue { Double(f64), Bool(bool), Tuple(Vec), + OptionSome(Box), + OptionNone, } impl Display for SymbolValue { @@ -50,6 +52,8 @@ impl Display for SymbolValue { SymbolValue::Tuple(t) => { write!(f, "({})", t.iter().map(|v| format!("{}", v)).collect::>().join(", ")) } + SymbolValue::OptionSome(v) => write!(f, "Some({})", v), + SymbolValue::OptionNone => write!(f, "none"), } } } diff --git a/nac3core/src/toplevel/helper.rs b/nac3core/src/toplevel/helper.rs index 793bb927..4bbd32d0 100644 --- a/nac3core/src/toplevel/helper.rs +++ b/nac3core/src/toplevel/helper.rs @@ -480,6 +480,33 @@ impl TopLevelComposer { Some("tuple".to_string()) } } + SymbolValue::OptionNone => { + if let TypeAnnotation::CustomClass { id, .. } = ty { + if *id == primitive.option.get_obj_id(unifier) { + None + } else { + Some("option".into()) + } + } else { + Some("option".into()) + } + } + SymbolValue::OptionSome(v) => { + if let TypeAnnotation::CustomClass { id, params } = ty { + if *id == primitive.option.get_obj_id(unifier) { + if params.len() == 1 { + Self::check_default_param_type(v, ¶ms[0], primitive, unifier)?; + None + } else { + Some("option".into()) + } + } else { + Some("option".into()) + } + } else { + Some("option".into()) + } + } }; if let Some(found) = res { Err(format!( @@ -511,6 +538,10 @@ pub fn parse_parameter_default_value( Constant::Tuple(tuple) => Ok(SymbolValue::Tuple( tuple.iter().map(|x| handle_constant(x, loc)).collect::, _>>()?, )), + Constant::None => Err(format!( + "`None` is not supported, use `none` for option type instead ({})", + loc + )), _ => unimplemented!("this constant is not supported at {}", loc), } } @@ -548,6 +579,11 @@ pub fn parse_parameter_default_value( } _ => Err(format!("only allow constant integer here at {}", default.location)) } + ast::ExprKind::Name { id, .. } if *id == "Some".into() => Ok( + SymbolValue::OptionSome( + Box::new(parse_parameter_default_value(&args[0], resolver)?) + ) + ), _ => Err(format!("unsupported default parameter at {}", default.location)), } } @@ -556,15 +592,20 @@ pub fn parse_parameter_default_value( .map(|x| parse_parameter_default_value(x, resolver)) .collect::, _>>()? )), + ast::ExprKind::Name { id, .. } if id == &"none".into() => Ok(SymbolValue::OptionNone), ast::ExprKind::Name { id, .. } => { resolver.get_default_param_value(default).ok_or_else( || format!( - "`{}` cannot be used as a default parameter at {} (not primitive type or tuple / not defined?)", + "`{}` cannot be used as a default parameter at {} \ + (not primitive type, option or tuple / not defined?)", id, default.location ) ) } - _ => Err(format!("unsupported default parameter at {}", default.location)) + _ => Err(format!( + "unsupported default parameter (not primitive type, option or tuple) at {}", + default.location + )) } }