/* Written by Cy Chan, July 2007
 */

#include "mex.h"
#include "string.h"
#include "utilities.h"

double updateQ(double Q, double *a, double *b, int p, int q, int n,
               int *partition, int partitionLength, int elementDecremented, 
               int add, double mult, int numMatrixArgs)
{
    int j;
    
    /* Check to see if we need to compute add and mult */
    if (partition) {
        computeCoefficientData(&add, &mult, partition, n, partitionLength, elementDecremented, numMatrixArgs);
    }
    for (j = 0; j < p; j++) {
        Q *= a[j] + add;
    }
    for (j = 0; j < q; j++) {
        Q /= b[j] + add;
    }
    Q *= mult;
    
    return Q;
}

void precomputeCoeffData(int *add, double *mult, int *partitionSizes, 
                         int outputLength, int n, int N, 
                         int numMatrixArgs)
{
    int i, *partition, partitionSize, partitionLength;

    /* initialize current partition to [1, 0, 0, ... , 0] */
    partition = mxCalloc(n, sizeof(int));
    partition[0] = 1;
    
    partitionSize = 1;
    partitionLength = 1;
    
    /* Loop through back references and save some precomputed coefficient data to add and mult arrays.
     * Note that we always decrement the last non-zero element of the partition. */
    for (i = 1; i < outputLength; i++) {
        computeCoefficientData(add + i, mult + i, partition, n, partitionLength, partitionLength, numMatrixArgs);
        partitionSizes[i] = partitionSize;
        iteratePartition(partition, &partitionSize, &partitionLength, N, n);
    }
    
    mxFree(partition);
}

/* Computes data necessary to efficiently compute the coefficient 
 * associated with a given partition.
 *   -based on algorithm by Raymond Kan
 *
 * In Raymond's notation:
 *   partitionLength = h
 *   elementDecremented = k
 *
 * This function assumes alpha = 1 and k = h (i.e. the last element of the 
 * partition was decremented to obtain the back reference).
 *
 * Precomputed add and mult data are optional for faster performance.
 * If add and mult are passed in, set partition == NULL (partitionLength 
 * and elementDecremented are ignored).  If partition != NULL, then add and 
 * mult are ignored.
 */
void computeCoefficientData(int *add, double *mult, int *partition, int n,
                            int partitionLength, int elementDecremented, 
                            int numMatrixArgs)
{
    int j, temp, delta;
    double num = 1, denom = 1;

    if (numMatrixArgs == 1) {
        
        temp = partition[elementDecremented - 1] - elementDecremented;
        *add = temp;

        if (elementDecremented == partitionLength) {
            for (j = 1; j < partitionLength; j++) {
                delta = partition[j - 1] - j - temp;
                num *= delta;
                denom *= delta + 1;
            }
        } else {
            for (j = 1; j <= partitionLength; j++) {
                if (j != elementDecremented) {
                    delta = partition[j - 1] - j - temp;
                    num *= delta;
                    denom *= delta + 1;
                }
            }
        }

        denom *= temp + partitionLength;
        *mult = num / denom;
        
    } else {
        
        temp = partition[elementDecremented - 1] - elementDecremented;
        *add = temp;

        *mult = 1 / (double) (n + temp);
        
    }
}

/* This function assumes that the dimensions of levelIndexTable are N x n x N */
void buildBackReferenceTable(int *backRefsTable, int *extraReferences, 
                             int *backRefsArray, 
                             int *lastElementIndexTable, int N, int n, 
                             int *levelIndexTable, int *outputLengths)
{
    int *partition = NULL, partitionSize, partitionLength, m, k, 
        tableStride, curLength, curLastElement;
    
    /* keep an extra zero at the end of the partition array for use in computeBackReference */
    partition = mxCalloc(n + 1, sizeof(int));
    
    /* initialize partition to the second partition in the set (first is all zeros) */
    partition[0] = 1;
    partitionLength = 1;
    partitionSize = 1;
    curLength = 1;
    curLastElement = 1;

    tableStride = outputLengths[n - 2];
    
    /* backRefsTable is assumed to be initialized to all zeros.
     * start at m = 2 since first partition (all zeros) has no work to be done. */
    for (m = 2; m <= outputLengths[n - 1]; m++) {
        if (m <= outputLengths[n - 2]) {
            for (k = 1; k < n; k++) {
                if (k < n && partition[k - 1] > partition[k] || k == n && partition[k - 1] > 0) {
                    backRefsTable[(m - 1) + (k - 1) * tableStride] = computeBackReference(0, partition, partitionSize, m, k, N, levelIndexTable, N, N * n);
                }
            }
        } else {
            extraReferences[m - outputLengths[n - 2] - 1] = computeBackReference(0, partition, partitionSize, m, partitionLength, N, levelIndexTable, N, N * n);
        }
        
        backRefsArray[m - 1] = computeBackReference(1, partition, partitionSize, m, partitionLength, N, levelIndexTable, N, N * n);

        if (partitionLength != curLength || partition[partitionLength - 1] != curLastElement) {
            lastElementIndexTable[(curLastElement - 1) + (curLength - 1) * N] = m - 1;
            curLength = partitionLength;
            curLastElement = partition[partitionLength - 1];
        }
        
        iteratePartition(partition, &partitionSize, &partitionLength, N, n);
    }

    /* set final entry of the lastElementIndexTable */
    lastElementIndexTable[(curLastElement - 1) + (curLength - 1) * N] = outputLengths[n - 1];

    
    mxFree(partition);
}

