sklearn中的决策树-分类树:泰坦尼克号生存预测

news/2025/2/27 11:28:59

分类树实例:泰坦尼克号生存预测

代码分解
  • 需要导入的库

    """导入所需要的库"""
    import pandas as pd
    import numpy as np
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.model_selection import train_test_split
    from sklearn.model_selection import GridSearchCV
    from sklearn.model_selection import cross_val_score
    import matplotlib.pyplot as plt
    
  • 导入数据集

    """导入数据集,探索数据"""
    data_train = pd.read_csv('./need/Taitanic_data/data.csv',index_col=0)
    data_test = pd.read_csv('./need/Taitanic_data/test.csv',index_col=0)
    
    data = pd.concat([data_train,data_test],axis=0)
    data.head()
    data.info()
    

    20211207VOfRuP

    20211207faTpKR

  • 对数据集进行预处理

    """对数据集进行预处理"""
    #删除缺失值过多的列,和观察判断来说和预测的y没有关系的列 
    data.drop(["Cabin","Name","Ticket"],inplace=True,axis=1)
    #处理缺失值,对缺失值较多的列进行填补,有一些特征只确实一两个值,可以采取直接删除记录的方法 
    data["Age"] = data["Age"].fillna(data["Age"].mean())
    data = data.dropna()
    #将分类变量转换为数值型变量
    #将二分类变量转换为数值型变量 #astype能够将一个pandas对象转换为某种类型,和apply(int(x))不同,astype可以将文本类转换为数字,用这 个方式可以很便捷地将二分类特征转换为0~1
    data["Sex"] = (data["Sex"]== "male").astype("int")
    #将三分类变量转换为数值型变量
    labels = data["Embarked"].unique().tolist()
    data["Embarked"] = data["Embarked"].apply(lambda x: labels.index(x))
    #查看处理后的数据集 
    data.head()
    

    20211207cmRwvy

  • 提取标签和特征矩阵,分测试集和训练集

    """提取标签和特征矩阵,分测试集和训练集"""
    X = data.iloc[:,data.columns != "Survived"]
    y = data.iloc[:,data.columns == "Survived"]
    Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,y,test_size=0.3)
    #修正测试集和训练集的索引(或者直接reset_index(drop=True,inplace=True))
    for i in [Xtrain, Xtest, Ytrain, Ytest]:
        i.index = range(i.shape[0])
    #查看分好的训练集和测试集 
    Xtrain.head()
    

    20211207ylMKCB

  • 导入模型,粗略跑一下查看结果

    """导入模型,粗略跑一下查看结果"""
    clf = DecisionTreeClassifier(random_state=25)
    clf = clf.fit(Xtrain, Ytrain)
    score_ = clf.score(Xtest, Ytest)
    print('单颗决策树精度',score_)
    score = cross_val_score(clf,X,y,cv=10).mean()
    print('10次交叉验证平均精度',score)
    
    """输出"""
    单颗决策树精度 0.8164794007490637
    10次交叉验证平均精度 0.7739274770173645
    
  • 在不同max_depth下观察模型的拟合状况

    """在不同max_depth下观察模型的拟合状况"""
    tr = [] # 训练集精度
    te = [] # 测试集交叉验证精度
    for i in range(10):
        clf = DecisionTreeClassifier(random_state=25
                                     ,max_depth=i+1
                                     ,criterion="entropy"
                                    )
        clf = clf.fit(Xtrain, Ytrain)
        score_tr = clf.score(Xtrain,Ytrain)
        score_te = cross_val_score(clf,X,y,cv=10).mean()
        tr.append(score_tr)
        te.append(score_te)
    print("测试集交叉验证均值最大值(精度)",max(te))
    plt.figure(figsize=(12,8))
    plt.plot(range(1,11),tr,color="red",label="train")
    plt.plot(range(1,11),te,color="blue",label="test")
    plt.xticks(range(1,11))
    plt.legend()
    plt.show()
    #这里为什么使用“entropy”?因为我们注意到,在最大深度=3的时候,模型拟合不足,在训练集和测试集上的表现接 近,但却都不是非常理想,只能够达到83%左右,所以我们要使用entropy。
    
    
    """输出"""
    测试集交叉验证均值最大值(精度) 0.8177860061287026
    

    202112070tP4hM

  • 网格搜索调整参数

    """用网格搜索调整参数"""
    gini_thresholds = np.linspace(0,0.5,20)
    parameters = {'splitter':('best','random')
                  ,'criterion':("gini","entropy")
                  ,"max_depth":[*range(1,10)]
                  ,'min_samples_leaf':[*range(1,50,5)]
                  ,'min_impurity_decrease':[*np.linspace(0,0.5,20)]
    }
    clf = DecisionTreeClassifier(random_state=25)
    GS = GridSearchCV(clf, parameters, cv=10)
    GS.fit(Xtrain,Ytrain)
    print('最佳参数',GS.best_params_)
    print('最佳精度',GS.best_score_)
    
    
    """输出"""
    最佳参数 {'criterion': 'entropy'
              , 'max_depth': 9
              , 'min_impurity_decrease': 0.0
              , 'min_samples_leaf': 6
              , 'splitter': 'best'
             }
    最佳精度 0.815284178187404
    

    由此可见,网格搜索并非一定比自己调参好,因为网格搜索无法舍弃无用的参数,默认传入的所有参数必须得都选上。

