/*  
 *  pstruct/hadmix.c
 * 
 *  $Author: baptiste $, $Date: 2008-05-13 15:33:44 $, $Version$
 *
 *  Libgdl : a C library for statistical genetics
 * 
 *  Copyright (C) 2003-2006  Jean-Baptiste Veyrieras, INRA, France.
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA * 
 */
#include <math.h>

#include <gdl/gdl_common.h>
#include <gdl/gdl_util.h>
#include <gdl/gdl_rng.h>
#include <gdl/gdl_errno.h>
#include <gdl/gdl_hview.h>
#include <gdl/gdl_clustering.h>
#include <gdl/gdl_pstruct.h>


typedef struct
{
  	int start;
  	size_t K;
  	size_t n;
  	size_t nn;
  	size_t p;
  	size_t l;
  	size_t * na;
  	double abs_res;
  	double sq_res;
  	double ** iq;
  	gdl_hnblock  * f;
  	gdl_hnblock  * uf;
  	gdl_block    * q;
  	gdl_block    * uq;
  	gdl_block    * z;
	gdl_hview * hview;
	gdl_rng   * rng;
} gdl_em_hadmix_t;

static double
_gdl_em_hadmix_get_q (const gdl_em_hadmix_t * state, size_t k, size_t i)
{
	return gdl_block_get (state->q, k, i);
}

static double
_gdl_em_hadmix_get_f (const gdl_em_hadmix_t * state, size_t k, size_t l, size_t a)
{
	return gdl_hnblock_get (state->f, k, l, a);
}


static int
gdl_em_hadmix_update_fallele (gdl_em_hadmix_t * state,
                                                         size_t k,
                                                         size_t l, 
                                                         size_t a,
                                                         double z,
                                                         double x)
{
	double p = x*z;
	double f = gdl_hnblock_get (state->uf, k, l, a);
	gdl_hnblock_set (state->uf, k, l, a, f + p);
	return 0;
}

static int 
gdl_em_hadmix_update_qz(gdl_em_hadmix_t * state,
                                                    size_t k,
                                                    size_t i, 
                                                    double z)
{
	double q = gdl_block_get (state->uq, k, i);
	gdl_block_set (state->uq, k, i, q + z);
	return 0;
}                                              

static double
gdl_em_hadmix_get_pr (gdl_em_hadmix_t * state,
                                                    size_t k,             
                                                    size_t i,
                                                    size_t l) 
{
	const gdl_allele * x;
	
	x = gdl_hview_get_hallele (state->hview, gdl_hview_get_haplotype (state->hview, i), l);
		
	return _gdl_em_hadmix_get_f (state, k, l, x->idx);
}

static void
gdl_em_hadmix_set_f (gdl_em_hadmix_t * state,
                                                    size_t k,            
                                                    size_t l,
                                                    size_t a,
                                                    double f)
{
	gdl_hnblock_set (state->f, k, l, a, f);
}  

static void
gdl_em_hadmix_set_q (gdl_em_hadmix_t * state,
                                                    size_t k,             
                                                    size_t i,
                                                    double q) 
{
	gdl_block_set (state->q, k, i, q);	
}                                                    

static double
gdl_em_hadmix_get_z (const gdl_em_hadmix_t * state,
                                                    size_t k,          
                                                    size_t i,
                                                    size_t l) 
{
	return gdl_block_get (state->z, k, i, l);
}

static void
gdl_em_hadmix_set_z (gdl_em_hadmix_t * state,
                                                    size_t k,          
                                                    size_t i,
                                                    size_t l,
                                                    double z) 
{
	gdl_block_set (state->z, k, i, l, z);
}


static int
gdl_em_hadmix_iterate_start (gdl_em_hadmix_t * state)
{
	size_t i, j, k, l;
	
	// Clean the buffers
	for ( k = 0; k < state->K; k++)
	{
		for ( i = 0; i < state->n; i++)
		{
			gdl_block_set (state->uq, k, i, 0.);
		}	
	}
	for ( k = 0; k < state->K; k++)
	{
		for ( i = 0; i < state->l; i++)
		{
			for ( j = 0; j < state->na[i]; j++)
			{
				gdl_hnblock_set (state->uf, k, i, j, 0.);
			}	
		}
	}
	
	state->abs_res = 0.;
	state->sq_res = 0.;
	
	return 0;
}

