#!/usr/bin/python

# osh
# Copyright (C) 2005 Jack Orenstein <jao@geophile.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.


# stdin carries 2 or 3 inputs:
# 1) verbosity flag
# 2) pipeline to be executed
# 3) optional: kill signal
#
# The main thread processes stdin. Command execution takes place
# in a separate thread (_PipelineRunner). Termination occurs in
# one of two ways:
# 1) kill signal arrives: The kill is applied to this process and all
#    descendents.
# 2) Command terminates, (normally or with exception): kill self and
#    all descendents anyway (but from _PipelineRunner). This is the
#    only way to get the wait for the kill signal to stop if the
#    client doesn't end first.
#
# The _PipelineRunner thread is never joined! The reason is that the
# process will be killed by one thread or the other.

import cPickle
import os
import sys
import threading
import traceback

import osh.oshobjects
import osh.oshtypes.process

pid = os.getpid()

closed_streams = False

TRACE = True
_tracefile = None

def trace(line):
    if TRACE:
        global _tracefile
        if _tracefile is None:
            _tracefile = open('/tmp/trace', 'w')
        print >>_tracefile, line
        _tracefile.flush()

def flush_trace():
    if TRACE:
        global _tracefile
        if _tracefile:
            _tracefile.flush()

def _kill_self_and_descendents(kill_signal = None):
    trace('>>> In _kill_self_and_descendents for process %s' % pid)
    try:
        this_process = osh.oshtypes.process.process(pid)
        for descendent in this_process.descendents():
            trace('>>> killing %s' % descendent.pid())
            descendent.kill(kill_signal)
        trace('>>> killing %s' % this_process.pid())
        this_process.kill(kill_signal)
    except:
        trace('>>> exception while killing self: %s' % str(e))
        traceback.print_exc(file = _tracefile)
        flush_trace()

def _shutdown():
    global closed_streams
    if not closed_streams:
        trace('Closing stdout and stderr')
        sys.stdout.flush()
        sys.stderr.flush()
        sys.stdout.close()
        sys.stderr.close()
        closed_streams = True

class _PicklingOpset(osh.oshobjects.Opset):
    _output = None

    def __init__(self):
        osh.oshobjects.Opset.__init__(self)
        self._output = cPickle.Pickler(sys.stdout)

    def __repr__(self):
        return '_PicklingOpset#%s' % self.id()

    def generator(self):
        # This isn't really a generator, but we don't need an
        # error handler generated for the opset
        return True

    def receive(self, label, object):
        for i in xrange(len(object)):
            oi = object[i]
            trace('label: %s, object[%s]: (%s) %s' % (label, i, oi.__class__, oi))
        self._output.dump((label, object))

    def receive_complete(self):
        _shutdown()


class _DebugOpset(osh.oshobjects.Opset):
    _output = None

    def __init__(self):
        osh.oshobjects.Opset.__init__(self)

    def __repr__(self):
        return '_DebugOpset#%s' % self.id()

    def generator(self):
        # This isn't really a generator, but we don't need an
        # error handler generated for the opset
        return True

    def send_complete(self):
        pass
    
    def receive(self, label, object):
        trace('about to pickle for %s: %s' % (label, object))


class _PipelineRunner(threading.Thread):

    _pipeline = None

    def __init__(self, pipeline):
        threading.Thread.__init__(self)
        self._pipeline = pipeline

    def run(self):
        try:
            try:
                self._pipeline.append_opset(_PicklingOpset())
                trace(('pipeline (before setup): %s') % self._pipeline.dump())
                self._pipeline.setup()
                trace(('pipeline (after setup): %s') % self._pipeline.dump())
                self._pipeline.execute()
                trace('done')
            except Exception, e:
                trace('Caught exception during execution: %s' % str(e))
                traceback.print_exc(file = _tracefile)
                flush_trace()
        finally:
            _shutdown()
            trace('About to kill self and descendents')
            _kill_self_and_descendents()
        

input = cPickle.Unpickler(sys.stdin)
try:
    osh.oshobjects.verbosity = input.load()
    trace('verbosity: %s' % osh.oshobjects.verbosity)
    # osh_usage controls error handling. On remote side (i.e., here),
    # do CLI error handling -- write to e stream. Caller will deal with it.
    osh.oshobjects.osh_usage = osh.oshobjects.USAGE_CLI
    if len(sys.argv) > 1:
        osh.oshobjects.default_db_profile = sys.argv[1]
    pipeline = input.load()
    pipeline_runner = _PipelineRunner(pipeline)
    pipeline_runner.start()
    # Wait for kill signal that may never come
    try:
        kill_signal = input.load()
        trace('Received kill signal %s' % kill_signal)
        _kill_self_and_descendents(kill_signal)
    except EOFError, e:
        trace('EOFError waiting for kill signal')
        trace(str(e))
        traceback.print_exc(file = _tracefile)
        flush_trace()
        _kill_self_and_descendents(9)
except Exception, e:
    trace(str(e))
    traceback.print_exc(file = _tracefile)
    flush_trace()
