40 lines
1.3 KiB
Matlab
40 lines
1.3 KiB
Matlab
function [trainX, valX, testX, trainY, valY, testY] = split_data(X, Y, ratios)
|
|
% SPLIT Split dataset into train/validation/test sets (stratified)
|
|
%
|
|
% [trainX, valX, testX, trainY, valY, testY] = split(X, Y, ratios)
|
|
%
|
|
% ratios : [trainRatio, valRatio, testRatio] (e.g. [0.6 0.2 0.2])
|
|
%
|
|
% Stratified split ensures class proportions remain consistent.
|
|
|
|
if nargin < 3
|
|
ratios = [0.6 0.2 0.2];
|
|
end
|
|
assert(abs(sum(ratios) - 1) < 1e-6, 'Ratios must sum to 1.');
|
|
|
|
n = size(X,1);
|
|
classes = unique(Y);
|
|
idxTrain = []; idxVal = []; idxTest = [];
|
|
|
|
for c = classes'
|
|
idx = find(Y == c);
|
|
idx = idx(randperm(length(idx))); % randomize within class
|
|
|
|
nTrain = round(ratios(1)*length(idx));
|
|
nVal = round(ratios(2)*length(idx));
|
|
|
|
idxTrain = [idxTrain; idx(1:nTrain)];
|
|
idxVal = [idxVal; idx(nTrain+1:nTrain+nVal)];
|
|
idxTest = [idxTest; idx(nTrain+nVal+1:end)];
|
|
end
|
|
|
|
% Shuffle within each subset to mix classes
|
|
idxTrain = idxTrain(randperm(length(idxTrain)));
|
|
idxVal = idxVal(randperm(length(idxVal)));
|
|
idxTest = idxTest(randperm(length(idxTest)));
|
|
|
|
trainX = X(idxTrain,:); trainY = Y(idxTrain);
|
|
valX = X(idxVal,:); valY = Y(idxVal);
|
|
testX = X(idxTest,:); testY = Y(idxTest);
|
|
end
|