#!/usr/bin/env python
# encoding: utf-8
"""
symmetricNN.py

This is the implementation of the weighted symmetric nearest neighbour algorithms described in "Missing Value Imputation for Epistatic MAPs"

Those only interested in understanding the algorithm should look at the performWNN function.

Further information and updates to this implementation will be available at : http://www.bioinformatics.org/emapimputation

For any queries please contact colm.ryan@ucd.ie

Version : 1.1 

$Rev:: 6             $:  Revision of last commit
$Author:: colmryan   $:  Author of last commit
$Date:: 2010-09-13 1#$:  Date of last commit

Copyright 2009 Colm Ryan 

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. 
You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, 
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 
and limitations under the License. 

"""

import sys
import getopt
import csv
import math
import numpy as np
import numpy.ma as ma

DEFAULT_K = 20
MISSING = ""
SEPARATOR = "\t"

def get_ordered_tuple(item_a, item_b) :
    """
    Takes two items as input & returns an ordered tuple as the result.
    
    Useful for interactions so that (A,B) & (B,A) are not treated seperately
    
    """
    if cmp(item_a,item_b) <= 0:
        return (item_a, item_b)
    else :
        return (item_b, item_a)

def read_square_dataset(file_location, missing_string, separator) :
    """
    Reads in a square data matrix
    
    Returns a map of present interactions to scores, interaction profiles for each gene,
    and a list of names of genes present in the dataset.
    
    """
    interactions = {}
    
    gene_profiles = {}
    csv.register_dialect('custom', delimiter=separator)
    lines = list(csv.reader(open(file_location, 'r'),'custom'))
    
    gene_list = lines[0]
    gene_count = len(gene_list)
    
    
    for x in range(1,gene_count) :      
        gene_x = lines[x][0]
        
        for y in range(x+1,gene_count) :
            gene_y = gene_list[y]
            interaction = lines[x][y]
            if interaction != missing_string :
                score = float(interaction)
                xy = get_ordered_tuple(gene_x,gene_y)
                interactions[xy] = score
                
                if gene_x not in gene_profiles :
                    gene_profiles[gene_x] = {}

                if gene_y not in gene_profiles :
                    gene_profiles[gene_y] = {}

                gene_profiles[gene_y][gene_x] = score
                gene_profiles[gene_x][gene_y] = score

    gene_list = gene_list[1:]
    
    return [interactions, gene_profiles, gene_list]


def LLSWrapper(data, k):
    """
    Reads in a square data matrix, removes row and column headers 
    and converts the matrix to numpy.array format.
    

    """
    interactions = data[0]
    interaction_profiles = data[1]
    gene_list = data[2]
    
    gene_map = {}
    for i in range(len(gene_list)) :
        gene_map[gene_list[i]] = i
        
    imputed = {}
    matrix = []
    counts = {}
    missing_values = set()
    
    #set missing values to 999.0
    for geneI in gene_list :
        gene_count = 0
        Iarray = []
        for geneJ in gene_list :
            interaction = get_ordered_tuple(geneI,geneJ)
            if interaction in interactions :
                Iarray.append(interactions[interaction])
                gene_count+=1
            else :
                Iarray.append(999.0)
                missing_values.add(interaction)   
        matrix.append(Iarray)
    
    filled = performLLS(np.asarray(matrix),999.0, k)
    
    for pair in missing_values:
        x = gene_map[pair[0]]
        y = gene_map[pair[1]]
        imputed[pair] = filled[x][y]
    
    return imputed    

