Skip to content

Commit

Permalink
upload files
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangyongcao committed Sep 2, 2019
1 parent a67ed6f commit 220f68b
Show file tree
Hide file tree
Showing 18 changed files with 1,023 additions and 0 deletions.
71 changes: 71 additions & 0 deletions CNN_AL_MRF_GenWeit.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
function CNN_AL_MRF_GenWeit(data, DirHor, DirVer)

BandInd = data.IndBand;
%BandInd = [12,60,180];
if ~isfield(data, 'F_norm')
Feature = data.F;
% normalize
for b=1:length(BandInd)
Feature(:,:,BandInd(b)) = (Feature(:,:,BandInd(b))-min(min(Feature(:,:,BandInd(b)))))./...
(max(max(Feature(:,:,BandInd(b))))-min(min(Feature(:,:,BandInd(b)))));
end
end
Feature = data.F_norm;
HSI_rgb = im2uint8(Feature(:,:,BandInd));
% figure, imshow(HSI_rgb)

H = size(HSI_rgb, 1);
W = size(HSI_rgb, 2);

R = flipud(HSI_rgb(:,:,1)); R = R(1:H,:);R = R(:,1:W);
G = flipud(HSI_rgb(:,:,2)); G = G(1:H,:);G = G(:,1:W);
B = flipud(HSI_rgb(:,:,3)); B = B(1:H,:);B = B(:,1:W);

for i = 2 : H
for j = 2 : W
Dij_left(i-1,j-1) = (abs(R(i,j) - R(i-1,j)))^2;
Dij_up(i-1,j-1) = (abs(R(i,j) - R(i,j-1)))^2;
end
end
var_left = std(double(Dij_left(:)));
var_up = std(double(Dij_up(:)));
HorzWeight1 = exp(-double(Dij_left)/(2*var_left));
VertWeight1 = exp(-double(Dij_up)/(2*var_up));

for i = 2 : H
for j = 2 : W
Dij_left(i-1,j-1) = (abs(G(i,j) - G(i-1,j)))^2;
Dij_up(i-1,j-1) = (abs(G(i,j) - G(i,j-1)))^2;
end
end
var_left = std(double(Dij_left(:)));
var_up = std(double(Dij_up(:)));
HorzWeight2 = exp(-double(Dij_left)/(2*var_left));
VertWeight2 = exp(-double(Dij_up)/(2*var_up));

for i = 2 : H
for j = 2 : W
Dij_left(i-1,j-1) = (abs(B(i,j) - B(i-1,j)))^2;
Dij_up(i-1,j-1) = (abs(B(i,j) - B(i,j-1)))^2;
end
end
var_left = std(double(Dij_left(:)));
var_up = std(double(Dij_up(:)));
HorzWeight3 = exp(-double(Dij_left)/(2*var_left));
VertWeight3 = exp(-double(Dij_up)/(2*var_up));

HorzWeight = real(HorzWeight1+HorzWeight2+HorzWeight3)/3;
VertWeight = real(VertWeight1+VertWeight2+VertWeight3)/3;

% figure;
% mesh(HorzWeight);
% title('HorzWeight');
% axis([1 W 1 H]);

% figure;
% mesh(VertWeight);
% title('VertWeight');
% axis([1 W 1 H]);

save(DirHor, 'HorzWeight', '-ascii');
save(DirVer, 'VertWeight', '-ascii');
127 changes: 127 additions & 0 deletions CNN_AL_MRF_main.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@

%% Experiment for HSIs classification using CNNs, active learning and MRF incorporated model
% writen by ...
% Aug. 2019.
% the folder containing this function is suggested to be located under ...\matconvnet\examples\
%
% if you find this code useful, please kindly cite the following papers:
% [1]
% [2]
% ...

close all; clear, clc;
Dir.Cur = pwd;
addpath(genpath(Dir.Cur));

%% Parameters for data
data.NameFolder = {'IndianPines', 'PaviaU', 'PaviaCenter'};
data.NameMat = {'GT.mat', 'Feature.mat'};
data.SizeOri = {[145, 145, 220], [610, 340, 103], [400, 300, 102]};
data.SizeWin = 8;
data.NumClass = {16, 9, 8};
data.IndBand = {[10, 80, 200], [12, 67, 98], [10, 60, 90]}; % to generate false RGB, which should be less contaminated bands
%data.flagPCA = true;
%data.ReducedDim = 10;

% Three datasets:
% data.flagSet = 1, Indian Pines;
% = 2, Pavia University;
% = 3, Pavia Center.
data.flagSet = 1;

