function [trainX, valX, testX, trainY, valY, testY] = split_data(X, Y, ratios) % SPLIT_DATA Split dataset into train/validation/test sets (stratified) % % [trainX, valX, testX, trainY, valY, testY] = split(X, Y, ratios) % % ratios : [trainRatio, valRatio, testRatio] (e.g. [0.6 0.2 0.2]) % % Stratified split ensures class proportions remain consistent. if nargin < 3 ratios = [0.6 0.2 0.2]; end assert(abs(sum(ratios) - 1) < 1e-6, 'Ratios must sum to 1.'); n = size(X,1); classes = unique(Y); idxTrain = []; idxVal = []; idxTest = []; for c = classes' idx = find(Y == c); idx = idx(randperm(length(idx))); % randomize within class nTrain = round(ratios(1)*length(idx)); nVal = round(ratios(2)*length(idx)); idxTrain = [idxTrain; idx(1:nTrain)]; idxVal = [idxVal; idx(nTrain+1:nTrain+nVal)]; idxTest = [idxTest; idx(nTrain+nVal+1:end)]; end % Shuffle within each subset to mix classes idxTrain = idxTrain(randperm(length(idxTrain))); idxVal = idxVal(randperm(length(idxVal))); idxTest = idxTest(randperm(length(idxTest))); trainX = X(idxTrain,:); trainY = Y(idxTrain); valX = X(idxVal,:); valY = Y(idxVal); testX = X(idxTest,:); testY = Y(idxTest); end