001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * ClassifierTree.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.trees.j48;
024:
025: import weka.core.Capabilities;
026: import weka.core.CapabilitiesHandler;
027: import weka.core.Drawable;
028: import weka.core.Instance;
029: import weka.core.Instances;
030: import weka.core.Utils;
031:
032: import java.io.Serializable;
033:
034: /**
035: * Class for handling a tree structure used for
036: * classification.
037: *
038: * @author Eibe Frank (eibe@cs.waikato.ac.nz)
039: * @version $Revision: 1.21 $
040: */
041: public class ClassifierTree implements Drawable, Serializable,
042: CapabilitiesHandler {
043:
044: /** for serialization */
045: static final long serialVersionUID = -8722249377542734193L;
046:
047: /** The model selection method. */
048: protected ModelSelection m_toSelectModel;
049:
050: /** Local model at node. */
051: protected ClassifierSplitModel m_localModel;
052:
053: /** References to sons. */
054: protected ClassifierTree[] m_sons;
055:
056: /** True if node is leaf. */
057: protected boolean m_isLeaf;
058:
059: /** True if node is empty. */
060: protected boolean m_isEmpty;
061:
062: /** The training instances. */
063: protected Instances m_train;
064:
065: /** The pruning instances. */
066: protected Distribution m_test;
067:
068: /** The id for the node. */
069: protected int m_id;
070:
071: /**
072: * For getting a unique ID when outputting the tree (hashcode isn't
073: * guaranteed unique)
074: */
075: private static long PRINTED_NODES = 0;
076:
077: /**
078: * Gets the next unique node ID.
079: *
080: * @return the next unique node ID.
081: */
082: protected static long nextID() {
083:
084: return PRINTED_NODES++;
085: }
086:
087: /**
088: * Resets the unique node ID counter (e.g.
089: * between repeated separate print types)
090: */
091: protected static void resetID() {
092:
093: PRINTED_NODES = 0;
094: }
095:
096: /**
097: * Constructor.
098: */
099: public ClassifierTree(ModelSelection toSelectLocModel) {
100:
101: m_toSelectModel = toSelectLocModel;
102: }
103:
104: /**
105: * Returns default capabilities of the classifier tree.
106: *
107: * @return the capabilities of this classifier tree
108: */
109: public Capabilities getCapabilities() {
110: return new Capabilities(this );
111: }
112:
113: /**
114: * Method for building a classifier tree.
115: *
116: * @param data the data to build the tree from
117: * @throws Exception if something goes wrong
118: */
119: public void buildClassifier(Instances data) throws Exception {
120:
121: // can classifier tree handle the data?
122: getCapabilities().testWithFail(data);
123:
124: // remove instances with missing class
125: data = new Instances(data);
126: data.deleteWithMissingClass();
127:
128: buildTree(data, false);
129: }
130:
131: /**
132: * Builds the tree structure.
133: *
134: * @param data the data for which the tree structure is to be
135: * generated.
136: * @param keepData is training data to be kept?
137: * @throws Exception if something goes wrong
138: */
139: public void buildTree(Instances data, boolean keepData)
140: throws Exception {
141:
142: Instances[] localInstances;
143:
144: if (keepData) {
145: m_train = data;
146: }
147: m_test = null;
148: m_isLeaf = false;
149: m_isEmpty = false;
150: m_sons = null;
151: m_localModel = m_toSelectModel.selectModel(data);
152: if (m_localModel.numSubsets() > 1) {
153: localInstances = m_localModel.split(data);
154: data = null;
155: m_sons = new ClassifierTree[m_localModel.numSubsets()];
156: for (int i = 0; i < m_sons.length; i++) {
157: m_sons[i] = getNewTree(localInstances[i]);
158: localInstances[i] = null;
159: }
160: } else {
161: m_isLeaf = true;
162: if (Utils.eq(data.sumOfWeights(), 0))
163: m_isEmpty = true;
164: data = null;
165: }
166: }
167:
168: /**
169: * Builds the tree structure with hold out set
170: *
171: * @param train the data for which the tree structure is to be
172: * generated.
173: * @param test the test data for potential pruning
174: * @param keepData is training Data to be kept?
175: * @throws Exception if something goes wrong
176: */
177: public void buildTree(Instances train, Instances test,
178: boolean keepData) throws Exception {
179:
180: Instances[] localTrain, localTest;
181: int i;
182:
183: if (keepData) {
184: m_train = train;
185: }
186: m_isLeaf = false;
187: m_isEmpty = false;
188: m_sons = null;
189: m_localModel = m_toSelectModel.selectModel(train, test);
190: m_test = new Distribution(test, m_localModel);
191: if (m_localModel.numSubsets() > 1) {
192: localTrain = m_localModel.split(train);
193: localTest = m_localModel.split(test);
194: train = test = null;
195: m_sons = new ClassifierTree[m_localModel.numSubsets()];
196: for (i = 0; i < m_sons.length; i++) {
197: m_sons[i] = getNewTree(localTrain[i], localTest[i]);
198: localTrain[i] = null;
199: localTest[i] = null;
200: }
201: } else {
202: m_isLeaf = true;
203: if (Utils.eq(train.sumOfWeights(), 0))
204: m_isEmpty = true;
205: train = test = null;
206: }
207: }
208:
209: /**
210: * Classifies an instance.
211: *
212: * @param instance the instance to classify
213: * @return the classification
214: * @throws Exception if something goes wrong
215: */
216: public double classifyInstance(Instance instance) throws Exception {
217:
218: double maxProb = -1;
219: double currentProb;
220: int maxIndex = 0;
221: int j;
222:
223: for (j = 0; j < instance.numClasses(); j++) {
224: currentProb = getProbs(j, instance, 1);
225: if (Utils.gr(currentProb, maxProb)) {
226: maxIndex = j;
227: maxProb = currentProb;
228: }
229: }
230:
231: return (double) maxIndex;
232: }
233:
234: /**
235: * Cleanup in order to save memory.
236: *
237: * @param justHeaderInfo
238: */
239: public final void cleanup(Instances justHeaderInfo) {
240:
241: m_train = justHeaderInfo;
242: m_test = null;
243: if (!m_isLeaf)
244: for (int i = 0; i < m_sons.length; i++)
245: m_sons[i].cleanup(justHeaderInfo);
246: }
247:
248: /**
249: * Returns class probabilities for a weighted instance.
250: *
251: * @param instance the instance to get the distribution for
252: * @param useLaplace whether to use laplace or not
253: * @return the distribution
254: * @throws Exception if something goes wrong
255: */
256: public final double[] distributionForInstance(Instance instance,
257: boolean useLaplace) throws Exception {
258:
259: double[] doubles = new double[instance.numClasses()];
260:
261: for (int i = 0; i < doubles.length; i++) {
262: if (!useLaplace) {
263: doubles[i] = getProbs(i, instance, 1);
264: } else {
265: doubles[i] = getProbsLaplace(i, instance, 1);
266: }
267: }
268:
269: return doubles;
270: }
271:
272: /**
273: * Assigns a uniqe id to every node in the tree.
274: *
275: * @param lastID the last ID that was assign
276: * @return the new current ID
277: */
278: public int assignIDs(int lastID) {
279:
280: int currLastID = lastID + 1;
281:
282: m_id = currLastID;
283: if (m_sons != null) {
284: for (int i = 0; i < m_sons.length; i++) {
285: currLastID = m_sons[i].assignIDs(currLastID);
286: }
287: }
288: return currLastID;
289: }
290:
291: /**
292: * Returns the type of graph this classifier
293: * represents.
294: * @return Drawable.TREE
295: */
296: public int graphType() {
297: return Drawable.TREE;
298: }
299:
300: /**
301: * Returns graph describing the tree.
302: *
303: * @throws Exception if something goes wrong
304: * @return the tree as graph
305: */
306: public String graph() throws Exception {
307:
308: StringBuffer text = new StringBuffer();
309:
310: assignIDs(-1);
311: text.append("digraph J48Tree {\n");
312: if (m_isLeaf) {
313: text.append("N" + m_id + " [label=\""
314: + m_localModel.dumpLabel(0, m_train) + "\" "
315: + "shape=box style=filled ");
316: if (m_train != null && m_train.numInstances() > 0) {
317: text.append("data =\n" + m_train + "\n");
318: text.append(",\n");
319:
320: }
321: text.append("]\n");
322: } else {
323: text.append("N" + m_id + " [label=\""
324: + m_localModel.leftSide(m_train) + "\" ");
325: if (m_train != null && m_train.numInstances() > 0) {
326: text.append("data =\n" + m_train + "\n");
327: text.append(",\n");
328: }
329: text.append("]\n");
330: graphTree(text);
331: }
332:
333: return text.toString() + "}\n";
334: }
335:
336: /**
337: * Returns tree in prefix order.
338: *
339: * @throws Exception if something goes wrong
340: * @return the prefix order
341: */
342: public String prefix() throws Exception {
343:
344: StringBuffer text;
345:
346: text = new StringBuffer();
347: if (m_isLeaf) {
348: text.append("[" + m_localModel.dumpLabel(0, m_train) + "]");
349: } else {
350: prefixTree(text);
351: }
352:
353: return text.toString();
354: }
355:
356: /**
357: * Returns source code for the tree as an if-then statement. The
358: * class is assigned to variable "p", and assumes the tested
359: * instance is named "i". The results are returned as two stringbuffers:
360: * a section of code for assignment of the class, and a section of
361: * code containing support code (eg: other support methods).
362: *
363: * @param className the classname that this static classifier has
364: * @return an array containing two stringbuffers, the first string containing
365: * assignment code, and the second containing source for support code.
366: * @throws Exception if something goes wrong
367: */
368: public StringBuffer[] toSource(String className) throws Exception {
369:
370: StringBuffer[] result = new StringBuffer[2];
371: if (m_isLeaf) {
372: result[0] = new StringBuffer(" p = "
373: + m_localModel.distribution().maxClass(0) + ";\n");
374: result[1] = new StringBuffer("");
375: } else {
376: StringBuffer text = new StringBuffer();
377: StringBuffer atEnd = new StringBuffer();
378:
379: long printID = ClassifierTree.nextID();
380:
381: text.append(" static double N").append(
382: Integer.toHexString(m_localModel.hashCode())
383: + printID).append("(Object []i) {\n")
384: .append(" double p = Double.NaN;\n");
385:
386: text.append(" if (").append(
387: m_localModel.sourceExpression(-1, m_train)).append(
388: ") {\n");
389: text.append(" p = ").append(
390: m_localModel.distribution().maxClass(0)).append(
391: ";\n");
392: text.append(" } ");
393: for (int i = 0; i < m_sons.length; i++) {
394: text.append("else if ("
395: + m_localModel.sourceExpression(i, m_train)
396: + ") {\n");
397: if (m_sons[i].m_isLeaf) {
398: text.append(" p = "
399: + m_localModel.distribution().maxClass(i)
400: + ";\n");
401: } else {
402: StringBuffer[] sub = m_sons[i].toSource(className);
403: text.append(sub[0]);
404: atEnd.append(sub[1]);
405: }
406: text.append(" } ");
407: if (i == m_sons.length - 1) {
408: text.append('\n');
409: }
410: }
411:
412: text.append(" return p;\n }\n");
413:
414: result[0] = new StringBuffer(" p = " + className + ".N");
415: result[0].append(
416: Integer.toHexString(m_localModel.hashCode())
417: + printID).append("(i);\n");
418: result[1] = text.append(atEnd);
419: }
420: return result;
421: }
422:
423: /**
424: * Returns number of leaves in tree structure.
425: *
426: * @return the number of leaves
427: */
428: public int numLeaves() {
429:
430: int num = 0;
431: int i;
432:
433: if (m_isLeaf)
434: return 1;
435: else
436: for (i = 0; i < m_sons.length; i++)
437: num = num + m_sons[i].numLeaves();
438:
439: return num;
440: }
441:
442: /**
443: * Returns number of nodes in tree structure.
444: *
445: * @return the number of nodes
446: */
447: public int numNodes() {
448:
449: int no = 1;
450: int i;
451:
452: if (!m_isLeaf)
453: for (i = 0; i < m_sons.length; i++)
454: no = no + m_sons[i].numNodes();
455:
456: return no;
457: }
458:
459: /**
460: * Prints tree structure.
461: *
462: * @return the tree structure
463: */
464: public String toString() {
465:
466: try {
467: StringBuffer text = new StringBuffer();
468:
469: if (m_isLeaf) {
470: text.append(": ");
471: text.append(m_localModel.dumpLabel(0, m_train));
472: } else
473: dumpTree(0, text);
474: text.append("\n\nNumber of Leaves : \t" + numLeaves()
475: + "\n");
476: text.append("\nSize of the tree : \t" + numNodes() + "\n");
477:
478: return text.toString();
479: } catch (Exception e) {
480: return "Can't print classification tree.";
481: }
482: }
483:
484: /**
485: * Returns a newly created tree.
486: *
487: * @param data the training data
488: * @return the generated tree
489: * @throws Exception if something goes wrong
490: */
491: protected ClassifierTree getNewTree(Instances data)
492: throws Exception {
493:
494: ClassifierTree newTree = new ClassifierTree(m_toSelectModel);
495: newTree.buildTree(data, false);
496:
497: return newTree;
498: }
499:
500: /**
501: * Returns a newly created tree.
502: *
503: * @param train the training data
504: * @param test the pruning data.
505: * @return the generated tree
506: * @throws Exception if something goes wrong
507: */
508: protected ClassifierTree getNewTree(Instances train, Instances test)
509: throws Exception {
510:
511: ClassifierTree newTree = new ClassifierTree(m_toSelectModel);
512: newTree.buildTree(train, test, false);
513:
514: return newTree;
515: }
516:
517: /**
518: * Help method for printing tree structure.
519: *
520: * @param depth the current depth
521: * @param text for outputting the structure
522: * @throws Exception if something goes wrong
523: */
524: private void dumpTree(int depth, StringBuffer text)
525: throws Exception {
526:
527: int i, j;
528:
529: for (i = 0; i < m_sons.length; i++) {
530: text.append("\n");
531: ;
532: for (j = 0; j < depth; j++)
533: text.append("| ");
534: text.append(m_localModel.leftSide(m_train));
535: text.append(m_localModel.rightSide(i, m_train));
536: if (m_sons[i].m_isLeaf) {
537: text.append(": ");
538: text.append(m_localModel.dumpLabel(i, m_train));
539: } else
540: m_sons[i].dumpTree(depth + 1, text);
541: }
542: }
543:
544: /**
545: * Help method for printing tree structure as a graph.
546: *
547: * @param text for outputting the tree
548: * @throws Exception if something goes wrong
549: */
550: private void graphTree(StringBuffer text) throws Exception {
551:
552: for (int i = 0; i < m_sons.length; i++) {
553: text.append("N" + m_id + "->" + "N" + m_sons[i].m_id
554: + " [label=\""
555: + m_localModel.rightSide(i, m_train).trim()
556: + "\"]\n");
557: if (m_sons[i].m_isLeaf) {
558: text.append("N" + m_sons[i].m_id + " [label=\""
559: + m_localModel.dumpLabel(i, m_train) + "\" "
560: + "shape=box style=filled ");
561: if (m_train != null && m_train.numInstances() > 0) {
562: text.append("data =\n" + m_sons[i].m_train + "\n");
563: text.append(",\n");
564: }
565: text.append("]\n");
566: } else {
567: text.append("N" + m_sons[i].m_id + " [label=\""
568: + m_sons[i].m_localModel.leftSide(m_train)
569: + "\" ");
570: if (m_train != null && m_train.numInstances() > 0) {
571: text.append("data =\n" + m_sons[i].m_train + "\n");
572: text.append(",\n");
573: }
574: text.append("]\n");
575: m_sons[i].graphTree(text);
576: }
577: }
578: }
579:
580: /**
581: * Prints the tree in prefix form
582: *
583: * @param text the buffer to output the prefix form to
584: * @throws Exception if something goes wrong
585: */
586: private void prefixTree(StringBuffer text) throws Exception {
587:
588: text.append("[");
589: text.append(m_localModel.leftSide(m_train) + ":");
590: for (int i = 0; i < m_sons.length; i++) {
591: if (i > 0) {
592: text.append(",\n");
593: }
594: text.append(m_localModel.rightSide(i, m_train));
595: }
596: for (int i = 0; i < m_sons.length; i++) {
597: if (m_sons[i].m_isLeaf) {
598: text.append("[");
599: text.append(m_localModel.dumpLabel(i, m_train));
600: text.append("]");
601: } else {
602: m_sons[i].prefixTree(text);
603: }
604: }
605: text.append("]");
606: }
607:
608: /**
609: * Help method for computing class probabilities of
610: * a given instance.
611: *
612: * @param classIndex the class index
613: * @param instance the instance to compute the probabilities for
614: * @param weight the weight to use
615: * @return the laplace probs
616: * @throws Exception if something goes wrong
617: */
618: private double getProbsLaplace(int classIndex, Instance instance,
619: double weight) throws Exception {
620:
621: double prob = 0;
622:
623: if (m_isLeaf) {
624: return weight
625: * localModel().classProbLaplace(classIndex,
626: instance, -1);
627: } else {
628: int treeIndex = localModel().whichSubset(instance);
629: if (treeIndex == -1) {
630: double[] weights = localModel().weights(instance);
631: for (int i = 0; i < m_sons.length; i++) {
632: if (!son(i).m_isEmpty) {
633: prob += son(i).getProbsLaplace(classIndex,
634: instance, weights[i] * weight);
635: }
636: }
637: return prob;
638: } else {
639: if (son(treeIndex).m_isEmpty) {
640: return weight
641: * localModel().classProbLaplace(classIndex,
642: instance, treeIndex);
643: } else {
644: return son(treeIndex).getProbsLaplace(classIndex,
645: instance, weight);
646: }
647: }
648: }
649: }
650:
651: /**
652: * Help method for computing class probabilities of
653: * a given instance.
654: *
655: * @param classIndex the class index
656: * @param instance the instance to compute the probabilities for
657: * @param weight the weight to use
658: * @return the probs
659: * @throws Exception if something goes wrong
660: */
661: private double getProbs(int classIndex, Instance instance,
662: double weight) throws Exception {
663:
664: double prob = 0;
665:
666: if (m_isLeaf) {
667: return weight
668: * localModel().classProb(classIndex, instance, -1);
669: } else {
670: int treeIndex = localModel().whichSubset(instance);
671: if (treeIndex == -1) {
672: double[] weights = localModel().weights(instance);
673: for (int i = 0; i < m_sons.length; i++) {
674: if (!son(i).m_isEmpty) {
675: prob += son(i).getProbs(classIndex, instance,
676: weights[i] * weight);
677: }
678: }
679: return prob;
680: } else {
681: if (son(treeIndex).m_isEmpty) {
682: return weight
683: * localModel().classProb(classIndex,
684: instance, treeIndex);
685: } else {
686: return son(treeIndex).getProbs(classIndex,
687: instance, weight);
688: }
689: }
690: }
691: }
692:
693: /**
694: * Method just exists to make program easier to read.
695: */
696: private ClassifierSplitModel localModel() {
697:
698: return (ClassifierSplitModel) m_localModel;
699: }
700:
701: /**
702: * Method just exists to make program easier to read.
703: */
704: private ClassifierTree son(int index) {
705:
706: return (ClassifierTree) m_sons[index];
707: }
708: }
|