001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * PruneableClassifierTree.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.trees.j48;
024:
025: import weka.core.Capabilities;
026: import weka.core.Instances;
027: import weka.core.Utils;
028: import weka.core.Capabilities.Capability;
029:
030: import java.util.Random;
031:
032: /**
033: * Class for handling a tree structure that can
034: * be pruned using a pruning set.
035: *
036: * @author Eibe Frank (eibe@cs.waikato.ac.nz)
037: * @version $Revision: 1.11 $
038: */
039: public class PruneableClassifierTree extends ClassifierTree {
040:
041: /** for serialization */
042: static final long serialVersionUID = -555775736857600201L;
043:
044: /** True if the tree is to be pruned. */
045: private boolean pruneTheTree = false;
046:
047: /** How many subsets of equal size? One used for pruning, the rest for training. */
048: private int numSets = 3;
049:
050: /** Cleanup after the tree has been built. */
051: private boolean m_cleanup = true;
052:
053: /** The random number seed. */
054: private int m_seed = 1;
055:
056: /**
057: * Constructor for pruneable tree structure. Stores reference
058: * to associated training data at each node.
059: *
060: * @param toSelectLocModel selection method for local splitting model
061: * @param pruneTree true if the tree is to be pruned
062: * @param num number of subsets of equal size
063: * @param cleanup
064: * @param seed the seed value to use
065: * @throws Exception if something goes wrong
066: */
067: public PruneableClassifierTree(ModelSelection toSelectLocModel,
068: boolean pruneTree, int num, boolean cleanup, int seed)
069: throws Exception {
070:
071: super (toSelectLocModel);
072:
073: pruneTheTree = pruneTree;
074: numSets = num;
075: m_cleanup = cleanup;
076: m_seed = seed;
077: }
078:
079: /**
080: * Returns default capabilities of the classifier tree.
081: *
082: * @return the capabilities of this classifier tree
083: */
084: public Capabilities getCapabilities() {
085: Capabilities result = super .getCapabilities();
086:
087: // attributes
088: result.enable(Capability.NOMINAL_ATTRIBUTES);
089: result.enable(Capability.NUMERIC_ATTRIBUTES);
090: result.enable(Capability.DATE_ATTRIBUTES);
091: result.enable(Capability.MISSING_VALUES);
092:
093: // class
094: result.enable(Capability.NOMINAL_CLASS);
095: result.enable(Capability.MISSING_CLASS_VALUES);
096:
097: // instances
098: result.setMinimumNumberInstances(0);
099:
100: return result;
101: }
102:
103: /**
104: * Method for building a pruneable classifier tree.
105: *
106: * @param data the data to build the tree from
107: * @throws Exception if tree can't be built successfully
108: */
109: public void buildClassifier(Instances data) throws Exception {
110:
111: // can classifier tree handle the data?
112: getCapabilities().testWithFail(data);
113:
114: // remove instances with missing class
115: data = new Instances(data);
116: data.deleteWithMissingClass();
117:
118: Random random = new Random(m_seed);
119: data.stratify(numSets);
120: buildTree(data.trainCV(numSets, numSets - 1, random), data
121: .testCV(numSets, numSets - 1), false);
122: if (pruneTheTree) {
123: prune();
124: }
125: if (m_cleanup) {
126: cleanup(new Instances(data, 0));
127: }
128: }
129:
130: /**
131: * Prunes a tree.
132: *
133: * @throws Exception if tree can't be pruned successfully
134: */
135: public void prune() throws Exception {
136:
137: if (!m_isLeaf) {
138:
139: // Prune all subtrees.
140: for (int i = 0; i < m_sons.length; i++)
141: son(i).prune();
142:
143: // Decide if leaf is best choice.
144: if (Utils.smOrEq(errorsForLeaf(), errorsForTree())) {
145:
146: // Free son Trees
147: m_sons = null;
148: m_isLeaf = true;
149:
150: // Get NoSplit Model for node.
151: m_localModel = new NoSplit(localModel().distribution());
152: }
153: }
154: }
155:
156: /**
157: * Returns a newly created tree.
158: *
159: * @param train the training data
160: * @param test the test data
161: * @return the generated tree
162: * @throws Exception if something goes wrong
163: */
164: protected ClassifierTree getNewTree(Instances train, Instances test)
165: throws Exception {
166:
167: PruneableClassifierTree newTree = new PruneableClassifierTree(
168: m_toSelectModel, pruneTheTree, numSets, m_cleanup,
169: m_seed);
170: newTree.buildTree(train, test, false);
171: return newTree;
172: }
173:
174: /**
175: * Computes estimated errors for tree.
176: *
177: * @return the estimated errors
178: * @throws Exception if error estimate can't be computed
179: */
180: private double errorsForTree() throws Exception {
181:
182: double errors = 0;
183:
184: if (m_isLeaf)
185: return errorsForLeaf();
186: else {
187: for (int i = 0; i < m_sons.length; i++)
188: if (Utils.eq(localModel().distribution().perBag(i), 0)) {
189: errors += m_test.perBag(i)
190: - m_test.perClassPerBag(i, localModel()
191: .distribution().maxClass());
192: } else
193: errors += son(i).errorsForTree();
194:
195: return errors;
196: }
197: }
198:
199: /**
200: * Computes estimated errors for leaf.
201: *
202: * @return the estimated errors
203: * @throws Exception if error estimate can't be computed
204: */
205: private double errorsForLeaf() throws Exception {
206:
207: return m_test.total()
208: - m_test.perClass(localModel().distribution()
209: .maxClass());
210: }
211:
212: /**
213: * Method just exists to make program easier to read.
214: */
215: private ClassifierSplitModel localModel() {
216:
217: return (ClassifierSplitModel) m_localModel;
218: }
219:
220: /**
221: * Method just exists to make program easier to read.
222: */
223: private PruneableClassifierTree son(int index) {
224:
225: return (PruneableClassifierTree) m_sons[index];
226: }
227: }
|