/*  multifit/logit.c
 * 
 * Copyright (C) 2007 Jean-Baptiste Veyrieras
 * 
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or (at
 * your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301, USA.
 */
#include <math.h>

#include <gdl/gdl_common.h>
#include <gdl/gdl_types.h>
#include <gdl/gdl_math.h>
#include <gdl/gdl_vector.h>
#include <gdl/gdl_matrix.h>
#include <gdl/gdl_multifit_nlin.h>

static void
gdl_multifit_logistic_update_g (const double * X,
                                const size_t * y,
                                const double * w,
                                double * g,
                                const size_t N,
                                const size_t M)
{
	size_t i, j;
	double z, yy;
	
	for (j = 0; j < M; g[j] = 0, j++);
	
	for (i = 0; i < N; i++)
	{
	 	for (z = 0, j = 0; j < M; z += w[j]*X[i*M + j], j++);
	 	yy = (y[i]) ? 1 : -1;
	 	z  = 1.0/(1+exp(-yy*z));
	 	z  = 1-z;
	 	for (j = 0; j < M; g[j] += z*yy*X[i*M + j], j++);
	}
	//for (j = 0; j < M; j++) printf ("g %d %g\n", j, g[j]);
}

static void
gdl_multifit_logistic_update_w (const double * X,
                                const double * g,
                                const double * u,
                                const double * wold, 
                                double * wnew,
                                const size_t N,
                                const size_t M)
{                              
   size_t i, j;
   double d, n, z, t;
   
   n = d = 0;
   for (i = 0; i < N; i++)
   {
   	z = t = 0;
   	for (j = 0; j < M; j++)
   	{
   		z += wold[j]*X[i*M + j];
   		t += u[j]*X[i*M + j];
   		if (!i) n += g[j]*u[j];
   	}	
   	z = 1.0/(1+exp(-z));
   	z *= (1-z);
   	t *= t;
   	d += z*t;
   }
   n /= d;
   for (i = 0; i < M; i++)
   {
   	wnew[i] = wold[i] + n*u[i];
   	//printf ("w %d %g\n", i, wnew[i]);
   }
}

static void
gdl_multifit_logistic_update_u (const double * gold,
                                const double * g,
                                const double * uold, 
                                double * u,
                                const size_t M)
{                              
   size_t i, j;
 	double n, d, z, beta;
 	
 	// Hestenes-Stiefel formula
 	n=d=0;
 	for (i = 0; i < M; i++)
 	{
 		z = g[i]-gold[i];
 		n += g[i]*z;
 		d += uold[i]*z;
 	}
 	beta = n/d;
 	n = 0;
 	for (i = 0; i < M; i++)
 	{
 		z = g[i]-uold[i]*beta;
 		u[i] = z;
 		n += z*z; 		
 	}
 	n = sqrt(n);
 	for (d = i = 0; i < M; i++)
 	{
 		u[i] /= n;
 		d += u[i]*u[i];
 		//printf ("u %d %g\n", i, u[i]);
 	}
 	//printf ("norm = %g\n", d);
 	// normalize u...
}

double
gdl_multifit_logistic_loglikelihood (const double * X,
                                     const size_t * y,
                                     const double * w,
                                     const size_t N,
                                     const size_t M)
{
	size_t i, j;
	double z, l=0;
	
	for (i = 0; i < N; i++)
	{
		z = 0;
		for (j = 0; j < M; j++)
		{
			z += w[j]*X[i*M+j];
		}
		z *= (y[i]) ? 1 : -1;
		l -= log (1 + exp(-z));
	}
	
	return l;
}                                     

int
gdl_multifit_logistic (const gdl_matrix      * X,
                       const gdl_vector_uint * y,
                       gdl_vector * w,
                       double * loglikelihood,
                       gdl_multifit_logistic_workspace * work)
{
	const size_t N = X->size1;
	const size_t M = X->size2;
	gdl_vector * tmp;
	size_t i, j;
	double old, new;
	
	for (i = 0; i < X->size1; i++)
	{
		printf ("%d", gdl_vector_uint_get (y, i));
		for (j = 0; j < X->size2; j++)
		{
			printf (" %g", gdl_matrix_get (X, i, j));	
		}
		printf ("\n");
	}

#define SWAP(u,v){tmp=u;u=v;v=tmp;}
	
	gdl_vector_set_zero (work->wold);
	gdl_vector_set_all (work->unew, 1.0/sqrt((double)M));
	
	gdl_multifit_logistic_update_g (X->data, y->data, work->wold->data, work->gnew->data, N, M);
	gdl_multifit_logistic_update_w (X->data, work->gnew->data, work->unew->data, work->wold->data, work->wnew->data, N, M);
	
	old = gdl_multifit_logistic_loglikelihood (X->data, y->data, work->wnew->data, N, M);
	printf ("L(0) %g\n", old);
	for (i = 0; i < work->max_iter; i++)
	{
		// gold <- gnew
		SWAP(work->gold, work->gnew);
		gdl_multifit_logistic_update_g (X->data, y->data, work->wnew->data, work->gnew->data, N, M);
		// uold <- unew 
		SWAP(work->uold, work->unew);
		gdl_multifit_logistic_update_u (work->gold->data, work->gnew->data, work->uold->data, work->unew->data, M);
		// wold <- wnew
		SWAP(work->wold, work->wnew);
		gdl_multifit_logistic_update_w (X->data, work->gnew->data, work->unew->data, work->wold->data, work->wnew->data, N, M);
		new = gdl_multifit_logistic_loglikelihood (X->data, y->data, work->wnew->data, N, M);
		printf ("L(%d) %g %e\n", i+1, new, (old-new)/old);
		if ((old-new)/old < work->epsilon)
		{
			break;	
		}
		old=new;
	}
	
#undef SWAP

	gdl_vector_memcpy (w, work->wnew);
	
	for (i = 0; i < M; i++)
	{
	printf ("W[%d] %g\n", i, gdl_vector_get (w, i));	
	}
	
}

gdl_multifit_logistic_workspace *
gdl_multifit_logistic_alloc (const size_t M)
{
	gdl_multifit_logistic_workspace * w;
	
	w = GDL_CALLOC (gdl_multifit_logistic_workspace, 1);
	
	w->unew = gdl_vector_alloc (M);
	w->uold = gdl_vector_alloc (M);
	w->gnew = gdl_vector_alloc (M);
	w->gold = gdl_vector_alloc (M);
	w->wold = gdl_vector_alloc (M);
	w->wnew = gdl_vector_alloc (M);
	
	w->max_iter = 25;
	w->epsilon  = 1.e-5;
	
	return w;
}

void
gdl_multifit_logistic_free (gdl_multifit_logistic_workspace * w)
{
	if (w)
	{
		gdl_vector_free (w->unew);
		gdl_vector_free (w->uold);
		gdl_vector_free (w->gnew);
		gdl_vector_free (w->gold);
		gdl_vector_free (w->wnew);
		gdl_vector_free (w->wold);
	}
}

