/*  
 *  pstruct/em_gmix.c
 * 
 *  $Author: baptiste $, $Date: 2008-05-13 15:33:44 $, $Version$
 *
 *  Libgdl : a C library for statistical genetics
 * 
 *  Copyright (C) 2003-2006  Jean-Baptiste Veyrieras, INRA, France.
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA * 
 */
#include <math.h>
#include <float.h>

#include <gdl/gdl_common.h>
#include <gdl/gdl_util.h>
#include <gdl/gdl_rng.h>
#include <gdl/gdl_errno.h>
#include <gdl/gdl_allele_block.h>
#include <gdl/gdl_locus_block.h>
#include <gdl/gdl_frequency_block.h>
#include <gdl/gdl_gview.h>
#include <gdl/gdl_gview_wrapper.h>
#include <gdl/gdl_clustering.h>
#include <gdl/gdl_pstruct.h>

typedef struct
{
  	int start;
  	size_t K;
  	size_t n;
  	size_t p;
  	size_t l;
  	size_t nres;
  	size_t * na;
  	double abs_res;
  	double sq_res;
  	double       * pi;
	double       * upi;
 	gdl_frequencies_block * f;
  	gdl_frequencies_block * uf;
  	gdl_block    * q;
	gdl_gvalues_get   * gbuf;
	gdl_gview_wrapper * gwrap;
	gdl_allele_block  * gblock;
	gdl_rng * rng;
} gdl_em_gmix_t;

static double
_gdl_em_gmix_get_q (const gdl_em_gmix_t * state, size_t k, size_t i)
{
	return gdl_block_get (state->q, k, i);
}

static double
_gdl_em_gmix_get_f (const gdl_em_gmix_t * state, size_t k, size_t l, size_t a)
{
	return gdl_frequencies_block_get (state->f, k, l, a);
}


static int
gdl_em_gmix_update_fallele (gdl_em_gmix_t * state,
                                                         size_t k,
                                                         size_t l, 
                                                         size_t a,
                                                         double z,
                                                         double x)
{
	double p = x*z;
	double f = gdl_frequencies_block_get (state->uf, k, l, a);
	gdl_frequencies_block_set (state->uf, k, l, a, f + p);
	return 0;
}

static double
gdl_em_gmix_get_pr (gdl_em_gmix_t * state, size_t k, size_t i) 
{
	size_t j, l, a, na;
	double f, pr, t;
	const gdl_gvalues * x;
	
	pr = 1.0;
		
	for (j = 0; j < state->p; j++)
	{
		for (l = 0; l < state->l; l++)
		{
			a = gdl_allele_block_get (state->gblock, i, l, j);
			if (a)
			{
				pr*=_gdl_em_gmix_get_f (state, k, l, a-1);
			}
			else
			{
				gdl_gview_wrapper_get_allele_c (state->gwrap, i, l, j, state->gbuf);
 				x = gdl_gvalues_get_gvalues (state->gbuf);
				if ( x != 0 ) 
				{
					for (t = 0., a = 0; a < x->size; a++)
					{
						f   = _gdl_em_gmix_get_f (state, k, l, x->values[a]->idx);
						t  += f*x->values[a]->value;
					}
					pr *= t;
				}
			}
		}	
	}
	
	return pr;
}

static void
gdl_em_gmix_set_f (gdl_em_gmix_t * state,
                                                    size_t k,            
                                                    size_t l,
                                                    size_t a,
                                                    double f)
{
	gdl_frequencies_block_set (state->f, k, l, a, f);
}  

static void
gdl_em_gmix_set_q (gdl_em_gmix_t * state,
                                                    size_t k,             
                                                    size_t i,
                                                    double q) 
{
	gdl_block_set (state->q, k, i, q);	
}                                                    


static int
gdl_em_gmix_iterate_start (gdl_em_gmix_t * state)
{
	size_t i, j, k;
	
	// Clean the buffers
	for ( k = 0; k < state->K; k++)
	{
		state->upi[k] = 0.0;
		for ( i = 0; i < state->l; i++)
		{
			for ( j = 0; j < state->na[i]; j++)
			{
				gdl_frequencies_block_set (state->uf, k, i, j, 0.);
			}	
		}
	}
	
	state->abs_res = 0.;
	state->sq_res  = 0.;
	state->nres    = 0;
	
	return 0;
}

