function [ pMem mMu mSig mProb llTrace dist] = selapMix( dat,ss )
% accepts a data set, dat, which consists of an NxP matrix where N is the
% number of observations from a space of dimension P.  These are assumed to
% originate from a mixture of normal distributions.  The number and prior
% means for the distributions are computed by affinity propogation.  Return
% values are: mem - NxK matrix of membership probabilities for each
% observation and each element in the mixture, mu - PxK dimensional matrix
% containing the posterior mean for the location of each of the K elements in the mixture, sigma -
% (PxP)xK matrix of covariance matricies.  Each column contains the
% elements in the covariance matrix for its corresponding mixture element.

'Calculating similarities matrix.'
[N P]=size(dat);
S=zeros(N,N);
for(i=1:N)
    for(j=1:N)
        S(i,j)=-sqrt(sum((dat(i,:)-dat(j,:)).^2));
    end;
end;

if(nargin==1)
    ss=median(S(:));
end;

'Computing affinity propogation clusters.'
idx=affProp(S,ss,.9);

['There are ' num2str(length(unique(idx))) ' groups.']

seeds=unique(idx);
K=length(seeds);
dm=size(dat,2);

%%% hyper-parameters
df=P+1;
df=15;
kappa=.01;
priorProb=ones(1,K);
prob=priorProb/sum(priorProb);

%%% initialize
priorBg=.01;
pMu=zeros(size(dat(seeds,:)'));
sigma=eye(dm);
pSigma=zeros(dm^2,K);
for(i=1:K)
    pSigma(:,i)=sigma(:);
end;
pMu=dat(seeds,:)';
% for(i=1:length(seeds))
%     pMu(:,i)=mean(dat(idx==seeds(i),:))';
%     tmp=cov(dat(idx==seeds(i),:));
%     pSigma(:,i)=tmp(:);
% end;
priorMu=pMu;
kSig=pSigma;
mem=zeros(size(idx));
for(i=1:length(seeds))
    mem(idx==seeds(i))=i;
end;

'Fitting mixture distribution.'
iter=1000;
burnin=100;
pMem=zeros(N,K);
mMu=zeros(size(pMu));
mSig=zeros(size(kSig));
mProb=zeros(1,K);
sig=zeros(size(kSig));
vdf=zeros(K,1);
llTrace=zeros(iter,1);
dist=zeros(iter,1);
% meanLL=0;
for(i=1:iter)
    for(j=1:K)
        prob(j)=priorProb(j)+sum(mem==j);
    end;
    prob=prob/sum(prob);
    if(i>burnin)
        mProb=mProb+prob;
    end;
    for(j=1:K)
        jdx=find(mem==j);
        nj=length(jdx);
        if(nj>0)
            z=dat(jdx,:)';
            mz=mean(z,2);
            z=z-repmat(mz,1,nj);
        end;
        sigma(:)=pSigma(:,j);
        if(nj>0)
            sigma=sigma+z*z'+kappa*nj/(kappa+nj)*(mz-priorMu(:,j))*(mz-priorMu(:,j))';
        end;
        sigma=sigma*(kappa+nj+1)/((kappa+nj)*(df+nj-P+1));
        vdf(j)=df+nj-P+1;
        kSig(:,j)=sigma(:);

        pMu(:,j)=kappa/(kappa+nj)*priorMu(:,j);
        if(nj>0)
            pMu(:,j)=pMu(:,j)+nj/(kappa+nj)*mz;
        end;

%%% need to update dimension subset
%         dimkp(:,j)=mvtDimensionSubset(pMu,kSig,vdf,dat(jdx,:)');


        sig(:,j)=sigma(:);   %%% only used for plotting
        if(i>burnin)
            pMem(jdx,j)=pMem(jdx,j)+1;
            mSig(:,j)=mSig(:,j)+sigma(:);
            mMu(:,j)=mMu(:,j)+pMu(:,j);
%             meanLL=meanLL+mvnMixtureLogLikelihood(dat,pMu,kSig,prob);
        end;
    end;
    mem=mvtMixtureMembership(pMu,kSig,vdf, prob, dat');
    if(mod(i,100)==0)
        [num2str(i) '/' num2str(iter)]
    end;
%     %%% plotting
    if(mod(i,50)==0)
        col=colormap();
        col=col(1:floor(size(col,1)/size(pMem,2)):size(col,1),:);
        for(jnk=1:(size(dat,2)/2))
            figure(jnk);
%             clf('reset');
            clf;
            hold on;
            tmp=zeros(size(sigma));
            for(cnt=1:max(mem))
                plot(dat(find(mem==cnt),2*jnk-1),dat(find(mem==cnt),2*jnk),'.','color',col(cnt,:));
                tmp(:)=sig(:,cnt);
                plotMVN(pMu([2*jnk-1 2*jnk],cnt),tmp([2*jnk-1 2*jnk],[2*jnk-1 2*jnk]),col(cnt,:));
            end;
            hold off;
        end;
        pause(.01);
    end;
%     %%%%%%%%%%%%%
    llTrace(i)=mvnMixtureLogLikelihood(dat,pMu,kSig,prob);
    r=mvnMixtureRnd(size(dat,1),pMu,kSig,prob);
    dist(i)=mvnMixtureLogLikelihood(r,pMu,kSig,prob);
end;
pMem=pMem/(iter-burnin);
mSig=mSig/(iter-burnin);
mMu=mMu/(iter-burnin);
mProb=mProb/(iter-burnin);
% meanLL=meanLL/(iter-burnin);
'Estimating overfitting.'
ll=mvnMixtureLogLikelihood(dat,mMu,mSig,mProb);
sum(dist>ll)/length(dist)