static int
gdl_em_hadmix_update_z (gdl_em_hadmix_t * state,
                                                    size_t i,
                                                    size_t l)
{
	size_t k;
	double pr, q, z, s = 0;
	
	for (k = 0; k < state->K; k++)
	{
		pr = gdl_em_hadmix_get_pr (state, k, i, l);
		q  = _gdl_em_hadmix_get_q (state, k, i);
		z  = q*pr;
		s += z;
		gdl_em_hadmix_set_z (state, k, i, l, z);
	}
	for (k = 0; k < state->K; k++)
	{
		z = gdl_em_hadmix_get_z (state, k, i, l);
		gdl_em_hadmix_set_z (state, k, i, l, z/s);
	}
	
	return 0;	
}

static int
gdl_em_hadmix_update_f (gdl_em_hadmix_t * state,
                                                    size_t i,
                                                    size_t l)
{
	size_t ii, k, a, na, ni;
	double z;
	const gdl_allele * x;
	
	ni  = gdl_hview_get_haplotype_freq (state->hview, i);
	ni *= state->nn*state->p;
	
	for (k = 0; k < state->K; k++)
	{
		z = gdl_em_hadmix_get_z (state, k, i, l);
		x = gdl_hview_get_hallele (state->hview, gdl_hview_get_haplotype (state->hview, i), l);
		gdl_em_hadmix_update_fallele(state, k, l, x->idx, z, ni);
	}
	
	return 0;
}

static int
gdl_em_hadmix_update_q (gdl_em_hadmix_t * state,
                                                    size_t i,
                                                    size_t l)
{
	size_t k;
	
	for (k = 0; k < state->K; k++)
	{
		double z = gdl_em_hadmix_get_z (state, k, i, l);
		gdl_em_hadmix_update_qz (state, k, i, z);
	}
	
	return 0;
}

static int
gdl_em_hadmix_iterate_end_swap (gdl_em_hadmix_t * state)
{
	gdl_hnblock  * ftmp;
  	gdl_block    * qtmp;
		
	ftmp      = state->f;
	state->f  = state->uf;
	state->uf = ftmp;
	
	qtmp      = state->q;
	state->q  = state->uq;
	state->uq = qtmp;
	
	return 0;
}

static int
gdl_em_hadmix_iterate_end_q (gdl_em_hadmix_t * state)
{
	double s, q, e;
	size_t i, k;
	
	for ( i = 0; i < state->n; i++)
	{
		s = 0;
		for ( k = 0; k < state->K; k++)
		{
			s += _gdl_em_hadmix_get_q (state, k, i);
			//printf (">>Q[%d][%d] %g\n", k, i, gdl_em_hadmix_get_q (state, k, i));
		}
		for ( k = 0; k < state->K; k++)
		{
			q = _gdl_em_hadmix_get_q (state, k, i);
			if ( q/s < 1.e-20) q = 0.;
			//printf ("Q[%d][%d] %g\n", k, i, q/s);
			gdl_em_hadmix_set_q (state, k, i, q/s);
			e = fabs(q/s - gdl_block_get (state->uq, k, i));
			state->abs_res += e;
			state->sq_res  += e*e;
		}		
	}
	
	return 0;
}

static int
gdl_em_hadmix_iterate_end_f (gdl_em_hadmix_t * state)
{
	double s, f, e;
	size_t i, j, k;
	
	for ( k = 0; k < state->K; k++)
	{
		for ( i = 0; i < state->l; i++)
		{
			s = 0;
			for ( j = 0; j < state->na[i]; j++)
			{
				s += _gdl_em_hadmix_get_f (state, k, i, j);
			}
			for ( j = 0; j < state->na[i]; j++)
			{
				f = _gdl_em_hadmix_get_f (state, k, i, j);
				if ( f/s < 1.e-20) f = 0.;
				gdl_em_hadmix_set_f (state, k, i, j, f/s);
				//printf ("F[%d][%d][%d] %g\n", k, i, j, f/s);
				e = fabs(f/s - gdl_hnblock_get (state->uf, k, i, j));
				state->abs_res += e;
				state->sq_res  += e*e;
			}
		}
	}
	
	return 0;
}

