function net=LCA_learn_NMF_lb(net, A, control)

% A ~ lB * U * V'

if nargin < 3
  control = struct;
end

control.maxiter=getOption(control,'maxiter',1000);
control.xtol=getOption(control,'xtol',1e-5);
control.display=getOption(control,'display',1);

[M,N] = size(A);

net.alphaU=getOption(net,'alphaU',0);
net.alphaV=getOption(net,'alphaV',0);

net.U=getOption(net,'U');
net.V=getOption(net,'V');
if isempty(net.V)
  net.V = nrmz(rand(N,net.K)+(net.alphaV+0.1),2);
end

if isempty(net.U)
  net.U=nrmz(rand(size(net.lB,2),net.K)+(net.alphaU+.1),1); % normalize each column
end

net.lBt=net.lB';
BU=net.lB*net.U;
[net.loss, As]=loss(net,A,BU);
for iter=1:control.maxiter
  old_U=net.U;

  net.V=As'*BU.*net.V+net.alphaV;
  net.V=nrmz(net.V,2);
  [net.loss, As]=loss(net,A,BU);

  net.U=net.lBt*(As*net.V).*net.U+net.alphaU;
  net.U=nrmz(net.U,1);
  BU=net.lB*net.U;
  [net.loss,As]=loss(net,A,BU);

  change= max(abs(net.U(:) - old_U(:))/max(net.U(:)));

  if control.display, 
    fprintf(2,'iter=%d loss=%.5e xdel=%.2e\n',iter, net.loss, change); 
  end
  if change<control.xtol, break; end
end

%% actual loss need to  minus \sum_i A_i log(A_i)
function [l,As]=loss(net,A,BU)
Y=spmult_raw(A,BU',net.V');
Anz=nonzeros(A);
l=sum(Y)-sum(Anz);
Y=Anz./Y;
l=l+sum(Anz.*log(Y));
if nargout>1, As=mksparse(A,Y); end
