#!/usr/bin/python

# osh
# Copyright (C) 2005 Jack Orenstein <jao@geophile.com>
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation; either version 2 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

"""oshtestconnect -c CLUSTER [-u USER]
"""

from osh.oshconfig import *
import getopt
import popen2
import sys
import threading;

def usage():
    print __doc__
    sys.exit(1)

class _Host:

    _name = None
    _address = None

    def __init__(self, x, y = None):
        self._name = x
        if y is None:
            self._address = x
        else:
            self._address = y

    def name(self):
        return self._name

    def address(self):
        return self._address

def cluster_configuration(cluster):
    user = config_value('remote', cluster, 'user')
    hosts = config_value('remote', cluster, 'hosts')
    if type(hosts) is list:
        hosts = [_Host(addr) for addr in hosts]
    elif type(hosts) is dict:
        hosts = [_Host(name, addr) for name, addr in hosts.iteritems()]
    else:
        raise Exception(('Error in ~/.oshrc: ' +
                         'osh.remote.%s.hosts must be a list or dict') % cluster)
    install = config_value('remote', cluster, 'install')
    return (user, hosts, install)

def has_profile(cluster):
    config = cluster_configuration(cluster)
    return config[0] or config[1] or config[2]

########## Adapted from oshobjects.remoteop
def printCommandOutput(host, stdout, stderr):
    if stdout:
        for line in stdout:
            sys.stdout.write('%s stdout: %s' % (host.name(), line))
    if stderr:
        for line in stderr:
            sys.stdout.write('%s stderr: %s' % (host.name(), line))

def spawn(host, command):
    read, write, err = popen2.popen3(command)
    printCommandOutput(host, read.readlines(), err.readlines())
    read.close()
    write.close()
    err.close()

def ssh(user, host, command):
    sshCommand = 'ssh %s -T -o StrictHostKeyChecking=no -l %s "%s" ' % (host.address(), user, command)
    spawn(host, sshCommand)

def user_and_hosts(cluster, user):
    if cluster:
        # cluster could be one of the following:
        # - cluster defined in profile, e.g. foo
        # - cluster and pattern, e.g. foo:116
        # - neither -- interpret as a hostname
        cluster, pattern = (cluster.split(':') + [None])[:2]
        if has_profile(cluster):
            profile_user, hosts, install = cluster_configuration(cluster)
            if pattern:
                hosts = filter(lambda host: host.name().find(pattern) >= 0, hosts)
            # Allow for command line overrides
            if not user:
                user = profile_user
        else:
            # Cluster is specified on command line but not configured
            # in .oshrc. Interpret the cluster as a hostname.
            hosts = [_Host(cluster)]
    else:
        # Cluster not specified. Use default in .oshrc.
        cluster = config_value('remote')
        if cluster:
            profile_user, hosts, install = cluster_configuration(cluster)
            # Allow for command line overrides
            if not user:
                user = profile_user
    # if user still unknown, check the environment.
    if not user:
        user = os.environ['USER']
    if not cluster or not user or not hosts:
        usage()
    return user, hosts
##########

class Tester(threading.Thread):

    _user = None
    _host = None

    def __init__(self, user, host):
        threading.Thread.__init__(self)
        self._user = user
        self._host = host

    def run(self):
        ssh(self._user, self._host, 'echo hello')

def test_connection(user, hosts):
    threads = []
    for host in hosts:
        tester = Tester(user, host)
        tester.start()
        threads.append(tester)
    for thread in threads:
        while thread.isAlive():
            thread.join(1.0)

options, args = getopt.getopt(sys.argv[1:], 'c:u:')
if args:
    usage()
cluster = None
user = None
for option in options:
    if option[0] == '-c':
        cluster = option[1]
    elif option[1] == '-u':
        user = option[1]
if not cluster:
    usage()
user, hosts = user_and_hosts(cluster, user)
test_connection(user, hosts)
