"""Auxillary classes for kplot.py which we do not want the user to see."""

__version__ = 0.3
__author__ = "Martin Wiechert <martin.wiechert@gmx.de>"
__date__ = "December 19, 2001"

import KMatplot as kmp
import Numeric
import kplot

class _kplane: # curve [/ bar graph / arrows]
    """Represents 2d data like a curve or bar graph or plane arrow set"""

    def __init__ (self, subplot, mask, data):
        self.parent = subplot
        self.id = self.parent.parent.proc.add_dataset (self.parent.id,
                                                       kmp.PlotCurve)
        for i in range (len (data)): # (len gives number of rows.)
            self.parent.parent.proc.set_channel (self.parent.id, self.id,
                                                 'XYxy'.find (mask [i]),
                                                 data [i, :])

class _kmap: # surface / contour / pixmap
    """Represents a scalar valued map from the plane.
    (I.e. a surface or contour or pixmap.)"""

    def __init__ (self, subplot, type, data):
        self.parent = subplot
        self.id = self.parent.parent.proc.add_dataset (self.parent.id, type)
        for i in range (len (data)):
            if data [i]:
                self.parent.parent.proc.set_channel (self.parent.id, self.id,
                                                     (i - 1) % 3, data [i])

class _kpoly: # polygons
    """Represents a polygon set in space."""

    def __init__ (self, subplot, X, Y, Z):
        self.parent = subplot
        self.id = self.parent.parent.proc.add_dataset (self.parent.id,
                                                       kmp.PlotFigure)
        for i in range (3):
            self.parent.parent.proc.set_channel (self.parent.id, self.id,
                                                 i, [X, Y, Z] [i])

