#
# This file is part of GNU Enterprise.
#
# GNU Enterprise is free software; you can redistribute it
# and/or modify it under the terms of the GNU General Public
# License as published by the Free Software Foundation; either
# version 2, or (at your option) any later version.
#
# GNU Enterprise is distributed in the hope that it will be
# useful, but WITHOUT ANY WARRANTY; without even the implied
# warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
# PURPOSE. See the GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public
# License along with program; see the file COPYING. If not,
# write to the Free Software Foundation, Inc., 59 Temple Place
# - Suite 330, Boston, MA 02111-1307, USA.
#
# Copyright 2002-2004 Free Software Foundation
#
# $Id: Scripter.py 5665 2004-04-07 11:04:32Z reinhard $
#

from gnue.common import VERSION
from gnue.common.schema import GSParser
from gnue.common.utils.FileUtils import openResource, dyn_import
from gnue.common.apps.GClientApp import GClientApp
from processors import vendors
from gnue.common.schema.scripter.Definition import *
from time import strftime
from string import join

import sys
import os
import re

# =============================================================================
# Generate SQL files from GNUe Schema Definition files
# =============================================================================

class Scripter (GClientApp):

  VERSION         = VERSION
  COMMAND         = "gnue-schema"
  NAME            = "GNUe Schema Scripter"
  USAGE           = "[options] file [old-schema]"
  SUMMARY = _("GNUe Schema Scripter creates SQL files based on GNUe "
              "Schema Definitions.")

  _PROC_PATH = "gnue.common.schema.scripter.processors.%s"


  # ---------------------------------------------------------------------------
  # Constructor
  # ---------------------------------------------------------------------------

  def __init__ (self, connections = None):
    self.addCommandOption('drop_tables',longOption='drop-tables',
        help=_("Generate commands to drop relevant tables.  * NOT IMPLEMENTED"))

    self.addCommandOption('ignore_schema','S','no-schema',
        help=_("Do not generate schema creation code.  * NOT IMPLEMENTED"))

    self.addCommandOption('ignore_data','D','no-data',
        help=_("Do not generate data insertion code.  * NOT IMPLEMENTED"))

    self.addCommandOption('encoding', 'e', default='UTF-8', argument=_('encoding'),
        help= _("The generated SQL script will be encoded using <encoding>. "
                "Default encoding is UTF-8") )

    self.addCommandOption('upgrade_schema','u','upgrade-schema',
        help= _("Generate code to upgrade an older version of a schema to "
                "the recent version. You must specify a previous schema "
                "on the command line.  * NOT IMPLEMENTED") )

    self.addCommandOption('upgrade_data','U',
        help= _("Generate code to upgrade an older version of schema data to "
                "the recent version. You must specify a previous schema "
                "on the command line.  * NOT IMPLEMENTED") )

    self.addCommandOption('help-vendors',shortOption='l', action=self.__listVendors,
        help=_("List all supported vendors.") )

    self.addCommandOption('output','o', argument='dest',
        help= _("The destination for the created schemas. This can be in several "
                "formats. If <dest> is a file name, then output is written to "
                "this file. If <dest> is a directory, then <dest>/<Vendor>.sql "
                "is created. The default is to create <Vendor>.sql in the "
                "current directory. NOTE: the first form (<dest> as a filename) "
                "is not supported for --vendors all.") )

    self.addCommandOption('vendor','v', default="all", argument='vendor',
        help= _("The vendor to create a script for. If <vendor> is 'all', then "
                "scripts for all supported vendors will be created. <vendor> "
                "can also be a comma-separated list."))

    ConfigOptions = {}
    GClientApp.__init__ (self, connections, 'schema', ConfigOptions)
    self.__vendors = []


  # ---------------------------------------------------------------------------
  # Main program
  # ---------------------------------------------------------------------------
  def run (self):
    self.__check_options ()

    try:
      print _("Loading gsd file '%s' ...") % self.ARGUMENTS [0]

      self.schema = GSParser.loadFile (self.__input)

    except Exception:
      print sys.exc_info () [1]

    else:
      for vendor in self.__vendors:
        self.__runProcessor (vendor)



  # ---------------------------------------------------------------------------
  # Walk through all command line options
  # ---------------------------------------------------------------------------
  def __check_options (self):

    # we need at least one thing to do :)
    if self.OPTIONS ["ignore_schema"] and self.OPTIONS ["ignore_data"]:
      self.handleStartupError (_("--no-schema and --no-data cannot be used "
                                 "together. What to export?"))

    # check for unsupported options
    if self.OPTIONS ["drop_tables"] or self.OPTIONS ["upgrade_schema"] or \
       self.OPTIONS ["upgrade_data"]:
      self.handleStartupError (_("--drop-tables, --upgrade-schema and "
                                 "--upgrade-data\n are not implemented yet."))


    # do we have an accessible input file
    if not len (self.ARGUMENTS):
      self.handleStartupError (_("No input file specified."))

    try:
      self.__input = openResource (self.ARGUMENTS [0])

    except IOError:
      self.handleStartupError (_("Unable to open input file %s.") % \
                                 self.ARGUMENTS [0])

    # check the specified vendors
    if self.OPTIONS ["vendor"].lower () == "all":
      self.__vendors.extend (vendors)
    else:
      self.__vendors.extend (self.OPTIONS ["vendor"].split (","))

    self.__output = self.OPTIONS ["output"]
    if len (self.__vendors) > 1 and self.__output is not None:
      if not os.path.isdir (self.__output):
        self.handleStartupError ( \
          _("If multiply vendors are specified --output must be a "
            "directory or\n left empty."))


  # ---------------------------------------------------------------------------
  # Print a list of all available processors
  # ---------------------------------------------------------------------------

  def __listVendors (self):
    self.printHelpHeader()
    print "The following vendors can be passed as a parameter to the --vendor option."
    print "To specify multiple vendors, separate with commas. "
    print
    print "Supported Database Vendors:"

    modules  = {}
    maxsize  = 0

    for vendor in vendors:
      maxsize = max(maxsize, len (vendor))

      try:
        modules [vendor] = dyn_import (self._PROC_PATH % vendor)

      except ImportError:
        pass

    available = modules.keys ()
    available.sort()

    for vendor in available:
      print "   " + vendor.ljust (maxsize + 4), modules [vendor].description

    print
    sys.exit()

  # ---------------------------------------------------------------------------
  # Get the name of a given processor
  # ---------------------------------------------------------------------------

  def __getProcessorName (self, processor):
    return dyn_import (self._PROC_PATH % processor).name


  # ---------------------------------------------------------------------------
  # Run a given processor
  # ---------------------------------------------------------------------------

  def __runProcessor (self, vendor):
    if not self.__output:
      filename = "%s.sql" % self.__getProcessorName (vendor)

    elif os.path.isdir (self.__output):
      filename = os.path.join (self.__output,
                               "%s.sql" % self.__getProcessorName (vendor))
    else:
      filename = self.__output

    try:
      self.destination = open (filename, "w")

    except IOError:
      sys.stderr.write (_("Unable to create output file %s.") % filename)
      sys.exit (1)


    # Instanciate the given processor and iterate over all schema objects
    aModule = self._PROC_PATH % vendor
    self.processor = dyn_import (aModule).Processor (self.destination, \
                                                     self.ARGUMENTS [0])

    print _("Writing schema to %s ...") % filename

    try:
      self.tables = {}
      self.data   = []

      self.processor.startDump ()
      self.processor.client_encoding (self.OPTIONS ['encoding'])

      self.schema.walk (self.__iterate_objects)

      maxPhase = 0
      for table in self.tables.values ():
        maxPhase = max (maxPhase, max (table.phases.keys ()))

      for phase in range (0, maxPhase + 1):
        for table in self.tables.values ():
          self.processor.writePhase (table, phase)

      for table in self.data:
        self.processor.writeData (table, table.tableDef)

      self.processor.finishDump ()

      # and finally close the output file
      self.destination.close ()

    except Exception, message:
      os.unlink (filename)
      print message




  # ---------------------------------------------------------------------------
  # iteration over all schema objects in the document tree
  # ---------------------------------------------------------------------------

  def __iterate_objects (self, sObject):
    if sObject._type == "GSSchema":
      if not self.OPTIONS ["ignore_schema"]:
        self.processor.startSchema ()

    elif sObject._type == "GSTable":
      if not self.OPTIONS ["ignore_schema"]:
        self.__schema_table (sObject)

    elif sObject._type == "GSData":
      if not self.OPTIONS ["ignore_data"]:
        self.processor.startData ()

    elif sObject._type == "GSTableData":
      if not self.OPTIONS ["ignore_data"]:
        self.__data_table (sObject)

    return


  # ---------------------------------------------------------------------------
  # Process the schema definition of a GSTable object
  # ---------------------------------------------------------------------------
  def __schema_table (self, sObject):
    aTable = TableDefinition (sObject.name, sObject.action)
    self.tables [aTable.name] = aTable
    sObject.walk (self.__schema_fields, tableDef = aTable)

    self.processor.translateTableDefinition (aTable)


  # ---------------------------------------------------------------------------
  # Process the fields of a GSTable
  # ---------------------------------------------------------------------------
  def __schema_fields (self, sObject, tableDef):

    # process a regular field of a table
    if sObject._type == "GSField":
      tableDef.fields.append (sObject)

    elif sObject._type == "GSPrimaryKey":
      pkdef = tableDef.addPrimaryKey (sObject.name)
      sObject.walk (self.__schema_primarykey, tableDef = tableDef, pDef = pkdef)

    # start an index definition and process it's fields
    elif sObject._type == "GSIndex":
      uniq = hasattr (sObject, "unique") and sObject.unique
      index = tableDef.newIndex (sObject.name, uniq)

      # iterate over all index fields
      sObject.walk (self.__schema_index, tableDef = tableDef, indexDef = index)

    # create constraints
    elif sObject._type == "GSConstraint":
      # for unique-constraints we use a 'unique index'
      if sObject.type == "unique":
        cDef = tableDef.newIndex (sObject.name, True)

      # for all other types of constraints we use a ConstraintDefinition
      else:
        cDef = tableDef.newConstraint (sObject.name, sObject.type)

      sObject.walk (self.__schema_constraint, constraint = cDef)



  # ---------------------------------------------------------------------------
  # Iterate over all fields of a primary key
  # ---------------------------------------------------------------------------
  def __schema_primarykey (self, sObject, tableDef, pDef):
    if sObject._type == "GSPKField":
      pDef.fields.append (sObject)


  # ---------------------------------------------------------------------------
  # Iterate over all fields of an index
  # ---------------------------------------------------------------------------
  def __schema_index (self, sObject, tableDef, indexDef):
    if sObject._type == "GSIndexField":
      indexDef.fields.append (sObject)


  # ---------------------------------------------------------------------------
  # Iterate over all children of a constraint definition
  # ---------------------------------------------------------------------------

  def __schema_constraint (self, sObject, constraint):
    if sObject._type == "GSConstraintField":
      constraint.fields.append (sObject)

    elif sObject._type == "GSConstraintRef":
      constraint.reftable = sObject.table
      constraint.reffields.append (sObject)


  # ---------------------------------------------------------------------------
  # Process a tabledata node
  # ---------------------------------------------------------------------------
  def __data_table (self, sObject):
    data = DataDefinition (sObject.tablename)

    self.data.append (data)
    if self.tables.has_key (data.name):
      data.tableDef = self.tables [data.name]

    sObject.walk (self.__data_rows, dataDef = data)


  # ---------------------------------------------------------------------------
  # Iterate over all rows of a tabledata definition
  # ---------------------------------------------------------------------------
  def __data_rows (self, sObject, dataDef):
    if sObject._type == "GSRow":
      row = dataDef.addRow ()
      sObject.walk (self.__data_values, rowDef = row)


  # ---------------------------------------------------------------------------
  # Iterate over all values of a row definition
  # ---------------------------------------------------------------------------
  def __data_values (self, sObject, rowDef):
    if sObject._type == "GSValue":
      if hasattr (sObject, "field"):
        rowDef.columns.append (sObject.field)

      if hasattr (sObject, "type"):
        rowDef.types.append (sObject.type)

      rowDef.values.append (sObject)


# =============================================================================
# If executed directly, start the scripter
# =============================================================================
if __name__ == '__main__':
  Scripter ().run ()
