#!/usr/bin/env python

#   The contents of this file are subject to the Mozilla Public License
#   Version 1.1 (the "License"); you may not use this file except in
#   compliance with the License. You may obtain a copy of the License at
#   http://www.mozilla.org/MPL/
#
#   Software distributed under the License is distributed on an "AS IS"
#   basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See the
#   License for the specific language governing rights and limitations
#   under the License.
#
#   The Original Code is RabbitMQ Management Plugin.
#
#   The Initial Developer of the Original Code is VMware, Inc.
#   Copyright (c) 2010-2011 VMware, Inc.  All rights reserved.

from optparse import OptionParser
import sys
import httplib
import urllib
import base64
import json
import os
import socket

LISTABLE = ['connections', 'channels', 'exchanges', 'queues', 'bindings',
            'users', 'vhosts', 'permissions', 'nodes']

SHOWABLE = ['overview']

PROMOTE_COLUMNS = ['vhost', 'name', 'type',
                   'source', 'destination', 'destination_type', 'routing_key']

URIS = {
    'exchange':   '/exchanges/{vhost}/{name}',
    'queue':      '/queues/{vhost}/{name}',
    'binding':    '/bindings/{vhost}/e/{source}/{destination_char}/{destination}',
    'binding_del':'/bindings/{vhost}/e/{source}/{destination_char}/{destination}/{properties_key}',
    'vhost':      '/vhosts/{name}',
    'user':       '/users/{name}',
    'permission': '/permissions/{vhost}/{user}'
    }

DECLARABLE = {
    'exchange':   {'mandatory': ['name', 'type'],
                   'optional':  {'auto_delete': 'false', 'durable': 'true',
                                 'internal': 'false'}},
    'queue':      {'mandatory': ['name'],
                   'optional':  {'auto_delete': 'false', 'durable': 'true'}},
    'binding':    {'mandatory': ['source', 'destination_type', 'destination',
                                 'routing_key'],
                   'optional':  {}},
    'vhost':      {'mandatory': ['name'],
                   'optional':  {}},
    'user':       {'mandatory': ['name', 'password', 'tags'],
                   'optional':  {}},
    'permission': {'mandatory': ['vhost', 'user', 'configure', 'write', 'read'],
                   'optional':  {}}
    }

DELETABLE = {
    'exchange':   {'mandatory': ['name']},
    'queue':      {'mandatory': ['name']},
    'binding':    {'mandatory': ['source', 'destination_type', 'destination',
                                 'properties_key']},
    'vhost':      {'mandatory': ['name']},
    'user':       {'mandatory': ['name']},
    'permission': {'mandatory': ['vhost', 'user']}
    }

CLOSABLE = {
    'connection': {'mandatory': ['name'],
                   'optional':  {},
                   'uri':       '/connections/{name}'}
    }

PURGABLE = {
    'queue': {'mandatory': ['name'],
              'optional':  {},
              'uri':       '/queues/{vhost}/{name}/contents'}
    }

EXTRA_VERBS = {
    'publish': {'mandatory': ['routing_key'],
                'optional':  {'payload': None,
                              'exchange': 'amq.default',
                              'payload_encoding': 'string'},
                'uri':       '/exchanges/{vhost}/{exchange}/publish'},
    'get':     {'mandatory': ['queue'],
                'optional':  {'count': '1', 'requeue': 'true',
                              'payload_file': None},
                'uri':       '/queues/{vhost}/{queue}/get'}
}

for k in DECLARABLE:
    DECLARABLE[k]['uri'] = URIS[k]

for k in DELETABLE:
    DELETABLE[k]['uri'] = URIS[k]
    DELETABLE[k]['optional'] = {}
DELETABLE['binding']['uri'] = URIS['binding_del']

def make_usage():
    usage = """usage: %prog [options] cmd
where cmd is one of:

"""
    for l in LISTABLE:
        usage += "  list {0} [<column>...]\n".format(l)
    for s in SHOWABLE:
        usage += "  show {0} [<column>...]\n".format(s)
    usage += fmt_usage_stanza(DECLARABLE,  'declare')
    usage += fmt_usage_stanza(DELETABLE,   'delete')
    usage += fmt_usage_stanza(CLOSABLE,    'close')
    usage += fmt_usage_stanza(PURGABLE,    'purge')
    usage += fmt_usage_stanza(EXTRA_VERBS, '')
    usage += """
  export <file>
  import <file>

  * If payload is not specified on publish, standard input is used

  * If payload_file is not specified on get, the payload will be shown on
    standard output along with the message metadata

  * If payload_file is specified on get, count must not be set"""
    return usage

