# -------------------------------------------------------------------------
#     This file is part of mMass - the spectrum analysis tool for MS.
#     Copyright (C) 2005-07 Martin Strohalm <mmass@biographics.cz>

#     This program is just a simplified PyPlot library, originaly developped
#     and copyrighted by Gordon Williams and Jeff Grimmett. Thank you!

#     This program is free software; you can redistribute it and/or modify
#     it under the terms of the GNU General Public License as published by
#     the Free Software Foundation; either version 2 of the License, or
#     (at your option) any later version.

#     This program is distributed in the hope that it will be useful,
#     but WITHOUT ANY WARRANTY; without even the implied warranty of
#     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#     GNU General Public License for more details.

#     Complete text of GNU GPL can be found in the file LICENSE in the
#     main directory of the program
# -------------------------------------------------------------------------

# Function: Simple plot canvas.

# load libs
import wx
import numpy as num


# ----
class polyPoints:
    """ Base Class for lines and markers. """

    def __init__(self, points, attr):
        self.points = num.array(points)
        self.currentScale = (1,1)
        self.currentShift = (0,0)
        self.scaled = self.points
        self.attributes = {}
        self.attributes.update(self._attributes)
        for name, value in attr.items():   
            self.attributes[name] = value
        
    def boundingBox(self):
        if len(self.points) == 0:
            minXY = num.array([-1,-1])
            maxXY = num.array([ 1, 1])
        else:
            minXY = num.minimum.reduce(self.points)
            maxXY = num.maximum.reduce(self.points)
        return minXY, maxXY

    def scaleAndShift(self, scale=(1,1), shift=(0,0)):
        if len(self.points) == 0:
            return
        if (scale is not self.currentScale) or (shift is not self.currentShift):
            self.scaled = scale * self.points + shift
            self.currentScale = scale
            self.currentShift = shift
# ----


# ----
class polyLine(polyPoints):
    """ Class to define line type and style. """
    
    _attributes = {'colour': 'black',
                   'width': 1,
                   'style': wx.SOLID,
                   'legend': ''}

    def __init__(self, points, **attr):
        polyPoints.__init__(self, points, attr)

    def draw(self, dc):
        colour = self.attributes['colour']
        width = self.attributes['width']
        style = self.attributes['style']
        pen = wx.Pen(wx.NamedColour(colour), width, style)
        dc.SetPen(pen)
        dc.DrawLines(self.scaled)
# ----


# ----
class polyMarker(polyPoints):
    """ Class to define marker type and style. """
  
    _attributes = {'colour': 'black',
                   'width': 1,
                   'size': 2,
                   'fillcolour': None,
                   'fillstyle': wx.SOLID,
                   'marker': 'circle',
                   'legend': ''}

    def __init__(self, points, **attr):
        polyPoints.__init__(self, points, attr)

    def draw(self, dc):
        colour = self.attributes['colour']
        width = self.attributes['width']
        size = self.attributes['size']
        fillcolour = self.attributes['fillcolour']
        fillstyle = self.attributes['fillstyle']
        marker = self.attributes['marker']

        dc.SetPen(wx.Pen(wx.NamedColour(colour), width))
        if fillcolour:
            dc.SetBrush(wx.Brush(wx.NamedColour(fillcolour),fillstyle))
        else:
            dc.SetBrush(wx.Brush(wx.NamedColour(colour), fillstyle))
        self._drawmarkers(dc, self.scaled, marker, size)

    def _drawmarkers(self, dc, coords, marker,size=1):
        f = eval('self._' +marker)
        f(dc, coords, size)

    def _circle(self, dc, coords, size=1):
        fact= 2.5*size
        wh= 5.0*size
        rect= num.zeros((len(coords),4),num.float)+[0.0,0.0,wh,wh]
        rect[:,0:2]= coords-[fact,fact]
        dc.DrawEllipseList(rect.astype(num.int32))

    def _dot(self, dc, coords, size=1):
        dc.DrawPointList(coords)

    def _square(self, dc, coords, size=1):
        fact= 2.5*size
        wh= 5.0*size
        rect= num.zeros((len(coords),4),num.float)+[0.0,0.0,wh,wh]
        rect[:,0:2]= coords-[fact,fact]
        dc.DrawRectangleList(rect.astype(num.int32))

    def _triangle(self, dc, coords, size=1):
        shape= [(-2.5*size,1.44*size), (2.5*size,1.44*size), (0.0,-2.88*size)]
        poly= num.repeat(coords,3)
        poly.shape= (len(coords),3,2)
        poly += shape
        dc.DrawPolygonList(poly.astype(num.int32))

    def _triangle_down(self, dc, coords, size=1):
        shape= [(-2.5*size,-1.44*size), (2.5*size,-1.44*size), (0.0,2.88*size)]
        poly= num.repeat(coords,3)
        poly.shape= (len(coords),3,2)
        poly += shape
        dc.DrawPolygonList(poly.astype(num.int32))
      
    def _cross(self, dc, coords, size=1):
        fact= 2.5*size
        for f in [[-fact,-fact,fact,fact],[-fact,fact,fact,-fact]]:
            lines= num.concatenate((coords,coords),axis=1)+f
            dc.DrawLineList(lines.astype(num.int32))

    def _plus(self, dc, coords, size=1):
        fact= 2.5*size
        for f in [[-fact,0,fact,0],[0,-fact,0,fact]]:
            lines= num.concatenate((coords,coords),axis=1)+f
            dc.DrawLineList(lines.astype(num.int32))