data.NameFolder = data.NameFolder{data.flagSet};
data.SizeOri = data.SizeOri{data.flagSet};
data.NumClass = data.NumClass{data.flagSet};
data.IndBand = data.IndBand{data.flagSet};

%% Parameters for algorithm
alg.SampleSty = 'Rd'; % out of {'Rd', 'Classwise'}
alg.CountSty = 'Num'; % out of {'Num', 'Ratio'}
alg.NumTrn1st = {250, 107, 58};
alg.NumTrn1st = alg.NumTrn1st{data.flagSet};
% if alg.CountSty == 'Ratio'
%alg.RatioTrn1st = {0.02, 0.0025, 0.0025};
%alg.RatioTrn1st = alg.RatioTrn1st{data.flagSet};
alg.CrossVal = 0.05;
alg.NumAlAugPerIte = {[250, 150, 100, 50], [107, 107, 107], [26, 20]}; % The training samples added in each iteration keeps the same ratio with the training sample number of the first iteration
alg.NumAlAugPerIte = alg.NumAlAugPerIte{data.flagSet};
alg.NumIter = length(alg.NumAlAugPerIte)+1;
alg.AlStra = 'BvSB'; % out of {'BvSB', 'RS', 'EP'};
alg.SmoothFactor = 8; % around 10
alg.flagAL = true;
alg.flagMRF = true;

alg.RngSeed = 4;

%% Parameters for CNN module
cnn_net.NumEpoch = {[800, 400, 400, 300, 200], [400, 200, 200, 200], [600, 400, 300]}; % Finetune helps to coverge faster
cnn_net.SizeKer1Conv = 3; cnn_net.SizeKer2Conv = 2;
cnn_net.NumKer1Conv = 20; cnn_net.NumKer2Conv = 20;
cnn_net.SizeKer1Pool = 2; cnn_net.SizeKer2Pool = 2;
cnn_net.Stride1Pool = 2; cnn_net.Stride2Pool = 2;
cnn_net.SizeKer1FC = 2; cnn_net.SizeKer2FC = 1;
cnn_net.NumKer1FC = 500; cnn_net.NumKer2FC = data.NumClass;
cnn_net.WeightDecay = 0.0005; cnn_net.BatchSize = 50;
cnn_net.LearningRate = 0.001; % or logspace(-3, -4, cnn_net.NumEpoch(Ite))
cnn_net.flagDA = true; cnn_net.flagBN = true;
cnn_net.flagFT = true;

if cnn_net.flagFT
cnn_net.NumEpoch = cnn_net.NumEpoch{data.flagSet};
else
cnn_net.NumEpoch = [800, 400, 600];
cnn_net.NumEpoch = cnn_net.NumEpoch(data.flagSet);
end

data = rmfield(data, 'flagSet');

