forked from M-Labs/artiq
1
0
Fork 0

compiler: add basic numpy array support (#424).

This commit is contained in:
whitequark 2016-07-06 09:51:57 +00:00
parent 906db876a6
commit 933ea53c77
13 changed files with 109 additions and 30 deletions

View File

@ -44,12 +44,12 @@ def TInt32():
def TInt64(): def TInt64():
return TInt(types.TValue(64)) return TInt(types.TValue(64))
def _int_printer(typ, depth, max_depth): def _int_printer(typ, printer, depth, max_depth):
if types.is_var(typ["width"]): if types.is_var(typ["width"]):
return "numpy.int?" return "numpy.int?"
else: else:
return "numpy.int{}".format(types.get_value(typ.find()["width"])) return "numpy.int{}".format(types.get_value(typ.find()["width"]))
types.TypePrinter.custom_printers['int'] = _int_printer types.TypePrinter.custom_printers["int"] = _int_printer
class TFloat(types.TMono): class TFloat(types.TMono):
def __init__(self): def __init__(self):
@ -73,6 +73,16 @@ class TList(types.TMono):
elt = types.TVar() elt = types.TVar()
super().__init__("list", {"elt": elt}) super().__init__("list", {"elt": elt})
class TArray(types.TMono):
def __init__(self, elt=None):
if elt is None:
elt = types.TVar()
super().__init__("array", {"elt": elt})
def _array_printer(typ, printer, depth, max_depth):
return "numpy.array(elt={})".format(printer.name(typ["elt"], depth, max_depth))
types.TypePrinter.custom_printers["array"] = _array_printer
class TRange(types.TMono): class TRange(types.TMono):
def __init__(self, elt=None): def __init__(self, elt=None):
if elt is None: if elt is None:
@ -124,6 +134,9 @@ def fn_str():
def fn_list(): def fn_list():
return types.TConstructor(TList()) return types.TConstructor(TList())
def fn_array():
return types.TConstructor(TArray())
def fn_Exception(): def fn_Exception():
return types.TExceptionConstructor(TException("Exception")) return types.TExceptionConstructor(TException("Exception"))
@ -231,6 +244,15 @@ def is_list(typ, elt=None):
else: else:
return types.is_mono(typ, "list") return types.is_mono(typ, "list")
def is_array(typ, elt=None):
if elt is not None:
return types.is_mono(typ, "array", elt=elt)
else:
return types.is_mono(typ, "array")
def is_listish(typ, elt=None):
return is_list(typ, elt) or is_array(typ, elt)
def is_range(typ, elt=None): def is_range(typ, elt=None):
if elt is not None: if elt is not None:
return types.is_mono(typ, "range", {"elt": elt}) return types.is_mono(typ, "range", {"elt": elt})
@ -247,7 +269,7 @@ def is_exception(typ, name=None):
def is_iterable(typ): def is_iterable(typ):
typ = typ.find() typ = typ.find()
return isinstance(typ, types.TMono) and \ return isinstance(typ, types.TMono) and \
typ.name in ('list', 'range') typ.name in ('list', 'array', 'range')
def get_iterable_elt(typ): def get_iterable_elt(typ):
if is_iterable(typ): if is_iterable(typ):

View File

@ -187,6 +187,18 @@ class ASTSynthesizer:
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(), return asttyped.ListT(elts=elts, ctx=None, type=builtins.TList(),
begin_loc=begin_loc, end_loc=end_loc, begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc)) loc=begin_loc.join(end_loc))
elif isinstance(value, numpy.ndarray):
begin_loc = self._add("numpy.array([")
elts = []
for index, elt in enumerate(value):
elts.append(self.quote(elt))
if index < len(value) - 1:
self._add(", ")
end_loc = self._add("])")
return asttyped.ListT(elts=elts, ctx=None, type=builtins.TArray(),
begin_loc=begin_loc, end_loc=end_loc,
loc=begin_loc.join(end_loc))
elif inspect.isfunction(value) or inspect.ismethod(value) or \ elif inspect.isfunction(value) or inspect.ismethod(value) or \
isinstance(value, pytypes.BuiltinFunctionType) or \ isinstance(value, pytypes.BuiltinFunctionType) or \
isinstance(value, SpecializedFunction): isinstance(value, SpecializedFunction):

