#!/usr/bin/env python
# encoding: utf-8
"""
test_imputation.py

This script performs K fold validation to estimate the accuracy of imputation.
Note that that this does not output a complete matrix, use symmetricNN.py for that purpose.

Further information and updates to this implementation will be available at : http://www.bioinformatics.org/emapimputation

Version : 1.1 

$Rev:: 6             $:  Revision of last commit
$Author:: colmryan   $:  Author of last commit
$Date:: 2010-09-13 1#$:  Date of last commit

Created by Colm Ryan on 2010-07-07.
Copyright (c) 2010 Colm Ryan. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. 
You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, 
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions 
and limitations under the License. 

"""

import sys
from symmetricNN import *
import random
import math

DEFAULT_FOLDS = 10

help_message = '''
This script performs K fold validation to estimate the accuracy of imputation.
Note that that this does not output a complete matrix, use symmetricNN.py for that purpose.

Required parameters are as follows:
-i --input (filename): the E-MAP file to use as input

Optional parameters as follows:
-k --neighbours (number): the number of neighbours to use for imputation, defaults to 50
-u --unweighted: run nearest neighbours without any weighting
-s --separator (character): the separator used in the input file, defaults to tab
-m --missing (string): the string to indicate a missing value, defaults to ""
-h --help: displays this message
-f --folds: the number of folds to use for validation, defaults to 10.
'''

def calculateNRMSE(profiles,answer,guess):
    '''
    Calculates the Normalised Root Mean Squared Error as given by the following formula :
    NRMSE = \sqrt{ \frac{mean[(ij_{answer} - ij_{guess})^2] }{ variance[ij_{answer}]}}
    Inputs are the same format as the pearson_correlation method for ease of use.
    '''
    MSE = sum(pow(profiles[answer][interaction] - profiles[guess][interaction],2) for interaction in profiles[answer]) / len(profiles[answer])
    mean = sum(profiles[answer][interaction] for interaction in profiles[answer]) / len(profiles[answer])
    variance = sum(pow(profiles[answer][interaction] - mean,2) for interaction in profiles[answer]) / len(profiles[answer])
    return math.sqrt(MSE/variance)
    
def choose_random_elements(count, population) :
    '''
    Returns a set of elements of size 'count'
    These are randomly selected from the population list.
    '''
    elements = set()    
    while len(elements) < count :
        element = random.choice(population)
        elements.add(element)   
    return elements

def k_fold_test(k,dataset,gene_list, neighbours, weighted) :
    interactions = dataset.copy()
    guess_tuples = {}
    fold_size = int(len(interactions) / k)
    print "Fold size = ", fold_size
    folds = {}
    fake_profile = {}
    fake_profile['guess'] = {}
    fake_profile['answer'] = {}
    for i in range(k-1) :
        folds[i] = choose_random_elements(fold_size,interactions.keys())
        for interaction in folds[i] :
            del interactions[interaction]  
    folds[k-1] = interactions.keys()
    print "Imputing fold ",
    for fold in folds :
        print fold, ",",
        sys.stdout.flush()
        testset = dataset.copy()
        for pair in folds[fold] :
            del testset[pair]
        gene_profiles = {}    
        for geneI in gene_list :
            profile = {}
            for geneJ in gene_list :
                interaction = get_ordered_tuple(geneI,geneJ)
                if interaction in testset :
                    profile[geneJ] = testset[interaction]
            gene_profiles[geneI] = profile
                        
        data = (testset,gene_profiles,gene_list)
        filled = performWNN(data,neighbours,weighted)[0]
        for pair in folds[fold] :
            if pair in filled :
                fake_profile['guess'][pair] = filled[pair]
                fake_profile['answer'][pair] = dataset[pair]
    print "Estimated correlation = " , pearson_correlation(fake_profile,'answer','guess')
    print "Estimated NRMSE = " , calculateNRMSE(fake_profile,'answer','guess')
    return

class Usage(Exception):
    def __init__(self, msg):
        self.msg = msg

def main(argv=None):
    k = DEFAULT_K
    weighted = True
    separator = SEPARATOR
    missing = MISSING
    ranked_output = False
    output = False
    test = False
    folds = DEFAULT_FOLDS
    if argv is None:
        argv = sys.argv
    try:
        try:
            opts, args = getopt.getopt(argv[1:], "hi:o:k:us:m:r:f:v", ["help","input=","output=","neighbours=","unweighted","separator=","missing=","ranked=","folds="])
        except getopt.error, msg:
            raise Usage(msg)

        # option processing
        for option, value in opts:
            if option == "-v":
                verbose = True
            elif option in ("-h", "--help"):
                raise Usage(help_message)
            elif option in ("-i", "--input"):
                input_file = value
            elif option in ("-o", "--output"):
                output = value
            elif option in ("-k", "--neighbours"):
                k = int(value)
            elif option in ("-u", "--unweighted"):
                weighted = False
            elif option in ("-s", "--separator"):
                separator = value
            elif option in ("-m", "--missing"):
                missing = value
            elif option in ("-r", "--ranked"):
                ranked_output = value  
            elif option in ("-f","--folds") :
                folds = int(value)    

        try :
            input_file
        except NameError :
            print "Input filename must be specified\n"
            raise Usage(help_message)
        else :
            data = read_square_dataset(input_file,missing,separator)
            k_fold_test(folds,data[0],data[2],k,weighted)

    except Usage, err:
        print >> sys.stderr, sys.argv[0].split("/")[-1] + ": " + str(err.msg)
        print >> sys.stderr, "\t for help use --help"
        return 2              


    return 0

if __name__ == "__main__":
    sys.exit(main())
