##############################################################################
#
# Python MemAid core <Peter.Bienstman@ugent.be>
#
##############################################################################

from pyqt_memaid.ann import *
import random, time, os, string, sys, cPickle, md5, struct



##############################################################################
#
# Global variables.
#
##############################################################################

time_of_start = None
import_time_of_start = None

items = []
import_items = []
revision_queue = []

categories = []
category_by_name = {}



##############################################################################
#
# Configuration parameters.
#
##############################################################################

config = {}
config["path"] = os.path.expanduser("~/.memaid/default.mem")
config["import_dir"] = os.path.expanduser("~/.memaid/")
config["drill_badly_known"] = False
config["threshold"] = 3
config["hide_toolbar"] = False
config["font"] = None
config["swap_QA"] = False
config["run_exec_tags"] = False



##############################################################################
#
# get_config
#
##############################################################################

def get_config(key):
    return config[key]



##############################################################################
#
# set_config
#
##############################################################################

def set_config(key, value):
    global config
    config[key] = value



##############################################################################
#
# load_config
#
##############################################################################

def load_config():
    global config
    config_file = file(os.path.expanduser("~/.memaid/config"), 'r')
    config = cPickle.load(config_file)



##############################################################################
#
# save_config
#
##############################################################################

def save_config():
    config_file = file(os.path.expanduser("~/.memaid/config"), 'w')
    cPickle.dump(config, config_file)



##############################################################################
#
# StartTime
#
##############################################################################

class StartTime:

    def __init__(self, start_time):

        # Reset to 3.30 am

        t = time.gmtime(start_time)
        self.time = time.mktime([t[0],t[1],t[2], 3,30,0, t[6],t[7],t[8]])

    def days_since(self):
        return long( (time.time() - self.time) / 60. / 60. / 24. )

        
    
##############################################################################
#
# Item
#
##############################################################################

class Item:

    ##########################################################################
    #
    # __init__
    #
    ##########################################################################

    def __init__(self):

        self.id        = 0
        self.tm_t_rpt  = 0
        self.stm_t_rpt = 0
        self.l_ivl     = 0
        self.rl_l_ivl  = 0
        self.ivl       = 0
        self.rp        = 0
        self.gr        = 0
        self.cat       = None
        
    ##########################################################################
    #
    # new_id
    #
    #   Convert the first 4 bytes of an MD5 hash to an id value.
    #
    ##########################################################################
    
    def new_id(self):
        digest = md5.new(self.q.encode("utf-8") +
                         self.a.encode("utf-8") + time.ctime()).digest()
        self.id = struct.unpack('L', digest[0:4])[0]
    
    ##########################################################################
    #
    # is_due
    #
    ##########################################################################
    
    def is_due(self):
        return time_of_start.days_since() >= self.tm_t_rpt
        
    ##########################################################################
    #
    # real_interval
    #
    ##########################################################################
    
    def real_interval(self):      
        return time_of_start.days_since() - self.tm_t_rpt + self.ivl
        
    ##########################################################################
    #
    # change_category
    #
    ##########################################################################
    
    def change_category(self, new_cat_name):

        global categories, category_by_name

        # Case 1: a new category was created.

        if new_cat_name not in category_by_name.keys():
            cat = Category(new_cat_name)
            categories.append(cat)
            category_by_name[new_cat_name] = cat

        old_cat = self.cat
        self.cat = category_by_name[new_cat_name]
    
        # Case 2: deleted last item of old_cat.

        if old_cat.in_use() == False:
            del category_by_name[old_cat.name]
            categories.remove(old_cat)



##############################################################################
#
# item_compare
#
##############################################################################

def item_compare(x, y):
    return int(x.tm_t_rpt - y.tm_t_rpt)



##############################################################################
#
# get_items
#
##############################################################################

def get_items():
    return items



##############################################################################
#
# get_item_by_id
#
##############################################################################

def get_item_by_id(id):
    for item in items:
        if item.id == id:
            return item



##############################################################################
#
# number_of_items
#
##############################################################################

def number_of_items():
    return len(items)


    
##############################################################################
#
# reps_for_today
#
##############################################################################

def reps_for_today():

    if len(revision_queue) == 0:
        rebuild_revision_queue()

    reps = 0
    for e in revision_queue:
        if ( e.is_due() and e.cat.scheduled == True ):
                reps += 1
    return reps



##############################################################################
#
# Category
#
##############################################################################

