function plot_results2(init_fis, final_fis, trn_error, val_error, y_true, y_pred, ... feature_grid, radius_grid, cv_scores, cv_rules, sel_idx) % PLOTS FOR SCENARIO 2: % (1) CV heatmap (mean CV error over grid) with overlaid #rules % (2) Error vs rules (aggregated) & Error vs #features (best over radii) % (3) Final model: Learning curves, Predicted vs Actual, Residuals % (4) MFs (subset of selected inputs) BEFORE/AFTER % % All figures are saved under ./figures outdir = fullfile('.', 'figures_scn2'); if ~exist(outdir,'dir'), mkdir(outdir); end % (1) CV heatmap f = figure('Color','w','Name','CV Heatmap'); imagesc(radius_grid, feature_grid, cv_scores); set(gca,'YDir','normal'); xlabel('Cluster radius r_\alpha','Interpreter','latex'); ylabel('#Features','Interpreter','latex'); title('Mean CV Error','Interpreter','latex'); colorbar; grid on; % overlay mean rules hold on; for i=1:numel(feature_grid) for j=1:numel(radius_grid) text(radius_grid(j), feature_grid(i), sprintf('%d', cv_rules(i,j)), ... 'HorizontalAlignment','center','Color','w','FontSize',8); end end try subtitle('Numbers show mean #rules','Interpreter','latex'); end print(f, fullfile(outdir,'scn2_cv_heatmap'), '-dpng'); % (2a) Error vs #rules (aggregate by identical rule-counts) uniq_rules = unique(cv_rules(:)); err_vs_rules = zeros(size(uniq_rules)); for k=1:numel(uniq_rules) err_vs_rules(k) = mean(cv_scores(cv_rules == uniq_rules(k))); end f = figure('Color','w','Name','Error vs Rules'); plot(uniq_rules, err_vs_rules, 'o-','LineWidth',1.5); grid on; xlabel('#Rules','Interpreter','latex'); ylabel('Mean CV Error','Interpreter','latex'); title('Error vs Rules','Interpreter','latex'); print(f, fullfile(outdir,'scn2_error_vs_rules'), '-dpng'); % (2b) Best CV error vs #features (min across radii) [min_err_per_feat,~] = min(cv_scores,[],2); f = figure('Color','w','Name','Error vs Features'); plot(feature_grid, min_err_per_feat, 's-','LineWidth',1.5); grid on; xlabel('#Features','Interpreter','latex'); ylabel('Best CV Error','Interpreter','latex'); title('Best CV Error vs #Features','Interpreter','latex'); print(f, fullfile(outdir,'scn2_error_vs_features'), '-dpng'); % (3a) Learning curves f = figure('Color','w','Name','Learning Curves'); plot(trn_error,'LineWidth',1.5); hold on; grid on; if ~isempty(val_error) plot(val_error,'LineWidth',1.5); legend('Train','Validation','Location','best'); else legend('Train','Location','best'); end xlabel('Epoch','Interpreter','latex'); ylabel('Error','Interpreter','latex'); title('Learning Curves','Interpreter','latex'); print(f, fullfile(outdir,'scn2_final_learning_curves'), '-dpng'); % (3b) Predicted vs Actual y_true = y_true(:); y_pred = y_pred(:); f = figure('Color','w','Name','Predicted vs Actual'); plot(y_true, y_pred, '.', 'MarkerSize', 10); hold on; grid on; mins = min([y_true; y_pred]); maxs = max([y_true; y_pred]); plot([mins maxs], [mins maxs], 'k-', 'LineWidth', 1); xlabel('Actual','Interpreter','latex'); ylabel('Predicted','Interpreter','latex'); title('Predicted vs Actual','Interpreter','latex'); print(f, fullfile(outdir,'scn2_final_pred_vs_actual'), '-dpng'); % (3c) Residuals (time series) err = y_true - y_pred; f = figure('Color','w','Name','Prediction Error'); plot(err, 'k'); grid on; xlabel('Sample','Interpreter','latex'); ylabel('Error','Interpreter','latex'); title('Prediction Error','Interpreter','latex'); mae = mean(abs(err)); try subtitle(sprintf('Mean absolute error: %.6f', mae), 'Interpreter','latex'); end print(f, fullfile(outdir,'scn2_final_error_series'), '-dpng'); % (4) MFs (subset of selected inputs) BEFORE/AFTER % Show up to 3 selected inputs for clarity nShow = min( min(3, numel(sel_idx)), numel(init_fis.Inputs) ); if nShow > 0 f = figure('Color','w','Name','MFs (subset)'); tl = tiledlayout(2, nShow, 'TileSpacing','compact', 'Padding','compact'); for k=1:nShow inIdx = k; % first few selected [xb,yb] = plotmf(init_fis,'input',inIdx); [xa,ya] = plotmf(final_fis,'input',inIdx); nexttile(tl,k); hold on; grid on; plot(xb,yb,'LineWidth',1.2); title(sprintf('BEFORE x_{%d}', sel_idx(k)),'Interpreter','latex'); xlabel(sprintf('x_{%d}', sel_idx(k)),'Interpreter','latex'); ylabel('Membership','Interpreter','latex'); nexttile(tl,nShow+k); hold on; grid on; plot(xa,ya,'LineWidth',1.2); title(sprintf('AFTER x_{%d}', sel_idx(k)),'Interpreter','latex'); xlabel(sprintf('x_{%d}', sel_idx(k)),'Interpreter','latex'); ylabel('Membership','Interpreter','latex'); end sgtitle(tl, 'Membership Functions (subset)','Interpreter','latex'); print(f, fullfile(outdir,'scn2_final_mfs_subset'), '-dpng'); end end