from numpy import diag, ones, inf, any, copy, zeros, dot, where, all, tile, sum, nan, isfinite, float64, isnan, log10, max, sign
from numpy.linalg import norm
try:
    from numpy.linalg import cond
except:
    print 'warning: no cond in numpy.linalg, matrix B rejuvenation check will be omitted'
    cond = lambda Matrix: 1

from scikits.openopt.Kernel.BaseAlg import *
from scikits.openopt.Kernel.Point import Point
from scikits.openopt.Kernel.setDefaultIterFuncs import SMALL_DELTA_X,  SMALL_DELTA_F,  SMALL_DF,  IS_LINE_SEARCH_FAILED
from UkrOptMisc import getConstrDirection

class ralg(BaseAlg):
    __name__ = 'ralg'
    __license__ = "BSD"
    __authors__ = "Dmitrey"
    __alg__ = "Naum Z. Shor R-algorithm with adaptive space dilation & some modifications"
    __optionalDataThatCanBeHandled__ = ['A', 'Aeq', 'b', 'beq', 'lb', 'ub', 'c', 'h']
    __iterfcnConnected__ = True

    #ralg default parameters
    alp, h0, nh, q1, q2  = 2.0, 1.0, 3, 'default:0.9 for NLP, 1.0 for NSP', 1.1
    hmult = 0.5
    S = 0
    T = float64
    dilationType = 'auto'

    showLS = False
    showRej = False
    showRes = False
    show_nnan = False
    doBackwardSearch = 1
    check_b_cond = 250
    maxRejNum = 15
    rejNum = 0
    def needRej(self, p, b, g, g_dilated):
        self.rejNum += 1
        if self.rejNum > self.maxRejNum: return False
        if self.check_b_cond >= 1 and (p.iter % self.check_b_cond == 0 or p.istop):
            cb = cond(b)
            r = cb > 1e13
            if self.showRej:
                msg = 'ralg iter%d log10(cond b) = %0.1f rej = ' % (p.iter, log10(cb+1e-100))
                if r: msg += 'True'
                else: msg += 'False'
                p.info(msg)
        else:
            r = False

        return r


        #if self.showRej: print 'itn %d log10(norm(g)/norm(g_dilated)) = %d' % (p.iter, log10(norm(g)/norm(g_dilated)))
        #return norm(g_dilated)  < 1e-15*norm(g)
    #checkTurnByGradient = True

    def __init__(self): pass
    def __solver__(self, p):

        alp, h0, nh, q1, q2 = self.alp, self.h0, self.nh, self.q1, self.q2
        self.check_b_cond = int(self.check_b_cond)
        self.rejNum = 0
        if type(q1) == str:
            if p.probType== 'NLP' and p.isUC: q1 = 0.9
            else: q1 = 1.0
        T = self.T
        # alternatively instead of alp=self.alp etc you can use directly self.alp etc



        n = p.n
        b = diag(ones(n,  T))
#        B_f = diag(ones(n))
#        B_constr = diag(ones(n))
        hs = T(h0)
        w = T(1.0/alp-1.0)
        x0 = atleast_1d(T(copy(p.x0)))


        """                            Shor r-alg engine                           """

        x = x0.copy()
        iterPoint = p.point(x)
        prevIterPoint = p.point(x)

        g = self.__getRalgDirection__(iterPoint)
        moveDirection = g
        if not any(g):
            p.istop = SMALL_DF
            p.msg = '|| gradient F(X[k]) || < gradtol'
            return

#        #pass-by-ref! not copy!
#        if p.isFeas(p.x0): b = B_f
#        else: b = B_constr

        """                           Ralg main cycle                                    """

        for itn in xrange(1500000):
            doDilation = True

            ls1 = 0
            # TODO: is (g^T b)^T better?
            g_tmp = self.__economyMult__(b.T, moveDirection)
            if any(g_tmp): g_tmp /= p.norm(g_tmp)
            g1 = p.matmult(b, g_tmp)

            """                           Forward line search                          """

            for ls in xrange(p.maxLineSearch):
                ls1 += 1
                x -= hs * g1#dotwise

                if itn == 0 and ls1 > 1:
                    hs *= 2.0
                elif ls1 > nh:
                    hs *= q2
                    ls1 = 0

                newPoint = p.point(x)
                if self.show_nnan: p.info('ls: %d nnan: %d' % (ls, newPoint.__nnan__()))

                if ls == 0:
                    oldPoint = prevIterPoint#p.point(p.xk, f = p.fk, mr = p.rk)

                #if not self.checkTurnByGradient:
                if newPoint.betterThan(oldPoint):
                    #TODO: 1. handle possible noise here; 2. handle the case
#                    if ls>15: ...
                    oldPoint, newPoint = newPoint,  None
                else:
                    break

            if ls == p.maxLineSearch-1:
                p.istop,  p.msg = IS_LINE_SEARCH_FAILED,  'maxLineSearch (' + str(p.maxLineSearch) + ') has been exceeded'
                return

            g2 = self.__getRalgDirection__(newPoint) # used for dilation direction obtaining

            iterPoint  = newPoint