%% make folders
cd ..
cd ..
Dir.Root = pwd;
Dir.Exp = strcat(pwd, '\data\HSI\', data.NameFolder, '\RngSeed-', num2str(alg.RngSeed));
if ~cnn_net.flagDA
Dir.Exp = fullfile(Dir.Exp, 'CNN');
else if ~cnn_net.flagBN
Dir.Exp = fullfile(Dir.Exp, 'DA');
else if ~alg.flagAL
Dir.Exp = fullfile(Dir.Exp, 'BN');
else if ~cnn_net.flagFT
Dir.Exp = fullfile(Dir.Exp, 'AL');
else if alg.flagMRF
Dir.Exp = fullfile(Dir.Exp, 'FT-MRF');
end
end
end
end
end
Dir.Results = fullfile(Dir.Exp, 'Results');

if ~exist(Dir.Exp) || ~exist(Dir.Results)
mkdir(Dir.Exp);
mkdir(Dir.Results);
end

cd(Dir.Cur)
Dir.Data = fullfile(pwd, 'Data', data.NameFolder);

%% Generate initialized labeled pixels
[data, Train, Patch] = CNN_AL_MRF_preprocess(Dir, data, alg);

% initialize
Train.Set_All = [];
Train.Data_Train = []; Train.Set_Train = [];
Train.Data_Val = []; Train.Set_Val = [];

%% Iterations of the overall algorithm, including data preparation, CNNs' training and testing (with active learning), post-preprossing with or w/o. MRF
for Ite = 1:alg.NumIter

%% step 1: data preparation (with data augmentation)
[Train, Test] = CNN_AL_MRF_prepdata(data, cnn_net, alg, Train, Patch);

%% step 2: train CNNs
[net] = CNN_AL_MRF_train(Dir, data, cnn_net, Train, Ite, Test);

%% step 3: test
[Train] = CNN_AL_MRF_test(Dir, data, alg, cnn_net, Train, Ite, Test);
end
30 changes: 30 additions & 0 deletions CNN_AL_MRF_plot.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
function CNN_AL_MRF_plot(labels, size, dir)

labels = reshape(labels, size);
im = zeros([size, 3]);
colorbar = [[0 0 0]; ...%% 0. Black: background
[0 0 255]; ...%% 1. Blue
[0 255 0]; ...%% 2. Beans?
[255 0 0]; ...%% 3. Beet?
[0 102 205]; ...%% 4. dzÀ¶£¿
[0 204 102]; ...%% 5. Maize
[255 128 0]; ...%% 6. Potato
[102 205 0]; ...%% 7. ?
[102 0 204]; ...%% 8. Lucerne
[204 0 102]; ...%% 9. Rapeseed
[204 102 255]; ...%% 10. Peas?
[251 232 45]; ...%% 11. Wheat
[138 42 166]; ...%% 12. Fruit?
[120 178 215]; ...%% 13. Barley
[204 255 204]; ...%% 14. Flax?
[255 204 204]; ...%% 15. Grass?
[40 210 180]; ...%% 16. ?
]/255;
for i=1:size(1)
for j=1:size(2)
im(i, j, :) = colorbar(labels(i, j)+1, :);
end
end
figure, imshow(im);
save(strcat(dir,'.mat'), 'im');
imwrite(im, strcat(dir, '.jpg'));
69 changes: 69 additions & 0 deletions CNN_AL_MRF_prepdata.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
function [Train, Test] = CNN_AL_MRF_prepdata(data, cnn_net, alg, Train, Patch)

% update train set for all
Train.Set_All = [Train.Set_All, Train.Set];

% update train pool
for i = 1:size(Train.Set, 2)
Train.Pool(:, find(Train.Pool(1, :) == Train.Set(1, i))) = [];
end

% Construct Test Dataset: load padded feature into #Test_num voxels
Test.Set = Train.Pool;
Test.Num = size(Train.Pool, 2);
Test.Data = single(zeros(data.SizeWin, data.SizeWin, size(data.F_pad,3), Test.Num));

for i=1:Test.Num
[x, y] = Index1Dto2D(Test.Set(1,i), data.SizeOri(1), data.SizeOri(2));
% patch center index: x+HalfWin, y+HalfWin
Test.Data(:,:,:,i) = data.F_pad(x+Patch.Start:x+Patch.End, y+Patch.Start:y+Patch.End, :);
end

% construct Train Dataset: load padded feature into #Train_num voxels
Train_num = size(Train.Set, 2);
Train_Data = single(zeros(data.SizeWin, data.SizeWin, size(data.F_pad,3), Train_num));

for i=1:Train_num
[x, y] = Index1Dto2D(Train.Set(1,i), data.SizeOri(1), data.SizeOri(2));
Train_Data(:,:,:,i) = data.F_pad(x+Patch.Start:x+Patch.End, y+Patch.Start:y+Patch.End, :);
end

% index partition for train and val.(in train dataset)
Val_num = floor(alg.CrossVal * Train_num);
tmp = randperm(Train_num);
Val_Idx = tmp(1:Val_num);
Train_Idx = tmp(Val_num+1: end);

% construct original train and cross validation data
Train_Data_Val = Train_Data(:, :, :, Val_Idx);
Train_Set_Val = Train.Set(:, Val_Idx);

Train_Data_Train = Train_Data(: ,:, :, Train_Idx);
Train_Set_Train = Train.Set(:, Train_Idx);

if cnn_net.flagDA
% Augment the data by rotating and flipping
Val_Data_1 = rot90(Train_Data_Val, 1);
Val_Data_2 = rot90(Train_Data_Val, 2);
Val_Data_3 = rot90(Train_Data_Val, 3);
Val_Data_4 = flipud(Train_Data_Val);
Val_Data_5 = fliplr(Train_Data_Val);

Train_Data_Val = cat(4, Train_Data_Val, Val_Data_1, Val_Data_2, Val_Data_3, Val_Data_4, Val_Data_5);
Train_Set_Val = repmat(Train_Set_Val, 1, 6);

Train_Data_1 = rot90(Train_Data_Train, 1);
Train_Data_2 = rot90(Train_Data_Train, 2);
Train_Data_3 = rot90(Train_Data_Train, 3);
Train_Data_4 = flipud(Train_Data_Train);
Train_Data_5 = fliplr(Train_Data_Train);

Train_Data_Train = cat(4, Train_Data_Train, Train_Data_1, Train_Data_2, Train_Data_3, Train_Data_4, Train_Data_5);
Train_Set_Train = repmat(Train_Set_Train, 1, 6);
end

% update Train
Train.Data_Val = cat(4, Train.Data_Val, Train_Data_Val);
Train.Set_Val = cat(2, Train.Set_Val, Train_Set_Val);
Train.Data_Train = cat(4, Train.Data_Train, Train_Data_Train);
Train.Set_Train = cat(2, Train.Set_Train, Train_Set_Train);
80 changes: 80 additions & 0 deletions CNN_AL_MRF_preprocess.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
function [data, Train, Patch] = CNN_AL_MRF_preprocess(Dir, data, alg)

Img_Size = data.SizeOri(1:2);

% Load feature and ground truth labels: Feature.mat, GT.mat
load(fullfile(Dir.Data, data.NameMat{1}));
load(fullfile(Dir.Data, data.NameMat{2}));
GT = double(GT);
data.GT = GT;

% if data.flagPCA
% data.F = Feature;
% % PCA Dim Reduction
% Feature = reshape(Feature, [prod(lw), data.SizeOri(3)]);
% [eigvec, ~] = PCA(Feature, data);
% Feature = Feature*eigvec;
% Feature = reshape(Feature, [lw, data.ReducedDim]);
% % normalize
% for b = 1:data.ReducedDim
% Feature(:, :, b) = (Feature(:, :, b) - min(min(Feature(:, :, b))))./(max(max(Feature(:, :, b))) ...
% - min(min(Feature(:, :, b))));
% end
% else
% for b = 1:data.SizeOri(3)
% Feature(:, :, b) = (Feature(:, :, b) - min(min(Feature(:, :, b))))./(max(max(Feature(:, :, b))) ...
% - min(min(Feature(:, :, b))));
% end
% data.F_norm = Feature;
% end

for b = 1:data.SizeOri(3)
Feature(:, :, b) = (Feature(:, :, b) - min(min(Feature(:, :, b))))./(max(max(Feature(:, :, b))) ...
- min(min(Feature(:, :, b))));
end
data.F_norm = Feature;

% padding
% if data.SizeWin is even£¬left padding HalfWin£¬right padding HalfWin-1
% if data.SizeWin is odd£¬left padding HalfWin£¬right padding HalfWin
HalfWin = floor(data.SizeWin/2);
if (mod(data.SizeWin, 2) == 0)
Patch.Start = HalfWin-HalfWin;
Patch.End = HalfWin+HalfWin-1;
else
Patch.Start = HalfWin-HalfWin;
Patch.End = HalfWin+HalfWin;
end
data.F_pad = single(zeros(Img_Size(1)+Patch.End, Img_Size(2)+Patch.End, size(Feature, 3)));
data.F_pad(HalfWin+1:Img_Size(1)+HalfWin, HalfWin+1:Img_Size(2)+HalfWin, :) = Feature;

data.MRF = single(zeros(data.SizeWin, data.SizeWin, size(data.F_pad,3), prod(Img_Size)));
if alg.flagMRF
for i=1:prod(Img_Size)
[x, y] = Index1Dto2D(i, Img_Size(1), Img_Size(2));
data.MRF(:,:,:,i) = data.F_pad(x+Patch.Start:x+Patch.End, y+Patch.Start:y+Patch.End, :);
end
end

% randomly select the training set
GT_1d = data.GT(:)';
GT_indexes = find(GT_1d);
Train.Pool = [GT_indexes; GT_1d(GT_indexes)];

% randomly sampling with seed
rng(alg.RngSeed), tmp = randperm(length(GT_indexes));
tmp_indexes = GT_indexes(tmp);
tmp_GT = GT_1d(tmp_indexes);
if strcmp(alg.SampleSty, 'Classwise') && strcmp(alg.CountSty, 'Ratio')
Train.Set = [];
for l=1:max(tmp_GT)
Train_GlbIndexes = tmp_indexes(find(tmp_GT==l));
%tmp = max(floor(length(Train_GlbIndexes)*alg.RatioTrn1st), 2);
tmp = round(length(Train_GlbIndexes)*alg.RatioTrn1st);
Train_Set = [Train_GlbIndexes(1:tmp);GT_1d(Train_GlbIndexes(1:tmp))];
Train.Set = [Train.Set, Train_Set];
end
elseif strcmp(alg.SampleSty, 'Rd') && strcmp(alg.CountSty, 'Num')
Train_GlbIndexes = tmp_indexes(1:alg.NumTrn1st);
Train.Set = [Train_GlbIndexes; GT_1d(Train_GlbIndexes)];
end
Loading

0 comments on commit 220f68b

Please sign in to comment.