/* Written by Cy Chan, July 2007
 */

#include "mex.h"
#include "string.h"
#include "utilities.h"
#include "computeHG.h"

double computeHGFromLevelIndexTable(double *bySize, double *x, double *y, 
                                    double *a, double *b, int p, int q, 
                                    int n, int N, int *levelIndexTable, 
                                    int stride1, int stride2)
{
    double result = 0, *coefficients = NULL, *schursX = NULL, *schursY = NULL;
    int i, *outputLengths, outputLength, *partitionSizes = NULL;

    /* COMPUTE OUTPUT LENGTHS FROM LEVEL INDEX TABLE */
    
    outputLengths = mxCalloc(n, sizeof(int));
    computeOutputLengths(outputLengths, levelIndexTable, stride1, stride2, n, N);
    outputLength = outputLengths[n - 1];
    
    /* ALLOCATE MEMORY */
    
    coefficients = mxCalloc(outputLength, sizeof(double));
    if (bySize) {
        partitionSizes = mxCalloc(outputLength, sizeof(int));
    }
    schursX = mxCalloc(outputLength, sizeof(double));
    if (y) {
        schursY = mxCalloc(outputLength, sizeof(double));
    }
    
    /* COMPUTE COEFFICIENTS */

    computeCoefficientsFromLevelIndexTable(coefficients, outputLength, 
        a, b, p, q, n, N, levelIndexTable, stride1, stride2, 
        partitionSizes, (y ? 2 : 1));
    
    /* COMPUTE SCHURS */
    
    computeSchursFromLevelIndexTable(schursX, outputLengths, x, n, N, 
                                     levelIndexTable, stride1, stride2);
    if (y) {
        computeSchursFromLevelIndexTable(schursY, outputLengths, y, n, N, 
                                         levelIndexTable, stride1, stride2);
    }

    /* COMPUTE HG FUNCTION FROM COEFFICIENTS AND SCHURS */
    
    if (bySize) {
        for (i = 0; i < outputLength; i++) {
            bySize[partitionSizes[i]] += coefficients[i] * (y ? schursX[i] * schursY[i] : schursX[i]);
        }
        for (i = 0; i < N + 1; i++) {
            result += bySize[i];
        }
    } else {
        for (i = 0; i < outputLength; i++) {
            result += coefficients[i] * (y ? schursX[i] * schursY[i] : schursX[i]);
        }
    }

    /* FREE MEMORY AND RETURN */
    
    mxFree(outputLengths);
    mxFree(coefficients);
    if (bySize) {
        mxFree(partitionSizes);
    }
    mxFree(schursX);
    if (y) {
        mxFree(schursY);
    }
    
    return result;
}

void computeCoefficientsFromLevelIndexTable(double *output, 
                                            int outputLength, double *a, 
                                            double *b, int p, int q, int n,
                                            int N, int *levelIndexTable, 
                                            int stride1, int stride2, 
                                            int *partitionSizes, 
                                            int numMatrixArgs)
{
    int i, *partition, partitionSize, partitionLength, level, backReference;
    
    output[0] = 1;

    /* initialize current partition to [1, 0, 0, ... , 0] */
    partition = mxCalloc(n + 1, sizeof(int));
    partition[0] = 1;
    
    partitionSize = 1;
    partitionLength = 1;
    
    /* loop through back references and do the computation */
    for (i = 2; i <= outputLength; i++) {
        level = 1;
        while (level < n && partition[level - 1] == partition[level]) {
            level++;
        }
        backReference = computeBackReference(0, partition, partitionSize, i, level, N, levelIndexTable, stride1, stride2);
        
        output[i - 1] = updateQ(output[backReference - 1], a, b, p, q, n, partition, partitionLength, level, 0, 0, numMatrixArgs);
        if (partitionSizes) {
            partitionSizes[i - 1] = partitionSize;
        }
        iteratePartition(partition, &partitionSize, &partitionLength, N, n);
    }
    
    mxFree(partition);
}

void computeSchursFromLevelIndexTable(double *output, int *outputLengths, 
                                      double *x, int n, int N, 
                                      int *levelIndexTable, int stride1, 
                                      int stride2)
{
    int i, k, *partition, partitionSize, lastElement, backReference;
    double xProduct = 1, curXPower;
    
    /* compute the product of the entries in X */
    for (k = 0; k < n; k++) {
        xProduct *= x[k];
    }
    
    partition = mxCalloc(n + 1, sizeof(int));
    
    output[0] = 1;
    for (k = 1; k <= n - 1; k++) {
        mulYFromLevelIndexTable(output, outputLengths[k - 1], x[k - 1], k, 
                                n, N, levelIndexTable, stride1, stride2, 
                                partition);
    }
    mulYFromLevelIndexTable(output, outputLengths[n - 2], x[n - 1], n - 1, 
                            n, N, levelIndexTable, stride1, stride2, 
                            partition);
    
    /* initialize partition to first length n partition */
    for (k = 0; k < n; k++) {
        partition[k] = 1;
    }
    partition[n] = 0;
    partitionSize = n;
    lastElement = 1;
    
    /* do block multiplications for partitions with n nonzero rows */
    curXPower = xProduct;
    for (i = outputLengths[n - 2]; i < outputLengths[n - 1]; i++) {
        backReference = computeBackReference(1, partition, partitionSize, i, n, N, levelIndexTable, stride1, stride2);
        output[i] = output[backReference] * curXPower;
        iteratePartition(partition, &partitionSize, NULL, N, n);
        if (partition[n - 1] != lastElement) {
            curXPower *= xProduct;
            lastElement++;
        }
    }
    
    mxFree(partition);
}

void mulYFromLevelIndexTable(double *const output, int maxIndex, double x, 
                             int k, int n, int N, int *levelIndexTable, 
                             int stride1, int stride2, int *partition)
{
    int i, j, partitionSize, backReference;
    
    /* loop through back references and do the computation */
    for (j = k; j >= 1; j--) {
        
        /* initialize current partition to [1, 0, 0, ... , 0] */
        memset(partition, 0, (n + 1) * sizeof(int));
        partition[0] = 1;
        partitionSize = 1;

        /* loop through back references and do the computation */
        for (i = 2; i <= maxIndex; i++) {
            if (j < n && partition[j - 1] > partition[j] || j == n && partition[j - 1] > 0) {
                backReference = computeBackReference(0, partition, partitionSize, i, j, N, levelIndexTable, stride1, stride2);
                output[i - 1] = x * output[backReference - 1] + output[i - 1];
            }
            iteratePartition(partition, &partitionSize, NULL, N, n);
        }
    }
}
