function [b_bar,bn_bar,en_bar,Nstar,gn]=waterfill_soln(Ex_bar,Noise_var,Ntot,gap_dB,H_D,index_main)
%[b_bar,bn_bar,en_bar,Nstar,gn]=waterfill_soln(Ex_bar,Noise_var,Ntot,gap_dB,H_D,index_main)
% Waterfilling
% Ex_bar is the normalized energy
% Noise_var is the noise variance per dimension
% Ntot is the total number of real/complex subchannels, Ntot>2
% gap_dB is the gap in dB
% H_D is the channel H(D)
% index_main - index of the main channel tap (zero delay)
%
% gn is channel gain
% en_bar is the energy/dim in the nth subchannel
% bn_bar is the bit/dim in the nth subchannel
% Nstar is the number of subchannel used
% b_bar is the bit rate

% dB into normal scale
gap=10^(gap_dB/10);

% initialization
% subchannel center frequencies
f=-1/2+1/Ntot:1/Ntot:1/2;
for k=1:length(f),
    Hn(k)=sum(H_D.*exp(2*pi*j*f(k)*((1:length(H_D))-index_main)));
end

Ntot=length(Hn);
en=zeros(1,Ntot);
bn=zeros(1,Ntot);
gn=zeros(1,Ntot);

% find gn vector
gn=abs(Hn).^2/Noise_var;
%plot(gn)

%%%%%%%%%%%%%%%%%%%%%%%
% Now do waterfilling %
%%%%%%%%%%%%%%%%%%%%%%%

%sort
[gn_sorted, Index]=sort(gn);  % sort gain, and get Index

gn_sorted = fliplr(gn_sorted);% flip left/right to get the largest 
                              % gain in leftside
Index = fliplr(Index);        % also flip index  


num_zero_gn = length(find(gn_sorted == 0)); %number of zero gain subchannels
Nstar=Ntot - num_zero_gn;    
 	% Number of used channels, 
 	% start from Ntot - (number of zero gain subchannels)

while(1) 
 	K=1/Nstar*(Ntot*Ex_bar+gap*sum(1./gn_sorted(1:Nstar))); 
 	En_min=K-gap/gn_sorted(Nstar);	% En_min occurs in the worst channel
 	if (En_min<0)		
    		Nstar=Nstar-1;  % If negative En, continue with less channels
 	else 
    		break;       % If all En positive, done.
 	end
end

En=K-gap./gn_sorted(1:Nstar); 		% Calculate En
Bn=.5*log2(K*gn_sorted(1:Nstar)/gap); 	% Calculate bn

bn(Index(1:Nstar))=Bn;		% return values in original index
en(Index(1:Nstar))=En;		% return values in original index

middle = Ntot/2;		% Since channel is even, need to display 
                                % only half of result
en_bar=en(middle:Ntot);
bn_bar=bn(middle:Ntot);
en_bar=[fliplr(en_bar(2:end-1)) en_bar];
bn_bar=[fliplr(bn_bar(2:end-1)) bn_bar];
% calculate b_bar
% b_bar=1/Ntot*(sum(bn));
b_bar=1/Ntot*(sum(bn));



% check if you get the following.
%>> [gn,en_bar,bn_bar,Nstar,b_bar] = waterfill([1 .9], 10,1,8,0)
%gn =
%  Columns 1 through 7 
%    2.9680   10.0000   17.0320   19.9448   17.0320   10.0000    2.9680
%  Column 8 
%    0.0552

%en_bar =
%    1.2415    1.2329    1.1916    0.9547         0

%bn_bar =
%    2.3436    2.2297    1.8456    0.9693         0

%Nstar =
%     7

%b_bar =
%    1.5541



