% driver01.m
% 
% This script solves the emission problem using a prior created using EDGAR
% 2005 and nightlights. The spatial pattern is nightlight's but it's been
% scaled to get EDGAR 2005 emissions over US
%
% - It first does the inversion based on wavelets. the emissions are
%  largely +ve, but do have some -ve emissions.
%
% - It then does non-negativity enforcement.
%
% obsErr (meas err) = 1.0e-1
% Jaideep Ray, 4/8/2013
% -------------------------------------------------------------------------

clear; close all ;

%%  ----- User-settable paths to toolboxes --------------------------------
% Set the paths to SparseLib2.1 and Wavelab 8.50 here

SPARSELABLIBPATH1 = '/home/jaray/Projects/ascr-graph/lib/SparseLab2.1-Core/';
Wavelab850Path    = '../../../../lib/Wavelab850/Orthogonal/' ;

%% ---- Don't change anything below this

addpath('../../sparse-inversion/') ;
addpath('../../msrf-selection/') ; % for wavelet transforms
addpath(Wavelab850Path); % for MakeONFilter and wavelet utils
addpath(genpath(SPARSELABLIBPATH1));

%% Legal splash screen
fprintf('Copyright 2014 Sandia Corporation. Under terms of Contract DE-AC04-94AL85000\n') ;
fprintf('there is a non-exclusive license for the use of this work by or on behalf of the U.S.\n');
fprintf('Government.\n');

%% ---- User-settable parameters

% ---- File and dir locations
edir         = '../../data/vulcan.8day.1x1/' ; % Emissions directory
obsDir       = '../../data/tower.obs/' ;   % To get tower names
hDir         = '../../H-matrices/' ; % to get H matrices
nlightsFile  = '../../data/nlights-0.1x0.1/nightlights.dat' ;
nlightsLats  = '../../data/nlights-0.1x0.1/nightlights_lat.dat' ;
nlightsLongs = '../../data/nlights-0.1x0.1/nightlights_long.dat' ;
USBndryFile  = '../../data/US_coords.dat' ; % To chop-off canada in US nightlights
edgarDir     = '../../data/edgar2005.1x1/' ;
oDir1        = 'wavelets' ;   % estimated wavelets saved in this dir
oDir2        = 'positivityEnforced' ;   % positive emissions saved here

% ---- Basic info about problem
info.nUSCells = 816 ;
info.nObsPerPeriod = 8*8 ;
info.nCompSamples = 4000 ;
obsErr = 1.0e-1 ;

% ---- Time period over which info is required
buffer  = 2 ; % buffer periods before and after the duration of interest
periods = [31 34] ; % Want to tackle Days 241:272 i.e. 32 days
% periods = [3 43] ; % Want to tackle a full year (Week 3 to 43)

% ---- RF model related
levelsToZero = 0 ;  % Uncomment for AccurateBases

% ---- StOMP related
thresh = 'FAR' ; 
param  = 0.5 ; % Keep this at 0.5, it picks the largest # of bases
maxIters = 100 ; 
verbose = 1 ; 
OptTol = 1.0e-5 ;
lsqrIters = 500 ;
%lsqrIters = 20 ;

% ---- Numerical parameters for my non-negativity enforcement iteration
smallEmission   = 1.0e-6 ;
eps             = 6.0e-4 ;
maxNonlinIters  = 500 ;
lsqrIters2      = 20 ; % This is for StOMP as called from my positivity enforcement

% ---- Outputs
saveStuff = true ;   % Do we save estimated wavelet coeffs to disk?

visualize  = false ;   % Plot bases or not
%visualize = true ;   % Plot things or not

%visualizeSoln = false ; % Visualize the solution or not
visualizeSoln = true ; % Visualize the solution or not

% ---- Plotting related
lc      = [{'-k'}, {'-b'}, {'-r'}, {'-g'}, {'-m'}];
sc      = [{'ok'}, {'ob'}, {'or'}, {'og'}, {'om'}] ;
marker  = 'dw' ;
ms      = 12 ; % marker size for tower locations
mfc     = 'w'; % marger fill color
lw      = 2  ; % linewidth
fs      = 14 ; % fontsize
lfs     = 12 ; % legend fontsize

% ---- End, user-settable parameters --------------------------------------

%% Get the tower names
[ ~, towerNames ] = getTowerObservations(obsDir, ...
                                [periods(1)-buffer, periods(2)+buffer], ...
                                info) ;
info.nTowers = length(towerNames) ;
info.buffer = buffer ; % buffer, in Periods, at the front and end of observations
info.nPeriods = periods(2) - periods(1) + 1 + 2*info.buffer;

