随机森林,Random Forests Classifiers/Regressor

发布时间:2024年01月04日

目录

介绍:?

一、?Random Forests Classifiers(离散型)

1.1 数据处理

1.2建模

1.3特征值权值分析

1.4 特征值的缩减

二、Random Forests Regressor(连续型)

2.1数据处理?

2.2建模

2.3调参


介绍:?

随机森林(Random Forests)是一种集成学习算法,它由多个决策树组成。它在每个决策树的训练过程中引入了随机性,以提高模型的泛化能力和鲁棒性。

随机森林的训练过程如下:

  1. 从训练集中随机选取一部分样本,构建一个决策树。这种随机选取样本的过程叫做自助采样(bootstrap sampling)。
  2. 对于每个决策树的每个节点,从所有特征中随机选取一部分特征,根据这些特征来选择最优的分割点。
  3. 重复以上两个步骤,构建多个决策树。
  4. 预测时,将待预测样本输入到每个决策树中,得到多个预测结果。最终,根据这些预测结果进行投票或平均,确定最终的预测结果。

随机森林在许多方面都表现出良好的性能。它可以用于分类问题和回归问题,并且对于处理高维数据和大型数据集也非常有效。此外,随机森林能够处理缺失数据和不平衡数据,并能够评估特征的重要性。

总的来说,随机森林是一种强大的机器学习算法,它通过组合多个决策树的预测结果来提高模型的性能和鲁棒性。它在实际应用中广泛使用,并且具有很好的可解释性和通用性。

一、?Random Forests Classifiers(离散型)

1.1 数据处理

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
%matplotlib inline

data=pd.read_csv('iris.csv')#离散型

X=data.iloc[:,1:5]
'''结果:
 	Sepal.Length 	Sepal.Width 	Petal.Length 	Petal.Width
0 	5.1 	3.5 	1.4 	0.2
1 	4.9 	3.0 	1.4 	0.2
2 	4.7 	3.2 	1.3 	0.2
3 	4.6 	3.1 	1.5 	0.2
4 	5.0 	3.6 	1.4 	0.2
... 	... 	... 	... 	...
145 	6.7 	3.0 	5.2 	2.3
146 	6.3 	2.5 	5.0 	1.9
147 	6.5 	3.0 	5.2 	2.0
148 	6.2 	3.4 	5.4 	2.3
149 	5.9 	3.0 	5.1 	1.8

150 rows × 4 columns
'''

y=data.iloc[:,-1:]
'''结果:
 	Species
0 	setosa
1 	setosa
2 	setosa
3 	setosa
4 	setosa
... 	...
145 	virginica
146 	virginica
147 	virginica
148 	virginica
149 	virginica

150 rows × 1 columns
'''

1.2建模

from  sklearn.model_selection import train_test_split#将数据分成测试和训练集
X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=0)#测试集占百分之三十,random_state=0随机抽取数据集里的成为测试集

from sklearn.ensemble import RandomForestClassifier#import  random forest model

clf=RandomForestClassifier(n_estimators=100)#赋类,100棵树,后面可以调参
clf.fit(X_train,y_train)#训练集喂给这个模型
y_pred=clf.predict(X_test)#预测值

y_pred#预测值
'''结果:
array(['virginica', 'versicolor', 'setosa', 'virginica', 'setosa',
       'virginica', 'setosa', 'versicolor', 'versicolor', 'versicolor',
       'virginica', 'versicolor', 'versicolor', 'versicolor',
       'versicolor', 'setosa', 'versicolor', 'versicolor', 'setosa',
       'setosa', 'virginica', 'versicolor', 'setosa', 'setosa',
       'virginica', 'setosa', 'setosa', 'versicolor', 'versicolor',
       'setosa', 'virginica', 'versicolor', 'setosa', 'virginica',
       'virginica', 'versicolor', 'setosa', 'virginica', 'versicolor',
       'versicolor', 'virginica', 'setosa', 'virginica', 'setosa',
       'setosa'], dtype=object)
'''

from sklearn import metrics
metrics.accuracy_score(y_test,y_pred)#模型的值,y_test,y_pred对比
#结果:0.9777777777777777

1.3特征值权值分析

#特征变量的权值分析
feature_list=list(X.columns)
feature_imp=pd.Series(clf.feature_importances_,index=feature_list).sort_values(ascending=False)
feature_imp#特征值的权重
'''结果:
Petal.Width     0.456188
Petal.Length    0.411471
Sepal.Length    0.106732
Sepal.Width     0.025609
dtype: float64
'''

feature_imp.index
#结果:Index(['Petal.Width', 'Petal.Length', 'Sepal.Length', 'Sepal.Width'], dtype='object')
sns.barplot(x=feature_imp,y=feature_imp.index)
plt.xlabel('feature importance score')
plt.ylabel('feature')
plt.legend(feature_imp.index)
plt.show()

?

1.4 特征值的缩减

#特征变量的缩减,对于成百上千特征变量的大数据非常有意义
X=data.iloc[:,1:-2]
'''结果:
 	Sepal.Length 	Sepal.Width 	Petal.Length
0 	5.1 	3.5 	1.4
1 	4.9 	3.0 	1.4
2 	4.7 	3.2 	1.3
3 	4.6 	3.1 	1.5
4 	5.0 	3.6 	1.4
... 	... 	... 	...
145 	6.7 	3.0 	5.2
146 	6.3 	2.5 	5.0
147 	6.5 	3.0 	5.2
148 	6.2 	3.4 	5.4
149 	5.9 	3.0 	5.1

150 rows × 3 columns
'''


X_train,X_test,y_train,y_test=train_test_split(X,y,test_size=0.3,random_state=0)#测试集占百分之三十,random_state=0随机抽取数据集里的成为测试集
clf=RandomForestClassifier(n_estimators=100)#赋类,100棵树,后面可以调参
clf.fit(X_train,y_train)#训练集喂给这个模型
y_pred=clf.predict(X_test)#预测值
metrics.accuracy_score(y_test,y_pred)#模型的值,y_test,y_pred对比
#结果:0.9333333333333333

二、Random Forests Regressor(连续型)

2.1数据处理?

dataset = pd.read_csv('petrol_consumption.csv')
dataset#汽油税,收入,高速费,人口密度,汽油消耗

dataset.info()
'''结果:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 48 entries, 0 to 47
Data columns (total 5 columns):
 #   Column                        Non-Null Count  Dtype  
---  ------                        --------------  -----  
 0   Petrol_tax                    48 non-null     float64
 1   Average_income                48 non-null     int64  
 2   Paved_Highways                48 non-null     int64  
 3   Population_Driver_licence(%)  48 non-null     float64
 4   Petrol_Consumption            48 non-null     int64  
dtypes: float64(2), int64(3)
memory usage: 2.0 KB
'''

X = dataset.iloc[:,0:4]
y=dataset.iloc[:,4]

#数据差异非常大,需要引入数据标准化
from sklearn.preprocessing import StandardScaler
sc =  StandardScaler()
X_train = sc.fit_transform(X_train)
X_test=sc.transform(X_test)

X_train
'''结果:
array([[-0.60684249, -0.13370363, -0.39371558,  0.71661097],
       [-0.60684249,  0.73650306,  0.12337074,  2.41961586],
       [-0.60684249, -0.08812138,  1.37233744,  0.09273789],
       [ 0.33674156, -0.35747107,  0.14030588, -0.29507511],
       [ 1.28032561,  1.11152071, -0.85491594, -1.1718697 ],
       [-0.60684249, -0.3201765 ,  0.85525111, -0.14332219],
       [-0.60684249, -1.45973288, -0.42137631, -0.29507511],
       [-0.60684249,  0.48165682,  0.66501302,  1.39106835],
       [-0.60684249, -0.12541595, -0.52016464,  0.37938228],
       [ 0.33674156,  0.17915639,  0.87472653, -0.86836388],
       [-0.60684249, -0.31810458,  0.31106856, -0.59858093],
       [ 1.28032561, -1.32505804, -0.20658227, -0.61544236],
       [ 2.22390966,  2.03352542, -1.16990957, -0.16018363],
       [-0.60684249, -1.3312738 , -0.21250957, -0.6828881 ],
       [-1.00314779, -1.15723246,  0.66501302,  0.81777958],
       [-0.60684249, -0.05911449,  0.75674504,  0.46368945],
       [ 0.80853358, -1.50324322, -0.62205774,  1.39106835],
       [ 1.28032561, -0.55637546, -1.19333652, -0.14332219],
       [-0.60684249,  0.94576705,  0.40985689, -0.10959932],
       [-0.60684249, -1.27533194, -0.80919106, -1.22245401],
       [ 0.80853358,  0.44229032, -0.80693304, -0.49741232],
       [ 0.33674156,  1.98587124,  1.80361904, -2.18355578],
       [ 1.28032561, -0.21243662, -0.22351741, -1.0707011 ],
       [-2.49401059, -0.65375573,  3.4728595 , -0.2444908 ],
       [ 0.33674156,  1.28970589, -1.37623605,  0.36252084],
       [ 0.80853358, -0.0404672 ,  0.15018472,  1.62712844],
       [-0.60684249,  0.31383124,  0.85496886, -0.48055089],
       [-0.60684249, -0.03217952, -0.4439565 ,  1.54282127],
       [ 1.28032561,  0.23924209, -0.43351316, -0.16018363],
       [-0.13505047,  1.05557885, -0.88257667, -0.86836388],
       [ 1.28032561, -1.63584614, -0.9884213 , -0.93580962],
       [-1.55042654,  1.77039149, -0.89640703,  1.54282127]])
'''