static int
gdl_em_hadmix_iterate_end (gdl_em_hadmix_t * state)
{
	int status;
	
	// Swap here the buffer
	
	status  = gdl_em_hadmix_iterate_end_swap (state);
	
	status |= gdl_em_hadmix_iterate_end_q (state);
	
	status |= gdl_em_hadmix_iterate_end_f (state);
	
	state->start = 1;
	
	return status;
}

int
gdl_em_hadmix_alloc (void * vstate, void * data, gdl_rng * rng, size_t k)
{
	size_t i, * tmp;
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
  
   state->hview = (gdl_hview *) data;
   state->rng   = rng;
  
	state->start = 1;
	state->K     = k;
	
	state->n     = gdl_hview_haplotype_size (state->hview);
	state->nn    = gdl_hview_accession_size (state->hview);
	state->p     = gdl_hview_ploidy (state->hview);
	state->l     = gdl_hview_locus_size (state->hview);
	state->na    = GDL_CALLOC (size_t, state->l);
	for (i = 0; i < state->l; i++)
	{
		gdl_locus * locus = gdl_hview_get_locus (state->hview, i);
		state->na[i]      = gdl_locus_allele (locus);
	}
	tmp          = GDL_MALLOC (size_t, 2);
	tmp[0]       = state->K;
	tmp[1]       = state->l;
	state->f     = gdl_hnblock_alloc (1, 1, tmp, &state->na);
	state->q     = gdl_block_alloc2 (2, state->K, state->n);
	state->uf    = gdl_hnblock_alloc (1, 1, tmp, &state->na);
	state->uq    = gdl_block_alloc2 (2, state->K, state->n);
	state->z     = gdl_block_alloc2 (3, state->K, state->n, state->l);
	state->iq    = GDL_CALLOC (double *, state->nn);
	
	GDL_FREE (tmp);
	
    return GDL_SUCCESS;
}

int
gdl_em_hadmix_free (void * vstate)
{
	size_t i;
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	
	if (state == 0) return;
	
	gdl_hnblock_free (state->f);
	gdl_block_free (state->q);
	gdl_hnblock_free (state->uf);
	gdl_block_free (state->uq);	
	gdl_block_free (state->z);
	GDL_FREE (state->na);
	
	for (i = 0; i < state->nn; i++)
	{
		GDL_FREE (state->iq[i]);
	}
	GDL_FREE (state->iq[i]);
	
	return GDL_SUCCESS;
}

int
gdl_em_hadmix_init (void * vstate)
{
	int status=0;
	size_t i, j, l, k;
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	
	// Randomly init z
	for (i = 0; i < state->n; i++)
	{
		for (l = 0; l < state->l; l++)
		{
			double z, s = 0.;
			for (k = 0; k < state->K; k++)
			{
				z = gdl_rng_uniform (state->rng);
				gdl_em_hadmix_set_z (state, k, i, l, z);
				s += z;
			}
			for (k = 0; k < state->K; k++)
			{
				z = gdl_em_hadmix_get_z (state, k, i, l);
				gdl_em_hadmix_set_z (state, k, i, l, z/s);
			}
			status |= gdl_em_hadmix_update_f (state, i, l);
			status |= gdl_em_hadmix_update_q (state, i, l);
		}
	}
	
	status |= gdl_em_hadmix_iterate_end (state);
	
	return status;
}

int
gdl_em_hadmix_iterate (void * vstate)
{
	size_t i, j, l;
	int status;
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	
	status = gdl_em_hadmix_iterate_start (state);
	
	for (i = 0; i < state->n; i++)
	{
		for (l = 0; l < state->l; l++)
		{
			status |= gdl_em_hadmix_update_z (state, i, l);
			status |= gdl_em_hadmix_update_f (state, i, l);
			status |= gdl_em_hadmix_update_q (state, i, l);				
		}
	}
	
	status |= gdl_em_hadmix_iterate_end (state);
	
	return status;
}

