import py

from time import time as now
Item = py.test.Item
from py.__.test.terminal.out import getout 

def getrelpath(source, dest): 
    base = source.common(dest)
    if not base: 
        return None 
    # with posix local paths '/' is always a common base
    relsource = source.relto(base)
    reldest = dest.relto(base)
    n = relsource.count(source.sep)
    target = dest.sep.join(('..', )*n + (reldest, ))
    return target 

class TerminalSession(py.test.Session): 
    def __init__(self, config, file=None): 
        super(TerminalSession, self).__init__(config) 
        if file is None: 
            file = py.std.sys.stdout 
        self._file = file
        self.out = getout(file) 
        self._started = {}
        self._opencollectors = []

    def main(self, args): 
        if self.config.option._remote: 
            from py.__.test.terminal import remote 
            return remote.main(self.config, self._file, self.config._origargs)
        else: 
            return super(TerminalSession, self).main(args) 

    # ---------------------
    # PROGRESS information 
    # ---------------------
   
    def start(self, colitem):
        super(TerminalSession, self).start(colitem) 
        if self.config.option.collectonly: 
            cols = self._opencollectors
            self.out.line('    ' * len(cols) + repr(colitem))
            cols.append(colitem) 
        else: 
            cls = getattr(colitem, '__class__', None)
            if cls is None:
                return
            for typ in py.std.inspect.getmro(cls):
                meth = getattr(self, 'start_%s' % typ.__name__, None)
                if meth:
                    meth(colitem)
                    break 
            colitem.start = py.std.time.time() 

    def start_Module(self, colitem): 
        if self.config.option.verbose == 0: 
            abbrev_fn = getrelpath(py.path.local('.xxx.'), colitem.fspath)
            self.out.write('%s' % (abbrev_fn, ))
        else: 
            self.out.line()
            self.out.line("+ testmodule: %s" % colitem.fspath) 

    def startiteration(self, colitem, subitems): 
        if (isinstance(colitem, py.test.collect.Module) 
            and self.config.option.verbose == 0 
            and not self.config.option.collectonly): 
            try: 
                sum = 0
                for sub in subitems: 
                    sum += len(list(colitem.join(sub).tryiter()))
            except (SystemExit, KeyboardInterrupt): 
                raise 
            except: 
                self.out.write('[?]')
            else: 
                self.out.write('[%d] ' % sum) 
            return self.out.line 

    def start_Item(self, colitem): 
        if self.config.option.verbose >= 1: 
            if isinstance(colitem, py.test.Item): 
                realpath, lineno = colitem.getpathlineno()
                location = "%s:%d" % (realpath.basename, lineno+1)
                self.out.rewrite("%-20s %s " % (location, colitem.getmodpath()))
  
    def finish(self, colitem, outcome):
        end = now()
        super(TerminalSession, self).finish(colitem, outcome) 
        if self.config.option.collectonly: 
            cols = self._opencollectors 
            last = cols.pop()
            #assert last == colitem, "expected %r, got %r" %(last, colitem)
            return
        colitem.elapsedtime = end - colitem.start 
        if self.config.option.usepdb:
            if isinstance(outcome, Item.Failed): 
                print "dispatching to ppdb", colitem
                self.repr_failure(colitem, outcome) 
                import pdb
                self.out.rewrite('\n%s\n' % (outcome.excinfo.exconly(),))
                pdb.post_mortem(outcome.excinfo._excinfo[2])
        if (isinstance(outcome, py.test.Item.Failed) and 
            isinstance(colitem, py.test.collect.Module)): 
                self.out.line(" FAILED TO LOAD MODULE") 
        if isinstance(colitem, py.test.Item): 
            if self.config.option.verbose >= 1: 
                resultstring = self.repr_progress_long_result(colitem, outcome)
                resultstring += " (%.2f)" % (colitem.elapsedtime,)
                self.out.line(resultstring) 
            else:
                c = self.repr_progress_short_result(colitem, outcome)
                self.out.write(c) 


    # -------------------
    # HEADER information 
    # -------------------
    def header(self, colitems): 
        super(TerminalSession, self).header(colitems) 
        self.out.sep("=", "test process starts")
        option = self.config.option 
        modes = []
        for name in 'looponfailing', 'exitfirst', 'nomagic': 
            if getattr(option, name): 
                modes.append(name) 
        if option._fromremote:
            modes.insert(0, 'child process') 
        else:
            modes.insert(0, 'inprocess')
        mode = "/".join(modes)
        self.out.line("testing-mode: %s" % mode)
        self.out.line("executable:   %s  (%s)" %
                          (py.std.sys.executable, repr_pythonversion()))
        rev = py.__package__.getrev()
        self.out.line("using py lib: %s <rev %s>" % (
                       py.path.local(py.__file__).dirpath(), rev))
    
        if self.config.option.traceconfig or self.config.option.verbose: 

            for x in colitems: 
                self.out.line("test target:  %s" %(x.fspath,))

            for i,x in py.builtin.enumerate(self.config._initialconfigmodules): 
                self.out.line("initial conf %d: %s" %(i, x.__file__)) 

            #for i, x in py.builtin.enumerate(py.test.config.configpaths):
            #    self.out.line("initial testconfig %d: %s" %(i, x))
            #additional = py.test.config.getfirst('additionalinfo')
            #if additional:
            #    for key, descr in additional():
            #        self.out.line("%s: %s" %(key, descr))
        self.out.line() 
        self.starttime = now()
  
    # -------------------
    # FOOTER information 
    # -------------------
 
    def footer(self, colitems):
        super(TerminalSession, self).footer(colitems) 
        self.endtime = now()
        self.out.line() 
        self.skippedreasons()
        self.failures()
        self.summaryline()

    # --------------------
    # progress information 
    # --------------------
    typemap = {
        Item.Passed: '.',
        Item.Skipped: 's',
        Item.Failed: 'F',
    }
    namemap = {
        Item.Passed: 'ok',
        Item.Skipped: 'SKIP',
        Item.Failed: 'FAIL',
    }

    def repr_progress_short_result(self, item, outcome):
        for outcometype, char in self.typemap.items():
            if isinstance(outcome, outcometype):
                return char
        else:
            #raise TypeError, "not an Outomce instance: %r" % (outcome,)
            return '?'

    def repr_progress_long_result(self, item, outcome):
        for outcometype, char in self.namemap.items():
            if isinstance(outcome, outcometype):
                return char
        else:
            #raise TypeError, "not an Outcome instance: %r" % (outcome,)
            return 'UNKNOWN'

    # --------------------
    # summary information 
    # --------------------
    def summaryline(self): 
        outlist = []
        sum = 0
        for typ in Item.Passed, Item.Failed, Item.Skipped:
            l = self.getitemoutcomepairs(typ)
            if l:
                outlist.append('%d %s' % (len(l), typ.__name__.lower()))
            sum += len(l)
        elapsed = self.endtime-self.starttime
        status = "%s" % ", ".join(outlist)
        self.out.sep('=', 'tests finished: %s in %4.2f seconds' %
                         (status, elapsed))

    def getlastvisible(self, sourcetraceback): 
        traceback = sourcetraceback[:]
        while traceback: 
            entry = traceback.pop()
            try: 
                x = entry.frame.eval("__tracebackhide__") 
            except: 
                x = False 
            if not x: 
                return entry 
        else: 
            return sourcetraceback[-1]
        
    def skippedreasons(self):
        texts = {}
        for colitem, outcome in self.getitemoutcomepairs(Item.Skipped):
            raisingtb = self.getlastvisible(outcome.excinfo.traceback) 
            fn = raisingtb.frame.code.path
            lineno = raisingtb.lineno
            d = texts.setdefault(outcome.excinfo.exconly(), {})
            d[(fn,lineno)] = outcome 
                
        if texts:
            self.out.line()
            self.out.sep('_', 'reasons for skipped tests')
            for text, dict in texts.items():
                for (fn, lineno), outcome in dict.items(): 
                    self.out.line('Skipped in %s:%d' %(fn, lineno+1))
                self.out.line("reason: %s" % text) 
                self.out.line()

    def failures(self):
        l = self.getitemoutcomepairs(Item.Failed)
        if l: 
            self.out.sep('_')
            for colitem, outcome in l: 
                self.repr_failure(colitem, outcome) 

    def repr_failure(self, item, outcome): 
        excinfo = outcome.excinfo 
        traceback = excinfo.traceback
        #print "repr_failures sees item", item
        #print "repr_failures sees traceback"
        #py.std.pprint.pprint(traceback)
        if item: 
            self.cut_traceback(traceback, item) 
        if not traceback: 
            self.out.line("empty traceback from item %r" % (item,)) 
            return
        last = traceback[-1]
        first = traceback[0]
        recursioncache = {}
        for entry in traceback: 
            if entry == first: 
                if item: 
                    self.repr_failure_info(item, entry) 
                    self.out.line()
            else: 
                self.out.line("")
            if entry == last: 
                indent = self.repr_source(entry, 'E') 
                self.repr_failure_explanation(excinfo, indent) 
            else:
                self.repr_source(entry, '>') 
            self.out.line("") 
            self.out.line("[%s:%d]" %(entry.frame.code.path, entry.lineno+1))  
            self.repr_locals(entry) 

            # trailing info 
            if entry == last: 
                #if item: 
                #    self.repr_failure_info(item, entry) 
                self.repr_out_err(item) 
                self.out.sep("_")
            else: 
                self.out.sep("_ ")
                if not self.config.option.nomagic and excinfo.errisinstance(RuntimeError) \
                       and self.isrecursive(entry, recursioncache): 
                    self.out.line("Recursion detected (same locals & position)")
                    self.out.sep("!")
                    break 

    def isrecursive(self, entry, recursioncache): 
        # recursion detection 
        key = entry.frame.code.path, entry.frame.lineno 
        #print "checking for recursion at", key
        l = recursioncache.setdefault(key, [])
        if l: 
            f = entry.frame
            loc = f.f_locals
            for otherloc in l: 
                if f.is_true(f.eval(co_equal, 
                    __recursioncache_locals_1=loc,
                    __recursioncache_locals_2=otherloc)):
                    return True 
        l.append(entry.frame.f_locals)
        
    def repr_failure_info(self, item, entry): 
        root = item.fspath 
        modpath = item.getmodpath() 
        try: 
            fn, lineno = item.getpathlineno() 
        except TypeError: 
            assert isinstance(item.parent, py.test.collect.Generator) 
            # a generative test yielded a non-callable 
            fn, lineno = item.parent.getpathlineno() 
        # hum, the following overloads traceback output 
        #if fn != entry.frame.code.path or \
        #   entry.frame.code.firstlineno != lineno: 
        #    self.out.line("testcode: %s:%d" % (fn, lineno+1)) 
        if root == fn: 
            self.out.sep("_", "entrypoint: %s" %(modpath))
        else:
            self.out.sep("_", "entrypoint: %s %s" %(root.basename, modpath))

    def repr_source(self, entry, marker=">"): 
        try: 
            source = entry.getsource() 
        except py.error.ENOENT:
            self.out.line("[failure to get at sourcelines from %r]\n" % entry)
        else: 
            source = source.deindent() 
            for line in source[:-1]: 
                self.out.line("    " + line) 
            lastline = source[-1]
            self.out.line(marker + "   " + lastline) 
            try: 
                s = str(source.getstatement(len(source)-1))
            except KeyboardInterrupt: 
                raise 
            except: 
                #self.out.line("[failed to get last statement]\n%s" %(source,))
                s = str(source[-1])
            #print "XXX %r" % s
            return 4 + (len(s) - len(s.lstrip()))
        return 0 

    def cut_traceback(self, traceback, item=None): 
        if self.config.option.fulltrace or item is None:
            return
        newtraceback = traceback[:]
        path, lineno = item.getpathlineno() 
        for i, entry in py.builtin.enumerate(newtraceback): 
            if entry.frame.code.path == path: 
               last = i
               while i < len(newtraceback)-1: 
                    entry = newtraceback[i]
                    next = newtraceback[i+1]
                    if next.frame.code.path != path: 
                        break 
                    if entry.frame.code.firstlineno == lineno: 
                        break 
               del newtraceback[:i]
               break
        if not newtraceback: 
            newtraceback = traceback[:]
            
        # get rid of all frames marked with __tracebackhide__ 
        l = []
        for entry in newtraceback: 
            try: 
                x = entry.frame.eval("__tracebackhide__") 
            except: 
                x = None 
            if not x: 
                l.append(entry) 
        traceback[:] = l 

    def repr_failure_explanation(self, excinfo, indent): 

        indent = " " * indent 
        # get the real exception information out 
        lines = excinfo.exconly(tryshort=True).split('\n') 
        self.out.line('>' + indent[:-1] + lines.pop(0)) 
        for x in lines: 
            self.out.line(indent + x) 
        return

        # XXX reinstate the following with a --magic option? 
        # the following line gets user-supplied messages (e.g.
        # for "assert 0, 'custom message'")
        msg = getattr(getattr(excinfo, 'value', ''), 'msg', '') 
        info = None
        if not msg: 
            special = excinfo.errisinstance((SyntaxError, SystemExit, KeyboardInterrupt))
            if not self.config.option.nomagic and not special: 
                try: 
                    info = excinfo.traceback[-1].reinterpret() # very detailed info
                except KeyboardInterrupt:
                    raise
                except:
                    if self.config.option.verbose >= 1:
                        self.out.line("[reinterpretation traceback]")
                        py.std.traceback.print_exc(file=py.std.sys.stdout)
                    else:
                        self.out.line("[reinterpretation failed, increase "
                                      "verbosity to see details]")
        # print reinterpreted info if any 
        if info: 
            lines = info.split('\n') 
            self.out.line('>' + indent[:-1] + lines.pop(0)) 
            for x in lines: 
                self.out.line(indent + x) 

    def repr_out_err(self, colitem): 
        for parent in colitem.listchain(): 
            for name, obj in zip(['out', 'err'], parent.getouterr()): 
                if obj: 
                    self.out.sep("- ", "%s: recorded std%s" % (parent.name, name))
                    self.out.line(obj)
            
    def repr_locals(self, entry): 
        if self.config.option.showlocals:
            self.out.sep('- ', 'locals')
            for name, value in entry.frame.f_locals.items():
                if name == '__builtins__': 
                    self.out.line("__builtins__ = <builtins>")
                elif len(repr(value)) < 70 or not isinstance(value,
                                                (list, tuple, dict)):
                    self.out.line("%-10s = %r" %(name, value))
                else:
                    self.out.line("%-10s =\\" % (name,))
                    py.std.pprint.pprint(value, stream=self.out)

co_equal = compile('__recursioncache_locals_1 == __recursioncache_locals_2',
                   '?', 'eval')

def repr_pythonversion():
    v = py.std.sys.version_info
    try:
        return "%s.%s.%s-%s-%s" % v
    except ValueError:
        return str(v)
