FuzzySystems/Work 4/source/scenario2.m

298 lines
9.7 KiB
Matlab
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

%% scenario2.m — Assignment 4 (Classification), Scenario 2 (Epileptic Seizure)
% TSK classification on a high-dimensional dataset with feature selection.
% Verbose version with progress printing.
%
% Uses: split_data, preprocess_data, evaluate_classification, plot_results2
% Dataset path: ./Datasets/epileptic_seizure_data.csv
%
% Assignment 4 in Fuzzy systems
%
% author:
% Christos Choutouridis ΑΕΜ 8997
% cchoutou@ece.auth.gr
close all; clear; clc;
%parpool('threads',6);
fprintf('\nScenario 2 - Epileptic Seizure Classification');
fprintf('\n================================================\n\n');
% CONFIGURATION
cfg = struct();
rng(42,'twister'); % reproducibility
% Data handling
cfg.split = [0.6 0.2 0.2];
cfg.standardize = true;
% Feature selection + SC hyper-params
% Debug configuration
cfg.feature_grid = [5 8]; %[5 8 11 15];
cfg.radii_grid = [0.5 0.75]; %[0.25 0.50 0.75 1.00];
cfg.kfold = 3;
cfg.maxEpochs = 20; % ANFIS options
cfg.displayANFIS = 0;
% Default configuraion
%cfg.feature_grid = [5 8 11 15];
%cfg.radii_grid = [0.25 0.50 0.75 1.00];
%cfg.kfold = 5;
%cfg.maxEpochs = 100; % ANFIS options
%cfg.displayANFIS = 0;
% Output directory
cfg.outDir = 'figures_scn2';
if ~exist(cfg.outDir,'dir'), mkdir(cfg.outDir); end
fprintf('Configuration loaded: %d folds, %d feature options, %d radius options.\n', ...
cfg.kfold, numel(cfg.feature_grid), numel(cfg.radii_grid));
% DATA
dataPath = './Datasets/epileptic_seizure_data.csv';
fprintf('Loading dataset from %s ...\n', dataPath);
assert(isfile(dataPath), 'Dataset not found!');
raw = importdata(dataPath);
if isstruct(raw) && isfield(raw,'data')
A = raw.data;
else
A = readmatrix(dataPath);
end
X = A(:,1:end-1);
Y = A(:,end);
Y = double(Y(:));
classLabels = unique(Y);
num_classes = numel(classLabels);
fprintf('Dataset loaded: %d samples, %d features, %d classes.\n', ...
size(X,1), size(X,2), num_classes);
% SPLIT & PREPROCESS
fprintf('\nSplitting data into train/val/test (%.0f/%.0f/%.0f%%)...\n', cfg.split*100);
[trainX, valX, testX, trainY, valY, testY] = split_data(X, Y, cfg.split);
fprintf('-> train: %d val: %d test: %d\n', size(trainX,1), size(valX,1), size(testX,1));
if cfg.standardize
fprintf('Applying z-score normalization...\n');
[trainX, mu, sigma] = preprocess_data(trainX);
valX = preprocess_data(valX, mu, sigma);
testX = preprocess_data(testX, mu, sigma);
else
mu = []; sigma = [];
end
fullTrainX = [trainX; valX];
fullTrainY = [trainY; valY];
% GRID SEARCH
fprintf('\nGRID SEARCH (features × radius) using %d-fold CV\n', cfg.kfold);
cvp = cvpartition(trainY, 'KFold', cfg.kfold, 'Stratify', true);
nF = numel(cfg.feature_grid);
nR = numel(cfg.radii_grid);
cvScores = zeros(nF, nR);
cvRules = zeros(nF, nR);
for fi = 1:nF
featKeep = cfg.feature_grid(fi);
for ri = 1:nR
radius = cfg.radii_grid(ri);
fprintf('\n[GRID] features=%2d, radius=%.2f ... ', featKeep, radius);
kappas = zeros(cvp.NumTestSets,1);
rulesK = zeros(cvp.NumTestSets,1);
for k = 1:cvp.NumTestSets
fprintf('\n-> Fold %d/%d ... ', k, cfg.kfold);
trIdx = training(cvp, k);
teIdx = test(cvp, k);
Xtr = trainX(trIdx,:); Ytr = trainY(trIdx);
Xva = trainX(teIdx,:); Yva = trainY(teIdx);
% Relief feature selection
[idxFeat, ~] = relief_select(Xtr, Ytr);
sel = idxFeat(1:min(featKeep, numel(idxFeat)));
Xtr = Xtr(:, sel);
Xva = Xva(:, sel);
% Build FIS
inRanges = [min(Xtr,[],1); max(Xtr,[],1)];
initFis = build_classdep_fis(Xtr, Ytr, classLabels, radius, inRanges);
% Train
trData = [Xtr double(Ytr)];
vaData = [Xva double(Yva)];
anfisOpts = anfisOptions('InitialFis', initFis, ...
'EpochNumber', cfg.maxEpochs, ...
'ValidationData', vaData, ...
'OptimizationMethod', 1, ...
'DisplayErrorValues', 0, ...
'DisplayStepSize', 0);
[~, ~, ~, bestFis, ~] = anfis(trData, anfisOpts);
% Evaluate fold
yhat = evalfis(bestFis, Xva);
yhat = round(yhat);
yhat(yhat < min(classLabels)) = min(classLabels);
yhat(yhat > max(classLabels)) = max(classLabels);
R = evaluate_classification(Yva, yhat, classLabels);
kappas(k) = R.Kappa;
rulesK(k) = numel(bestFis.rule);
fprintf('kappa=%.3f rules=%d\n', R.Kappa, rulesK(k));
end
cvScores(fi,ri) = mean(kappas);
cvRules(fi,ri) = round(mean(rulesK));
fprintf('\n-> mean Kappa=%.3f mean rules=%d\n', cvScores(fi,ri), cvRules(fi,ri));
end
end
[maxPerRow, idxR] = max(cvScores, [], 2);
[bestKappa, idxF] = max(maxPerRow);
idxR = idxR(idxF);
bestFeatures = cfg.feature_grid(idxF);
bestRadius = cfg.radii_grid(idxR);
bestRulesEst = cvRules(idxF, idxR);
fprintf('\nBEST HYPERPARAMS\nfeatures=%d radius=%.2f CV Kappa=%.3f mean rules=%d\n', ...
bestFeatures, bestRadius, bestKappa, bestRulesEst);
% FINAL TRAIN
fprintf('\nTraining final model on train+val with best params ...\n');
[idxAll, weightsAll] = relief_select(fullTrainX, fullTrainY);
sel = idxAll(1:min(bestFeatures, numel(idxAll)));
Xtr = fullTrainX(:, sel);
Xte = testX(:, sel);
inRanges = [min(Xtr,[],1); max(Xtr,[],1)];
initFis = build_classdep_fis(Xtr, fullTrainY, classLabels, bestRadius, inRanges);
trData = [Xtr double(fullTrainY)];
teData = [Xte double(testY)];
anfisOpts = anfisOptions('InitialFis', initFis, ...
'EpochNumber', cfg.maxEpochs, ...
'ValidationData', teData, ...
'OptimizationMethod', 1, ...
'DisplayErrorValues', 0, ...
'DisplayStepSize', 0);
[fisTrained, trError, ~, bestFis, vaError] = anfis(trData, anfisOpts);
fprintf('Final training complete: %d rules.\n', numel(bestFis.rule));
% TEST EVAL
fprintf('\nEvaluating on TEST set ...\n');
yhat_test = evalfis(bestFis, Xte);
yhat_test = round(yhat_test);
yhat_test(yhat_test < min(classLabels)) = min(classLabels);
yhat_test(yhat_test > max(classLabels)) = max(classLabels);
Rtest = evaluate_classification(testY, yhat_test, classLabels);
fprintf('\n[TEST RESULTS]\n');
fprintf(' OA = %.2f %%\n', 100*Rtest.OA);
fprintf(' Kappa= %.3f\n', Rtest.Kappa);
fprintf(' Rules= %d\n', numel(bestFis.rule));
% PLOTTING
fprintf('\nGenerating figures ...\n');
results = struct();
results.cvScores = cvScores;
results.cvRules = cvRules;
results.fGrid = cfg.feature_grid;
results.rGrid = cfg.radii_grid;
results.bestF = numel(sel);
results.bestR = bestRadius;
results.bestFis = bestFis;
results.initFis = initFis;
results.trError = trError;
results.vaError = vaError;
results.ytrue = testY;
results.yhat = yhat_test;
results.metrics = Rtest;
results.selIdx = sel;
results.reliefW = weightsAll;
plot_results2(results, cfg, classLabels);
save('results_scn2.mat','results','cfg','classLabels','mu','sigma');
fprintf('\nDone. Figures saved in: %s\n', cfg.outDir);
% LOCAL FUNCTIONS
% ==================================================
function fis = build_classdep_fis(X, Y, classLabels, radius, inRanges)
% BUILD_CLASSDEP_FIS — class-dependent SC Sugeno FIS (ANFIS-ready)
% Creates ONE constant output MF PER RULE (ANFIS requirement).
% Runs subclust on FEATURES ONLY per class.
D = size(X,2);
fis = sugfis('Name','TSK_CD');
% Inputs
for d = 1:D
fis = addInput(fis, [inRanges(1,d) inRanges(2,d)], 'Name', sprintf('x%d', d));
end
% Output (range spans label space)
outRange = [min(classLabels) max(classLabels)];
fis = addOutput(fis, outRange, 'Name', 'y');
ruleList = [];
for k = 1:numel(classLabels)
c = classLabels(k);
Xi = X(Y==c, :);
if isempty(Xi), continue; end
[centers, sigmas] = subclust(Xi, radius);
nCl = size(centers,1);
% Robust sigma broadcasting to M×D
if isscalar(sigmas)
S = repmat(sigmas, nCl, D);
elseif size(sigmas,1)==1 && size(sigmas,2)==D
S = repmat(sigmas, nCl, 1);
elseif all(size(sigmas)==[nCl D])
S = sigmas;
else
S = repmat(0.5*(inRanges(2,:)-inRanges(1,:)), nCl, 1);
end
for i = 1:nCl
antIdx = zeros(1,D);
for d = 1:D
mfName = sprintf('c%d_r%d_x%d', c, i, d);
params = [S(i,d) centers(i,d)]; % [sigma center]
fis = addMF(fis, sprintf('x%d', d), 'gaussmf', params, 'Name', mfName);
antIdx(d) = numel(fis.Inputs(d).MembershipFunctions);
end
% ONE constant output MF per rule
outName = sprintf('const_c%d_r%d', c, i);
fis = addMF(fis, 'y', 'constant', double(c), 'Name', outName);
outIdx = numel(fis.Outputs(1).MembershipFunctions);
ruleList = [ruleList; [antIdx, outIdx, 1, 1]]; %#ok<AGROW>
end
end
if ~isempty(ruleList)
fis = addRule(fis, ruleList);
end
% Standard TSK ops
fis.AndMethod = 'prod';
fis.OrMethod = 'probor';
fis.ImplicationMethod = 'prod';
fis.AggregationMethod = 'sum';
fis.DefuzzificationMethod = 'wtaver';
end
function [idx, w] = relief_select(X, y)
% RELIEF_SELECT — wraps relieff and returns ranked indices + weights.
try
[idx, w] = relieff(X, y, 10); % k=10 neighbors
catch
% Fallback: simple variance ranking if Statistics Toolbox missing
w = var(X, 0, 1);
[~, idx] = sort(w, 'descend');
end
end