%% Extract the Hmatrices for the same period
[Hmatrices, Hlatlon] = getHMatrices(hDir, ...
                                [periods(1)-buffer, periods(2)+buffer], ...
                                towerNames, info) ;

%% Construct a set of bases that provides emissions over US only, at
% full 1 deg x 1 deg resolution
tst = tic ;
[Phi64, latlon, priorWeights] = ...
    constructBasesFromNlightEmissionModel(nlightsFile, nlightsLats, ...
           nlightsLongs, USBndryFile, edgarDir, levelsToZero, visualize);
fprintf('Constructing bases took %f secs\n', toc(tst));


%% Get emissions on the dyadic grid on which bases are described. 
[ emissions, lat2D, long2D ] = getEmissionsOnBaseGrid( edir, ...
                          [periods(1)-buffer, periods(2)+buffer], latlon) ;   

%% Get emisisons on lower 48 and get tower (synthetic) observations
[emissionsLower48, latlonLower48, ind] = getFieldInLower48(emissions, ...
                                          [lat2D(:), long2D(:)], Hlatlon) ;

% ---- The total emissions and the emissions in lower 48 should be the same. are they?
fprintf('Lower 48 recovers %f of the total emissions\n', ...
                              sum(emissionsLower48(:)) /sum(emissions(:)) );

%% Project emissions on bases and get the wavelet weights
emissionsModfd = zeros(size(emissions, 1), size(emissions, 2));
emissionsModfd(ind, :) = emissionsLower48(:, :) ;
weights1 = transpose(Phi64) * emissionsModfd ; % weights is a matrix 1054 x 45 in size
clear('emissionsModfd') ;

%% Use emissions to generate synthetic observations
for i = 1 : length(towerNames)
    em = emissionsLower48(:) ;
    towerObs{i} = Hmatrices{i} * em ;
    
    % add observation errors
    tmp = towerObs{i} + obsErr * randn(size(towerObs{i}, 1), ...
                                       size(towerObs{i}, 2)) ;
    % observations can't be -ve conc
    ii = tmp < 0.0; tmp(ii) = 0.0 ;
    towerObs{i} = tmp ;
end

%% Split bases into Lower48 and Clower48 (complement to Lower48) set

% ---- Scale basis set by a prior belief of wavelet weights. cuts down on 
%      the number of basis sets chosen
tmp = diag( abs(priorWeights) ) * transpose(Phi64) ;
Phi64 = transpose(tmp) ;

% ---- Now split the basis set, once it's been scaled
[Phi64Lower48, latlonLower48, Phi64Clower48, latlonClower48] =  ...
                                    splitBasisSet(Phi64, latlon, Hlatlon) ;

%% Generate the random matrix for compressive sampling
R = genRandomMatrixForCompSampling(info.nCompSamples, size(Phi64Clower48, 1)) ;

%% Construct gain functions
disp('Computing gain matrices') ; tst = tic ;
[Gtowers, GCS] = computeAugmentedGainMatrices(Hmatrices, Phi64Lower48, R, ...
                                              Phi64Clower48, info) ;
fprintf('Gain matrices took %f secs to compute\n', toc(tst));

%% Set up problem and call cost function
global PROB_GLOBAL ;

PROB_GLOBAL.obs     = towerObs    ; % Cell array
PROB_GLOBAL.Phi     = Phi64Lower48; % Wavelet bases matrix; only used to check # of unknowns
PROB_GLOBAL.info    = info        ; % struct array
PROB_GLOBAL.Gtowers = Gtowers     ; % Gain matrices
PROB_GLOBAL.GCS     = GCS         ;

%% Observations and solve the inverse problem
nobs = PROB_GLOBAL.info.nPeriods * PROB_GLOBAL.info.nObsPerPeriod * ...
       PROB_GLOBAL.info.nTowers + info.nCompSamples ;
y = zeros(nobs, 1) ;
for i = 1 : length(towerNames)
    iend = i * PROB_GLOBAL.info.nPeriods * PROB_GLOBAL.info.nObsPerPeriod ;
    y(iend - PROB_GLOBAL.info.nPeriods * PROB_GLOBAL.info.nObsPerPeriod + 1 : iend) = towerObs{i} ;
end

N = size(PROB_GLOBAL.Phi, 2) * PROB_GLOBAL.info.nPeriods ; % # of wavelet coeffs for nPeriods
% Use the standard version of StOMP. For the optimized version, contact 
% Jaideep Ray, jairay@sandia.gov
% [sol, iters] = SolveStOMP_JRmod2(@projectionsUsingAugmentedGain, y, N, thresh, ...
%                                  param, maxIters, verbose, OptTol, ...
%                                  lsqrIters, PROB_GLOBAL) ;
[sol, iters] = SolveStOMP(@projectionsUsingAugmentedGain, y, N, thresh, ...
                          param, maxIters, verbose, OptTol) ;

                      %%  =============== ANALYSIS OF THE FIRST SOLUTION ==============================
