forked from M-Labs/nac3
Compare commits
1 Commits
master
...
feature/rp
Author | SHA1 | Date | |
---|---|---|---|
79931365b7 |
@ -114,13 +114,26 @@ def extern(function):
|
|||||||
|
|
||||||
|
|
||||||
def rpc(arg=None, flags={}):
|
def rpc(arg=None, flags={}):
|
||||||
"""Decorates a function or method to be executed on the host interpreter."""
|
"""Decorates a function to be executed on the host interpreter with kwargs support."""
|
||||||
|
def decorator(function):
|
||||||
|
@wraps(function)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
# Get function signature
|
||||||
|
sig = inspect.signature(function)
|
||||||
|
|
||||||
|
# Validate kwargs against signature
|
||||||
|
bound_args = sig.bind(*args, **kwargs)
|
||||||
|
bound_args.apply_defaults()
|
||||||
|
|
||||||
|
# Call RPC with both args and kwargs
|
||||||
|
return _do_rpc(function.__name__,
|
||||||
|
bound_args.args,
|
||||||
|
bound_args.kwargs)
|
||||||
|
return wrapper
|
||||||
|
|
||||||
if arg is None:
|
if arg is None:
|
||||||
def inner_decorator(function):
|
return decorator
|
||||||
return rpc(function, flags)
|
return decorator(arg)
|
||||||
return inner_decorator
|
|
||||||
register_function(arg)
|
|
||||||
return arg
|
|
||||||
|
|
||||||
def kernel(function_or_method):
|
def kernel(function_or_method):
|
||||||
"""Decorates a function or method to be executed on the core device."""
|
"""Decorates a function or method to be executed on the core device."""
|
||||||
|
@ -79,8 +79,7 @@ pub struct ArtiqCodeGenerator<'a> {
|
|||||||
|
|
||||||
/// The [`ParallelMode`] of the current parallel context.
|
/// The [`ParallelMode`] of the current parallel context.
|
||||||
///
|
///
|
||||||
/// The current parallel context refers to the nearest `with parallel` or `with legacy_parallel`
|
/// The current parallel context refers to the nearest `with` statement, which is used to determine when and how the timeline should be updated.
|
||||||
/// statement, which is used to determine when and how the timeline should be updated.
|
|
||||||
parallel_mode: ParallelMode,
|
parallel_mode: ParallelMode,
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -373,8 +372,14 @@ impl<'b> CodeGenerator for ArtiqCodeGenerator<'b> {
|
|||||||
fn gen_rpc_tag(
|
fn gen_rpc_tag(
|
||||||
ctx: &mut CodeGenContext<'_, '_>,
|
ctx: &mut CodeGenContext<'_, '_>,
|
||||||
ty: Type,
|
ty: Type,
|
||||||
|
is_kwarg: bool, // Add this parameter
|
||||||
buffer: &mut Vec<u8>,
|
buffer: &mut Vec<u8>,
|
||||||
) -> Result<(), String> {
|
) -> Result<(), String> {
|
||||||
|
// Add kwarg marker if needed
|
||||||
|
if is_kwarg {
|
||||||
|
buffer.push(b'k'); // 'k' for keyword argument
|
||||||
|
}
|
||||||
|
|
||||||
use nac3core::typecheck::typedef::TypeEnum::*;
|
use nac3core::typecheck::typedef::TypeEnum::*;
|
||||||
|
|
||||||
let int32 = ctx.primitives.int32;
|
let int32 = ctx.primitives.int32;
|
||||||
@ -403,14 +408,14 @@ fn gen_rpc_tag(
|
|||||||
buffer.push(b't');
|
buffer.push(b't');
|
||||||
buffer.push(ty.len() as u8);
|
buffer.push(ty.len() as u8);
|
||||||
for ty in ty {
|
for ty in ty {
|
||||||
gen_rpc_tag(ctx, *ty, buffer)?;
|
gen_rpc_tag(ctx, *ty, false, buffer)?; // Pass false for is_kwarg
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
||||||
let ty = iter_type_vars(params).next().unwrap().ty;
|
let ty = iter_type_vars(params).next().unwrap().ty;
|
||||||
|
|
||||||
buffer.push(b'l');
|
buffer.push(b'l');
|
||||||
gen_rpc_tag(ctx, ty, buffer)?;
|
gen_rpc_tag(ctx, ty, false, buffer)?; // Pass false for is_kwarg
|
||||||
}
|
}
|
||||||
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
TObj { obj_id, .. } if *obj_id == PrimDef::NDArray.id() => {
|
||||||
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
let (ndarray_dtype, ndarray_ndims) = unpack_ndarray_var_tys(&mut ctx.unifier, ty);
|
||||||
@ -434,7 +439,7 @@ fn gen_rpc_tag(
|
|||||||
|
|
||||||
buffer.push(b'a');
|
buffer.push(b'a');
|
||||||
buffer.push((ndarray_ndims & 0xFF) as u8);
|
buffer.push((ndarray_ndims & 0xFF) as u8);
|
||||||
gen_rpc_tag(ctx, ndarray_dtype, buffer)?;
|
gen_rpc_tag(ctx, ndarray_dtype, false, buffer)?; // Pass false for is_kwarg
|
||||||
}
|
}
|
||||||
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
_ => return Err(format!("Unsupported type: {:?}", ctx.unifier.stringify(ty))),
|
||||||
}
|
}
|
||||||
@ -808,10 +813,10 @@ fn rpc_codegen_callback_fn<'ctx>(
|
|||||||
tag.push(b'O');
|
tag.push(b'O');
|
||||||
}
|
}
|
||||||
for arg in &fun.0.args {
|
for arg in &fun.0.args {
|
||||||
gen_rpc_tag(ctx, arg.ty, &mut tag)?;
|
gen_rpc_tag(ctx, arg.ty, false, &mut tag)?; // Pass false for is_kwarg
|
||||||
}
|
}
|
||||||
tag.push(b':');
|
tag.push(b':');
|
||||||
gen_rpc_tag(ctx, fun.0.ret, &mut tag)?;
|
gen_rpc_tag(ctx, fun.0.ret, false, &mut tag)?;
|
||||||
|
|
||||||
let mut hasher = DefaultHasher::new();
|
let mut hasher = DefaultHasher::new();
|
||||||
tag.hash(&mut hasher);
|
tag.hash(&mut hasher);
|
||||||
@ -858,8 +863,17 @@ fn rpc_codegen_callback_fn<'ctx>(
|
|||||||
// -- rpc args handling
|
// -- rpc args handling
|
||||||
let mut keys = fun.0.args.clone();
|
let mut keys = fun.0.args.clone();
|
||||||
let mut mapping = HashMap::new();
|
let mut mapping = HashMap::new();
|
||||||
|
let mut is_keyword_arg = HashMap::new();
|
||||||
|
|
||||||
for (key, value) in args {
|
for (key, value) in args {
|
||||||
mapping.insert(key.unwrap_or_else(|| keys.remove(0).name), value);
|
if let Some(key_name) = key {
|
||||||
|
mapping.insert(key_name, value);
|
||||||
|
is_keyword_arg.insert(key_name, true);
|
||||||
|
} else {
|
||||||
|
let arg_name = keys.remove(0).name;
|
||||||
|
mapping.insert(arg_name, value);
|
||||||
|
is_keyword_arg.insert(arg_name, false);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
// default value handling
|
// default value handling
|
||||||
for k in keys {
|
for k in keys {
|
||||||
@ -901,6 +915,14 @@ fn rpc_codegen_callback_fn<'ctx>(
|
|||||||
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
|
ctx.builder.build_store(arg_ptr, arg_slot).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Before calling rpc_send/rpc_send_async, add keyword arg info to tag
|
||||||
|
for arg in &fun.0.args {
|
||||||
|
if *is_keyword_arg.get(&arg.name).unwrap_or(&false) {
|
||||||
|
tag.push(b'k'); // Mark as keyword argument
|
||||||
|
}
|
||||||
|
gen_rpc_tag(ctx, arg.ty, true, &mut tag)?; // Pass true for is_kwarg
|
||||||
|
}
|
||||||
|
|
||||||
// call
|
// call
|
||||||
if is_async {
|
if is_async {
|
||||||
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
|
let rpc_send_async = ctx.module.get_function("rpc_send_async").unwrap_or_else(|| {
|
||||||
@ -1007,7 +1029,7 @@ pub fn attributes_writeback<'ctx>(
|
|||||||
if !is_mutable {
|
if !is_mutable {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
if gen_rpc_tag(ctx, *field_ty, &mut scratch_buffer).is_ok() {
|
if gen_rpc_tag(ctx, *field_ty, false, &mut scratch_buffer).is_ok() {
|
||||||
attributes.push(name.to_string());
|
attributes.push(name.to_string());
|
||||||
let (index, _) = ctx.get_attr_index(ty, *name);
|
let (index, _) = ctx.get_attr_index(ty, *name);
|
||||||
values.push((
|
values.push((
|
||||||
@ -1030,7 +1052,7 @@ pub fn attributes_writeback<'ctx>(
|
|||||||
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
TypeEnum::TObj { obj_id, params, .. } if *obj_id == PrimDef::List.id() => {
|
||||||
let elem_ty = iter_type_vars(params).next().unwrap().ty;
|
let elem_ty = iter_type_vars(params).next().unwrap().ty;
|
||||||
|
|
||||||
if gen_rpc_tag(ctx, elem_ty, &mut scratch_buffer).is_ok() {
|
if gen_rpc_tag(ctx, elem_ty, false, &mut scratch_buffer).is_ok() {
|
||||||
let pydict = PyDict::new(py);
|
let pydict = PyDict::new(py);
|
||||||
pydict.set_item("obj", val)?;
|
pydict.set_item("obj", val)?;
|
||||||
host_attributes.append(pydict)?;
|
host_attributes.append(pydict)?;
|
||||||
|
@ -4,9 +4,7 @@
|
|||||||
#include "irrt/ndarray.hpp"
|
#include "irrt/ndarray.hpp"
|
||||||
#include "irrt/range.hpp"
|
#include "irrt/range.hpp"
|
||||||
#include "irrt/slice.hpp"
|
#include "irrt/slice.hpp"
|
||||||
#include "irrt/string.hpp"
|
|
||||||
#include "irrt/ndarray/basic.hpp"
|
#include "irrt/ndarray/basic.hpp"
|
||||||
#include "irrt/ndarray/def.hpp"
|
#include "irrt/ndarray/def.hpp"
|
||||||
#include "irrt/ndarray/iter.hpp"
|
#include "irrt/ndarray/iter.hpp"
|
||||||
#include "irrt/ndarray/indexing.hpp"
|
#include "irrt/ndarray/indexing.hpp"
|
||||||
#include "irrt/string.hpp"
|
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
#pragma once
|
|
||||||
|
|
||||||
#include "irrt/int_types.hpp"
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
template<typename SizeT>
|
|
||||||
SizeT __nac3_str_eq_impl(const char* str1, SizeT len1, const char* str2, SizeT len2) {
|
|
||||||
if (len1 != len2){
|
|
||||||
return 0;
|
|
||||||
}
|
|
||||||
return (__builtin_memcmp(str1, str2, static_cast<SizeT>(len1)) == 0) ? 1 : 0;
|
|
||||||
}
|
|
||||||
} // namespace
|
|
||||||
|
|
||||||
extern "C" {
|
|
||||||
uint32_t nac3_str_eq(const char* str1, uint32_t len1, const char* str2, uint32_t len2) {
|
|
||||||
return __nac3_str_eq_impl<uint32_t>(str1, len1, str2, len2);
|
|
||||||
}
|
|
||||||
|
|
||||||
uint64_t nac3_str_eq64(const char* str1, uint64_t len1, const char* str2, uint64_t len2) {
|
|
||||||
return __nac3_str_eq_impl<uint64_t>(str1, len1, str2, len2);
|
|
||||||
}
|
|
||||||
}
|
|
@ -24,7 +24,7 @@ use super::{
|
|||||||
irrt::*,
|
irrt::*,
|
||||||
llvm_intrinsics::{
|
llvm_intrinsics::{
|
||||||
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
call_expect, call_float_floor, call_float_pow, call_float_powi, call_int_smax,
|
||||||
call_memcpy_generic,
|
call_int_umin, call_memcpy_generic,
|
||||||
},
|
},
|
||||||
macros::codegen_unreachable,
|
macros::codegen_unreachable,
|
||||||
need_sret, numpy,
|
need_sret, numpy,
|
||||||
@ -2045,43 +2045,111 @@ pub fn gen_cmpop_expr_with_values<'ctx, G: CodeGenerator>(
|
|||||||
} else if left_ty == ctx.primitives.str {
|
} else if left_ty == ctx.primitives.str {
|
||||||
assert!(ctx.unifier.unioned(left_ty, right_ty));
|
assert!(ctx.unifier.unioned(left_ty, right_ty));
|
||||||
|
|
||||||
let lhs = lhs.into_struct_value();
|
let llvm_i1 = ctx.ctx.bool_type();
|
||||||
let rhs = rhs.into_struct_value();
|
|
||||||
|
|
||||||
let llvm_i32 = ctx.ctx.i32_type();
|
let llvm_i32 = ctx.ctx.i32_type();
|
||||||
let llvm_usize = generator.get_size_type(ctx.ctx);
|
let llvm_usize = generator.get_size_type(ctx.ctx);
|
||||||
|
|
||||||
|
let lhs = lhs.into_struct_value();
|
||||||
|
let rhs = rhs.into_struct_value();
|
||||||
|
|
||||||
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
let plhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
||||||
ctx.builder.build_store(plhs, lhs).unwrap();
|
ctx.builder.build_store(plhs, lhs).unwrap();
|
||||||
let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
let prhs = generator.gen_var_alloc(ctx, lhs.get_type().into(), None).unwrap();
|
||||||
ctx.builder.build_store(prhs, rhs).unwrap();
|
ctx.builder.build_store(prhs, rhs).unwrap();
|
||||||
|
|
||||||
let lhs_ptr = ctx.build_in_bounds_gep_and_load(
|
|
||||||
plhs,
|
|
||||||
&[llvm_usize.const_zero(), llvm_i32.const_zero()],
|
|
||||||
None,
|
|
||||||
).into_pointer_value();
|
|
||||||
let lhs_len = ctx.build_in_bounds_gep_and_load(
|
let lhs_len = ctx.build_in_bounds_gep_and_load(
|
||||||
plhs,
|
plhs,
|
||||||
&[llvm_usize.const_zero(), llvm_i32.const_int(1, false)],
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
||||||
|
None,
|
||||||
|
).into_int_value();
|
||||||
|
let rhs_len = ctx.build_in_bounds_gep_and_load(
|
||||||
|
prhs,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_int(1, false)],
|
||||||
None,
|
None,
|
||||||
).into_int_value();
|
).into_int_value();
|
||||||
|
|
||||||
let rhs_ptr = ctx.build_in_bounds_gep_and_load(
|
let len = call_int_umin(ctx, lhs_len, rhs_len, None);
|
||||||
prhs,
|
|
||||||
&[llvm_usize.const_zero(), llvm_i32.const_zero()],
|
let current_bb = ctx.builder.get_insert_block().unwrap();
|
||||||
|
let post_foreach_cmp = ctx.ctx.insert_basic_block_after(current_bb, "foreach.cmp.end");
|
||||||
|
|
||||||
|
ctx.builder.position_at_end(post_foreach_cmp);
|
||||||
|
let cmp_phi = ctx.builder.build_phi(llvm_i1, "").unwrap();
|
||||||
|
ctx.builder.position_at_end(current_bb);
|
||||||
|
|
||||||
|
gen_for_callback_incrementing(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
None,
|
||||||
|
llvm_usize.const_zero(),
|
||||||
|
(len, false),
|
||||||
|
|generator, ctx, _, i| {
|
||||||
|
let lhs_char = {
|
||||||
|
let plhs_data = ctx.build_in_bounds_gep_and_load(
|
||||||
|
plhs,
|
||||||
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
None,
|
None,
|
||||||
).into_pointer_value();
|
).into_pointer_value();
|
||||||
let rhs_len = ctx.build_in_bounds_gep_and_load(
|
|
||||||
|
ctx.build_in_bounds_gep_and_load(
|
||||||
|
plhs_data,
|
||||||
|
&[i],
|
||||||
|
None
|
||||||
|
).into_int_value()
|
||||||
|
};
|
||||||
|
let rhs_char = {
|
||||||
|
let prhs_data = ctx.build_in_bounds_gep_and_load(
|
||||||
prhs,
|
prhs,
|
||||||
&[llvm_usize.const_zero(), llvm_i32.const_int(1, false)],
|
&[llvm_i32.const_zero(), llvm_i32.const_zero()],
|
||||||
None,
|
None,
|
||||||
).into_int_value();
|
).into_pointer_value();
|
||||||
let result = call_string_eq(generator, ctx, lhs_ptr, lhs_len, rhs_ptr, rhs_len);
|
|
||||||
|
ctx.build_in_bounds_gep_and_load(
|
||||||
|
prhs_data,
|
||||||
|
&[i],
|
||||||
|
None
|
||||||
|
).into_int_value()
|
||||||
|
};
|
||||||
|
|
||||||
|
gen_if_callback(
|
||||||
|
generator,
|
||||||
|
ctx,
|
||||||
|
|_, ctx| {
|
||||||
|
Ok(ctx.builder.build_int_compare(IntPredicate::NE, lhs_char, rhs_char, "").unwrap())
|
||||||
|
},
|
||||||
|
|_, ctx| {
|
||||||
|
let bb = ctx.builder.get_insert_block().unwrap();
|
||||||
|
cmp_phi.add_incoming(&[(&llvm_i1.const_zero(), bb)]);
|
||||||
|
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
|_, _| Ok(()),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
llvm_usize.const_int(1, false),
|
||||||
|
)?;
|
||||||
|
|
||||||
|
let bb = ctx.builder.get_insert_block().unwrap();
|
||||||
|
let is_len_eq = ctx.builder.build_int_compare(
|
||||||
|
IntPredicate::EQ,
|
||||||
|
lhs_len,
|
||||||
|
rhs_len,
|
||||||
|
"",
|
||||||
|
).unwrap();
|
||||||
|
cmp_phi.add_incoming(&[(&is_len_eq, bb)]);
|
||||||
|
ctx.builder.build_unconditional_branch(post_foreach_cmp).unwrap();
|
||||||
|
|
||||||
|
ctx.builder.position_at_end(post_foreach_cmp);
|
||||||
|
let cmp_phi = cmp_phi.as_basic_value().into_int_value();
|
||||||
|
|
||||||
|
// Invert the final value if __ne__
|
||||||
if *op == Cmpop::NotEq {
|
if *op == Cmpop::NotEq {
|
||||||
ctx.builder.build_not(result, "").unwrap()
|
ctx.builder.build_not(cmp_phi, "").unwrap()
|
||||||
} else {
|
} else {
|
||||||
result
|
cmp_phi
|
||||||
}
|
}
|
||||||
} else if [left_ty, right_ty]
|
} else if [left_ty, right_ty]
|
||||||
.iter()
|
.iter()
|
||||||
|
@ -15,14 +15,12 @@ pub use list::*;
|
|||||||
pub use math::*;
|
pub use math::*;
|
||||||
pub use range::*;
|
pub use range::*;
|
||||||
pub use slice::*;
|
pub use slice::*;
|
||||||
pub use string::*;
|
|
||||||
|
|
||||||
mod list;
|
mod list;
|
||||||
mod math;
|
mod math;
|
||||||
pub mod ndarray;
|
pub mod ndarray;
|
||||||
mod range;
|
mod range;
|
||||||
mod slice;
|
mod slice;
|
||||||
mod string;
|
|
||||||
|
|
||||||
#[must_use]
|
#[must_use]
|
||||||
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
pub fn load_irrt<'ctx>(ctx: &'ctx Context, symbol_resolver: &dyn SymbolResolver) -> Module<'ctx> {
|
||||||
|
@ -1,48 +0,0 @@
|
|||||||
use inkwell::values::{BasicValueEnum, CallSiteValue, IntValue, PointerValue};
|
|
||||||
use itertools::Either;
|
|
||||||
|
|
||||||
use crate::codegen::{macros::codegen_unreachable, CodeGenContext, CodeGenerator};
|
|
||||||
|
|
||||||
/// Generates a call to string equality comparison. Returns an `i1` representing whether the strings are equal.
|
|
||||||
pub fn call_string_eq<'ctx, G: CodeGenerator + ?Sized>(
|
|
||||||
generator: &G,
|
|
||||||
ctx: &CodeGenContext<'ctx, '_>,
|
|
||||||
str1_ptr: PointerValue<'ctx>,
|
|
||||||
str1_len: IntValue<'ctx>,
|
|
||||||
str2_ptr: PointerValue<'ctx>,
|
|
||||||
str2_len: IntValue<'ctx>,
|
|
||||||
) -> IntValue<'ctx> {
|
|
||||||
let (func_name, return_type) = match ctx.ctx.i32_type().get_bit_width() {
|
|
||||||
32 => ("nac3_str_eq", ctx.ctx.i32_type()),
|
|
||||||
64 => ("nac3_str_eq64", ctx.ctx.i64_type()),
|
|
||||||
bw => codegen_unreachable!(ctx, "Unsupported size type bit width: {}", bw),
|
|
||||||
};
|
|
||||||
|
|
||||||
let func = ctx.module.get_function(func_name).unwrap_or_else(|| {
|
|
||||||
ctx.module.add_function(
|
|
||||||
func_name,
|
|
||||||
return_type.fn_type(
|
|
||||||
&[
|
|
||||||
str1_ptr.get_type().into(),
|
|
||||||
str1_len.get_type().into(),
|
|
||||||
str2_ptr.get_type().into(),
|
|
||||||
str2_len.get_type().into(),
|
|
||||||
],
|
|
||||||
false,
|
|
||||||
),
|
|
||||||
None,
|
|
||||||
)
|
|
||||||
});
|
|
||||||
let result = ctx
|
|
||||||
.builder
|
|
||||||
.build_call(
|
|
||||||
func,
|
|
||||||
&[str1_ptr.into(), str1_len.into(), str2_ptr.into(), str2_len.into()],
|
|
||||||
"str_eq_call",
|
|
||||||
)
|
|
||||||
.map(CallSiteValue::try_as_basic_value)
|
|
||||||
.map(|v| v.map_left(BasicValueEnum::into_int_value))
|
|
||||||
.map(Either::unwrap_left)
|
|
||||||
.unwrap();
|
|
||||||
generator.bool_to_i1(ctx, result)
|
|
||||||
}
|
|
@ -1832,20 +1832,47 @@ impl<'a> Inferencer<'a> {
|
|||||||
|
|
||||||
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) {
|
if let TypeEnum::TFunc(sign) = &*self.unifier.get_ty(func.custom.unwrap()) {
|
||||||
if sign.vars.is_empty() {
|
if sign.vars.is_empty() {
|
||||||
|
// Build keyword argument map
|
||||||
|
let mut kwargs_map = HashMap::new();
|
||||||
|
for kw in &keywords {
|
||||||
|
if let Some(name) = &kw.node.arg {
|
||||||
|
// Check if keyword arg exists in function signature
|
||||||
|
if !sign.args.iter().any(|arg| arg.name == *name) {
|
||||||
|
return report_error(
|
||||||
|
&format!("Unexpected keyword argument '{}'", name),
|
||||||
|
kw.location,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
kwargs_map.insert(*name, kw.node.value.custom.unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate that all required args are provided
|
||||||
|
for arg in &sign.args {
|
||||||
|
if arg.default_value.is_none()
|
||||||
|
&& !kwargs_map.contains_key(&arg.name)
|
||||||
|
&& args.len() < sign.args.len()
|
||||||
|
{
|
||||||
|
return report_error(
|
||||||
|
&format!("Missing required argument '{}'", arg.name),
|
||||||
|
location,
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
let call = Call {
|
let call = Call {
|
||||||
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
||||||
kwargs: keywords
|
kwargs: kwargs_map,
|
||||||
.iter()
|
|
||||||
.map(|v| (*v.node.arg.as_ref().unwrap(), v.node.value.custom.unwrap()))
|
|
||||||
.collect(),
|
|
||||||
fun: RefCell::new(None),
|
fun: RefCell::new(None),
|
||||||
ret: sign.ret,
|
ret: sign.ret,
|
||||||
loc: Some(location),
|
loc: Some(location),
|
||||||
operator_info: None,
|
operator_info: None,
|
||||||
};
|
};
|
||||||
|
|
||||||
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
|
self.unifier.unify_call(&call, func.custom.unwrap(), sign).map_err(|e| {
|
||||||
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
|
HashSet::from([e.at(Some(location)).to_display(self.unifier).to_string()])
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
return Ok(Located {
|
return Ok(Located {
|
||||||
location,
|
location,
|
||||||
custom: Some(sign.ret),
|
custom: Some(sign.ret),
|
||||||
@ -1859,7 +1886,7 @@ impl<'a> Inferencer<'a> {
|
|||||||
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
posargs: args.iter().map(|v| v.custom.unwrap()).collect(),
|
||||||
kwargs: keywords
|
kwargs: keywords
|
||||||
.iter()
|
.iter()
|
||||||
.map(|v| (*v.node.arg.as_ref().unwrap(), v.custom.unwrap()))
|
.filter_map(|v| v.node.arg.map(|name| (name, v.node.value.custom.unwrap())))
|
||||||
.collect(),
|
.collect(),
|
||||||
fun: RefCell::new(None),
|
fun: RefCell::new(None),
|
||||||
ret,
|
ret,
|
||||||
|
Loading…
Reference in New Issue
Block a user