clear all;

addpath D:\utilities;
addpath D:\utilities\bfrm;
addpath d:\utilities\sparseFactor\;
addpath d:\utilities\mcmc;
addpath d:\utilities\sss;

datPth='D:\dataRepository\tissueSamples\serum\mccarthy\deplete7\hcv-180.txt';
dat=readDataframe(datPth,'D:\dataRepository\tissueSamples\serum\mccarthy\deplete7\dateKey.txt');
at=scanfile('D:\dataRepository\tissueSamples\serum\mccarthy\deplete7\hcv-180-annotationEnhanced.txt');
% at=scanfile('D:\dataRepository\tissueSamples\serum\mccarthy\deplete7\hcv-180-annotation.txt');
%%% dat and at are already sorted correctly

kp=find(strcmp(at(:,end),'-')~=1);
dat.pids=dat.pids(kp);
dat.xpr(dat.xpr<.01)=NaN;
dat.xpr=log(dat.xpr(kp,:));
at=at(kp,:);

%%% clear outliers
[p n]=size(dat.xpr);
% H=ones(length(dat.sids),1);
% H=[H cell2mat(dat.key(:,6))==2 cell2mat(dat.key(:,6))==3];
for(count=1:2)
%     cdt=dat.xpr;
%     for(i=1:p)
%         b=regress(cdt(i,:)',H)';
%         cdt(i,:)=cdt(i,:)-b*H';
%     end;
    %%% calculate outliers
    cdt=standardize(dat.xpr,2);
%     cdt=standardize(cdt,2);
    out=zeros(p,n);
    for(i=1:p)
        out(i,:)=1*(abs(cdt(i,:))>3);
    end;
    cdt(out==1)=NaN;
    dat.xpr(out==1)=NaN;
end;
%%%%%%%%%%%%
%%% set missing data to row mean
for(i=1:size(dat.xpr,1))
    dat.xpr(i,isnan(dat.xpr(i,:)))=nanmean(dat.xpr(i,:));
end;

ptn=sort(at(:,end));
[uPtn b c]=unique(ptn);
cnt=b-[0; b(1:(end-1))];
clear ptn;

%%% initialize the chain
[p n]=size(dat.xpr);
uPtn(cnt<=5)=[];
cnt(cnt<=5)=[];
k=length(uPtn);
pP=.01*ones(p,k);

%%% factor init.
[a b c]=svd(standardize(dat.xpr,2),0);
cr=corr(dat.xpr',c(:,1:10));
% nf=max(find(sum(abs(cr)>.6)>.1*p));
pcc=.5;
nf=max(find(sum(abs(cr)>pcc)>.1*p));
cr=cr(:,1:nf);
pPF=.01*ones(p,nf);
% f=zeros(nf,n);
f=c(:,1:nf)'./repmat(std(c(:,1:nf))',1,n);
delta=1*(abs(cr)>pcc);
pPF(find(delta))=.95;

mem=zeros(p,1);
lambda=zeros(length(uPtn),n);
beta=std(dat.xpr-repmat(mean(dat.xpr,2),1,n),[],2);
% sigma2=var(dat.xpr,[],2);
sigma2=ones(p,1);
pVMean=100;
pMMean=8;
pPhiA=2;
pPhiB=3;
memProb=zeros(size(pP));

for(i=1:length(uPtn))
    kp=strcmp(at(:,end),uPtn(i));
    pP(kp,:)=.0001;
    pP(kp,i)=.9999;
end;
for(i=1:p)
    mem(i)=randsample(k,1,'true',pP(i,:));
end;
mu=mean(dat.xpr,2);

%%% initialize tracking parameters
mBeta=zeros(size(beta));
mProb=zeros(size(memProb));
mSig=zeros(size(sigma2));
mLambda=zeros(size(lambda));
mMu=zeros(size(mu));
mF=zeros(size(f));
mPDelta=zeros(size(delta));
mDelta=zeros(size(delta));
nSteps=1000;
nBurnin=100;

tLambda=zeros(length(lambda(:)),nSteps);
tCnt=zeros(k,nSteps);
tMem=zeros(p,nSteps);
tBeta=zeros(p,nSteps);
tSig=zeros(p,nSteps);
for(mc=1:nSteps)
    %%% update error factors
    z=dat.xpr-repmat(mu,1,n);
    for(i=1:p)
        z(i,:)=z(i,:)-beta(i)*lambda(mem(i),:);
    end;
%     for(i=1:nf)
%         zchk=z-delta(:,[1:(i-1) (i+1):nf])*f([1:(i-1) (i+1):nf],:);
%         kp=find(delta(:,i)~=0);
% %         f(i,:)=uMvn3(zchk(kp,:)./repmat(delta(kp,i),1,n),sigma2(kp)./(delta(kp,i).^2));
%         f(i,:)=uMvn1(zchk(kp,:),sigma2(kp));
%     end;   
    
    %%% update error factor loadings
    for(i=1:nf)
        zchk=z-delta(:,[1:(i-1) (i+1):nf])*f([1:(i-1) (i+1):nf],:);
        delta(:,i)=updateDelta( z,sigma2,pPF(:,i),f(i,:) );
    end;
    
    %%% update atom locations
    for(i=1:length(uPtn))
        kp=find(mem==i);
%         z=(dat.xpr(kp,:)-repmat(mu(kp),1,n)-delta(kp,:)*f)./repmat(beta(kp),1,n);
%         lambda(i,:)=uMvn3(z,sigma2(kp)./(beta(kp).^2));
        z=(dat.xpr(kp,:)-repmat(mu(kp),1,n)-delta(kp,:)*f);
        lambda(i,:)=uMvn1(z,sigma2(kp));
    end;
    
    %%% update membership probabilities and membership id's
    z=dat.xpr-repmat(mu,1,n)-delta*f;
    memProb=uMvtMem1(z,lambda,pP,pPhiA,pPhiB);
    for(i=1:p)
        mem(i)=randsample(k,1,true,memProb(i,:));
    end;
    tMem(:,mc)=mem;
    
    %%% update sigma and beta
    for(i=1:length(uPtn))
        kp=find(mem==i);
        z=dat.xpr(kp,:)-repmat(mu(kp),1,n)-delta(kp,:)*f;
        [beta(kp) sigma2(kp)]=uMvn2(z,lambda(i,:),pPhiA,pPhiB);
    end;    
    tBeta(:,mc)=beta;
    tSig(:,mc)=sigma2;
    
    %%% update mu
    for(i=1:length(uPtn))
        kp=find(mem==i);
        z=dat.xpr(kp,:)-beta(kp)*lambda(i,:)-delta(kp,:)*f;
        s2=1./(1/pVMean+n./sigma2(kp));
        m=s2.*((1/pVMean)*pMMean+(n./sigma2(kp)).*mean(z,2));
        mu(kp)=normrnd(m,sqrt(s2));
        tCnt(i,mc)=length(kp);
    end;
    if(mod(mc,100)==0)
        mc
        figure(1)
        plot(lambda');
        figure(2)
        plot(f');
        figure(3)
        plot(beta);
        figure(4)
        plot(delta);
        pause(.01);
    end;
    if(mc>nBurnin)
        mBeta=mBeta+beta;
        mProb=mProb+memProb;
        mSig=mSig+sigma2;
        mLambda=mLambda+lambda;
        mMu=mMu+mu;
        mDelta=mDelta+delta;
        mF=mF+f;
        mPDelta=mPDelta+1*(delta~=0);
    end;
    tLambda(:,mc)=lambda(:);
end;
mBeta=mBeta/(nSteps-nBurnin);
mProb=mProb/(nSteps-nBurnin);
mSig=mSig/(nSteps-nBurnin);
mLambda=mLambda/(nSteps-nBurnin);
mMu=mMu/(nSteps-nBurnin);
mDelta=mDelta/(nSteps-nBurnin);
mF=mF/(nSteps-nBurnin);
mPDelta=mPDelta/(nSteps-nBurnin);

%%%%%%%%%%%%%%%%%%%%%%% end fitting of peptides to proteins

%%%%%%%%%%%%%% check the results
sdTrace=zeros(k,nSteps);
for(i=1:nSteps)
    lambda(:)=tLambda(:,i);
    sdTrace(:,i)=std(lambda')';
end;

[jnk mx]=max(mProb,[],2);
for(i=1:length(uPtn))
    kp=find(mx==i);
    [jnk ord]=sort(at(kp,end));
%     resid=(dat.xpr(kp,:)-repmat(mu(kp),1,n)-delta(kp,:)*f)./repmat(beta(kp),1,n);
    resid=(dat.xpr(kp,:)-repmat(mu(kp),1,n)-delta(kp,:)*f);
%     imagesc([lambda(i,:); resid(ord,:)]);
    imagepc([lambda(i,:); resid(ord,:)],1,0,1);
    set(gca,'YTick',1:(1+length(kp)),'YTickLabel',[cellstr([num2str(cnt(i)) '-' char(uPtn(i))]); at(kp(ord),end)]);
    sum(strcmp(at(kp(ord),end),uPtn(i)))/cnt(i)
    colorbar;
    pause;
end;


%%%%%%%%%%%%% predictions based on atoms

% kp=find(cell2mat(dat.key(:,6))==0);
kp=find(cell2mat(dat.key(:,6))~=1);
% kp=find(cell2mat(dat.key(:,6))~=1 & sum(isnan([x(:,pvs<.01/length(pvs)) cell2mat(dat.key(kp,[3 7 9 15 17]))]),2)==0);
% kp=1:180;
y=cell2mat(dat.key(kp,4));
x=lambda(:,kp)';

%%% calculate associations
% pvs=zeros(size(x,2),1);
% for(i=1:length(pvs))
%     pvs(i)=anovan(x(:,i),y,'display','off');
% end;
% 
% res=glmval(glmfit(x(:,pvs<.01/length(pvs)),[y ones(size(y))],'binomial','link','probit'),x(:,pvs<.01/length(pvs)),'probit','size',ones(size(y)));
% res=glmval(glmfit([x(:,pvs<.01/length(pvs)) cell2mat(dat.key(kp,[3 7 9 15 17]))],[y ones(size(y))],'binomial','link','probit'),[x(:,pvs<.01/length(pvs)) cell2mat(dat.key(kp,[3 7 9 15 17]))],'probit','size',ones(size(y)));
% xval=res;
% % [res vlist]=sss(y,x(:,pvs<.01/length(pvs)));
% [res vlist]=sss(y,[x(:,pvs<.01/length(pvs)) cell2mat(dat.key(kp,[3 7 9 15 17]))]);

xval=zeros(size(y));
drp=unique(strtok(dat.sids(kp),'_'));
for(i=1:length(drp))
    w=ones(size(y));
    w(strcmp(drp(i),strtok(dat.sids(kp),'_')))=0;
%     [res vlist]=sss(y,x,w);
    [res vlist]=sss(y,[x cell2mat(dat.key(kp,[3 7 9 15 17]))],w);
    xval(w==0)=res(w==0);
end;
clf;
subplot(1,2,1)
plotGroups(xval,y);
subplot(1,2,2)
[tp fp a]=roc(y,xval);
plot(fp,tp,'linewidth',2);
set(gca,'fontsize',16,'linewidth',2);
text(.2,.1,['area = ' num2str(a,2)],'fontsize',16);

[jnk mx]=max(mProb,[],2);
for(i=find(pvs<.01/length(pvs))')
    kp=find(mx==i);
    [jnk ord]=sort(at(kp,end));
%     resid=(dat.xpr(kp,:)-repmat(mu(kp),1,n)-delta(kp,:)*f)./repmat(beta(kp),1,n);
    resid=(dat.xpr(kp,:)-repmat(mu(kp),1,n)-delta(kp,:)*f);
%     imagesc([lambda(i,:); resid(ord,:)]);
    imagepc([lambda(i,:); resid(ord,:)],0,0,1);
    set(gca,'YTick',1:(1+length(kp)),'YTickLabel',[cellstr([num2str(cnt(i)) '-' char(uPtn(i))]); at(kp(ord),end)]);
    sum(strcmp(at(kp(ord),end),uPtn(i)))/cnt(i)
    colorbar;
    pause;
end;