View File

@ -12,6 +12,7 @@ def globals():
"int": builtins.fn_int(), "int": builtins.fn_int(),
"float": builtins.fn_float(), "float": builtins.fn_float(),
"list": builtins.fn_list(), "list": builtins.fn_list(),
"array": builtins.fn_array(),
"range": builtins.fn_range(), "range": builtins.fn_range(),
# Exception constructors # Exception constructors

View File

@ -477,7 +477,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
self.continue_target = old_continue self.continue_target = old_continue
def iterable_len(self, value, typ=_size_type): def iterable_len(self, value, typ=_size_type):
if builtins.is_list(value.type): if builtins.is_listish(value.type):
return self.append(ir.Builtin("len", [value], typ, return self.append(ir.Builtin("len", [value], typ,
name="{}.len".format(value.name))) name="{}.len".format(value.name)))
elif builtins.is_range(value.type): elif builtins.is_range(value.type):
@ -492,7 +492,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
def iterable_get(self, value, index): def iterable_get(self, value, index):
# Assuming the value is within bounds. # Assuming the value is within bounds.
if builtins.is_list(value.type): if builtins.is_listish(value.type):
return self.append(ir.GetElem(value, index)) return self.append(ir.GetElem(value, index))
elif builtins.is_range(value.type): elif builtins.is_range(value.type):
start = self.append(ir.GetAttr(value, "start")) start = self.append(ir.GetAttr(value, "start"))
@ -1322,7 +1322,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
for index, elt in enumerate(node.right.type.elts): for index, elt in enumerate(node.right.type.elts):
elts.append(self.append(ir.GetAttr(rhs, index))) elts.append(self.append(ir.GetAttr(rhs, index)))
return self.append(ir.Alloc(elts, node.type)) return self.append(ir.Alloc(elts, node.type))
elif builtins.is_list(node.left.type) and builtins.is_list(node.right.type): elif builtins.is_listish(node.left.type) and builtins.is_listish(node.right.type):
lhs_length = self.iterable_len(lhs) lhs_length = self.iterable_len(lhs)
rhs_length = self.iterable_len(rhs) rhs_length = self.iterable_len(rhs)
@ -1355,9 +1355,9 @@ class ARTIQIRGenerator(algorithm.Visitor):
assert False assert False
elif isinstance(node.op, ast.Mult): # list * int, int * list elif isinstance(node.op, ast.Mult): # list * int, int * list
lhs, rhs = self.visit(node.left), self.visit(node.right) lhs, rhs = self.visit(node.left), self.visit(node.right)
if builtins.is_list(lhs.type) and builtins.is_int(rhs.type): if builtins.is_listish(lhs.type) and builtins.is_int(rhs.type):
lst, num = lhs, rhs lst, num = lhs, rhs
elif builtins.is_int(lhs.type) and builtins.is_list(rhs.type): elif builtins.is_int(lhs.type) and builtins.is_listish(rhs.type):
lst, num = rhs, lhs lst, num = rhs, lhs
else: else:
assert False assert False
@ -1412,7 +1412,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
result = self.append(ir.Select(result, elt_result, result = self.append(ir.Select(result, elt_result,
ir.Constant(False, builtins.TBool()))) ir.Constant(False, builtins.TBool())))
return result return result
elif builtins.is_list(lhs.type) and builtins.is_list(rhs.type): elif builtins.is_listish(lhs.type) and builtins.is_listish(rhs.type):
head = self.current_block head = self.current_block
lhs_length = self.iterable_len(lhs) lhs_length = self.iterable_len(lhs)
rhs_length = self.iterable_len(rhs) rhs_length = self.iterable_len(rhs)
@ -1606,7 +1606,7 @@ class ARTIQIRGenerator(algorithm.Visitor):
return self.append(ir.Coerce(arg, node.type)) return self.append(ir.Coerce(arg, node.type))
else: else:
assert False assert False
elif types.is_builtin(typ, "list"): elif types.is_builtin(typ, "list") or types.is_builtin(typ, "array"):
if len(node.args) == 0 and len(node.keywords) == 0: if len(node.args) == 0 and len(node.keywords) == 0:
length = ir.Constant(0, builtins.TInt32()) length = ir.Constant(0, builtins.TInt32())
return self.append(ir.Alloc([length], node.type)) return self.append(ir.Alloc([length], node.type))
@ -1968,8 +1968,13 @@ class ARTIQIRGenerator(algorithm.Visitor):
else: else:
format_string += "%s" format_string += "%s"
args.append(value) args.append(value)
elif builtins.is_list(value.type): elif builtins.is_listish(value.type):
if builtins.is_list(value.type):
format_string += "["; flush() format_string += "["; flush()
elif builtins.is_array(value.type):
format_string += "array(["; flush()
else:
assert False
length = self.iterable_len(value) length = self.iterable_len(value)
last = self.append(ir.Arith(ast.Sub(loc=None), length, ir.Constant(1, length.type))) last = self.append(ir.Arith(ast.Sub(loc=None), length, ir.Constant(1, length.type)))
@ -1992,7 +1997,10 @@ class ARTIQIRGenerator(algorithm.Visitor):
lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, length)), lambda index: self.append(ir.Compare(ast.Lt(loc=None), index, length)),
body_gen) body_gen)
if builtins.is_list(value.type):
format_string += "]" format_string += "]"
elif builtins.is_array(value.type):
format_string += "])"
elif builtins.is_range(value.type): elif builtins.is_range(value.type):
format_string += "range("; flush() format_string += "range("; flush()

