FuzzySystems/Work 3/source/plot_results2.m

107 lines
5.0 KiB
Matlab

function plot_results2(init_fis, final_fis, trn_error, val_error, y_true, y_pred, ...
feature_grid, radius_grid, cv_scores, cv_rules, sel_idx)
% PLOTS FOR SCENARIO 2:
% (1) CV heatmap (mean CV error over grid) with overlaid #rules
% (2) Error vs rules (aggregated) & Error vs #features (best over radii)
% (3) Final model: Learning curves, Predicted vs Actual, Residuals
% (4) MFs (subset of selected inputs) BEFORE/AFTER
%
% All figures are saved under ./figures
outdir = fullfile('.', 'figures_scn2'); if ~exist(outdir,'dir'), mkdir(outdir); end
% (1) CV heatmap
f = figure('Color','w','Name','CV Heatmap');
imagesc(radius_grid, feature_grid, cv_scores);
set(gca,'YDir','normal'); xlabel('Cluster radius r_\alpha','Interpreter','latex');
ylabel('#Features','Interpreter','latex');
title('Mean CV Error','Interpreter','latex'); colorbar; grid on;
% overlay mean rules
hold on;
for i=1:numel(feature_grid)
for j=1:numel(radius_grid)
text(radius_grid(j), feature_grid(i), sprintf('%d', cv_rules(i,j)), ...
'HorizontalAlignment','center','Color','w','FontSize',8);
end
end
try
subtitle('Numbers show mean #rules','Interpreter','latex');
end
print(f, fullfile(outdir,'scn2_cv_heatmap'), '-dpng');
% (2a) Error vs #rules (aggregate by identical rule-counts)
uniq_rules = unique(cv_rules(:));
err_vs_rules = zeros(size(uniq_rules));
for k=1:numel(uniq_rules)
err_vs_rules(k) = mean(cv_scores(cv_rules == uniq_rules(k)));
end
f = figure('Color','w','Name','Error vs Rules');
plot(uniq_rules, err_vs_rules, 'o-','LineWidth',1.5); grid on;
xlabel('#Rules','Interpreter','latex'); ylabel('Mean CV Error','Interpreter','latex');
title('Error vs Rules','Interpreter','latex');
print(f, fullfile(outdir,'scn2_error_vs_rules'), '-dpng');
% (2b) Best CV error vs #features (min across radii)
[min_err_per_feat,~] = min(cv_scores,[],2);
f = figure('Color','w','Name','Error vs Features');
plot(feature_grid, min_err_per_feat, 's-','LineWidth',1.5); grid on;
xlabel('#Features','Interpreter','latex'); ylabel('Best CV Error','Interpreter','latex');
title('Best CV Error vs #Features','Interpreter','latex');
print(f, fullfile(outdir,'scn2_error_vs_features'), '-dpng');
% (3a) Learning curves
f = figure('Color','w','Name','Learning Curves');
plot(trn_error,'LineWidth',1.5); hold on; grid on;
if ~isempty(val_error)
plot(val_error,'LineWidth',1.5); legend('Train','Validation','Location','best');
else
legend('Train','Location','best');
end
xlabel('Epoch','Interpreter','latex'); ylabel('Error','Interpreter','latex');
title('Learning Curves','Interpreter','latex');
print(f, fullfile(outdir,'scn2_final_learning_curves'), '-dpng');
% (3b) Predicted vs Actual
y_true = y_true(:); y_pred = y_pred(:);
f = figure('Color','w','Name','Predicted vs Actual');
plot(y_true, y_pred, '.', 'MarkerSize', 10); hold on; grid on;
mins = min([y_true; y_pred]); maxs = max([y_true; y_pred]);
plot([mins maxs], [mins maxs], 'k-', 'LineWidth', 1);
xlabel('Actual','Interpreter','latex'); ylabel('Predicted','Interpreter','latex');
title('Predicted vs Actual','Interpreter','latex');
print(f, fullfile(outdir,'scn2_final_pred_vs_actual'), '-dpng');
% (3c) Residuals (time series)
err = y_true - y_pred;
f = figure('Color','w','Name','Prediction Error');
plot(err, 'k'); grid on;
xlabel('Sample','Interpreter','latex'); ylabel('Error','Interpreter','latex');
title('Prediction Error','Interpreter','latex');
mae = mean(abs(err));
try
subtitle(sprintf('Mean absolute error: %.6f', mae), 'Interpreter','latex');
end
print(f, fullfile(outdir,'scn2_final_error_series'), '-dpng');
% (4) MFs (subset of selected inputs) BEFORE/AFTER
% Show up to 3 selected inputs for clarity
nShow = min( min(3, numel(sel_idx)), numel(init_fis.Inputs) );
if nShow > 0
f = figure('Color','w','Name','MFs (subset)');
tl = tiledlayout(2, nShow, 'TileSpacing','compact', 'Padding','compact');
for k=1:nShow
inIdx = k; % first few selected
[xb,yb] = plotmf(init_fis,'input',inIdx);
[xa,ya] = plotmf(final_fis,'input',inIdx);
nexttile(tl,k); hold on; grid on; plot(xb,yb,'LineWidth',1.2);
title(sprintf('BEFORE x_{%d}', sel_idx(k)),'Interpreter','latex');
xlabel(sprintf('x_{%d}', sel_idx(k)),'Interpreter','latex'); ylabel('Membership','Interpreter','latex');
nexttile(tl,nShow+k); hold on; grid on; plot(xa,ya,'LineWidth',1.2);
title(sprintf('AFTER x_{%d}', sel_idx(k)),'Interpreter','latex');
xlabel(sprintf('x_{%d}', sel_idx(k)),'Interpreter','latex'); ylabel('Membership','Interpreter','latex');
end
sgtitle(tl, 'Membership Functions (subset)','Interpreter','latex');
print(f, fullfile(outdir,'scn2_final_mfs_subset'), '-dpng');
end
end