#!/usr/bin/env python
# encoding: utf-8
"""
symmetricNN.py

This is the implementation of the weighted symmetric nearest neighbour algorithms described in "Missing Value Imputation for Epistatic MAPs"

Those only interested in understanding the algorithm should look at the performWNN function.

Further information and updates to this implementation will be available at : http://www.bioinformatics.org/emapimputation

For any queries please contact colm.ryan@ucd.ie

Version : 1.1 

$Rev:: 6             $:  Revision of last commit
$Author:: colmryan   $:  Author of last commit
$Date:: 2010-09-13 1#$:  Date of last commit

Copyright 2009 Colm Ryan 

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. 
You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, 
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 
and limitations under the License. 

"""

import sys
import getopt
import csv
import math
import random

DEFAULT_K = 50
MISSING = ""
SEPARATOR = "\t"
WEIGHTED = True

def get_ordered_tuple(item_a, item_b) :
    """
    Takes two items as input & returns an ordered tuple as the result.
    
    Useful for interactions so that (A,B) & (B,A) are not treated seperately
    
    """
    if cmp(item_a,item_b) <= 0:
        return (item_a, item_b)
    else :
        return (item_b, item_a)


def weight(correlation):
    """
    Implements the weighting scheme discussed in : "Missing Value Imputation for Epistatic MAPs" (http://mlg.ucd.ie/emapimputation)
    
    Given a value r denoting the Pearson correlation between a gene i and its neighbor i', the weight w(i,i') is calculated as follows:
    w(i,i') = \left(\frac{r^2}{1 - r^2 + \epsilon}\right)^2
    
    """
    return pow(pow(correlation,2) / (1 - pow(correlation,2) + 0.000001),2)


def pearson_correlation(gene_profiles,gene1, gene2) :
    """
    Returns the Pearson correlation between two genes
    
    gene_profiles maps each gene name to its interaction profile. 
    The interaction profiles are dictionaries, where the keys are gene names and the values are interaction scores.
    gene1 and gene2 are gene names
    
    """
    sum1 = 0 
    sum2 = 0 
    sum_squares_1 = 0
    sum_squares_2=0
    sum_of_products=0
    count=0
    
    for gene in gene_profiles[gene1] :
            if gene in gene_profiles[gene2] :
                sum1 = sum1 + gene_profiles[gene1][gene]
                sum2 = sum2 + gene_profiles[gene2][gene]
                sum_squares_1=sum_squares_1 + pow(gene_profiles[gene1][gene],2)
                sum_squares_2=sum_squares_2 + pow(gene_profiles[gene2][gene],2)
                sum_of_products = sum_of_products + gene_profiles[gene1][gene]*gene_profiles[gene2][gene]
                count = count + 1
    
    #return 0 if there are no genes in common
    if count == 0 :
        return 0
    
    top = sum_of_products - (sum1*sum2/count)
    bottom = math.sqrt((sum_squares_1 - pow(sum1,2)/count)*(sum_squares_2 - pow(sum2,2)/count))
    
    if bottom == 0 :
        return 0
    
    correlation = top / bottom   
    
    return correlation

def read_square_dataset(file_location, missing_string, separator) :
    """
    Reads in a square data matrix
    
    Returns a map of present interactions to scores, interaction profiles for each gene,
    and a list of names of genes present in the dataset.
    
    """
    interactions = {}
    
    gene_profiles = {}
    csv.register_dialect('custom', delimiter=separator)
    lines = list(csv.reader(open(file_location, 'r'),'custom'))
    
    gene_list = lines[0]
    gene_count = len(gene_list)
    
    
    for x in range(1,gene_count) :      
        gene_x = lines[x][0]
        
        for y in range(x+1,gene_count) :
            gene_y = gene_list[y]
            interaction = lines[x][y]
            if interaction != missing_string :
                score = float(interaction)
                xy = get_ordered_tuple(gene_x,gene_y)
                interactions[xy] = score
                
                if gene_x not in gene_profiles :
                    gene_profiles[gene_x] = {}

                if gene_y not in gene_profiles :
                    gene_profiles[gene_y] = {}

                gene_profiles[gene_y][gene_x] = score
                gene_profiles[gene_x][gene_y] = score

    gene_list = gene_list[1:]
    
    return [interactions, gene_profiles, gene_list]


def output_results(filename, separator, gene_list, scores,imputed) :
    """
    Outputs a square data matrix
    
    """    
    f = open(filename,'w')
    
    for i in range(len(gene_list)) :
        f.write("%s%s" % (separator,gene_list[i]))

    f.write('\n')   

    for i in range(len(gene_list)) :
        f.write("%s" % gene_list[i])
        for j in range(len(gene_list)) :
            f.write(separator)
            interaction = get_ordered_tuple(gene_list[i],gene_list[j])
            if interaction in scores :
                score = scores[interaction]             
                
            elif interaction in imputed :
                score = imputed[interaction] 
            else : # Only happens for self interactions(where geneI == geneJ)
                score = 0.0    
            f.write("%s" % score)    
        f.write('\n')   
    f.close()

    return 0

def output_ranked_results(filename, separator, imputed, weights) :
    """
    Outputs strong interactions, ranked by the weights of their neighbours

    """    
    f = open(filename,'w')
    ranked = sorted(weights.iteritems(), key=lambda (k,v): (v,k),reverse=True)
    f.write("GeneA\tGeneB\tScore\tWeight\n")
    for item in ranked :
        r = item[0]
        if imputed[r] <= -2.5 or imputed[r] > 2 :
            f.write("%s\t%s\t%s\t%s\n" % (r[0],r[1],imputed[r],item[1]))
    f.close()
    return 0
    