%
% Got the solution; what's the  L2 norm of the difference between best coeffs and soln?
%  Of course, first extract the relevant wavelet coeffs

% ---- Recalculate weights for true emissions with scaled Phi64
emissionsModfd = zeros(size(emissions, 1), size(emissions, 2));
emissionsModfd(ind, :) = emissionsLower48(:, :) ;
weights = transpose(Phi64) * emissionsModfd ; % weights is a matrix 1054 x 45 in size
clear('emissionsModfd') ;

% ---- Find difference in wavelet weights
usefulSol = sol(info.buffer*size(PROB_GLOBAL.Phi,2)+1 : end - info.buffer*size(PROB_GLOBAL.Phi,2)) ;
goodWeights = weights(:, (info.buffer+1) : (end - info.buffer));
diff      = goodWeights(:) - usefulSol(:) ;
fprintf('L2 norm, of soln = %f; vector size = %d \n', norm(diff), length(diff));
if (saveStuff == true)
    res = saveWaveletCoeffsToDisk(goodWeights(:), usefulSol, oDir1, periods, PROB_GLOBAL) ;
end

%% How much of the emissions did we recover? 
PROB_GLOBAL.H = Hmatrices ;
topcorner = [51.5, -126.5] ; % (51.5N, -127.5W) Extreme values of cell centers
botcorner = [23.5, -63.5]  ; % (23.5N, -62.5W)  spanned by the wavelets
figNo = 0;

[towerPreds, emissionsPreds, emissionsTrue2D, emissionsRecons2D, figNo] = ...
 estimateEmissionReconsError(PROB_GLOBAL, sol, emissionsLower48, ...
                             latlonLower48, topcorner, botcorner, periods, ...
                             figNo, visualizeSoln, saveStuff, oDir1, fs) ;
                                                    
%% What is the correlation between the estimated and reconstructed emissions?
figNo = calculateEmissionCorr(info, emissionsTrue2D, emissionsRecons2D, ...
                              periods, figNo, visualize, saveStuff, oDir1, fs);
                          
%% Plots of co2 conc at a few towers. Posterior predictive check
figNo = plotTowerCO2Conc(info, towerNames, towerObs, towerPreds, periods, ...
                         figNo, visualize, saveStuff, oDir1, fs, lfs, sc, lc) ;

%% Separate plots of emissions; overplot tower locations
figNo = plotReconsEmission(hDir, info, periods, USBndryFile, ...
                           topcorner, botcorner, emissionsTrue2D, ...
                           emissionsRecons2D, figNo, visualize, ...
                           saveStuff, oDir1, fs, marker, ms, mfc, lw) ;
                       
%% ---- Check whether the estimated wavelet weights do not lead to emissions over oceans
oceanEmissions = calcOceanEmissions(PROB_GLOBAL, sol) ;

%% ---- Dump out the emissions as a function of time, but true and reconstructed
answer = outputEmissionsToFile(emissionsLower48(:, (buffer+1) : (end-buffer)), ... 
                               emissionsPreds(:, (buffer+1) : (end-buffer)), ...
                               latlonLower48, oDir1) ;

%%  ======== START THE SECOND SOLUTION WHERE POSITIVITIY IS ENFORCED ======
% First, close all windows, because we'll have to redo all of them anyway.
% Then plot how far is our approximate solution, without positivity
% enforcement, from truth.

close all ; figNo = 1 ;

% ---- Plot the total US emissions obtained from just the +ve fluxes as well as 
% obtained from all fluxes i.e., do the -ve fluxes account for a lot?
d = emissionsPreds(:, (buffer+1) : (end - buffer)) ;
i = d > smallEmission ; 
pd = zeros( size(d) ) ; pd(i) = d(i) ;
wEm = sum( d ) ; % weekly emissions for US
pwEm = sum( pd ) ; % weekly emissions, but +ve ones only
ratio = pwEm ./ wEm ;
figure(figNo) ; figNo = figNo + 1 ;
plot(periods(1) : periods(2), wEm, '-r'); hold on
[ax, h1, h2] = plotyy(periods(1) : periods(2), pwEm, periods(1) : periods(2), ratio);
set(get(ax(1), 'Ylabel'), 'String', '8-day emissions over US') ;
set(get(ax(2), 'Ylabel'), 'String', 'Ratio - positive only v/s emissions') ;
set(h1, 'LineWidth', lw) ; set(h2, 'LineWidth', lw) ; 
xlabel('Period #') ; title('Comparision of emissions calculated from positive fluxes');
legend('Emissions', 'Emissions from positive fluxes only', 'Ratio', 'Location', 'Best');
legend boxoff;
hold off;