def fmt_usage_stanza(root, verb):
    def fmt_args(args):
        res = " ".join(["{0}=...".format(a) for a in args['mandatory']])
        opts = " ".join("{0}=...".format(o) for o in args['optional'].keys())
        if opts != "":
            res += " [{0}]".format(opts)
        return res

    text = "\n"
    if verb != "":
        verb = " " + verb
    for k in root.keys():
        text += " {0} {1} {2}\n".format(verb, k, fmt_args(root[k]))
    return text

parser = OptionParser(usage=make_usage())

def make_parser():
    parser.add_option("-H", "--host", dest="hostname", default="localhost",
                      help="connect to host HOST [default: %default]",
                      metavar="HOST")
    parser.add_option("-P", "--port", dest="port", default="55672",
                      help="connect to port PORT [default: %default]",
                      metavar="PORT")
    parser.add_option("-V", "--vhost", dest="vhost",
                      help="connect to vhost VHOST [default: all vhosts for list, '/' for declare]",
                      metavar="VHOST")
    parser.add_option("-u", "--username", dest="username", default="guest",
                      help="connect using username USERNAME [default: %default]",
                      metavar="USERNAME")
    parser.add_option("-p", "--password", dest="password", default="guest",
                      help="connect using password PASSWORD [default: %default]",
                      metavar="PASSWORD")
    parser.add_option("-q", "--quiet", action="store_false", dest="verbose",
                      default=True, help="suppress status messages")
    parser.add_option("-s", "--ssl", action="store_true", dest="ssl",
                      default=False, help="connect with ssl")
    parser.add_option("--ssl-key-file", dest="ssl_key_file", default=None,
                      help="PEM format key file for SSL [default: none]")
    parser.add_option("--ssl-cert-file", dest="ssl_cert_file", default=None,
                      help="PEM format certificate file for SSL [default: none]")
    parser.add_option("-f", "--format", dest="format", default="table",
                      help="format for listing commands - one of [" + ", ".join(FORMATS.keys())  + "]  [default: %default]")
    parser.add_option("-d", "--depth", dest="depth", default="1",
                      help="maximum depth to recurse for listing tables [default: %default]")
    parser.add_option("--bash-completion", action="store_true",
                      dest="bash_completion", default=False,
                      help="Print bash completion script")

def assert_usage(expr, error):
    if not expr:
        output("\nERROR: {0}\n".format(error))
        output("{0} --help for help\n".format(os.path.basename(sys.argv[0])))
        sys.exit(1)

def column_sort_key(col):
    if col in PROMOTE_COLUMNS:
        return (1, PROMOTE_COLUMNS.index(col))
    else:
        return (2, col)

def main():
    make_parser()
    (options, args) = parser.parse_args()
    if options.bash_completion:
        print_bash_completion()
        exit(0)
    assert_usage(len(args) > 0, 'Action not specified')
    mgmt = Management(options, args[1:])
    mode = "invoke_" + args[0]
    assert_usage(hasattr(mgmt, mode),
                 'Action {0} not understood'.format(args[0]))
    method = getattr(mgmt, "invoke_%s" % args[0])
    method()

def output(s):
    print maybe_utf8(s, sys.stdout)

def die(s):
    sys.stderr.write(maybe_utf8("*** {0}".format(s), sys.stderr))
    exit(1)

def maybe_utf8(s, stream):
    if stream.isatty():
        # It will have an encoding, which Python will respect
        return s
    else:
        # It won't have an encoding, and Python will pick ASCII by default
        return s.encode('utf-8')