class Category:
    
    ##########################################################################
    #
    # __init__
    #
    ##########################################################################
    
    def __init__(self, name, scheduled=True, badly_known=True):

        self.name = name
        self.scheduled = scheduled
        self.badly_known = badly_known

    ##########################################################################
    #
    # in_use
    #
    ##########################################################################

    def in_use(self):

        used = False

        for e in items:
            if self.name == e.cat.name:
                used = True
                break

        return used



##############################################################################
#
# get_categories
#
##############################################################################

def get_categories():
    global categories
    return categories



##############################################################################
#
# new_database
#
##############################################################################

def new_database(path):

    global config, time_of_start
    
    unload_database()

    time_of_start = StartTime(time.time())
    config["path"] = path

    save_database(path)



##############################################################################
#
# load_database
#
##############################################################################

def load_database(path):

    global config, time_of_start, categories, category_by_name, items

    if list_is_loaded():
        unload_database()

    nn_init()

    try:    
        infile = file(path)

        db = cPickle.load(infile)
        
        time_of_start = db[0]
        categories    = db[1]
        items         = db[2]

        infile.close()
    except:
        return False

    for c in categories:
        if not c.in_use():
            categories.remove(c)
        else:
            category_by_name[c.name] = c

    # Sometimes the category links seem to get corrupt.
    # Uncommenting the following lines rebuilds the connections.
    
    #for i in items:
    #    i.cat = category_by_name[i.cat.name]
    
    config["path"] = path
    
    return True



##############################################################################
#
# save_database
#
##############################################################################

def save_database(path):

    global config

    try:
        outfile = file(path,'w')

        db = [time_of_start, categories, items]
        cPickle.dump(db, outfile)

        outfile.close()
    except:
        return False

    config["path"] = path
    
    return True



##############################################################################
#
# unload_database
#
##############################################################################

def unload_database():

    global items, revision_queue, categories, category_by_name

    if list_is_loaded() == False:
        return
    
    nn_deinit()
        
    status = save_database(config["path"])
    if status == False:
        return False
        
    items = []
    revision_queue = []
        
    categories = []
    category_by_name = {}
    return True



##############################################################################
#
# list_is_loaded
#
##############################################################################

def list_is_loaded():
    return len(items) != 0



##############################################################################
#
# escape
#
#   Escapes literal < (unmatched tag) and new line from string.
#
##############################################################################

def escape(old_string):
    
    hanging = []
    open = 0
    pending = 0

    for i in range(len(old_string)):
        if old_string[i] == '<':
            if open != 0:
                hanging.append(pending)
                pending = i
                continue
            open += 1
            pending = i
        elif old_string[i] == '>':
            if open > 0:
                open -= 1

    if open != 0:
        hanging.append(pending)

    new_string = ""
    for i in range(len(old_string)):
        if old_string[i] == '\n':
            new_string += "<br>"
        elif i in hanging:
            new_string += "&lt;"
        else:
            new_string += old_string[i]

    return new_string



##############################################################################
#
# write_item_XML
#
##############################################################################

def write_item_XML(e, outfile):
    print >> outfile, "<item id=\""+str(e.id) + "\"" \
                         + " tm_t_rpt=\""+str(e.tm_t_rpt) + "\"" \
                         + " stm_t_rpt=\""+str(e.stm_t_rpt) + "\"" \
                         + " l_ivl=\""+str(e.l_ivl) + "\"" \
                         + " rl_l_ivl=\""+str(e.rl_l_ivl) + "\""  \
                         + " ivl=\""+str(e.ivl) + "\"" \
                         + " rp=\""+str(e.rp) + "\"" \
                         + " gr=\""+str(e.gr) + "\">"
    print >> outfile, " <cat><![CDATA["+e.cat.name+"]]></cat>"
    print >> outfile, " <Q><![CDATA["+e.q+"]]></Q>"
    print >> outfile, " <A><![CDATA["+e.a+"]]></A>"
    print >> outfile, "</item>"



##############################################################################
#
# write_category_XML
#
##############################################################################

def write_category_XML(category, outfile):
    print >> outfile, "<category scheduled=\""+str(category.scheduled)\
          +"\" drill_badly_known=\""+str(category.badly_known)+"\">"
    print >> outfile, " <name><![CDATA["+category.name+"]]></name>"
    print >> outfile, "</category>"

    

##############################################################################
#
# export_XML
#
##############################################################################

