博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
CART
阅读量:7289 次
发布时间:2019-06-30

本文共 7504 字,大约阅读时间需要 25 分钟。

一、为什么有CART回归树

  以前学过全局回归,顾名思义,就是指全部数据符合某种曲线。比如线性回归,多项式拟合(泰勒)等等。可是这些数学规律多强,硬硬地将全部数据逼近一些特殊的曲线。生活中的数据可是千变万化。那么,局部回归是一种合理地选择。在斯坦福大学NG的公开课中,他也提到局部回归的好处。其中,CART回归树就是局部回归的一种。

二、CART回归树的算法流程

  注意到,(1)中两步优化,即选择最优切分变量和切分点。(i)如果给定x的切分点。那么可以马上求得中括号内的最优。(ii)对于切分点怎么确定,这里是用遍历的方法。

三、CART分类树

  实际上,CART分类树的生成树和ID3方法类似,只是这里用基尼指数代替了信息增益,定义

四、CART剪枝算法流程

 例子参考:http://www.cnblogs.com/zhangchaoyang/articles/2709922.html

比如:

当分类回归树划分得太细时,会对噪声数据产生过拟合作用。因此我们要通过剪枝来解决。剪枝又分为前剪枝和后剪枝:前剪枝是指在构造树的过程中就知道哪些节点可以剪掉,于是干脆不对这些节点进行分裂,在中用的都是前剪枝,上面的χ2方法也可以认为是一种前剪枝;后剪枝是指构造出完整的决策树之后再来考查哪些子树可以剪掉。

在分类回归树中可以使用的后剪枝方法有多种,比如:代价复杂性剪枝、最小误差剪枝、悲观误差剪枝等等。这里我们只介绍代价复杂性剪枝法。

对于分类回归树中的每一个非叶子节点计算它的表面误差率增益值α。

是子树中包含的叶子节点个数;

是节点t的误差代价,如果该节点被剪枝;

r(t)是节点t的误差率;

p(t)是节点t上的数据占所有数据的比例。

是子树Tt的误差代价,如果该节点不被剪枝。它等于子树Tt上所有叶子节点的误差代价之和。

比如有个非叶子节点t4如图所示:

 

已知所有的数据总共有60条,则节点t4的节点误差代价为:

子树误差代价为:

以t4为根节点的子树上叶子节点有3个,最终:

找到α值最小的非叶子节点,令其左右孩子为NULL。当多个非叶子节点的α值同时达到最小时,取最大的进行剪枝。

 