class Management:
    def __init__(self, options, args):
        self.options = options
        self.args = args

    def get(self, path):
        return self.http("GET", path, "")

    def put(self, path, body):
        return self.http("PUT", path, body)

    def post(self, path, body):
        return self.http("POST", path, body)

    def delete(self, path):
        return self.http("DELETE", path, "")

    def http(self, method, path, body):
        if self.options.ssl:
            conn = httplib.HTTPSConnection(self.options.hostname,
                                           self.options.port,
                                           self.options.ssl_key_file,
                                           self.options.ssl_cert_file)
        else:
            conn = httplib.HTTPConnection(self.options.hostname,
                                          self.options.port)
        headers = {"Authorization":
                       "Basic " + base64.b64encode(self.options.username + ":" +
                                                   self.options.password)}
        if body != "":
            headers["Content-Type"] = "application/json"
        try:
            conn.request(method, "/api%s" % path, body, headers)
        except socket.error as e:
            die("Could not connect: {0}".format(e))
        resp = conn.getresponse()
        if resp.status == 400:
            die(json.loads(resp.read())['reason'])
        if resp.status == 401:
            die("Access refused: {0}".format(path))
        if resp.status == 404:
            die("Not found: {0}".format(path))
        if resp.status < 200 or resp.status > 400:
            raise Exception("Received %d %s for path %s\n%s"
                            % (resp.status, resp.reason, path, resp.read()))
        return resp.read()

    def verbose(self, string):
        if self.options.verbose:
            output(string)

    def get_arg(self):
        assert_usage(len(self.args) == 1, 'Exactly one argument required')
        return self.args[0]

    def invoke_publish(self):
        (uri, upload) = self.parse_args(self.args, EXTRA_VERBS['publish'])
        upload['properties'] = {} # TODO do we care here?
        if not upload['payload']:
            data = sys.stdin.read()
            upload['payload'] = base64.b64encode(data)
            upload['payload_encoding'] = 'base64'
        resp = json.loads(self.post(uri, json.dumps(upload)))
        if resp['routed']:
            self.verbose("Message published")
        else:
            self.verbose("Message published but NOT routed")

    def invoke_get(self):
        (uri, upload) = self.parse_args(self.args, EXTRA_VERBS['get'])
        payload_file = upload['payload_file']
        assert_usage(not payload_file or upload['count'] == '1',
                     'Cannot get multiple messages using payload_file')
        result = self.post(uri, json.dumps(upload))
        if payload_file:
            write_payload_file(payload_file, result)
            columns = ['routing_key', 'exchange', 'message_count',
                       'payload_bytes', 'redelivered']
            format_list(result, columns, self.options)
        else:
            format_list(result, [], self.options)

    def invoke_export(self):
        path = self.get_arg()
        definitions = self.get("/definitions")
        with open(path, 'w') as f:
            f.write(definitions)
        self.verbose("Exported definitions for %s to \"%s\""
                     % (self.options.hostname, path))

    def invoke_import(self):
        path = self.get_arg()
        with open(path, 'r') as f:
            definitions = f.read()
        self.post("/definitions", definitions)
        self.verbose("Imported definitions for %s from \"%s\""
                     % (self.options.hostname, path))

    def invoke_list(self):
        cols = self.args[1:]
        uri = self.list_show_uri(LISTABLE, 'list', cols)
        format_list(self.get(uri), cols, self.options)

    def invoke_show(self):
        cols = self.args[1:]
        uri = self.list_show_uri(SHOWABLE, 'show', cols)
        format_list('[{0}]'.format(self.get(uri)), cols, self.options)

    def list_show_uri(self, obj_types, verb, cols):
        obj_type = self.args[0]
        assert_usage(obj_type in obj_types,
                     "Don't know how to {0} {1}".format(verb, obj_type))
        uri = "/%s" % obj_type
        if self.options.vhost:
            uri += "/%s" % urllib.quote_plus(self.options.vhost)
        if cols != []:
            uri += "?columns=" + ",".join(cols)
        return uri

    def invoke_declare(self):
        (obj_type, uri, upload) = self.declare_delete_parse(DECLARABLE)
        if obj_type == 'binding':
            self.post(uri, json.dumps(upload))
        else:
            self.put(uri, json.dumps(upload))
        self.verbose("{0} declared".format(obj_type))

    def invoke_delete(self):
        (obj_type, uri, upload) = self.declare_delete_parse(DELETABLE)
        self.delete(uri)
        self.verbose("{0} deleted".format(obj_type))

    def invoke_close(self):
        (obj_type, uri, upload) = self.declare_delete_parse(CLOSABLE)
        self.delete(uri)
        self.verbose("{0} closed".format(obj_type))

    def invoke_purge(self):
        (obj_type, uri, upload) = self.declare_delete_parse(PURGABLE)
        self.delete(uri)
        self.verbose("{0} purged".format(obj_type))

    def declare_delete_parse(self, root):
        assert_usage(len(self.args) > 0, 'Type not specified')
        obj_type = self.args[0]
        assert_usage(obj_type in root,
                     'Type {0} not recognised'.format(obj_type))
        obj = root[obj_type]
        (uri, upload) = self.parse_args(self.args[1:], obj)
        return (obj_type, uri, upload)

    def parse_args(self, args, obj):
        mandatory =  obj['mandatory']
        optional = obj['optional']
        uri_template = obj['uri']
        upload = {}
        for k in optional.keys():
            upload[k] = optional[k]
        for arg in args:
            assert_usage("=" in arg,
                         'Argument "{0}" not in format name=value'.format(arg))
            (name, value) = arg.split("=", 1)
            assert_usage(name in mandatory or name in optional.keys(),
                         'Argument "{0}" not recognised'.format(name))
            upload[name] = value
        for m in mandatory:
            assert_usage(m in upload.keys(),
                         'mandatory argument "{0}" required'.format(m))
        upload['arguments'] = {}
        upload['vhost'] = self.options.vhost or '/'
        uri_args = {}
        for k in upload:
            v = upload[k]
            if v:
                uri_args[k] = urllib.quote_plus(v)
                if k == 'destination_type':
                    uri_args['destination_char'] = v[0]
        uri = uri_template.format(**uri_args)
        return (uri, upload)