#            if ls > 0:
#                iterPoint = oldPoint
#            else:
#                iterPoint  = newPoint

            """                          Backward line search                          """


            if ls == 0 and self.doBackwardSearch:
                x_tmp = newPoint.x.copy()
                PrevPoint = iterPoint
                while 1:
                    hs *= self.hmult
                    hs_prev = hs

                    if itn == 0:
                        x_tmp = p.xk - hs * g1
                    else:
                        x_tmp = 0.5*x_tmp+0.5*prevIterPoint.x
                        hs /= self.hmult

                    newPoint = p.point(x_tmp)

                    if PrevPoint.betterThan(newPoint) or newPoint.f() > PrevPoint.f() or abs(newPoint.f() - p.fk) < 15 * p.ftol or p.norm(newPoint.x - p.xk) < 15 * p.xtol:
                        iterPoint, hs = PrevPoint, hs_prev
                        break

                    PrevPoint = newPoint

                    ls -= 1
                    if itn != 0:
                        break


                iterPoint = PrevPoint



            """                      iterPoint has been obtained                     """

            moveDirection = self.__getRalgDirection__(iterPoint)
            x = iterPoint.x.copy()
            if ls <= 0: hs *= q1
            if itn == 0:
                p.debugmsg('hs: ' + str(hs))
                p.debugmsg('ls: ' + str(ls))
            if self.showLS: p.info('ls: ' + str(ls))
            if self.show_nnan: p.info('nnan: ' + str(iterPoint.__nnan__()))
            if self.showRes:
                r, fname, ind = iterPoint.mr(True)
                p.info(fname+str(ind))

            """                         Set dilation direction                            """

            #if sum(p.dotmult(g, g2))>0:
                #p.debugmsg('ralg warning: slope angle less than pi/2. Mb dilation for the iter will be omitted.')
                #doDilation = False

            prevIterPointIsFeasible = prevIterPoint.isFeas()
            currIterPointIsFeasible = iterPoint.isFeas()
            r_p, ind_p, fname_p = prevIterPoint.mr(1)
            r_, ind_, fname_ = iterPoint.mr(1)

            if self.dilationType == 'normalized' and (not fname_p in ('lb', 'ub', 'lin_eq', 'lin_ineq') or not fname_ in ('lb', 'ub', 'lin_eq', 'lin_ineq')) and (fname_p != fname_  or ind_p != ind_):
                G2,  G = g2/norm(g2), g/norm(g)
            else:
                G2,  G = g2, g

            if prevIterPointIsFeasible == currIterPointIsFeasible == True:
                g1 = G2 - G
            elif prevIterPointIsFeasible == currIterPointIsFeasible == False:
                g1 = G2 - G
            elif prevIterPointIsFeasible:
                g1 = G2.copy()
            else:
                g1 = -G.copy() # signum doesn't matter here

#            #pass-by-ref! not copy!
#            if currIterPointIsFeasible: b = B_f
#            else: b = B_constr


            """                             Perform dilation                               """

            g = self.__economyMult__(b.T, g1)
            ng = p.norm(g)
            p._df = g2.copy()

            if self.needRej(p, b, g1, g):
                if self.showRej or p.debug:
                    p.info('debug msg: matrix B restoration in ralg solver')
                b = diag(ones(n))
                hs = 0.5*p.norm(prevIterPoint.x - iterPoint.x)
            if all(isfinite(g)) and ng > 1e-50 and doDilation:
                g = (g / ng).reshape(-1,1)
                vec1 = self.__economyMult__(b, g).reshape(-1,1)
                vec2 = w * g.T
                b += p.matmult(vec1, vec2)



            """                               Call OO iterfcn                                """

            if iterPoint.isFeas():
                if hasattr(iterPoint, '_df'):
                    p._df = iterPoint._df
                    #print 'has _df'
                #else:
                    #pass
                    #print "hasn't _df"
                #p._df = iterPoint._df
            p.iterfcn(iterPoint)


            """                             Check stop criteria                           """