int
gdl_em_hadmix_init_update (void * vstate, const int * aidx, const int * lidx, const void * db, double (*get_q)(const void * db, size_t k, size_t i), double (*get_f)(const void * db, size_t k, size_t l, size_t a))
{

}
int
gdl_em_hadmix_iterate_update (void * vstate, const int * aidx, const int * lidx)
{
}

double
gdl_em_hadmix_loglikelihood (const void * vstate)
{
	size_t i, j, k, l, nn;
	double s, p, q, v = 0.;
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	
	nn = state->nn*state->p;
	
	for (i = 0; i < state->n; i++)
	{
		for (l = 0; l < state->l; l++)
		{
			s = 0;
			for (k = 0; k < state->K; k++)
			{
				p  = gdl_em_hadmix_get_pr (state, k, i, l);
				q  = _gdl_em_hadmix_get_q (state, k, i);
				s += p*q;
			}
			v += log(s)*gdl_hview_get_haplotype_freq (state->hview, i)*nn;
		}
	}
	
	return v;
}

double
gdl_em_hadmix_residual_sq (const void * vstate)
{
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	return state->sq_res;
}

double
gdl_em_hadmix_residual_abs (const void * vstate)
{
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	return state->abs_res;
}

double
gdl_em_hadmix_get_pop_q (const void * vstate, size_t k)
{
	size_t i, nh;
	double m, q = 0;
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	
	nh = gdl_hview_haplotype_size (state->hview);
	
	for (i = 0; i < nh; i++)
	{
		m = gdl_hview_get_haplotype_freq (state->hview, i);
		q += _gdl_em_hadmix_get_q (state, k, i)*m;
	}
	
	return q;
}

double
gdl_em_hadmix_get_q (const void * vstate, size_t k, size_t i)
{
	size_t j, p, h, kk, nh;
	double z, zq, tot;
	gdl_hconfig * hc;
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	
	if (state->iq[i] == 0)
	{
		state->iq[i] = GDL_CALLOC (double, state->K);
		
		nh = gdl_hview_hconfig_size (state->hview, i);
		
		for (tot = j = 0; j < nh; j++)
		{
			hc = gdl_hview_get_hconfig (state->hview, i, j);
			z  = gdl_hconfig_get_proba (hc);
			for (p = 0; p < state->p; p++)
			{
				h = gdl_hconfig_get_haplotype (hc, p);
				for (kk = 0; kk < state->K; kk++)
				{
					zq = z*_gdl_em_hadmix_get_q (state, kk, h);
					state->iq[i][kk] += zq;
					tot += zq;
				}
			}
		}
		
		for (kk = 0; kk < state->K; kk++)
		{
			state->iq[i][kk] /= tot;
		}	
	}
	
	return state->iq[i][k];
}

double
gdl_em_hadmix_get_f (const void * vstate, size_t k, size_t l, size_t a)
{
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	return _gdl_em_hadmix_get_f (state, k, l, a);
}

size_t
gdl_em_hadmix_get_f_max (const void * vstate, size_t pop, size_t loc)
{
	double x, max  = GDL_NEGINF;
	size_t a, amax = 0;
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	
	for (a = 0; a < state->na[loc]; a++)
	{
		x = _gdl_em_hadmix_get_f (state, pop, loc, a);
		if (x > max)
		{
			max  = x;
			amax = a;
		}
	}
	
	return amax;
}

size_t
gdl_em_hadmix_get_f_size (const void * vstate, size_t loc)
{
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	return state->na[loc];
}

int
gdl_em_hadmix_fread (FILE * stream, void * vstate)
{
	int status;
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	
	status = fread (&state->K, sizeof (size_t), 1, stream);
	GDL_FREAD_STATUS (status, 1);
	status = fread (&state->n, sizeof (size_t), 1, stream);
	GDL_FREAD_STATUS (status, 1);
	status = fread (&state->p, sizeof (size_t), 1, stream);
	GDL_FREAD_STATUS (status, 1);
	status = fread (&state->l, sizeof (size_t), 1, stream);
	GDL_FREAD_STATUS (status, 1);
	status = fread (&state->abs_res, sizeof (double), 1, stream);
	GDL_FREAD_STATUS (status, 1);
	status = fread (&state->sq_res, sizeof (double), 1, stream);
	GDL_FREAD_STATUS (status, 1);
	state->na = GDL_ARRAY_FREAD (size_t, state->l, stream);
	GDL_FREAD_STATUS (state->na==0, 0);
	state->f  = gdl_hnblock_fread (stream);
	GDL_FREAD_STATUS (state->f == 0, 0);
	state->uf = gdl_hnblock_fread (stream);
	GDL_FREAD_STATUS (state->uf == 0, 0);
	state->q  = gdl_block_fread (stream);
	GDL_FREAD_STATUS (state->q == 0, 0);
	state->uq = gdl_block_fread (stream);
	GDL_FREAD_STATUS (state->uq == 0, 0);
	state->z  = gdl_block_fread (stream);
	GDL_FREAD_STATUS (state->z == 0, 0);
	
	return GDL_SUCCESS;
}

int
gdl_em_hadmix_fwrite (FILE * stream, const void * vstate)
{
	int status;
	gdl_em_hadmix_t * state = (gdl_em_hadmix_t *) vstate;
	
	status = fwrite (&state->K, sizeof (size_t), 1, stream);
	GDL_FWRITE_STATUS (status, 1);
	status = fwrite (&state->n, sizeof (size_t), 1, stream);
	GDL_FWRITE_STATUS (status, 1);
	status = fwrite (&state->p, sizeof (size_t), 1, stream);
	GDL_FWRITE_STATUS (status, 1);
	status = fwrite (&state->l, sizeof (size_t), 1, stream);
	GDL_FWRITE_STATUS (status, 1);
	status = fwrite (&state->abs_res, sizeof (double), 1, stream);
	GDL_FWRITE_STATUS (status, 1);
	status = fwrite (&state->sq_res, sizeof (double), 1, stream);
	GDL_FWRITE_STATUS (status, 1);
	status = GDL_ARRAY_FWRITE (state->na, size_t, state->l, stream);
	GDL_FWRITE_STATUS (status, GDL_SUCCESS);
	status = gdl_hnblock_fwrite (stream, state->f);
	GDL_FWRITE_STATUS (status, GDL_SUCCESS);
	status = gdl_hnblock_fwrite (stream, state->uf);
	GDL_FWRITE_STATUS (status, GDL_SUCCESS);
	status  = gdl_block_fwrite (stream, state->q);
	GDL_FWRITE_STATUS (status, GDL_SUCCESS);
	status = gdl_block_fwrite (stream, state->uq);
	GDL_FWRITE_STATUS (status, GDL_SUCCESS);
	status  = gdl_block_fwrite (stream, state->z);
	GDL_FWRITE_STATUS (status, GDL_SUCCESS);
	
	return GDL_SUCCESS;	
}


static const gdl_pstruct_workspace_type _hview =
{
	"gdl_pstruct_em_hadmixture",
	sizeof (gdl_em_hadmix_t),
	&gdl_em_hadmix_alloc,
	&gdl_em_hadmix_free,
	&gdl_em_hadmix_init,
	&gdl_em_hadmix_init_update,
	&gdl_em_hadmix_iterate,
	&gdl_em_hadmix_iterate_update,
	NULL,
	&gdl_em_hadmix_loglikelihood,
	&gdl_em_hadmix_residual_abs,
	&gdl_em_hadmix_residual_sq,
	&gdl_em_hadmix_get_pop_q,
	&gdl_em_hadmix_get_q,
	&gdl_em_hadmix_get_f,
	&gdl_em_hadmix_get_f_max,
	&gdl_em_hadmix_get_f_size
};

const gdl_pstruct_workspace_type * gdl_pstruct_em_hadmixture = &_hview;