static int
gdl_em_gmix_update_q (gdl_em_gmix_t * state, size_t i)
{
	size_t k;
	double pr, q, z, s = 0;
	
	for (k = 0; k < state->K; k++)
	{
		pr = gdl_em_gmix_get_pr (state, k, i);
		if (pr < DBL_MIN)
		{
			//fprintf (stderr, "WARNING Proba for individual %d in pop %d too small\n", i, k);
			gdl_em_gmix_set_q (state, k, i, 0.0);
		}
		else
		{
			z  = (state->pi[k])*pr;
			gdl_em_gmix_set_q (state, k, i, z);
			s += z;
		}
	}
	if (s > DBL_MIN)
	{
		for (k = 0; k < state->K; k++)
		{
			z = _gdl_em_gmix_get_q (state, k, i);
			gdl_em_gmix_set_q (state, k, i, z/s);
			//printf ("UPDATE Q[%d][%d] %g\n", i, k, z/s);
			state->upi[k] += z/s;
		}
	}
	else
	{
		for (k = 0; k < state->K; k++)
		{
			gdl_em_gmix_set_q (state, k, i, 1.0/(double)state->K);
			state->upi[k] += 1.0/(double)state->K;
		}
	}
	
	return 0;	
}

static int
gdl_em_gmix_update_f (gdl_em_gmix_t * state,
                                                    size_t i,
                                                    size_t j,
                                                    size_t l)
{
	size_t  k, a, na, ni;
	double z;
	const gdl_gvalues * x;
	
	ni  = gdl_gview_wrapper_accession_mult_c (state->gwrap, i);
	
	a = gdl_allele_block_get (state->gblock, i, l, j);
	
	if (a)
	{
		for (k = 0; k < state->K; k++)
		{
			z = _gdl_em_gmix_get_q (state, k, i);
			gdl_em_gmix_update_fallele(state, k, l, a-1, z, ni);
		}
	}
	else
	{
		gdl_gview_wrapper_get_allele_c (state->gwrap, i, l, j, state->gbuf);
		x = gdl_gvalues_get_gvalues (state->gbuf);
		for (k = 0; k < state->K; k++)
		{
			z = _gdl_em_gmix_get_q (state, k, i);
			if ( x != 0 ) 
			{
				for ( a = 0; a < x->size; a++)
				{
					gdl_em_gmix_update_fallele(state, k, l, x->values[a]->idx, z, ni*x->values[a]->value);
				}
			}
		}	
	}
	
	return 0;
}

static int
gdl_em_gmix_iterate_end_swap (gdl_em_gmix_t * state, const int * lidx)
{
  	double * pitmp;
	
	if (lidx)
	{
		size_t i, j, k;
		double f;
		
		for (i = 0; i < state->l; i++)
		{
			if (lidx[i] < 0)
			{
				for (j = 0; j < state->na[i]; j++)
				{
					for (k = 0; k < state->K; k++)
					{
						f = gdl_frequencies_block_get (state->f, k, i, j);
					   gdl_frequencies_block_set (state->f, k, i, j, gdl_frequencies_block_get (state->uf, k, i, j));
					   gdl_frequencies_block_set (state->uf, k, i, j, f);
					}
				}	
			}	
		}
	}
	else
	{
		gdl_frequencies_block  * ftmp;
	
		ftmp      = state->f;
		state->f  = state->uf;
		state->uf = ftmp;
	}
	
	pitmp      = state->pi;
	state->pi  = state->upi;
	state->upi = pitmp;
	
	return 0;
}

static int
gdl_em_gmix_iterate_end_pi (gdl_em_gmix_t * state)
{
	double s, q, e;
	size_t k;
	
	for ( k = 0; k < state->K; k++)
	{
		s += state->pi[k];
		//printf (">>PI[%d] %g\n", k, state->pi[k]);
	}
	for ( k = 0; k < state->K; k++)
	{
		q = state->pi[k];
		if (q/s < DBL_MIN)
		{
			state->pi[k] = DBL_MIN;
			// WARNING
		}
		else
		{
			state->pi[k] = q/s;
		}
		//printf (">>PI[%d] %g (%g)\n", k, state->pi[k], state->upi[k]);
		e = fabs(state->pi[k] - state->upi[k]);
		state->abs_res += e;
		state->sq_res  += e*e;
		(state->nres)++;
	}
	
	return 0;
}

static int
gdl_em_gmix_iterate_end_f (gdl_em_gmix_t * state, const int * lidx)
{
	double t, s, f, e;
	size_t i, j, k;
	
	for ( k = 0; k < state->K; k++)
	{
		for ( i = 0; i < state->l; i++)
		{
			if ((lidx && lidx[i] < 0) || lidx == 0)
			{
				s = 0;
				for ( j = 0; j < state->na[i]; j++)
				{
					s += _gdl_em_gmix_get_f (state, k, i, j);
				}
				for (j = 0; j < state->na[i]; j++)
				{
					f = _gdl_em_gmix_get_f (state, k, i, j);
					if ( f/s < DBL_MIN)
					{
						f = 0.;
					}
					gdl_em_gmix_set_f (state, k, i, j, f/s);
					//printf ("F[%d][%d][%d] %g %g\n", k, i, j, f/s, gdl_frequencies_block_get (state->uf, k, i, j));
					e = fabs(f/s - gdl_frequencies_block_get (state->uf, k, i, j));
					state->abs_res += e;
					state->sq_res  += e*e;
					(state->nres)++;
					//t+=f/s;
				}
			}
			//printf ("TOTLA %g\n", t);
		}
	}
	
	return 0;
}

static int
gdl_em_gmix_iterate_end (gdl_em_gmix_t * state, const int * lidx)
{
	int status;
	
	// Swap here the buffer
	
	status  = gdl_em_gmix_iterate_end_swap (state, lidx);
	
	status |= gdl_em_gmix_iterate_end_pi (state);
	
	status |= gdl_em_gmix_iterate_end_f (state, lidx);
	
	state->start = 1;
	
	return status;
}

int
gdl_em_gmix_alloc (void * vstate, void * data, gdl_rng * rng, size_t k)
{
	size_t i, * tmp;
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
  
   state->gwrap = (gdl_gview_wrapper *) data;
  
	state->start = 1;
	state->rng   = rng;
	state->K     = k;
	state->n     = gdl_gview_wrapper_accession_size_c (state->gwrap);
	state->p     = gdl_gview_wrapper_ploidy (state->gwrap);
	state->gbuf  = gdl_gview_wrapper_get_new (state->gwrap);
	state->l     = gdl_gview_wrapper_locus_size (state->gwrap);
	state->na    = GDL_CALLOC (size_t, state->l);
	for (i = 0; i < state->l; i++)
	{
		gdl_locus * locus = gdl_gview_wrapper_get_locus (state->gwrap, i);
		state->na[i]      = gdl_locus_allele (locus);
	}
	state->f     = gdl_frequencies_block_alloc (state->K, state->l, state->na);
	state->uf    = gdl_frequencies_block_alloc (state->K, state->l, state->na);
	
	state->q     = gdl_block_alloc2 (2, state->K, state->n);
	
	state->pi    = GDL_CALLOC (double, state->K);
	state->upi   = GDL_CALLOC (double, state->K);
	
	state->gblock = gdl_gview_wrapper_allele_block_c (state->gwrap, 1);
	
   return GDL_SUCCESS;
}

int
gdl_em_gmix_free (void * vstate)
{
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	
	if (state == 0) return;
	
	gdl_frequencies_block_free (state->f);
	gdl_block_free (state->q);
	gdl_frequencies_block_free (state->uf);
	GDL_FREE (state->pi);
	GDL_FREE (state->upi);
	GDL_FREE (state->na);
	gdl_gvalues_get_free (state->gbuf);
	
	gdl_allele_block_free (state->gblock);
	
	return GDL_SUCCESS;
}

static void
_gdl_em_gmix_init_rng (gdl_em_gmix_t * state, size_t i)
{
	size_t k;
	double s, z;
	
	for (s = 0, k = 0; k < state->K; k++)
	{
		z = gdl_rng_uniform (state->rng);
		gdl_em_gmix_set_q (state, k, i, z);
		s += z;
	}
	for (k = 0; k < state->K; k++)
	{
		z = _gdl_em_gmix_get_q (state, k, i);
		gdl_em_gmix_set_q (state, k, i, z/s);
		state->upi[k] += z/s;
	}
}

int
gdl_em_gmix_init (void * vstate)
{
	int status=0;
	size_t i, j, l, k;
	double s, z;
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	
	// Randomly init z
	for (i = 0; i < state->n; i++)
	{
		_gdl_em_gmix_init_rng (state, i);
		
		for (j = 0; j < state->p; j++)
		{
			for (l = 0; l < state->l; l++)
			{
				status |= gdl_em_gmix_update_f (state, i, j, l);
			}
		}
	}
	
	status |= gdl_em_gmix_iterate_end (state, 0);
	
	return status;
}

int
gdl_em_gmix_init_update (void * vstate, const int * aidx, const int * lidx, const void * db, double (*get_q)(const void * db, size_t k, size_t i), double (*get_f)(const void * db, size_t k, size_t l, size_t a))
{
	int status=0;
	size_t i, j, l, k, a;
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	
	for (i = 0; i < state->n; i++) // loop on individuals
	{
		if (aidx[i] >= 0)
		{
			// set the previous proba
			for (k = 0; k < state->K; k++)
			{
				gdl_em_gmix_set_q (state, k, i, (*get_q)(db, k, aidx[i]));
				state->upi[k] += _gdl_em_gmix_get_q (state, k, i);
			}
		}
		else
		{
			_gdl_em_gmix_init_rng (state, i);
		}
		for (l = 0; l < state->l; l++) // loop on loci
		{
			if (i == 0 && lidx[l] >= 0)
			{
				// set the previous locus allele frequencies
				for (a = 0; a < state->na[l]; a++)
				{ 
					for (k = 0; k < state->K; k++)
					{
						gdl_em_gmix_set_f (state, k, l, a, (*get_f)(db, k, lidx[l], a));
					}
				}
			}
			else if (lidx[l] < 0)
			{
				for (j = 0; j < state->p; j++) // loop on phases
				{
					status |= gdl_em_gmix_update_f (state, i, j, l);
				}
			}
		}		
	}
	
	status |= gdl_em_gmix_iterate_end (state, lidx);
	
	return status;
}

int
gdl_em_gmix_iterate (void * vstate)
{
	size_t i, j, l;
	int status;
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	
	status = gdl_em_gmix_iterate_start (state);
	
	for (i = 0; i < state->n; i++)
	{
		status |= gdl_em_gmix_update_q (state, i);
		
		for (j = 0; j < state->p; j++)
		{
			for (l = 0; l < state->l; l++)
			{
				status |= gdl_em_gmix_update_f (state, i, j, l);							
			}
		}
	}
	
	status |= gdl_em_gmix_iterate_end (state, 0);
	
	return status;
}

int
gdl_em_gmix_iterate_update (void * vstate, const int * aidx, const int * lidx)
{
	size_t i, j, k, l;
	int status;
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	
	status = gdl_em_gmix_iterate_start (state);
	
	for (i = 0; i < state->n; i++)
	{
		if (aidx[i] < 0)
		{
			status |= gdl_em_gmix_update_q (state, i);
		}
		else
		{
			for (k = 0; k < state->K; k++)
			{
				state->upi[k] += _gdl_em_gmix_get_q (state, k, i);
			}	
		}
		for (l = 0; l < state->l; l++)
		{
			if (lidx[l] < 0)
			{
				for (j = 0; j < state->p; j++)
				{
					status |= gdl_em_gmix_update_f (state, i, j, l);
				}
			}
		}
	}
	
	status |= gdl_em_gmix_iterate_end (state, lidx);
	
	return status;
}

static int
gdl_em_gmix_allele_imputation (gdl_em_gmix_t * state, size_t i, size_t l, gdl_gvalues * x)
{
	size_t j, k;
	double q, f;
	gdl_gvalue * gx;
	
	for (j = 0; j < x->size; j++)
	{
		gx = x->values[j];
		gx->value = 0;
		for (k = 0; k < state->K; k++)
		{
			q = _gdl_em_gmix_get_q (state, k, i);
			f = _gdl_em_gmix_get_f (state, k, l, gx->idx);
			gx->value += q*f;
		}
		//printf ("ALLELE [%d][%d][%d] %g\n", i, l, gx->idx, gx->value);
	}
}

