with parallel/sequential support
Behavior of parallel and sequential: Each function call (indirectly, can be inside a sequential block) within a parallel block will update the end variable to the maximum now_mu in the block. Each function call directly inside a parallel block will reset the timeline after execution. A parallel block within a sequential block (or not within any block) will set the timeline to the max now_mu within the block (and the outer max now_mu will also be updated). Implementation: We track the start and end separately. - If there is a start variable, it indicates that we are directly inside a parallel block and we have to reset the timeline after every function call. - If there is a end variable, it indicates that we are (indirectly) inside a parallel block, and we should update the max end value. Note: requires testing, it is difficult to inspect the output IR
This commit is contained in:
parent
558c3f03ef
commit
84c5201243
|
@ -5,7 +5,8 @@ from numpy import int32, int64
|
||||||
|
|
||||||
import nac3artiq
|
import nac3artiq
|
||||||
|
|
||||||
__all__ = ["KernelInvariant", "extern", "kernel", "portable", "ms", "us", "ns", "Core", "TTLOut"]
|
__all__ = ["KernelInvariant", "extern", "kernel", "portable", "ms", "us", "ns",
|
||||||
|
"Core", "TTLOut", "parallel", "sequential"]
|
||||||
|
|
||||||
|
|
||||||
import device_db
|
import device_db
|
||||||
|
@ -15,7 +16,6 @@ nac3 = nac3artiq.NAC3(core_arguments["target"])
|
||||||
allow_module_registration = True
|
allow_module_registration = True
|
||||||
registered_ids = set()
|
registered_ids = set()
|
||||||
|
|
||||||
|
|
||||||
def KernelInvariant(t):
|
def KernelInvariant(t):
|
||||||
return t
|
return t
|
||||||
|
|
||||||
|
@ -83,6 +83,14 @@ def rtio_input_timestamp(timeout_mu: int64, channel: int32) -> int64:
|
||||||
def rtio_input_data(channel: int32) -> int32:
|
def rtio_input_data(channel: int32) -> int32:
|
||||||
raise NotImplementedError("syscall not simulated")
|
raise NotImplementedError("syscall not simulated")
|
||||||
|
|
||||||
|
def at_mu(_):
|
||||||
|
raise NotImplementedError("at_mu not simulated")
|
||||||
|
|
||||||
|
def now_mu() -> int32:
|
||||||
|
raise NotImplementedError("now_mu not simulated")
|
||||||
|
|
||||||
|
def delay_mu(_):
|
||||||
|
raise NotImplementedError("delay_mu not simulated")
|
||||||
|
|
||||||
@kernel
|
@kernel
|
||||||
class Core:
|
class Core:
|
||||||
|
@ -169,3 +177,17 @@ class TTLOut:
|
||||||
self.on()
|
self.on()
|
||||||
self.core.delay(duration)
|
self.core.delay(duration)
|
||||||
self.off()
|
self.off()
|
||||||
|
|
||||||
|
@portable
|
||||||
|
class KernelContextManager:
|
||||||
|
@kernel
|
||||||
|
def __enter__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
@kernel
|
||||||
|
def __exit__(self):
|
||||||
|
pass
|
||||||
|
|
||||||
|
parallel = KernelContextManager()
|
||||||
|
sequential = KernelContextManager()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,196 @@
|
||||||
|
use nac3core::{
|
||||||
|
codegen::{expr::gen_call, stmt::gen_with, CodeGenContext, CodeGenerator},
|
||||||
|
toplevel::DefinitionId,
|
||||||
|
typecheck::typedef::{FunSignature, Type},
|
||||||
|
};
|
||||||
|
|
||||||
|
use rustpython_parser::ast::{Expr, ExprKind, Located, Stmt, StmtKind, StrRef};
|
||||||
|
|
||||||
|
use inkwell::values::BasicValueEnum;
|
||||||
|
|
||||||
|
use crate::timeline::TimeFns;
|
||||||
|
|
||||||
|
pub struct ArtiqCodeGenerator<'a> {
|
||||||
|
name: String,
|
||||||
|
name_counter: u32,
|
||||||
|
start: Option<Expr<Option<Type>>>,
|
||||||
|
end: Option<Expr<Option<Type>>>,
|
||||||
|
timeline: &'a (dyn TimeFns + Sync),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'a> ArtiqCodeGenerator<'a> {
|
||||||
|
pub fn new(name: String, timeline: &'a (dyn TimeFns + Sync)) -> ArtiqCodeGenerator<'a> {
|
||||||
|
ArtiqCodeGenerator {
|
||||||
|
name,
|
||||||
|
name_counter: 0,
|
||||||
|
start: None,
|
||||||
|
end: None,
|
||||||
|
timeline,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
||||||
|
fn get_name(&self) -> &str {
|
||||||
|
&self.name
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_call<'ctx, 'a>(
|
||||||
|
&mut self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
obj: Option<(Type, BasicValueEnum<'ctx>)>,
|
||||||
|
fun: (&FunSignature, DefinitionId),
|
||||||
|
params: Vec<(Option<StrRef>, BasicValueEnum<'ctx>)>,
|
||||||
|
) -> Option<BasicValueEnum<'ctx>> {
|
||||||
|
let result = gen_call(self, ctx, obj, fun, params);
|
||||||
|
if let Some(end) = self.end.clone() {
|
||||||
|
let old_end = self.gen_expr(ctx, &end).unwrap();
|
||||||
|
let now = self.timeline.emit_now_mu(ctx);
|
||||||
|
let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| {
|
||||||
|
let i64 = ctx.ctx.i64_type();
|
||||||
|
ctx.module.add_function(
|
||||||
|
"llvm.smax.i64",
|
||||||
|
i64.fn_type(&[i64.into(), i64.into()], false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
let max = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(smax, &[old_end, now], "smax")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
let end_store = self.gen_store_target(ctx, &end);
|
||||||
|
ctx.builder.build_store(end_store, max);
|
||||||
|
}
|
||||||
|
if let Some(start) = self.start.clone() {
|
||||||
|
let start_val = self.gen_expr(ctx, &start).unwrap();
|
||||||
|
self.timeline.emit_at_mu(ctx, start_val);
|
||||||
|
}
|
||||||
|
result
|
||||||
|
}
|
||||||
|
|
||||||
|
fn gen_with<'ctx, 'a>(
|
||||||
|
&mut self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
stmt: &Stmt<Option<Type>>,
|
||||||
|
) -> bool {
|
||||||
|
if let StmtKind::With { items, body, .. } = &stmt.node {
|
||||||
|
if items.len() == 1 && items[0].optional_vars.is_none() {
|
||||||
|
let item = &items[0];
|
||||||
|
// Behavior of parallel and sequential:
|
||||||
|
// Each function call (indirectly, can be inside a sequential block) within a parallel
|
||||||
|
// block will update the end variable to the maximum now_mu in the block.
|
||||||
|
// Each function call directly inside a parallel block will reset the timeline after
|
||||||
|
// execution. A parallel block within a sequential block (or not within any block) will
|
||||||
|
// set the timeline to the max now_mu within the block (and the outer max now_mu will also
|
||||||
|
// be updated).
|
||||||
|
//
|
||||||
|
// Implementation: We track the start and end separately.
|
||||||
|
// - If there is a start variable, it indicates that we are directly inside a
|
||||||
|
// parallel block and we have to reset the timeline after every function call.
|
||||||
|
// - If there is a end variable, it indicates that we are (indirectly) inside a
|
||||||
|
// parallel block, and we should update the max end value.
|
||||||
|
if let ExprKind::Name { id, ctx: name_ctx } = &item.context_expr.node {
|
||||||
|
if id == &"parallel".into() {
|
||||||
|
let old_start = self.start.take();
|
||||||
|
let old_end = self.end.take();
|
||||||
|
let now = if let Some(old_start) = &old_start {
|
||||||
|
self.gen_expr(ctx, old_start).unwrap()
|
||||||
|
} else {
|
||||||
|
self.timeline.emit_now_mu(ctx)
|
||||||
|
};
|
||||||
|
// Emulate variable allocation, as we need to use the CodeGenContext
|
||||||
|
// HashMap to store our variable due to lifetime limitation
|
||||||
|
// Note: we should be able to store variables directly if generic
|
||||||
|
// associative type is used by limiting the lifetime of CodeGenerator to
|
||||||
|
// the LLVM Context.
|
||||||
|
// The name is guaranteed to be unique as users cannot use this as variable
|
||||||
|
// name.
|
||||||
|
self.start = old_start.clone().or_else(|| {
|
||||||
|
let start = format!("with-{}-start", self.name_counter).into();
|
||||||
|
let start_expr = Located {
|
||||||
|
// location does not matter at this point
|
||||||
|
location: stmt.location,
|
||||||
|
node: ExprKind::Name {
|
||||||
|
id: start,
|
||||||
|
ctx: name_ctx.clone(),
|
||||||
|
},
|
||||||
|
custom: Some(ctx.primitives.int64),
|
||||||
|
};
|
||||||
|
let start = self.gen_store_target(ctx, &start_expr);
|
||||||
|
ctx.builder.build_store(start, now);
|
||||||
|
Some(start_expr)
|
||||||
|
});
|
||||||
|
let end = format!("with-{}-end", self.name_counter).into();
|
||||||
|
let end_expr = Located {
|
||||||
|
// location does not matter at this point
|
||||||
|
location: stmt.location,
|
||||||
|
node: ExprKind::Name {
|
||||||
|
id: end,
|
||||||
|
ctx: name_ctx.clone(),
|
||||||
|
},
|
||||||
|
custom: Some(ctx.primitives.int64),
|
||||||
|
};
|
||||||
|
let end = self.gen_store_target(ctx, &end_expr);
|
||||||
|
ctx.builder.build_store(end, now);
|
||||||
|
self.end = Some(end_expr);
|
||||||
|
self.name_counter += 1;
|
||||||
|
let mut exited = false;
|
||||||
|
for stmt in body.iter() {
|
||||||
|
if self.gen_stmt(ctx, stmt) {
|
||||||
|
exited = true;
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// set duration
|
||||||
|
let end_expr = self.end.take().unwrap();
|
||||||
|
let end_val = self.gen_expr(ctx, &end_expr).unwrap();
|
||||||
|
|
||||||
|
// inside an sequential block
|
||||||
|
if old_start.is_none() {
|
||||||
|
self.timeline.emit_at_mu(ctx, end_val);
|
||||||
|
}
|
||||||
|
// inside a parallel block, should update the outer max now_mu
|
||||||
|
if let Some(old_end) = &old_end {
|
||||||
|
let outer_end_val = self.gen_expr(ctx, old_end).unwrap();
|
||||||
|
let smax = ctx.module.get_function("llvm.smax.i64").unwrap_or_else(|| {
|
||||||
|
let i64 = ctx.ctx.i64_type();
|
||||||
|
ctx.module.add_function(
|
||||||
|
"llvm.smax.i64",
|
||||||
|
i64.fn_type(&[i64.into(), i64.into()], false),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
});
|
||||||
|
let max = ctx
|
||||||
|
.builder
|
||||||
|
.build_call(smax, &[end_val, outer_end_val], "smax")
|
||||||
|
.try_as_basic_value()
|
||||||
|
.left()
|
||||||
|
.unwrap();
|
||||||
|
let outer_end = self.gen_store_target(ctx, old_end);
|
||||||
|
ctx.builder.build_store(outer_end, max);
|
||||||
|
}
|
||||||
|
self.start = old_start;
|
||||||
|
self.end = old_end;
|
||||||
|
return exited;
|
||||||
|
} else if id == &"sequential".into() {
|
||||||
|
let start = self.start.take();
|
||||||
|
for stmt in body.iter() {
|
||||||
|
if self.gen_stmt(ctx, stmt) {
|
||||||
|
self.start = start;
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.start = start;
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// not parallel/sequential
|
||||||
|
gen_with(self, ctx, stmt)
|
||||||
|
} else {
|
||||||
|
unreachable!()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -19,18 +19,16 @@ use rustpython_parser::{
|
||||||
use parking_lot::{Mutex, RwLock};
|
use parking_lot::{Mutex, RwLock};
|
||||||
|
|
||||||
use nac3core::{
|
use nac3core::{
|
||||||
codegen::{
|
codegen::{concrete_type::ConcreteTypeStore, CodeGenTask, WithCall, WorkerRegistry},
|
||||||
concrete_type::ConcreteTypeStore, CodeGenTask, DefaultCodeGenerator, WithCall,
|
|
||||||
WorkerRegistry,
|
|
||||||
},
|
|
||||||
symbol_resolver::SymbolResolver,
|
symbol_resolver::SymbolResolver,
|
||||||
toplevel::{composer::TopLevelComposer, DefinitionId, GenCall, TopLevelContext, TopLevelDef},
|
toplevel::{composer::TopLevelComposer, DefinitionId, GenCall, TopLevelContext, TopLevelDef},
|
||||||
typecheck::typedef::{FunSignature, FuncArg},
|
typecheck::typedef::{FunSignature, FuncArg},
|
||||||
typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
|
typecheck::{type_inferencer::PrimitiveStore, typedef::Type},
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::symbol_resolver::Resolver;
|
use crate::{codegen::ArtiqCodeGenerator, symbol_resolver::Resolver};
|
||||||
|
|
||||||
|
mod codegen;
|
||||||
mod symbol_resolver;
|
mod symbol_resolver;
|
||||||
mod timeline;
|
mod timeline;
|
||||||
|
|
||||||
|
@ -436,10 +434,14 @@ impl Nac3 {
|
||||||
)
|
)
|
||||||
.expect("couldn't write module to file");
|
.expect("couldn't write module to file");
|
||||||
})));
|
})));
|
||||||
|
let time_fns: &(dyn TimeFns + Sync) = match isa {
|
||||||
|
Isa::RiscV => &timeline::NOW_PINNING_TIME_FNS,
|
||||||
|
Isa::CortexA9 => &timeline::EXTERN_TIME_FNS,
|
||||||
|
};
|
||||||
let thread_names: Vec<String> = (0..4).map(|i| format!("module{}", i)).collect();
|
let thread_names: Vec<String> = (0..4).map(|i| format!("module{}", i)).collect();
|
||||||
let threads: Vec<_> = thread_names
|
let threads: Vec<_> = thread_names
|
||||||
.iter()
|
.iter()
|
||||||
.map(|s| Box::new(DefaultCodeGenerator::new(s.to_string())))
|
.map(|s| Box::new(ArtiqCodeGenerator::new(s.to_string(), time_fns)))
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
py.allow_threads(|| {
|
py.allow_threads(|| {
|
||||||
|
|
|
@ -125,6 +125,14 @@ pub trait CodeGenerator {
|
||||||
gen_if(self, ctx, stmt)
|
gen_if(self, ctx, stmt)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn gen_with<'ctx, 'a>(
|
||||||
|
&mut self,
|
||||||
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
stmt: &Stmt<Option<Type>>,
|
||||||
|
) -> bool {
|
||||||
|
gen_with(self, ctx, stmt)
|
||||||
|
}
|
||||||
|
|
||||||
/// Generate code for a statement
|
/// Generate code for a statement
|
||||||
/// Return true if the statement must early return
|
/// Return true if the statement must early return
|
||||||
fn gen_stmt<'ctx, 'a>(
|
fn gen_stmt<'ctx, 'a>(
|
||||||
|
|
|
@ -28,9 +28,9 @@ use std::sync::{
|
||||||
use std::thread;
|
use std::thread;
|
||||||
|
|
||||||
pub mod concrete_type;
|
pub mod concrete_type;
|
||||||
mod expr;
|
pub mod expr;
|
||||||
|
pub mod stmt;
|
||||||
mod generator;
|
mod generator;
|
||||||
mod stmt;
|
|
||||||
|
|
||||||
#[cfg(test)]
|
#[cfg(test)]
|
||||||
mod test;
|
mod test;
|
||||||
|
|
|
@ -291,6 +291,15 @@ pub fn gen_if<'ctx, 'a, G: CodeGenerator + ?Sized>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub fn gen_with<'ctx, 'a, G: CodeGenerator + ?Sized>(
|
||||||
|
_: &mut G,
|
||||||
|
_: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
_: &Stmt<Option<Type>>,
|
||||||
|
) -> bool {
|
||||||
|
// TODO: Implement with statement after finishing exceptions
|
||||||
|
unimplemented!()
|
||||||
|
}
|
||||||
|
|
||||||
pub fn gen_stmt<'ctx, 'a, G: CodeGenerator + ?Sized>(
|
pub fn gen_stmt<'ctx, 'a, G: CodeGenerator + ?Sized>(
|
||||||
generator: &mut G,
|
generator: &mut G,
|
||||||
ctx: &mut CodeGenContext<'ctx, 'a>,
|
ctx: &mut CodeGenContext<'ctx, 'a>,
|
||||||
|
@ -330,6 +339,7 @@ pub fn gen_stmt<'ctx, 'a, G: CodeGenerator + ?Sized>(
|
||||||
StmtKind::If { .. } => return generator.gen_if(ctx, stmt),
|
StmtKind::If { .. } => return generator.gen_if(ctx, stmt),
|
||||||
StmtKind::While { .. } => return generator.gen_while(ctx, stmt),
|
StmtKind::While { .. } => return generator.gen_while(ctx, stmt),
|
||||||
StmtKind::For { .. } => return generator.gen_for(ctx, stmt),
|
StmtKind::For { .. } => return generator.gen_for(ctx, stmt),
|
||||||
|
StmtKind::With { .. } => return generator.gen_with(ctx, stmt),
|
||||||
_ => unimplemented!(),
|
_ => unimplemented!(),
|
||||||
};
|
};
|
||||||
false
|
false
|
||||||
|
|
|
@ -219,6 +219,17 @@ impl<'a> Inferencer<'a> {
|
||||||
self.check_block(orelse, &mut defined_identifiers)?;
|
self.check_block(orelse, &mut defined_identifiers)?;
|
||||||
Ok(false)
|
Ok(false)
|
||||||
}
|
}
|
||||||
|
StmtKind::With { items, body, .. } => {
|
||||||
|
let mut new_defined_identifiers = defined_identifiers.clone();
|
||||||
|
for item in items.iter() {
|
||||||
|
self.check_expr(&item.context_expr, defined_identifiers)?;
|
||||||
|
if let Some(var) = item.optional_vars.as_ref() {
|
||||||
|
self.check_pattern(var, &mut new_defined_identifiers)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
self.check_block(body, &mut new_defined_identifiers)?;
|
||||||
|
Ok(false)
|
||||||
|
}
|
||||||
StmtKind::Expr { value } => {
|
StmtKind::Expr { value } => {
|
||||||
self.check_expr(value, defined_identifiers)?;
|
self.check_expr(value, defined_identifiers)?;
|
||||||
Ok(false)
|
Ok(false)
|
||||||
|
|
|
@ -64,6 +64,10 @@ impl fold::Fold<()> for NaiveFolder {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
fn report_error<T>(msg: &str, location: Location) -> Result<T, String> {
|
||||||
|
Err(format!("{} at {}", msg, location))
|
||||||
|
}
|
||||||
|
|
||||||
impl<'a> fold::Fold<()> for Inferencer<'a> {
|
impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
type TargetU = Option<Type>;
|
type TargetU = Option<Type>;
|
||||||
type Error = String;
|
type Error = String;
|
||||||
|
@ -165,6 +169,14 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
}
|
}
|
||||||
fold::fold_stmt(self, node)?
|
fold::fold_stmt(self, node)?
|
||||||
}
|
}
|
||||||
|
ast::StmtKind::With { ref items, .. } => {
|
||||||
|
for item in items.iter() {
|
||||||
|
if let Some(var) = &item.optional_vars {
|
||||||
|
self.infer_pattern(var)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fold::fold_stmt(self, node)?
|
||||||
|
}
|
||||||
_ => fold::fold_stmt(self, node)?,
|
_ => fold::fold_stmt(self, node)?,
|
||||||
};
|
};
|
||||||
match &stmt.node {
|
match &stmt.node {
|
||||||
|
@ -186,19 +198,92 @@ impl<'a> fold::Fold<()> for Inferencer<'a> {
|
||||||
}
|
}
|
||||||
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
ast::StmtKind::AnnAssign { .. } | ast::StmtKind::Expr { .. } => {}
|
||||||
ast::StmtKind::Break | ast::StmtKind::Continue | ast::StmtKind::Pass => {}
|
ast::StmtKind::Break | ast::StmtKind::Continue | ast::StmtKind::Pass => {}
|
||||||
|
ast::StmtKind::With { items, .. } => {
|
||||||
|
for item in items.iter() {
|
||||||
|
let ty = item.context_expr.custom.unwrap();
|
||||||
|
// if we can simply unify without creating new types...
|
||||||
|
let mut fast_path = false;
|
||||||
|
if let TypeEnum::TObj { fields, .. } = &*self.unifier.get_ty(ty) {
|
||||||
|
let fields = fields.borrow();
|
||||||
|
fast_path = true;
|
||||||
|
if let Some(enter) = fields.get(&"__enter__".into()).cloned() {
|
||||||
|
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(enter) {
|
||||||
|
let signature = signature.borrow();
|
||||||
|
if !signature.args.is_empty() {
|
||||||
|
return report_error(
|
||||||
|
"__enter__ method should take no argument other than self",
|
||||||
|
stmt.location
|
||||||
|
)
|
||||||
|
}
|
||||||
|
if let Some(var) = &item.optional_vars {
|
||||||
|
if signature.vars.is_empty() {
|
||||||
|
self.unify(signature.ret, var.custom.unwrap(), &stmt.location)?;
|
||||||
|
} else {
|
||||||
|
fast_path = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fast_path = false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return report_error(
|
||||||
|
"__enter__ method is required for context manager",
|
||||||
|
stmt.location
|
||||||
|
);
|
||||||
|
}
|
||||||
|
if let Some(exit) = fields.get(&"__exit__".into()).cloned() {
|
||||||
|
if let TypeEnum::TFunc(signature) = &*self.unifier.get_ty(exit) {
|
||||||
|
let signature = signature.borrow();
|
||||||
|
if !signature.args.is_empty() {
|
||||||
|
return report_error(
|
||||||
|
"__exit__ method should take no argument other than self",
|
||||||
|
stmt.location
|
||||||
|
)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
fast_path = false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return report_error(
|
||||||
|
"__exit__ method is required for context manager",
|
||||||
|
stmt.location
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !fast_path {
|
||||||
|
let enter = TypeEnum::TFunc(RefCell::new(FunSignature {
|
||||||
|
args: vec![],
|
||||||
|
ret: item.optional_vars.as_ref().map_or_else(|| self.unifier.get_fresh_var().0, |var| var.custom.unwrap()),
|
||||||
|
vars: Default::default()
|
||||||
|
}));
|
||||||
|
let enter = self.unifier.add_ty(enter);
|
||||||
|
let exit = TypeEnum::TFunc(RefCell::new(FunSignature {
|
||||||
|
args: vec![],
|
||||||
|
ret: self.unifier.get_fresh_var().0,
|
||||||
|
vars: Default::default()
|
||||||
|
}));
|
||||||
|
let exit = self.unifier.add_ty(exit);
|
||||||
|
let mut fields = HashMap::new();
|
||||||
|
fields.insert("__enter__".into(), enter);
|
||||||
|
fields.insert("__exit__".into(), exit);
|
||||||
|
let record = self.unifier.add_record(fields);
|
||||||
|
self.unify(ty, record, &stmt.location)?;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
ast::StmtKind::Return { value } => match (value, self.function_data.return_type) {
|
ast::StmtKind::Return { value } => match (value, self.function_data.return_type) {
|
||||||
(Some(v), Some(v1)) => {
|
(Some(v), Some(v1)) => {
|
||||||
self.unify(v.custom.unwrap(), v1, &v.location)?;
|
self.unify(v.custom.unwrap(), v1, &v.location)?;
|
||||||
}
|
}
|
||||||
(Some(_), None) => {
|
(Some(_), None) => {
|
||||||
return Err("Unexpected return value".to_string());
|
return report_error("Unexpected return value", stmt.location);
|
||||||
}
|
}
|
||||||
(None, Some(_)) => {
|
(None, Some(_)) => {
|
||||||
return Err("Expected return value".to_string());
|
return report_error("Expected return value", stmt.location);
|
||||||
}
|
}
|
||||||
(None, None) => {}
|
(None, None) => {}
|
||||||
},
|
},
|
||||||
_ => return Err("Unsupported statement type".to_string()),
|
_ => return report_error("Unsupported statement type", stmt.location),
|
||||||
};
|
};
|
||||||
Ok(stmt)
|
Ok(stmt)
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue