基于sklearn的随机森林分类器,对mnist数据集进行预测
发布于 作者:苏南大叔 来源:程序如此灵动~

苏南大叔将使用sklearn
库中的随机森林分类器,对MNIST
数据集进行预测。MNIST
是经典的机器学习的数据集,内置7万张手写阿拉伯数字图片,并进行了数据标注,以便进行分类学习预测。
苏南大叔的“程序如此灵动”博客,记录苏南大叔的代码编程经验总结。测试环境:win10
,python@3.12.9
,sklearn@1.6.1
。本文使用的是随机森林分类法,对图像数据进行分类。
前文回顾
MNIST
数据集:
随机森林对鸢尾花数据集预测:
基于sklearn
库的随机森林分类器,对 MNIST
数据集进行预测,预测结果如何呢?是否精准?
数据集加载,fetch_openml
fetch_openml()
是sklearn
的一个工具函数,用于从OpenML
平台加载数据集。可以直接获取数据集,无需手动下载和解压。这里使用这个函数加载mnist
数据集。
from sklearn.datasets import fetch_openml
mnist = fetch_openml('mnist_784', version=1)
X, y = mnist.data, mnist.target
返回值:X
是包含图像像素值的二维数组,y
是对应的标签。
相关文章:
数据标准化,StandardScaler
MNIST
数据集中,像素值范围为0
到255
。直接使用这些数据可能导致模型对较大的特征值敏感。为了可以加速模型的收敛,提高预测精度,因此对数据进行标准化处理。
将数据转换为均值为0
、标准差为1
的分布。
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
fit_transform()
用于训练集,transform()
用于测试集,避免数据泄露。
相关文章:
- https://newsn.net/say/var-std.html
- https://newsn.net/say/python-std.html
- https://newsn.net/say/sklearn-cv-fit.html
预测算法,RandomForestClassifier
RandomForestClassifier()
是一种基于决策树的集成学习算法,具有以下特点:
- 随机性:通过随机选择特征和样本,构建多棵决策树,减少过拟合。
- 高效性:适用于大规模数据集,训练速度快。
- 鲁棒性:对噪声和缺失数据不敏感。
随机森林训练模型,代码如下:
from sklearn.ensemble import RandomForestClassifier
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
n_estimators
:指定决策树的数量,更多的树通常能提高模型性能。random_state
:保证结果的可重复性。
评测标准:accuracy_score
模型的性能通过 accuracy_score
进行评估,计算预测正确的样本占总样本的比例:
from sklearn.metrics import accuracy_score
accuracy = accuracy_score(y_test, y_pred)
print(f'Accuracy: {accuracy}')
- 优点:简单直观,适合分类任务。
- 局限性:对于类别不平衡的数据集,可能无法全面反映模型性能。
相关文章:
模型预加载,joblib
因为训练耗时的缘故,这里使用joblib
来加载上次训练的结果。
if os.path.exists(MODEL_FILE):
print("Loading model from file...")
model = joblib.load(MODEL_FILE)
else:
print("Training model...")
model = train_model(X_train, y_train)
完整代码
完整代码如下:
import numpy as np
import ssl
import joblib
import os
from sklearn.datasets import fetch_openml
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score
ssl._create_default_https_context = ssl._create_unverified_context
MODEL_FILE = "random_forest_mnist_model.joblib"
def load_data():
mnist = fetch_openml('mnist_784', version=1)
X = mnist.data
y = mnist.target.astype(int) # Replace np.int with int
return train_test_split(X, y, test_size=0.2, random_state=42)
def preprocess_data(X_train, X_test):
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
return X_train_scaled, X_test_scaled, scaler
def train_model(X_train, y_train):
model = RandomForestClassifier(n_estimators=100, random_state=42)
model.fit(X_train, y_train)
# Save the trained model to a file
joblib.dump(model, MODEL_FILE)
return model
def load_or_train_model(X_train, y_train):
if os.path.exists(MODEL_FILE):
print("Loading model from file...")
model = joblib.load(MODEL_FILE)
else:
print("Training model...")
model = train_model(X_train, y_train)
return model
def evaluate_model(model, X_test, y_test):
y_pred = model.predict(X_test)
accuracy = accuracy_score(y_test, y_pred)
return accuracy
if __name__ == "__main__":
X_train, X_test, y_train, y_test = load_data()
X_train_scaled, X_test_scaled, scaler = preprocess_data(X_train, X_test)
model = load_or_train_model(X_train_scaled, y_train)
accuracy = evaluate_model(model, X_test_scaled, y_test)
print(f'Accuracy: {accuracy}')
精准度输出96.74%
。
实际效果不佳
然而,这么高的准确度都是表象。实际应用中,使用自己手写的数字图片来进行预测的话,准确度低的令人发指,简直是弱智中的弱智。有人说是因为东西方的数字写法有区别造成的,似乎有些道理。此处待续。
总结
在本文的实验中,使用 RandomForestClassifier
对MNIST
数据集进行预测,最终模型的准确率为96%。这表明随机森林在手写数字识别任务中表现良好。
更多苏南大叔的机器学习的文章,可以参考:


