/*  
 * 	phase/em_hv.c
 * 
 *  $Author: baptiste $, $Date: 2008-05-13 15:33:53 $, $Version$
 *
 *  Libgdl : a C library for statistical genetics
 * 
 *  Copyright (C) 2003-2006  Jean-Baptiste Veyrieras, INRA, France.
 *
 *  This program is free software; you can redistribute it and/or modify
 *  it under the terms of the GNU General Public License as published by
 *  the Free Software Foundation; either version 2 of the License, or
 *  (at your option) any later version.
 *
 *  This program is distributed in the hope that it will be useful,
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *  GNU General Public License for more details.
 *
 *  You should have received a copy of the GNU General Public License
 *  along with this program; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA * 
 */
#include <float.h>

#include <gdl/gdl_common.h>
#include <gdl/gdl_errno.h>
#include <gdl/gdl_rng.h>
#include <gdl/gdl_randist.h>
#include <gdl/gdl_gview.h>
#include <gdl/gdl_hview.h>
#include <gdl/gdl_view.h>
#include <gdl/gdl_phase_em.h>

typedef struct
{
	size_t nz;
	size_t np;
	const gdl_gview * g;
	const gdl_mask * m;
	const gdl_rng   * r;
	gdl_hview       * h;
	double res_abs;
	double res_sq;
	double pmt;
	size_t * zidx;
	double * ztot;
	double * uf;
} gdl_phase_em_hview_t;

int
gdl_phase_em_hview_alloc (void * vstate, const gdl_phase_em_param * P)
{
	gdl_phase_em_hview_t * state;
	
	state = (gdl_phase_em_hview_t *) vstate;
	
	state->g = gdl_view_get_gview (P->v);
	
	if (state->g == 0)
	{
		return GDL_FAILURE;	
	}
	
	state->np = gdl_gview_ploidy (state->g);
	
	state->m = P->m;
	
	state->r = P->r;
	
	state->pmt = P->pmt;
	
	if (state->r == 0)
	{
		return GDL_FAILURE;	
	}
	
	state->uf   = NULL;
	state->zidx = NULL;
	state->ztot = NULL;
	
	return GDL_SUCCESS;
}

int
gdl_phase_em_hview_free (void * vstate)
{
	if (vstate)
	{
		gdl_phase_em_hview_t * state 
		    = (gdl_phase_em_hview_t *) vstate;
		
		GDL_FREE (state->zidx);
		GDL_FREE (state->ztot);
		GDL_FREE (state->uf);
		
		return GDL_SUCCESS;
	}
}

static void
gdl_phase_em_hview_rng_init (gdl_phase_em_hview_t * state)
{
	size_t i, j, zh, nh, n;
	double * zf;
	gdl_ran_discrete_t * zd;
	gdl_hconfig * hc;
	
	for (i = 0; i < state->nz; i++)
	{
		nh = gdl_hview_hconfig_size_c (state->h, state->zidx[i]);
		zf = GDL_MALLOC (double, nh);
		for (j = 0; j < nh; j++)
		{
			hc = gdl_hview_get_hconfig_c (state->h, state->zidx[i], j);
			zf[j] = gdl_hconfig_get_proba (hc);
		}
		zd = gdl_ran_discrete_preproc (nh, zf);
		zh = gdl_ran_discrete (state->r, zd);
		for (j = 0; j < nh; j++)
		{
			hc = gdl_hview_get_hconfig_c (state->h, state->zidx[i], j);
			gdl_hconfig_set_proba (hc, (j==zh) ? 1.0 : 0.0);
		}
		gdl_ran_discrete_free (zd);
		GDL_FREE (zf);
	}
}

static int
gdl_phase_em_hview_mem_init (gdl_phase_em_hview_t * state)
{
	size_t i, j, na, n, nh;
	
	GDL_FREE (state->uf);
	GDL_FREE (state->zidx);
	GDL_FREE (state->ztot);
		
	n = gdl_hview_haplotype_size (state->h);
	
	state->uf = GDL_MALLOC (double, n);
	
	na = gdl_hview_accession_size_c (state->h);
	
	for (n = i = 0; i < na; i++)
	{
		nh = gdl_hview_hconfig_size_c (state->h, i);
		if (nh > 1)
		{
			n++;
		}
	}
	
	state->nz   = n;
	state->zidx = GDL_MALLOC (size_t, n);
	state->ztot = GDL_MALLOC (double, n);
	
	for (j = i = 0; i < na; i++)
	{
		nh = gdl_hview_hconfig_size_c (state->h, i);
		if (nh > 1)
		{
			state->zidx[j++]=i;
			if (j == n)
			{
				break;	
			}
		}
	}
	
	return GDL_SUCCESS;
}

int
gdl_phase_em_hview_init_rng (void * vstate)
{
	int status;
	gdl_phase_em_hview_t * state 
		    = (gdl_phase_em_hview_t *) vstate;
	
	state->h = gdl_hview_alloc (state->g, state->m);
	
	status = gdl_hview_create (state->h);
	
	if (status == GDL_SUCCESS)
	{
		gdl_phase_em_hview_mem_init (state);
		gdl_phase_em_hview_rng_init (state);
	}
	
	return GDL_FAILURE;
}

int
gdl_phase_em_hview_init_static (void * vstate, void * start)
{
	int status;
	gdl_phase_em_hview_t * state 
		    = (gdl_phase_em_hview_t *) vstate;
	
	state->h = (gdl_hview *) start;
	
	if (state->h == 0)
	{
		return GDL_FAILURE;
	}
	if (gdl_hview_get_gview (state->h) != state->g)
	{
		return GDL_FAILURE;
	}
	if (gdl_hview_get_gmask (state->h) != state->m)
	{
		return GDL_FAILURE;
	}
	
	return gdl_phase_em_hview_mem_init (state);
}

static void
Estep (gdl_phase_em_hview_t * state)
{
	size_t i, j, k, h, nh;
	double u, s;
	gdl_hconfig * hc;
	
	for (i = 0; i < state->nz; i++)
	{
		nh = gdl_hview_hconfig_size_c (state->h, state->zidx[i]);
		state->ztot[i] = 0;
		for (j = 0; j < nh; j++)
		{
			hc = gdl_hview_get_hconfig_c (state->h, state->zidx[i], j);
			for (u = 1., k = 0; k < state->np; k++)
			{
				h = gdl_hconfig_get_haplotype (hc, k);
				u *= gdl_hview_get_haplotype_freq (state->h, h);
			}
			if (u > DBL_MIN)
			{
				state->ztot[i] += u;
				gdl_hconfig_set_proba (hc, u);
				
			}
			else
			{
				gdl_hconfig_set_proba (hc, 0.0);
			}						
		}
	}
}

