1
0
forked from M-Labs/nac3

meta: Refactor to use more let-else bindings

This commit is contained in:
David Mak 2023-12-12 13:38:27 +08:00
parent 5bf05c6a69
commit a19f1065e3
16 changed files with 2227 additions and 2270 deletions

View File

@ -215,7 +215,10 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
if let StmtKind::With { items, body, .. } = &stmt.node {
let StmtKind::With { items, body, .. } = &stmt.node else {
unreachable!()
};
if items.len() == 1 && items[0].optional_vars.is_none() {
let item = &items[0];
@ -354,9 +357,6 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
// not parallel/sequential
gen_with(self, ctx, stmt)
} else {
unreachable!()
}
}
}

View File

@ -533,14 +533,13 @@ impl Nac3 {
let instance = {
let defs = top_level.definitions.read();
let mut definition = defs[def_id.0].write();
if let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } =
&mut *definition
{
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } =
&mut *definition else {
unreachable!()
};
instance_to_symbol.insert(String::new(), "__modinit__".into());
instance_to_stmt[""].clone()
} else {
unreachable!()
}
};
let task = CodeGenTask {

View File

@ -311,7 +311,11 @@ impl InnerResolver {
unreachable!("none cannot be typeid")
} else if let Some(def_id) = self.pyid_to_def.read().get(&ty_id).copied() {
let def = defs[def_id.0].read();
if let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def {
let TopLevelDef::Class { object_id, type_vars, fields, methods, .. } = &*def else {
// only object is supported, functions are not supported
unreachable!("function type is not supported, should not be queried")
};
// do not handle type var param and concrete check here, and no subst
Ok(Ok({
let ty = TypeEnum::TObj {
@ -319,11 +323,11 @@ impl InnerResolver {
params: type_vars
.iter()
.map(|x| {
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) {
(*id, *x)
} else {
let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*x) else {
unreachable!()
}
};
(*id, *x)
})
.collect(),
fields: {
@ -338,10 +342,6 @@ impl InnerResolver {
// here also false, later instantiation use python object to check compatible
(unifier.add_ty(ty), false)
}))
} else {
// only object is supported, functions are not supported
unreachable!("function type is not supported, should not be queried")
}
} else if ty_ty_id == self.primitive_ids.typevar {
let name: &str = pyty.getattr("__name__").unwrap().extract().unwrap();
let (constraint_types, is_const_generic) = {
@ -652,24 +652,24 @@ impl InnerResolver {
// if is `none`
let zelf_id: u64 = self.helper.id_fn.call1(py, (obj,))?.extract(py)?;
if zelf_id == self.primitive_ids.none {
if let TypeEnum::TObj { params, .. } =
unifier.get_ty_immutable(primitives.option).as_ref()
{
let ty_enum = unifier.get_ty_immutable(primitives.option);
let TypeEnum::TObj { params, .. } = ty_enum.as_ref() else {
unreachable!("must be tobj")
};
let var_map = params
.iter()
.map(|(id_var, ty)| {
if let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) {
let TypeEnum::TVar { id, range, name, loc, .. } = &*unifier.get_ty(*ty) else {
unreachable!()
};
assert_eq!(*id, *id_var);
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0)
} else {
unreachable!()
}
})
.collect::<HashMap<_, _>>();
return Ok(Ok(unifier.subst(primitives.option, &var_map).unwrap()))
}
unreachable!("must be tobj")
}
let ty = match self.get_obj_type(py, field_data, unifier, defs, primitives)? {
Ok(t) => t,
@ -688,14 +688,13 @@ impl InnerResolver {
let var_map = params
.iter()
.map(|(id_var, ty)| {
if let TypeEnum::TVar { id, range, name, loc, .. } =
&*unifier.get_ty(*ty)
{
let TypeEnum::TVar { id, range, name, loc, .. } =
&*unifier.get_ty(*ty) else {
unreachable!()
};
assert_eq!(*id, *id_var);
(*id, unifier.get_fresh_var_with_range(range, *name, *loc).0)
} else {
unreachable!()
}
})
.collect::<HashMap<_, _>>();
let mut instantiate_obj = || {
@ -900,7 +899,11 @@ impl InnerResolver {
Ok(Some(global.as_pointer_value().into()))
} else if ty_id == self.primitive_ids.tuple {
if let TypeEnum::TTuple { ty } = ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
let expected_ty_enum = ctx.unifier.get_ty_immutable(expected_ty);
let TypeEnum::TTuple { ty } = expected_ty_enum.as_ref() else {
unreachable!()
};
let tup_tys = ty.iter();
let elements: &PyTuple = obj.downcast()?;
assert_eq!(elements.len(), tup_tys.len());
@ -919,9 +922,6 @@ impl InnerResolver {
let val = val?.unwrap();
let val = ctx.ctx.const_struct(&val, false);
Ok(Some(val.into()))
} else {
unreachable!("must expect tuple type")
}
} else if ty_id == self.primitive_ids.option {
let option_val_ty = match ctx.unifier.get_ty_immutable(expected_ty).as_ref() {
TypeEnum::TObj { obj_id, params, .. }
@ -993,7 +993,8 @@ impl InnerResolver {
// should be classes
let definition =
top_level_defs.get(self.pyid_to_def.read().get(&ty_id).unwrap().0).unwrap().read();
if let TopLevelDef::Class { fields, .. } = &*definition {
let TopLevelDef::Class { fields, .. } = &*definition else { unreachable!() };
let values: Result<Option<Vec<_>>, _> = fields
.iter()
.map(|(name, ty, _)| {
@ -1012,9 +1013,6 @@ impl InnerResolver {
} else {
Ok(None)
}
} else {
unreachable!()
}
}
}
@ -1065,8 +1063,11 @@ impl InnerResolver {
impl SymbolResolver for Resolver {
fn get_default_param_value(&self, expr: &ast::Expr) -> Option<SymbolValue> {
match &expr.node {
ast::ExprKind::Name { id, .. } => {
let ast::ExprKind::Name { id, .. } = &expr.node else {
unreachable!("only for resolving names")
};
Python::with_gil(|py| -> PyResult<Option<SymbolValue>> {
let obj: &PyAny = self.0.module.extract(py)?;
let members: &PyDict = obj.getattr("__dict__").unwrap().downcast().unwrap();
@ -1081,11 +1082,7 @@ impl SymbolResolver for Resolver {
}
}
Ok(sym_value)
})
.unwrap()
}
_ => unreachable!("only for resolving names"),
}
}).unwrap()
}
fn get_symbol_type(

View File

@ -29,15 +29,21 @@ impl TimeFns for NowPinningTimeFns64 {
let now_hiptr =
ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr");
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
};
if let (BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_lo)) = (
let (BasicValueEnum::IntValue(now_hi), BasicValueEnum::IntValue(now_lo)) = (
ctx.builder.build_load(now_hiptr, "now.hi"),
ctx.builder.build_load(now_loptr, "now.lo"),
) {
) else {
unreachable!()
};
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "");
let shifted_hi = ctx.builder.build_left_shift(
zext_hi,
@ -46,12 +52,6 @@ impl TimeFns for NowPinningTimeFns64 {
);
let zext_lo = ctx.builder.build_int_z_extend(now_lo, i64_type, "");
ctx.builder.build_or(shifted_hi, zext_lo, "now_mu").into()
} else {
unreachable!();
}
} else {
unreachable!();
}
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -59,7 +59,10 @@ impl TimeFns for NowPinningTimeFns64 {
let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false);
if let BasicValueEnum::IntValue(time) = t {
let BasicValueEnum::IntValue(time) = t else {
unreachable!()
};
let time_hi = ctx.builder.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, "time.hi"),
i32_type,
@ -76,7 +79,10 @@ impl TimeFns for NowPinningTimeFns64 {
"now.hi.addr",
);
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
};
@ -88,12 +94,6 @@ impl TimeFns for NowPinningTimeFns64 {
.build_store(now_loptr, time_lo)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
} else {
unreachable!();
}
} else {
unreachable!();
}
}
fn emit_delay_mu<'ctx>(
@ -110,12 +110,15 @@ impl TimeFns for NowPinningTimeFns64 {
let now_hiptr =
ctx.builder.build_bitcast(now, i32_type.ptr_type(AddressSpace::default()), "now.hi.addr");
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(2, false)], "now.lo.addr")
};
if let (
let (
BasicValueEnum::IntValue(now_hi),
BasicValueEnum::IntValue(now_lo),
BasicValueEnum::IntValue(dt),
@ -123,7 +126,10 @@ impl TimeFns for NowPinningTimeFns64 {
ctx.builder.build_load(now_hiptr, "now.hi"),
ctx.builder.build_load(now_loptr, "now.lo"),
dt,
) {
) else {
unreachable!()
};
let zext_hi = ctx.builder.build_int_z_extend(now_hi, i64_type, "");
let shifted_hi = ctx.builder.build_left_shift(
zext_hi,
@ -154,12 +160,6 @@ impl TimeFns for NowPinningTimeFns64 {
.build_store(now_loptr, time_lo)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
} else {
unreachable!();
}
} else {
unreachable!();
};
}
}
@ -176,14 +176,14 @@ impl TimeFns for NowPinningTimeFns {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "now");
if let BasicValueEnum::IntValue(now_raw) = now_raw {
let BasicValueEnum::IntValue(now_raw) = now_raw else {
unreachable!()
};
let i64_32 = i64_type.const_int(32, false);
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo");
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi");
ctx.builder.build_or(now_lo, now_hi, "now_mu").into()
} else {
unreachable!();
}
}
fn emit_at_mu<'ctx>(&self, ctx: &mut CodeGenContext<'ctx, '_>, t: BasicValueEnum<'ctx>) {
@ -191,7 +191,10 @@ impl TimeFns for NowPinningTimeFns {
let i64_type = ctx.ctx.i64_type();
let i64_32 = i64_type.const_int(32, false);
if let BasicValueEnum::IntValue(time) = t {
let BasicValueEnum::IntValue(time) = t else {
unreachable!()
};
let time_hi = ctx.builder.build_int_truncate(
ctx.builder.build_right_shift(time, i64_32, false, ""),
i32_type,
@ -208,7 +211,10 @@ impl TimeFns for NowPinningTimeFns {
"now.hi.addr",
);
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
};
@ -220,12 +226,6 @@ impl TimeFns for NowPinningTimeFns {
.build_store(now_loptr, time_lo)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
} else {
unreachable!();
}
} else {
unreachable!();
}
}
fn emit_delay_mu<'ctx>(
@ -242,7 +242,10 @@ impl TimeFns for NowPinningTimeFns {
.unwrap_or_else(|| ctx.module.add_global(i64_type, None, "now"));
let now_raw = ctx.builder.build_load(now.as_pointer_value(), "");
if let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = (now_raw, dt) {
let (BasicValueEnum::IntValue(now_raw), BasicValueEnum::IntValue(dt)) = (now_raw, dt) else {
unreachable!()
};
let now_lo = ctx.builder.build_left_shift(now_raw, i64_32, "now.lo");
let now_hi = ctx.builder.build_right_shift(now_raw, i64_32, false, "now.hi");
let now_val = ctx.builder.build_or(now_lo, now_hi, "now_val");
@ -259,7 +262,10 @@ impl TimeFns for NowPinningTimeFns {
"now.hi.addr",
);
if let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr {
let BasicValueEnum::PointerValue(now_hiptr) = now_hiptr else {
unreachable!()
};
let now_loptr = unsafe {
ctx.builder.build_gep(now_hiptr, &[i32_type.const_int(1, false)], "now.lo.addr")
};
@ -271,12 +277,6 @@ impl TimeFns for NowPinningTimeFns {
.build_store(now_loptr, time_lo)
.set_atomic_ordering(AtomicOrdering::SequentiallyConsistent)
.unwrap();
} else {
unreachable!();
}
} else {
unreachable!();
}
}
}

View File

@ -39,11 +39,10 @@ pub fn get_subst_key(
) -> String {
let mut vars = obj
.map(|ty| {
if let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) {
params.clone()
} else {
let TypeEnum::TObj { params, .. } = &*unifier.get_ty(ty) else {
unreachable!()
}
};
params.clone()
})
.unwrap_or_default();
vars.extend(fun_vars.iter());
@ -224,7 +223,7 @@ impl<'ctx, 'a> CodeGenContext<'ctx, 'a> {
{
self.ctx.i64_type()
} else {
unreachable!();
unreachable!()
};
Some(ty.const_int(*val as u64, false).into())
}
@ -599,8 +598,10 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>(
def: &TopLevelDef,
params: Vec<(Option<StrRef>, ValueEnum<'ctx>)>,
) -> Result<BasicValueEnum<'ctx>, String> {
match def {
TopLevelDef::Class { methods, .. } => {
let TopLevelDef::Class { methods, .. } = def else {
unreachable!()
};
// TODO: what about other fields that require alloca?
let fun_id = methods.iter().find(|method| method.0 == "__init__".into()).map(|method| method.2);
let ty = ctx.get_llvm_type(generator, signature.ret).into_pointer_type();
@ -619,9 +620,6 @@ pub fn gen_constructor<'ctx, 'a, G: CodeGenerator>(
}
Ok(zelf)
}
TopLevelDef::Function { .. } => unreachable!(),
}
}
/// See [`CodeGenerator::gen_func_instance`].
pub fn gen_func_instance<'ctx>(
@ -630,14 +628,14 @@ pub fn gen_func_instance<'ctx>(
fun: (&FunSignature, &mut TopLevelDef, String),
id: usize,
) -> Result<String, String> {
if let (
let (
sign,
TopLevelDef::Function {
name, instance_to_symbol, instance_to_stmt, var_id, resolver, ..
},
key,
) = fun
{
) = fun else { unreachable!() };
if let Some(sym) = instance_to_symbol.get(&key) {
return Ok(sym.clone());
}
@ -672,14 +670,14 @@ pub fn gen_func_instance<'ctx>(
if let Some(obj) = &obj {
let zelf =
store.from_unifier_type(&mut ctx.unifier, &ctx.primitives, obj.0, &mut cache);
if let ConcreteTypeEnum::TFunc { args, .. } = &mut signature {
let ConcreteTypeEnum::TFunc { args, .. } = &mut signature else {
unreachable!()
};
args.insert(
0,
ConcreteFuncArg { name: "self".into(), ty: zelf, default_value: None },
);
} else {
unreachable!()
}
}
let signature = store.add_cty(signature);
@ -695,9 +693,6 @@ pub fn gen_func_instance<'ctx>(
id,
});
Ok(symbol)
} else {
unreachable!()
}
}
/// See [`CodeGenerator::gen_call`].
@ -946,7 +941,10 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
ctx: &mut CodeGenContext<'ctx, '_>,
expr: &Expr<Option<Type>>,
) -> Result<Option<BasicValueEnum<'ctx>>, String> {
if let ExprKind::ListComp { elt, generators } = &expr.node {
let ExprKind::ListComp { elt, generators } = &expr.node else {
unreachable!()
};
let current = ctx.builder.get_insert_block().unwrap().get_parent().unwrap();
let init_bb = ctx.ctx.append_basic_block(current, "listcomp.init");
@ -1109,9 +1107,6 @@ pub fn gen_comprehension<'ctx, G: CodeGenerator>(
emit_cont_bb(ctx);
Ok(Some(list.into()))
} else {
unreachable!()
}
}
/// Generates LLVM IR for a [binary operator expression][expr].
@ -1170,9 +1165,11 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
.unwrap_left();
Ok(Some(res.into()))
} else {
let (op_name, id) = if let TypeEnum::TObj { fields, obj_id, .. } =
ctx.unifier.get_ty_immutable(left.custom.unwrap()).as_ref()
{
let left_ty_enum = ctx.unifier.get_ty_immutable(left.custom.unwrap());
let TypeEnum::TObj { fields, obj_id, .. } = left_ty_enum.as_ref() else {
unreachable!("must be tobj")
};
let (op_name, id) = {
let (binop_name, binop_assign_name) = (
binop_name(op).into(),
binop_assign_name(op).into()
@ -1183,34 +1180,33 @@ pub fn gen_binop_expr<'ctx, G: CodeGenerator>(
} else {
(binop_name, *obj_id)
}
} else {
unreachable!("must be tobj")
};
let signature = match ctx.calls.get(&loc.into()) {
Some(call) => ctx.unifier.get_call_signature(*call).unwrap(),
None => {
if let TypeEnum::TObj { fields, .. } =
ctx.unifier.get_ty_immutable(left.custom.unwrap()).as_ref()
{
let fn_ty = fields.get(&op_name).unwrap().0;
if let TypeEnum::TFunc(sig) = ctx.unifier.get_ty_immutable(fn_ty).as_ref() {
sig.clone()
} else {
unreachable!("must be func sig")
}
} else {
let left_enum_ty = ctx.unifier.get_ty_immutable(left.custom.unwrap());
let TypeEnum::TObj { fields, .. } = left_enum_ty.as_ref() else {
unreachable!("must be tobj")
}
};
let fn_ty = fields.get(&op_name).unwrap().0;
let fn_ty_enum = ctx.unifier.get_ty_immutable(fn_ty);
let TypeEnum::TFunc(sig) = fn_ty_enum.as_ref() else {
unreachable!()
};
sig.clone()
},
};
let fun_id = {
let defs = ctx.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read();
if let TopLevelDef::Class { methods, .. } = &*obj_def {
methods.iter().find(|method| method.0 == op_name).unwrap().2
} else {
let TopLevelDef::Class { methods, .. } = &*obj_def else {
unreachable!()
}
};
methods.iter().find(|method| method.0 == op_name).unwrap().2
};
generator
.gen_call(
@ -1290,11 +1286,11 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}
let ty = if elements.is_empty() {
if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(expr.custom.unwrap()) {
ctx.get_llvm_type(generator, *ty)
} else {
let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(expr.custom.unwrap()) else {
unreachable!()
}
};
ctx.get_llvm_type(generator, *ty)
} else {
elements[0].get_type()
};
@ -1636,11 +1632,11 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
ctx.unifier.get_call_signature(*call).unwrap()
} else {
let ty = func.custom.unwrap();
if let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) {
sign.clone()
} else {
let TypeEnum::TFunc(sign) = &*ctx.unifier.get_ty(ty) else {
unreachable!()
}
};
sign.clone()
};
let func = func.as_ref();
match &func.node {
@ -1669,11 +1665,11 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
let fun_id = {
let defs = ctx.top_level.definitions.read();
let obj_def = defs.get(id.0).unwrap().read();
if let TopLevelDef::Class { methods, .. } = &*obj_def {
methods.iter().find(|method| method.0 == *attr).unwrap().2
} else {
let TopLevelDef::Class { methods, .. } = &*obj_def else {
unreachable!()
}
};
methods.iter().find(|method| method.0 == *attr).unwrap().2
};
// directly generate code for option.unwrap
// since it needs to return static value to optimize for kernel invariant
@ -1755,7 +1751,8 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}
}
ExprKind::Subscript { value, slice, .. } => {
if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(value.custom.unwrap()) {
match &*ctx.unifier.get_ty(value.custom.unwrap()) {
TypeEnum::TList { ty } => {
let v = if let Some(v) = generator.gen_expr(ctx, value)? {
v.to_basic_value_enum(ctx, generator, value.custom.unwrap())?.into_pointer_value()
} else {
@ -1848,7 +1845,8 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
);
ctx.build_gep_and_load(arr_ptr, &[index], None).into()
}
} else if let TypeEnum::TTuple { .. } = &*ctx.unifier.get_ty(value.custom.unwrap()) {
}
TypeEnum::TTuple { .. } => {
let index: u32 =
if let ExprKind::Constant { value: Constant::Int(v), .. } = &slice.node {
(*v).try_into().unwrap()
@ -1872,8 +1870,8 @@ pub fn gen_expr<'ctx, G: CodeGenerator>(
}
None => return Ok(None),
}
} else {
unreachable!("should not be other subscriptable types after type check");
}
_ => unreachable!("should not be other subscriptable types after type check"),
}
},
ExprKind::ListComp { .. } => {

View File

@ -451,11 +451,12 @@ fn get_llvm_type<'ctx>(
// a struct with fields in the order of declaration
let top_level_defs = top_level.definitions.read();
let definition = top_level_defs.get(obj_id.0).unwrap();
let ty = if let TopLevelDef::Class { fields: fields_list, .. } =
&*definition.read()
{
let TopLevelDef::Class { fields: fields_list, .. } = &*definition.read() else {
unreachable!()
};
let name = unifier.stringify(ty);
if let Some(t) = module.get_struct_type(&name) {
let ty = if let Some(t) = module.get_struct_type(&name) {
t.ptr_type(AddressSpace::default()).into()
} else {
let struct_type = ctx.opaque_struct_type(&name);
@ -480,11 +481,8 @@ fn get_llvm_type<'ctx>(
.collect_vec();
struct_type.set_body(&fields, false);
struct_type.ptr_type(AddressSpace::default()).into()
}
} else {
unreachable!()
};
return ty;
return ty
}
TTuple { ty } => {
// a struct with fields in the order present in the tuple
@ -661,10 +659,12 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
// NOTE: special handling of option cannot use this type cache since it contains type var,
// handled inside get_llvm_type instead
let (args, ret) = if let ConcreteTypeEnum::TFunc { args, ret, .. } =
task.store.get(task.signature)
{
(
let ConcreteTypeEnum::TFunc { args, ret, .. } =
task.store.get(task.signature) else {
unreachable!()
};
let (args, ret) = (
args.iter()
.map(|arg| FuncArg {
name: arg.name,
@ -673,10 +673,7 @@ pub fn gen_func_impl<'ctx, G: CodeGenerator, F: FnOnce(&mut G, &mut CodeGenConte
})
.collect_vec(),
task.store.to_unifier_type(&mut unifier, &primitives, *ret, &mut cache),
)
} else {
unreachable!()
};
);
let ret_type = if unifier.unioned(ret, primitives.none) {
None
} else {

View File

@ -171,9 +171,11 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
) -> Result<(), String> {
match &target.node {
ExprKind::Tuple { elts, .. } => {
if let BasicValueEnum::StructValue(v) =
value.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
{
let BasicValueEnum::StructValue(v) =
value.to_basic_value_enum(ctx, generator, target.custom.unwrap())? else {
unreachable!()
};
for (i, elt) in elts.iter().enumerate() {
let v = ctx
.builder
@ -181,14 +183,14 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
.unwrap();
generator.gen_assign(ctx, elt, v.into())?;
}
} else {
unreachable!()
}
}
ExprKind::Subscript { value: ls, slice, .. }
if matches!(&slice.node, ExprKind::Slice { .. }) =>
{
if let ExprKind::Slice { lower, upper, step } = &slice.node {
let ExprKind::Slice { lower, upper, step } = &slice.node else {
unreachable!()
};
let ls = generator
.gen_expr(ctx, ls)?
.unwrap()
@ -201,19 +203,15 @@ pub fn gen_assign<'ctx, G: CodeGenerator>(
let value = value
.to_basic_value_enum(ctx, generator, target.custom.unwrap())?
.into_pointer_value();
let ty =
if let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(target.custom.unwrap()) {
ctx.get_llvm_type(generator, *ty)
} else {
let TypeEnum::TList { ty } = &*ctx.unifier.get_ty(target.custom.unwrap()) else {
unreachable!()
};
let ty = ctx.get_llvm_type(generator, *ty);
let Some(src_ind) = handle_slice_indices(&None, &None, &None, ctx, generator, value)? else {
return Ok(())
};
list_slice_assignment(generator, ctx, ty, ls, (start, end, step), value, src_ind);
} else {
unreachable!()
}
}
_ => {
let name = if let ExprKind::Name { id, .. } = &target.node {
@ -245,7 +243,10 @@ pub fn gen_for<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
if let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node {
let StmtKind::For { iter, target, body, orelse, .. } = &stmt.node else {
unreachable!()
};
// var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch
let var_assignment = ctx.var_assignment.clone();
@ -394,9 +395,7 @@ pub fn gen_for<G: CodeGenerator>(
ctx.builder.position_at_end(cont_bb);
ctx.loop_target = loop_bb;
} else {
unreachable!()
}
Ok(())
}
@ -406,7 +405,10 @@ pub fn gen_while<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
if let StmtKind::While { test, body, orelse, .. } = &stmt.node {
let StmtKind::While { test, body, orelse, .. } = &stmt.node else {
unreachable!()
};
// var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch
let var_assignment = ctx.var_assignment.clone();
@ -432,11 +434,12 @@ pub fn gen_while<G: CodeGenerator>(
return Ok(())
};
if let BasicValueEnum::IntValue(test) = test {
ctx.builder.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb);
} else {
let BasicValueEnum::IntValue(test) = test else {
unreachable!()
};
ctx.builder.build_conditional_branch(generator.bool_to_i1(ctx, test), body_bb, orelse_bb);
ctx.builder.position_at_end(body_bb);
generator.gen_block(ctx, body.iter())?;
for (k, (_, _, counter)) in &var_assignment {
@ -463,9 +466,7 @@ pub fn gen_while<G: CodeGenerator>(
}
ctx.builder.position_at_end(cont_bb);
ctx.loop_target = loop_bb;
} else {
unreachable!()
}
Ok(())
}
@ -475,7 +476,10 @@ pub fn gen_if<G: CodeGenerator>(
ctx: &mut CodeGenContext<'_, '_>,
stmt: &Stmt<Option<Type>>,
) -> Result<(), String> {
if let StmtKind::If { test, body, orelse, .. } = &stmt.node {
let StmtKind::If { test, body, orelse, .. } = &stmt.node else {
unreachable!()
};
// var_assignment static values may be changed in another branch
// if so, remove the static value as it may not be correct in this branch
let var_assignment = ctx.var_assignment.clone();
@ -533,9 +537,7 @@ pub fn gen_if<G: CodeGenerator>(
*static_val = None;
}
}
} else {
unreachable!()
}
Ok(())
}
@ -595,16 +597,16 @@ pub fn exn_constructor<'ctx>(
let int32 = ctx.ctx.i32_type();
let zero = int32.const_zero();
let zelf_id = {
if let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(zelf_ty) {
obj_id.0
} else {
let TypeEnum::TObj { obj_id, .. } = &*ctx.unifier.get_ty(zelf_ty) else {
unreachable!()
}
};
obj_id.0
};
let defs = ctx.top_level.definitions.read();
let def = defs[zelf_id].read();
let zelf_name =
if let TopLevelDef::Class { name, .. } = &*def { *name } else { unreachable!() };
let TopLevelDef::Class { name: zelf_name, .. } = &*def else {
unreachable!()
};
let exception_name = format!("{}:{}", ctx.resolver.get_exception_id(zelf_id), zelf_name);
unsafe {
let id_ptr = ctx.builder.build_in_bounds_gep(zelf, &[zero, zero], "exn.id");
@ -715,7 +717,10 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
ctx: &mut CodeGenContext<'ctx, 'a>,
target: &Stmt<Option<Type>>,
) -> Result<(), String> {
if let StmtKind::Try { body, handlers, orelse, finalbody, .. } = &target.node {
let StmtKind::Try { body, handlers, orelse, finalbody, .. } = &target.node else {
unreachable!()
};
// if we need to generate anything related to exception, we must have personality defined
let personality_symbol = ctx.top_level.personality_symbol.as_ref().unwrap();
let personality = ctx.module.get_function(personality_symbol).unwrap_or_else(|| {
@ -1025,10 +1030,8 @@ pub fn gen_try<'ctx, 'a, G: CodeGenerator>(
}
ctx.builder.position_at_end(tail);
}
Ok(())
} else {
unreachable!()
}
}
/// See [`CodeGenerator::gen_with`].

View File

@ -528,11 +528,11 @@ impl dyn SymbolResolver + Send + Sync {
unifier.internal_stringify(
ty,
&mut |id| {
if let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() {
name.to_string()
} else {
let TopLevelDef::Class { name, .. } = &*top_level_defs[id].read() else {
unreachable!("expected class definition")
}
};
name.to_string()
},
&mut |id| format!("typevar{id}"),
&mut None,

View File

@ -421,11 +421,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
generator,
expect_ty,
)?;
if let BasicValueEnum::PointerValue(ptr) = obj_val {
Ok(Some(ctx.builder.build_is_not_null(ptr, "is_some").into()))
} else {
let BasicValueEnum::PointerValue(ptr) = obj_val else {
unreachable!("option must be ptr")
}
};
Ok(Some(ctx.builder.build_is_not_null(ptr, "is_some").into()))
},
)))),
loc: None,
@ -446,11 +446,11 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
generator,
expect_ty,
)?;
if let BasicValueEnum::PointerValue(ptr) = obj_val {
Ok(Some(ctx.builder.build_is_null(ptr, "is_none").into()))
} else {
let BasicValueEnum::PointerValue(ptr) = obj_val else {
unreachable!("option must be ptr")
}
};
Ok(Some(ctx.builder.build_is_null(ptr, "is_none").into()))
},
)))),
loc: None,
@ -686,7 +686,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
val
} else {
unreachable!();
unreachable!()
};
Ok(Some(res))
},
@ -762,7 +762,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
val
} else {
unreachable!();
unreachable!()
};
Ok(Some(res))
},
@ -1361,7 +1361,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
} else if is_type(m_ty, n_ty) && is_type(n_ty, float) {
("llvm.minnum.f64", llvm_f64)
} else {
unreachable!();
unreachable!()
};
let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| {
let fn_type = arg_ty.fn_type(&[arg_ty.into(), arg_ty.into()], false);
@ -1423,7 +1423,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
} else if is_type(m_ty, n_ty) && is_type(n_ty, float) {
("llvm.maxnum.f64", llvm_f64)
} else {
unreachable!();
unreachable!()
};
let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| {
let fn_type = arg_ty.fn_type(&[arg_ty.into(), arg_ty.into()], false);
@ -1480,7 +1480,7 @@ pub fn get_builtins(primitives: &mut (PrimitiveStore, Unifier)) -> BuiltinInfo {
is_float = true;
("llvm.fabs.f64", llvm_f64)
} else {
unreachable!();
unreachable!()
};
let intrinsic = ctx.module.get_function(fun_name).unwrap_or_else(|| {
let fn_type = if is_float {

View File

@ -300,12 +300,12 @@ impl TopLevelComposer {
// get the methods into the top level class_def
for (name, _, id, ty, ..) in &class_method_name_def_ids {
let mut class_def = class_def_ast.0.write();
if let TopLevelDef::Class { methods, .. } = &mut *class_def {
let TopLevelDef::Class { methods, .. } = &mut *class_def else {
unreachable!()
};
methods.push((*name, *ty, *id));
self.method_class.insert(*id, DefinitionId(class_def_id));
} else {
unreachable!()
}
}
// now class_def_ast and class_method_def_ast_ids are ok, put them into actual def list in correct order
self.definition_ast_list.push(class_def_ast);
@ -385,14 +385,13 @@ impl TopLevelComposer {
let mut class_def = class_def.write();
let (class_bases_ast, class_def_type_vars, class_resolver) = {
if let TopLevelDef::Class { type_vars, resolver, .. } = &mut *class_def {
if let Some(ast::Located {
let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. }, ..
}) = class_ast
{
}) = class_ast else {
unreachable!()
};
(bases, type_vars, resolver)
} else {
unreachable!("must be both class")
}
} else {
return Ok(());
}
@ -515,15 +514,14 @@ impl TopLevelComposer {
ancestors, resolver, object_id, type_vars, ..
} = &mut *class_def
{
if let Some(ast::Located {
let Some(ast::Located {
node: ast::StmtKind::ClassDef { bases, .. },
..
}) = class_ast
{
}) = class_ast else {
unreachable!()
};
(object_id, bases, ancestors, resolver, type_vars)
} else {
unreachable!("must be both class")
}
} else {
return Ok(());
}
@ -659,7 +657,10 @@ impl TopLevelComposer {
.any(|ann| matches!(ann, TypeAnnotation::CustomClass { id, .. } if id.0 == 7))
{
// if inherited from Exception, the body should be a pass
if let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node {
let ast::StmtKind::ClassDef { body, .. } = &class_ast.as_ref().unwrap().node else {
unreachable!()
};
for stmt in body {
if matches!(
stmt.node,
@ -670,21 +671,17 @@ impl TopLevelComposer {
]))
}
}
} else {
unreachable!()
}
}
}
// deal with ancestor of Exception object
if let TopLevelDef::Class { name, ancestors, object_id, .. } =
&mut *self.definition_ast_list[7].0.write()
{
let TopLevelDef::Class { name, ancestors, object_id, .. } =
&mut *self.definition_ast_list[7].0.write() else {
unreachable!()
};
assert_eq!(*name, "Exception".into());
ancestors.push(make_self_type_annotation(&[], *object_id));
} else {
unreachable!();
}
Ok(())
}
@ -775,7 +772,10 @@ impl TopLevelComposer {
}
}
for ty in subst_list.unwrap() {
if let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) {
let TypeEnum::TObj { obj_id, params, fields } = &*unifier.get_ty(ty) else {
unreachable!()
};
let mut new_fields = HashMap::new();
let mut need_subst = false;
for (name, (ty, mutable)) in fields {
@ -793,9 +793,6 @@ impl TopLevelComposer {
errors.insert(e.to_display(unifier).to_string());
}
}
} else {
unreachable!()
}
}
if !errors.is_empty() {
return Err(errors)
@ -833,14 +830,19 @@ impl TopLevelComposer {
return Ok(());
};
if let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } =
function_def
{
let TopLevelDef::Function { signature: dummy_ty, resolver, var_id, .. } = function_def else {
// not top level function def, skip
return Ok(());
};
if matches!(unifier.get_ty(*dummy_ty).as_ref(), TypeEnum::TFunc(_)) {
// already have a function type, is class method, skip
return Ok(());
}
if let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node {
let ast::StmtKind::FunctionDef { args, returns, .. } = &function_ast.node else {
unreachable!("must be both function");
};
let resolver = resolver.as_ref();
let resolver = resolver.unwrap();
let resolver = &**resolver;
@ -853,15 +855,13 @@ impl TopLevelComposer {
if !defined_parameter_name.insert(x.node.arg)
|| keyword_list.contains(&x.node.arg)
{
return Err(HashSet::from([
format!(
return Err(HashSet::from([format!(
"top level function must have unique parameter names \
and names should not be the same as the keywords (at {})",
x.location
),
]))
}
}
}}
let arg_with_default: Vec<(
&ast::Located<ast::ArgData<()>>,
@ -911,11 +911,11 @@ impl TopLevelComposer {
get_type_var_contained_in_type_annotation(&type_annotation)
.into_iter()
.map(|x| -> Result<(u32, Type), HashSet<String>> {
if let TypeAnnotation::TypeVar(ty) = x {
Ok((Self::get_var_id(ty, unifier)?, ty))
} else {
let TypeAnnotation::TypeVar(ty) = x else {
unreachable!("must be type var annotation kind")
}
};
Ok((Self::get_var_id(ty, unifier)?, ty))
})
.collect::<Result<Vec<_>, _>>()?;
for (id, ty) in type_vars_within {
@ -947,8 +947,8 @@ impl TopLevelComposer {
primitives_store,
unifier,
)
.map_err(|err| HashSet::from([
format!("{} (at {})", err, x.location),
.map_err(
|err| HashSet::from([format!("{} (at {})", err, x.location),
]))?;
v
}),
@ -979,11 +979,11 @@ impl TopLevelComposer {
get_type_var_contained_in_type_annotation(&return_ty_annotation)
.into_iter()
.map(|x| -> Result<(u32, Type), HashSet<String>> {
if let TypeAnnotation::TypeVar(ty) = x {
Ok((Self::get_var_id(ty, unifier)?, ty))
} else {
let TypeAnnotation::TypeVar(ty) = x else {
unreachable!("must be type var here")
}
};
Ok((Self::get_var_id(ty, unifier)?, ty))
})
.collect::<Result<Vec<_>, _>>()?;
for (id, ty) in type_vars_within {
@ -1023,13 +1023,6 @@ impl TopLevelComposer {
unifier.unify(*dummy_ty, function_ty).map_err(|e| HashSet::from([
e.at(Some(function_ast.location)).to_display(unifier).to_string(),
]))?;
} else {
unreachable!("must be both function");
}
} else {
// not top level function def, skip
return Ok(());
}
Ok(())
};
for (function_def, function_ast) in def_list.iter().skip(self.builtin_num) {
@ -1057,6 +1050,21 @@ impl TopLevelComposer {
) -> Result<(), HashSet<String>> {
let (keyword_list, core_config) = core_info;
let mut class_def = class_def.write();
let TopLevelDef::Class {
object_id,
ancestors,
fields,
methods,
resolver,
type_vars,
..
} = &mut *class_def else {
unreachable!("here must be toplevel class def");
};
let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast else {
unreachable!("here must be class def ast")
};
let (
class_id,
_class_name,
@ -1067,24 +1075,8 @@ impl TopLevelComposer {
class_methods_def,
class_type_vars_def,
class_resolver,
) = if let TopLevelDef::Class {
object_id,
ancestors,
fields,
methods,
resolver,
type_vars,
..
} = &mut *class_def
{
if let ast::StmtKind::ClassDef { name, bases, body, .. } = &class_ast {
(*object_id, *name, bases, body, ancestors, fields, methods, type_vars, resolver)
} else {
unreachable!("here must be class def ast");
}
} else {
unreachable!("here must be toplevel class def");
};
) = (*object_id, *name, bases, body, ancestors, fields, methods, type_vars, resolver);
let class_resolver = class_resolver.as_ref().unwrap();
let class_resolver = class_resolver.as_ref();
@ -1174,15 +1166,15 @@ impl TopLevelComposer {
get_type_var_contained_in_type_annotation(&type_ann);
// handle the class type var and the method type var
for type_var_within in type_vars_within {
if let TypeAnnotation::TypeVar(ty) = type_var_within {
let TypeAnnotation::TypeVar(ty) = type_var_within else {
unreachable!("must be type var annotation")
};
let id = Self::get_var_id(ty, unifier)?;
if let Some(prev_ty) = method_var_map.insert(id, ty) {
// if already in the list, make sure they are the same?
assert_eq!(prev_ty, ty);
}
} else {
unreachable!("must be type var annotation");
}
}
// finish handling type vars
let dummy_func_arg = FuncArg {
@ -1239,15 +1231,15 @@ impl TopLevelComposer {
get_type_var_contained_in_type_annotation(&annotation);
// handle the class type var and the method type var
for type_var_within in type_vars_within {
if let TypeAnnotation::TypeVar(ty) = type_var_within {
let TypeAnnotation::TypeVar(ty) = type_var_within else {
unreachable!("must be type var annotation");
};
let id = Self::get_var_id(ty, unifier)?;
if let Some(prev_ty) = method_var_map.insert(id, ty) {
// if already in the list, make sure they are the same?
assert_eq!(prev_ty, ty);
}
} else {
unreachable!("must be type var annotation");
}
}
let dummy_return_type = unifier.get_dummy_var().0;
type_var_to_concrete_def.insert(dummy_return_type, annotation.clone());
@ -1264,9 +1256,10 @@ impl TopLevelComposer {
}
};
if let TopLevelDef::Function { var_id, .. } =
&mut *temp_def_list.get(method_id.0).unwrap().write()
{
let TopLevelDef::Function { var_id, .. } =
&mut *temp_def_list.get(method_id.0).unwrap().write() else {
unreachable!()
};
var_id.extend_from_slice(method_var_map
.iter()
.filter_map(|(id, ty)| {
@ -1279,9 +1272,6 @@ impl TopLevelComposer {
.collect_vec()
.as_slice()
);
} else {
unreachable!()
}
let method_type = unifier.add_ty(TypeEnum::TFunc(FunSignature {
args: arg_types,
ret: ret_type,
@ -1336,7 +1326,10 @@ impl TopLevelComposer {
get_type_var_contained_in_type_annotation(&parsed_annotation);
// handle the class type var and the method type var
for type_var_within in type_vars_within {
if let TypeAnnotation::TypeVar(t) = type_var_within {
let TypeAnnotation::TypeVar(t) = type_var_within else {
unreachable!("must be type var annotation")
};
if !class_type_vars_def.contains(&t) {
return Err(HashSet::from([
format!(
@ -1346,9 +1339,6 @@ impl TopLevelComposer {
),
]))
}
} else {
unreachable!("must be type var annotation");
}
}
type_var_to_concrete_def.insert(dummy_field_type, parsed_annotation);
} else {
@ -1391,14 +1381,7 @@ impl TopLevelComposer {
_primitives: &PrimitiveStore,
type_var_to_concrete_def: &mut HashMap<Type, TypeAnnotation>,
) -> Result<(), HashSet<String>> {
let (
_class_id,
class_ancestor_def,
class_fields_def,
class_methods_def,
_class_type_vars_def,
_class_resolver,
) = if let TopLevelDef::Class {
let TopLevelDef::Class {
object_id,
ancestors,
fields,
@ -1406,20 +1389,31 @@ impl TopLevelComposer {
resolver,
type_vars,
..
} = class_def
{
(*object_id, ancestors, fields, methods, type_vars, resolver)
} else {
unreachable!("here must be class def ast");
} = class_def else {
unreachable!("here must be class def ast")
};
let (
_class_id,
class_ancestor_def,
class_fields_def,
class_methods_def,
_class_type_vars_def,
_class_resolver,
) = (*object_id, ancestors, fields, methods, type_vars, resolver);
// since when this function is called, the ancestors of the direct parent
// are supposed to be already handled, so we only need to deal with the direct parent
let base = class_ancestor_def.get(1).unwrap();
if let TypeAnnotation::CustomClass { id, params: _ } = base {
let TypeAnnotation::CustomClass { id, params: _ } = base else {
unreachable!("must be class type annotation")
};
let base = temp_def_list.get(id.0).unwrap();
let base = base.read();
if let TopLevelDef::Class { methods, fields, .. } = &*base {
let TopLevelDef::Class { methods, fields, .. } = &*base else {
unreachable!("must be top level class def")
};
// handle methods override
// since we need to maintain the order, create a new list
let mut new_child_methods: Vec<(StrRef, Type, DefinitionId)> = Vec::new();
@ -1441,10 +1435,8 @@ impl TopLevelComposer {
type_var_to_concrete_def,
);
if !ok {
return Err(HashSet::from([
format!(
"method {class_method_name} has same name as ancestors' method, but incompatible type"
),
return Err(HashSet::from([format!(
"method {class_method_name} has same name as ancestors' method, but incompatible type"),
]))
}
// mark it as added
@ -1480,10 +1472,8 @@ impl TopLevelComposer {
// find if there is a fields with the same name in the child class
for (class_field_name, ..) in &*class_fields_def {
if class_field_name == anc_field_name {
return Err(HashSet::from([
format!(
"field `{class_field_name}` has already declared in the ancestor classes"
),
return Err(HashSet::from([format!(
"field `{class_field_name}` has already declared in the ancestor classes"),
]))
}
}
@ -1496,12 +1486,6 @@ impl TopLevelComposer {
}
class_fields_def.drain(..);
class_fields_def.extend(new_child_fields);
} else {
unreachable!("must be top level class def")
}
} else {
unreachable!("must be class type annotation")
}
Ok(())
}
@ -1626,14 +1610,14 @@ impl TopLevelComposer {
for (name, func_sig, id) in methods {
if *name == init_str_id {
init_id = Some(*id);
if let TypeEnum::TFunc(FunSignature { args, vars, .. }) =
unifier.get_ty(*func_sig).as_ref()
{
let func_ty_enum = unifier.get_ty(*func_sig);
let TypeEnum::TFunc(FunSignature { args, vars, .. }) =
func_ty_enum.as_ref() else {
unreachable!("must be typeenum::tfunc")
};
constructor_args.extend_from_slice(args);
type_vars.extend(vars);
} else {
unreachable!("must be typeenum::tfunc")
}
}
}
(constructor_args, type_vars)
@ -1685,16 +1669,15 @@ impl TopLevelComposer {
}
for (i, signature, id) in constructors {
if let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write()
{
let TopLevelDef::Class { methods, .. } = &mut *self.definition_ast_list[i].0.write() else {
unreachable!()
};
methods.push((
init_str_id,
signature,
DefinitionId(self.definition_ast_list.len() + id),
));
} else {
unreachable!()
}
}
self.definition_ast_list.extend_from_slice(&definition_extension);
@ -1720,16 +1703,22 @@ impl TopLevelComposer {
..
} = &mut *function_def
{
if let TypeEnum::TFunc(FunSignature { args, ret, vars }) =
unifier.get_ty(*signature).as_ref()
{
let signature_ty_enum = unifier.get_ty(*signature);
let TypeEnum::TFunc(FunSignature { args, ret, vars }) =
signature_ty_enum.as_ref() else {
unreachable!("must be typeenum::tfunc")
};
let mut vars = vars.clone();
// None if is not class method
let uninst_self_type = {
if let Some(class_id) = method_class.get(&DefinitionId(id)) {
let class_def = definition_ast_list.get(class_id.0).unwrap();
let class_def = class_def.0.read();
if let TopLevelDef::Class { type_vars, .. } = &*class_def {
let TopLevelDef::Class { type_vars, .. } = &*class_def else {
unreachable!("must be class def")
};
let ty_ann = make_self_type_annotation(type_vars, *class_id);
let self_ty = get_type_from_type_annotation_kinds(
&def_list,
@ -1737,16 +1726,14 @@ impl TopLevelComposer {
&ty_ann,
&mut None
)?;
vars.extend(type_vars.iter().map(|ty|
if let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) {
(*id, *ty)
} else {
vars.extend(type_vars.iter().map(|ty| {
let TypeEnum::TVar { id, .. } = &*unifier.get_ty(*ty) else {
unreachable!()
};
(*id, *ty)
}));
Some((self_ty, type_vars.clone()))
} else {
unreachable!("must be class def")
}
} else {
None
}
@ -1759,14 +1746,13 @@ impl TopLevelComposer {
.values()
.map(|ty| {
unifier.get_instantiations(*ty).unwrap_or_else(|| {
if let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty)
{
let TypeEnum::TVar { name, loc, is_const_generic: false, .. } = &*unifier.get_ty(*ty) else {
unreachable!()
};
let rigid = unifier.get_fresh_rigid_var(*name, *loc).0;
no_ranges.push(rigid);
vec![rigid]
} else {
unreachable!()
}
})
})
.multi_cartesian_product()
@ -1859,10 +1845,10 @@ impl TopLevelComposer {
in_handler: false,
};
let fun_body =
if let ast::StmtKind::FunctionDef { body, decorator_list, .. } =
ast.clone().unwrap().node
{
let ast::StmtKind::FunctionDef { body, decorator_list, .. } =
ast.clone().unwrap().node else {
unreachable!("must be function def ast")
};
if !decorator_list.is_empty()
&& matches!(&decorator_list[0].node,
ast::ExprKind::Name{ id, .. } if id == &"extern".into())
@ -1877,10 +1863,8 @@ impl TopLevelComposer {
instance_to_symbol.insert(String::new(), simple_name.to_string());
continue;
}
body
} else {
unreachable!("must be function def ast")
}
let fun_body = body
.into_iter()
.map(|b| inferencer.fold_stmt(b))
.collect::<Result<Vec<_>, _>>()?;
@ -1908,29 +1892,25 @@ impl TopLevelComposer {
} else {
let base_repr = inferencer.unifier.stringify(*base);
let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(HashSet::from([
format!(
"Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"
),
return Err(HashSet::from([format!(
"Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"),
]))
}
};
let subtype_entry = defs[subtype_id.0].read();
if let TopLevelDef::Class { ancestors, .. } = &*subtype_entry {
let TopLevelDef::Class { ancestors, .. } = &*subtype_entry else {
unreachable!()
};
let m = ancestors.iter()
.find(|kind| matches!(kind, TypeAnnotation::CustomClass { id, .. } if *id == base_id));
if m.is_none() {
let base_repr = inferencer.unifier.stringify(*base);
let subtype_repr = inferencer.unifier.stringify(*subtype);
return Err(HashSet::from([
format!(
"Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"
),
return Err(HashSet::from([format!(
"Expected a subtype of {base_repr}, but got {subtype_repr} (at {loc})"),
]))
}
} else {
unreachable!();
}
}
}
if !unifier.unioned(inst_ret, primitives_ty.none) && !returned {
@ -1938,19 +1918,15 @@ impl TopLevelComposer {
let ret_str = unifier.internal_stringify(
inst_ret,
&mut |id| {
if let TopLevelDef::Class { name, .. } =
&*def_ast_list[id].0.read()
{
let TopLevelDef::Class { name, .. } = &*def_ast_list[id].0.read()
else { unreachable!("must be class id here") };
name.to_string()
} else {
unreachable!("must be class id here")
}
},
&mut |id| format!("typevar{id}"),
&mut None,
);
return Err(HashSet::from([
format!(
return Err(HashSet::from([format!(
"expected return type of `{}` in function `{}` (at {})",
ret_str,
name,
@ -1969,10 +1945,8 @@ impl TopLevelComposer {
},
);
}
} else {
unreachable!("must be typeenum::tfunc")
}
}
Ok(())
};
for (id, (def, ast)) in self.definition_ast_list.iter().enumerate().skip(self.builtin_num) {

View File

@ -233,11 +233,11 @@ impl TopLevelComposer {
};
// check cycle
let no_cycle = result.iter().all(|x| {
if let TypeAnnotation::CustomClass { id, .. } = x {
id.0 != p_id.0
} else {
let TypeAnnotation::CustomClass { id, .. } = x else {
unreachable!("must be class kind annotation")
}
};
id.0 != p_id.0
});
if no_cycle {
result.push(p);
@ -260,15 +260,15 @@ impl TopLevelComposer {
};
let child_def = temp_def_list.get(child_id.0).unwrap();
let child_def = child_def.read();
if let TopLevelDef::Class { ancestors, .. } = &*child_def {
let TopLevelDef::Class { ancestors, .. } = &*child_def else {
unreachable!("child must be top level class def")
};
if ancestors.is_empty() {
None
} else {
Some(ancestors[0].clone())
}
} else {
unreachable!("child must be top level class def")
}
}
/// get the `var_id` of a given `TVar` type
@ -292,11 +292,13 @@ impl TopLevelComposer {
let this = this.as_ref();
let other = unifier.get_ty(other);
let other = other.as_ref();
if let (
let (
TypeEnum::TFunc(FunSignature { args: this_args, ret: this_ret, .. }),
TypeEnum::TFunc(FunSignature { args: other_args, ret: other_ret, .. }),
) = (this, other)
{
) = (this, other) else {
unreachable!("this function must be called with function type")
};
// check args
let args_ok = this_args
.iter()
@ -322,9 +324,6 @@ impl TopLevelComposer {
// return
args_ok && ret_ok
} else {
unreachable!("this function must be called with function type")
}
}
pub fn check_overload_field_type(

View File

@ -163,11 +163,11 @@ pub fn parse_ast_to_type_annotation_kinds<T>(
let type_vars = {
let def_read = top_level_defs[obj_id.0].try_read();
if let Some(def_read) = def_read {
if let TopLevelDef::Class { type_vars, .. } = &*def_read {
type_vars.clone()
} else {
let TopLevelDef::Class { type_vars, .. } = &*def_read else {
unreachable!("must be class here")
}
};
type_vars.clone()
} else {
locked.get(&obj_id).unwrap().clone()
}
@ -497,13 +497,11 @@ pub fn get_type_from_type_annotation_kinds(
TypeAnnotation::Primitive(ty) | TypeAnnotation::TypeVar(ty) => Ok(*ty),
TypeAnnotation::Constant { ty, value, .. } => {
let ty_enum = unifier.get_ty(*ty);
let (ty, loc) = match &*ty_enum {
TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } => {
(ntv_underlying_ty[0], loc)
}
_ => unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name()),
let TypeEnum::TVar { range: ntv_underlying_ty, loc, is_const_generic: true, .. } = &*ty_enum else {
unreachable!("{} ({})", unifier.stringify(*ty), ty_enum.get_type_name());
};
let ty = ntv_underlying_ty[0];
let var = unifier.get_fresh_constant(value.clone(), ty, *loc);
Ok(var)
}
@ -596,15 +594,14 @@ pub fn check_overload_type_annotation_compatible(
let a = &*a;
let b = unifier.get_ty(*b);
let b = &*b;
if let (
let (
TypeEnum::TVar { id: a, fields: None, .. },
TypeEnum::TVar { id: b, fields: None, .. },
) = (a, b)
{
a == b
} else {
) = (a, b) else {
unreachable!("must be type var")
}
};
a == b
}
(TypeAnnotation::Virtual(a), TypeAnnotation::Virtual(b))
| (TypeAnnotation::List(a), TypeAnnotation::List(b)) => {

View File

@ -241,7 +241,10 @@ impl<'a> Fold<()> for Inferencer<'a> {
let targets: Result<Vec<_>, _> = targets
.into_iter()
.map(|target| {
if let ExprKind::Name { id, ctx } = target.node {
let ExprKind::Name { id, ctx } = target.node else {
unreachable!()
};
self.defined_identifiers.insert(id);
let target_ty = if let Some(ty) = self.variable_mapping.get(&id)
{
@ -267,9 +270,6 @@ impl<'a> Fold<()> for Inferencer<'a> {
node: ExprKind::Name { id, ctx },
custom: Some(target_ty),
})
} else {
unreachable!()
}
})
.collect();
let loc = node.location;
@ -465,12 +465,12 @@ impl<'a> Fold<()> for Inferencer<'a> {
let var_map = params
.iter()
.map(|(id_var, ty)| {
if let TypeEnum::TVar { id, range, name, loc, .. } = &*self.unifier.get_ty(*ty) {
let TypeEnum::TVar { id, range, name, loc, .. } = &*self.unifier.get_ty(*ty) else {
unreachable!()
};
assert_eq!(*id, *id_var);
(*id, self.unifier.get_fresh_var_with_range(range, *name, *loc).0)
} else {
unreachable!()
}
})
.collect::<HashMap<_, _>>();
Some(self.unifier.subst(self.primitives.option, &var_map).unwrap())

View File

@ -499,12 +499,9 @@ impl Unifier {
let instantiated = self.instantiate_fun(b, signature);
let r = self.get_ty(instantiated);
let r = r.as_ref();
let signature;
if let TypeEnum::TFunc(s) = r {
signature = s;
} else {
unreachable!();
}
let TypeEnum::TFunc(signature) = r else {
unreachable!()
};
// we check to make sure that all required arguments (those without default
// arguments) are provided, and do not provide the same argument twice.
let mut required = required.to_vec();
@ -940,13 +937,12 @@ impl Unifier {
top_level.as_ref().map_or_else(
|| format!("{id}"),
|top_level| {
if let TopLevelDef::Class { name, .. } =
&*top_level.definitions.read()[id].read()
{
name.to_string()
} else {
let top_level_def = &top_level.definitions.read()[id];
let TopLevelDef::Class { name, .. } = &*top_level_def.read() else {
unreachable!("expected class definition")
}
};
name.to_string()
},
)
},

View File

@ -339,23 +339,21 @@ fn test_recursive_subst() {
let int = *env.type_mapping.get("int").unwrap();
let foo_id = *env.type_mapping.get("Foo").unwrap();
let foo_ty = env.unifier.get_ty(foo_id);
let mapping: HashMap<_, _>;
with_fields(&mut env.unifier, foo_id, |_unifier, fields| {
fields.insert("rec".into(), (foo_id, true));
});
if let TypeEnum::TObj { params, .. } = &*foo_ty {
mapping = params.iter().map(|(id, _)| (*id, int)).collect();
} else {
let TypeEnum::TObj { params, .. } = &*foo_ty else {
unreachable!()
}
};
let mapping = params.iter().map(|(id, _)| (*id, int)).collect();
let instantiated = env.unifier.subst(foo_id, &mapping).unwrap();
let instantiated_ty = env.unifier.get_ty(instantiated);
if let TypeEnum::TObj { fields, .. } = &*instantiated_ty {
let TypeEnum::TObj { fields, .. } = &*instantiated_ty else {
unreachable!()
};
assert!(env.unifier.unioned(fields.get(&"a".into()).unwrap().0, int));
assert!(env.unifier.unioned(fields.get(&"rec".into()).unwrap().0, instantiated));
} else {
unreachable!()
}
}
#[test]

View File

@ -363,12 +363,11 @@ fn main() {
.unwrap_or_else(|_| panic!("cannot find run() entry point"))
.0]
.write();
if let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance {
let TopLevelDef::Function { instance_to_stmt, instance_to_symbol, .. } = &mut *instance else {
unreachable!()
};
instance_to_symbol.insert(String::new(), "run".to_string());
instance_to_stmt[""].clone()
} else {
unreachable!()
}
};
let llvm_options = CodeGenLLVMOptions {