diff --git a/artiq/protocols/pipe_ipc.py b/artiq/protocols/pipe_ipc.py index a982bdf9b..910fe93c4 100644 --- a/artiq/protocols/pipe_ipc.py +++ b/artiq/protocols/pipe_ipc.py @@ -16,9 +16,6 @@ class _BaseIO: async def read(self, n): return await self.reader.read(n) - def close(self): - self.writer.close() - if os.name != "nt": async def _fds_to_asyncio(rfd, wfd, loop): @@ -26,9 +23,9 @@ if os.name != "nt": reader_protocol = asyncio.StreamReaderProtocol(reader, loop=loop) wf = open(wfd, "wb", 0) - transport, protocol = await loop.connect_write_pipe( + transport, _ = await loop.connect_write_pipe( FlowControlMixin, wf) - writer = asyncio.StreamWriter(transport, protocol, + writer = asyncio.StreamWriter(transport, reader_protocol, None, loop) rf = open(rfd, "rb", 0) @@ -45,6 +42,10 @@ if os.name != "nt": def get_address(self): return "{},{}".format(self.c_rfd, self.c_wfd) + async def _autoclose(self): + await self.process.wait() + self.writer.close() + async def create_subprocess(self, *args, **kwargs): loop = asyncio.get_event_loop() self.process = await asyncio.create_subprocess_exec( @@ -54,6 +55,7 @@ if os.name != "nt": self.reader, self.writer = await _fds_to_asyncio( self.p_rfd, self.p_wfd, loop) + asyncio.ensure_future(self._autoclose()) class AsyncioChildComm(_BaseIO): @@ -65,6 +67,9 @@ if os.name != "nt": self.reader, self.writer = await _fds_to_asyncio( int(rfd), int(wfd), asyncio.get_event_loop()) + def close(self): + self.writer.close() + class ChildComm: def __init__(self, address): @@ -88,7 +93,10 @@ if os.name != "nt": else: # windows class AsyncioParentComm(_BaseIO): - pass + async def _autoclose(self): + await self.process.wait() + self.writer.close() + class AsyncioChildComm(_BaseIO): """Requires ProactorEventLoop""" @@ -100,9 +108,9 @@ else: # windows self.reader = asyncio.StreamReader(loop=loop) reader_protocol = asyncio.StreamReaderProtocol( self.reader, loop=loop) - transport, protocol = await loop.create_pipe_connection( + transport, _ = await loop.create_pipe_connection( self.address, lambda: reader_protocol) - self.writer = asyncio.StreamWriter(transport, protocol, + self.writer = asyncio.StreamWriter(transport, reader_protocol, self.reader, loop) class ChildComm: diff --git a/artiq/test/pipe_ipc.py b/artiq/test/pipe_ipc.py index 9614a7431..b066d4276 100644 --- a/artiq/test/pipe_ipc.py +++ b/artiq/test/pipe_ipc.py @@ -30,7 +30,6 @@ class IPCCase(unittest.TestCase): self.assertEqual(int(s), i+1) ipc.write(b"-1\n") await ipc.process.wait() - ipc.close() def test_blocking(self): self.loop.run_until_complete(self._coro_test(False))