/* WARNING: the size of allocated partition array must be n + 1 with the 
 * last position set to zero for this function to work properly when k = n.
 */
int computeBackReference(int mode, int *partition, int partitionSize, 
                         int curIndex, int k, int N, int *levelIndexTable, 
                         int stride1, int stride2)
{
    int backReference = curIndex, *tableData = NULL, subSize, level, index, lastElement;
    
    if (mode != 0 && (partition[k - 1] == 0 || partition[k] != 0)) {
        mexErrMsgTxt("k must be position of last nonzero element of partition");
    }
    
    subSize = N - partitionSize;
    for (level = 1; level <= k; level++) {
        subSize += partition[level - 1];
        if (partition[level - 1] > partition[level]) {
            index =   ((subSize - level * partition[level]) - 1)
                    + (level - 1) * stride1
                    + ((partition[level - 1] - partition[level]) - 1) * stride2;
            backReference -= levelIndexTable[index];
        }
    }
    
    /* Temporarily change partition to its back reference.  If mode == 0, 
     * then decrement the kth element of partition.  If mode != 0, then k 
     * should be the position of the last nonzero element of the partition.  
     * We then subtract the last element from the entire partition. */
    if (mode == 0) {
        partition[k - 1]--;
    } else {
        lastElement = partition[k - 1];
        for (level = 1; level <= k; level++) {
            partition[level - 1] -= lastElement;
        }
    }
    
    for (level = k; level >= 1; level--) {
        if (partition[level - 1] > partition[level]) {
            index =   ((subSize - level * partition[level]) - 1)
                    + (level - 1) * stride1
                    + ((partition[level - 1] - partition[level]) - 1) * stride2;
            backReference += levelIndexTable[index];
        }
        subSize -= partition[level - 1];
    }
    
    /* change partition back to its original value */
    if (mode == 0) {
        partition[k - 1]++;
    } else {
        for (level = 1; level <= k; level++) {
            partition[level - 1] += lastElement;
        }
    }
    
    return backReference;
}

/* updates to partitionLength are optional.  set to NULL to ignore. */
bool iteratePartition(int *const partition, int *const partitionSize, 
                      int *const partitionLength, const int N, const int n)
{
    int i, subSize, level;
    bool successFlag = true;
    
    if (*partitionSize < N) {
        partition[0]++;
        (*partitionSize)++;
    } else {
        subSize = N - (*partitionSize - partition[0] - partition[1]);
        level = 2;
        while (level < n && (partition[level - 1] + 1) * level > subSize) {
            level = level + 1;
            subSize = subSize + partition[level - 1];
        }
        if (level <= n && (partition[level - 1] + 1) * level <= subSize) {
            for (i = 0; i < level; i++) {
                partition[i] = partition[level - 1] + 1;
            }
            *partitionSize = (N - subSize) + partition[level - 1] * level;
            if (partitionLength && level > *partitionLength) {
                *partitionLength = level;
            }
        } else {
            successFlag = false;
        }
    }
    
    return successFlag;
}

/* Builds a table (of size maxN x maxn x maxN) such that table(N, n, k) is 
 * the number of partitions in the set Phi(maxN, maxn) that occur before 
 * the first partition with k in the nth position (when Phi(maxN, maxn) is 
 * ordered in reverse lexicographic order). 
 * 
 * As with buildPhiTable, time to compute is approximately 
 * O(maxN^2 ln maxN).
 *
 * Also computes array outputLengths (of length n) such that 
 * outputLengths(j) is the number of partitions of size (weight) at most 
 * maxN and length at most j.
 */
