# Mavric -- a module for manipulating and visualizing phylogenies

# Copyright (C) 2000 Rick Ree
# Email : rree@oeb.harvard.edu
# 	   
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2 
# of the License, or (at your option) any later version.
#   
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details. 
# 
# You should have received a copy of the GNU General Public License
# along with this program; if not, write to the Free Software
# Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA  02111-1307, USA.

#from toolkit import *
from constants import *
from Mavric import phylo
nodes_to_tips = phylo.nodes_to_tips
length_to_tips = phylo.length_to_tips

def layout_node(view_id, root, x1,y1, x2,y2, label_width_func,
                scaled, min_unit_hor=15, min_unit_ver=15, direction=EAST):

    """ Position nodes of a horizontally-growing tree in the rectangle
    bounded by x1, y1 and x2, y2, with minimum unscaled branch length
    = min_unit_hor, and minimum vertical distance between leaf nodes =
    min_unit_ver.  Parameter scaled is a boolean flag to specify
    whether to scale the layout so that branch lengths are drawn to
    proportion.  Parameter label_width_func is a callable object that
    should return the width|height of a string argument, and is used
    to determine how much space to allocate for drawing the leaf
    labels (taxon names).

    The x, y coordinates calculated for each node are placed in the
    node's data dictionary, with parameter view_id as the key.

    The function returns the 'unit' horizontal and vertical distances
    computed in the layout.

    This should be the only function you need to call from this
    module."""

    assert direction in (EAST, WEST)
    #direction = not direction
    
    rooted = 1
    if root.back: rooted = 0

    max_east_label = max_leaf_label_width(root, label_width_func)
    max_east_ntt = max(nodes_to_tips(root))

    x2 = x2 - max_east_label
    min_req_hor = max_east_ntt*min_unit_hor

    if (x2 - min_req_hor) < x1:
        x2 = x1 + min_req_hor

    if rooted:
        return layout_rooted(view_id, root, max_east_ntt, x1,y1, x2,y2,
                             scaled, min_unit_hor, min_unit_ver, direction)
    
    else:
        assert root.back
        if direction == WEST: root = root.back
        
        max_west_label = max_leaf_label_width(root.back, label_width_func)
        max_west_ntt = max(nodes_to_tips(root.back))
        
        x1 = x1 + max_west_label
        min_req_hor = (max_east_ntt+max_west_ntt+1)*min_unit_hor
        if (x2 - min_req_hor) < x1:
            x2 = x1 + min_req_hor
        return layout_unrooted(view_id, root, max_west_ntt, max_east_ntt,
                               x1,y1, x2,y2, scaled,
                               min_unit_hor, min_unit_ver)
        

def layout_unrooted(view_id, root, max_west_ntt, max_east_ntt,
                    x1,y1, x2,y2, scaled,
                    min_unit_hor, min_unit_ver):
    east_leaves = root.leaves() or [root]
    west_leaves = root.back.leaves() or [root.back]

    east_numlvs = len(east_leaves)
    west_numlvs = len(west_leaves)
    max_numlvs = max(east_numlvs, west_numlvs)

    total_ntt = max_east_ntt+max_west_ntt

    unit_ver = (y2-y1)/(max_numlvs-1)
    if unit_ver < min_unit_ver: unit_ver = min_unit_ver

    unit_hor = (x2-x1)/(total_ntt+1)
    if unit_hor < min_unit_hor: unit_hor = min_unit_hor

    scale_factor = None
    east_px = None; west_px = None
    if scaled:
        max_width = (total_ntt+1)*unit_hor
        east_ltt = length_to_tips(root)
        max_east_ltt = max(east_ltt)
        west_ltt = length_to_tips(root.back)
        max_west_ltt = max(west_ltt)
        for x in (root.length, root.back.length, 1.0):
            if x != None:
                root_length = x
                break
        #root_length = root.length or root.back.length or 1.0
        max_length = max_east_ltt + max_west_ltt + root_length
        scale_factor = max_width/max_length
        east_px = x2 - max_east_ltt*scale_factor
        west_px = x1 + max_west_ltt*scale_factor

    west_x1 = x1
    if not scale_factor:
        west_x2 = x1 + max_west_ntt*unit_hor
    else:
        west_x2 = west_px
    west_y1 = y1; west_y2 = y2

    root_back_x, root_back_y = \
                 layout_recursive(view_id, root.back, root.back,
                                  west_x1, west_y1,
                                  west_x2, west_y2,
                                  unit_hor, unit_ver,
                                  WEST,
                                  scale_factor, west_px)

    if not scale_factor:
        east_x1 = x2 - max_east_ntt*unit_hor
    else:
        east_x1 = east_px
    east_x2 = x2
    east_y1 = y1; east_y2 = y2
 
    root_x, root_y = layout_recursive(view_id, root, root,
                                      east_x1, east_y1,
                                      east_x2, east_y2,
                                      unit_hor, unit_ver,
                                      EAST,
                                      scale_factor, east_px)

    if east_numlvs >= west_numlvs:
        y_offset = root_y - root_back_y
        move(root.back, view_id, 0, y_offset)
    else:
        y_offset = root_back_y - root_y
        move(root, view_id, 0, y_offset)

    y_min = min(west_leaves[-1][view_id][1], east_leaves[0][view_id][1])
    y_max = max(west_leaves[0][view_id][1], east_leaves[-1][view_id][1])

    delta = (y_max-y_min) - (y2-y1)
    if delta > 0:
        adj = (y2-y1)/(y_max-y_min)
        unit_ver = unit_ver*adj
        if unit_ver < min_unit_ver: unit_ver = min_unit_ver

        root_back_x, root_back_y = \
                     layout_recursive(view_id, root.back, root.back,
                                      west_x1, west_y1,
                                      west_x2, west_y2,
                                      unit_hor, unit_ver,
                                      WEST,
                                      scale_factor, west_px)
        
        root_x, root_y = layout_recursive(view_id, root, root,
                                          east_x1, east_y1,
                                          east_x2, east_y2,
                                          unit_hor, unit_ver,
                                          EAST,
                                          scale_factor, east_px)

        if east_numlvs >= west_numlvs:
            y_offset = root_y - root_back_y
            move(root.back, view_id, 0, y_offset)
        else:
            y_offset = root_back_y - root_y
            move(root, view_id, 0, y_offset)

    return unit_hor, unit_ver, scale_factor
    
