Skip to content

树模型

本节选用随机森林树模型进行演示,也可尝试sklearn和GBDT模型,以及xgboost模型,lightgbm模型 树模型的基础知识请参照周志华的《机器学习》以及李航的《统计学习方法》,另外代码可以参照机器学习实战 树模型最核心的就是信息熵,基尼指数等叶子结点得分,这是树模型分裂的依据,请着重研究~ 想偷懒的同学,社区的连接在此,看看吧

7.集成方法-随机森林和AdaBoost

顺便说一下,嫌慢的同学还是用lightgbm吧,这个超快的~ GridSearchCV是用来调参的,只需要把注释去掉就可以显示出最佳的参数,不过实测慢的要死。。

代码:

from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
import pandas as pd
import numpy as np
# from sklearn.grid_search import GridSearchCV
# from numpy import arange
# from lightgbm import LGBMClassifier

train_data=pd.read_csv(r"C:\Users\312\Desktop\digit-recognizer\train.csv")
test_data=pd.read_csv(r"C:\Users\312\Desktop\digit-recognizer\test.csv")
data=pd.concat([train_data,test_data],axis=0).reset_index(drop=True)
data.drop(['label'],axis=1,inplace=True)
label=train_data.label


pca=PCA(n_components=100, random_state=34)
data_pca=pca.fit_transform(data)

Xtrain,Ytrain,xtest,ytest=train_test_split(data_pca[0:len(train_data)],label,test_size=0.1, random_state=34)


clf=RandomForestClassifier(n_estimators=110,max_depth=5,min_samples_split=2, min_samples_leaf=1,random_state=34)

# clf=LGBMClassifier(num_leaves=63, max_depth=7, n_estimators=80, n_jobs=20)

# param_test1 = {'n_estimators':arange(10,150,10),'max_depth':arange(1,11,1)}
# gsearch1 = GridSearchCV(estimator = clf, param_grid = param_test1, scoring='accuracy',iid=False,cv=5)
# gsearch1.fit(Xtrain,xtest)
# print(gsearch1.grid_scores_, gsearch1.best_params_, gsearch1.best_score_)


clf.fit(Xtrain,xtest)
y_predict=clf.predict(Ytrain)


zeroLable=ytest-y_predict
rightCount=0
for i in range(len(zeroLable)):
    if list(zeroLable)[i]==0:
        rightCount+=1
print ('the right rate is:',float(rightCount)/len(zeroLable))


result=clf.predict(data_pca[len(train_data):])

with open("C:\\Users\\312\\Desktop\\digit-recognizer\\result.csv", 'w') as fw:
    with open('C:\\Users\\312\\Desktop\\digit-recognizer\\sample_submission.csv') as pred_file:
        fw.write('{},{}\n'.format('ImageId', 'Label'))
        for i,line in enumerate(pred_file.readlines()[1:]):
            splits = line.strip().split(',')
            fw.write('{},{}\n'.format(splits[0],result[i]))

结果:

the right rate is: 0.819047619047619


回到顶部