我们相信:世界是美好的,你是我也是。 来玩一下解压小游戏吧!

苏南大叔将使用sklearn库中的随机森林分类器,对MNIST数据集进行预测。MNIST是经典的机器学习的数据集,内置7万张手写阿拉伯数字图片,并进行了数据标注,以便进行分类学习预测。

苏南大叔:基于sklearn的随机森林分类器,对mnist数据集进行预测 - 随机森林-mnist数据集
基于sklearn的随机森林分类器,对mnist数据集进行预测(图3-1)

苏南大叔的“程序如此灵动”博客,记录苏南大叔的代码编程经验总结。测试环境:win10python@3.12.9sklearn@1.6.1。本文使用的是随机森林分类法,对图像数据进行分类。

前文回顾

MNIST数据集:

随机森林对鸢尾花数据集预测:

基于sklearn库的随机森林分类器,对 MNIST 数据集进行预测,预测结果如何呢?是否精准?

数据集加载,fetch_openml

fetch_openml()sklearn的一个工具函数,用于从OpenML平台加载数据集。可以直接获取数据集,无需手动下载和解压。这里使用这个函数加载mnist数据集。

from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist.data, mnist.target

返回值X 是包含图像像素值的二维数组,y 是对应的标签。

相关文章:

数据标准化,StandardScaler

MNIST数据集中,像素值范围为0255。直接使用这些数据可能导致模型对较大的特征值敏感。为了可以加速模型的收敛,提高预测精度,因此对数据进行标准化处理。

将数据转换为均值为0、标准差为1的分布。

from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

fit_transform() 用于训练集,transform() 用于测试集,避免数据泄露。

相关文章:

预测算法,RandomForestClassifier

RandomForestClassifier() 是一种基于决策树的集成学习算法,具有以下特点:

  • 随机性:通过随机选择特征和样本,构建多棵决策树,减少过拟合。
  • 高效性:适用于大规模数据集,训练速度快。
  • 鲁棒性:对噪声和缺失数据不敏感。

随机森林训练模型,代码如下:

from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
  • n_estimators:指定决策树的数量,更多的树通常能提高模型性能。
  • random_state:保证结果的可重复性。

评测标准:accuracy_score

模型的性能通过 accuracy_score 进行评估,计算预测正确的样本占总样本的比例:

from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
  • 优点:简单直观,适合分类任务。
  • 局限性:对于类别不平衡的数据集,可能无法全面反映模型性能。

相关文章:

模型预加载,joblib

因为训练耗时的缘故,这里使用joblib来加载上次训练的结果。

if os.path.exists(MODEL_FILE):
    print("Loading model from file...")
    model = joblib.load(MODEL_FILE)
else:
    print("Training model...")
    model = train_model(X_train, y_train)

完整代码

完整代码如下:

import numpy as np
import ssl
import joblib
import os
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score

ssl._create_default_https_context = ssl._create_unverified_context
MODEL_FILE = "random_forest_mnist_model.joblib"

def load_data():
    mnist = fetch_openml('mnist_784', version=1)
    X = mnist.data
    y = mnist.target.astype(int)  # Replace np.int with int
    return train_test_split(X, y, test_size=0.2, random_state=42)

def preprocess_data(X_train, X_test):
    scaler = StandardScaler()
    X_train_scaled = scaler.fit_transform(X_train)
    X_test_scaled = scaler.transform(X_test)
    return X_train_scaled, X_test_scaled, scaler

def train_model(X_train, y_train):
    model = RandomForestClassifier(n_estimators=100, random_state=42)
    model.fit(X_train, y_train)
    # Save the trained model to a file
    joblib.dump(model, MODEL_FILE)
    return model

def load_or_train_model(X_train, y_train):
    if os.path.exists(MODEL_FILE):
        print("Loading model from file...")
        model = joblib.load(MODEL_FILE)
    else:
        print("Training model...")
        model = train_model(X_train, y_train)
    return model

def evaluate_model(model, X_test, y_test):
    y_pred = model.predict(X_test)
    accuracy = accuracy_score(y_test, y_pred)
    return accuracy

if __name__ == "__main__":
    X_train, X_test, y_train, y_test = load_data()
    X_train_scaled, X_test_scaled, scaler = preprocess_data(X_train, X_test)
    model = load_or_train_model(X_train_scaled, y_train)
    accuracy = evaluate_model(model, X_test_scaled, y_test)
    print(f'Accuracy: {accuracy}')

精准度输出96.74%

苏南大叔:基于sklearn的随机森林分类器,对mnist数据集进行预测 - 精确度存疑
基于sklearn的随机森林分类器,对mnist数据集进行预测(图3-2)

实际效果不佳

然而,这么高的准确度都是表象。实际应用中,使用自己手写的数字图片来进行预测的话,准确度低的令人发指,简直是弱智中的弱智。有人说是因为东西方的数字写法有区别造成的,似乎有些道理。此处待续。

苏南大叔:基于sklearn的随机森林分类器,对mnist数据集进行预测 - 识别错误
基于sklearn的随机森林分类器,对mnist数据集进行预测(图3-3)

总结

在本文的实验中,使用 RandomForestClassifierMNIST数据集进行预测,最终模型的准确率为96%。这表明随机森林在手写数字识别任务中表现良好。

更多苏南大叔的机器学习的文章,可以参考:

如果本文对您有帮助,或者节约了您的时间,欢迎打赏瓶饮料,建立下友谊关系。
本博客不欢迎:各种镜像采集行为。请尊重原创文章内容,转载请保留作者链接。

 【福利】 腾讯云最新爆款活动!1核2G云服务器首年50元!

 【源码】本文代码片段及相关软件,请点此获取更多信息

 【绝密】秘籍文章入口,仅传授于有缘之人   ai