001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * MultiScheme.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.meta;
024:
025: import weka.classifiers.Classifier;
026: import weka.classifiers.Evaluation;
027: import weka.classifiers.RandomizableMultipleClassifiersCombiner;
028: import weka.core.Instance;
029: import weka.core.Instances;
030: import weka.core.Option;
031: import weka.core.OptionHandler;
032: import weka.core.Utils;
033:
034: import java.util.Enumeration;
035: import java.util.Random;
036: import java.util.Vector;
037:
038: /**
039: <!-- globalinfo-start -->
040: * Class for selecting a classifier from among several using cross validation on the training data or the performance on the training data. Performance is measured based on percent correct (classification) or mean-squared error (regression).
041: * <p/>
042: <!-- globalinfo-end -->
043: *
044: <!-- options-start -->
045: * Valid options are: <p/>
046: *
047: * <pre> -X <number of folds>
048: * Use cross validation for model selection using the
049: * given number of folds. (default 0, is to
050: * use training error)</pre>
051: *
052: * <pre> -S <num>
053: * Random number seed.
054: * (default 1)</pre>
055: *
056: * <pre> -B <classifier specification>
057: * Full class name of classifier to include, followed
058: * by scheme options. May be specified multiple times.
059: * (default: "weka.classifiers.rules.ZeroR")</pre>
060: *
061: * <pre> -D
062: * If set, classifier is run in debug mode and
063: * may output additional info to the console</pre>
064: *
065: <!-- options-end -->
066: *
067: * @author Len Trigg (trigg@cs.waikato.ac.nz)
068: * @version $Revision: 1.24 $
069: */
070: public class MultiScheme extends
071: RandomizableMultipleClassifiersCombiner {
072:
073: /** for serialization */
074: static final long serialVersionUID = 5710744346128957520L;
075:
076: /** The classifier that had the best performance on training data. */
077: protected Classifier m_Classifier;
078:
079: /** The index into the vector for the selected scheme */
080: protected int m_ClassifierIndex;
081:
082: /**
083: * Number of folds to use for cross validation (0 means use training
084: * error for selection)
085: */
086: protected int m_NumXValFolds;
087:
088: /**
089: * Returns a string describing classifier
090: * @return a description suitable for
091: * displaying in the explorer/experimenter gui
092: */
093: public String globalInfo() {
094:
095: return "Class for selecting a classifier from among several using cross "
096: + "validation on the training data or the performance on the "
097: + "training data. Performance is measured based on percent correct "
098: + "(classification) or mean-squared error (regression).";
099: }
100:
101: /**
102: * Returns an enumeration describing the available options.
103: *
104: * @return an enumeration of all the available options.
105: */
106: public Enumeration listOptions() {
107:
108: Vector newVector = new Vector(1);
109: newVector
110: .addElement(new Option(
111: "\tUse cross validation for model selection using the\n"
112: + "\tgiven number of folds. (default 0, is to\n"
113: + "\tuse training error)", "X", 1,
114: "-X <number of folds>"));
115:
116: Enumeration enu = super .listOptions();
117: while (enu.hasMoreElements()) {
118: newVector.addElement(enu.nextElement());
119: }
120: return newVector.elements();
121: }
122:
123: /**
124: * Parses a given list of options. <p/>
125: *
126: <!-- options-start -->
127: * Valid options are: <p/>
128: *
129: * <pre> -X <number of folds>
130: * Use cross validation for model selection using the
131: * given number of folds. (default 0, is to
132: * use training error)</pre>
133: *
134: * <pre> -S <num>
135: * Random number seed.
136: * (default 1)</pre>
137: *
138: * <pre> -B <classifier specification>
139: * Full class name of classifier to include, followed
140: * by scheme options. May be specified multiple times.
141: * (default: "weka.classifiers.rules.ZeroR")</pre>
142: *
143: * <pre> -D
144: * If set, classifier is run in debug mode and
145: * may output additional info to the console</pre>
146: *
147: <!-- options-end -->
148: *
149: * @param options the list of options as an array of strings
150: * @throws Exception if an option is not supported
151: */
152: public void setOptions(String[] options) throws Exception {
153:
154: String numFoldsString = Utils.getOption('X', options);
155: if (numFoldsString.length() != 0) {
156: setNumFolds(Integer.parseInt(numFoldsString));
157: } else {
158: setNumFolds(0);
159: }
160: super .setOptions(options);
161: }
162:
163: /**
164: * Gets the current settings of the Classifier.
165: *
166: * @return an array of strings suitable for passing to setOptions
167: */
168: public String[] getOptions() {
169:
170: String[] super Options = super .getOptions();
171: String[] options = new String[super Options.length + 2];
172:
173: int current = 0;
174: options[current++] = "-X";
175: options[current++] = "" + getNumFolds();
176:
177: System.arraycopy(super Options, 0, options, current,
178: super Options.length);
179:
180: return options;
181: }
182:
183: /**
184: * Returns the tip text for this property
185: * @return tip text for this property suitable for
186: * displaying in the explorer/experimenter gui
187: */
188: public String classifiersTipText() {
189: return "The classifiers to be chosen from.";
190: }
191:
192: /**
193: * Sets the list of possible classifers to choose from.
194: *
195: * @param classifiers an array of classifiers with all options set.
196: */
197: public void setClassifiers(Classifier[] classifiers) {
198:
199: m_Classifiers = classifiers;
200: }
201:
202: /**
203: * Gets the list of possible classifers to choose from.
204: *
205: * @return the array of Classifiers
206: */
207: public Classifier[] getClassifiers() {
208:
209: return m_Classifiers;
210: }
211:
212: /**
213: * Gets a single classifier from the set of available classifiers.
214: *
215: * @param index the index of the classifier wanted
216: * @return the Classifier
217: */
218: public Classifier getClassifier(int index) {
219:
220: return m_Classifiers[index];
221: }
222:
223: /**
224: * Gets the classifier specification string, which contains the class name of
225: * the classifier and any options to the classifier
226: *
227: * @param index the index of the classifier string to retrieve, starting from
228: * 0.
229: * @return the classifier string, or the empty string if no classifier
230: * has been assigned (or the index given is out of range).
231: */
232: protected String getClassifierSpec(int index) {
233:
234: if (m_Classifiers.length < index) {
235: return "";
236: }
237: Classifier c = getClassifier(index);
238: if (c instanceof OptionHandler) {
239: return c.getClass().getName()
240: + " "
241: + Utils.joinOptions(((OptionHandler) c)
242: .getOptions());
243: }
244: return c.getClass().getName();
245: }
246:
247: /**
248: * Returns the tip text for this property
249: * @return tip text for this property suitable for
250: * displaying in the explorer/experimenter gui
251: */
252: public String seedTipText() {
253: return "The seed used for randomizing the data "
254: + "for cross-validation.";
255: }
256:
257: /**
258: * Sets the seed for random number generation.
259: *
260: * @param seed the random number seed
261: */
262: public void setSeed(int seed) {
263:
264: m_Seed = seed;
265: ;
266: }
267:
268: /**
269: * Gets the random number seed.
270: *
271: * @return the random number seed
272: */
273: public int getSeed() {
274:
275: return m_Seed;
276: }
277:
278: /**
279: * Returns the tip text for this property
280: * @return tip text for this property suitable for
281: * displaying in the explorer/experimenter gui
282: */
283: public String numFoldsTipText() {
284: return "The number of folds used for cross-validation (if 0, "
285: + "performance on training data will be used).";
286: }
287:
288: /**
289: * Gets the number of folds for cross-validation. A number less
290: * than 2 specifies using training error rather than cross-validation.
291: *
292: * @return the number of folds for cross-validation
293: */
294: public int getNumFolds() {
295:
296: return m_NumXValFolds;
297: }
298:
299: /**
300: * Sets the number of folds for cross-validation. A number less
301: * than 2 specifies using training error rather than cross-validation.
302: *
303: * @param numFolds the number of folds for cross-validation
304: */
305: public void setNumFolds(int numFolds) {
306:
307: m_NumXValFolds = numFolds;
308: }
309:
310: /**
311: * Returns the tip text for this property
312: * @return tip text for this property suitable for
313: * displaying in the explorer/experimenter gui
314: */
315: public String debugTipText() {
316: return "Whether debug information is output to console.";
317: }
318:
319: /**
320: * Set debugging mode
321: *
322: * @param debug true if debug output should be printed
323: */
324: public void setDebug(boolean debug) {
325:
326: m_Debug = debug;
327: }
328:
329: /**
330: * Get whether debugging is turned on
331: *
332: * @return true if debugging output is on
333: */
334: public boolean getDebug() {
335:
336: return m_Debug;
337: }
338:
339: /**
340: * Get the index of the classifier that was determined as best during
341: * cross-validation.
342: *
343: * @return the index in the classifier array
344: */
345: public int getBestClassifierIndex() {
346: return m_ClassifierIndex;
347: }
348:
349: /**
350: * Buildclassifier selects a classifier from the set of classifiers
351: * by minimising error on the training data.
352: *
353: * @param data the training data to be used for generating the
354: * boosted classifier.
355: * @throws Exception if the classifier could not be built successfully
356: */
357: public void buildClassifier(Instances data) throws Exception {
358:
359: if (m_Classifiers.length == 0) {
360: throw new Exception("No base classifiers have been set!");
361: }
362:
363: // can classifier handle the data?
364: getCapabilities().testWithFail(data);
365:
366: // remove instances with missing class
367: Instances newData = new Instances(data);
368: newData.deleteWithMissingClass();
369:
370: Random random = new Random(m_Seed);
371: newData.randomize(random);
372: if (newData.classAttribute().isNominal()
373: && (m_NumXValFolds > 1)) {
374: newData.stratify(m_NumXValFolds);
375: }
376: Instances train = newData; // train on all data by default
377: Instances test = newData; // test on training data by default
378: Classifier bestClassifier = null;
379: int bestIndex = -1;
380: double bestPerformance = Double.NaN;
381: int numClassifiers = m_Classifiers.length;
382: for (int i = 0; i < numClassifiers; i++) {
383: Classifier currentClassifier = getClassifier(i);
384: Evaluation evaluation;
385: if (m_NumXValFolds > 1) {
386: evaluation = new Evaluation(newData);
387: for (int j = 0; j < m_NumXValFolds; j++) {
388:
389: // We want to randomize the data the same way for every
390: // learning scheme.
391: train = newData.trainCV(m_NumXValFolds, j,
392: new Random(1));
393: test = newData.testCV(m_NumXValFolds, j);
394: currentClassifier.buildClassifier(train);
395: evaluation.setPriors(train);
396: evaluation.evaluateModel(currentClassifier, test);
397: }
398: } else {
399: currentClassifier.buildClassifier(train);
400: evaluation = new Evaluation(train);
401: evaluation.evaluateModel(currentClassifier, test);
402: }
403:
404: double error = evaluation.errorRate();
405: if (m_Debug) {
406: System.err.println("Error rate: "
407: + Utils.doubleToString(error, 6, 4)
408: + " for classifier "
409: + currentClassifier.getClass().getName());
410: }
411:
412: if ((i == 0) || (error < bestPerformance)) {
413: bestClassifier = currentClassifier;
414: bestPerformance = error;
415: bestIndex = i;
416: }
417: }
418: m_ClassifierIndex = bestIndex;
419: if (m_NumXValFolds > 1) {
420: bestClassifier.buildClassifier(newData);
421: }
422: m_Classifier = bestClassifier;
423: }
424:
425: /**
426: * Returns class probabilities.
427: *
428: * @param instance the instance to be classified
429: * @return the distribution for the instance
430: * @throws Exception if instance could not be classified
431: * successfully
432: */
433: public double[] distributionForInstance(Instance instance)
434: throws Exception {
435:
436: return m_Classifier.distributionForInstance(instance);
437: }
438:
439: /**
440: * Output a representation of this classifier
441: * @return a string representation of the classifier
442: */
443: public String toString() {
444:
445: if (m_Classifier == null) {
446: return "MultiScheme: No model built yet.";
447: }
448:
449: String result = "MultiScheme selection using";
450: if (m_NumXValFolds > 1) {
451: result += " cross validation error";
452: } else {
453: result += " error on training data";
454: }
455: result += " from the following:\n";
456: for (int i = 0; i < m_Classifiers.length; i++) {
457: result += '\t' + getClassifierSpec(i) + '\n';
458: }
459:
460: result += "Selected scheme: "
461: + getClassifierSpec(m_ClassifierIndex) + "\n\n"
462: + m_Classifier.toString();
463: return result;
464: }
465:
466: /**
467: * Main method for testing this class.
468: *
469: * @param argv should contain the following arguments:
470: * -t training file [-T test file] [-c class index]
471: */
472: public static void main(String[] argv) {
473: runClassifier(new MultiScheme(), argv);
474: }
475: }
|