static int
Mstep (gdl_phase_em_hview_t * state)
{
	size_t i, iz, j, k, h, na, nh, nn;
	double z, u, t = 0;
	gdl_hconfig * hc;
	
	na = gdl_hview_accession_size_c (state->h);
	
	for (iz = i = 0; i < na; i++)
	{
		nh = gdl_hview_hconfig_size_c (state->h, i);
		nn = gdl_hview_accession_mult_c (state->h, i);
		
		if (nh > 1)
		{
			for (j = 0; j < nh; j++)
			{
				hc = gdl_hview_get_hconfig_c (state->h, i, j);
				
				z  = gdl_hconfig_get_proba (hc);
				z /= state->ztot[iz];
				gdl_hconfig_set_proba (hc, z);
				
				if (z > DBL_MIN)
				{
					for (k = 0; k < state->np; k++)
					{
						h = gdl_hconfig_get_haplotype (hc, k);
						u = z*nn;
						state->uf[h] += u;
						t += u;
					}
				}
			}
			iz++;
		}
		else
		{
			hc = gdl_hview_get_hconfig_c (state->h, i, 0);
			for (k = 0; k < state->np; k++)
			{
				h = gdl_hconfig_get_haplotype (hc, k);
				state->uf[h] += (double)nn;
				t += (double)nn;
			}
		}
	}
	
	nh = gdl_hview_haplotype_size (state->h);
	
	for (i = 0; i < nh; i++)
	{
		z = gdl_hview_get_haplotype_freq (state->h, i);
		u = state->uf[i] / t;
		if (u <= DBL_MIN)
		{
			u = 0.;
		}
		//printf ("F[%d] %g <- %g (%g)\n", i, z, u, fabs(z-u));
		z -= u;
		state->res_abs += fabs (z);
		state->res_sq  += GDL_SQR (z);
		gdl_hview_set_haplotype_freq (state->h, i, u);
		state->uf[i] = 0.;
	}
	
	return GDL_SUCCESS;
}

int 
gdl_phase_em_hview_iterate (void * vstate)
{
	gdl_phase_em_hview_t * state 
		    = (gdl_phase_em_hview_t *) vstate;
	
	Estep (state);
	state->res_abs = state->res_sq = 0;
	return Mstep (state);
}

double
gdl_phase_em_hview_loglikelihood (void * vstate)
{
	gdl_phase_em_hview_t * state = (gdl_phase_em_hview_t *) vstate;
	size_t i, j, k, h, na, nh, nn;
	double z, u, t, s = 0;
	gdl_hconfig * hc;
	
	na = gdl_hview_accession_size_c (state->h);
	
	for (i = 0; i < na; i++)
	{
		nn = gdl_hview_accession_mult_c (state->h, i);
		nh = gdl_hview_hconfig_size_c (state->h, i);
		for (t = j = 0; j < nh; j++)
		{
			hc = gdl_hview_get_hconfig_c (state->h, i, j);
			z  = gdl_hconfig_get_proba (hc);
			for (u = 1.0, k = 0; k < state->np; k++)
			{
				h = gdl_hconfig_get_haplotype (hc, k);
				u *= gdl_hview_get_haplotype_freq (state->h, h);
			}
			t += z*u;
		}
		s += nn*log(t);
	}
	
	return s;
}

double
gdl_phase_em_hview_residual_abs (void * vstate)
{
	gdl_phase_em_hview_t * state 
		    = (gdl_phase_em_hview_t *) vstate;
	return state->res_abs;
}

double
gdl_phase_em_hview_residual_sq (void * vstate)
{
	gdl_phase_em_hview_t * state 
		    = (gdl_phase_em_hview_t *) vstate;
	return state->res_sq;
}

void *
gdl_phase_em_hview_result  (void * vstate)
{
	gdl_phase_em_hview_t * state 
		    = (gdl_phase_em_hview_t *) vstate;
	return state->h;
}

int
gdl_phase_em_hview_fread  (FILE * stream, void * vstate)
{
}

int
gdl_phase_em_hview_fwrite (FILE * stream, const void * vstate)
{
}

static const gdl_phase_em_workspace_type _em_hview =
{
    "gdl_phase_em_hview",
    sizeof (gdl_phase_em_hview_t),
    &gdl_phase_em_hview_alloc,
    &gdl_phase_em_hview_free,
    &gdl_phase_em_hview_init_rng,
    &gdl_phase_em_hview_init_static,
    &gdl_phase_em_hview_iterate,
    &gdl_phase_em_hview_loglikelihood,
    &gdl_phase_em_hview_residual_abs,
    &gdl_phase_em_hview_residual_sq,
    &gdl_phase_em_hview_result,
    &gdl_phase_em_hview_fread,
    &gdl_phase_em_hview_fwrite
};

const gdl_phase_em_workspace_type * gdl_phase_em_hview = &_em_hview;