def format_list(json_list, columns, options):
    format = options.format
    formatter = None
    if format == "raw_json":
        output(json_list)
        return
    elif format == "pretty_json":
        enc = json.JSONEncoder(False, False, True, True, True, 2)
        output(enc.encode(json.loads(json_list)))
        return
    else:
        formatter = FORMATS[format]
    assert_usage(formatter != None,
                 "Format {0} not recognised".format(format))
    formatter_instance = formatter(columns, options)
    formatter_instance.display(json_list)

class Lister:
    def verbose(self, string):
        if self.options.verbose:
            output(string)

    def display(self, json_list):
        depth = sys.maxint
        if len(self.columns) == 0:
            depth = int(self.options.depth)
        (columns, table) = self.list_to_table(json.loads(json_list), depth)
        if len(table) > 0:
            self.display_list(columns, table)
        else:
            self.verbose("No items")

    def list_to_table(self, items, max_depth):
        columns = {}
        column_ix = {}
        row = None
        table = []

        def add(prefix, depth, item, fun):
            for key in item:
                column = prefix == '' and key or (prefix + '.' + key)
                if type(item[key]) == dict:
                    if depth < max_depth:
                        add(column, depth + 1, item[key], fun)
                elif type(item[key]) == list:
                    # TODO - ATM this is only applications inside nodes. Not
                    # sure how to handle this sensibly.
                    pass
                else:
                    fun(column, item[key])

        def add_to_columns(col, val):
            columns[col] = True

        def add_to_row(col, val):
            if col in column_ix:
                row[column_ix[col]] = unicode(val)

        if len(self.columns) == 0:
            for item in items:
                add('', 1, item, add_to_columns)
            columns = columns.keys()
            columns.sort(key=column_sort_key)
        else:
            columns = self.columns

        for i in xrange(0, len(columns)):
            column_ix[columns[i]] = i
        for item in items:
            row = len(columns) * ['']
            add('', 1, item, add_to_row)
            table.append(row)

        return (columns, table)

class TSVList(Lister):
    def __init__(self, columns, options):
        self.columns = columns
        self.options = options

    def display_list(self, columns, table):
        head = ""
        for col in columns:
            head += col + "\t"
        self.verbose(head)

        for row in table:
            line = ""
            for cell in row:
                line += cell + "\t"
            output(line)

class LongList(Lister):
    def __init__(self, columns, options):
        self.columns = columns
        self.options = options

    def display_list(self, columns, table):
        sep = "\n" + "-" * 80 + "\n"
        max_width = 0
        for col in columns:
            max_width = max(max_width, len(col))
        fmt = "{0:>" + unicode(max_width) + "}: {1}"
        output(sep)
        for i in xrange(0, len(table)):
            for j in xrange(0, len(columns)):
                output(fmt.format(columns[j], table[i][j]))
            output(sep)

class TableList(Lister):
    def __init__(self, columns, options):
        self.columns = columns
        self.options = options

    def display_list(self, columns, table):
        total = [columns]
        total.extend(table)
        self.ascii_table(total)

    def ascii_table(self, rows):
        table = ""
        col_widths = [0] * len(rows[0])
        for i in xrange(0, len(rows[0])):
            for j in xrange(0, len(rows)):
                col_widths[i] = max(col_widths[i], len(rows[j][i]))
        self.ascii_bar(col_widths)
        self.ascii_row(col_widths, rows[0], "^")
        self.ascii_bar(col_widths)
        for row in rows[1:]:
            self.ascii_row(col_widths, row, "<")
        self.ascii_bar(col_widths)

    def ascii_row(self, col_widths, row, align):
        txt = "|"
        for i in xrange(0, len(col_widths)):
            fmt = " {0:" + align + unicode(col_widths[i]) + "} "
            txt += fmt.format(row[i]) + "|"
        output(txt)

    def ascii_bar(self, col_widths):
        txt = "+"
        for w in col_widths:
            txt += ("-" * (w + 2)) + "+"
        output(txt)