View File

@ -671,7 +671,8 @@ class Inferencer(algorithm.Visitor):
pass pass
else: else:
diagnose(valid_forms()) diagnose(valid_forms())
elif types.is_builtin(typ, "list"): elif types.is_builtin(typ, "list") or types.is_builtin(typ, "array"):
if types.is_builtin(typ, "list"):
valid_forms = lambda: [ valid_forms = lambda: [
valid_form("list() -> list(elt='a)"), valid_form("list() -> list(elt='a)"),
valid_form("list(x:'a) -> list(elt='b) where 'a is iterable") valid_form("list(x:'a) -> list(elt='b) where 'a is iterable")
@ -679,6 +680,16 @@ class Inferencer(algorithm.Visitor):
self._unify(node.type, builtins.TList(), self._unify(node.type, builtins.TList(),
node.loc, None) node.loc, None)
elif types.is_builtin(typ, "array"):
valid_forms = lambda: [
valid_form("array() -> array(elt='a)"),
valid_form("array(x:'a) -> array(elt='b) where 'a is iterable")
]
self._unify(node.type, builtins.TArray(),
node.loc, None)
else:
assert False
if len(node.args) == 0 and len(node.keywords) == 0: if len(node.args) == 0 and len(node.keywords) == 0:
pass # [] pass # []
@ -708,7 +719,8 @@ class Inferencer(algorithm.Visitor):
{"type": types.TypePrinter().name(arg.type)}, {"type": types.TypePrinter().name(arg.type)},
arg.loc) arg.loc)
diag = diagnostic.Diagnostic("error", diag = diagnostic.Diagnostic("error",
"the argument of list() must be of an iterable type", {}, "the argument of {builtin}() must be of an iterable type",
{"builtin": typ.find().name},
node.func.loc, notes=[note]) node.func.loc, notes=[note])
self.engine.process(diag) self.engine.process(diag)
else: else:
@ -743,7 +755,7 @@ class Inferencer(algorithm.Visitor):
if builtins.is_range(arg.type): if builtins.is_range(arg.type):
self._unify(node.type, builtins.get_iterable_elt(arg.type), self._unify(node.type, builtins.get_iterable_elt(arg.type),
node.loc, None) node.loc, None)
elif builtins.is_list(arg.type): elif builtins.is_listish(arg.type):
# TODO: should be ssize_t-sized # TODO: should be ssize_t-sized
self._unify(node.type, builtins.TInt32(), self._unify(node.type, builtins.TInt32(),
node.loc, None) node.loc, None)

