mirror of https://github.com/m-labs/artiq.git
protocols/pipe_ipc: implement AsyncioChildComm for Unix (affected by python/asyncio#314)
This commit is contained in:
parent
8befc6a8fc
commit
dcea6780c6
|
@ -1,5 +1,6 @@
|
||||||
import os
|
import os
|
||||||
import asyncio
|
import asyncio
|
||||||
|
from asyncio.streams import FlowControlMixin
|
||||||
|
|
||||||
|
|
||||||
class _BaseIO:
|
class _BaseIO:
|
||||||
|
@ -20,6 +21,22 @@ class _BaseIO:
|
||||||
|
|
||||||
|
|
||||||
if os.name != "nt":
|
if os.name != "nt":
|
||||||
|
async def _fds_to_asyncio(rfd, wfd, loop):
|
||||||
|
reader = asyncio.StreamReader(loop=loop)
|
||||||
|
reader_protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
|
||||||
|
|
||||||
|
wf = open(wfd, "wb", 0)
|
||||||
|
transport, protocol = await loop.connect_write_pipe(
|
||||||
|
FlowControlMixin, wf)
|
||||||
|
writer = asyncio.StreamWriter(transport, protocol,
|
||||||
|
None, loop)
|
||||||
|
|
||||||
|
rf = open(rfd, "rb", 0)
|
||||||
|
await loop.connect_read_pipe(lambda: reader_protocol, rf)
|
||||||
|
|
||||||
|
return reader, writer
|
||||||
|
|
||||||
|
|
||||||
class AsyncioParentComm(_BaseIO):
|
class AsyncioParentComm(_BaseIO):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.c_rfd, self.p_wfd = os.pipe()
|
self.c_rfd, self.p_wfd = os.pipe()
|
||||||
|
@ -35,20 +52,19 @@ if os.name != "nt":
|
||||||
os.close(self.c_rfd)
|
os.close(self.c_rfd)
|
||||||
os.close(self.c_wfd)
|
os.close(self.c_wfd)
|
||||||
|
|
||||||
pipe = open(self.p_rfd, "rb", 0)
|
self.reader, self.writer = await _fds_to_asyncio(
|
||||||
self.reader = asyncio.StreamReader(loop=loop)
|
self.p_rfd, self.p_wfd, loop)
|
||||||
def factory():
|
|
||||||
return asyncio.StreamReaderProtocol(self.reader, loop=loop)
|
|
||||||
await loop.connect_read_pipe(factory, pipe)
|
|
||||||
|
|
||||||
pipe = open(self.p_wfd, "wb", 0)
|
|
||||||
transport, protocol = await loop.connect_write_pipe(
|
|
||||||
asyncio.Protocol, pipe)
|
|
||||||
self.writer = asyncio.StreamWriter(transport, protocol,
|
|
||||||
None, loop)
|
|
||||||
|
|
||||||
class AsyncioChildComm(_BaseIO):
|
class AsyncioChildComm(_BaseIO):
|
||||||
pass
|
def __init__(self, address):
|
||||||
|
self.address = address
|
||||||
|
|
||||||
|
async def connect(self):
|
||||||
|
rfd, wfd = self.address.split(",", maxsplit=1)
|
||||||
|
self.reader, self.writer = await _fds_to_asyncio(
|
||||||
|
int(rfd), int(wfd), asyncio.get_event_loop())
|
||||||
|
|
||||||
|
|
||||||
class ChildComm:
|
class ChildComm:
|
||||||
def __init__(self, address):
|
def __init__(self, address):
|
||||||
|
@ -82,10 +98,10 @@ else: # windows
|
||||||
async def connect(self):
|
async def connect(self):
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
self.reader = asyncio.StreamReader(loop=loop)
|
self.reader = asyncio.StreamReader(loop=loop)
|
||||||
def factory():
|
reader_protocol = asyncio.StreamReaderProtocol(
|
||||||
return asyncio.StreamReaderProtocol(self.reader)
|
self.reader, loop=loop)
|
||||||
transport, protocol = await loop.create_pipe_connection(self.address,
|
transport, protocol = await loop.create_pipe_connection(
|
||||||
factory)
|
self.address, lambda: reader_protocol)
|
||||||
self.writer = asyncio.StreamWriter(transport, protocol,
|
self.writer = asyncio.StreamWriter(transport, protocol,
|
||||||
self.reader, loop)
|
self.reader, loop)
|
||||||
|
|
||||||
|
|
|
@ -1,25 +1,31 @@
|
||||||
import unittest
|
import unittest
|
||||||
import sys
|
import sys
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import os
|
||||||
|
|
||||||
from artiq.protocols import pipe_ipc
|
from artiq.protocols import pipe_ipc
|
||||||
|
|
||||||
|
|
||||||
class IPCCase(unittest.TestCase):
|
class IPCCase(unittest.TestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
self.loop = asyncio.new_event_loop()
|
if os.name == "nt":
|
||||||
|
self.loop = asyncio.ProactorEventLoop()
|
||||||
|
else:
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(self.loop)
|
asyncio.set_event_loop(self.loop)
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.loop.close()
|
self.loop.close()
|
||||||
|
|
||||||
async def _coro_test(self, child_blocking):
|
async def _coro_test(self, child_asyncio):
|
||||||
ipc = pipe_ipc.AsyncioParentComm()
|
ipc = pipe_ipc.AsyncioParentComm()
|
||||||
await ipc.create_subprocess(sys.executable,
|
await ipc.create_subprocess(sys.executable,
|
||||||
sys.modules[__name__].__file__,
|
sys.modules[__name__].__file__,
|
||||||
|
str(child_asyncio),
|
||||||
ipc.get_address())
|
ipc.get_address())
|
||||||
for i in range(10):
|
for i in range(10):
|
||||||
ipc.write("{}\n".format(i).encode())
|
ipc.write("{}\n".format(i).encode())
|
||||||
|
await ipc.drain()
|
||||||
s = (await ipc.readline()).decode()
|
s = (await ipc.readline()).decode()
|
||||||
self.assertEqual(int(s), i+1)
|
self.assertEqual(int(s), i+1)
|
||||||
ipc.write(b"-1\n")
|
ipc.write(b"-1\n")
|
||||||
|
@ -27,16 +33,49 @@ class IPCCase(unittest.TestCase):
|
||||||
ipc.close()
|
ipc.close()
|
||||||
|
|
||||||
def test_blocking(self):
|
def test_blocking(self):
|
||||||
|
self.loop.run_until_complete(self._coro_test(False))
|
||||||
|
|
||||||
|
def test_asyncio(self):
|
||||||
self.loop.run_until_complete(self._coro_test(True))
|
self.loop.run_until_complete(self._coro_test(True))
|
||||||
|
|
||||||
|
|
||||||
def run_child():
|
def run_child_blocking():
|
||||||
child_comm = pipe_ipc.ChildComm(sys.argv[1])
|
child_comm = pipe_ipc.ChildComm(sys.argv[2])
|
||||||
while True:
|
while True:
|
||||||
x = int(child_comm.readline().decode())
|
x = int(child_comm.readline().decode())
|
||||||
if x < 0:
|
if x < 0:
|
||||||
break
|
break
|
||||||
child_comm.write((str(x+1) + "\n").encode())
|
child_comm.write((str(x+1) + "\n").encode())
|
||||||
|
child_comm.close()
|
||||||
|
|
||||||
|
|
||||||
|
async def coro_child():
|
||||||
|
child_comm = pipe_ipc.AsyncioChildComm(sys.argv[2])
|
||||||
|
await child_comm.connect()
|
||||||
|
while True:
|
||||||
|
x = int((await child_comm.readline()).decode())
|
||||||
|
if x < 0:
|
||||||
|
break
|
||||||
|
child_comm.write((str(x+1) + "\n").encode())
|
||||||
|
await child_comm.drain()
|
||||||
|
child_comm.close()
|
||||||
|
|
||||||
|
|
||||||
|
def run_child_asyncio():
|
||||||
|
if os.name == "nt":
|
||||||
|
loop = asyncio.ProactorEventLoop()
|
||||||
|
asyncio.set_event_loop(loop)
|
||||||
|
else:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(coro_child())
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
def run_child():
|
||||||
|
if sys.argv[1] == "True":
|
||||||
|
run_child_asyncio()
|
||||||
|
else:
|
||||||
|
run_child_blocking()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
run_child()
|
run_child()
|
||||||
|
|
Loading…
Reference in New Issue