所有代码
  • 所有代码

    """导入所需要的库"""
    import pandas as pd
    import numpy as np
    from sklearn.tree import DecisionTreeClassifier
    from sklearn.model_selection import train_test_split
    from sklearn.model_selection import GridSearchCV
    from sklearn.model_selection import cross_val_score
    import matplotlib.pyplot as plt
    
    """导入数据集,探索数据"""
    data_train = pd.read_csv('./need/Taitanic_data/data.csv',index_col=0)
    data_test = pd.read_csv('./need/Taitanic_data/test.csv',index_col=0)
    
    data = pd.concat([data_train,data_test],axis=0)
    data.head()
    data.info()
    
    
    """对数据集进行预处理"""
    #删除缺失值过多的列,和观察判断来说和预测的y没有关系的列 
    data.drop(["Cabin","Name","Ticket"],inplace=True,axis=1)
    #处理缺失值,对缺失值较多的列进行填补,有一些特征只确实一两个值,可以采取直接删除记录的方法 
    data["Age"] = data["Age"].fillna(data["Age"].mean())
    data = data.dropna()
    #将分类变量转换为数值型变量
    #将二分类变量转换为数值型变量 #astype能够将一个pandas对象转换为某种类型,和apply(int(x))不同,astype可以将文本类转换为数字,用这 个方式可以很便捷地将二分类特征转换为0~1
    data["Sex"] = (data["Sex"]== "male").astype("int")
    #将三分类变量转换为数值型变量
    labels = data["Embarked"].unique().tolist()
    data["Embarked"] = data["Embarked"].apply(lambda x: labels.index(x))
    #查看处理后的数据集 
    data.head()
    
    
    """提取标签和特征矩阵,分测试集和训练集"""
    X = data.iloc[:,data.columns != "Survived"]
    y = data.iloc[:,data.columns == "Survived"]
    Xtrain, Xtest, Ytrain, Ytest = train_test_split(X,y,test_size=0.3)
    #修正测试集和训练集的索引(或者直接reset_index(drop=True,inplace=True))
    for i in [Xtrain, Xtest, Ytrain, Ytest]:
        i.index = range(i.shape[0])
    #查看分好的训练集和测试集 
    Xtrain.head()
    
    
    """导入模型,粗略跑一下查看结果"""
    clf = DecisionTreeClassifier(random_state=25)
    clf = clf.fit(Xtrain, Ytrain)
    score_ = clf.score(Xtest, Ytest)
    print('单颗决策树精度',score_)
    score = cross_val_score(clf,X,y,cv=10).mean()
    print('10次交叉验证平均精度',score)
    
    
    """在不同max_depth下观察模型的拟合状况"""
    tr = [] # 训练集精度
    te = [] # 测试集交叉验证精度
    for i in range(10):
        clf = DecisionTreeClassifier(random_state=25
                                     ,max_depth=i+1
                                     ,criterion="entropy"
                                    )
        clf = clf.fit(Xtrain, Ytrain)
        score_tr = clf.score(Xtrain,Ytrain)
        score_te = cross_val_score(clf,X,y,cv=10).mean()
        tr.append(score_tr)
        te.append(score_te)
    print("测试集交叉验证均值最大值(精度)",max(te))
    plt.figure(figsize=(12,8))
    plt.plot(range(1,11),tr,color="red",label="train")
    plt.plot(range(1,11),te,color="blue",label="test")
    plt.xticks(range(1,11))
    plt.legend()
    plt.show()
    #这里为什么使用“entropy”?因为我们注意到,在最大深度=3的时候,模型拟合不足,在训练集和测试集上的表现接 近,但却都不是非常理想,只能够达到83%左右,所以我们要使用entropy。
    """用网格搜索调整参数"""
    # gini系数最大为0.5最小为0、信息增益最大为1,最小为0
    gini_thresholds = np.linspace(0,0.5,20)
    parameters = {'splitter':('best','random')
                  ,'criterion':("gini","entropy")
                  ,"max_depth":[*range(1,10)]
                  ,'min_samples_leaf':[*range(1,50,5)]
                  ,'min_impurity_decrease':[*np.linspace(0,0.5,20)]
    }
    clf = DecisionTreeClassifier(random_state=25)
    GS = GridSearchCV(clf, parameters, cv=10)
    GS.fit(Xtrain,Ytrain)
    print('最佳参数',GS.best_params_)
    print('最佳精度',GS.best_score_)
    