def performLLS(input_matrix, missing_string, k) :
    """
    Reads in a square data matrix

    Returns a numpy.ma array of interactions

    """
    #Read data & create masked array
    masked_input = ma.masked_equal(input_matrix,missing_string)
    gene_count = len(input_matrix)
    missing_values = np.where(input_matrix==missing_string)

    #Get similarity matrix
    similarity = ma.corrcoef(masked_input)
    similarity = ma.filled(similarity,fill_value=0.0)

    #Find sorted neighbour list for each gene
    neighbours = {}
    for i in range(gene_count) :
        #get list of neighbours in reverse order
        neighbours[i] = np.argsort(np.absolute(similarity[i]))[::-1]
        #remove self from neighbour list
        neighbours[i] = np.delete(neighbours[i],np.where(neighbours[i]==i))

    # Calculate average for each gene
    gene_averages = np.zeros(gene_count)
    for i in range(len(masked_input)) :
        gene = masked_input[i]
        mean = ma.mean(gene)
        if mean != mean : #if mean not a real number
            mean = 0
        gene_averages[i] = mean

    # Initial estimate of missing value(row & column average)
    filled = input_matrix.copy()    
    for i in range(len(missing_values[0])) :
        x = missing_values[0][i]
        y = missing_values[1][i]
        filled[x][y] = (gene_averages[x] + gene_averages[y]) / 2

    estimates = {}
    unused_rows = []
    for i in range(gene_count) :
        missing = np.where(input_matrix[i]==missing_string)[0]
        non_missing = np.where(input_matrix[i]!=missing_string)[0]
        if (len(missing) > 0) and (len(non_missing) > 0):
            estimates[i] = {}
            ineighbours = np.take(filled,neighbours[i][0:k],axis=0)
            A = np.take(ineighbours,non_missing,axis=1).T
            B = np.take(ineighbours,missing, axis=1).T
            w = np.take(input_matrix[i],non_missing)
            A_inv = np.linalg.pinv(A)
            X = np.dot(A_inv,w)
            values = np.dot(B,X)
            for j in range(len(missing)) :
                estimates[i][missing[j]] = values[j]
        else :
             unused_rows.append(i)

    for i in range(len(missing_values[0])) :
        x = missing_values[0][i]
        y = missing_values[1][i]
        if (x not in unused_rows) and (y not in unused_rows) :
            filled[x][y] = (estimates[x][y] + estimates[y][x]) / 2

    #Self interactions are set to 0    
    for i in range(len(filled)) :
        filled[i][i] = 0   

    return filled    

def output_results(filename, separator, gene_list, scores,imputed) :
    """
    Outputs a square data matrix
    
    """    
    f = open(filename,'w')
    
    for i in range(len(gene_list)) :
        f.write("%s%s" % (separator,gene_list[i]))

    f.write('\n')   

    for i in range(len(gene_list)) :
        f.write("%s" % gene_list[i])
        for j in range(len(gene_list)) :
            f.write(separator)
            interaction = get_ordered_tuple(gene_list[i],gene_list[j])
            if interaction in scores :
                score = scores[interaction]             
                
            elif interaction in imputed :
                score = imputed[interaction] 
            else : # Only happens for self interactions(where geneI == geneJ)
                score = 0.0    
            f.write("%s" % score)    
        f.write('\n')   
    f.close()

    return 0


#Below is the executable part of the program
    
help_message = '''
This is an implementation of the weighted local least squares algorithm,
as described in "Missing Value Imputation for Epistatic MAPs"

Requires Numpy - tested on version 1.4.0

Further information, and updates to this implementation will be available at : http://mlg.ucd.ie/emapimputation

Required parameters are as follows:
-i --input (filename): the E-MAP file to use as input
-o --output (filename): the file to output the resulting complete matrix to

Optional parameters as follows:
-k --neighbours (number): the number of neighbours to use for imputation, defaults to 20
-s --separator (character): the separator used in the input file, defaults to tab
-m --missing (string): the string to indicate a missing value, defaults to ""
-h --help: displays this message  

'''

class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg

def main(argv=None):
    k = DEFAULT_K
    weighted = True
    separator = SEPARATOR
    missing = MISSING
    
    if argv is None:
        argv = sys.argv
    try:
        try:
            opts, args = getopt.getopt(argv[1:], "hi:o:k:us:m:v", ["help","input=","output=","neighbours=","unweighted","separator=","missing="])
        except getopt.error, msg:
            raise Usage(msg)
    
        # option processing
        for option, value in opts:
            if option == "-v":
                verbose = True
            elif option in ("-h", "--help"):
                raise Usage(help_message)
            elif option in ("-i", "--input"):
                input_file = value
            elif option in ("-o", "--output"):
                output = value
            elif option in ("-k", "--neighbours"):
                k = int(value)
            elif option in ("-u", "--unweighted"):
                weighted = False
            elif option in ("-s", "--separator"):
                separator = value
            elif option in ("-m", "--missing"):
                missing = value   
        
        try :
            output
            input_file
        except NameError :
            print "Input and output filenames must be specified\n"
            raise Usage(help_message)
        else :
            #The runnable part of the program...
            data = read_square_dataset(input_file,missing,separator)
            imputed = LLSWrapper(data, k)
            output_results(output,separator,data[2],data[0],imputed)
            
    except Usage, err:
        print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + str(err.msg)
        print >> sys.stderr, "\t for help use --help"
        return 2              
    
    
    return 0

if __name__ == "__main__":
    sys.exit(main())