#            cond_same_point = all(iterPoint.x == prevIterPoint.x)
#            if cond_same_point and not p.istop:
#                p.istop = SMALL_DELTA_X
#                p.msg = '|| X[k] - X[k-1] || < xtol'
#                p.stopdict[SMALL_DELTA_X] = True
#                return

            s2 = 0
            if not p.istop and not p.userStop:
                p.debugmsg('istop:'+str(p.istop))
                if p.stopdict.has_key(SMALL_DF):
                    if currIterPointIsFeasible: s2 = p.istop
                    p.stopdict.pop(SMALL_DF)
                if p.stopdict.has_key(SMALL_DELTA_F):
                    if currIterPointIsFeasible: s2 = p.istop
                    p.stopdict.pop(SMALL_DELTA_F)
                if p.stopdict.has_key(SMALL_DELTA_X):
                    if currIterPointIsFeasible or not prevIterPointIsFeasible or cond_same_point: s2 = p.istop
                    p.stopdict.pop(SMALL_DELTA_X)
                if s2 and (any(isnan(iterPoint.c())) or any(isnan(iterPoint.h()))) \
                and not p.isNaNInConstraintsAllowed\
                and not cond_same_point:
                    s2 = 0
                if not s2 and any(p.stopdict.values()):
                    for key,  val in p.stopdict.iteritems():
                        if val == True:
                            s2 = key
                            break
                p.istop = s2


            """                                If stop required                                """

            if p.istop:
                if self.needRej(p, b, g1, g):
                    b = diag(ones(n))
                    hs = 0.5*p.norm(prevIterPoint.x - iterPoint.x)
                    p.istop = 0
                else:
                    if newPoint.betterThan(oldPoint):
                        optimIterPoint = newPoint
                    else:
                        optimIterPoint = oldPoint
                    if any(optimIterPoint.x != iterPoint.x): p.iterfcn(optimIterPoint)
                    return


            """                Some final things for ralg main cycle                """
            #g = moveDirection.copy()
            g = g2.copy()

            prevIterPoint, iterPoint = iterPoint, None


    def __getRalgDirection22__(self, point):
        # TODO: what if df and/or dmr has some NaNs?
        maxRes = point.mr()

        if maxRes > point.p.contol \
        or (not point.p.isNaNInConstraintsAllowed and any(isnan(point.c())) or any(isnan(point.h())))\
        or not isfinite(point.f()):
            dmr = point.dmr()
            if self.S == 0:
                d = dmr
            else:
                d = point.df() + self.S*dmr
        else:
            d = point.df()
            if any(isnan(d)) and maxRes>0: # comparison to 0, not contol
                d = point.dmr()
        return d



    def __getRalgDirection__(self, point):
        p = point.p
        contol = p.contol
        maxRes, fname, ind = point.mr(retAll=True)
        if maxRes <= p.contol and all(isfinite(point.df())):
            return point.df()
        else:
            d = zeros(p.n)
            #if any(point.lb()>contol) or any(point.ub()>contol) or any(point.lin_eq()>contol) or any(point.lin_ineq()>contol):
            lb = point.lb()
            ub = point.ub()
            lin_ineq = point.lin_ineq()
            lin_eq = point.lin_eq()
            c = point.c()
            h = point.h()

#            tmp = sum(lb[lb>contol] ** 2) + sum(ub[ub>contol] ** 2)
#            if  lin_ineq.size > 0: tmp+= sum(lin_ineq[lin_ineq>contol] ** 2)
#            #if  lin_eq.size > 0: tmp+= sum(lin_eq[abs(lin_eq)>contol] ** 2)
#
#            maxNonLinConstraint = 0.0
#            if c.size > 0: maxNonLinConstraint = max((maxNonLinConstraint, max(c)))
#            if h.size > 0: maxNonLinConstraint = max((maxNonLinConstraint, max(abs(h))))

            if fname in ['lb',  'ub',  'lin_eq',  'lin_ineq']:# or tmp > maxNonLinConstraint:
                threshold = contol
                ind_lb = where(lb>0)[0]
                ind_ub = where(ub>0)[0]
                ind_lin_ineq = where(lin_ineq>threshold)[0]
                ind_lin_eq = where(abs(lin_eq)>threshold)[0]

                if ind_lb.size != 0:
                    d[ind_lb] -= lb[ind_lb]# 0.5*d/dx((x-lb)^2) for violated constraints
                if ind_ub.size != 0:
                    d[ind_ub] += ub[ind_ub]# 0.5*d/dx((x-ub)^2) for violated constraints
                if ind_lin_ineq.size != 0:
                    a = p.A[ind_lin_ineq]
                    b = p.b[ind_lin_ineq]
                    d += dot(a.T, dot(a, point.x)  - b) # 0.5*d/dx((Ax-b)^2)
                if ind_lin_eq.size != 0:
                    aeq = p.Aeq[ind_lin_eq]
                    beq = p.beq[ind_lin_eq]
                    d += dot(aeq.T, dot(aeq, point.x)  - beq) # 0.5*d/dx((Ax-b)^2)
            elif fname == 'c':
                d = point.dc(ind) # TODO: not recalculate all dc!
#                c = p.c(x)
#                ind_c = where(c > p.contol)[0]
#                if ind_c.size != 0:
#                    d += sum(p.dc(x)[ind_c].T * tile(c[ind_c].reshape(1,-1), (p.n,1)), 1)
            elif fname == 'h':
                d = sign(point.h(ind))*point.dh(ind) # TODO: not recalculate all dc!
#                h = p.h(x)
#                ind_h = where(h > p.contol)[0]
#                if ind_h.size == 0: p.err('error in getRalgDirection, you should report the bug')
#                d += sum(p.dh(x)[ind_h].T * tile(h[ind_h].reshape(1,-1), (p.n,1)), 1)
            else:
                p.err('error in getRalgDirection (unknown residual type ' + fname + ' ), you should report the bug')
            return d.flatten()

    def __economyMult__(self, M, V):
        #return dot(M, V)
        if all(V): # all v coords are non-zeros
            return dot(M, V)
        else:
            v = V.flatten()
            ind = where(v != 0)[0]
            r = dot(M[:,ind], v[ind])
            return r