int
gdl_em_gmix_imputation (void * vstate)
{
	size_t na, nl, i, ic, j, l, il;
	const size_t * lidx;
	int status;
	gdl_gvalues * x;
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	
	na = gdl_gview_wrapper_missing_accession_size (state->gwrap);
	
	for (i = 0; i < na; i++)
	{
		ic = gdl_gview_wrapper_missing_accession_idx_c (state->gwrap, i);
		for (j = 0; j < state->p; j++)
		{
			nl   = gdl_gview_wrapper_missing_hlocus_size (state->gwrap, i, j);
			lidx = gdl_gview_wrapper_missing_hlocus_idx (state->gwrap, i, j);
			for (l = 0; l < nl; l++)
			{
				il = lidx[l];
				//printf ("MISSING DATA [%d, %d, %d]\n", ic, j, il);
				x = gdl_gview_wrapper_missing_hget (state->gwrap, i, j, l);
				gdl_em_gmix_allele_imputation (state, ic, il, x);
			}	
		}	
	}
	
	return GDL_SUCCESS;
}

double
gdl_em_gmix_loglikelihood (const void * vstate)
{
	size_t i, k;
	double pk, p, v = 0.;
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	
	for (i = 0; i < state->n; i++)
	{
		for (p = 0., k = 0; k < state->K; k++)
		{
			pk = gdl_em_gmix_get_pr (state, k, i);
			if (pk > DBL_MIN)
			{
				p += (state->pi[k])*pk;
			}
		}
		if (p > DBL_MIN)
		{
			v += log (p);
		}
		else
		{
			//fprintf (stderr, "WARNING Loglikelihood ignores individual %d: proba too small\n", i);
		}		
	}
	
	return v;
}

double
gdl_em_gmix_residual_sq (const void * vstate)
{
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	return state->sq_res/state->nres;
}

double
gdl_em_gmix_residual_abs (const void * vstate)
{
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	return state->abs_res/state->nres;
}

double
gdl_em_gmix_get_pop_q (const void * vstate, size_t k)
{
	size_t i, nn, m;
	double q = 0;
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	
	return state->pi[k];
}

double
gdl_em_gmix_get_q (const void * vstate, size_t k, size_t i)
{
	size_t ii;
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	ii = gdl_clustering_cluster (gdl_gview_wrapper_clustering (state->gwrap), i);
	return _gdl_em_gmix_get_q (state, k, ii);
}

double
gdl_em_gmix_get_f (const void * vstate, size_t k, size_t l, size_t a)
{
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	return _gdl_em_gmix_get_f (state, k, l, a);
}

size_t
gdl_em_gmix_get_f_max (const void * vstate, size_t pop, size_t loc)
{
	double x, max  = GDL_NEGINF;
	size_t a, amax = 0;
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	
	for (a = 0; a < state->na[loc]; a++)
	{
		x = _gdl_em_gmix_get_f (state, pop, loc, a);
		if (x > max)
		{
			max  = x;
			amax = a;
		}
	}
	
	return amax;
}

size_t
gdl_em_gmix_get_f_size (const void * vstate, size_t loc)
{
	gdl_em_gmix_t * state = (gdl_em_gmix_t *) vstate;
	return state->na[loc];
}

static const gdl_pstruct_workspace_type _em_gmix =
{
	"gdl_pstruct_em_gmixture",
	sizeof (gdl_em_gmix_t),
	&gdl_em_gmix_alloc,
	&gdl_em_gmix_free,
	&gdl_em_gmix_init,
	&gdl_em_gmix_init_update,
	&gdl_em_gmix_iterate,
	&gdl_em_gmix_iterate_update,
	&gdl_em_gmix_imputation,
	&gdl_em_gmix_loglikelihood,
	&gdl_em_gmix_residual_abs,
	&gdl_em_gmix_residual_sq,
	&gdl_em_gmix_get_pop_q,
	&gdl_em_gmix_get_q,
	&gdl_em_gmix_get_f,
	&gdl_em_gmix_get_f_max,
	&gdl_em_gmix_get_f_size
};

const gdl_pstruct_workspace_type * gdl_pstruct_em_gmixture = &_em_gmix;