def export_XML(path, cat_names_to_export):

    outfile = file(path,'w')

    print >> outfile, """<?xml version="1.0"?>"""
    print >> outfile, "<memaid core_version=\"8\" time_of_start=\""\
                      +str(long(time_of_start.time))+"\">"
    
    for cat in categories:
        if cat.name in cat_names_to_export:
            write_category_XML(cat, outfile)

    for e in items:
        if e.cat.name in cat_names_to_export:
            write_item_XML(e, outfile)

    print >> outfile, """</memaid>"""

    outfile.close()



##############################################################################
#
# XML_Importer
#
##############################################################################

from xml.sax import saxutils, make_parser
from xml.sax.handler import feature_namespaces
class XML_Importer(saxutils.DefaultHandler):

    def __init__(self, default_cat=None):
        self.reading, self.text = {}, {}
        
        self.reading["cat"] = False
        self.reading["Q"]   = False
        self.reading["A"]   = False

        self.default_cat = default_cat

    def to_bool(self, string):
        if string == '0':
            return False
        else:
            return True
    
    def startElement(self, name, attrs):
        global import_time_of_start
        
        if name == "memaid":
            import_time_of_start = StartTime(long(attrs.get("time_of_start")))
        elif name == "item":
            self.e = Item()
            
            self.e.id        = long(attrs.get("id"))
            
            self.e.tm_t_rpt  = int(attrs.get("tm_t_rpt"))
            self.e.stm_t_rpt = int(attrs.get("stm_t_rpt"))
            self.e.l_ivl     = int(attrs.get("l_ivl"))
            self.e.rl_l_ivl  = int(attrs.get("rl_l_ivl"))
            self.e.ivl       = int(attrs.get("ivl"))
            self.e.rp        = int(attrs.get("rp"))
            self.e.gr        = int(attrs.get("gr"))
            
        elif name == "category":
            self.scheduled   = self.to_bool(attrs.get("scheduled"))
            self.badly_known = self.to_bool(attrs.get("drill_badly_known"))
        else:
            self.reading[name] = True
            self.text[name] = ""

    def characters(self, ch):
        for name in self.reading.keys():
            if self.reading[name] == True:
                self.text[name] += ch

    def endElement(self, name):

        global import_items, categories, category_by_name
    
        self.reading[name] = False
       
        if name == "A":

            self.e.q = self.text["Q"]
            self.e.a = self.text["A"]

            if "cat" in self.text.keys():
                cat_name = self.text["cat"]
                if not cat_name in category_by_name.keys():
                    new_cat = Category(cat_name)
                    categories.append(new_cat)
                    category_by_name[cat_name] = new_cat
                self.e.cat = category_by_name[cat_name]
            else:
                self.e.cat = self.default_cat

            if self.e.id == 0:
                self.e.new_id()

            import_items.append(self.e)

        elif name == "name":

            if self.text["name"] not in category_by_name.keys():
                cat = Category(self.text["name"],
                               self.scheduled, self.badly_known)
                categories.append(cat)
                category_by_name[self.text["name"]] = cat

                

##############################################################################
#
# import_XML
#
##############################################################################

def import_XML(filename, default_cat_name):

    global import_items, categories, category_by_name

    # If no database is active, create one.

    if not time_of_start:
        new_database(config["path"])

    # Create default category if necessary.

    if default_cat_name not in category_by_name.keys():
        default_cat = Category(default_cat_name)
        categories.append(default_cat)
        category_by_name[default_cat_name] = default_cat
    else:
        default_cat = category_by_name[default_cat_name]

    # Parse XML file.

    parser = make_parser()
    parser.setFeature(feature_namespaces, 0)
    parser.setContentHandler(XML_Importer(default_cat))

    try:
        parser.parse(file(filename))
    except:
        return False

    # Calculate offset with current start date.
    
    start_date_0 = time_of_start.time
    start_date_1 = import_time_of_start.time
    
    offset = long(round((start_date_0 - start_date_1) / 60. / 60. / 24.))
        
    # Adjust timings.

    if offset <= 0:
        for e in import_items:
            e.tm_t_rpt  += offset
            e.stm_t_rpt += offset
    else:
        for e in items:
            e.tm_t_rpt  -= offset
            e.stm_t_rpt -= offset

    if start_date_1 < start_date_0:
        time_of_start.time = start_date_1

    # Add new items.
    
    for e in import_items:
        items.append(e)
    
    items.sort(item_compare)

    # Clean up.

    if default_cat.in_use() == False:
        del category_by_name[default_cat.name]
        categories.remove(default_cat)
        
    import_items = []

    return True



