function net=LCA_learn_NMF_rb(net, A, control)

% A ~ U * V' * rB'

if nargin < 3
  control = struct;
end

control.maxiter=getOption(control,'maxiter',1000);
control.xtol=getOption(control,'xtol',1e-5);
control.display=getOption(control,'display',1);

[M,N] = size(A);

net.alphaU=getOption(net,'alphaU',0);
net.alphaV=getOption(net,'alphaV',0);

net.U=getOption(net,'U');
net.V=getOption(net,'V');
if isempty(net.V)
  net.V = nrmz(rand(size(net.rB,2),net.K)+(net.alphaV+0.1),2);
end

if isempty(net.U)
  net.U=nrmz(rand(M,net.K)+(net.alphaU+.1),1); % normalize each column
end

net.rBt=net.rB';
BV=net.rB*net.V;

[net.loss, As]=loss(net,A,BV);
for iter=1:control.maxiter
  old_U=net.U;
  net.V=(net.rBt*(As'*net.U)).*net.V+net.alphaV;
  net.V=nrmz(net.V,2);
  BV=net.rB*net.V;
  [net.loss, As]=loss(net,A,BV);

  net.U=As*BV.*net.U+net.alphaU;
  net.U=nrmz(net.U,1);
  [net.loss,As]=loss(net,A,BV);

  change= max(abs(net.U(:) - old_U(:))/max(net.U(:)));

  if control.display, 
    fprintf(2,'iter=%d loss=%.5e xdel=%.2e\n',iter, net.loss, change); 
  end
  if change<control.xtol, break; end
end

%% actual loss need to  minus \sum_i A_i log(A_i)
function [l,As]=loss(net,A,BV)
Y=spmult_raw(A,net.U',BV');
Anz=nonzeros(A);
l=sum(Y)-sum(Anz);
Y=Anz./Y;
l=l+sum(Anz.*log(Y));
if nargout>1, As=mksparse(A,Y); end