def layout_rooted(view_id, root, max_ntt, x1,y1, x2,y2, scaled,
                  min_unit_hor, min_unit_ver, direction):
    leaves = root.leaves()
    numlvs = len(leaves)

    unit_ver = (y2-y1)/(numlvs-1)
    if unit_ver < min_unit_ver: unit_ver = min_unit_ver

    unit_hor = (x2-x1)/max_ntt
    if unit_hor < min_unit_hor: unit_hor = min_unit_hor

    scale_factor = None; parent_x = None
    if scaled:
        max_width = max_ntt*unit_hor
        ltt = length_to_tips(root)
        max_length = max(ltt)
        scale_factor = max_width/max_length
        parent_x = x1

    layout_recursive(view_id, root, root,
                     x1,y1, x2,y2,
                     unit_hor, unit_ver,
                     direction,
                     scale_factor, parent_x)

    return unit_hor, unit_ver, scale_factor

def layout_recursive(view_id, node, root,
                     x1,y1, x2,y2,
                     unit_hor, unit_ver,
                     direction,
                     scale_factor=None, parent_x=None):
    if scale_factor:
        # calculate x position of node
        # with scaled branches on
        if node == root:
            node_length = 0.0
        else:
            for x in (node.length, node.back.length, 1.0):
                if x != None:
                    node_length = x
                    break
            #node_length = node.length or node.back.length or 1.0

        if direction == EAST:
            x = parent_x+(node_length*scale_factor)
        else:
            x = parent_x-(node_length*scale_factor)
        parent_x = x
        
    if not node.istip:
        children = node.children()
        child_coords = []
        for child in children:
            cx, cy = layout_recursive(view_id, child, root,
                                      x1,y1, x2,y2,
                                      unit_hor, unit_ver,
                                      direction,
                                      scale_factor, parent_x)
            child_coords.append((cx, cy))


        # y-coord of internal node is midpoint of 1st and last
        # child's y-coords
        cy1 = child_coords[0][1]; cy2 = child_coords[-1][1]
        y = min(cy1, cy2) + (abs(cy2-cy1)/2)

        if not scale_factor:
            # x-coord of internal node is min|max of children's x-coords,
            # minus|plus the unit horizonal
            if direction == EAST:
                x = min(map(lambda x: x[0], child_coords))-unit_hor
            else:
                x = max(map(lambda x: x[0], child_coords))+unit_hor

    else: # tip node case
        if direction == EAST:
            if not scale_factor: x = x2  # right border
            y = y2  # bottom border
        else:
            if not scale_factor: x = x1  # left border
            y = y1  # top border

        nptr = node.back
        while nptr.next != root:
            if nptr == root: break
            if nptr.istip:
                if direction == EAST:
                    y = y-unit_ver
                else:
                    y = y+unit_ver
                nptr = nptr.back
            else: nptr = nptr.next.back

    node[view_id] = (x, y)
    return (x, y)

def max_leaf_label_width(node, label_width_func):
    n = node; max_width = 0
    if n.istip: return label_width_func(n.label)
    while n.next != node:
        if n.istip:
            assert n.label
            w = label_width_func(n.label)
            if w > max_width:
                max_width = w
            n = n.back
        else:
            n = n.next.back
    return max_width

def move(node, view_id, x_offset, y_offset):
    nx, ny = node[view_id]
    node[view_id] = (nx+x_offset, ny+y_offset)
    if not node.istip:
        for child in node.children():
            move(child, view_id, x_offset, y_offset)

    