# ----


# ----
class plotGraphics:
    """ Container to hold PolyXXX objects. """

    def __init__(self, objects):
        self.objects = objects

    def getBoundingBox(self):
        p1, p2 = self.objects[0].boundingBox()
        for o in self.objects[1:]:
            p1o, p2o = o.boundingBox()
            p1 = num.minimum(p1, p1o)
            p2 = num.maximum(p2, p2o)
        return p1, p2

    def scaleAndShift(self, scale=(1,1), shift=(0,0)):
        for object in self.objects:
            object.scaleAndShift(scale, shift)

    def draw(self, dc):
        for object in self.objects:
            object.draw(dc)

    def __len__(self):
        return len(self.objects)

    def __getitem__(self, item):
        return self.objects[item]
# ----


# ----
class plotCanvas(wx.Window):
    """ Plot canvas class."""

    # ----
    def __init__(self, parent, id = -1, pos=wx.DefaultPosition, size=wx.DefaultSize, style=wx.DEFAULT_FRAME_STYLE):
        wx.Window.__init__(self, parent, id, pos, size, style)

        # set default canvas params
        self.SetBackgroundColour('white')

        self.xLabel = ''
        self.yLabel = ''
        self.title = ' '

        self.xSpec = 'auto'
        self.ySpec = 'auto'

        self.axisColour = wx.Colour(0, 0, 0)
        self.gridColour = wx.Colour(240, 240, 240)
        if wx.Platform == "__WXMAC__":
            self.mainFont = wx.Font(10, wx.SWISS, wx.NORMAL, wx.NORMAL, 0)
        else:
            self.mainFont = wx.Font(8, wx.SWISS, wx.NORMAL, wx.NORMAL, 0)

        # initial values
        self.lastDraw = None
        self.pointScale = 1
        self.pointShift = 0

        self.Bind(wx.EVT_PAINT, self.onPaint)
        self.Bind(wx.EVT_SIZE, self.onSize)

        # initialize bitmap buffer and set initial size based on client size
        self.onSize(0)
    # ----


    # ----
    def onPaint(self, evt):
        """ Repaint spectrum. """

        # draw buffer to screen
        dc = wx.BufferedPaintDC(self, self.plotBuffer)
    # ----


    # ----
    def onSize(self, evt):
        """ Repaint spectrum when size changed. """

        # get size
        width, height = self.GetClientSize()
        if width <= 0 or height <= 0:
            width = 1
            height = 1

        # make new offscreen bitmap
        self.plotBuffer = wx.EmptyBitmap(width, height)
        self.setSize()

        # redraw plot or clear area
        if self.lastDraw:
            self.draw(self.lastDraw[0])
        else:
            self.clear()
    # ----


    # ----
    def setTitle(self, value):
        """ Set title. """
        self.title = value
    # ----


    # ----
    def setXLabel(self, value):
        """ Set label for X axis. """
        self.xLabel = value
    # ----


    # ----
    def setYLabel(self, value):
        """ Set label for Y axis. """
        self.yLabel = value
    # ----


    # ----
    def setAxisFont(self, font):
        """ Set font for axis. """
        self.mainFont = font
    # ----


    # ----
    def setAxisColour(self, red, green, blue):
        """ Set color for axis. """
        self.axisColour = wx.Colour(red, green, blue)
    # ----


    # ----
    def setGridColour(self, red, green, blue):
        """ Set color for grid. """
        self.gridColour = wx.Colour(red, green, blue)
    # ----

        
    # ----
    def SetXSpec(self, type='auto'):
        """ Set x-axit type. """
        self.xSpec = type
    # ----


    # ----
    def SetYSpec(self, type= 'auto'):
        """ Set y-axit type. """
        self.ySpec = type
    # ----


    # ----
    def setSize(self, width=None, height=None):
        """ Set DC width and height. """

        # get size
        if width == None:
            (self.width, self.height) = self.GetClientSize()
        else:
            self.width, self.height = width, height

        # set size
        self.plotBoxSize = num.array([self.width, self.height])
        xo = 0.5 * (self.width - self.plotBoxSize[0])
        yo = self.height - 0.5 * (self.height - self.plotBoxSize[1])
        self.plotBoxOrigin = num.array([xo, yo])
    # ----


    # ----
    def draw(self, graphics):
        """ Draw objects in graphics. """

        # sets new dc and clears it 
        dc = wx.BufferedDC(wx.ClientDC(self), self.plotBuffer)
        dc.Clear()
        dc.BeginDrawing()
        
        # set axis font
        dc.SetFont(self.mainFont)

        # get lower left and upper right corners of plot
        p1, p2 = graphics.getBoundingBox()
        xAxis = self.getAxisInterval(self.xSpec, p1[0], p2[0])
        yAxis = self.getAxisInterval(self.ySpec, p1[1], p2[1])
        p1[0], p1[1] = xAxis[0], yAxis[0]
        p2[0], p2[1] = xAxis[1], yAxis[1]

        # save most recent values
        self.lastDraw = (graphics, xAxis, yAxis)

        # get axis ticks
        xAxisTicks = self.makeAxisTicks(xAxis[0], xAxis[1])
        yAxisTicks = self.makeAxisTicks(yAxis[0], yAxis[1])

        # get text extents for axis ticks
        xLeft = dc.GetTextExtent(xAxisTicks[0][1])
        xRight = dc.GetTextExtent(xAxisTicks[-1][1])
        yBottom = dc.GetTextExtent(yAxisTicks[0][1])
        yTop = dc.GetTextExtent(yAxisTicks[-1][1])
        xAxisTextExtent = (max(xLeft[0], xRight[0]), max(xLeft[1], xRight[1]))
        yAxisTextExtent = (max(yBottom[0], yTop[0]), max(yBottom[1], yTop[1]))

        # get text extents for axis labels and title
        titleWH = dc.GetTextExtent(self.title)
        xAxisLabelWH = dc.GetTextExtent(self.xLabel)
        yAxisLabelWH = dc.GetTextExtent(self.yLabel)

        # get room around graph area
        spaceRight = xAxisTextExtent[0]
        spaceLeft = yAxisTextExtent[0] + yAxisLabelWH[1] + 20
        spaceBottom = xAxisTextExtent[1] + xAxisLabelWH[1] + 15
        spaceTop = 5 + yAxisTextExtent[1]/2. + titleWH[1]

        # draw axis labels and title
        titlePos = (self.plotBoxOrigin[0] + spaceLeft + (self.plotBoxSize[0] - spaceLeft - spaceRight)/2.- titleWH[0]/2., self.plotBoxOrigin[1] - self.plotBoxSize[1] + 5)
        dc.DrawText(self.title, titlePos[0], titlePos[1])
        xLabelPos = (self.plotBoxOrigin[0] + spaceLeft + (self.plotBoxSize[0] - spaceLeft - spaceRight)/2.- xAxisLabelWH[0]/2., self.plotBoxOrigin[1] - xAxisLabelWH[1] - 5)
        dc.DrawText(self.xLabel, xLabelPos[0], xLabelPos[1])
        yLabelPos = (self.plotBoxOrigin[0] + 5, self.plotBoxOrigin[1] - spaceBottom - (self.plotBoxSize[1] - spaceBottom - spaceTop)/2.+ yAxisLabelWH[0]/2.)
        dc.DrawRotatedText(self.yLabel, yLabelPos[0], yLabelPos[1], 90)

        # scaling and shifting plotted points
        textSizeScale = num.array([spaceRight + spaceLeft, spaceBottom + spaceTop])
        textSizeShift = num.array([spaceLeft, spaceBottom])
        scale = (self.plotBoxSize - textSizeScale) / (p2 - p1) * num.array((1, -1))
        shift = - p1 * scale + self.plotBoxOrigin + textSizeShift * num.array((1, -1))
        self.pointScale = scale
        self.pointShift = shift
        
        # draw plot axis
        self.drawAxis(dc, p1, p2, scale, shift, xAxisTicks, yAxisTicks)

        # set clipping area
        x, y, width, height = self.point2ClientCoord(p1, p2)
        dc.SetClippingRegion(x, y, width, height)

        # recalculate plot lines
        graphics.scaleAndShift(scale, shift)

        # draw plot lines
        graphics.draw(dc)

        # remove the clipping region and end drawing
        dc.DestroyClippingRegion()
        dc.EndDrawing()
    # ----


    # ----
    def clear(self):
        """Erase the window."""

        dc = wx.BufferedDC(wx.ClientDC(self), self.plotBuffer)
        dc.Clear()
        self.lastDraw = None
    # ----


    # ----
    def getAxisInterval(self, spec, lower, upper):
        """Returns sensible axis range for given spec"""

        # exact range
        if spec == 'none' or spec == 'min':
            if lower == upper:
                return lower - 0.5, upper + 0.5
            else:
                return lower, upper

        # extended range
        elif spec == 'auto':
            range = upper - lower
            if range == 0.:
                return lower - 0.5, upper + 0.5
            log = num.log10(range)
            power = num.floor(log)
            fraction = log - power
            if fraction <= 0.05:
                power = power - 1
            grid = 10.**power
            lower = lower - lower % grid
            mod = upper % grid
            if mod != 0:
                upper = upper - mod + grid
            return lower, upper
    # ----


    # ----
    def drawAxis(self, dc, p1, p2, scale, shift, xticks, yticks):
        """ Draw plot axis. """

        # set pen
        penWidth = 1
        dc.SetPen(wx.Pen(self.axisColour, penWidth))
        dc.SetTextForeground(self.axisColour)

        # get plot coordinates
        plotX1, plotY1, plotWidth, plotHeight = self.point2ClientCoord(p1, p2)
        plotX1 -= 1
        plotX2 = plotX1 + plotWidth + 1
        plotY1 -= 1
        plotY2 = plotY1 + plotHeight + 1

        # set length of tick marks
        tickLength = 5

        # x axis
        for x, label in xticks:
            pt = scale * num.array([x, p1[1]]) + shift
            dc.DrawLine(pt[0], plotY2, pt[0], plotY2 + tickLength)
            dc.DrawText(label, pt[0] - dc.GetTextExtent(label)[0] / 2, plotY2 + tickLength*1.4)

            # draw grid
            dc.SetPen(wx.Pen(self.gridColour, penWidth))
            dc.SetTextForeground(self.gridColour)
            dc.DrawLine(pt[0], plotY1, pt[0], plotY2)
            dc.SetPen(wx.Pen(self.axisColour, penWidth))
            dc.SetTextForeground(self.axisColour)

        # y axis
        charHeight = dc.GetCharHeight()
        for y, label in yticks:
            pt = scale * num.array([p1[0], y]) + shift
            dc.DrawLine(plotX1, pt[1], plotX1 - tickLength, pt[1])
            dc.DrawText(label, plotX1 - dc.GetTextExtent(label)[0] - tickLength*1.4, pt[1] - 0.5 * charHeight)

            # draw grid
            dc.SetPen(wx.Pen(self.gridColour, penWidth))
            dc.SetTextForeground(self.gridColour)
            dc.DrawLine(plotX1, pt[1], plotX2, pt[1])
            dc.SetPen(wx.Pen(self.axisColour, penWidth))
            dc.SetTextForeground(self.axisColour)

        # draw plot outline
        dc.SetBrush(wx.Brush(wx.BLACK, wx.TRANSPARENT))
        dc.DrawRectangle(plotX1, plotY1, plotWidth + 2, plotHeight + 2)
    # ----


    # ----
    def makeAxisTicks(self, lower, upper):
        """ Count axis ticks - fce from PyPlot. """

        ideal = (upper-lower)/7.
        log = num.log10(ideal)
        power = num.floor(log)
        fraction = log-power
        factor = 1.
        error = fraction
        multiples = [(2., num.log10(2.)), (5., num.log10(5.))]

        for f, lf in multiples:
            e = num.fabs(fraction-lf)
            if e < error:
                error = e
                factor = f
        grid = factor * 10.**power
        if power > 4 or power < -4:
            format = '%+7.1e'
        elif power >= 0:
            digits = max(1, int(power))
            format = '%' + `digits`+'.0f'
        else:
            digits = -int(power)
            format = '%'+`digits+2`+'.'+`digits`+'f'
        ticks = []
        t = -grid*num.floor(-lower/grid)
        while t <= upper:
            ticks.append( (t, format % (t,)) )
            t = t + grid

        return ticks
    # ----


    # ----
    def point2ClientCoord(self, corner1, corner2):
        """ Convert user coords to client screen coords x,y,width,height. """

        c1 = num.array(corner1)
        c2 = num.array(corner2)

        # convert to screen coords
        pt1 = c1 * self.pointScale + self.pointShift
        pt2 = c2 * self.pointScale + self.pointShift

        # make height and width positive
        pointUpperLeft = num.minimum(pt1, pt2)
        pointLowerRight = num.maximum(pt1, pt2)
        rectWidth, rectHeight = pointLowerRight - pointUpperLeft
        pointX, pointY = pointUpperLeft

        return round(pointX), round(pointY), round(rectWidth), round(rectHeight)
    # ----