http://www.niftyadmin.cn/n/5870040.html

相关文章

蓝桥杯备赛-拔河

问题描述 小明是学校里的一名老师,他带的班级共有 nn 名同学,第 ii 名同学力量值为 aiai​。在闲暇之余,小明决定在班级里组织一场拔河比赛。 为了保证比赛的双方实力尽可能相近,需要在这 nn 名同学中挑选出两个队伍&#xff0c…

Android AsyncLayoutInflater异步加载xml布局文件,Kotlin

Android AsyncLayoutInflater异步加载xml布局文件,Kotlin implementation "androidx.asynclayoutinflater:asynclayoutinflater:1.1.0-alpha01" import android.os.Bundle import android.util.Log import android.view.View import android.view.ViewGro…

Redis 高可用性:如何让你的缓存一直在线,稳定运行?

🎯 引言:Redis的高可用性为啥这么重要? 在现代高可用系统中,Redis 是一款不可或缺的分布式缓存与数据库系统。无论是提升访问速度,还是实现数据的高效持久化,Redis 都能轻松搞定。可是,当你把 …

使用elasticdump导出/导入 -- ES数据

导出指定索引数据到指定文件夹: ./elasticdump --inputhttp://用户:密码IP:9201/索引名字 --output导出路径/out.json --typedata 将导出的文件导入 ./elasticdump --input路径/out.json --outputhttp://账号:密码IP:9201/索引名称 --typedata --fileTypejson 【el…

【Springboot知识】Logback从1.2.x升级到1.3.x需要注意哪些点?

文章目录 **1. 确认依赖版本**示例依赖配置(Maven): **2. 处理 StaticLoggerBinder 的移除**解决方案: **3. 修改日志配置文件**示例 logback.xml 配置: **4. 检查兼容性问题**Spring Boot 2.x 的兼容性解决方案&#…

鸿蒙-AVPlayer

compileVersion 5.0.2(14) 音频播放 import media from ohos.multimedia.media; import common from ohos.app.ability.common; import { BusinessError } from ohos.base;Entry Component struct AudioPlayer {private avPlayer: media.AVPlayer | nu…

[记录贴] 火绒奇怪的进程保护

最近一次更新火绒6.0到最新版,发现processhacker的结束进程功能无法杀掉火绒的进程,弹窗提示如下: 可能是打开进程时做了权限过滤,火绒注册了两个回调函数如下: 但奇怪的是,在另外一台机器上面更新到最新版…

C++ Qt常见面试题(3):Qt内存管理机制

Qt 内存管理机制是其框架的重要组成部分,目的是简化开发者对内存的管理,减少内存泄漏的风险,同时提供高效的资源使用方式。Qt 的内存管理机制主要依赖于 对象树(Object Tree) 和 父子关系(Parent-Child Relationship) 的设计,通过智能管理对象的生命周期来实现自动化的…