107 lines
5.0 KiB
Matlab
107 lines
5.0 KiB
Matlab
function plot_results2(init_fis, final_fis, trn_error, val_error, y_true, y_pred, ...
|
|
feature_grid, radius_grid, cv_scores, cv_rules, sel_idx)
|
|
% PLOTS FOR SCENARIO 2:
|
|
% (1) CV heatmap (mean CV error over grid) with overlaid #rules
|
|
% (2) Error vs rules (aggregated) & Error vs #features (best over radii)
|
|
% (3) Final model: Learning curves, Predicted vs Actual, Residuals
|
|
% (4) MFs (subset of selected inputs) BEFORE/AFTER
|
|
%
|
|
% All figures are saved under ./figures
|
|
|
|
outdir = fullfile('.', 'figures_scn2'); if ~exist(outdir,'dir'), mkdir(outdir); end
|
|
|
|
% (1) CV heatmap
|
|
f = figure('Color','w','Name','CV Heatmap');
|
|
imagesc(radius_grid, feature_grid, cv_scores);
|
|
set(gca,'YDir','normal'); xlabel('Cluster radius r_\alpha','Interpreter','latex');
|
|
ylabel('#Features','Interpreter','latex');
|
|
title('Mean CV Error','Interpreter','latex'); colorbar; grid on;
|
|
% overlay mean rules
|
|
hold on;
|
|
for i=1:numel(feature_grid)
|
|
for j=1:numel(radius_grid)
|
|
text(radius_grid(j), feature_grid(i), sprintf('%d', cv_rules(i,j)), ...
|
|
'HorizontalAlignment','center','Color','w','FontSize',8);
|
|
end
|
|
end
|
|
try
|
|
subtitle('Numbers show mean #rules','Interpreter','latex');
|
|
end
|
|
print(f, fullfile(outdir,'scn2_cv_heatmap'), '-dpng');
|
|
|
|
% (2a) Error vs #rules (aggregate by identical rule-counts)
|
|
uniq_rules = unique(cv_rules(:));
|
|
err_vs_rules = zeros(size(uniq_rules));
|
|
for k=1:numel(uniq_rules)
|
|
err_vs_rules(k) = mean(cv_scores(cv_rules == uniq_rules(k)));
|
|
end
|
|
f = figure('Color','w','Name','Error vs Rules');
|
|
plot(uniq_rules, err_vs_rules, 'o-','LineWidth',1.5); grid on;
|
|
xlabel('#Rules','Interpreter','latex'); ylabel('Mean CV Error','Interpreter','latex');
|
|
title('Error vs Rules','Interpreter','latex');
|
|
print(f, fullfile(outdir,'scn2_error_vs_rules'), '-dpng');
|
|
|
|
% (2b) Best CV error vs #features (min across radii)
|
|
[min_err_per_feat,~] = min(cv_scores,[],2);
|
|
f = figure('Color','w','Name','Error vs Features');
|
|
plot(feature_grid, min_err_per_feat, 's-','LineWidth',1.5); grid on;
|
|
xlabel('#Features','Interpreter','latex'); ylabel('Best CV Error','Interpreter','latex');
|
|
title('Best CV Error vs #Features','Interpreter','latex');
|
|
print(f, fullfile(outdir,'scn2_error_vs_features'), '-dpng');
|
|
|
|
% (3a) Learning curves
|
|
f = figure('Color','w','Name','Learning Curves');
|
|
plot(trn_error,'LineWidth',1.5); hold on; grid on;
|
|
if ~isempty(val_error)
|
|
plot(val_error,'LineWidth',1.5); legend('Train','Validation','Location','best');
|
|
else
|
|
legend('Train','Location','best');
|
|
end
|
|
xlabel('Epoch','Interpreter','latex'); ylabel('Error','Interpreter','latex');
|
|
title('Learning Curves','Interpreter','latex');
|
|
print(f, fullfile(outdir,'scn2_final_learning_curves'), '-dpng');
|
|
|
|
% (3b) Predicted vs Actual
|
|
y_true = y_true(:); y_pred = y_pred(:);
|
|
f = figure('Color','w','Name','Predicted vs Actual');
|
|
plot(y_true, y_pred, '.', 'MarkerSize', 10); hold on; grid on;
|
|
mins = min([y_true; y_pred]); maxs = max([y_true; y_pred]);
|
|
plot([mins maxs], [mins maxs], 'k-', 'LineWidth', 1);
|
|
xlabel('Actual','Interpreter','latex'); ylabel('Predicted','Interpreter','latex');
|
|
title('Predicted vs Actual','Interpreter','latex');
|
|
print(f, fullfile(outdir,'scn2_final_pred_vs_actual'), '-dpng');
|
|
|
|
% (3c) Residuals (time series)
|
|
err = y_true - y_pred;
|
|
f = figure('Color','w','Name','Prediction Error');
|
|
plot(err, 'k'); grid on;
|
|
xlabel('Sample','Interpreter','latex'); ylabel('Error','Interpreter','latex');
|
|
title('Prediction Error','Interpreter','latex');
|
|
mae = mean(abs(err));
|
|
try
|
|
subtitle(sprintf('Mean absolute error: %.6f', mae), 'Interpreter','latex');
|
|
end
|
|
print(f, fullfile(outdir,'scn2_final_error_series'), '-dpng');
|
|
|
|
% (4) MFs (subset of selected inputs) BEFORE/AFTER
|
|
% Show up to 3 selected inputs for clarity
|
|
nShow = min( min(3, numel(sel_idx)), numel(init_fis.Inputs) );
|
|
if nShow > 0
|
|
f = figure('Color','w','Name','MFs (subset)');
|
|
tl = tiledlayout(2, nShow, 'TileSpacing','compact', 'Padding','compact');
|
|
for k=1:nShow
|
|
inIdx = k; % first few selected
|
|
[xb,yb] = plotmf(init_fis,'input',inIdx);
|
|
[xa,ya] = plotmf(final_fis,'input',inIdx);
|
|
nexttile(tl,k); hold on; grid on; plot(xb,yb,'LineWidth',1.2);
|
|
title(sprintf('BEFORE x_{%d}', sel_idx(k)),'Interpreter','latex');
|
|
xlabel(sprintf('x_{%d}', sel_idx(k)),'Interpreter','latex'); ylabel('Membership','Interpreter','latex');
|
|
|
|
nexttile(tl,nShow+k); hold on; grid on; plot(xa,ya,'LineWidth',1.2);
|
|
title(sprintf('AFTER x_{%d}', sel_idx(k)),'Interpreter','latex');
|
|
xlabel(sprintf('x_{%d}', sel_idx(k)),'Interpreter','latex'); ylabel('Membership','Interpreter','latex');
|
|
end
|
|
sgtitle(tl, 'Membership Functions (subset)','Interpreter','latex');
|
|
print(f, fullfile(outdir,'scn2_final_mfs_subset'), '-dpng');
|
|
end
|
|
end |