FuzzySystems/Work 4/source/split_data.m

40 lines
1.3 KiB
Matlab

function [trainX, valX, testX, trainY, valY, testY] = split_data(X, Y, ratios)
% SPLIT_DATA Split dataset into train/validation/test sets (stratified)
%
% [trainX, valX, testX, trainY, valY, testY] = split(X, Y, ratios)
%
% ratios : [trainRatio, valRatio, testRatio] (e.g. [0.6 0.2 0.2])
%
% Stratified split ensures class proportions remain consistent.
if nargin < 3
ratios = [0.6 0.2 0.2];
end
assert(abs(sum(ratios) - 1) < 1e-6, 'Ratios must sum to 1.');
n = size(X,1);
classes = unique(Y);
idxTrain = []; idxVal = []; idxTest = [];
for c = classes'
idx = find(Y == c);
idx = idx(randperm(length(idx))); % randomize within class
nTrain = round(ratios(1)*length(idx));
nVal = round(ratios(2)*length(idx));
idxTrain = [idxTrain; idx(1:nTrain)];
idxVal = [idxVal; idx(nTrain+1:nTrain+nVal)];
idxTest = [idxTest; idx(nTrain+nVal+1:end)];
end
% Shuffle within each subset to mix classes
idxTrain = idxTrain(randperm(length(idxTrain)));
idxVal = idxVal(randperm(length(idxVal)));
idxTest = idxTest(randperm(length(idxTest)));
trainX = X(idxTrain,:); trainY = Y(idxTrain);
valX = X(idxVal,:); valY = Y(idxVal);
testX = X(idxTest,:); testY = Y(idxTest);
end