2.2建模

from sklearn.ensemble import RandomForestRegressor
regressor = RandomForestRegressor(n_estimators=100,random_state=0)
regressor.fit(X_train,y_train)
y_pred = regressor.predict(X_test)
#结果:array([586.98, 489.34, 626.27, 658.65, 631.72, 614.6 , 612.19, 579.32,
#       465.44, 512.32, 438.32, 657.18, 624.52, 598.46, 534.4 , 555.78])


from sklearn import metrics
print('Mean Absolute Error',metrics.mean_absolute_error(y_test,y_pred))
print('Mean Squared Error',metrics.mean_squared_error(y_test,y_pred))
print('root mean squared error',np.sqrt(metrics.mean_squared_error(y_test,y_pred)))
'''结果:
Mean Absolute Error 50.18562499999999
Mean Squared Error 3964.588093749999
root mean squared error 62.964975134990716
'''

2.3调参

rmse=nestimators=[]#调参
for n in [20,30,50,80,100,200,300,400,500,600,700,800]:
    regressor = RandomForestRegressor(n_estimators=n,random_state=0)
    regressor.fit(X_train,y_train)
    y_pred = regressor.predict(X_test)
    print('-------------------')
    print('n_estimators={}',format(n))
    print('Mean Absolute Error',metrics.mean_absolute_error(y_test,y_pred))
    print('Mean Squared Error',metrics.mean_squared_error(y_test,y_pred))
    print('Root mean squared error',np.sqrt(metrics.mean_squared_error(y_test,y_pred)))
    rmse=np.append(rmse,np.sqrt(metrics.mean_squared_error(y_test,y_pred)))
    nestimators=np.append(nestimators,n)

'''结果:
-------------------
n_estimators={} 20
Mean Absolute Error 56.128125000000004
Mean Squared Error 4606.41578125
Root mean squared error 67.87058111766835
-------------------
n_estimators={} 30
Mean Absolute Error 49.94375000000001
Mean Squared Error 3922.442708333335
Root mean squared error 62.62940769585271
-------------------
n_estimators={} 50
Mean Absolute Error 49.158750000000005
Mean Squared Error 3868.3672749999996
Root mean squared error 62.19619984372035
-------------------
n_estimators={} 80
Mean Absolute Error 50.70390625
Mean Squared Error 4013.614755859375
Root mean squared error 63.35309586641662
-------------------
n_estimators={} 100
Mean Absolute Error 50.18562499999999
Mean Squared Error 3964.588093749999
Root mean squared error 62.964975134990716
-------------------
n_estimators={} 200
Mean Absolute Error 48.34375
Mean Squared Error 3622.057096875
Root mean squared error 60.18352845152069
-------------------
n_estimators={} 300
Mean Absolute Error 49.467708333333334
Mean Squared Error 3789.9574437499987
Root mean squared error 61.5626302536693
-------------------
n_estimators={} 400
Mean Absolute Error 48.489999999999995
Mean Squared Error 3636.7144398437504
Root mean squared error 60.30517755420135
-------------------
n_estimators={} 500
Mean Absolute Error 48.917499999999976
Mean Squared Error 3726.081923499998
Root mean squared error 61.041640897833
-------------------
n_estimators={} 600
Mean Absolute Error 48.97749999999999
Mean Squared Error 3719.864061805555
Root mean squared error 60.990688320476885
-------------------
n_estimators={} 700
Mean Absolute Error 48.50473214285714
Mean Squared Error 3633.144154209183
Root mean squared error 60.275568468569276
-------------------
n_estimators={} 800
Mean Absolute Error 48.12984374999999
Mean Squared Error 3560.1158533203115
Root mean squared error 59.66670640583668
'''
rmse
'''结果:
array([67.87058112, 62.6294077 , 62.19619984, 63.35309587, 62.96497513,
       60.18352845, 61.56263025, 60.30517755, 61.0416409 , 60.99068832,
       60.27556847, 59.66670641])
'''

sns.set_style('whitegrid')
plt.plot(nestimators,rmse,'bo',linestyle='dashed',linewidth=1,markersize=10)#前面x,后面y
plt.xlabel('feature importance score')
plt.ylabel('features')
plt.title("viualizing importeant features")
plt.show()

文章来源:https://blog.csdn.net/qq_74156152/article/details/135381353
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。