import os |
import threading |
import Queue |
import traceback |
import atexit |
import weakref |
import __future__ |
|
|
|
|
|
|
|
|
|
|
|
if 'ThreadOut' not in globals(): |
import py |
from py.code import Source |
from py.__.execnet.channel import ChannelFactory, Channel |
from py.__.execnet.message import Message |
ThreadOut = py._thread.ThreadOut |
WorkerPool = py._thread.WorkerPool |
NamedThreadPool = py._thread.NamedThreadPool |
|
import os |
debug = 0 |
|
sysex = (KeyboardInterrupt, SystemExit) |
|
class Gateway(object): |
_ThreadOut = ThreadOut |
remoteaddress = "" |
def __init__(self, io, execthreads=None, _startcount=2): |
""" initialize core gateway, using the given |
inputoutput object and 'execthreads' execution |
threads. |
""" |
global registered_cleanup |
self._execpool = WorkerPool(maxthreads=execthreads) |
self._io = io |
self._outgoing = Queue.Queue() |
self._channelfactory = ChannelFactory(self, _startcount) |
if not registered_cleanup: |
atexit.register(cleanup_atexit) |
registered_cleanup = True |
_active_sendqueues[self._outgoing] = True |
self._pool = NamedThreadPool(receiver = self._thread_receiver, |
sender = self._thread_sender) |
|
def __repr__(self): |
""" return string representing gateway type and status. """ |
addr = self.remoteaddress |
if addr: |
addr = '[%s]' % (addr,) |
else: |
addr = '' |
try: |
r = (len(self._pool.getstarted('receiver')) |
and "receiving" or "not receiving") |
s = (len(self._pool.getstarted('sender')) |
and "sending" or "not sending") |
i = len(self._channelfactory.channels()) |
except AttributeError: |
r = s = "uninitialized" |
i = "no" |
return "<%s%s %s/%s (%s active channels)>" %( |
self.__class__.__name__, addr, r, s, i) |
|
|
|
|
def _trace(self, *args): |
if debug: |
try: |
l = "\n".join(args).split(os.linesep) |
id = getid(self) |
for x in l: |
print >>debug, x |
debug.flush() |
except sysex: |
raise |
except: |
traceback.print_exc() |
|
def _traceex(self, excinfo): |
try: |
l = traceback.format_exception(*excinfo) |
errortext = "".join(l) |
except: |
errortext = '%s: %s' % (excinfo[0].__name__, |
excinfo[1]) |
self._trace(errortext) |
|
def _thread_receiver(self): |
""" thread to read and handle Messages half-sync-half-async. """ |
try: |
from sys import exc_info |
while 1: |
try: |
msg = Message.readfrom(self._io) |
self._trace("received <- %r" % msg) |
msg.received(self) |
except sysex: |
break |
except EOFError: |
break |
except: |
self._traceex(exc_info()) |
break |
finally: |
self._outgoing.put(None) |
self._channelfactory._finished_receiving() |
self._trace('leaving %r' % threading.currentThread()) |
|
def _thread_sender(self): |
""" thread to send Messages over the wire. """ |
try: |
from sys import exc_info |
while 1: |
msg = self._outgoing.get() |
try: |
if msg is None: |
self._io.close_write() |
break |
msg.writeto(self._io) |
except: |
excinfo = exc_info() |
self._traceex(excinfo) |
if msg is not None: |
msg.post_sent(self, excinfo) |
break |
else: |
self._trace('sent -> %r' % msg) |
msg.post_sent(self) |
finally: |
self._trace('leaving %r' % threading.currentThread()) |
|
def _local_redirect_thread_output(self, outid, errid): |
l = [] |
for name, id in ('stdout', outid), ('stderr', errid): |
if id: |
channel = self._channelfactory.new(outid) |
out = self._ThreadOut(sys, name) |
out.setwritefunc(channel.send) |
l.append((out, channel)) |
def close(): |
for out, channel in l: |
out.delwritefunc() |
channel.close() |
return close |
|
def _thread_executor(self, channel, (source, outid, errid)): |
""" worker thread to execute source objects from the execution queue. """ |
from sys import exc_info |
try: |
loc = { 'channel' : channel } |
self._trace("execution starts:", repr(source)[:50]) |
close = self._local_redirect_thread_output(outid, errid) |
try: |
co = compile(source+'\n', '', 'exec', |
__future__.CO_GENERATOR_ALLOWED) |
exec co in loc |
finally: |
close() |
self._trace("execution finished:", repr(source)[:50]) |
except (KeyboardInterrupt, SystemExit): |
pass |
except: |
excinfo = exc_info() |
l = traceback.format_exception(*excinfo) |
errortext = "".join(l) |
channel.close(errortext) |
self._trace(errortext) |
else: |
channel.close() |
|
def _local_schedulexec(self, channel, sourcetask): |
self._trace("dispatching exec") |
self._execpool.dispatch(self._thread_executor, channel, sourcetask) |
|
def _newredirectchannelid(self, callback): |
if callback is None: |
return |
if hasattr(callback, 'write'): |
callback = callback.write |
assert callable(callback) |
chan = self.newchannel() |
chan.setcallback(callback) |
return chan.id |
|
|
|
|
|
|
def newchannel(self): |
""" return new channel object. """ |
return self._channelfactory.new() |
|
def remote_exec(self, source, stdout=None, stderr=None): |
""" return channel object and connect it to a remote |
execution thread where the given 'source' executes |
and has the sister 'channel' object in its global |
namespace. The callback functions 'stdout' and |
'stderr' get called on receival of remote |
stdout/stderr output strings. |
""" |
try: |
source = str(Source(source)) |
except NameError: |
try: |
import py |
source = str(py.code.Source(source)) |
except ImportError: |
pass |
channel = self.newchannel() |
outid = self._newredirectchannelid(stdout) |
errid = self._newredirectchannelid(stderr) |
self._outgoing.put(Message.CHANNEL_OPEN(channel.id, |
(source, outid, errid))) |
return channel |
|
def _remote_redirect(self, stdout=None, stderr=None): |
""" return a handle representing a redirection of a remote |
end's stdout to a local file object. with handle.close() |
the redirection will be reverted. |
""" |
clist = [] |
for name, out in ('stdout', stdout), ('stderr', stderr): |
if out: |
outchannel = self.newchannel() |
outchannel.setcallback(getattr(out, 'write', out)) |
channel = self.remote_exec(""" |
import sys |
outchannel = channel.receive() |
outchannel.gateway._ThreadOut(sys, %r).setdefaultwriter(outchannel.send) |
""" % name) |
channel.send(outchannel) |
clist.append(channel) |
for c in clist: |
c.waitclose() |
class Handle: |
def close(_): |
for name, out in ('stdout', stdout), ('stderr', stderr): |
if out: |
c = self.remote_exec(""" |
import sys |
channel.gateway._ThreadOut(sys, %r).resetdefault() |
""" % name) |
c.waitclose() |
return Handle() |
|
def exit(self): |
""" Try to stop all IO activity. """ |
try: |
del _active_sendqueues[self._outgoing] |
except KeyError: |
pass |
else: |
self._outgoing.put(None) |
|
def join(self, joinexec=True): |
""" Wait for all IO (and by default all execution activity) |
to stop. |
""" |
current = threading.currentThread() |
for x in self._pool.getstarted(): |
if x != current: |
self._trace("joining %s" % x) |
x.join() |
self._trace("joining sender/reciver threads finished, current %r" % current) |
if joinexec: |
self._execpool.join() |
self._trace("joining execution threads finished, current %r" % current) |
|
def getid(gw, cache={}): |
name = gw.__class__.__name__ |
try: |
return cache.setdefault(name, {})[id(gw)] |
except KeyError: |
cache[name][id(gw)] = x = "%s:%s.%d" %(os.getpid(), gw.__class__.__name__, len(cache[name])) |
return x |
|
registered_cleanup = False |
_active_sendqueues = weakref.WeakKeyDictionary() |
def cleanup_atexit(): |
if debug: |
print >>debug, "="*20 + "cleaning up" + "=" * 20 |
debug.flush() |
while True: |
try: |
queue, ignored = _active_sendqueues.popitem() |
except KeyError: |
break |
queue.put(None) |
|