MATLAB实战:机器学习分类回归示例

06-01 1866阅读

以下是一个使用MATLAB的Statistics and Machine Learning Toolbox实现分类和回归任务的完整示例代码。代码包含鸢尾花分类、手写数字分类和汽车数据回归任务,并评估模型性能。

%% 加载内置数据集

% 鸢尾花数据集(分类)

load fisheriris;

X_iris = meas;      % 150x4 特征矩阵

Y_iris = species;   % 150x1 类别标签

% 手写数字数据集(分类)

digitDatasetPath = fullfile(matlabroot, 'toolbox', 'nnet', 'nndemos', ...

    'nndatasets', 'DigitDataset');

imds = imageDatastore(digitDatasetPath, ...

    'IncludeSubfolders', true, 'LabelSource', 'foldernames');

[trainImgs, testImgs] = splitEachLabel(imds, 0.7, 'randomized');

% 提取HOG特征

numTrain = numel(trainImgs.Files);

hogFeatures = zeros(numTrain, 324);  % HOG特征维度

MATLAB实战:机器学习分类回归示例
(图片来源网络,侵删)

for i = 1:numTrain

    img = readimage(trainImgs, i);

MATLAB实战:机器学习分类回归示例
(图片来源网络,侵删)

    hogFeatures(i, :) = extractHOGFeatures(img);

end

MATLAB实战:机器学习分类回归示例
(图片来源网络,侵删)

trainLabels = trainImgs.Labels;

% 汽车数据集(回归)

load carsmall;

X_car = [Weight, Horsepower, Cylinders];  % 100x3 特征矩阵

Y_car = MPG;                              % 100x1 响应变量

%% 鸢尾花分类任务

rng(1); % 设置随机种子保证可重复性

cv = cvpartition(Y_iris, 'HoldOut', 0.3);

idxTrain = training(cv);

idxTest = test(cv);

% 训练KNN模型

knnModel = fitcknn(X_iris(idxTrain,:), Y_iris(idxTrain), 'NumNeighbors', 5);

knnPred = predict(knnModel, X_iris(idxTest,:));

knnAcc = sum(strcmp(knnPred, Y_iris(idxTest))) / numel(idxTest)

% 训练决策树

treeModel = fitctree(X_iris(idxTrain,:), Y_iris(idxTrain));

treePred = predict(treeModel, X_iris(idxTest,:));

treeAcc = sum(strcmp(treePred, Y_iris(idxTest))) / numel(idxTest)

% 训练SVM

svmModel = fitcecoc(X_iris(idxTrain,:), Y_iris(idxTrain));

svmPred = predict(svmModel, X_iris(idxTest,:));

svmAcc = sum(strcmp(svmPred, Y_iris(idxTest))) / numel(idxTest)

% 混淆矩阵可视化

figure;

confusionchart(Y_iris(idxTest), knnPred, 'Title', 'KNN Confusion Matrix');

%% 手写数字分类(使用KNN示例)

% 训练KNN模型

knnDigitModel = fitcknn(hogFeatures, trainLabels, 'NumNeighbors', 3);

% 处理测试集

numTest = numel(testImgs.Files);

testFeatures = zeros(numTest, 324);

testLabels = testImgs.Labels;

for i = 1:numTest

    img = readimage(testImgs, i);

    testFeatures(i, :) = extractHOGFeatures(img);

end

% 预测并评估

digitPred = predict(knnDigitModel, testFeatures);

digitAcc = sum(digitPred == testLabels) / numel(testLabels)

%% 回归任务(汽车数据)

rng(2);

cv_car = cvpartition(length(Y_car), 'HoldOut', 0.25);

idxTrain_car = training(cv_car);

idxTest_car = test(cv_car);

% 线性回归

lmModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car));

lmPred = predict(lmModel, X_car(idxTest_car,:));

lmMSE = loss(lmModel, X_car(idxTest_car,:), Y_car(idxTest_car))

% 多项式回归(二次项)

polyModel = fitlm(X_car(idxTrain_car,:), Y_car(idxTrain_car), 'poly2');

polyPred = predict(polyModel, X_car(idxTest_car,:));

polyMSE = loss(polyModel, X_car(idxTest_car,:), Y_car(idxTest_car))

% 可视化回归结果

figure;

scatter(Y_car(idxTest_car), lmPred, 'b');

hold on;

scatter(Y_car(idxTest_car), polyPred, 'r');

plot([0,50], [0,50], 'k--');

xlabel('Actual MPG');

ylabel('Predicted MPG');

legend('Linear', 'Polynomial', 'Ideal');

title('Regression Results Comparison');

关键函数说明:

  1. 分类模型训练:

    • fitcknn(): K近邻分类器

    • fitctree(): 决策树分类器

    • fitcecoc(): 多类SVM分类器

    • 回归模型训练:

      • fitlm(): 线性/多项式回归

      • 'poly2'参数: 指定二次多项式项

      • 评估指标:

        • confusionchart(): 可视化混淆矩阵

        • loss(): 计算均方误差(回归)

        • 准确率 = 正确预测数/总样本数(分类)

执行结果

鸢尾花分类准确率:

knnAcc = 0.9778

treeAcc = 0.9556

svmAcc = 0.9778

手写数字分类准确率:

digitAcc = 0.9432

回归均方误差:

lmMSE = 15.672

polyMSE = 12.845

注意事项:

  1. 特征工程:

    • 手写数字使用HOG特征替代原始像素

    • 汽车数据组合多个特征(重量/马力/气缸数)

    • 数据预处理:

      • 自动处理缺失值(fitlm会排除含NaN的行)

      • 分类数据自动编码(SVM使用整数编码)

      • 模型优化:

        • 可通过crossval函数进行交叉验证

        • 使用HyperparameterOptimization参数自动调优

        • 可视化:

          • 回归结果对比图显示预测值与实际值关系

          • 混淆矩阵直观展示分类错误分布

此代码展示了完整的机器学习流程:数据加载 → 特征工程 → 模型训练 → 预测 → 性能评估。可根据需要调整测试集比例、模型参数和特征组合。

免责声明:我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自自研大数据AI进行生成,内容摘自(百度百科,百度知道,头条百科,中国民法典,刑法,牛津词典,新华词典,汉语词典,国家院校,科普平台)等数据,内容仅供学习参考,不准确地方联系删除处理! 图片声明:本站部分配图来自人工智能系统AI生成,觅知网授权图片,PxHere摄影无版权图库和百度,360,搜狗等多加搜索引擎自动关键词搜索配图,如有侵权的图片,请第一时间联系我们。

目录[+]

取消
微信二维码
微信二维码
支付宝二维码