ofname = [oDir2 '/' 'emissionComp.jpg'] ; fprintf('Saving file %s\n', ofname) ;
print('-djpeg', ofname) ;

% ---- How many grid cells have a +ve flux in them?
fprintf('Fraction of cells that have +ve emissions = %f\n', sum(i(:)) / numel(i)) ;

% ---- Are most of the negative emissions small? plot the ditribution
figure(figNo) ; figNo = figNo + 1 ;
h = cdfplot( d(:) ) ; 
set(h, 'LineWidth', lw) ; set(h, 'Color', 'b') ; hold on ;
xlabel('Emissions in a grid-cell') ; ylabel('Probability mass') ;
title('Distribution of grid-cell emissions') ;

clear pd; clear d ; % Don't need them anymore

%% Set up the inversion problem, but this time with Lower48 emissions, not
% wavelets. The observations are just the tower observations and not the
% compressive samples. Further, the dimensionality of the inversion is # of
% grid cells X nPeriods, and not # of wavelet coefficients.

nTowerObs = PROB_GLOBAL.info.nPeriods * PROB_GLOBAL.info.nObsPerPeriod * PROB_GLOBAL.info.nTowers ;
yy = y(1 : nTowerObs) ;

NN = size(PROB_GLOBAL.H{1}, 2); % # of emissions for nPeriods = dimensionality of inv prob
[ePlus, iters, exitflag] = calcPositiveEmissions(yy, NN, emissionsPreds, thresh, param, ...
                                                 maxIters, verbose, OptTol, ...
                                                 lsqrIters2, PROB_GLOBAL, eps, ...
                                                 maxNonlinIters) ;

%% =============== ANALYSE THE SECOND SOLUTION ===================================

% ---- Make sure we removed all negative fluxes
emissionsPreds = reshape(ePlus, size(emissionsLower48)) ;
d = emissionsPreds(:, (buffer+1) : (end-buffer)) ;
h = cdfplot( d(:) ) ;  
set(h, 'LineWidth', lw) ; set(h, 'Color', 'r') ; hold off ;
legend('Before imposing non-negativity', 'After imposition', 'Location', 'Best') ;
legend boxoff ;
ofname = [oDir2 '/' 'emissionPositivity.jpg'] ; fprintf('Saving file %s\n', ofname) ;
print('-djpeg', ofname) ;

% ---- How much of the emissions did we recover? 
topcorner = [51.5, -126.5] ; % (51.5N, -127.5W) Extreme values of cell centers
botcorner = [23.5, -63.5]  ; % (23.5N, -62.5W)  spanned by the wavelets

[towerPreds, emissionsTrue2D, emissionsRecons2D, figNo] = ...
    estimateEmissionReconsErrorV2(PROB_GLOBAL, emissionsPreds, emissionsLower48, ...
                             latlonLower48, topcorner, botcorner, periods, ...
                             figNo, saveStuff, oDir2, fs) ;
                         
%% What is the correlation between the estimated and reconstructed emissions?
figNo = calculateEmissionCorr(info, emissionsTrue2D, emissionsRecons2D, ...
                              periods, figNo, visualizeSoln, saveStuff, oDir2, fs);
                          
%% Plots of co2 conc at a few towers. Posterior predictive check
figNo = plotTowerCO2Conc(info, towerNames, towerObs, towerPreds, periods, ...
                         figNo, visualizeSoln, saveStuff, oDir2, fs, lfs, sc, lc) ;

%% Separate plots of emissions; overplot tower locations
figNo = plotReconsEmission(hDir, info, periods, USBndryFile, ...
                           topcorner, botcorner, emissionsTrue2D, ...
                           emissionsRecons2D, figNo, visualizeSoln, ...
                           saveStuff, oDir2, fs, marker, ms, mfc, lw) ;
                       
%% ---- Dump out the emissions as a function of time, but true and reconstructed
answer = outputEmissionsToFile(emissionsLower48(:, (buffer+1) : (end-buffer)), ... 
                               emissionsPreds(:, (buffer+1) : (end-buffer)), ...
                               latlonLower48, oDir2) ;
                           
%% ---- Compare estimated and true emissions in terms of wavelets.
%  Illustrates sparsity of reconstruction.
answer = computeAndSaveWaveletDecomp(oDir2, periods, Phi64Lower48, ...
                        emissionsLower48(:, (buffer+1) : (end-buffer)), ...
                        emissionsPreds(:, (buffer+1) : (end-buffer)) ) ;  
                    
