/* Written by Cy Chan, July 2007
 */

#include "mex.h"
#include "utilities.h"
#include "computeHG.h"

/* "add" and "mult" are optional parameters passed in when coefficent data 
 * has been precomputed.  Pass in NULL for both parameters if no 
 * preprocessing was done.
 *
 * "y" is an optional parameter to compute the hypergeometric function of 
 * two matrix arguments.  Set to NULL to compute the HG function of one 
 * argument.
 *
 * "bySize" is an optional parameter used to output a breakdown of the HG 
 * function based on size of the partitions included in the sum.  In other 
 * words, bySize(j) equals the HG sum over partitions of size exactly j.  
 * Thus, 1 + sum(bySize) = result, since the contribution of the 
 * zero-weight partition is 1.
 *
 * "partitionSizes" is an optional parameter that is needed to compute the 
 * "bySize" output when "add" and "mult" are supplied.  Since the 
 * partitions are never enumerated, their sizes must be passed in by the 
 * caller in order to separate the summands according to their respective 
 * partitions' sizes.
 */
double computeHGFromTable(double *bySize, double *x, double *y, double *a, 
                          double *b, int p, int q, int n, int N, 
                          int *outputLengths, int *backRefsTable, 
                          int tableStride, int *extraReferences, 
                          int *backRefsArray, int *lastElementIndexTable, 
                          int *add, double *mult, int *partitionSizes)
{
    double result = 0, *coefficients = NULL, *schursX = NULL, *schursY = NULL;
    int i, outputLength = outputLengths[n - 1], *partitionSizesLocal = NULL;

    /* ALLOCATE MEMORY */
    
    coefficients = mxCalloc(outputLength, sizeof(double));
    if (bySize && !partitionSizes) {
        partitionSizesLocal = mxCalloc(outputLength, sizeof(int));
        partitionSizes = partitionSizesLocal;
    }
    schursX = mxCalloc(outputLength, sizeof(double));
    if (y) {
        schursY = mxCalloc(outputLength, sizeof(double));
    }
    
    /* COMPUTE COEFFICIENTS */

    if (!y) {
        computeCoefficientsFromTable(coefficients, outputLengths, a, b, p, 
                                     q, n, N, backRefsTable, tableStride, 
                                     extraReferences, add, mult, 
                                     partitionSizesLocal, 1);
    } else {
        computeCoefficientsFromTable(coefficients, outputLengths, a, b, p, 
                                     q, n, N, backRefsTable, tableStride, 
                                     extraReferences, add, mult, 
                                     partitionSizesLocal, 2);
    }
    
    /* COMPUTE SCHURS */
    
    computeSchursFromTable(schursX, outputLengths, x, n, N, backRefsTable, tableStride, backRefsArray, lastElementIndexTable);
    if (y) {
        computeSchursFromTable(schursY, outputLengths, y, n, N, backRefsTable, tableStride, backRefsArray, lastElementIndexTable);
    }
    
    /* COMPUTE HG FUNCTION FROM COEFFICIENTS AND SCHURS */
    
    if (bySize) {
        for (i = 0; i < outputLength; i++) {
            bySize[partitionSizes[i]] += coefficients[i] * (y ? schursX[i] * schursY[i] : schursX[i]);
        }
        for (i = 0; i < N + 1; i++) {
            result += bySize[i];
        }
    } else {
        for (i = 0; i < outputLength; i++) {
            result += coefficients[i] * (y ? schursX[i] * schursY[i] : schursX[i]);
        }
    }

    /* FREE MEMORY AND RETURN */
    
    mxFree(coefficients);
    if (partitionSizesLocal) {
        mxFree(partitionSizesLocal);
    }
    mxFree(schursX);
    if (y) {
        mxFree(schursY);
    }
    
    return result;
}