##############################################################################
#
# add_new_item
#
##############################################################################

def add_new_item(grade, question, answer, cat_name):

    global items, categories, category_by_name

    if cat_name not in category_by_name.keys():
        cat = Category(cat_name)
        categories.append(cat)
        category_by_name[cat_name] = cat
    else:
        cat = category_by_name[cat_name]

    interval = ma_new_interval(0,0,0,grade)
    
    e = Item()

    e.tm_t_rpt  = time_of_start.days_since() + interval
    e.stm_t_rpt = time_of_start.days_since() + interval
    e.ivl       = interval
    e.gr        = grade
    e.q         = question
    e.a         = answer
    e.cat       = cat

    e.new_id()

    items.append(e)
    items.sort(item_compare)



##############################################################################
#
# delete_item
#
##############################################################################

def delete_item(e):

    old_cat = e.cat
    
    items.remove(e)
    rebuild_revision_queue()
    
    if old_cat.in_use() == False:
        del category_by_name[old_cat.name]
        categories.remove(old_cat)



##############################################################################
#
# rebuild_revision_queue
#
##############################################################################

filtered_questions = {}

def rebuild_revision_queue():
            
    global revision_queue, filtered_questions
    revision_queue = []

    if len(items) == 0:
        return

    if not items[0].is_due() and config["drill_badly_known"] == True:
        for e in items:
            if e.gr <= config["threshold"] and e.cat.badly_known == True:
                revision_queue.append(e)
    else:
        for e in items:
            if ( e.is_due() and e.cat.scheduled == True ):
                revision_queue.append(e)
                
    # Filter out clearly inverse (and identical) problems from the queue,
    # leaving the one that is scheduled earlier since asking them in the
    # same session would skew the difficulty of the one that is asked later and
    # in doing so, schedule it badly and train the ANN incorrectly.

    i = 0
    rq = revision_queue
    while i <= len(rq):
        for j in range(len(rq)-1, i, -1):
            if (rq[i].q == rq[j].q and rq[i].a == rq[j].a) or \
               (rq[i].q == rq[j].a and rq[i].a == rq[j].q):
                removed_id = rq[j].id
                if rq[j].tm_t_rpt < rq[i].tm_t_rpt:
                    removed_id = rq[i].id
                    rq[i] = rq[j]
                if not filtered_questions.has_key(removed_id):
                    filtered_questions[removed_id] = time.time()
                del rq[j]
        i += 1

    # Finally filter out questions that have been previously filtered
    # less than 3 hours (arbitrary!) before as this function is
    # called more than once during each execution.
   
    for i in range(len(rq)-1, -1, -1):
        if filtered_questions.has_key(rq[i].id):
            if time.time() > filtered_questions[rq[i].id] + 60*60*3:
                del filtered_questions[rq[i].id]
            else:
                del rq[i]



##############################################################################
#
# in_revision_queue
#
##############################################################################

def in_revision_queue(item):
    return item in revision_queue



##############################################################################
#
# get_new_question
#
##############################################################################

def get_new_question():
            
    # Populate list if it is empty.
        
    if len(revision_queue) == 0:
        rebuild_revision_queue()
        if len(revision_queue) == 0:
            return None

    # Pick a random question and remove it from the queue.

    item = random.choice(revision_queue)
    revision_queue.remove(item)
    return item



##############################################################################
#
# process_answer
#
##############################################################################

def process_answer(item, grade):

    if item.is_due() == True \
        or (config["drill_badly_known"] == True and \
            grade > config["threshold"]):

	feedback_to_ann(item.l_ivl, item.rl_l_ivl, item.rp, item.gr,
                        item.ivl, item.real_interval(), grade)

        new_ivl = ma_new_interval(item.ivl, item.real_interval(),
                                  item.rp, grade)
	rl_ivl = item.real_interval()
	n_tm = time_of_start.days_since() + new_ivl
        
        item.rp += 1
        item.stm_t_rpt = n_tm
	item.tm_t_rpt = n_tm
	item.l_ivl = item.ivl
	item.rl_l_ivl = rl_ivl
	item.ivl = new_ivl
	item.gr = grade
        
        items.sort(item_compare)



##############################################################################
#
# Train neural net for a number of seconds.
#
##############################################################################

def train_ann_for_x_secs(x):
    ma_train_ann_for_x_secs(x)


        


    