View File

@ -218,7 +218,7 @@ class LLVMIRGenerator:
return lldouble return lldouble
elif builtins.is_str(typ) or ir.is_exn_typeinfo(typ): elif builtins.is_str(typ) or ir.is_exn_typeinfo(typ):
return llptr return llptr
elif builtins.is_list(typ): elif builtins.is_listish(typ):
lleltty = self.llty_of_type(builtins.get_iterable_elt(typ)) lleltty = self.llty_of_type(builtins.get_iterable_elt(typ))
return ll.LiteralStructType([lli32, lleltty.as_pointer()]) return ll.LiteralStructType([lli32, lleltty.as_pointer()])
elif builtins.is_range(typ): elif builtins.is_range(typ):
@ -610,7 +610,7 @@ class LLVMIRGenerator:
name=insn.name) name=insn.name)
else: else:
assert False assert False
elif builtins.is_list(insn.type): elif builtins.is_listish(insn.type):
llsize = self.map(insn.operands[0]) llsize = self.map(insn.operands[0])
llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined) llvalue = ll.Constant(self.llty_of_type(insn.type), ll.Undefined)
llvalue = self.llbuilder.insert_value(llvalue, llsize, 0) llvalue = self.llbuilder.insert_value(llvalue, llsize, 0)
@ -1162,6 +1162,9 @@ class LLVMIRGenerator:
elif builtins.is_list(typ): elif builtins.is_list(typ):
return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ), return b"l" + self._rpc_tag(builtins.get_iterable_elt(typ),
error_handler) error_handler)
elif builtins.is_array(typ):
return b"a" + self._rpc_tag(builtins.get_iterable_elt(typ),
error_handler)
elif builtins.is_range(typ): elif builtins.is_range(typ):
return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ), return b"r" + self._rpc_tag(builtins.get_iterable_elt(typ),
error_handler) error_handler)
@ -1405,13 +1408,13 @@ class LLVMIRGenerator:
elif builtins.is_str(typ): elif builtins.is_str(typ):
assert isinstance(value, (str, bytes)) assert isinstance(value, (str, bytes))
return self.llstr_of_str(value) return self.llstr_of_str(value)
elif builtins.is_list(typ): elif builtins.is_listish(typ):
assert isinstance(value, list) assert isinstance(value, (list, numpy.ndarray))
elt_type = builtins.get_iterable_elt(typ) elt_type = builtins.get_iterable_elt(typ)
llelts = [self._quote(value[i], elt_type, lambda: path() + [str(i)]) llelts = [self._quote(value[i], elt_type, lambda: path() + [str(i)])
for i in range(len(value))] for i in range(len(value))]
lleltsary = ll.Constant(ll.ArrayType(self.llty_of_type(elt_type), len(llelts)), lleltsary = ll.Constant(ll.ArrayType(self.llty_of_type(elt_type), len(llelts)),
llelts) list(llelts))
llglobal = ll.GlobalVariable(self.llmodule, lleltsary.type, llglobal = ll.GlobalVariable(self.llmodule, lleltsary.type,
self.llmodule.scope.deduplicate("quoted.list")) self.llmodule.scope.deduplicate("quoted.list"))

View File