void buildLevelIndexTable(int *levelIndexTable, int *outputLengths, int maxN, int maxn)
{
    int *phi, N, n, i, k, curIndex, maxK, stride1 = maxN, stride2 = maxN * maxn, maxCol;
    
    phi = mxCalloc(maxN * maxn, sizeof(int));
    buildPhiTable(phi, maxN, maxn);
    
    for (N = 1; N <= maxN; N++) {
        for (i = 1; i <= N; i++) {
            levelIndexTable[(N - 1) + (i - 1) * stride2] = i;
        }
        curIndex = N + 1;
        
        /* determine maximum column to iterate through */
        maxCol = maxn;
        if (N < maxn) {
            maxCol = N;
        }
        
        /* iterate through columns of table */
        for (n = 2; n <= maxCol; n++) {
            maxK = N / n;
            levelIndexTable[(N - 1) + (n - 1) * stride1] = curIndex;
            if (n < N) {
                curIndex += phi[(N - n - 1) + (n - 2) * stride1];
                for (k = 2; k <= maxK; k++) {
                    levelIndexTable[(N - 1) + (n - 1) * stride1 + (k - 1) * stride2] = curIndex;
                    if (n * k != N) {
                        curIndex += phi[(N - n * k - 1) + (n - 2) * stride1];
                    } else {
                        curIndex++;
                    }
                }
            }
        }
    }
    
    if (outputLengths) {
        for (i = 0; i < maxn; i++) {
            outputLengths[i] = phi[(maxN - 1) + i * stride1];
        }
    }
    
    mxFree(phi);
    
}

/* phi = buildPhiTable(maxN)
 * 
 * Builds table of phi(N, n) up to maxN by maxN.
 * 
 * phi(N, n) is defined to be the number of length n partitions (where 0's
 * are allowed) of size at most N.
 * 
 * \phi(0, n) = \phi(N, 0) = 1, for all N, n.
 * \phi(N, n) = \sum_{i = 0}^{floor(N / n)} phi(N - i * n, n - 1), for N > 0 and n > 0.
 * 
 * Time to compute is approximately O(maxN^2 ln maxN).
 */
void buildPhiTable(int *phi, int maxN, int maxn)
{
    int N, n, i, k, sum, maxCol;
    
    /* for N = 1, there are only (0, 0, ...) and (1, 0, 0, ...) for all n */
    for (i = 0; i < maxn; i++) {
        phi[i * maxN] = 2;
    }
    
    for (N = 2; N <= maxN; N++) {
        phi[N - 1] = N + 1;
        
        /* determine maximum column to iterate through */
        maxCol = maxn;
        if (N < maxn) {
            maxCol = N;
        }
        
        /* iterate through columns of table */
        for (n = 2; n <= maxCol; n++) {
            sum = phi[N - 1 + (n - 2) * maxN]; /* start with the number of partitions when lastElement == 0 */
            
            /* loop through possible values for k = the last element in the partition */
            for (k = 1; k <= N / n; k++) {
                if (n * k == N) {
                    sum++;
                } else {
                    sum += phi[(N - k * n - 1) + (n - 2) * maxN];
                }
            }
            phi[(N - 1) + (n - 1) * maxN] = sum;
        }
        
        /* Set remaining values to last value */
        for (i = N; i < maxn; i++) {
            phi[(N - 1) + i * maxN] = phi[(N - 1) + (N - 1) * maxN];
        }
    }
}

void computeOutputLengths(int *outputLengths, int *levelIndexTable, int stride1, int stride2, int n, int N)
{
    int i, j, subSize = N;
    
    for (i = 1; i < n; i++) {
        outputLengths[i - 1] = levelIndexTable[(N - 1) + i * stride1];
    }
    outputLengths[n - 1] = 1;
    for (j = n; j >= 1 && subSize > 0; j--) {
        if (subSize / j > 0) {
            outputLengths[n - 1] += levelIndexTable[(subSize - 1) + (j - 1) * stride1 + (subSize / j - 1) * stride2];
            subSize -= j * (subSize / j);
        }
    }
}

void mexPrintArray(const int *const array, const int length)
{
    int i;

    if (length > 0) {
        mexPrintf("[%d", array[0]);
        for (i = 1; i < length; i++) {
            mexPrintf(" %d", array[i]);
        }
        mexPrintf("]");
    } else {
        mexPrintf("[]");
    }
    
    return;
}

void mexPrintDoubleArray(const double *const array, const int length)
{
    int i;

    if (length > 0) {
        mexPrintf("[%.4g", array[0]);
        for (i = 1; i < length; i++) {
            mexPrintf(" %.4g", array[i]);
        }
        mexPrintf("]");
    } else {
        mexPrintf("[]");
    }
    
    return;
}
