# extra routines for using regression trees # requires: tree package, clus1.r # Tom Minka 10/26/01 library(tree) # plot tree partitions along with residuals colored by magnitude plot.residuals <- function(tr) { res <- residuals(tr) b <- c(-Inf, -4, -2, 0, 2, 4, Inf) color <- c(hsv(0,seq(1,0,length=3)), hsv(4/6,seq(0,1,length=3))) q <- ncut(res,b*sd(res)) # find the two variables frame <- tr$frame leaves <- frame$var == "" v <- unique(as.character(frame$var[!leaves])) x <- data.tree(tr) plot(x[[v[1]]],x[[v[2]]],col=color[q],xlab=v[1],ylab=v[2]) partition.tree(tr,add=T) } response.var.tree <- function(tr) { a <- attributes(tr$terms) as.character(a$variables[1+a$response]) } response.var <- function(object, ...) UseMethod("response.var") # plot tree partitions along with data colored by response cplot.tree <- function(tr,x=NULL,label=F,jitter=T) { y <- tr$y if(is.numeric(y)) { b <- break.quantile(y,4,plot=F) q <- ncut(y,b) } else { q <- as.numeric(y) } # find the two variables frame <- tr$frame leaves <- frame$var == "" v <- unique(as.character(frame$var[!leaves])) if(is.null(x)) x <- model.frame(tr) if(jitter) { xv <- x[[v[1]]] jitter <- (runif(length(xv))-0.5)*diff(range(xv))/100 } if(length(v) == 2) { par(mai=c(1,1,0,0)*3/4) x1 <- x[[v[1]]] x2 <- x[[v[2]]] plot(x1+jitter,x2,col=q,xlab=v[1],ylab=v[2]) if(label) { text(x1[label],x2[label],rownames(x)[label],cex=0.6,adj=0) } partition.tree(tr,add=T) } else { yvar <- response.var(tr) plot(x[[v]]+jitter,y,col=q,xlab=v,ylab=yvar) partition.tree(tr,add=T) } } cplot <- function(object, ...) UseMethod("cplot") # returns the parent index of every node parents.tree <- function(tr) { i <- as.numeric(rownames(tr$frame)) match(i %/% 2, i) } # returns a vector of node indices under node # the input is a node name, but output is node number descendants.tree <- function(tr,node) { p <- parents.tree(tr) node <- match(node,rownames(tr$frame)) i <- 1:length(p) r <- c() repeat { r <- c(r,which(i==node)) i <- p[i] if(all(is.na(i) | (i length(hc$height)) q <- 1:(length(hc$height)+1) else q <- cutree(hc,n[i]) b <- which(diff(q)!=0)+0.5 b <- c(0,b,length(x)+1) scale <- (length(n)-i+1)/length(n) segments(b,rep(r[1],length(b)), b,rep(r[2]*scale,length(b)), col="blue") } par(lwd=1) } # breaks a factor into nbins bins in order to preserve the prediction of x break.factor <- function(f,x,nbins,same.var=T,trace=T,xlab=NA,ylab=NA) { if(is.na(xlab)) xlab <- deparse(substitute(f)) if(is.na(ylab)) ylab <- deparse(substitute(x)) f <- sort.levels(f,x,fun=mean) # drop unused categories f <- factor(f) sf <- f n <- tapply(x,f,length) m <- tapply(x,f,mean) s <- tapply(x,f,scatter) if(!same.var) s <- s + 0.1 h <- sstree(m,n,s,same.var=same.var) if(nbins > length(h$height)) q <- 1:(length(h$height)+1) else q <- cutree(h,nbins) nam <- tapply(levels(f),q,function(x) paste(x,collapse=".")) levels(f) <- nam[q] if(same.var) { ss <- sum(tapply(x,f,scatter)) cat("sum of squares =", format(ss), "\n") } else { n <- tapply(x,f,length) ss <- sum(n*log(tapply(x,f,scatter)/n)) cat("sum of log-variance =", format(ss), "\n") } if(trace) { split.screen(c(1,2)) on.exit(close.screen(all=TRUE)) g <- rev(h$height) if(length(g) > 9) { g <- g[1:9] } else { g <- c(g,0) } if(length(g) > 1) g <- -diff(g) plot(g,ylab="merging cost",xlab="bins") screen(2) } boxplot(x ~ sf, xlab=xlab, ylab=ylab) boxplot.hclust.breaks(h,x) levels(f) } ############################################################################# # bug fixes # partition.tree must require cont predictors because split sets on a # factor can vary partition.tree <- function(tree, m = NULL, label = "yval", add = FALSE, ordvars, ...) { ptXlines <- function(x, v, xrange, xcoord = NULL, ycoord = NULL, tvar, i = 1) { if(v[i] == "") { y1 <- (xrange[1] + xrange[3])/2 y2 <- (xrange[2] + xrange[4])/2 return(xcoord, ycoord = c(ycoord, y1, y2), i = i) } if(v[i] == tvar[1]) { xcoord <- c(xcoord, x[i], xrange[2], x[i], xrange[4]) xr <- xrange xr[3] <- x[i] ll2 <- Recall(x, v, xr, xcoord, ycoord, tvar, i + 1) xr <- xrange xr[1] <- x[i] return(Recall(x, v, xr, ll2$xcoord, ll2$ycoord, tvar, ll2$i + 1)) } else if(v[i] == tvar[2]) { xcoord <- c(xcoord, xrange[1], x[i], xrange[3], x[i]) xr <- xrange xr[4] <- x[i] ll2 <- Recall(x, v, xr, xcoord, ycoord, tvar, i + 1) xr <- xrange xr[2] <- x[i] return(Recall(x, v, xr, ll2$xcoord, ll2$ycoord, tvar, ll2$i + 1)) } else stop("Wrong variable numbers in tree.") } if(inherits(tree, "singlenode")) stop("Cannot plot singlenode tree") if(!inherits(tree, "tree")) stop("Not legitimate tree") frame <- tree$frame leaves <- frame$var == "" var <- unique(as.character(frame$var[!leaves])) if(length(var) > 2 || length(var) < 1) stop("Tree can only have one or two predictors") nlevels <- sapply(xlevels <- attr(tree, "xlevels"), length) if(any(nlevels[var] > 0)) stop("Tree can only have continuous predictors") x <- rep(NA, length(leaves)) x[!leaves] <- as.double(substring(frame$splits[!leaves, "cutleft"], 2, 100)) # minka: let caller specify data if desired if(is.null(m)) m <- model.frame(tree) if(!missing(ordvars)) { ind <- match(var, ordvars) if(any(is.na(ind))) stop("unmatched names in vars") var <- ordvars[sort(ind)] } # minka: treat 1 var as if var 2 was yvar if(length(var) == 1) { var[2] <- response.var(tree) } lab <- frame$yval[leaves] # minka: shorten text if(is.null(frame$yprob)) lab <- formatC(signif(lab, 3)) else if(match(label, attr(tree, "ylevels"), nomatch = 0)) lab <- formatC(signif(frame$yprob[leaves, label], 3)) rx <- range(m[[var[1]]]) rx <- rx + c(-0.025, 0.025) * diff(rx) rz <- range(as.numeric(m[[var[2]]])) rz <- rz + c(-0.025, 0.025) * diff(rz) xrange <- c(rx, rz)[c(1, 3, 2, 4)] xcoord <- NULL # x1lo, x2lo, x1hi, x2hi ycoord <- NULL # y1, y2 xy <- ptXlines(x, frame$var, xrange, xcoord, ycoord, var) xx <- matrix(xy$xcoord, nrow = 4) yy <- matrix(xy$ycoord, nrow = 2) if(!add) plot(rx, rz, xlab = var[1], ylab = var[2], type = "n", xaxs = "i", yaxs = "i", ...) segments(xx[1, ], xx[2, ], xx[3, ], xx[4, ]) text(yy[1, ], yy[2, ], as.character(lab), ...) } predict.tree <- function(object, newdata = list(), type = c("vector", "tree", "class", "where"), split = FALSE, nwts, eps = 1e-3) { pred2.tree <- function(tree, x) { frame <- tree$frame if(!length(frame$yprob)) stop("only for classification trees") dimx <- dim(x) ypred <- .C("VR_pred2", as.double(x), as.integer(unclass(frame$var) - 1),#0 denotes leaf node as.character(frame$splits[, "cutleft"]), as.character(frame$splits[, "cutright"]), as.integer(sapply(attr(tree, "xlevels"), length)), as.integer(row.names(frame)), as.integer(frame$n), as.integer(nf <- dim(frame)[1]), as.integer(dimx[1]), where = double(nf*dimx[1]), NAOK = TRUE) ypred <- matrix(ypred$where, nf) dimnames(ypred) <- list(row.names(frame),dimnames(x)[[1]]) ypred } if(!inherits(object, "tree") && !inherits(object, "singlenode")) stop("Not legitimate tree") type <- match.arg(type) if(type == "class" && is.null(attr(object, "ylevels"))) stop("type class only for classification trees") # minka: bugfix if((missing(newdata) || is.null(newdata)) && type == "tree") return(object) #idiot proofing if(missing(newdata) || is.null(newdata)) { where <- object$where newdata <- model.frame.tree(object) if(!is.null(w <- object$call$weights)) nwts <- model.extract(model.frame.tree(object), "weights") } else { if(is.null(attr(newdata, "terms"))) { # newdata is not a model frame. Terms <- object$terms if(type == "tree") { # test if response can be extracted from newdata response.vars <- all.vars(formula(Terms)[[2]]) response.exists <- sapply(response.names, function(nm, newdata) eval(local = newdata, substitute(exists(nm), list(nm=nm))), newdata) if(!all(response.exists)) Terms <- delete.response(Terms) } else Terms <- delete.response(Terms) newdata <- model.frame(Terms, newdata, na.action = na.pass) } where <- pred1.tree(object, tree.matrix(newdata)) } if(type == "where") return(where) frame <- object$frame node <- row.names(frame) nodes <- as.numeric(node) nnode <- length(node) if(type != "tree") if(is.null(lev <- attr(object, "ylevels"))) { if(!split) { frame <- frame$yval[where] names(frame) <- names(where) return(frame) } else { where <- pred2.tree(object, tree.matrix(newdata)) leaf <- frame$var=="" frame <- t(where[leaf, , drop = FALSE]) %*% frame$y[leaf] names(frame) <- names(where) return(frame) } } else { if(!split) { pr <- frame$yprob[where, , drop = FALSE] dimnames(pr)[[1]] <- names(where) } else { where <- pred2.tree(object, tree.matrix(newdata)) leaf <- frame$var=="" pr <- t(where[leaf,,drop = FALSE]) %*% frame$yprob[leaf,,drop=FALSE] dimnames(pr) <- list(names(where), lev) } if(type=="class") { # minka: pick first class in order when there is a tie cl <- apply(pr, 1, which.max) return(factor(lev[cl], levels=lev)) } else return(pr) } # now must be type = "tree" which <- descendants(as.numeric(row.names(frame)))[, where, drop = FALSE] if(!all(response.exists)) dev <- rep(NA, nrow(frame)) else { y <- model.extract(newdata, "response") if(missing(nwts)) nwts <- rep(1, length(y)) if(!length(attr(object, "ylevels"))) { # # handle NAs in y separately. # drp <- is.na(y); nwts[drp] <- 0; y[drp] <- 0 dev <- .C("VR_dev3", as.integer(nnode), as.integer(nodes), integer(nnode), dev = double(nnode), sdev = double(nnode), as.double(y), as.integer(length(y)), as.double(frame$yval), as.integer(where), as.double(nwts) )$dev dev[which %*% drp > 0] <- NA } else { yp <- frame$yprob yp[yp==0] <- max(0,eps) drp <- is.na(y); nwts[drp] <- 0; y[drp] <- levels(y)[1] dev <- -2 * .C("VR_dev2", as.integer(nnode), as.integer(nodes), integer(nnode), dev = double(nnode), sdev = double(nnode), as.integer(y), as.integer(length(y)), as.double(yp), as.integer(where), as.double(nwts) )$dev dev[which %*% drp > 0] <- NA } } object$frame$dev <- as.vector(dev) object$frame$n <- as.vector(which %*% rep(1, length(where))) object$where <- where object$call <- match.call() object$y <- object$x <- NULL object } # minka: added newdata option misclass.tree <- function(tree, newdata, detail = FALSE) { if(!inherits(tree, "tree")) stop("Not legitimate tree") if(is.null(attr(tree, "ylevels"))) stop("Misclassification error rate is appropriate for factor responses only") if(!missing(newdata)) { yvar <- response.var(tree) return(sum(predict.tree(tree, newdata, type="class") != newdata[[yvar]])) } if(is.null(y <- tree$y)) y <- model.extract(model.frame(tree), "response") if(is.null(wts <- tree$weights)) wts <- model.weights(model.frame(tree)) if(is.null(wts)) wts <- rep(1, length(y)) frame <- tree$frame if(detail) { which <- descendants(as.numeric(row.names(frame))) tmp <- as.vector((which[, tree$where] * outer(frame$yval, y, "!=")) %*% wts) names(tmp) <- row.names(tree$frame) tmp } else sum(wts*(frame$yval[tree$where] != y)) } # minka: added newdata option deviance.tree <- function(object, newdata, detail = FALSE, ...) { if(!inherits(object, "tree")) stop("Not legitimate tree") if(!missing(newdata)) { yvar <- response.var(object) truth <- as.numeric(newdata[[yvar]]) p <- predict.tree(object, newdata) # correction used in prune.tree p[p == 0] <- 1e-3 i <- (1:nrow(p)) + (truth-1)*nrow(p) return(-2*sum(log(p[i]))) } frame <- object$frame if(detail) frame$dev else sum(frame$dev[frame$var == ""]) } cv.tree <- function(object, rand, FUN = prune.tree, K = 10, ...) { if(!inherits(object, "tree")) stop("Not legitimate tree") m <- model.frame(object) extras <- match.call(expand.dots = FALSE)$... FUN <- deparse(substitute(FUN)) init <- do.call(FUN, c(list(object), extras)) if(missing(rand)) rand <- sample(K, length(m[[1]]), replace = TRUE) cvdev <- 0 for(i in unique(rand)) { # minka: must build a tree as big as original tlearn <- tree(model = m[rand != i, , drop = FALSE], mindev=0) plearn <- do.call(FUN, c(list(tlearn, newdata = m[rand ==i, , drop = FALSE], k = init$k), extras)) cvdev <- cvdev + plearn$dev } init$dev <- cvdev init }