#include <stdio.h>
#include <stdlib.h>

#include "../swsse2.h" 
#include "swstriped.h"

void spu_subs(vector signed short *, vector signed short, vector signed short);
void spu_adds(vector signed short *result, vector signed short a, vector signed short b);

int
swStripedWord (unsigned char   *querySeq,
               int              queryLength,
               unsigned char   *dbSeq,
               int              dbLength,
               unsigned short   gapOpen,
               unsigned short   gapExtend,
               vector signed short         *queryProf,
               vector signed short         *pvH1,
               vector signed short         *pvH2,
               vector signed short         *pvE);


int
swStripedByte (unsigned char   *querySeq,
               int              queryLength,
               unsigned char   *dbSeq,
               int              dbLength,
               unsigned short   gapOpen,
               unsigned short   gapExtend,
               vector unsigned short         *queryProf,
               vector unsigned short         *pvH1,
               vector unsigned short         *pvH2,
               vector unsigned short         *pvE,
               unsigned short   bias);

//void *
//swStripedInit(unsigned char   *querySeq,
//              int              queryLength,
//              signed char     *matrix)
//{
//	SwStripedData *pSwData;
//    return pSwData;
//}
void *
swStripedInit(unsigned char   *querySeq,
              int              queryLength,
              signed char     *matrix)
{
    int i, j, k;

    int segSize;
    int nCount;

    int bias;

    int lenQryByte;
    int lenQryShort;

    int weight;

    short *ps;
    char *pc;

    signed char *matrixRow;

    size_t aligned;

    SwStripedData *pSwData;
 
    lenQryByte = (queryLength + 15) / 16;
    lenQryShort = (queryLength + 7) / 8;

    pSwData = (SwStripedData *) malloc (sizeof (SwStripedData));
    if (!pSwData) {
        fprintf (stderr, "Unable to allocate memory for SW data\n");
        exit (-1);
    }

    nCount = 64 +                             /* slack bytes */
             lenQryByte * ALPHA_SIZE +        /* query profile byte */
             lenQryShort * ALPHA_SIZE +       /* query profile short */
             (lenQryShort * 3);               /* vH1, vH2 and vE */

    pSwData->pData = (unsigned char *) calloc (nCount, sizeof (vector unsigned short));
    if (!pSwData->pData) {
        fprintf (stderr, "Unable to allocate memory for SW data buffers\n");
        exit (-1);
    }

    /* since we might port this to another platform, lets align the data */
    /* to 16 byte boundries ourselves */
    aligned = ((size_t) pSwData->pData + 15) & ~(0x0f);

    pSwData->pvbQueryProf = (vector unsigned short *) aligned;
    pSwData->pvsQueryProf = pSwData->pvbQueryProf + lenQryByte * ALPHA_SIZE;

    pSwData->pvH1 = pSwData->pvsQueryProf + lenQryShort * ALPHA_SIZE;
    pSwData->pvH2 = pSwData->pvH1 + lenQryShort;
    pSwData->pvE  = pSwData->pvH2 + lenQryShort;

    /* Use a scoring profile for the SSE2 implementation, but the layout
     * is a bit strange.  The scoring profile is parallel to the query, but is
     * accessed in a stripped pattern.  The query is divided into equal length
     * segments.  The number of segments is equal to the number of elements
     * processed in the SSE2 register.  For 8-bit calculations, the query will
     * be divided into 16 equal length parts.  If the query is not long enough
     * to fill the last segment, it will be filled with neutral weights.  The
     * first element in the SSE register will hold a value from the first segment,
     * the second element of the SSE register will hold a value from the
     * second segment and so on.  So if the query length is 288, then each
     * segment will have a length of 18.  So the first 16 bytes will  have
     * the following weights: Q1, Q19, Q37, ... Q271; the next 16 bytes will
     * have the following weights: Q2, Q20, Q38, ... Q272; and so on until
     * all parts of all segments have been written.  The last seqment will
     * have the following weights: Q18, Q36, Q54, ... Q288.  This will be
     * done for the entire alphabet.
     */

    /* Find the bias to use in the substitution matrix */
    bias = 127;
    for (i = 0; i < ALPHA_SIZE * ALPHA_SIZE; i++) {
        if (matrix[i] < bias) {
            bias = matrix[i];
        }
    }
    if (bias > 0) {
        bias = 0;
    }

    /* Fill in the byte query profile */
    pc = (char *) pSwData->pvbQueryProf;
    segSize = (queryLength + 15) / 16;
    nCount = segSize * 16;
    for (i = 0; i < ALPHA_SIZE; ++i) {
        matrixRow = matrix + i * ALPHA_SIZE;
        for (j = 0; j < segSize; ++j) {
            for (k = j; k < nCount; k += segSize) {
                if (k >= queryLength) {
                    weight = 0;
                } else {
                    weight = matrixRow[*(querySeq + k)];
                }
                *pc++ = (char) (weight - bias);
            }
        }
    }

    /* Fill in the short query profile */
    ps = (short *) pSwData->pvsQueryProf;
    segSize = (queryLength + 7) / 8;
    nCount = segSize * 8;
    for (i = 0; i < ALPHA_SIZE; ++i) {
        matrixRow = matrix + i * ALPHA_SIZE;
        for (j = 0; j < segSize; ++j) {
            for (k = j; k < nCount; k += segSize) {
                if (k >= queryLength) {
                    weight = 0;
                } else {
                    weight = matrixRow[*(querySeq + k)];
                }
                *ps++ = (unsigned short) weight;
            }
        }
    }

    pSwData->bias = (unsigned short) -bias;

    return pSwData;
}

