001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * Bagging.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.meta;
024:
025: import weka.classifiers.RandomizableIteratedSingleClassifierEnhancer;
026: import weka.core.AdditionalMeasureProducer;
027: import weka.core.Instance;
028: import weka.core.Instances;
029: import weka.core.Option;
030: import weka.core.Randomizable;
031: import weka.core.TechnicalInformation;
032: import weka.core.TechnicalInformationHandler;
033: import weka.core.Utils;
034: import weka.core.WeightedInstancesHandler;
035: import weka.core.TechnicalInformation.Field;
036: import weka.core.TechnicalInformation.Type;
037:
038: import java.util.Enumeration;
039: import java.util.Random;
040: import java.util.Vector;
041:
042: /**
043: <!-- globalinfo-start -->
044: * Class for bagging a classifier to reduce variance. Can do classification and regression depending on the base learner. <br/>
045: * <br/>
046: * For more information, see<br/>
047: * <br/>
048: * Leo Breiman (1996). Bagging predictors. Machine Learning. 24(2):123-140.
049: * <p/>
050: <!-- globalinfo-end -->
051: *
052: <!-- technical-bibtex-start -->
053: * BibTeX:
054: * <pre>
055: * @article{Breiman1996,
056: * author = {Leo Breiman},
057: * journal = {Machine Learning},
058: * number = {2},
059: * pages = {123-140},
060: * title = {Bagging predictors},
061: * volume = {24},
062: * year = {1996}
063: * }
064: * </pre>
065: * <p/>
066: <!-- technical-bibtex-end -->
067: *
068: <!-- options-start -->
069: * Valid options are: <p/>
070: *
071: * <pre> -P
072: * Size of each bag, as a percentage of the
073: * training set size. (default 100)</pre>
074: *
075: * <pre> -O
076: * Calculate the out of bag error.</pre>
077: *
078: * <pre> -S <num>
079: * Random number seed.
080: * (default 1)</pre>
081: *
082: * <pre> -I <num>
083: * Number of iterations.
084: * (default 10)</pre>
085: *
086: * <pre> -D
087: * If set, classifier is run in debug mode and
088: * may output additional info to the console</pre>
089: *
090: * <pre> -W
091: * Full name of base classifier.
092: * (default: weka.classifiers.trees.REPTree)</pre>
093: *
094: * <pre>
095: * Options specific to classifier weka.classifiers.trees.REPTree:
096: * </pre>
097: *
098: * <pre> -M <minimum number of instances>
099: * Set minimum number of instances per leaf (default 2).</pre>
100: *
101: * <pre> -V <minimum variance for split>
102: * Set minimum numeric class variance proportion
103: * of train variance for split (default 1e-3).</pre>
104: *
105: * <pre> -N <number of folds>
106: * Number of folds for reduced error pruning (default 3).</pre>
107: *
108: * <pre> -S <seed>
109: * Seed for random data shuffling (default 1).</pre>
110: *
111: * <pre> -P
112: * No pruning.</pre>
113: *
114: * <pre> -L
115: * Maximum tree depth (default -1, no maximum)</pre>
116: *
117: <!-- options-end -->
118: *
119: * Options after -- are passed to the designated classifier.<p>
120: *
121: * @author Eibe Frank (eibe@cs.waikato.ac.nz)
122: * @author Len Trigg (len@reeltwo.com)
123: * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
124: * @version $Revision: 1.39 $
125: */
126: public class Bagging extends
127: RandomizableIteratedSingleClassifierEnhancer implements
128: WeightedInstancesHandler, AdditionalMeasureProducer,
129: TechnicalInformationHandler {
130:
131: /** for serialization */
132: static final long serialVersionUID = -505879962237199703L;
133:
134: /** The size of each bag sample, as a percentage of the training size */
135: protected int m_BagSizePercent = 100;
136:
137: /** Whether to calculate the out of bag error */
138: protected boolean m_CalcOutOfBag = false;
139:
140: /** The out of bag error that has been calculated */
141: protected double m_OutOfBagError;
142:
143: /**
144: * Constructor.
145: */
146: public Bagging() {
147:
148: m_Classifier = new weka.classifiers.trees.REPTree();
149: }
150:
151: /**
152: * Returns a string describing classifier
153: * @return a description suitable for
154: * displaying in the explorer/experimenter gui
155: */
156: public String globalInfo() {
157:
158: return "Class for bagging a classifier to reduce variance. Can do classification "
159: + "and regression depending on the base learner. \n\n"
160: + "For more information, see\n\n"
161: + getTechnicalInformation().toString();
162: }
163:
164: /**
165: * Returns an instance of a TechnicalInformation object, containing
166: * detailed information about the technical background of this class,
167: * e.g., paper reference or book this class is based on.
168: *
169: * @return the technical information about this class
170: */
171: public TechnicalInformation getTechnicalInformation() {
172: TechnicalInformation result;
173:
174: result = new TechnicalInformation(Type.ARTICLE);
175: result.setValue(Field.AUTHOR, "Leo Breiman");
176: result.setValue(Field.YEAR, "1996");
177: result.setValue(Field.TITLE, "Bagging predictors");
178: result.setValue(Field.JOURNAL, "Machine Learning");
179: result.setValue(Field.VOLUME, "24");
180: result.setValue(Field.NUMBER, "2");
181: result.setValue(Field.PAGES, "123-140");
182:
183: return result;
184: }
185:
186: /**
187: * String describing default classifier.
188: *
189: * @return the default classifier classname
190: */
191: protected String defaultClassifierString() {
192:
193: return "weka.classifiers.trees.REPTree";
194: }
195:
196: /**
197: * Returns an enumeration describing the available options.
198: *
199: * @return an enumeration of all the available options.
200: */
201: public Enumeration listOptions() {
202:
203: Vector newVector = new Vector(2);
204:
205: newVector.addElement(new Option(
206: "\tSize of each bag, as a percentage of the\n"
207: + "\ttraining set size. (default 100)", "P", 1,
208: "-P"));
209: newVector.addElement(new Option(
210: "\tCalculate the out of bag error.", "O", 0, "-O"));
211:
212: Enumeration enu = super .listOptions();
213: while (enu.hasMoreElements()) {
214: newVector.addElement(enu.nextElement());
215: }
216: return newVector.elements();
217: }
218:
219: /**
220: * Parses a given list of options. <p/>
221: *
222: <!-- options-start -->
223: * Valid options are: <p/>
224: *
225: * <pre> -P
226: * Size of each bag, as a percentage of the
227: * training set size. (default 100)</pre>
228: *
229: * <pre> -O
230: * Calculate the out of bag error.</pre>
231: *
232: * <pre> -S <num>
233: * Random number seed.
234: * (default 1)</pre>
235: *
236: * <pre> -I <num>
237: * Number of iterations.
238: * (default 10)</pre>
239: *
240: * <pre> -D
241: * If set, classifier is run in debug mode and
242: * may output additional info to the console</pre>
243: *
244: * <pre> -W
245: * Full name of base classifier.
246: * (default: weka.classifiers.trees.REPTree)</pre>
247: *
248: * <pre>
249: * Options specific to classifier weka.classifiers.trees.REPTree:
250: * </pre>
251: *
252: * <pre> -M <minimum number of instances>
253: * Set minimum number of instances per leaf (default 2).</pre>
254: *
255: * <pre> -V <minimum variance for split>
256: * Set minimum numeric class variance proportion
257: * of train variance for split (default 1e-3).</pre>
258: *
259: * <pre> -N <number of folds>
260: * Number of folds for reduced error pruning (default 3).</pre>
261: *
262: * <pre> -S <seed>
263: * Seed for random data shuffling (default 1).</pre>
264: *
265: * <pre> -P
266: * No pruning.</pre>
267: *
268: * <pre> -L
269: * Maximum tree depth (default -1, no maximum)</pre>
270: *
271: <!-- options-end -->
272: *
273: * Options after -- are passed to the designated classifier.<p>
274: *
275: * @param options the list of options as an array of strings
276: * @throws Exception if an option is not supported
277: */
278: public void setOptions(String[] options) throws Exception {
279:
280: String bagSize = Utils.getOption('P', options);
281: if (bagSize.length() != 0) {
282: setBagSizePercent(Integer.parseInt(bagSize));
283: } else {
284: setBagSizePercent(100);
285: }
286:
287: setCalcOutOfBag(Utils.getFlag('O', options));
288:
289: super .setOptions(options);
290: }
291:
292: /**
293: * Gets the current settings of the Classifier.
294: *
295: * @return an array of strings suitable for passing to setOptions
296: */
297: public String[] getOptions() {
298:
299: String[] super Options = super .getOptions();
300: String[] options = new String[super Options.length + 3];
301:
302: int current = 0;
303: options[current++] = "-P";
304: options[current++] = "" + getBagSizePercent();
305:
306: if (getCalcOutOfBag()) {
307: options[current++] = "-O";
308: }
309:
310: System.arraycopy(super Options, 0, options, current,
311: super Options.length);
312:
313: current += super Options.length;
314: while (current < options.length) {
315: options[current++] = "";
316: }
317: return options;
318: }
319:
320: /**
321: * Returns the tip text for this property
322: * @return tip text for this property suitable for
323: * displaying in the explorer/experimenter gui
324: */
325: public String bagSizePercentTipText() {
326: return "Size of each bag, as a percentage of the training set size.";
327: }
328:
329: /**
330: * Gets the size of each bag, as a percentage of the training set size.
331: *
332: * @return the bag size, as a percentage.
333: */
334: public int getBagSizePercent() {
335:
336: return m_BagSizePercent;
337: }
338:
339: /**
340: * Sets the size of each bag, as a percentage of the training set size.
341: *
342: * @param newBagSizePercent the bag size, as a percentage.
343: */
344: public void setBagSizePercent(int newBagSizePercent) {
345:
346: m_BagSizePercent = newBagSizePercent;
347: }
348:
349: /**
350: * Returns the tip text for this property
351: * @return tip text for this property suitable for
352: * displaying in the explorer/experimenter gui
353: */
354: public String calcOutOfBagTipText() {
355: return "Whether the out-of-bag error is calculated.";
356: }
357:
358: /**
359: * Set whether the out of bag error is calculated.
360: *
361: * @param calcOutOfBag whether to calculate the out of bag error
362: */
363: public void setCalcOutOfBag(boolean calcOutOfBag) {
364:
365: m_CalcOutOfBag = calcOutOfBag;
366: }
367:
368: /**
369: * Get whether the out of bag error is calculated.
370: *
371: * @return whether the out of bag error is calculated
372: */
373: public boolean getCalcOutOfBag() {
374:
375: return m_CalcOutOfBag;
376: }
377:
378: /**
379: * Gets the out of bag error that was calculated as the classifier
380: * was built.
381: *
382: * @return the out of bag error
383: */
384: public double measureOutOfBagError() {
385:
386: return m_OutOfBagError;
387: }
388:
389: /**
390: * Returns an enumeration of the additional measure names.
391: *
392: * @return an enumeration of the measure names
393: */
394: public Enumeration enumerateMeasures() {
395:
396: Vector newVector = new Vector(1);
397: newVector.addElement("measureOutOfBagError");
398: return newVector.elements();
399: }
400:
401: /**
402: * Returns the value of the named measure.
403: *
404: * @param additionalMeasureName the name of the measure to query for its value
405: * @return the value of the named measure
406: * @throws IllegalArgumentException if the named measure is not supported
407: */
408: public double getMeasure(String additionalMeasureName) {
409:
410: if (additionalMeasureName
411: .equalsIgnoreCase("measureOutOfBagError")) {
412: return measureOutOfBagError();
413: } else {
414: throw new IllegalArgumentException(additionalMeasureName
415: + " not supported (Bagging)");
416: }
417: }
418:
419: /**
420: * Creates a new dataset of the same size using random sampling
421: * with replacement according to the given weight vector. The
422: * weights of the instances in the new dataset are set to one.
423: * The length of the weight vector has to be the same as the
424: * number of instances in the dataset, and all weights have to
425: * be positive.
426: *
427: * @param data the data to be sampled from
428: * @param random a random number generator
429: * @param sampled indicating which instance has been sampled
430: * @return the new dataset
431: * @throws IllegalArgumentException if the weights array is of the wrong
432: * length or contains negative weights.
433: */
434: public final Instances resampleWithWeights(Instances data,
435: Random random, boolean[] sampled) {
436:
437: double[] weights = new double[data.numInstances()];
438: for (int i = 0; i < weights.length; i++) {
439: weights[i] = data.instance(i).weight();
440: }
441: Instances newData = new Instances(data, data.numInstances());
442: if (data.numInstances() == 0) {
443: return newData;
444: }
445: double[] probabilities = new double[data.numInstances()];
446: double sumProbs = 0, sumOfWeights = Utils.sum(weights);
447: for (int i = 0; i < data.numInstances(); i++) {
448: sumProbs += random.nextDouble();
449: probabilities[i] = sumProbs;
450: }
451: Utils.normalize(probabilities, sumProbs / sumOfWeights);
452:
453: // Make sure that rounding errors don't mess things up
454: probabilities[data.numInstances() - 1] = sumOfWeights;
455: int k = 0;
456: int l = 0;
457: sumProbs = 0;
458: while ((k < data.numInstances() && (l < data.numInstances()))) {
459: if (weights[l] < 0) {
460: throw new IllegalArgumentException(
461: "Weights have to be positive.");
462: }
463: sumProbs += weights[l];
464: while ((k < data.numInstances())
465: && (probabilities[k] <= sumProbs)) {
466: newData.add(data.instance(l));
467: sampled[l] = true;
468: newData.instance(k).setWeight(1);
469: k++;
470: }
471: l++;
472: }
473: return newData;
474: }
475:
476: /**
477: * Bagging method.
478: *
479: * @param data the training data to be used for generating the
480: * bagged classifier.
481: * @throws Exception if the classifier could not be built successfully
482: */
483: public void buildClassifier(Instances data) throws Exception {
484:
485: // can classifier handle the data?
486: getCapabilities().testWithFail(data);
487:
488: // remove instances with missing class
489: data = new Instances(data);
490: data.deleteWithMissingClass();
491:
492: super .buildClassifier(data);
493:
494: if (m_CalcOutOfBag && (m_BagSizePercent != 100)) {
495: throw new IllegalArgumentException(
496: "Bag size needs to be 100% if "
497: + "out-of-bag error is to be calculated!");
498: }
499:
500: int bagSize = data.numInstances() * m_BagSizePercent / 100;
501: Random random = new Random(m_Seed);
502:
503: boolean[][] inBag = null;
504: if (m_CalcOutOfBag)
505: inBag = new boolean[m_Classifiers.length][];
506:
507: for (int j = 0; j < m_Classifiers.length; j++) {
508: Instances bagData = null;
509:
510: // create the in-bag dataset
511: if (m_CalcOutOfBag) {
512: inBag[j] = new boolean[data.numInstances()];
513: bagData = resampleWithWeights(data, random, inBag[j]);
514: } else {
515: bagData = data.resampleWithWeights(random);
516: if (bagSize < data.numInstances()) {
517: bagData.randomize(random);
518: Instances newBagData = new Instances(bagData, 0,
519: bagSize);
520: bagData = newBagData;
521: }
522: }
523:
524: if (m_Classifier instanceof Randomizable) {
525: ((Randomizable) m_Classifiers[j]).setSeed(random
526: .nextInt());
527: }
528:
529: // build the classifier
530: m_Classifiers[j].buildClassifier(bagData);
531: }
532:
533: // calc OOB error?
534: if (getCalcOutOfBag()) {
535: double outOfBagCount = 0.0;
536: double errorSum = 0.0;
537: boolean numeric = data.classAttribute().isNumeric();
538:
539: for (int i = 0; i < data.numInstances(); i++) {
540: double vote;
541: double[] votes;
542: if (numeric)
543: votes = new double[1];
544: else
545: votes = new double[data.numClasses()];
546:
547: // determine predictions for instance
548: int voteCount = 0;
549: for (int j = 0; j < m_Classifiers.length; j++) {
550: if (inBag[j][i])
551: continue;
552:
553: voteCount++;
554: double pred = m_Classifiers[j]
555: .classifyInstance(data.instance(i));
556: if (numeric)
557: votes[0] += pred;
558: else
559: votes[(int) pred]++;
560: }
561:
562: // "vote"
563: if (numeric)
564: vote = votes[0] / voteCount; // average
565: else
566: vote = Utils.maxIndex(votes); // majority vote
567:
568: // error for instance
569: outOfBagCount += data.instance(i).weight();
570: if (numeric) {
571: errorSum += StrictMath.abs(vote
572: - data.instance(i).classValue())
573: * data.instance(i).weight();
574: } else {
575: if (vote != data.instance(i).classValue())
576: errorSum += data.instance(i).weight();
577: }
578: }
579:
580: m_OutOfBagError = errorSum / outOfBagCount;
581: } else {
582: m_OutOfBagError = 0;
583: }
584: }
585:
586: /**
587: * Calculates the class membership probabilities for the given test
588: * instance.
589: *
590: * @param instance the instance to be classified
591: * @return preedicted class probability distribution
592: * @throws Exception if distribution can't be computed successfully
593: */
594: public double[] distributionForInstance(Instance instance)
595: throws Exception {
596:
597: double[] sums = new double[instance.numClasses()], newProbs;
598:
599: for (int i = 0; i < m_NumIterations; i++) {
600: if (instance.classAttribute().isNumeric() == true) {
601: sums[0] += m_Classifiers[i].classifyInstance(instance);
602: } else {
603: newProbs = m_Classifiers[i]
604: .distributionForInstance(instance);
605: for (int j = 0; j < newProbs.length; j++)
606: sums[j] += newProbs[j];
607: }
608: }
609: if (instance.classAttribute().isNumeric() == true) {
610: sums[0] /= (double) m_NumIterations;
611: return sums;
612: } else if (Utils.eq(Utils.sum(sums), 0)) {
613: return sums;
614: } else {
615: Utils.normalize(sums);
616: return sums;
617: }
618: }
619:
620: /**
621: * Returns description of the bagged classifier.
622: *
623: * @return description of the bagged classifier as a string
624: */
625: public String toString() {
626:
627: if (m_Classifiers == null) {
628: return "Bagging: No model built yet.";
629: }
630: StringBuffer text = new StringBuffer();
631: text.append("All the base classifiers: \n\n");
632: for (int i = 0; i < m_Classifiers.length; i++)
633: text.append(m_Classifiers[i].toString() + "\n\n");
634:
635: if (m_CalcOutOfBag) {
636: text
637: .append("Out of bag error: "
638: + Utils.doubleToString(m_OutOfBagError, 4)
639: + "\n\n");
640: }
641:
642: return text.toString();
643: }
644:
645: /**
646: * Main method for testing this class.
647: *
648: * @param argv the options
649: */
650: public static void main(String[] argv) {
651: runClassifier(new Bagging(), argv);
652: }
653: }
|