forked from xiangyongcao/CNN-AL-MRF
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
a67ed6f
commit 220f68b
Showing
18 changed files
with
1,023 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
function CNN_AL_MRF_GenWeit(data, DirHor, DirVer) | ||
|
||
BandInd = data.IndBand; | ||
%BandInd = [12,60,180]; | ||
if ~isfield(data, 'F_norm') | ||
Feature = data.F; | ||
% normalize | ||
for b=1:length(BandInd) | ||
Feature(:,:,BandInd(b)) = (Feature(:,:,BandInd(b))-min(min(Feature(:,:,BandInd(b)))))./... | ||
(max(max(Feature(:,:,BandInd(b))))-min(min(Feature(:,:,BandInd(b))))); | ||
end | ||
end | ||
Feature = data.F_norm; | ||
HSI_rgb = im2uint8(Feature(:,:,BandInd)); | ||
% figure, imshow(HSI_rgb) | ||
|
||
H = size(HSI_rgb, 1); | ||
W = size(HSI_rgb, 2); | ||
|
||
R = flipud(HSI_rgb(:,:,1)); R = R(1:H,:);R = R(:,1:W); | ||
G = flipud(HSI_rgb(:,:,2)); G = G(1:H,:);G = G(:,1:W); | ||
B = flipud(HSI_rgb(:,:,3)); B = B(1:H,:);B = B(:,1:W); | ||
|
||
for i = 2 : H | ||
for j = 2 : W | ||
Dij_left(i-1,j-1) = (abs(R(i,j) - R(i-1,j)))^2; | ||
Dij_up(i-1,j-1) = (abs(R(i,j) - R(i,j-1)))^2; | ||
end | ||
end | ||
var_left = std(double(Dij_left(:))); | ||
var_up = std(double(Dij_up(:))); | ||
HorzWeight1 = exp(-double(Dij_left)/(2*var_left)); | ||
VertWeight1 = exp(-double(Dij_up)/(2*var_up)); | ||
|
||
for i = 2 : H | ||
for j = 2 : W | ||
Dij_left(i-1,j-1) = (abs(G(i,j) - G(i-1,j)))^2; | ||
Dij_up(i-1,j-1) = (abs(G(i,j) - G(i,j-1)))^2; | ||
end | ||
end | ||
var_left = std(double(Dij_left(:))); | ||
var_up = std(double(Dij_up(:))); | ||
HorzWeight2 = exp(-double(Dij_left)/(2*var_left)); | ||
VertWeight2 = exp(-double(Dij_up)/(2*var_up)); | ||
|
||
for i = 2 : H | ||
for j = 2 : W | ||
Dij_left(i-1,j-1) = (abs(B(i,j) - B(i-1,j)))^2; | ||
Dij_up(i-1,j-1) = (abs(B(i,j) - B(i,j-1)))^2; | ||
end | ||
end | ||
var_left = std(double(Dij_left(:))); | ||
var_up = std(double(Dij_up(:))); | ||
HorzWeight3 = exp(-double(Dij_left)/(2*var_left)); | ||
VertWeight3 = exp(-double(Dij_up)/(2*var_up)); | ||
|
||
HorzWeight = real(HorzWeight1+HorzWeight2+HorzWeight3)/3; | ||
VertWeight = real(VertWeight1+VertWeight2+VertWeight3)/3; | ||
|
||
% figure; | ||
% mesh(HorzWeight); | ||
% title('HorzWeight'); | ||
% axis([1 W 1 H]); | ||
|
||
% figure; | ||
% mesh(VertWeight); | ||
% title('VertWeight'); | ||
% axis([1 W 1 H]); | ||
|
||
save(DirHor, 'HorzWeight', '-ascii'); | ||
save(DirVer, 'VertWeight', '-ascii'); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
|
||
%% Experiment for HSIs classification using CNNs, active learning and MRF incorporated model | ||
% writen by ... | ||
% Aug. 2019. | ||
% the folder containing this function is suggested to be located under ...\matconvnet\examples\ | ||
% | ||
% if you find this code useful, please kindly cite the following papers: | ||
% [1] | ||
% [2] | ||
% ... | ||
|
||
close all; clear, clc; | ||
Dir.Cur = pwd; | ||
addpath(genpath(Dir.Cur)); | ||
|
||
%% Parameters for data | ||
data.NameFolder = {'IndianPines', 'PaviaU', 'PaviaCenter'}; | ||
data.NameMat = {'GT.mat', 'Feature.mat'}; | ||
data.SizeOri = {[145, 145, 220], [610, 340, 103], [400, 300, 102]}; | ||
data.SizeWin = 8; | ||
data.NumClass = {16, 9, 8}; | ||
data.IndBand = {[10, 80, 200], [12, 67, 98], [10, 60, 90]}; % to generate false RGB, which should be less contaminated bands | ||
%data.flagPCA = true; | ||
%data.ReducedDim = 10; | ||
|
||
% Three datasets: | ||
% data.flagSet = 1, Indian Pines; | ||
% = 2, Pavia University; | ||
% = 3, Pavia Center. | ||
data.flagSet = 1; | ||
|
||
data.NameFolder = data.NameFolder{data.flagSet}; | ||
data.SizeOri = data.SizeOri{data.flagSet}; | ||
data.NumClass = data.NumClass{data.flagSet}; | ||
data.IndBand = data.IndBand{data.flagSet}; | ||
|
||
%% Parameters for algorithm | ||
alg.SampleSty = 'Rd'; % out of {'Rd', 'Classwise'} | ||
alg.CountSty = 'Num'; % out of {'Num', 'Ratio'} | ||
alg.NumTrn1st = {250, 107, 58}; | ||
alg.NumTrn1st = alg.NumTrn1st{data.flagSet}; | ||
% if alg.CountSty == 'Ratio' | ||
%alg.RatioTrn1st = {0.02, 0.0025, 0.0025}; | ||
%alg.RatioTrn1st = alg.RatioTrn1st{data.flagSet}; | ||
alg.CrossVal = 0.05; | ||
alg.NumAlAugPerIte = {[250, 150, 100, 50], [107, 107, 107], [26, 20]}; % The training samples added in each iteration keeps the same ratio with the training sample number of the first iteration | ||
alg.NumAlAugPerIte = alg.NumAlAugPerIte{data.flagSet}; | ||
alg.NumIter = length(alg.NumAlAugPerIte)+1; | ||
alg.AlStra = 'BvSB'; % out of {'BvSB', 'RS', 'EP'}; | ||
alg.SmoothFactor = 8; % around 10 | ||
alg.flagAL = true; | ||
alg.flagMRF = true; | ||
|
||
alg.RngSeed = 4; | ||
|
||
%% Parameters for CNN module | ||
cnn_net.NumEpoch = {[800, 400, 400, 300, 200], [400, 200, 200, 200], [600, 400, 300]}; % Finetune helps to coverge faster | ||
cnn_net.SizeKer1Conv = 3; cnn_net.SizeKer2Conv = 2; | ||
cnn_net.NumKer1Conv = 20; cnn_net.NumKer2Conv = 20; | ||
cnn_net.SizeKer1Pool = 2; cnn_net.SizeKer2Pool = 2; | ||
cnn_net.Stride1Pool = 2; cnn_net.Stride2Pool = 2; | ||
cnn_net.SizeKer1FC = 2; cnn_net.SizeKer2FC = 1; | ||
cnn_net.NumKer1FC = 500; cnn_net.NumKer2FC = data.NumClass; | ||
cnn_net.WeightDecay = 0.0005; cnn_net.BatchSize = 50; | ||
cnn_net.LearningRate = 0.001; % or logspace(-3, -4, cnn_net.NumEpoch(Ite)) | ||
cnn_net.flagDA = true; cnn_net.flagBN = true; | ||
cnn_net.flagFT = true; | ||
|
||
if cnn_net.flagFT | ||
cnn_net.NumEpoch = cnn_net.NumEpoch{data.flagSet}; | ||
else | ||
cnn_net.NumEpoch = [800, 400, 600]; | ||
cnn_net.NumEpoch = cnn_net.NumEpoch(data.flagSet); | ||
end | ||
|
||
data = rmfield(data, 'flagSet'); | ||
|
||
%% make folders | ||
cd .. | ||
cd .. | ||
Dir.Root = pwd; | ||
Dir.Exp = strcat(pwd, '\data\HSI\', data.NameFolder, '\RngSeed-', num2str(alg.RngSeed)); | ||
if ~cnn_net.flagDA | ||
Dir.Exp = fullfile(Dir.Exp, 'CNN'); | ||
else if ~cnn_net.flagBN | ||
Dir.Exp = fullfile(Dir.Exp, 'DA'); | ||
else if ~alg.flagAL | ||
Dir.Exp = fullfile(Dir.Exp, 'BN'); | ||
else if ~cnn_net.flagFT | ||
Dir.Exp = fullfile(Dir.Exp, 'AL'); | ||
else if alg.flagMRF | ||
Dir.Exp = fullfile(Dir.Exp, 'FT-MRF'); | ||
end | ||
end | ||
end | ||
end | ||
end | ||
Dir.Results = fullfile(Dir.Exp, 'Results'); | ||
|
||
if ~exist(Dir.Exp) || ~exist(Dir.Results) | ||
mkdir(Dir.Exp); | ||
mkdir(Dir.Results); | ||
end | ||
|
||
cd(Dir.Cur) | ||
Dir.Data = fullfile(pwd, 'Data', data.NameFolder); | ||
|
||
%% Generate initialized labeled pixels | ||
[data, Train, Patch] = CNN_AL_MRF_preprocess(Dir, data, alg); | ||
|
||
% initialize | ||
Train.Set_All = []; | ||
Train.Data_Train = []; Train.Set_Train = []; | ||
Train.Data_Val = []; Train.Set_Val = []; | ||
|
||
%% Iterations of the overall algorithm, including data preparation, CNNs' training and testing (with active learning), post-preprossing with or w/o. MRF | ||
for Ite = 1:alg.NumIter | ||
|
||
%% step 1: data preparation (with data augmentation) | ||
[Train, Test] = CNN_AL_MRF_prepdata(data, cnn_net, alg, Train, Patch); | ||
|
||
%% step 2: train CNNs | ||
[net] = CNN_AL_MRF_train(Dir, data, cnn_net, Train, Ite, Test); | ||
|
||
%% step 3: test | ||
[Train] = CNN_AL_MRF_test(Dir, data, alg, cnn_net, Train, Ite, Test); | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
function CNN_AL_MRF_plot(labels, size, dir) | ||
|
||
labels = reshape(labels, size); | ||
im = zeros([size, 3]); | ||
colorbar = [[0 0 0]; ...%% 0. Black: background | ||
[0 0 255]; ...%% 1. Blue | ||
[0 255 0]; ...%% 2. Beans? | ||
[255 0 0]; ...%% 3. Beet? | ||
[0 102 205]; ...%% 4. dzÀ¶£¿ | ||
[0 204 102]; ...%% 5. Maize | ||
[255 128 0]; ...%% 6. Potato | ||
[102 205 0]; ...%% 7. ? | ||
[102 0 204]; ...%% 8. Lucerne | ||
[204 0 102]; ...%% 9. Rapeseed | ||
[204 102 255]; ...%% 10. Peas? | ||
[251 232 45]; ...%% 11. Wheat | ||
[138 42 166]; ...%% 12. Fruit? | ||
[120 178 215]; ...%% 13. Barley | ||
[204 255 204]; ...%% 14. Flax? | ||
[255 204 204]; ...%% 15. Grass? | ||
[40 210 180]; ...%% 16. ? | ||
]/255; | ||
for i=1:size(1) | ||
for j=1:size(2) | ||
im(i, j, :) = colorbar(labels(i, j)+1, :); | ||
end | ||
end | ||
figure, imshow(im); | ||
save(strcat(dir,'.mat'), 'im'); | ||
imwrite(im, strcat(dir, '.jpg')); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
function [Train, Test] = CNN_AL_MRF_prepdata(data, cnn_net, alg, Train, Patch) | ||
|
||
% update train set for all | ||
Train.Set_All = [Train.Set_All, Train.Set]; | ||
|
||
% update train pool | ||
for i = 1:size(Train.Set, 2) | ||
Train.Pool(:, find(Train.Pool(1, :) == Train.Set(1, i))) = []; | ||
end | ||
|
||
% Construct Test Dataset: load padded feature into #Test_num voxels | ||
Test.Set = Train.Pool; | ||
Test.Num = size(Train.Pool, 2); | ||
Test.Data = single(zeros(data.SizeWin, data.SizeWin, size(data.F_pad,3), Test.Num)); | ||
|
||
for i=1:Test.Num | ||
[x, y] = Index1Dto2D(Test.Set(1,i), data.SizeOri(1), data.SizeOri(2)); | ||
% patch center index: x+HalfWin, y+HalfWin | ||
Test.Data(:,:,:,i) = data.F_pad(x+Patch.Start:x+Patch.End, y+Patch.Start:y+Patch.End, :); | ||
end | ||
|
||
% construct Train Dataset: load padded feature into #Train_num voxels | ||
Train_num = size(Train.Set, 2); | ||
Train_Data = single(zeros(data.SizeWin, data.SizeWin, size(data.F_pad,3), Train_num)); | ||
|
||
for i=1:Train_num | ||
[x, y] = Index1Dto2D(Train.Set(1,i), data.SizeOri(1), data.SizeOri(2)); | ||
Train_Data(:,:,:,i) = data.F_pad(x+Patch.Start:x+Patch.End, y+Patch.Start:y+Patch.End, :); | ||
end | ||
|
||
% index partition for train and val.(in train dataset) | ||
Val_num = floor(alg.CrossVal * Train_num); | ||
tmp = randperm(Train_num); | ||
Val_Idx = tmp(1:Val_num); | ||
Train_Idx = tmp(Val_num+1: end); | ||
|
||
% construct original train and cross validation data | ||
Train_Data_Val = Train_Data(:, :, :, Val_Idx); | ||
Train_Set_Val = Train.Set(:, Val_Idx); | ||
|
||
Train_Data_Train = Train_Data(: ,:, :, Train_Idx); | ||
Train_Set_Train = Train.Set(:, Train_Idx); | ||
|
||
if cnn_net.flagDA | ||
% Augment the data by rotating and flipping | ||
Val_Data_1 = rot90(Train_Data_Val, 1); | ||
Val_Data_2 = rot90(Train_Data_Val, 2); | ||
Val_Data_3 = rot90(Train_Data_Val, 3); | ||
Val_Data_4 = flipud(Train_Data_Val); | ||
Val_Data_5 = fliplr(Train_Data_Val); | ||
|
||
Train_Data_Val = cat(4, Train_Data_Val, Val_Data_1, Val_Data_2, Val_Data_3, Val_Data_4, Val_Data_5); | ||
Train_Set_Val = repmat(Train_Set_Val, 1, 6); | ||
|
||
Train_Data_1 = rot90(Train_Data_Train, 1); | ||
Train_Data_2 = rot90(Train_Data_Train, 2); | ||
Train_Data_3 = rot90(Train_Data_Train, 3); | ||
Train_Data_4 = flipud(Train_Data_Train); | ||
Train_Data_5 = fliplr(Train_Data_Train); | ||
|
||
Train_Data_Train = cat(4, Train_Data_Train, Train_Data_1, Train_Data_2, Train_Data_3, Train_Data_4, Train_Data_5); | ||
Train_Set_Train = repmat(Train_Set_Train, 1, 6); | ||
end | ||
|
||
% update Train | ||
Train.Data_Val = cat(4, Train.Data_Val, Train_Data_Val); | ||
Train.Set_Val = cat(2, Train.Set_Val, Train_Set_Val); | ||
Train.Data_Train = cat(4, Train.Data_Train, Train_Data_Train); | ||
Train.Set_Train = cat(2, Train.Set_Train, Train_Set_Train); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
function [data, Train, Patch] = CNN_AL_MRF_preprocess(Dir, data, alg) | ||
|
||
Img_Size = data.SizeOri(1:2); | ||
|
||
% Load feature and ground truth labels: Feature.mat, GT.mat | ||
load(fullfile(Dir.Data, data.NameMat{1})); | ||
load(fullfile(Dir.Data, data.NameMat{2})); | ||
GT = double(GT); | ||
data.GT = GT; | ||
|
||
% if data.flagPCA | ||
% data.F = Feature; | ||
% % PCA Dim Reduction | ||
% Feature = reshape(Feature, [prod(lw), data.SizeOri(3)]); | ||
% [eigvec, ~] = PCA(Feature, data); | ||
% Feature = Feature*eigvec; | ||
% Feature = reshape(Feature, [lw, data.ReducedDim]); | ||
% % normalize | ||
% for b = 1:data.ReducedDim | ||
% Feature(:, :, b) = (Feature(:, :, b) - min(min(Feature(:, :, b))))./(max(max(Feature(:, :, b))) ... | ||
% - min(min(Feature(:, :, b)))); | ||
% end | ||
% else | ||
% for b = 1:data.SizeOri(3) | ||
% Feature(:, :, b) = (Feature(:, :, b) - min(min(Feature(:, :, b))))./(max(max(Feature(:, :, b))) ... | ||
% - min(min(Feature(:, :, b)))); | ||
% end | ||
% data.F_norm = Feature; | ||
% end | ||
|
||
for b = 1:data.SizeOri(3) | ||
Feature(:, :, b) = (Feature(:, :, b) - min(min(Feature(:, :, b))))./(max(max(Feature(:, :, b))) ... | ||
- min(min(Feature(:, :, b)))); | ||
end | ||
data.F_norm = Feature; | ||
|
||
% padding | ||
% if data.SizeWin is even£¬left padding HalfWin£¬right padding HalfWin-1 | ||
% if data.SizeWin is odd£¬left padding HalfWin£¬right padding HalfWin | ||
HalfWin = floor(data.SizeWin/2); | ||
if (mod(data.SizeWin, 2) == 0) | ||
Patch.Start = HalfWin-HalfWin; | ||
Patch.End = HalfWin+HalfWin-1; | ||
else | ||
Patch.Start = HalfWin-HalfWin; | ||
Patch.End = HalfWin+HalfWin; | ||
end | ||
data.F_pad = single(zeros(Img_Size(1)+Patch.End, Img_Size(2)+Patch.End, size(Feature, 3))); | ||
data.F_pad(HalfWin+1:Img_Size(1)+HalfWin, HalfWin+1:Img_Size(2)+HalfWin, :) = Feature; | ||
|
||
data.MRF = single(zeros(data.SizeWin, data.SizeWin, size(data.F_pad,3), prod(Img_Size))); | ||
if alg.flagMRF | ||
for i=1:prod(Img_Size) | ||
[x, y] = Index1Dto2D(i, Img_Size(1), Img_Size(2)); | ||
data.MRF(:,:,:,i) = data.F_pad(x+Patch.Start:x+Patch.End, y+Patch.Start:y+Patch.End, :); | ||
end | ||
end | ||
|
||
% randomly select the training set | ||
GT_1d = data.GT(:)'; | ||
GT_indexes = find(GT_1d); | ||
Train.Pool = [GT_indexes; GT_1d(GT_indexes)]; | ||
|
||
% randomly sampling with seed | ||
rng(alg.RngSeed), tmp = randperm(length(GT_indexes)); | ||
tmp_indexes = GT_indexes(tmp); | ||
tmp_GT = GT_1d(tmp_indexes); | ||
if strcmp(alg.SampleSty, 'Classwise') && strcmp(alg.CountSty, 'Ratio') | ||
Train.Set = []; | ||
for l=1:max(tmp_GT) | ||
Train_GlbIndexes = tmp_indexes(find(tmp_GT==l)); | ||
%tmp = max(floor(length(Train_GlbIndexes)*alg.RatioTrn1st), 2); | ||
tmp = round(length(Train_GlbIndexes)*alg.RatioTrn1st); | ||
Train_Set = [Train_GlbIndexes(1:tmp);GT_1d(Train_GlbIndexes(1:tmp))]; | ||
Train.Set = [Train.Set, Train_Set]; | ||
end | ||
elseif strcmp(alg.SampleSty, 'Rd') && strcmp(alg.CountSty, 'Num') | ||
Train_GlbIndexes = tmp_indexes(1:alg.NumTrn1st); | ||
Train.Set = [Train_GlbIndexes; GT_1d(Train_GlbIndexes)]; | ||
end |
Oops, something went wrong.