void
swStripedComplete(void *pSwData)
{
    SwStripedData *pStripedData = (SwStripedData *) pSwData;

    free (pStripedData->pData);
    free (pStripedData);
}

int
swStripedWord(unsigned char   *querySeq,
              int              queryLength,
              unsigned char   *dbSeq,
              int              dbLength,
              unsigned short   gapOpen,
              unsigned short   gapExtend,
              vector signed short         *pvQueryProf,
              vector signed short         *pvHLoad,
              vector signed short         *pvHStore,
              vector signed short         *pvE)
{	
    int     i, j;
    int     score;

//    int     cmp;
    int     iter = (queryLength + 7) / 8;
    
    vector signed short *pv;

    vector signed short vE, vF, vH;

    vector signed short vMaxScore;
    vector signed short vGapOpen;
    vector signed short vGapExtend;

    vector signed short vMin;
    vector signed short vMinimums;
    vector signed short vTemp;

    vector signed short *pvScore;

    /* Load gap opening penalty to all elements of a constant */
	vGapOpen = spu_splats((signed short)gapOpen);

    /* Load gap extension penalty to all elements of a constant */
	vGapExtend = spu_splats((signed short)gapExtend);

    /*  load vMaxScore with the zeros.  since we are using signed */
    /*  math, we will bias the maxscore to -32768 so we have the */
    /*  full range of the short. */
	vMaxScore = spu_splats((signed short)0xC000);
	vMinimums = spu_splats((signed short)0xC000);; 
	vMin = spu_slqwbyte(vMinimums, 14);
    
    for (i = 0; i < iter; i++)
    {
		*(pvE + i) = vMaxScore;
    	*(pvHStore + i) = vMaxScore;
    }
    for (i = 0; i < dbLength; ++i)
    {
    	/* fetch first data asap. */
        pvScore = pvQueryProf + dbSeq[i] * iter;
		vF = spu_splats((signed short)0xC000);
		vH = *(pvHStore + iter - 1);
        vH = spu_rlmaskqwbyte(vH, -2);
        
        vH = spu_or (vH, vMin);
		    	
        pv = pvHLoad;
        pvHLoad = pvHStore;
        pvHStore = pv;

        for (j = 0; j < iter; j++)
        {
            /* load values of vF and vH from previous row (one unit up) */
			vE = *(pvE + j);
            /* add score to vH */
			vector signed short temp1 = *(pvScore++);			
			vH = spu_add(vH, temp1);
			vH = spu_sel(vMinimums, vH, spu_cmpgt(vH, vMinimums));
            /* Update highest score encountered this far */
			vMaxScore = spu_sel(vMaxScore, vH, spu_cmpgt(vH, vMaxScore));

            /* get max from vH, vE and vF */
			vH = spu_sel(vE, vH, spu_cmpgt(vH, vE));
			vH = spu_sel(vF, vH, spu_cmpgt(vH, vF));

            /* save vH values */
			*(pvHStore + j) = vH;

            /* update vE value */
			vH = spu_sub(vH, vGapOpen);
			
			vH = spu_sel(vMinimums, vH, spu_cmpgt(vH, vMinimums));
            vE = spu_sub(vE, vGapExtend);
            vE = spu_sel(vE, vH, spu_cmpgt(vH, vE));

            /* update vF value */
			vF = spu_sub(vF, vGapExtend);
			vH = spu_sel(vMinimums, vF, spu_cmpgt(vF, vMinimums));
            vF = spu_sel(vF, vH, spu_cmpgt(vH, vF));
			
            /* save vE values */
			*(pvE + j) = vE;

            /* load the next h value */
			vH = *(pvHLoad + j);
        }

        /* reset pointers to the start of the saved data */
        j = 0;
//        vH = _mm_load_si128 (pvHStore + j);
		vH = *(pvHStore + j);

        /*  the computed vF value is for the given column.  since */
        /*  we are at the end, we need to shift the vF value over */
        /*  to the next column. */
		vF = spu_rlmaskqwbyte(vF, -2);
        vF = spu_or(vF, vMin);
                
        vTemp = spu_sub (vH, vGapOpen); 
        
        vTemp = spu_sel(vMinimums, vTemp, spu_cmpgt(vTemp, vMinimums));
        vTemp = (vector signed short) spu_cmpgt (vF, vTemp);
        
        vector signed short temp1 = spu_and(vTemp, (signed short)0x8000);
        signed long long temp2 = spu_extract(temp1, 0);
        signed long long temp3 = spu_extract(temp1, 1);
        while ((temp2!=(signed long long)0)||(temp3!=(signed long long)0)) 
        {
			vE = *(pvE + j);

            vH = spu_sel(vF, vH, spu_cmpgt(vH, vF));

            /* save vH values */
			*(pvHStore + j) = vH;

            /*  update vE incase the new vH value would change it */
			vH = spu_sub (vH, vGapOpen);
			vH = spu_sel(vMinimums, vH, spu_cmpgt(vH, vMinimums));
            vE = spu_sel(vE, vH, spu_cmpgt(vH, vE));
            *(pvE + j) = vE;

            /* update vF value */
			vF = spu_sub (vF, vGapExtend);
			vF = spu_sel(vMinimums, vF, spu_cmpgt(vF, vMinimums));

            j++;
            if (j >= iter)
            {
                j = 0;
				vF = spu_rlmaskqwbyte (vF, -2);
                vF = spu_or (vF, vMin);
            }

			vH = *(pvHStore + j); 

            vTemp = spu_sub (vH, vGapOpen);
            vTemp = spu_sel(vMinimums, vTemp, spu_cmpgt(vTemp, vMinimums));
            vTemp = (vector signed short) spu_cmpgt (vF, vTemp);
            temp1 = spu_and(vTemp, (signed short)0x8000);
        	temp2 = spu_extract(temp1, 0);
        	temp3 = spu_extract(temp1, 1);
			
        }
    }

    /* find largest score in the vMaxScore vector */
    vTemp = spu_rlmaskqwbyte(vMaxScore, -8);
    vMaxScore = spu_sel(vMaxScore, vTemp, spu_cmpgt(vTemp, vMaxScore));
    vTemp = spu_rlmaskqwbyte(vMaxScore, -4);
    vMaxScore = spu_sel(vMaxScore, vTemp, spu_cmpgt(vTemp, vMaxScore));
    vTemp = spu_rlmaskqwbyte(vMaxScore, -2);
    vMaxScore = spu_sel(vMaxScore, vTemp, spu_cmpgt(vTemp, vMaxScore));

    /* store in temporary variable */
    score = (short) spu_extract (vMaxScore, 7);

    /* return largest score */
    return score + SHORT_BIAS;  
}

void spu_subs(vector signed short *result, vector signed short a, vector signed short b)
{
  vector signed short s;
  vector signed short d;
  
  
  s = spu_sub(a, b);
  
  vector signed short temp1 = spu_splats((signed short)0x7FFF); 
  vector unsigned short temp2 = (vector unsigned short)(spu_rlmaska(spu_nor(a, spu_nand(s, b)), -15)); 
  
  d = spu_sel(s, temp1, temp2);
  d = spu_sel(d, (vector signed short) spu_splats((signed short)0x8000), (vector unsigned short)(spu_rlmaska(spu_and(a, spu_nor(s, b)), -15)));

  *result = d;
  //return d;
}

void spu_adds(vector signed short *result, vector signed short a, vector signed short b){
  vector signed short s;
  vector signed short d;
  
  s = spu_add(a, b);
  d = spu_sel(s, spu_splats((signed short)0x7FFF), (vector unsigned short)(spu_rlmaska(spu_and(s, spu_nor(a, b)), -15)));
  d = spu_sel(d, (vector signed short) spu_splats((signed short)0x8000), (vector unsigned short)(spu_rlmaska(spu_nor(s, spu_nand(a, b)), -15)));
  *result = d;
}