def performWNN(data, k, weighted):
    """
    Performs the symmetric nearest neighbour imputation described in 
    
    k - the number of neighbours used for the imputation
    weighted - a boolean, if false then standarad(unweighted) nearest neighbours is used
    
    Returns a dictionary of the imputed interactions.
    """
    interaction_scores = data[0]
    interaction_profiles = data[1]
    gene_list = data[2]
    
    #Remove genes with no interaction scores
    for gene in gene_list :
        if gene not in interaction_profiles:
            gene_list.remove(gene)
            print gene, "has no profile"

    missing_values = set()

    # Identify the missing values
    for gene_i in gene_list :
        for gene_j in gene_list :
            if gene_i != gene_j :
                ij = get_ordered_tuple(gene_i,gene_j)
                if ij not in interaction_scores :
                    missing_values.add(ij)  

    # Cache similarities, to prevent recalculation(AB is same as BA)
    similarities = {}

    # Find ordered list of neighbours for each gene
    neighbours = {}
    for gene_i in gene_list :
        ineighbours = {}
        for gene_j in gene_list:
            if gene_i!= gene_j: 
                ij = get_ordered_tuple(gene_i,gene_j)
                if ij in similarities:
                    similarity = similarities[ij]
                else:
                    similarity = pearson_correlation(interaction_profiles,gene_i,gene_j)
                    similarities[ij] = similarity
                
                ineighbours[gene_j] = similarity
            
        neighbourlist = sorted(ineighbours.iteritems(), key=lambda (k,v): (v,k),reverse=True)
        neighbours[gene_i] = neighbourlist  
    
    # Start the imputation.....
    imputed = {}
    weights = {}
    for interaction in missing_values :
        neighbour_count = 0.0
        score = 0.0
        ij = interaction
        i = interaction[0]
        j = interaction[1]
        ineighbours = 0
        if i in neighbours : 
            neighbour_list = neighbours[i]
            for ibar in neighbour_list :
                ibar_j = get_ordered_tuple(ibar[0],j)
                if ibar_j in interaction_scores :
                    if weighted :
                        neighbour_count = neighbour_count + weight(ibar[1])
                        score = score + (interaction_scores[ibar_j]*weight(ibar[1]))
                    else :
                        neighbour_count = neighbour_count + 1
                        score = score + interaction_scores[ibar_j]
                        
                    ineighbours = ineighbours + 1
                    if ineighbours >= k :
                        break        

        jneighbours = 0

        if j in neighbours :
            neighbour_list = neighbours[j]
            for jbar in neighbour_list :
                i_jbar = get_ordered_tuple(i,jbar[0])
                if i_jbar in interaction_scores :
                    if weighted :
                        neighbour_count = neighbour_count + weight(jbar[1])
                        score = score + (interaction_scores[i_jbar]*weight(jbar[1]))
                    else :
                        neighbour_count = neighbour_count + 1
                        score = score + interaction_scores[i_jbar]
                        
                    jneighbours = jneighbours + 1
                    if jneighbours >= k :
                        break           
        weights[ij] = neighbour_count
        imputed[ij] = score / neighbour_count
    
    return imputed,weights

#Below is the executable part of the program
    
help_message = '''
This is an implementation of the weighted symmetric nearest neighbour algorithm,
as described in "Missing Value Imputation for Epistatic MAPs"

Further information, and updates to this implementation will be available at : http://mlg.ucd.ie/emapimputation

Required parameters are as follows:
-i --input (filename): the E-MAP file to use as input

Optional parameters as follows:
-o --output (filename): the file to output the resulting complete matrix to
-r --ranked (filename): the file to output a ranked list of strong(score < -2.5 or score > 2.0) interactions to.
-k --neighbours (number): the number of neighbours to use for imputation, defaults to 50
-u --unweighted: run nearest neighbours without any weighting
-s --separator (character): the separator used in the input file, defaults to tab
-m --missing (string): the string to indicate a missing value, defaults to ""
-h --help: displays this message  

'''

class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg

def main(argv=None):
    k = DEFAULT_K
    weighted = True
    separator = SEPARATOR
    missing = MISSING
    ranked_output = False
    output = False
    test = False
    if argv is None:
        argv = sys.argv
    try:
        try:
            opts, args = getopt.getopt(argv[1:], "hi:o:k:us:m:r:v", ["help","input=","output=","neighbours=","unweighted","separator=","missing=","ranked="])
        except getopt.error, msg:
            raise Usage(msg)
    
        # option processing
        for option, value in opts:
            if option == "-v":
                verbose = True
            elif option in ("-h", "--help"):
                raise Usage(help_message)
            elif option in ("-i", "--input"):
                input_file = value
            elif option in ("-o", "--output"):
                output = value
            elif option in ("-k", "--neighbours"):
                k = int(value)
            elif option in ("-u", "--unweighted"):
                weighted = False
            elif option in ("-s", "--separator"):
                separator = value
            elif option in ("-m", "--missing"):
                missing = value
            elif option in ("-r", "--ranked"):
                ranked_output = value
            elif option in ("-t", "--test"):
                test = True       
        
        try :
            input_file
        except NameError :
            print "Input filename must be specified\n"
            raise Usage(help_message)
        else :
            #The runnable part of the program...
            data = read_square_dataset(input_file,missing,separator)
            imputed,weights = performWNN(data, k,weighted)
            if output :
                output_results(output,separator,data[2],data[0],imputed)
            if ranked_output :
                output_ranked_results(ranked_output,separator,imputed,weights)
            
    except Usage, err:
        print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + str(err.msg)
        print >> sys.stderr, "\t for help use --help"
        return 2              
    
    
    return 0

if __name__ == "__main__":
    sys.exit(main())