class _ksubplot:
    """Maintains one set of axes."""

    def __init__ (self, plot):
        self.parent = plot
        self.hold = 1
        self._3d = None
        self.sets = []
        # Attributes 'id' and 'current' are created with first 'plot' call.

    def _shrink (self, kw):
        """Replaces dX, dY by x, y (to simplify parsing).
        Check keyword legality."""

        kw = kw.replace ('dX', 'x')
        kw = kw.replace ('dY', 'y')
        # Check legality.
        X = kw.count ('X')
        Y = kw.count ('Y')
        x = kw.count ('x')
        y = kw.count ('y')
        # Any of the above letters at most once, no other letters,
        # composite keywords must contain Y.
        if X > 1 or Y > 1 or x > 1 or y > 1 or X + Y + x + y != len (kw) or \
               (Y == 0 and len (kw) > 1):
            raise kplot.Error, 'Unknown keyword.'
        return kw

    def _plot_curve (self, kw, data):
        """Plots one curve."""

        if len (data) % len (kw):
            raise kplot.Error, 'Mismatched number of rows.'
        for i in range (0, len (data), len (kw)):
            self.sets.append (_kplane (self, kw, data [i : i + len (kw), :]))
            
    def plot (self, args, kwargs):
        """Plots one or more curves"""
        
        # Most of the code is for sorting out data according to keywords.
        # See kplot.kplot.plot doc string for keyword semantics.

        # First handle hold and axes status.
        if self._3d != 0: # wrong axes or no axes at all
            if self._3d == 1: # wrong axes - erase all
                self.parent.proc.remove_all_datasets (self.id)
                self.sets = []
                self.parent.proc.remove_axes (self.id)
            self._3d = 0
            self.id = self.parent.proc.add_axes (self._3d)
        elif not self.hold: 
            self.parent.proc.remove_all_datasets (self.id)
            self.sets = []

        # Now tackle arguments.
        # Composite keywords are independent (as they carry complete data
        # for one (or even multiple) plot(s)). So they are passed instantly.
        # All others (positional and one-letter keywords) are collected and
        # then plotted in one go.
        # Compatibility of positional and one-letter keyword args is checked
        # via args_given.
        coll_args = {} # Collect arguments here.
        if args: # First look for positional arguments.
            n = len (args)
            if n == 1:
                synth_kw = 'Y'
            else:
                synth_kw = 'XYxy' [:n]
            for i in range (n):
                coll_args [synth_kw [i]] = args [i] 
            if n > 2:
                args_given = 2 # error bars ...
            else:
                args_given = 1 # ... only coordinates ...
        else:
            args_given = 0 # ... nothing
            synth_kw = ''
        for kw in kwargs.keys ():
            kws = self._shrink (kw)
            if len (kws) > 1: # Mixed data - plot instantly.
                if not len (kwargs [kw].shape) == 2:
                    raise Error, 'Matrix expected.'
                self._plot_curve (kws, kwargs [kw])
            # Data given separately for X, Y, dX, dY - keep track, but check
            # for compatibility with positional arguments first.
            elif args_given == 2 or (args_given == 1 and kws in 'XY'):
                raise kplot.Error, 'Too many keyword arguments.'
            else:
                synth_kw = synth_kw + kws
                coll_args [kws] = kwargs [kw]

        # Finally plot collected arguments if there are any.
        if coll_args:
            if not 'Y' in synth_kw:
                raise kplot.Error, 'Y data missing.'
            # All data must have the same shape as Y-data.
            # Exception: If multiple curves are given, X may still be just one
            # row, because a common X-axis is feasible.
            sh = coll_args ['Y'].shape
            if len (sh) > 2:
                raise Error, 'Matrix or vector expected'
            if ('x' in synth_kw and coll_args ['x'].shape != sh) or \
               ('y' in synth_kw and coll_args ['y'].shape != sh) or \
               ('X' in synth_kw and not (coll_args ['X'].shape == sh or
                                         (len (coll_args ['X'].shape) == 1 and
                                          len (coll_args ['X']) == sh [-1]))):
                raise kplot.Error, 'Mismatched shapes.'
            n = len (synth_kw)
            data = Numeric.zeros ((n, sh [-1]), 'd') # container for data mix
            if len (sh) == 1: # just one curve
                for i in range (n): # mix data ...
                    data [i, :] = coll_args [synth_kw [i]]
                self._plot_curve (synth_kw, data) # ... and send
            else: # multiple curves
                if 'X' in synth_kw and len (coll_args ['X'].shape) == 1:
                    # Common X-axis
                    # Put 'X' to the end, set it once and for all and trick
                    # the subsequent loop into overlooking it.
                    synth_kw = synth_kw.replace ('X', '') + 'X'
                    n -= 1
                    data [n, :] = coll_args ['X']
                for j in range (sh [0]): 
                    for i in range (n): # mix data ...
                        data [i, :] = coll_args [synth_kw [i]] [j, :]
                    self._plot_curve (synth_kw, data) # ... and send
        self.current = len (self.sets) - 1

    def plane_map (self, args, type):
        """Plots a surface or contour or pixmap."""
        
        # First handle hold and axes status.
        _3d = type == kmp.PlotSurface # Surfaces are 3d, which contours and
                                      # pixmaps are not.
        if self._3d != _3d: # wrong axes or no axes at all
            if self._3d == (not _3d): # wrong axes - erase all
                self.parent.proc.remove_all_datasets (self.id)
                self.sets = []
                self.parent.proc.remove_axes (self.id)
            self._3d = _3d
            self.id = self.parent.proc.add_axes (self._3d)
        elif not self.hold: 
            self.parent.proc.remove_all_datasets (self.id)
            self.sets = []

        if not len (args [0].shape) == 2:
            raise kplot.Error, 'Matrix expected.'
        # Check if scaling vectors are ok if given.
        # Note that for pixmaps they must be one cell larger,
        # because they denote pixel edges.
        if not args [-2] or (len (args [-2].shape) == 1 and
                             len (args [-2]) == args [0].shape [1] +
                             (type == kmp.PlotImage)) and \
           not args [-1] or (len (args [-1].shape) == 1 and
                             len (args [-1]) == args [0].shape [0] +
                             (type == kmp.PlotImage)):
            if args [-1]: # KMatplot wants a column vector here.
                args [-1] = args [-1] [:, Numeric.NewAxis]
            self.sets.append (_kmap (self, type, args)) # Send.
        else:
            raise kplot.Error, 'Shapes do not match.'
        self.current = len (self.sets) - 1

    def polygon (self, X, Y, Z):
        """Plots a set of polygons."""

        # First handle hold and axes status.
        if self._3d != 1: # wrong axes or no axes at all
            if self._3d == 0: # wrong axes - erase all
                self.parent.proc.remove_all_datasets (self.id)
                self.sets = []
                self.parent.proc.remove_axes (self.id)
            self._3d = 1
            self.id = self.parent.proc.add_axes (self._3d)
        elif not self.hold: 
            self.parent.proc.remove_all_datasets (self.id)
            self.sets = []

        if len (X.shape) == 1:
            X, Y, Z  = [i [:, Numeric.NewAxis] for i in [X, Y, Z]]
        elif len (X.shape) == 2:
            X, Y, Z  = [Numeric.transpose (i) for i in [X, Y, Z]]
        else:
            raise kplot.Error, 'Matrix or vector expected.'
        self.sets.append (_kpoly (self, X, Y, Z))
        self.current = len (self.sets) - 1
        
