%
% Incomplete Matrix Factorization Loss
%  optimized for large scale matrices
%
% function [f,g,h]=IMFutil(Y, M, which, k)
%
% Y: sparse matrix (unobserved as missing value)
% M: model with field U, S, V, u, v, m
% C: which = Y, U, S, V, u, v, m
%    k   = 1..K
%
% f=SquareSumOnObservation(M.U*M.S*M.V'+M.u*1'+1*M.v'+M.m-Y)
% g= gradient 
% h= hessian 

function [f,g,h]=IMFutil(Y, M, which,k)

if nargin==0, f=0; return; end

switch which
  case 'mkSparse'
    [I,J]=find(Y);
    f=sparse(I,J,M,size(Y,1),size(Y,2));
    return;
  case 'spMult'
    [I,J]=find(Y);
    if isfield(M,'Ut'),
        f=spMult_t(M.Ut,M.Vt,I,J);
    else
        f=spMult(M.U,M.V,I,J);
    end
    return;
  case 'scaleRow'
    f=diag(sparse(M))*Y;
    return;
  case 'maskmultU'
    f=(Y~=0)'*M;
    return;
  case 'maskmultUt'
    f=M*(Y~=0);
    return;
  case 'maskmultV'
    f=(Y~=0)*M;
    return;
  case 'maskmultVt'
    f=M*(Y~=0)';
    return;
  case 'multU'
    Y=update_Y(Y,M);
    if isfield(M,'Ut'),
        f=M.Ut*Y;
    else
        f=Y'*M.U;
    end
    return;
  case 'multV'
    Y=update_Y(Y,M);
    if isfield(M,'Ut'),
        f=M.Vt*Y';
    else
        f=Y*M.V;
    end
    return;
  otherwise
    1;
end

if isfield(M,'Ut'),
    U=M.Ut';
    V=M.Vt';
else
    U=M.U;
    V=M.V;
end

batch=256;
if isstruct(Y),
  nmi=Y.I;
  nmj=Y.J;
  nmy=Y.Y;
else
    [nmi,nmj,nmy]=find(Y);
    nz=length(nmi);
end

if isfield(M,'S'),
    if (size(M.S,1)==1 || size(M.S,2)==1)
        U1=U*diag(M.S);
        V1=V*diag(M.S);
        diagS=true;
    else
        U1=U*M.S;
        V1=V*M.S;
        diagS=false;
  end
else
    U1=U;
    V1=V;
end

maxi=size(U1,1);
maxj=size(V,1);
nmz=zeros(size(nmy));
for ik=1:batch:nz,
    if (ik+batch>nz),
    kk=ik:nz;
    else
        kk=ik:(ik+batch-1);
    end
    nmz(kk)=sum(U1(nmi(kk),:).*V(nmj(kk),:),2);
end

if isfield(M,'u'),
    nmz=nmz+M.u(nmi);
end
if isfield(M,'v'),
    nmz=nmz+M.v(nmj);
end
if isfield(M,'m'),
    nmz=nmz+M.m;
end

if which=='Yr',
    f=nmz;
elseif which=='Y',
    f=sparse(nmi,nmj,nmz, maxi, maxj);
else
    nmd=nmz-nmy;
    f=nmd'*nmd/2;
    D=sparse(nmi,nmj,nmd, maxi, maxj);
    I=sparse(nmi,nmj,true(size(nmd)),maxi,maxj);
    switch which
      case 'U'
        g=D*V1(:,k);
        h=I*(V1(:,k).^2);
      case 'V'
        g=D'*U1(:,k);
        h=I'*(U1(:,k).^2);
      case 'S'
        g=U'*D*V;
        h=(U.^2)'*I*(V.^2);
        if diagS,
            g=diag(g);
            h=diag(h);
        end
      case 'u'
        g=full(sum(D,2));
        h=full(sum(I,2));
      case 'v'
        g=full(sum(D,1));
        h=full(sum(I,1));
      case 'm'
        g=sum(nmd);
        h=length(nmd);
      otherwise
        error('unknown method');
    end
end

function Y=update_Y(Y,M)
[m,n]=size(Y);
[nmi,nmj,nmy]=find(Y);
if isfield(M,'u'),
  nmy=nmy-M.u(nmi);
end
if isfield(M,'v'),
  nmy=nmy-M.v(nmj);
end
if isfield(M,'m'),
  nmy=nmy-M.m;
end
Y=sparse(nmi,nmj,nmy,m,n);

function V=spMult(Z,X,I,J)

batch=256;
nz=length(I);
V=zeros(nz,1);
for k=1:batch:nz,
    if (k+batch>nz),
        kk=k:nz;
    else
        kk=k:(k+batch-1);
    end
    V(kk)=V(kk)+sum(Z(I(kk),:).*X(J(kk),:),2);
end

function V=spMult_t(Z,X,I,J)
%V=Z'*X for I row and J column, V is the column vector
batch=256;
nz=length(I);
V=zeros(nz,1);
for k=1:batch:nz,
    if (k+batch>nz),
        kk=k:nz;
    else
        kk=k:(k+batch-1);
    end
    V(kk)=V(kk)+sum(Z(:,I(kk)).*X(:,J(kk)),1)';
end