#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
#include
using namespace std; //置信水平取0.95时的卡方表const double CHI[18]={ 0.004,0.103,0.352,0.711,1.145,1.635,2.167,2.733,3.325,3.94,4.575,5.226,5.892,6.571,7.261,7.962};/*根据多维数组计算卡方值*/template
double cal_chi(Comparable **arr,int row,int col){ vector
rowsum(row); vector
colsum(col); Comparable totalsum=static_cast
(0); //cout<<"observation"<
first-obj.first; if(cmp>0) return false; else if(cmp<0) return true; else{ cmp=obj.second-this->second; if(cmp<0) return true; else return false; } }}; typedef map
MAP_REST_COUNT;typedef map
MAP_ATTR_REST;typedef vector
VEC_STATI; const int ATTR_NUM=8; //自变量的维度vector
X(ATTR_NUM);int rest_number; //因变量的种类数,即类别数vector
> classes; //把类别、对应的记录数存放在一个数组中int total_record_number; //总的记录数vector
> inputData; //原始输入数据 class node{public: node* parent; //父节点 node* leftchild; //左孩子节点 node* rightchild; //右孩子节点 string cond; //分枝条件 string decision; //在该节点上作出的类别判定 double precision; //判定的正确率 int record_number; //该节点上涵盖的记录个数 int size; //子树包含的叶子节点的数目 int index; //层次遍历树,给节点标上序号 double alpha; //表面误差率的增加量 node(){ parent=NULL; leftchild=NULL; rightchild=NULL; precision=0.0; record_number=0; size=1; index=0; alpha=1.0; } node(node* p){ parent=p; leftchild=NULL; rightchild=NULL; precision=0.0; record_number=0; size=1; index=0; alpha=1.0; } node(node* p,string c,string d):cond(c),decision(d){ parent=p; leftchild=NULL; rightchild=NULL; precision=0.0; record_number=0; size=1; index=0; alpha=1.0; } void printInfo(){ cout<<"index:"<
<<"\tdecisoin:"<
<<"\tprecision:"<
<<"\tcondition:"<
<<"\tsize:"<
index; if(leftchild!=NULL) cout<<"\tleftchild:"<
index<<"\trightchild:"<
index; cout<
printTree(); if(rightchild!=NULL) rightchild->printTree(); }}; int readInput(string filename){ ifstream ifs(filename.c_str()); if(!ifs){ cerr<<"open inputfile failed!"<
catg; string line; getline(ifs,line); string item; istringstream strstm(line); strstm>>item; for(int i=0;i
>item; X[i]=item; } while(getline(ifs,line)){ vector
conts(ATTR_NUM+2); istringstream strstm(line); //strstm.str(line); for(int i=0;i
>item; conts[i]=item; if(i==conts.size()-1) catg[item]++; } inputData.push_back(conts); } total_record_number=inputData.size(); ifs.close(); map
::const_iterator itr=catg.begin(); while(itr!=catg.end()){ classes.push_back(make_pair(itr->first,itr->second)); itr++; } rest_number=classes.size(); return 0;} /*根据inputData作出一个统计stati*/void statistic(vector
> &inputData,VEC_STATI &stati){ for(int i=1;i
second).find(rest); if(iter==(itr->second).end()){ (itr->second).insert(make_pair(rest,1)); } else{ iter->second+=1; } } } stati.push_back(attr_rest); }} /*依据某条件作出分枝时,inputData被分成两部分*/void splitInput(vector
> &inputData,int fitIndex,string cond,vector
> &LinputData,vector
> &RinputData){ for(int i=0;i
first; MAP_REST_COUNT::const_iterator iter=(itr->second).begin(); while(iter!=(itr->second).end()){ cout<<"\t"<
first<<"\t"<
second; iter++; } itr++; cout<
> &inputData,vector
> classes){ //root->printInfo(); root->record_number=inputData.size(); VEC_STATI stati; statistic(inputData,stati); //printStati(stati); //for(int i=0;i
> fitleftclasses; vector
> fitrightclasses; int fitleftnumber; int fitrightnumber; for(int i=0;i
first; //判定的条件,即到达左孩子的条件 //cout<<"cond 为"<
<<"时:"; vector
> leftclasses(classes); //左孩子节点上类别、及对应的数目 vector
> rightclasses(classes); //右孩子节点上类别、及对应的数目 int leftnumber=0; //左孩子节点上包含的类别数目 int rightnumber=0; //右孩子节点上包含的类别数目 for(int j=0;j
second).find(rest); if(iter2==(itr->second).end()){ //没找到 leftclasses[j].second=0; rightnumber+=rightclasses[j].second; } else{ //找到 leftclasses[j].second=iter2->second; leftnumber+=leftclasses[j].second; rightclasses[j].second-=(iter2->second); rightnumber+=rightclasses[j].second; } } /**if(leftnumber==0 || rightnumber==0){ cout<<"左右有一边为空"<
cond<
size)++; travel=travel->parent; } node *LChild=new node(root); //创建左右孩子 node *RChild=new node(root); root->leftchild=LChild; root->rightchild=RChild; int maxLcount=0; int maxRcount=0; string Ldicision,Rdicision; for(int i=0;i
maxLcount){ maxLcount=fitleftclasses[i].second; Ldicision=fitleftclasses[i].first; } if(fitrightclasses[i].second>maxRcount){ maxRcount=fitrightclasses[i].second; Rdicision=fitrightclasses[i].first; } } LChild->decision=Ldicision; RChild->decision=Rdicision; LChild->precision=1.0*maxLcount/fitleftnumber; RChild->precision=1.0*maxRcount/fitrightnumber; /*递归对左右孩子进行分裂*/ vector
> LinputData,RinputData; splitInput(inputData,fitIndex,fitCond,LinputData,RinputData); //cout<<"左边inputData行数:"<
<
leftchild==NULL) return (1-root->precision)*root->record_number/total_record_number; else return calR2(root->leftchild)+calR2(root->rightchild);} /*层次遍历树,给节点标上序号。同时计算alpha*/void index(node *root,priority_queue
&pq){ int i=1; queue
que; que.push(root); while(!que.empty()){ node* n=que.front(); que.pop(); n->index=i++; if(n->leftchild!=NULL){ que.push(n->leftchild); que.push(n->rightchild); //计算表面误差率的增量 double r1=(1-n->precision)*n->record_number/total_record_number; //节点的误差代价 double r2=calR2(n); n->alpha=(r1-r2)/(n->size-1); pq.push(MyTriple(n->alpha,n->size,n->index)); } }} /*剪枝*/void prune(node *root,priority_queue
&pq){ MyTriple triple=pq.top(); int i=triple.third; queue
que; que.push(root); while(!que.empty()){ node* n=que.front(); que.pop(); if(n->index==i){ cout<<"将要剪掉"<
<<"的左右子树"<
leftchild=NULL; n->rightchild=NULL; int s=n->size-1; node *trav=n; while(trav!=NULL){ trav->size-=s; trav=trav->parent; } break; } else if(n->leftchild!=NULL){ que.push(n->leftchild); que.push(n->rightchild); } }} void test(string filename,node *root){ ifstream ifs(filename.c_str()); if(!ifs){ cerr<<"open inputfile failed!"<
independent; //自变量,即分类的依据 while(getline(ifs,line)){ istringstream strstm(line); //strstm.str(line); strstm>>item; cout<
<<"\t"; for(int i=0;i
>item; independent[X[i]]=item; } node *trav=root; while(trav!=NULL){ if(trav->leftchild==NULL){ cout<<(trav->decision)<<"\t置信度:"<<(trav->precision)<
cond; string::size_type pos=cond.find("="); string pre=cond.substr(0,pos); string post=cond.substr(pos+1); if(independent[pre]==post) trav=trav->leftchild; else trav=trav->rightchild; } } ifs.close();} int main(){ string inputFile="animal"; readInput(inputFile); VEC_STATI stati; //最原始的统计 statistic(inputData,stati); // for(int i=0;i
pq; index(root,pq); root->printTree(); cout<<"剪枝前使用该决策树最多进行"<
size-1<<"次条件判断"<
size-1<<"次条件判断"<

 

参考文献:

http://blog.csdn.net/google19890102/article/details/32329823

转载于:https://www.cnblogs.com/Wanggcong/p/4670138.html

你可能感兴趣的文章
java 设计模式 建造者模式
查看>>
mysql备份和恢复工作记录
查看>>
我的友情链接
查看>>
vFrank考VCDX的过程
查看>>
jQuery input同步发sims
查看>>
memcached起步
查看>>
lesson 10-你所不知道的邮件退信代码
查看>>
OSPF LSA过滤简述
查看>>
m283-tftp传输,nfs挂载rootfs
查看>>
Windows Server 2008搭建***服务
查看>>
实验一 路由配置(cisco packet tracer)
查看>>
装机流程
查看>>
练习题7
查看>>
简单的nginx启动脚本
查看>>
我的友情链接
查看>>
React Native集成到Android项目当中
查看>>
cd ls
查看>>
linux学习命令总结⑩①
查看>>
【好程序员笔记分享】C语言之交换变量的值
查看>>
linux 安装和初级优化
查看>>