class KeyValueList(Lister):
    def __init__(self, columns, options):
        self.columns = columns
        self.options = options

    def display_list(self, columns, table):
        for i in xrange(0, len(table)):
            row = []
            for j in xrange(0, len(columns)):
                row.append("{0}=\"{1}\"".format(columns[j], table[i][j]))
            output(" ".join(row))

# TODO handle spaces etc in completable names
class BashList(Lister):
    def __init__(self, columns, options):
        self.columns = columns
        self.options = options

    def display_list(self, columns, table):
        ix = None
        for i in xrange(0, len(columns)):
            if columns[i] == 'name':
                ix = i
        if ix is not None:
            res = []
            for row in table:
                res.append(row[ix])
            output(" ".join(res))

FORMATS = {
    'raw_json'    : None, # Special cased
    'pretty_json' : None, # Ditto
    'tsv'         : TSVList,
    'long'        : LongList,
    'table'       : TableList,
    'kvp'         : KeyValueList,
    'bash'        : BashList
}

def write_payload_file(payload_file, json_list):
    result = json.loads(json_list)[0]
    payload = result['payload']
    payload_encoding = result['payload_encoding']
    with open(payload_file, 'w') as f:
        if payload_encoding == 'base64':
            data = base64.b64decode(payload)
        else:
            data = payload
        f.write(data)

def print_bash_completion():
    script = """# This is a bash completion script for rabbitmqadmin.
# Redirect it to a file, then source it or copy it to /etc/bash_completion.d
# to get tab completion. rabbitmqadmin must be on your PATH for this to work.
_rabbitmqadmin()
{
    local cur prev opts base
    COMPREPLY=()
    cur="${COMP_WORDS[COMP_CWORD]}"
    prev="${COMP_WORDS[COMP_CWORD-1]}"

    opts="list show declare delete close purge import export get publish"
    fargs="--help --host --port --vhost --username --password --format --depth"

    case "${prev}" in
	list)
	    COMPREPLY=( $(compgen -W '""" + " ".join(LISTABLE) + """' -- ${cur}) )
            return 0
            ;;
	show)
	    COMPREPLY=( $(compgen -W '""" + " ".join(SHOWABLE) + """' -- ${cur}) )
            return 0
            ;;
	declare)
	    COMPREPLY=( $(compgen -W '""" + " ".join(DECLARABLE.keys()) + """' -- ${cur}) )
            return 0
            ;;
	delete)
	    COMPREPLY=( $(compgen -W '""" + " ".join(DELETABLE.keys()) + """' -- ${cur}) )
            return 0
            ;;
	close)
	    COMPREPLY=( $(compgen -W '""" + " ".join(CLOSABLE.keys()) + """' -- ${cur}) )
            return 0
            ;;
	purge)
	    COMPREPLY=( $(compgen -W '""" + " ".join(PURGABLE.keys()) + """' -- ${cur}) )
            return 0
            ;;
	export)
	    COMPREPLY=( $(compgen -f ${cur}) )
            return 0
            ;;
	import)
	    COMPREPLY=( $(compgen -f ${cur}) )
            return 0
            ;;
	-@(H|-host))
	    COMPREPLY=( $(compgen -A hostname ${cur}) )
            return 0
            ;;
	-@(V|-vhost))
            opts="$(rabbitmqadmin -q -f bash list vhosts)"
	    COMPREPLY=( $(compgen -W "${opts}"  -- ${cur}) )
            return 0
            ;;
	-@(u|-username))
            opts="$(rabbitmqadmin -q -f bash list users)"
	    COMPREPLY=( $(compgen -W "${opts}"  -- ${cur}) )
            return 0
            ;;
	-@(f|-format))
	    COMPREPLY=( $(compgen -W \"""" + " ".join(FORMATS.keys()) + """\"  -- ${cur}) )
            return 0
            ;;

"""
    for l in LISTABLE:
        key = l[0:len(l) - 1]
        script += "        " + key + """)
            opts="$(rabbitmqadmin -q -f bash list """ + l + """)"
	    COMPREPLY=( $(compgen -W "${opts}"  -- ${cur}) )
            return 0
            ;;
"""
    script += """        *)
        ;;
    esac

   COMPREPLY=($(compgen -W "${opts} ${fargs}" -- ${cur}))
   return 0
}
complete -F _rabbitmqadmin rabbitmqadmin
"""
    output(script)

if __name__ == "__main__":
    main()
