"""
tree_compare.py -- module for comaring topologies of phylogenies
"""

def set_labels(n, d = None):
    """
    recursive function that adds a list of leaf node labels to node
    n's data attribute, and returns a mapping of the sorted lists to
    nodes
    """
    if d == None:
        d = {}
    labels = {}
    if n.istip:
        labels = {n.label: None}
    else:
        labels = {}
        children = n.children()
        for c in children:
            set_labels(c, d)
            labels.update(c['labels'])
    n['labels'] = labels
    tmp = labels.keys()
    tmp.sort()
    d[tuple(tmp)] = n
    return d

def tree_label_dict(n):
    return set_labels(n)

def __intersection(d1, d2):
    intersection = {}
    s1 = d1; s2 = d2
    if len(d1) > len(d2):
        s1 = d2; s2 = d1
    for x in s1.keys():
        if s2.has_key(x):
            intersection[x] = None
    return intersection

def __filter_treedict(lookup_dict, to_filter_out):
    d = {}
    for k, n in lookup_dict.items():
        if not d.has_key(k):
            k = list(k)
            for x in to_filter_out:
                try: k.remove(x) 
                except ValueError: pass
            tk = tuple(k)
            if (d.has_key(tk) and len(d[tk]['labels']) > len(n['labels']) \
                or (not d.has_key(tk))):
                d[tk] = n
    return d
            
def tree_compare(tree1, tree2):
    info = {}  # where to store the tree comparison data

    # lists of nodes in trees
    tree1_nodes = [tree1,]+tree1.descendants()
    tree2_nodes = [tree2,]+tree2.descendants()

    # mapping of leaf labels (lookup dicts) to nodes
    tree1_dict = set_labels(tree1)
    tree2_dict = set_labels(tree2)

    # lookup dicts of leaf labels at root of trees
    tree1_labels = tree1['labels']
    tree2_labels = tree2['labels']
    info['tree1 labels'] = tree1_labels.keys()
    info['tree2 labels'] = tree2_labels.keys()

    # lookup dict of labels in common between trees
    common_labels = __intersection(tree1_labels, tree2_labels).keys()
    info['common labels'] = common_labels

    # labels outside the intersection between the trees
    t1_copy = tree1_labels.keys()
    t2_copy = tree2_labels.keys()
    for x in common_labels:
        t1_copy.remove(x); t2_copy.remove(x)
    info['tree1 labels not in tree2'] = t1_copy
    info['tree2 labels not in tree1'] = t2_copy

    tree1_filtered = __filter_treedict(tree1_dict, t1_copy)
    tree2_filtered = __filter_treedict(tree2_dict, t2_copy)

    tree1_exact_matches = {}
    tree1_consistent_matches = {}
    for n in tree1_nodes:
        if not n.istip:
            n_labels = n['labels']
            n_size = len(n_labels)
            exact_match = tree2_dict.get(tuple(n_labels.keys()))
            if exact_match:
                tree1_exact_matches[n] = exact_match
            tmp_labels = __intersection(n_labels, tree2_labels)
            consistent_match = tree2_filtered.get(tuple(tmp_labels.keys()))
            if consistent_match and len(tmp_labels) > 1:
                tree1_consistent_matches[n] = consistent_match

    tree2_exact_matches = {}
    tree2_consistent_matches = {}
    for n in tree2_nodes:
        if not n.istip:
            n_labels = n['labels']
            n_size = len(n_labels)
            exact_match = tree2_dict.get(tuple(n_labels.keys()))
            if exact_match:
                tree2_exact_matches[n] = exact_match
            tmp_labels = __intersection(n_labels, tree1_labels)
            consistent_match = tree1_filtered.get(tuple(tmp_labels.keys()))
            if consistent_match and len(tmp_labels) > 1:
                tree2_consistent_matches[n] = consistent_match

    info['tree1 exact matches'] = tree1_exact_matches
    info['tree1 consistent matches'] = tree1_consistent_matches
    info['tree2 exact matches'] = tree2_exact_matches
    info['tree2 consistent matches'] = tree2_consistent_matches

    return info

if __name__ == '__main__':
    from Mavric import newick
    import gtk, GTK, GDK, gnome.ui
    import Mavric.gui

    w = gtk.GtkWindow()
    hp = gtk.GtkHPaned()
    w.add(hp)
    w.set_usize(400,400)
    w.connect('delete_event', gtk.mainquit)
    
    basepath = '/home/rree/consed'
    tf1 = open(basepath+'/analysis/its/its.culled+mollis.recoded.cons')
    tf2 = open(basepath+'/analysis/its/its.culled+mollis.gapmiss.cons')

    t1 = newick.parse(tf1.read())
    t2 = newick.parse(tf2.read())

    tv1 = Mavric.gui.treeview.TreeView(t1)
    tv2 = Mavric.gui.treeview.TreeView(t2)
    hp.add1(tv1)
    hp.add2(tv2)

    cmp_info = tree_compare(t1, t2)
##     for k, v in cmp_info.items():
##         print k, len(v)

    for n in cmp_info['tree2 exact matches'].keys():
        try:
            tv2.canvas.node2cbranch[n].adorn_child('green')
        except KeyError:
            pass
    for n in cmp_info['tree1 exact matches'].keys():
        try:
            tv1.canvas.node2cbranch[n].adorn_child('red')
        except KeyError:
            pass
    w.show_all()
    gtk.mainloop()