@ -697,7 +697,7 @@ class TypePrinter(object):
return "<instance {} {{}}>".format(typ.name) return "<instance {} {{}}>".format(typ.name)
elif isinstance(typ, TMono): elif isinstance(typ, TMono):
if typ.name in self.custom_printers: if typ.name in self.custom_printers:
return self.custom_printers[typ.name](typ, depth + 1, max_depth) return self.custom_printers[typ.name](typ, self, depth + 1, max_depth)
elif typ.params == {}: elif typ.params == {}:
return typ.name return typ.name
else: else:

View File

@ -331,6 +331,9 @@ class CommGeneric:
elif tag == "l": elif tag == "l":
length = self._read_int32() length = self._read_int32()
return [self._receive_rpc_value(embedding_map) for _ in range(length)] return [self._receive_rpc_value(embedding_map) for _ in range(length)]
elif tag == "a":
length = self._read_int32()
return numpy.array([self._receive_rpc_value(embedding_map) for _ in range(length)])
elif tag == "r": elif tag == "r":
start = self._receive_rpc_value(embedding_map) start = self._receive_rpc_value(embedding_map)
stop = self._receive_rpc_value(embedding_map) stop = self._receive_rpc_value(embedding_map)

View File

@ -607,6 +607,7 @@ static void skip_rpc_value(const char **tag) {
} }
case 'l': case 'l':
case 'a':
skip_rpc_value(tag); skip_rpc_value(tag);
break; break;
@ -650,6 +651,7 @@ static int sizeof_rpc_value(const char **tag)
return sizeof(char *); return sizeof(char *);
case 'l': // list(elt='a) case 'l': // list(elt='a)
case 'a': // array(elt='a)
skip_rpc_value(tag); skip_rpc_value(tag);
return sizeof(struct { int32_t length; struct {} *elements; }); return sizeof(struct { int32_t length; struct {} *elements; });
@ -733,7 +735,8 @@ static int receive_rpc_value(const char **tag, void **slot)
break; break;
} }
case 'l': { // list(elt='a) case 'l': // list(elt='a)
case 'a': { // array(elt='a)
struct { int32_t length; struct {} *elements; } *list = *slot; struct { int32_t length; struct {} *elements; } *list = *slot;
list->length = in_packet_int32(); list->length = in_packet_int32();
@ -824,7 +827,8 @@ static int send_rpc_value(const char **tag, void **value)
return out_packet_string(*((*(const char***)value)++)); return out_packet_string(*((*(const char***)value)++));
} }
case 'l': { // list(elt='a) case 'l': // list(elt='a)
case 'a': { // array(elt='a)
struct { uint32_t length; struct {} *elements; } *list = *value; struct { uint32_t length; struct {} *elements; } *list = *value;
void *element = list->elements; void *element = list->elements;

View File

@ -41,6 +41,9 @@ class RoundtripTest(ExperimentCase):
def test_list(self): def test_list(self):
self.assertRoundtrip([10]) self.assertRoundtrip([10])
def test_array(self):
self.assertRoundtrip(numpy.array([10]))
def test_object(self): def test_object(self):
obj = object() obj = object()
self.assertRoundtrip(obj) self.assertRoundtrip(obj)

View File

@ -57,6 +57,9 @@ lambda x, y=1: x
k = "x" k = "x"
# CHECK-L: k:str # CHECK-L: k:str
l = array([1])
# CHECK-L: l:numpy.array(elt=numpy.int?)
IndexError() IndexError()
# CHECK-L: IndexError:<constructor IndexError {}>():IndexError # CHECK-L: IndexError:<constructor IndexError {}>():IndexError

View File

@ -0,0 +1,5 @@
# RUN: %python -m artiq.compiler.testbench.jit %s
# REQUIRES: exceptions
ary = array([1, 2, 3])
assert [x*x for x in ary] == [1, 4, 9]

View File

@ -30,3 +30,6 @@ print([[1, 2], [3]])
# CHECK-L: range(0, 10, 1) # CHECK-L: range(0, 10, 1)
print(range(10)) print(range(10))
# CHECK-L: array([1, 2])
print(array([1, 2]))