목표변수가 집단을 의미하는 범주형 의사결정나무 -> 분류나무모형
목표변수가 연속형 변수인 의사결정나무 -> 회귀나무모형
예제) 목표변수가 범주형 good, bad인 독일 신용평가 데이터를 이용하여 cart 방법을 이용한 의사결정나무 구축
1) 데이터 불러오기
> setwd('c:/Rwork')
> german<-read.table('germandata.txt',header=T)
> str(german)
'data.frame': 1000 obs. of 21 variables:
$ check : Factor w/ 4 levels "A11","A12","A13",..: 1 2 4 1 1 4 4 2 4 2 ...
$ duration : int 6 48 12 42 24 36 24 36 12 30 ...
$ history : Factor w/ 5 levels "A30","A31","A32",..: 5 3 5 3 4 3 3 3 3 5 ...
$ purpose : Factor w/ 10 levels "A40","A41","A410",..: 5 5 8 4 1 8 4 2 5 1 ...
$ credit : int 1169 5951 2096 7882 4870 9055 2835 6948 3059 5234 ...
$ savings : Factor w/ 5 levels "A61","A62","A63",..: 5 1 1 1 1 5 3 1 4 1 ...
$ employment : Factor w/ 5 levels "A71","A72","A73",..: 5 3 4 4 3 3 5 3 4 1 ...
$ installment: int 4 2 2 2 3 2 3 2 2 4 ...
$ personal : Factor w/ 4 levels "A91","A92","A93",..: 3 2 3 3 3 3 3 3 1 4 ...
$ debtors : Factor w/ 3 levels "A101","A102",..: 1 1 1 3 1 1 1 1 1 1 ...
$ residence : int 4 2 3 4 4 4 4 2 4 2 ...
$ property : Factor w/ 4 levels "A121","A122",..: 1 1 1 2 4 4 2 3 1 3 ...
$ age : int 67 22 49 45 53 35 53 35 61 28 ...
$ others : Factor w/ 3 levels "A141","A142",..: 3 3 3 3 3 3 3 3 3 3 ...
$ housing : Factor w/ 3 levels "A151","A152",..: 2 2 2 3 3 3 2 1 2 2 ...
$ numcredits : int 2 1 1 1 2 1 1 1 1 2 ...
$ job : Factor w/ 4 levels "A171","A172",..: 3 3 2 3 3 2 3 4 2 4 ...
$ residpeople: int 1 1 2 2 2 2 1 1 1 1 ...
$ telephone : Factor w/ 2 levels "A191","A192": 2 1 1 1 1 2 1 2 1 1 ...
$ foreign : Factor w/ 2 levels "A201","A202": 1 1 1 1 1 1 1 1 1 1 ...
$ y : Factor w/ 2 levels "bad","good": 2 1 2 2 1 2 2 2 2 1 ...
> german$numcredits<-factor(german$numcredits)
> german$residence<-factor(german$residence)
> german$residpeople<-factor(german$residpeople)
> class(german$numcredits);class(german$residence);class(german$residpeople)
[1] "factor"
[1] "factor"
[1] "factor"
2) cart 방법 적용
> library(rpart)
> my.control<-rpart.control(xval=10,cp=0,minsplit=5)
> fit.german<-rpart(y~.,data=german,method='class',control=my.control)
> fit.german #최초의 나무. 가지치기를 하지 않은 최대 크기의 나무 보기
n= 1000
node), split, n, loss, yval, (yprob)
* denotes terminal node
1) root 1000 300 good (0.300000000 0.700000000)
2) check=A11,A12 543 240 good (0.441988950 0.558011050)
.
. (너무 커서 중략)
.
253) credit>=1273 10 0 good (0.000000000 1.000000000) *
127) check=A14 122 1 good (0.008196721 0.991803279) *
함수 설명
1. rpart.control:
- xval=10: 교타 타당성의 fold 개수, 디폴트는 10
- cp=0: 오분류율이 cp값 이상으로 향상되지 않으면 더 이상 분할하지 않고 나무구조 생성을 멈춘다. cp값이 0이면 오분류값이 최소, 디폴트는 0.01
- minsplit=5: 한 노드를 분할하기 위해 필요한 데이터의 개수. 이 값보다 적은 수의 관측치가 있는 노드는 분할하지 않는다. 디폴트는 20
2. r.part
- method=class: 나무 모형을 지정한다. anova는 회귀나무, poisson 포아송 회귀나무, class는 분류나무 exp는 생존나무. 디폴트는 class
- na.action=na.rpart: 목표변수가 결측치이면 전체 관측치를 삭제. 입력변수가 결측치인 경우에는 삭제하지 않는다.
결과 해석
중간노드를 분할하는 최소 자료의 수를 5개로 지정하였고, cp값은 0으로 하여 나무모형의 오분류값이 최소가 될 때 까지 분할을 진행하였다. 또한 10-fold 교차타당성을 수행하여 최적의 cp값을 찾도록 하였다. 나무가 너무나 큰 관계로 중간 부분을 생략하였고, 용이한 모형 분석을 위해 가지치기를 해보자.
3) 큰 나무를 줄이기 위한 가지치기 작업
> printcp(fit.german)
Classification tree:
rpart(formula = y ~ ., data = german, method = "class", control = my.control)
Variables actually used in tree construction:
[1] age check credit debtors duration employment history
[8] housing installment job numcredits others personal property
[15] purpose residence savings
Root node error: 300/1000 = 0.3
n= 1000
CP nsplit rel error xerror xstd
1 0.0516667 0 1.00000 1.00000 0.048305
2 0.0466667 3 0.84000 0.94667 0.047533
3 0.0183333 4 0.79333 0.86333 0.046178
4 0.0166667 6 0.75667 0.87000 0.046294
5 0.0155556 8 0.72333 0.88667 0.046577
6 0.0116667 11 0.67667 0.88000 0.046464
7 0.0100000 13 0.65333 0.85667 0.046062
8 0.0083333 16 0.62333 0.87000 0.046294
9 0.0066667 18 0.60667 0.87333 0.046351
10 0.0060000 38 0.44333 0.92000 0.047120
11 0.0050000 43 0.41333 0.91000 0.046960
12 0.0044444 55 0.35333 0.92000 0.047120
13 0.0033333 59 0.33333 0.92000 0.047120
14 0.0029167 83 0.25000 0.97000 0.047879
15 0.0022222 93 0.22000 0.97667 0.047976
16 0.0016667 96 0.21333 0.97667 0.047976
17 0.0000000 104 0.20000 1.01333 0.048486
결과 해석
10-fold 교차타당성 방법에 의한 오분율(xerror)이 최소가 되는 값은 0.85667이며 이때의 cp값은 0.01임을 알 수 있다. 이 때 분리의 횟수가 13회(nsplit=13)인 나무를 의미한다.
또는 아래와 같은 방법으로도 최소 오분류값(xerror)를 찾을 수 있다.
> names(fit.german)
[1] "frame" "where" "call" "terms"
[5] "cptable" "method" "parms" "control"
[9] "functions" "numresp" "splits" "csplit"
[13] "variable.importance" "y" "ordered"
> fit.german$cptable[,'xerror']
1 2 3 4 5 6 7 8 9
1.0000000 0.9466667 0.8633333 0.8700000 0.8866667 0.8800000 0.8566667 0.8700000 0.8733333
10 11 12 13 14 15 16 17
0.9200000 0.9100000 0.9200000 0.9200000 0.9700000 0.9766667 0.9766667 1.0133333
> which.min(fit.german$cptable[,'xerror'])
7
7
> fit.german$cptable[7,]
CP nsplit rel error xerror xstd
0.01000000 13.00000000 0.65333333 0.85666667 0.04606167
> fit.german$cptable[7]
[1] 0.01
> fit.german$cptable[which.min(fit.german$cptable[,'xerror'])]
[1] 0.01
> min.cp<-fit.german$cptable[which.min(fit.german$cptable[,'xerror'])]
> min.cp
[1] 0.01
> fit.prune.german<-prune(fit.german,cp=min.cp)
4) 오분류율이 최소인 cp값(=0.011)을 찾았으니 이 값을 기준으로 가지치기를 시행하자.
> fit.prune.german<-prune(fit.german,cp=0.01)
> fit.prune.german
결과 해석
node), split, n, loss, yval, (yprob) 기준으로 첫번째 결과를 분석하면 다음과 같다.
노드, 분할점, 개수, …공부 필요
16) duration>=47.5 36 5 bad (0.8611111 0.1388889) *
duration 변수 중 47.5보다 큰 경우, 전체 36(n)개를 bad(yval)로 분류하였고 그 중 5개(loss)가 good이다. 그리하여 bad로 분류되는 것은 31/36 = 0.8611111로 표기하게 되고, 5개의 loss는 5/36 = 1388889 로 그 확률을 볼 수 있다. 아래 plot에서는 bad 31/5로 표기
376) property=A123,A124 20 3 bad (0.8500000 0.1500000) *
377) property=A121,A122 45 14 good (0.3111111 0.6888889) *
property가 a123(car), a124(unknown / no property)의 경우 전체 20개를 bad로 분류하였고 3개의 loss 즉 good (3/20 = 0.15)로 분류하였다. 아래 plot에서는 bad 17/3로 표기
property가 a121(real estate), a122(building society savings agreement)인 경우에는 전체 45개를 good으로 분류하였고 14개의 loss 즉 bad로 분류 (14/45=0.3111111), 아래 plot에서는 good 14/31로 표기
<< 17.6.18(일)>> 해석 부분 내용 추가
duration > = 22.5인 경우, 전체 고객은 237명이고, 이 중 신용도가 나쁜 사람의 비율은 56.5%이고 좋은 사람의 비율은43.5%로 103명이다. 따라서 duration > 22.5 그룹은 bad로 분류된다.
가지치기를 한 모형을 그림으로 나타내는 함수는 아래와 같다.
> plot(fit.prune.german,uniform = T,compress=T,margin=0.1)
> text(fit.prune.german,use.n=T,col='blue',cex=0.7)
왼쪽 가지의 가장 아랫부분의 분할점인 ‘purpose=acdeghj’는 purpose 변수의 범주값 중에서 알파벳 순서로, 1(=a), 3(=c), 4(=d), 5(=e), 7(=g), 8(=h), 10(=j)번째 범주값을 의미하며, fit.prune.german에서 각각 A40,A410,A42,A43,A45,A46,A49 임을 알 수 있다.
34) purpose=A40,A410,A42,A43,A45,A46,A49 137 52 bad (0.6204380 0.3795620) *
<< 17.6.18(일)>> 해석 부분 내용 추가
가장 우측의 duration > = 11.5가 아닌 경우, 신용다가 나쁜 / 좋은 사람의 비율은 9명 / . 4명이고, 신용도가 좋은 good으로 분류된다.
5) 나무수를 더 줄여보자.
> printcp(fit.german)
Classification tree:
rpart(formula = y ~ ., data = german, method = "class", control = my.control)
Variables actually used in tree construction:
[1] age check credit debtors duration employment history
[8] housing installment job numcredits others personal property
[15] purpose residence savings
Root node error: 300/1000 = 0.3
n= 1000
CP nsplit rel error xerror xstd
1 0.0516667 0 1.00000 1.00000 0.048305
2 0.0466667 3 0.84000 0.94667 0.047533
3 0.0183333 4 0.79333 0.86333 0.046178
4 0.0166667 6 0.75667 0.87000 0.046294
5 0.0155556 8 0.72333 0.88667 0.046577
6 0.0116667 11 0.67667 0.88000 0.046464
7 0.0100000 13 0.65333 0.85667 0.046062
8 0.0083333 16 0.62333 0.87000 0.046294
9 0.0066667 18 0.60667 0.87333 0.046351
10 0.0060000 38 0.44333 0.92000 0.047120
11 0.0050000 43 0.41333 0.91000 0.046960
12 0.0044444 55 0.35333 0.92000 0.047120
13 0.0033333 59 0.33333 0.92000 0.047120
14 0.0029167 83 0.25000 0.97000 0.047879
15 0.0022222 93 0.22000 0.97667 0.047976
16 0.0016667 96 0.21333 0.97667 0.047976
17 0.0000000 104 0.20000 1.01333 0.048486
5번째 단계이며 분리의 횟수가 8회(nsplit=8)인 나무는 교차타당성 오분류율이 0.88667로 최소는 아니지만 7번째 단계의 분리의 횟수 13회 나무 가지의 최소 오분류율 0.85667과는 크게 차이가 나지 않는다. 그리고 최소 오분류율 표준편차의 1배 범위(0.88667 < 0.85667 + 0.046062)에 있다. 이런 경우에는 5번째 단계이며 분리의 횟수가 8인 나무를 선택하는 경우도 있다.
5번째 단계이며 분리 횟수가 8인 cp값 0.0155556의 반올림 값 0.016 적용하여 다시 가지치기
> fit.prune.german<-prune(fit.german,cp=0.016)
> fit.prune.german
> plot(fit.prune.german,uniform=T,compress=T,margin=0.1)
> text(fit.prune.german,use.n=T,col='blue',cex=0.7)
6) 목표변수의 분류예측치를 구하고 그 정확도에 대해서 평가해 보자
> fit.prune.german<-prune(fit.german,cp=0.01)
> pred.german=predict(fit.prune.german,newdata=german,type='class')
> tab=table(german$y,pred.german,dnn=c('Actual','Predicted'))
> tab
Predicted
Actual bad good
bad 180 120
good 76 624
함수 설명
predict(fit.prune.german,newdata=german,type='class'), type = class는 분류나무의 집단값 예측결과, 회귀나무라면 type = vector라고 해야 한다.
결과 해석
실제 good인데 good으로 예측한 것이 624개, 실제 bad인데 bad로 예측한 것이 180
따라서 오분류율은 {1000 – (624+180)} / 1000 = 19.6%
R코드를 이용하면 1-sum(diag(tab)) / sum(tab)
7) 마지막으로 독일신용평가데이터를 훈련데이터와 검증 데이터로 분할하여 분류나무를 평가해보자.
> set.seed(1234)
> i=sample(1:nrow(german),round(nrow(german)*0.7)) #70% for training훈련 data, 30% for test검증
> german.train=german[i,]
> german.test=german[-i,]
> fit.german<-rpart(y~.,data=german.train,method='class',control=my.control)
> printcp(fit.german)
Classification tree:
rpart(formula = y ~ ., data = german.train, method = "class",
control = my.control)
Variables actually used in tree construction:
[1] age check credit debtors duration employment history
[8] housing installment job numcredits others personal property
[15] purpose residence savings telephone
Root node error: 201/700 = 0.28714
n= 700
CP nsplit rel error xerror xstd
1 0.05721393 0 1.00000 1.00000 0.059553
2 0.03482587 2 0.88557 1.00498 0.059641
3 0.02985075 5 0.78109 1.00000 0.059553
4 0.01990050 6 0.75124 0.95025 0.058631
5 0.01741294 8 0.71144 0.96020 0.058822
6 0.01492537 10 0.67662 1.00000 0.059553
7 0.01243781 14 0.61692 1.00000 0.059553
8 0.00995025 17 0.57711 1.00995 0.059728
9 0.00746269 35 0.39303 1.03980 0.060238
10 0.00621891 46 0.30846 1.06965 0.060722
11 0.00497512 50 0.28358 1.04975 0.060402
12 0.00331675 58 0.24378 1.09950 0.061181
13 0.00248756 61 0.23383 1.11940 0.061474
14 0.00124378 69 0.21393 1.14925 0.061894
15 0.00099502 73 0.20896 1.14925 0.061894
16 0.00000000 78 0.20398 1.14925 0.061894
> fit.prune.german<-prune(fit.german,cp=0.02)
> fit.prune.german
> p.german.test=predict(fit.prune.german,newdata=german.test,type='class')
> tab=table(german.test$y,p.german.test,dnn=c('Actual','Predicted'))
> tab
Predicted
Actual bad good
bad 34 65
good 14 187
> 1-sum(diag(tab))/sum(tab) #오분류율
[1] 0.2633333
출처: 데이터마이닝(장영재, 김현중, 조형준 공저,knou press)