void computeCoefficientsFromTable(double *output, int *outputLengths, 
                                  double *a, double *b, int p, int q, 
                                  int n, int N, int *backRefsTable, 
                                  int tableStride, int *extraReferences, 
                                  int *add, double *mult, 
                                  int *partitionSizes, int numMatrixArgs)
{
    int i, *partition, partitionSize, partitionLength, maxTableRow, maxTableColumn;
    
    if (add != NULL && mult != NULL) {
        
        output[0] = 1;

        /* Check to see if table is smaller than desired output */
        
        if (tableStride < outputLengths[n - 1]) {
            maxTableColumn = n - 1;
        } else {
            maxTableColumn = n;
        }
        
        /* loop through back references and compute the coefficients */
        i = 1;
        for (partitionLength = 1; partitionLength <= maxTableColumn; partitionLength++) {
            for ( ; i < outputLengths[partitionLength - 1]; i++) {
                output[i] = updateQ(output[backRefsTable[i + (partitionLength - 1) * tableStride] - 1], a, b, p, q, n, NULL, 0, 0, add[i], mult[i], numMatrixArgs);
            }
        }
        
        /* If we skipped the size n partitions in the above loop, use the extra references to finish */
        if (tableStride < outputLengths[n - 1]) {
            for ( ; i < outputLengths[n - 1]; i++) {
                output[i] = updateQ(output[extraReferences[i - tableStride] - 1], a, b, p, q, n, NULL, 0, 0, add[i], mult[i], numMatrixArgs);
            }
        }
        
    } else {
    
        output[0] = 1;

        /* initialize current partition to [1, 0, 0, ... , 0] */
        partition = mxCalloc(n, sizeof(int));
        partition[0] = 1;

        partitionSize = 1;
        partitionLength = 1;

        /* set the number of iterations that use the backRefsTable */
        if (tableStride < outputLengths[n - 1]) {
            maxTableRow = tableStride;
        } else {
            maxTableRow = outputLengths[n - 1];
        }

        /* loop through partitions and use back references stored in the table */
        for (i = 1; i < maxTableRow; i++) {
            output[i] = updateQ(output[backRefsTable[i + (partitionLength - 1) * tableStride] - 1], a, b, p, q, n,
                                partition, partitionLength, partitionLength, 0, 0, numMatrixArgs);
            if (partitionSizes) {
                partitionSizes[i] = partitionSize;
            }
            iteratePartition(partition, &partitionSize, &partitionLength, N, n);
        }

        /* if need more coefficients than there are table rows, loop through extra partitions using extraReferences */
        if (tableStride < outputLengths[n - 1]) {
            for (i = tableStride; i < outputLengths[n - 1]; i++) {
                output[i] = updateQ(output[extraReferences[i - tableStride] - 1], a, b, p, q, n,
                                    partition, partitionLength, partitionLength, 0, 0, numMatrixArgs);
                if (partitionSizes) {
                    partitionSizes[i] = partitionSize;
                }
                iteratePartition(partition, &partitionSize, &partitionLength, N, n);
            }
        }

        mxFree(partition);
        
    }
}

void computeSchursFromTable(double *output, int *outputLengths, double *x, 
                            int n, int N, int *backRefsTable, 
                            int tableStride, int *backRefsArray, 
                            int *lastElementIndexTable)
{
    int i, k, lastElement;
    double xProduct = 1, curXPower;
    
    /* compute the product of the entries in X */
    for (k = 0; k < n; k++) {
        xProduct *= x[k];
    }
    
    output[0] = 1;

    for (k = 1; k <= n - 1; k++) {
        mulYFromTable(output, outputLengths[k - 1], x[k - 1], k, backRefsTable, tableStride);
    }
    mulYFromTable(output, outputLengths[n - 2], x[n - 1], n - 1, backRefsTable, tableStride);
    
    /* do block multiplications for partitions with n nonzero rows */
    i = outputLengths[n - 2];
    curXPower = xProduct;
    for (lastElement = 1; lastElement <= N / n; lastElement++, curXPower *= xProduct) {
        for ( ; i < lastElementIndexTable[(lastElement - 1) + (n - 1) * N]; i++) {
            output[i] = output[backRefsArray[i] - 1] * curXPower;
        }
    }
    
}

void mulYFromTable(double *const output, int maxIndex, double x, int k, 
                   int *backRefsTable, int tableStride)
{
    int i, j, *tablePointer;
    
    /* loop through back references by column */
    for (j = k - 1; j >= 0; j--) {
        tablePointer = backRefsTable + j * tableStride; /* point to begining of table column */
        for (i = 0; i < maxIndex; i++, tablePointer++) {
            if (*tablePointer) {
                output[i] = x * output[*tablePointer - 1] + output[i];
            }
        }
    }
}
