#!/usr/bin/env python
# -*- Mode: python -*-
#
# Usage: pssh [OPTIONS] -h hosts.txt prog [arg0] [arg1] ..
#
# Parallel ssh to the set of nodes in hosts.txt. For each node, this
# essentially does an "ssh host -l user prog [arg0] [arg1] ...". The -o
# option can be used to store stdout from each remote node in a
# directory.  Each output file in that directory will be named by the
# corresponding remote node's hostname or IP address.
#
# Created: 16 August 2003
#
# $Id: pssh 400 2008-10-12 11:48:28Z bnc $
#
import fcntl
import getopt
import os
import pwd
import signal
import subprocess
import sys
import threading

basedir, bin = os.path.split(os.path.dirname(os.path.abspath(sys.argv[0])))
sys.path.append("%s" % basedir)

from psshlib import psshutil
from psshlib.basethread import BaseThread

_DEFAULT_PARALLELISM = 32
_DEFAULT_TIMEOUT     = 60

def print_usage():
    print "Usage: pssh [OPTIONS] -h hosts.txt prog [arg0] .."
    print
    print "  -h --hosts   hosts file (each line \"host[:port] [user]\")"
    print "  -l --user    username (OPTIONAL)"
    print "  -p --par     max number of parallel threads (OPTIONAL)"
    print "  -o --outdir  output directory for stdout files (OPTIONAL)"
    print "  -e --errdir  output directory for stderr files (OPTIONAL)"
    print "  -t --timeout timeout (secs) (-1 = no timeout) per host (OPTIONAL)"
    print "  -O --options SSH options (OPTIONAL)"
    print "  -v --verbose turn on warning and diagnostic messages (OPTIONAL)"
    print "  -P --print   print output as we get it (OPTIONAL)"
    print "  -i --inline  inline aggregated output for each server (OPTIONAL)"
    print
    print "Example: pssh -h nodes.txt -l irb2 -o /tmp/foo uptime"
    print

def read_envvars(flags):
    if os.getenv("PSSH_HOSTS"):
        flags["hosts"] = os.getenv("PSSH_HOSTS")
    if os.getenv("PSSH_USER"):
        flags["user"] = os.getenv("PSSH_USER")
    if os.getenv("PSSH_PAR"):
        flags["par"] = int(os.getenv("PSSH_PAR"))
    if os.getenv("PSSH_OUTDIR"):
        flags["outdir"] = os.getenv("PSSH_OUTDIR")
    if os.getenv("PSSH_ERRDIR"):
        flags["errdir"] = os.getenv("PSSH_ERRDIR")
    if os.getenv("PSSH_TIMEOUT"):
        timeout = int(os.getenv("PSSH_TIMEOUT"))
        if timeout != -1:
            flags["timeout"] = timeout
        else:
            flags["timeout"] = None
    if os.getenv("PSSH_OPTIONS"):
        flags["options"] = os.getenv("PSSH_OPTIONS")
    if os.getenv("PSSH_VERBOSE"): # "0" or "1"
        flags["verbose"] = int(os.getenv("PSSH_VERBOSE"))
    if os.getenv("PSSH_PRINT"):
       flags["print"] = int(os.getenv("PSSH_PRINT"))
    if os.getenv("PSSH_INLINE"):
       flags["inline"] = int(os.getenv("PSSH_INLINE"))

def parsecmdline(argv):
    shortopts = "h:l:p:o:e:t:O:Pvi"
    longopts = [ "hosts=",
                 "user=",
                 "par=",
                 "outdir=",
                 "errdir=",
                 "timeout=",
                 "options=",
                 "print",
                 "verbose",
                 "inline" ]
    flags = { "hosts" : None,
              "user" : None,
              "par" : _DEFAULT_PARALLELISM,
              "outdir" : None,
              "errdir" : None,
              "timeout" : _DEFAULT_TIMEOUT,
              "options" : None,
              "print" : None,
              "verbose" : None,
              "inline": None }
    read_envvars(flags)
    if not flags["user"]:
        flags["user"] = pwd.getpwuid(os.getuid())[0] # Default to current user
    try:
        opts, args = getopt.getopt(argv[1:], shortopts, longopts)
    except getopt.GetoptError, e:
        print "Error: %s\n" % str(e)
        print_usage()
        sys.exit(3)
    for o, v in opts:
        if o in ("-h", "--hosts"):
            flags["hosts"] = v
        elif o in ("-l", "--user"):
            flags["user"] = v
        elif o in ("-p", "--par"):
            flags["par"] = int(v)
        elif o in ("-o", "--outdir"):
            flags["outdir"] = v
        elif o in ("-e", "--errdir"):
            flags["errdir"] = v
        elif o in ("-t", "--timeout"):
            timeout =  int(v)
            if timeout != -1:
               flags["timeout"] = timeout
            else:
               flags["timeout"] = None
        elif o in ("-O", "--options"):
            flags["options"] = v
        elif o in ("-v", "--verbose"):
            flags["verbose"] = 1
        elif o in ("-P", "--print"):
            flags["print"] = 1
        elif o in ("-i", "--inline"):
            flags["inline"] = 1
    # Required flags
    if not flags["hosts"]:
        print_usage()
        sys.exit(3)
    return args, flags

def buffer_input():
    origfl = fcntl.fcntl(sys.stdin.fileno(), fcntl.F_GETFL)
    fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, origfl | os.O_NONBLOCK)
    try:
        stdin = sys.stdin.read()
    except IOError: # Stdin contained no information
        stdin = ""
    fcntl.fcntl(sys.stdin.fileno(), fcntl.F_SETFL, origfl)
    return stdin

def do_pssh(hosts, ports, users, cmdline, flags):
    if flags["outdir"] and not os.path.exists(flags["outdir"]):
        os.makedirs(flags["outdir"])
    if flags["errdir"] and not os.path.exists(flags["errdir"]):
        os.makedirs(flags["errdir"])
    stdin = buffer_input()
    sem = threading.Semaphore(flags["par"])
    threads = []
    for i in range(len(hosts)):
        sem.acquire()
        if flags["verbose"]:
            quietswitch = ""
        else:
            quietswitch = "-q"
        if flags["options"]:
            cmd = "ssh -o \"%s\" %s -p %s -l %s %s \"%s\"" % \
                       (flags["options"], hosts[i], ports[i], users[i],
                        quietswitch, cmdline)
        else:
            cmd = "ssh %s -p %s -l %s %s \"%s\"" % \
                       (hosts[i], ports[i], users[i], quietswitch, cmdline)
        t = BaseThread(hosts[i], ports[i], cmd, flags, sem, stdin)
        t.start()
        threads.append(t)
    for t in threads:
        t.join()
   
if __name__ == "__main__":
    subprocess._cleanup = lambda : None
    args, flags = parsecmdline(sys.argv)
    if len(args) == 0:
        print_usage()
        sys.exit(3)
    cmdline = " ".join(args)
    hosts, ports, users = psshutil.read_hosts(flags["hosts"])
    psshutil.patch_users(hosts, ports, users, flags["user"])
    signal.signal(signal.SIGCHLD, signal.SIG_DFL)
    os.setpgid(0, 0)
    do_pssh(hosts, ports, users, cmdline, flags)
