0001: /*
0002: * This program is free software; you can redistribute it and/or modify
0003: * it under the terms of the GNU General Public License as published by
0004: * the Free Software Foundation; either version 2 of the License, or
0005: * (at your option) any later version.
0006: *
0007: * This program is distributed in the hope that it will be useful,
0008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
0009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
0010: * GNU General Public License for more details.
0011: *
0012: * You should have received a copy of the GNU General Public License
0013: * along with this program; if not, write to the Free Software
0014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
0015: */
0016:
0017: /*
0018: * EnsembleSelection.java
0019: * Copyright (C) 2006 David Michael
0020: *
0021: */
0022:
0023: package weka.classifiers.meta;
0024:
0025: import weka.classifiers.Evaluation;
0026: import weka.classifiers.RandomizableClassifier;
0027: import weka.classifiers.meta.ensembleSelection.EnsembleMetricHelper;
0028: import weka.classifiers.meta.ensembleSelection.EnsembleSelectionLibrary;
0029: import weka.classifiers.meta.ensembleSelection.EnsembleSelectionLibraryModel;
0030: import weka.classifiers.meta.ensembleSelection.ModelBag;
0031: import weka.classifiers.trees.REPTree;
0032: import weka.classifiers.xml.XMLClassifier;
0033: import weka.core.Capabilities;
0034: import weka.core.Instance;
0035: import weka.core.Instances;
0036: import weka.core.Option;
0037: import weka.core.SelectedTag;
0038: import weka.core.Tag;
0039: import weka.core.TechnicalInformation;
0040: import weka.core.TechnicalInformationHandler;
0041: import weka.core.Utils;
0042: import weka.core.Capabilities.Capability;
0043: import weka.core.TechnicalInformation.Field;
0044: import weka.core.TechnicalInformation.Type;
0045: import weka.core.xml.KOML;
0046: import weka.core.xml.XMLOptions;
0047: import weka.core.xml.XMLSerialization;
0048:
0049: import java.io.BufferedInputStream;
0050: import java.io.BufferedOutputStream;
0051: import java.io.BufferedReader;
0052: import java.io.File;
0053: import java.io.FileInputStream;
0054: import java.io.FileOutputStream;
0055: import java.io.FileReader;
0056: import java.io.InputStream;
0057: import java.io.ObjectInputStream;
0058: import java.io.ObjectOutputStream;
0059: import java.io.OutputStream;
0060: import java.util.Date;
0061: import java.util.Enumeration;
0062: import java.util.HashMap;
0063: import java.util.Iterator;
0064: import java.util.Map;
0065: import java.util.Random;
0066: import java.util.Set;
0067: import java.util.Vector;
0068: import java.util.zip.GZIPInputStream;
0069: import java.util.zip.GZIPOutputStream;
0070:
0071: /**
0072: <!-- globalinfo-start -->
0073: * Combines several classifiers using the ensemble selection method. For more information, see: Caruana, Rich, Niculescu, Alex, Crew, Geoff, and Ksikes, Alex, Ensemble Selection from Libraries of Models, The International Conference on Machine Learning (ICML'04), 2004. Implemented in Weka by Bob Jung and David Michael.
0074: * <p/>
0075: <!-- globalinfo-end -->
0076: *
0077: <!-- technical-bibtex-start -->
0078: * BibTeX:
0079: * <pre>
0080: * @inproceedings{RichCaruana2004,
0081: * author = {Rich Caruana, Alex Niculescu, Geoff Crew, and Alex Ksikes},
0082: * booktitle = {21st International Conference on Machine Learning},
0083: * title = {Ensemble Selection from Libraries of Models},
0084: * year = {2004}
0085: * }
0086: * </pre>
0087: * <p/>
0088: <!-- technical-bibtex-end -->
0089: *
0090: * Our implementation of ensemble selection is a bit different from the other
0091: * classifiers because we assume that the list of models to be trained is too
0092: * large to fit in memory and that our base classifiers will need to be
0093: * serialized to the file system (in the directory listed in the "workingDirectory
0094: * option). We have adopted the term "model library" for this large set of
0095: * classifiers keeping in line with the original paper.
0096: * <p/>
0097: *
0098: * If you are planning to use this classifier, we highly recommend you take a
0099: * quick look at our FAQ/tutorial on the WIKI. There are a few things that
0100: * are unique to this classifier that could trip you up. Otherwise, this
0101: * method is a great way to get really great classifier performance without
0102: * having to do too much parameter tuning. What is nice is that in the worst
0103: * case you get a nice summary of how s large number of diverse models
0104: * performed on your data set.
0105: * <p/>
0106: *
0107: * This class relies on the package weka.classifiers.meta.ensembleSelection.
0108: * <p/>
0109: *
0110: * When run from the Explorer or another GUI, the classifier depends on the
0111: * package weka.gui.libraryEditor.
0112: * <p/>
0113: *
0114: <!-- options-start -->
0115: * Valid options are: <p/>
0116: *
0117: * <pre> -L </path/to/modelLibrary>
0118: * Specifies the Model Library File, continuing the list of all models.</pre>
0119: *
0120: * <pre> -W </path/to/working/directory>
0121: * Specifies the Working Directory, where all models will be stored.</pre>
0122: *
0123: * <pre> -B <numModelBags>
0124: * Set the number of bags, i.e., number of iterations to run
0125: * the ensemble selection algorithm.</pre>
0126: *
0127: * <pre> -E <modelRatio>
0128: * Set the ratio of library models that will be randomly chosen
0129: * to populate each bag of models.</pre>
0130: *
0131: * <pre> -V <validationRatio>
0132: * Set the ratio of the training data set that will be reserved
0133: * for validation.</pre>
0134: *
0135: * <pre> -H <hillClimbIterations>
0136: * Set the number of hillclimbing iterations to be performed
0137: * on each model bag.</pre>
0138: *
0139: * <pre> -I <sortInitialization>
0140: * Set the the ratio of the ensemble library that the sort
0141: * initialization algorithm will be able to choose from while
0142: * initializing the ensemble for each model bag</pre>
0143: *
0144: * <pre> -X <numFolds>
0145: * Sets the number of cross-validation folds.</pre>
0146: *
0147: * <pre> -P <hillclimbMettric>
0148: * Specify the metric that will be used for model selection
0149: * during the hillclimbing algorithm.
0150: * Valid metrics are:
0151: * accuracy, rmse, roc, precision, recall, fscore, all</pre>
0152: *
0153: * <pre> -A <algorithm>
0154: * Specifies the algorithm to be used for ensemble selection.
0155: * Valid algorithms are:
0156: * "forward" (default) for forward selection.
0157: * "backward" for backward elimination.
0158: * "both" for both forward and backward elimination.
0159: * "best" to simply print out top performer from the
0160: * ensemble library
0161: * "library" to only train the models in the ensemble
0162: * library</pre>
0163: *
0164: * <pre> -R
0165: * Flag whether or not models can be selected more than once
0166: * for an ensemble.</pre>
0167: *
0168: * <pre> -G
0169: * Whether sort initialization greedily stops adding models
0170: * when performance degrades.</pre>
0171: *
0172: * <pre> -O
0173: * Flag for verbose output. Prints out performance of all
0174: * selected models.</pre>
0175: *
0176: * <pre> -S <num>
0177: * Random number seed.
0178: * (default 1)</pre>
0179: *
0180: * <pre> -D
0181: * If set, classifier is run in debug mode and
0182: * may output additional info to the console</pre>
0183: *
0184: <!-- options-end -->
0185: *
0186: * @author Robert Jung
0187: * @author David Michael
0188: * @version $Revision: 1.3 $
0189: */
0190: public class EnsembleSelection extends RandomizableClassifier implements
0191: TechnicalInformationHandler {
0192:
0193: /** for serialization */
0194: private static final long serialVersionUID = -1744155148765058511L;
0195:
0196: /**
0197: * The Library of models, from which we can select our ensemble. Usually
0198: * loaded from a model list file (.mlf or .model.xml) using the -L
0199: * command-line option.
0200: */
0201: protected EnsembleSelectionLibrary m_library = new EnsembleSelectionLibrary();
0202:
0203: /**
0204: * List of models chosen by EnsembleSelection. Populated by buildClassifier.
0205: */
0206: protected EnsembleSelectionLibraryModel[] m_chosen_models = null;
0207:
0208: /**
0209: * An array of weights for the chosen models. Elements are parallel to those
0210: * in m_chosen_models. That is, m_chosen_model_weights[i] is the weight
0211: * associated with the model at m_chosen_models[i].
0212: */
0213: protected int[] m_chosen_model_weights = null;
0214:
0215: /** Total weight of all chosen models. */
0216: protected int m_total_weight = 0;
0217:
0218: /**
0219: * ratio of library models that will be randomly chosen to be used for each
0220: * model bag
0221: */
0222: protected double m_modelRatio = 0.5;
0223:
0224: /**
0225: * Indicates the fraction of the given training set that should be used for
0226: * hillclimbing/validation. This fraction is set aside and not used for
0227: * training. It is assumed that any loaded models were also not trained on
0228: * set-aside data. (If the same percentage and random seed were used
0229: * previously to train the models in the library, this will work as expected -
0230: * i.e., those models will be valid)
0231: */
0232: protected double m_validationRatio = 0.25;
0233:
0234: /** defines metrics that can be chosen for hillclimbing */
0235: public static final Tag[] TAGS_METRIC = {
0236: new Tag(EnsembleMetricHelper.METRIC_ACCURACY,
0237: "Optimize with Accuracy"),
0238: new Tag(EnsembleMetricHelper.METRIC_RMSE,
0239: "Optimize with RMSE"),
0240: new Tag(EnsembleMetricHelper.METRIC_ROC,
0241: "Optimize with ROC"),
0242: new Tag(EnsembleMetricHelper.METRIC_PRECISION,
0243: "Optimize with precision"),
0244: new Tag(EnsembleMetricHelper.METRIC_RECALL,
0245: "Optimize with recall"),
0246: new Tag(EnsembleMetricHelper.METRIC_FSCORE,
0247: "Optimize with fscore"),
0248: new Tag(EnsembleMetricHelper.METRIC_ALL,
0249: "Optimize with all metrics"), };
0250:
0251: /**
0252: * The "enumeration" of the algorithms we can use. Forward - forward
0253: * selection. For hillclimb iterations,
0254: */
0255: public static final int ALGORITHM_FORWARD = 0;
0256:
0257: public static final int ALGORITHM_BACKWARD = 1;
0258:
0259: public static final int ALGORITHM_FORWARD_BACKWARD = 2;
0260:
0261: public static final int ALGORITHM_BEST = 3;
0262:
0263: public static final int ALGORITHM_BUILD_LIBRARY = 4;
0264:
0265: /** defines metrics that can be chosen for hillclimbing */
0266: public static final Tag[] TAGS_ALGORITHM = {
0267: new Tag(ALGORITHM_FORWARD, "Forward selection"),
0268: new Tag(ALGORITHM_BACKWARD, "Backward elimation"),
0269: new Tag(ALGORITHM_FORWARD_BACKWARD,
0270: "Forward Selection + Backward Elimination"),
0271: new Tag(ALGORITHM_BEST, "Best model"),
0272: new Tag(ALGORITHM_BUILD_LIBRARY, "Build Library Only") };
0273:
0274: /**
0275: * this specifies the number of "Ensembl-X" directories that are allowed to
0276: * be created in the users home directory where X is the number of the
0277: * ensemble
0278: */
0279: private static final int MAX_DEFAULT_DIRECTORIES = 1000;
0280:
0281: /**
0282: * The name of the Model Library File (if one is specified) which lists
0283: * models from which ensemble selection will choose. This is only used when
0284: * run from the command-line, as otherwise m_library is responsible for
0285: * this.
0286: */
0287: protected String m_modelLibraryFileName = null;
0288:
0289: /**
0290: * The number of "model bags". Using 1 is equivalent to no bagging at all.
0291: */
0292: protected int m_numModelBags = 10;
0293:
0294: /** The metric for which the ensemble will be optimized. */
0295: protected int m_hillclimbMetric = EnsembleMetricHelper.METRIC_RMSE;
0296:
0297: /** The algorithm used for ensemble selection. */
0298: protected int m_algorithm = ALGORITHM_FORWARD;
0299:
0300: /**
0301: * number of hillclimbing iterations for the ensemble selection algorithm
0302: */
0303: protected int m_hillclimbIterations = 100;
0304:
0305: /** ratio of library models to be used for sort initialization */
0306: protected double m_sortInitializationRatio = 1.0;
0307:
0308: /**
0309: * specifies whether or not the ensemble algorithm is allowed to include a
0310: * specific model in the library more than once in each ensemble
0311: */
0312: protected boolean m_replacement = true;
0313:
0314: /**
0315: * specifies whether we use "greedy" sort initialization. If false, we
0316: * simply add the best m_sortInitializationRatio models of the bag blindly.
0317: * If true, we add the best models in order up to m_sortInitializationRatio
0318: * until adding the next model would not help performance.
0319: */
0320: protected boolean m_greedySortInitialization = true;
0321:
0322: /**
0323: * Specifies whether or not we will output metrics for all models
0324: */
0325: protected boolean m_verboseOutput = false;
0326:
0327: /**
0328: * Hash map of cached predictions. The key is a stringified Instance. Each
0329: * entry is a 2d array, first indexed by classifier index (i.e., the one
0330: * used in m_chosen_model). The second index is the usual "distribution"
0331: * index across classes.
0332: */
0333: protected Map m_cachedPredictions = null;
0334:
0335: /**
0336: * This string will store the working directory where all models , temporary
0337: * prediction values, and modellist logs are to be built and stored.
0338: */
0339: protected File m_workingDirectory = new File(
0340: getDefaultWorkingDirectory());
0341:
0342: /**
0343: * Indicates the number of folds for cross-validation. A value of 1
0344: * indicates there is no cross-validation. Cross validation is done in the
0345: * "embedded" fashion described by Caruana, Niculescu, and Munson
0346: * (unpublished work - tech report forthcoming)
0347: */
0348: protected int m_NumFolds = 1;
0349:
0350: /**
0351: * Returns a string describing classifier
0352: *
0353: * @return a description suitable for displaying in the
0354: * explorer/experimenter gui
0355: */
0356: public String globalInfo() {
0357:
0358: return "Combines several classifiers using the ensemble "
0359: + "selection method. For more information, see: "
0360: + "Caruana, Rich, Niculescu, Alex, Crew, Geoff, and Ksikes, Alex, "
0361: + "Ensemble Selection from Libraries of Models, "
0362: + "The International Conference on Machine Learning (ICML'04), 2004. "
0363: + "Implemented in Weka by Bob Jung and David Michael.";
0364: }
0365:
0366: /**
0367: * Returns an enumeration describing the available options.
0368: *
0369: * @return an enumeration of all the available options.
0370: */
0371: public Enumeration listOptions() {
0372: Vector result = new Vector();
0373:
0374: result
0375: .addElement(new Option(
0376: "\tSpecifies the Model Library File, continuing the list of all models.",
0377: "L", 1, "-L </path/to/modelLibrary>"));
0378:
0379: result
0380: .addElement(new Option(
0381: "\tSpecifies the Working Directory, where all models will be stored.",
0382: "W", 1, "-W </path/to/working/directory>"));
0383:
0384: result.addElement(new Option(
0385: "\tSet the number of bags, i.e., number of iterations to run \n"
0386: + "\tthe ensemble selection algorithm.", "B",
0387: 1, "-B <numModelBags>"));
0388:
0389: result.addElement(new Option(
0390: "\tSet the ratio of library models that will be randomly chosen \n"
0391: + "\tto populate each bag of models.", "E", 1,
0392: "-E <modelRatio>"));
0393:
0394: result.addElement(new Option(
0395: "\tSet the ratio of the training data set that will be reserved \n"
0396: + "\tfor validation.", "V", 1,
0397: "-V <validationRatio>"));
0398:
0399: result.addElement(new Option(
0400: "\tSet the number of hillclimbing iterations to be performed \n"
0401: + "\ton each model bag.", "H", 1,
0402: "-H <hillClimbIterations>"));
0403:
0404: result
0405: .addElement(new Option(
0406: "\tSet the the ratio of the ensemble library that the sort \n"
0407: + "\tinitialization algorithm will be able to choose from while \n"
0408: + "\tinitializing the ensemble for each model bag",
0409: "I", 1, "-I <sortInitialization>"));
0410:
0411: result.addElement(new Option(
0412: "\tSets the number of cross-validation folds.", "X", 1,
0413: "-X <numFolds>"));
0414:
0415: result
0416: .addElement(new Option(
0417: "\tSpecify the metric that will be used for model selection \n"
0418: + "\tduring the hillclimbing algorithm.\n"
0419: + "\tValid metrics are: \n"
0420: + "\t\taccuracy, rmse, roc, precision, recall, fscore, all",
0421: "P", 1, "-P <hillclimbMettric>"));
0422:
0423: result
0424: .addElement(new Option(
0425: "\tSpecifies the algorithm to be used for ensemble selection. \n"
0426: + "\tValid algorithms are:\n"
0427: + "\t\t\"forward\" (default) for forward selection.\n"
0428: + "\t\t\"backward\" for backward elimination.\n"
0429: + "\t\t\"both\" for both forward and backward elimination.\n"
0430: + "\t\t\"best\" to simply print out top performer from the \n"
0431: + "\t\t ensemble library\n"
0432: + "\t\t\"library\" to only train the models in the ensemble \n"
0433: + "\t\t library", "A", 1,
0434: "-A <algorithm>"));
0435:
0436: result.addElement(new Option(
0437: "\tFlag whether or not models can be selected more than once \n"
0438: + "\tfor an ensemble.", "R", 0, "-R"));
0439:
0440: result
0441: .addElement(new Option(
0442: "\tWhether sort initialization greedily stops adding models \n"
0443: + "\twhen performance degrades.", "G",
0444: 0, "-G"));
0445:
0446: result.addElement(new Option(
0447: "\tFlag for verbose output. Prints out performance of all \n"
0448: + "\tselected models.", "O", 0, "-O"));
0449:
0450: // TODO - Add more options here
0451: Enumeration enu = super .listOptions();
0452: while (enu.hasMoreElements()) {
0453: result.addElement(enu.nextElement());
0454: }
0455:
0456: return result.elements();
0457: }
0458:
0459: /**
0460: * We return true for basically everything except for Missing class values,
0461: * because we can't really answer for all the models in our library. If any of
0462: * them don't work with the supplied data then we just trap the exception.
0463: *
0464: * @return the capabilities of this classifier
0465: */
0466: public Capabilities getCapabilities() {
0467: Capabilities result = super .getCapabilities(); // returns the object
0468: // from
0469: // weka.classifiers.Classifier
0470:
0471: // attributes
0472: result.enable(Capability.NOMINAL_ATTRIBUTES);
0473: result.enable(Capability.NUMERIC_ATTRIBUTES);
0474: result.enable(Capability.DATE_ATTRIBUTES);
0475: result.enable(Capability.MISSING_VALUES);
0476: result.enable(Capability.BINARY_ATTRIBUTES);
0477:
0478: // class
0479: result.enable(Capability.NOMINAL_CLASS);
0480: result.enable(Capability.NUMERIC_CLASS);
0481: result.enable(Capability.BINARY_CLASS);
0482:
0483: return result;
0484: }
0485:
0486: /**
0487: <!-- options-start -->
0488: * Valid options are: <p/>
0489: *
0490: * <pre> -L </path/to/modelLibrary>
0491: * Specifies the Model Library File, continuing the list of all models.</pre>
0492: *
0493: * <pre> -W </path/to/working/directory>
0494: * Specifies the Working Directory, where all models will be stored.</pre>
0495: *
0496: * <pre> -B <numModelBags>
0497: * Set the number of bags, i.e., number of iterations to run
0498: * the ensemble selection algorithm.</pre>
0499: *
0500: * <pre> -E <modelRatio>
0501: * Set the ratio of library models that will be randomly chosen
0502: * to populate each bag of models.</pre>
0503: *
0504: * <pre> -V <validationRatio>
0505: * Set the ratio of the training data set that will be reserved
0506: * for validation.</pre>
0507: *
0508: * <pre> -H <hillClimbIterations>
0509: * Set the number of hillclimbing iterations to be performed
0510: * on each model bag.</pre>
0511: *
0512: * <pre> -I <sortInitialization>
0513: * Set the the ratio of the ensemble library that the sort
0514: * initialization algorithm will be able to choose from while
0515: * initializing the ensemble for each model bag</pre>
0516: *
0517: * <pre> -X <numFolds>
0518: * Sets the number of cross-validation folds.</pre>
0519: *
0520: * <pre> -P <hillclimbMettric>
0521: * Specify the metric that will be used for model selection
0522: * during the hillclimbing algorithm.
0523: * Valid metrics are:
0524: * accuracy, rmse, roc, precision, recall, fscore, all</pre>
0525: *
0526: * <pre> -A <algorithm>
0527: * Specifies the algorithm to be used for ensemble selection.
0528: * Valid algorithms are:
0529: * "forward" (default) for forward selection.
0530: * "backward" for backward elimination.
0531: * "both" for both forward and backward elimination.
0532: * "best" to simply print out top performer from the
0533: * ensemble library
0534: * "library" to only train the models in the ensemble
0535: * library</pre>
0536: *
0537: * <pre> -R
0538: * Flag whether or not models can be selected more than once
0539: * for an ensemble.</pre>
0540: *
0541: * <pre> -G
0542: * Whether sort initialization greedily stops adding models
0543: * when performance degrades.</pre>
0544: *
0545: * <pre> -O
0546: * Flag for verbose output. Prints out performance of all
0547: * selected models.</pre>
0548: *
0549: * <pre> -S <num>
0550: * Random number seed.
0551: * (default 1)</pre>
0552: *
0553: * <pre> -D
0554: * If set, classifier is run in debug mode and
0555: * may output additional info to the console</pre>
0556: *
0557: <!-- options-end -->
0558: *
0559: * @param options
0560: * the list of options as an array of strings
0561: * @throws Exception
0562: * if an option is not supported
0563: */
0564: public void setOptions(String[] options) throws Exception {
0565: String tmpStr;
0566:
0567: tmpStr = Utils.getOption('L', options);
0568: if (tmpStr.length() != 0) {
0569: m_modelLibraryFileName = tmpStr;
0570: m_library = new EnsembleSelectionLibrary(
0571: m_modelLibraryFileName);
0572: } else {
0573: setLibrary(new EnsembleSelectionLibrary());
0574: // setLibrary(new Library(super.m_Classifiers));
0575: }
0576:
0577: tmpStr = Utils.getOption('W', options);
0578: if (tmpStr.length() != 0 && validWorkingDirectory(tmpStr)) {
0579: m_workingDirectory = new File(tmpStr);
0580: } else {
0581: m_workingDirectory = new File(getDefaultWorkingDirectory());
0582: }
0583: m_library.setWorkingDirectory(m_workingDirectory);
0584:
0585: tmpStr = Utils.getOption('E', options);
0586: if (tmpStr.length() != 0) {
0587: setModelRatio(Double.parseDouble(tmpStr));
0588: } else {
0589: setModelRatio(1.0);
0590: }
0591:
0592: tmpStr = Utils.getOption('V', options);
0593: if (tmpStr.length() != 0) {
0594: setValidationRatio(Double.parseDouble(tmpStr));
0595: } else {
0596: setValidationRatio(0.25);
0597: }
0598:
0599: tmpStr = Utils.getOption('B', options);
0600: if (tmpStr.length() != 0) {
0601: setNumModelBags(Integer.parseInt(tmpStr));
0602: } else {
0603: setNumModelBags(10);
0604: }
0605:
0606: tmpStr = Utils.getOption('H', options);
0607: if (tmpStr.length() != 0) {
0608: setHillclimbIterations(Integer.parseInt(tmpStr));
0609: } else {
0610: setHillclimbIterations(100);
0611: }
0612:
0613: tmpStr = Utils.getOption('I', options);
0614: if (tmpStr.length() != 0) {
0615: setSortInitializationRatio(Double.parseDouble(tmpStr));
0616: } else {
0617: setSortInitializationRatio(1.0);
0618: }
0619:
0620: tmpStr = Utils.getOption('X', options);
0621: if (tmpStr.length() != 0) {
0622: setNumFolds(Integer.parseInt(tmpStr));
0623: } else {
0624: setNumFolds(10);
0625: }
0626:
0627: setReplacement(Utils.getFlag('R', options));
0628:
0629: setGreedySortInitialization(Utils.getFlag('G', options));
0630:
0631: setVerboseOutput(Utils.getFlag('O', options));
0632:
0633: tmpStr = Utils.getOption('P', options);
0634: // if (hillclimbMetricString.length() != 0) {
0635:
0636: if (tmpStr.toLowerCase().equals("accuracy")) {
0637: setHillclimbMetric(new SelectedTag(
0638: EnsembleMetricHelper.METRIC_ACCURACY, TAGS_METRIC));
0639: } else if (tmpStr.toLowerCase().equals("rmse")) {
0640: setHillclimbMetric(new SelectedTag(
0641: EnsembleMetricHelper.METRIC_RMSE, TAGS_METRIC));
0642: } else if (tmpStr.toLowerCase().equals("roc")) {
0643: setHillclimbMetric(new SelectedTag(
0644: EnsembleMetricHelper.METRIC_ROC, TAGS_METRIC));
0645: } else if (tmpStr.toLowerCase().equals("precision")) {
0646: setHillclimbMetric(new SelectedTag(
0647: EnsembleMetricHelper.METRIC_PRECISION, TAGS_METRIC));
0648: } else if (tmpStr.toLowerCase().equals("recall")) {
0649: setHillclimbMetric(new SelectedTag(
0650: EnsembleMetricHelper.METRIC_RECALL, TAGS_METRIC));
0651: } else if (tmpStr.toLowerCase().equals("fscore")) {
0652: setHillclimbMetric(new SelectedTag(
0653: EnsembleMetricHelper.METRIC_FSCORE, TAGS_METRIC));
0654: } else if (tmpStr.toLowerCase().equals("all")) {
0655: setHillclimbMetric(new SelectedTag(
0656: EnsembleMetricHelper.METRIC_ALL, TAGS_METRIC));
0657: } else {
0658: setHillclimbMetric(new SelectedTag(
0659: EnsembleMetricHelper.METRIC_RMSE, TAGS_METRIC));
0660: }
0661:
0662: tmpStr = Utils.getOption('A', options);
0663: if (tmpStr.toLowerCase().equals("forward")) {
0664: setAlgorithm(new SelectedTag(ALGORITHM_FORWARD,
0665: TAGS_ALGORITHM));
0666: } else if (tmpStr.toLowerCase().equals("backward")) {
0667: setAlgorithm(new SelectedTag(ALGORITHM_BACKWARD,
0668: TAGS_ALGORITHM));
0669: } else if (tmpStr.toLowerCase().equals("both")) {
0670: setAlgorithm(new SelectedTag(ALGORITHM_FORWARD_BACKWARD,
0671: TAGS_ALGORITHM));
0672: } else if (tmpStr.toLowerCase().equals("forward")) {
0673: setAlgorithm(new SelectedTag(ALGORITHM_FORWARD,
0674: TAGS_ALGORITHM));
0675: } else if (tmpStr.toLowerCase().equals("best")) {
0676: setAlgorithm(new SelectedTag(ALGORITHM_BEST, TAGS_ALGORITHM));
0677: } else if (tmpStr.toLowerCase().equals("library")) {
0678: setAlgorithm(new SelectedTag(ALGORITHM_BUILD_LIBRARY,
0679: TAGS_ALGORITHM));
0680: } else {
0681: setAlgorithm(new SelectedTag(ALGORITHM_FORWARD,
0682: TAGS_ALGORITHM));
0683: }
0684:
0685: super .setOptions(options);
0686:
0687: m_library.setDebug(m_Debug);
0688: }
0689:
0690: /**
0691: * Gets the current settings of the Classifier.
0692: *
0693: * @return an array of strings suitable for passing to setOptions
0694: */
0695: public String[] getOptions() {
0696: Vector result;
0697: String[] options;
0698: int i;
0699:
0700: result = new Vector();
0701:
0702: if (m_library.getModelListFile() != null) {
0703: result.add("-L");
0704: result.add("" + m_library.getModelListFile());
0705: }
0706:
0707: if (!m_workingDirectory.equals("")) {
0708: result.add("-W");
0709: result.add("" + getWorkingDirectory());
0710: }
0711:
0712: result.add("-P");
0713: switch (getHillclimbMetric().getSelectedTag().getID()) {
0714: case (EnsembleMetricHelper.METRIC_ACCURACY):
0715: result.add("accuracy");
0716: break;
0717: case (EnsembleMetricHelper.METRIC_RMSE):
0718: result.add("rmse");
0719: break;
0720: case (EnsembleMetricHelper.METRIC_ROC):
0721: result.add("roc");
0722: break;
0723: case (EnsembleMetricHelper.METRIC_PRECISION):
0724: result.add("precision");
0725: break;
0726: case (EnsembleMetricHelper.METRIC_RECALL):
0727: result.add("recall");
0728: break;
0729: case (EnsembleMetricHelper.METRIC_FSCORE):
0730: result.add("fscore");
0731: break;
0732: case (EnsembleMetricHelper.METRIC_ALL):
0733: result.add("all");
0734: break;
0735: }
0736:
0737: result.add("-A");
0738: switch (getAlgorithm().getSelectedTag().getID()) {
0739: case (ALGORITHM_FORWARD):
0740: result.add("forward");
0741: break;
0742: case (ALGORITHM_BACKWARD):
0743: result.add("backward");
0744: break;
0745: case (ALGORITHM_FORWARD_BACKWARD):
0746: result.add("both");
0747: break;
0748: case (ALGORITHM_BEST):
0749: result.add("best");
0750: break;
0751: case (ALGORITHM_BUILD_LIBRARY):
0752: result.add("library");
0753: break;
0754: }
0755:
0756: result.add("-B");
0757: result.add("" + getNumModelBags());
0758: result.add("-V");
0759: result.add("" + getValidationRatio());
0760: result.add("-E");
0761: result.add("" + getModelRatio());
0762: result.add("-H");
0763: result.add("" + getHillclimbIterations());
0764: result.add("-I");
0765: result.add("" + getSortInitializationRatio());
0766: result.add("-X");
0767: result.add("" + getNumFolds());
0768:
0769: if (m_replacement)
0770: result.add("-R");
0771: if (m_greedySortInitialization)
0772: result.add("-G");
0773: if (m_verboseOutput)
0774: result.add("-O");
0775:
0776: options = super .getOptions();
0777: for (i = 0; i < options.length; i++)
0778: result.add(options[i]);
0779:
0780: return (String[]) result.toArray(new String[result.size()]);
0781: }
0782:
0783: /**
0784: * Returns the tip text for this property
0785: *
0786: * @return tip text for this property suitable for displaying in the
0787: * explorer/experimenter gui
0788: */
0789: public String numFoldsTipText() {
0790: return "The number of folds used for cross-validation.";
0791: }
0792:
0793: /**
0794: * Gets the number of folds for the cross-validation.
0795: *
0796: * @return the number of folds for the cross-validation
0797: */
0798: public int getNumFolds() {
0799: return m_NumFolds;
0800: }
0801:
0802: /**
0803: * Sets the number of folds for the cross-validation.
0804: *
0805: * @param numFolds
0806: * the number of folds for the cross-validation
0807: * @throws Exception
0808: * if parameter illegal
0809: */
0810: public void setNumFolds(int numFolds) throws Exception {
0811: if (numFolds < 0) {
0812: throw new IllegalArgumentException(
0813: "EnsembleSelection: Number of cross-validation "
0814: + "folds must be positive.");
0815: }
0816: m_NumFolds = numFolds;
0817: }
0818:
0819: /**
0820: * Returns the tip text for this property
0821: *
0822: * @return tip text for this property suitable for displaying in the
0823: * explorer/experimenter gui
0824: */
0825: public String libraryTipText() {
0826: return "An ensemble library.";
0827: }
0828:
0829: /**
0830: * Gets the ensemble library.
0831: *
0832: * @return the ensemble library
0833: */
0834: public EnsembleSelectionLibrary getLibrary() {
0835: return m_library;
0836: }
0837:
0838: /**
0839: * Sets the ensemble library.
0840: *
0841: * @param newLibrary
0842: * the ensemble library
0843: */
0844: public void setLibrary(EnsembleSelectionLibrary newLibrary) {
0845: m_library = newLibrary;
0846: m_library.setDebug(m_Debug);
0847: }
0848:
0849: /**
0850: * Returns the tip text for this property
0851: *
0852: * @return tip text for this property suitable for displaying in the
0853: * explorer/experimenter gui
0854: */
0855: public String modelRatioTipText() {
0856: return "The ratio of library models that will be randomly chosen to be used for each iteration.";
0857: }
0858:
0859: /**
0860: * Get the value of modelRatio.
0861: *
0862: * @return Value of modelRatio.
0863: */
0864: public double getModelRatio() {
0865: return m_modelRatio;
0866: }
0867:
0868: /**
0869: * Set the value of modelRatio.
0870: *
0871: * @param v
0872: * Value to assign to modelRatio.
0873: */
0874: public void setModelRatio(double v) {
0875: m_modelRatio = v;
0876: }
0877:
0878: /**
0879: * Returns the tip text for this property
0880: *
0881: * @return tip text for this property suitable for displaying in the
0882: * explorer/experimenter gui
0883: */
0884: public String validationRatioTipText() {
0885: return "The ratio of the training data set that will be reserved for validation.";
0886: }
0887:
0888: /**
0889: * Get the value of validationRatio.
0890: *
0891: * @return Value of validationRatio.
0892: */
0893: public double getValidationRatio() {
0894: return m_validationRatio;
0895: }
0896:
0897: /**
0898: * Set the value of validationRatio.
0899: *
0900: * @param v
0901: * Value to assign to validationRatio.
0902: */
0903: public void setValidationRatio(double v) {
0904: m_validationRatio = v;
0905: }
0906:
0907: /**
0908: * Returns the tip text for this property
0909: *
0910: * @return tip text for this property suitable for displaying in the
0911: * explorer/experimenter gui
0912: */
0913: public String hillclimbMetricTipText() {
0914: return "the metric that will be used to optimizer the chosen ensemble..";
0915: }
0916:
0917: /**
0918: * Gets the hill climbing metric. Will be one of METRIC_ACCURACY,
0919: * METRIC_RMSE, METRIC_ROC, METRIC_PRECISION, METRIC_RECALL, METRIC_FSCORE,
0920: * METRIC_ALL
0921: *
0922: * @return the hillclimbMetric
0923: */
0924: public SelectedTag getHillclimbMetric() {
0925: return new SelectedTag(m_hillclimbMetric, TAGS_METRIC);
0926: }
0927:
0928: /**
0929: * Sets the hill climbing metric. Will be one of METRIC_ACCURACY,
0930: * METRIC_RMSE, METRIC_ROC, METRIC_PRECISION, METRIC_RECALL, METRIC_FSCORE,
0931: * METRIC_ALL
0932: *
0933: * @param newType
0934: * the new hillclimbMetric
0935: */
0936: public void setHillclimbMetric(SelectedTag newType) {
0937: if (newType.getTags() == TAGS_METRIC) {
0938: m_hillclimbMetric = newType.getSelectedTag().getID();
0939: }
0940: }
0941:
0942: /**
0943: * Returns the tip text for this property
0944: *
0945: * @return tip text for this property suitable for displaying in the
0946: * explorer/experimenter gui
0947: */
0948: public String algorithmTipText() {
0949: return "the algorithm used to optimizer the ensemble";
0950: }
0951:
0952: /**
0953: * Gets the algorithm
0954: *
0955: * @return the algorithm
0956: */
0957: public SelectedTag getAlgorithm() {
0958: return new SelectedTag(m_algorithm, TAGS_ALGORITHM);
0959: }
0960:
0961: /**
0962: * Sets the Algorithm to use
0963: *
0964: * @param newType
0965: * the new algorithm
0966: */
0967: public void setAlgorithm(SelectedTag newType) {
0968: if (newType.getTags() == TAGS_ALGORITHM) {
0969: m_algorithm = newType.getSelectedTag().getID();
0970: }
0971: }
0972:
0973: /**
0974: * Returns the tip text for this property
0975: *
0976: * @return tip text for this property suitable for displaying in the
0977: * explorer/experimenter gui
0978: */
0979: public String hillclimbIterationsTipText() {
0980: return "The number of hillclimbing iterations for the ensemble selection algorithm.";
0981: }
0982:
0983: /**
0984: * Gets the number of hillclimbIterations.
0985: *
0986: * @return the number of hillclimbIterations
0987: */
0988: public int getHillclimbIterations() {
0989: return m_hillclimbIterations;
0990: }
0991:
0992: /**
0993: * Sets the number of hillclimbIterations.
0994: *
0995: * @param n
0996: * the number of hillclimbIterations
0997: * @throws Exception
0998: * if parameter illegal
0999: */
1000: public void setHillclimbIterations(int n) throws Exception {
1001: if (n < 0) {
1002: throw new IllegalArgumentException(
1003: "EnsembleSelection: Number of hillclimb iterations "
1004: + "must be positive.");
1005: }
1006: m_hillclimbIterations = n;
1007: }
1008:
1009: /**
1010: * Returns the tip text for this property
1011: *
1012: * @return tip text for this property suitable for displaying in the
1013: * explorer/experimenter gui
1014: */
1015: public String numModelBagsTipText() {
1016: return "The number of \"model bags\" used in the ensemble selection algorithm.";
1017: }
1018:
1019: /**
1020: * Gets numModelBags.
1021: *
1022: * @return numModelBags
1023: */
1024: public int getNumModelBags() {
1025: return m_numModelBags;
1026: }
1027:
1028: /**
1029: * Sets numModelBags.
1030: *
1031: * @param n
1032: * the new value for numModelBags
1033: * @throws Exception
1034: * if parameter illegal
1035: */
1036: public void setNumModelBags(int n) throws Exception {
1037: if (n <= 0) {
1038: throw new IllegalArgumentException(
1039: "EnsembleSelection: Number of model bags "
1040: + "must be positive.");
1041: }
1042: m_numModelBags = n;
1043: }
1044:
1045: /**
1046: * Returns the tip text for this property
1047: *
1048: * @return tip text for this property suitable for displaying in the
1049: * explorer/experimenter gui
1050: */
1051: public String sortInitializationRatioTipText() {
1052: return "The ratio of library models to be used for sort initialization.";
1053: }
1054:
1055: /**
1056: * Get the value of sortInitializationRatio.
1057: *
1058: * @return Value of sortInitializationRatio.
1059: */
1060: public double getSortInitializationRatio() {
1061: return m_sortInitializationRatio;
1062: }
1063:
1064: /**
1065: * Set the value of sortInitializationRatio.
1066: *
1067: * @param v
1068: * Value to assign to sortInitializationRatio.
1069: */
1070: public void setSortInitializationRatio(double v) {
1071: m_sortInitializationRatio = v;
1072: }
1073:
1074: /**
1075: * Returns the tip text for this property
1076: *
1077: * @return tip text for this property suitable for displaying in the
1078: * explorer/experimenter gui
1079: */
1080: public String replacementTipText() {
1081: return "Whether models in the library can be included more than once in an ensemble.";
1082: }
1083:
1084: /**
1085: * Get the value of replacement.
1086: *
1087: * @return Value of replacement.
1088: */
1089: public boolean getReplacement() {
1090: return m_replacement;
1091: }
1092:
1093: /**
1094: * Set the value of replacement.
1095: *
1096: * @param newReplacement
1097: * Value to assign to replacement.
1098: */
1099: public void setReplacement(boolean newReplacement) {
1100: m_replacement = newReplacement;
1101: }
1102:
1103: /**
1104: * Returns the tip text for this property
1105: *
1106: * @return tip text for this property suitable for displaying in the
1107: * explorer/experimenter gui
1108: */
1109: public String greedySortInitializationTipText() {
1110: return "Whether sort initialization greedily stops adding models when performance degrades.";
1111: }
1112:
1113: /**
1114: * Get the value of greedySortInitialization.
1115: *
1116: * @return Value of replacement.
1117: */
1118: public boolean getGreedySortInitialization() {
1119: return m_greedySortInitialization;
1120: }
1121:
1122: /**
1123: * Set the value of greedySortInitialization.
1124: *
1125: * @param newGreedySortInitialization
1126: * Value to assign to replacement.
1127: */
1128: public void setGreedySortInitialization(
1129: boolean newGreedySortInitialization) {
1130: m_greedySortInitialization = newGreedySortInitialization;
1131: }
1132:
1133: /**
1134: * Returns the tip text for this property
1135: *
1136: * @return tip text for this property suitable for displaying in the
1137: * explorer/experimenter gui
1138: */
1139: public String verboseOutputTipText() {
1140: return "Whether metrics are printed for each model.";
1141: }
1142:
1143: /**
1144: * Get the value of verboseOutput.
1145: *
1146: * @return Value of verboseOutput.
1147: */
1148: public boolean getVerboseOutput() {
1149: return m_verboseOutput;
1150: }
1151:
1152: /**
1153: * Set the value of verboseOutput.
1154: *
1155: * @param newVerboseOutput
1156: * Value to assign to verboseOutput.
1157: */
1158: public void setVerboseOutput(boolean newVerboseOutput) {
1159: m_verboseOutput = newVerboseOutput;
1160: }
1161:
1162: /**
1163: * Returns the tip text for this property
1164: *
1165: * @return tip text for this property suitable for displaying in the
1166: * explorer/experimenter gui
1167: */
1168: public String workingDirectoryTipText() {
1169: return "The working directory of the ensemble - where trained models will be stored.";
1170: }
1171:
1172: /**
1173: * Get the value of working directory.
1174: *
1175: * @return Value of working directory.
1176: */
1177: public File getWorkingDirectory() {
1178: return m_workingDirectory;
1179: }
1180:
1181: /**
1182: * Set the value of working directory.
1183: *
1184: * @param newWorkingDirectory directory Value.
1185: */
1186: public void setWorkingDirectory(File newWorkingDirectory) {
1187: if (m_Debug) {
1188: System.out.println("working directory changed to: "
1189: + newWorkingDirectory);
1190: }
1191: m_library.setWorkingDirectory(newWorkingDirectory);
1192:
1193: m_workingDirectory = newWorkingDirectory;
1194: }
1195:
1196: /**
1197: * Buildclassifier selects a classifier from the set of classifiers by
1198: * minimising error on the training data.
1199: *
1200: * @param trainData the training data to be used for generating the boosted
1201: * classifier.
1202: * @throws Exception if the classifier could not be built successfully
1203: */
1204: public void buildClassifier(Instances trainData) throws Exception {
1205:
1206: getCapabilities().testWithFail(trainData);
1207:
1208: // First we need to make sure that some library models
1209: // were specified. If not, then use the default list
1210: if (m_library.m_Models.size() == 0) {
1211:
1212: System.out
1213: .println("WARNING: No library file specified. Using some default models.");
1214: System.out
1215: .println("You should specify a model list with -L <file> from the command line.");
1216: System.out
1217: .println("Or edit the list directly with the LibraryEditor from the GUI");
1218:
1219: for (int i = 0; i < 10; i++) {
1220:
1221: REPTree tree = new REPTree();
1222: tree.setSeed(i);
1223: m_library.addModel(new EnsembleSelectionLibraryModel(
1224: tree));
1225:
1226: }
1227:
1228: }
1229:
1230: if (m_library == null) {
1231: m_library = new EnsembleSelectionLibrary();
1232: m_library.setDebug(m_Debug);
1233: }
1234:
1235: m_library.setNumFolds(getNumFolds());
1236: m_library.setValidationRatio(getValidationRatio());
1237: // train all untrained models, and set "data" to the hillclimbing set.
1238: Instances data = m_library.trainAll(trainData,
1239: m_workingDirectory.getAbsolutePath(), m_algorithm);
1240: // We cache the hillclimb predictions from all of the models in
1241: // the library so that we can evaluate their performances when we
1242: // combine them
1243: // in various ways (without needing to keep the classifiers in memory).
1244: double predictions[][][] = m_library.getHillclimbPredictions();
1245: int numModels = predictions.length;
1246: int modelWeights[] = new int[numModels];
1247: m_total_weight = 0;
1248: Random rand = new Random(m_Seed);
1249:
1250: if (m_algorithm == ALGORITHM_BUILD_LIBRARY) {
1251: return;
1252:
1253: } else if (m_algorithm == ALGORITHM_BEST) {
1254: // If we want to choose the best model, just make a model bag that
1255: // includes all the models, then sort initialize to find the 1 that
1256: // performs best.
1257: ModelBag model_bag = new ModelBag(predictions, 1.0, m_Debug);
1258: int[] modelPicked = model_bag.sortInitialize(1, false,
1259: data, m_hillclimbMetric);
1260: // Then give it a weight of 1, while all others remain 0.
1261: modelWeights[modelPicked[0]] = 1;
1262: } else {
1263:
1264: if (m_Debug)
1265: System.out.println("Starting hillclimbing algorithm: "
1266: + m_algorithm);
1267:
1268: for (int i = 0; i < getNumModelBags(); ++i) {
1269: // For the number of bags,
1270: if (m_Debug)
1271: System.out
1272: .println("Starting on ensemble bag: " + i);
1273: // Create a new bag of the appropriate size
1274: ModelBag modelBag = new ModelBag(predictions,
1275: getModelRatio(), m_Debug);
1276: // And shuffle it.
1277: modelBag.shuffle(rand);
1278: if (getSortInitializationRatio() > 0.0) {
1279: // Sort initialize, if the ratio greater than 0.
1280: modelBag.sortInitialize(
1281: (int) (getSortInitializationRatio()
1282: * getModelRatio() * numModels),
1283: getGreedySortInitialization(), data,
1284: m_hillclimbMetric);
1285: }
1286:
1287: if (m_algorithm == ALGORITHM_BACKWARD) {
1288: // If we're doing backwards elimination, we just give all
1289: // models
1290: // a weight of 1 initially. If the # of hillclimb iterations
1291: // is too high, we'll end up with just one model in the end
1292: // (we never delete all models from a bag). TODO - it might
1293: // be
1294: // smarter to base this weight off of how many models we
1295: // have.
1296: modelBag.weightAll(1); // for now at least, I'm just
1297: // assuming 1.
1298: }
1299: // Now the bag is initialized, and we're ready to hillclimb.
1300: for (int j = 0; j < getHillclimbIterations(); ++j) {
1301: if (m_algorithm == ALGORITHM_FORWARD) {
1302: modelBag.forwardSelect(getReplacement(), data,
1303: m_hillclimbMetric);
1304: } else if (m_algorithm == ALGORITHM_BACKWARD) {
1305: modelBag.backwardEliminate(data,
1306: m_hillclimbMetric);
1307: } else if (m_algorithm == ALGORITHM_FORWARD_BACKWARD) {
1308: modelBag.forwardSelectOrBackwardEliminate(
1309: getReplacement(), data,
1310: m_hillclimbMetric);
1311: }
1312: }
1313: // Now that we've done all the hillclimbing steps, we can just
1314: // get
1315: // the model weights that the bag determined, and add them to
1316: // our
1317: // running total.
1318: int[] bagWeights = modelBag.getModelWeights();
1319: for (int j = 0; j < bagWeights.length; ++j) {
1320: modelWeights[j] += bagWeights[j];
1321: }
1322: }
1323: }
1324: // Now we've done the hard work of actually learning the ensemble. Now
1325: // we set up the appropriate data structures so that Ensemble Selection
1326: // can
1327: // make predictions for future test examples.
1328: Set modelNames = m_library.getModelNames();
1329: String[] modelNamesArray = new String[m_library.size()];
1330: Iterator iter = modelNames.iterator();
1331: // libraryIndex indexes over all the models in the library (not just
1332: // those
1333: // which we chose for the ensemble).
1334: int libraryIndex = 0;
1335: // chosenModels will count the total number of models which were
1336: // selected
1337: // by EnsembleSelection (those that have non-zero weight).
1338: int chosenModels = 0;
1339: while (iter.hasNext()) {
1340: // Note that we have to be careful of order. Our model_weights array
1341: // is in the same order as our list of models in m_library.
1342:
1343: // Get the name of the model,
1344: modelNamesArray[libraryIndex] = (String) iter.next();
1345: // and its weight.
1346: int weightOfModel = modelWeights[libraryIndex++];
1347: m_total_weight += weightOfModel;
1348: if (weightOfModel > 0) {
1349: // If the model was chosen at least once, increment the
1350: // number of chosen models.
1351: ++chosenModels;
1352: }
1353: }
1354: if (m_verboseOutput) {
1355: // Output every model and its performance with respect to the
1356: // validation
1357: // data.
1358: ModelBag bag = new ModelBag(predictions, 1.0, m_Debug);
1359: int modelIndexes[] = bag.sortInitialize(
1360: modelNamesArray.length, false, data,
1361: m_hillclimbMetric);
1362: double modelPerformance[] = bag.getIndividualPerformance(
1363: data, m_hillclimbMetric);
1364: for (int i = 0; i < modelIndexes.length; ++i) {
1365: // TODO - Could do this in a more readable way.
1366: System.out.println("" + modelPerformance[i] + " "
1367: + modelNamesArray[modelIndexes[i]]);
1368: }
1369: }
1370: // We're now ready to build our array of the models which were chosen
1371: // and there associated weights.
1372: m_chosen_models = new EnsembleSelectionLibraryModel[chosenModels];
1373: m_chosen_model_weights = new int[chosenModels];
1374:
1375: libraryIndex = 0;
1376: // chosenIndex indexes over the models which were chosen by
1377: // EnsembleSelection
1378: // (those which have non-zero weight).
1379: int chosenIndex = 0;
1380: iter = m_library.getModels().iterator();
1381: while (iter.hasNext()) {
1382: int weightOfModel = modelWeights[libraryIndex++];
1383:
1384: EnsembleSelectionLibraryModel model = (EnsembleSelectionLibraryModel) iter
1385: .next();
1386:
1387: if (weightOfModel > 0) {
1388: // If the model was chosen at least once, add it to our array
1389: // of chosen models and weights.
1390: m_chosen_models[chosenIndex] = model;
1391: m_chosen_model_weights[chosenIndex] = weightOfModel;
1392: // Note that the EnsembleSelectionLibraryModel may not be
1393: // "loaded" -
1394: // that is, its classifier(s) may be null pointers. That's okay
1395: // -
1396: // we'll "rehydrate" them later, if and when we need to.
1397: ++chosenIndex;
1398: }
1399: }
1400: }
1401:
1402: /**
1403: * Calculates the class membership probabilities for the given test instance.
1404: *
1405: * @param instance the instance to be classified
1406: * @return predicted class probability distribution
1407: * @throws Exception if instance could not be classified
1408: * successfully
1409: */
1410: public double[] distributionForInstance(Instance instance)
1411: throws Exception {
1412: String stringInstance = instance.toString();
1413: double cachedPreds[][] = null;
1414:
1415: if (m_cachedPredictions != null) {
1416: // If we have any cached predictions (i.e., if cachePredictions was
1417: // called), look for a cached set of predictions for this instance.
1418: if (m_cachedPredictions.containsKey(stringInstance)) {
1419: cachedPreds = (double[][]) m_cachedPredictions
1420: .get(stringInstance);
1421: }
1422: }
1423: double[] prediction = new double[instance.numClasses()];
1424: for (int i = 0; i < prediction.length; ++i) {
1425: prediction[i] = 0.0;
1426: }
1427:
1428: // Now do a weighted average of the predictions of each of our models.
1429: for (int i = 0; i < m_chosen_models.length; ++i) {
1430: double[] predictionForThisModel = null;
1431: if (cachedPreds == null) {
1432: // If there are no predictions cached, we'll load the model's
1433: // classifier(s) in to memory and get the predictions.
1434: m_chosen_models[i].rehydrateModel(m_workingDirectory
1435: .getAbsolutePath());
1436: predictionForThisModel = m_chosen_models[i]
1437: .getAveragePrediction(instance);
1438: // We could release the model here to save memory, but we assume
1439: // that there is enough available since we're not using the
1440: // prediction caching functionality. If we load and release a
1441: // model
1442: // every time we need to get a prediction for an instance, it
1443: // can be
1444: // prohibitively slow.
1445: } else {
1446: // If it's cached, just get it from the array of cached preds
1447: // for this instance.
1448: predictionForThisModel = cachedPreds[i];
1449: }
1450: // We have encountered a bug where MultilayerPerceptron returns a
1451: // null
1452: // prediction array. If that happens, we just don't count that model
1453: // in
1454: // our ensemble prediction.
1455: if (predictionForThisModel != null) {
1456: // Okay, the model returned a valid prediction array, so we'll
1457: // add the appropriate fraction of this model's prediction.
1458: for (int j = 0; j < prediction.length; ++j) {
1459: prediction[j] += m_chosen_model_weights[i]
1460: * predictionForThisModel[j]
1461: / m_total_weight;
1462: }
1463: }
1464: }
1465: // normalize to add up to 1.
1466: if (instance.classAttribute().isNominal()) {
1467: if (Utils.sum(prediction) > 0)
1468: Utils.normalize(prediction);
1469: }
1470: return prediction;
1471: }
1472:
1473: /**
1474: * This function tests whether or not a given path is appropriate for being
1475: * the working directory. Specifically, we care that we can write to the
1476: * path and that it doesn't point to a "non-directory" file handle.
1477: *
1478: * @param dir the directory to test
1479: * @return true if the directory is valid
1480: */
1481: private boolean validWorkingDirectory(String dir) {
1482:
1483: boolean valid = false;
1484:
1485: File f = new File((dir));
1486:
1487: if (f.exists()) {
1488: if (f.isDirectory() && f.canWrite())
1489: valid = true;
1490: } else {
1491: if (f.canWrite())
1492: valid = true;
1493: }
1494:
1495: return valid;
1496:
1497: }
1498:
1499: /**
1500: * This method tries to find a reasonable path name for the ensemble working
1501: * directory where models and files will be stored.
1502: *
1503: *
1504: * @return true if m_workingDirectory now has a valid file name
1505: */
1506: public static String getDefaultWorkingDirectory() {
1507:
1508: String defaultDirectory = new String("");
1509:
1510: boolean success = false;
1511:
1512: int i = 1;
1513:
1514: while (i < MAX_DEFAULT_DIRECTORIES && !success) {
1515:
1516: File f = new File(System.getProperty("user.home"),
1517: "Ensemble-" + i);
1518:
1519: if (!f.exists() && f.getParentFile().canWrite()) {
1520: defaultDirectory = f.getPath();
1521: success = true;
1522: }
1523: i++;
1524:
1525: }
1526:
1527: if (!success) {
1528: defaultDirectory = new String("");
1529: // should we print an error or something?
1530: }
1531:
1532: return defaultDirectory;
1533: }
1534:
1535: /**
1536: * Output a representation of this classifier
1537: *
1538: * @return a string representation of the classifier
1539: */
1540: public String toString() {
1541: // We just print out the models which were selected, and the number
1542: // of times each was selected.
1543: String result = new String();
1544: if (m_chosen_models != null) {
1545: for (int i = 0; i < m_chosen_models.length; ++i) {
1546: result += m_chosen_model_weights[i];
1547: result += " "
1548: + m_chosen_models[i].getStringRepresentation()
1549: + "\n";
1550: }
1551: } else {
1552: result = "No models selected.";
1553: }
1554: return result;
1555: }
1556:
1557: /**
1558: * Cache predictions for the individual base classifiers in the ensemble
1559: * with respect to the given dataset. This is used so that when testing a
1560: * large ensemble on a test set, we don't have to keep the models in memory.
1561: *
1562: * @param test The instances for which to cache predictions.
1563: * @throws Exception if somethng goes wrong
1564: */
1565: private void cachePredictions(Instances test) throws Exception {
1566: m_cachedPredictions = new HashMap();
1567: Evaluation evalModel = null;
1568: Instances originalInstances = null;
1569: // If the verbose flag is set, we'll also print out the performances of
1570: // all the individual models w.r.t. this test set while we're at it.
1571: boolean printModelPerformances = getVerboseOutput();
1572: if (printModelPerformances) {
1573: // To get performances, we need to keep the class attribute.
1574: originalInstances = new Instances(test);
1575: }
1576:
1577: // For each model, we'll go through the dataset and get predictions.
1578: // The idea is we want to only have one model in memory at a time, so
1579: // we'll
1580: // load one model in to memory, get all its predictions, and add them to
1581: // the
1582: // hash map. Then we can release it from memory and move on to the next.
1583: for (int i = 0; i < m_chosen_models.length; ++i) {
1584: if (printModelPerformances) {
1585: // If we're going to print predictions, we need to make a new
1586: // Evaluation object.
1587: evalModel = new Evaluation(originalInstances);
1588: }
1589:
1590: Date startTime = new Date();
1591:
1592: // Load the model in to memory.
1593: m_chosen_models[i].rehydrateModel(m_workingDirectory
1594: .getAbsolutePath());
1595: // Now loop through all the instances and get the model's
1596: // predictions.
1597: for (int j = 0; j < test.numInstances(); ++j) {
1598: Instance currentInstance = test.instance(j);
1599: // When we're looking for a cached prediction later, we'll only
1600: // have the non-class attributes, so we set the class missing
1601: // here
1602: // in order to make the string match up properly.
1603: currentInstance.setClassMissing();
1604: String stringInstance = currentInstance.toString();
1605:
1606: // When we come in here with the first model, the instance will
1607: // not
1608: // yet be part of the map.
1609: if (!m_cachedPredictions.containsKey(stringInstance)) {
1610: // The instance isn't in the map yet, so add it.
1611: // For each instance, we store a two-dimensional array - the
1612: // first
1613: // index is over all the models in the ensemble, and the
1614: // second
1615: // index is over the (i.e., typical prediction array).
1616: int predSize = test.classAttribute().isNumeric() ? 1
1617: : test.classAttribute().numValues();
1618: double predictionArray[][] = new double[m_chosen_models.length][predSize];
1619: m_cachedPredictions.put(stringInstance,
1620: predictionArray);
1621: }
1622: // Get the array from the map which is associated with this
1623: // instance
1624: double predictions[][] = (double[][]) m_cachedPredictions
1625: .get(stringInstance);
1626: // And add our model's prediction for it.
1627: predictions[i] = m_chosen_models[i]
1628: .getAveragePrediction(test.instance(j));
1629:
1630: if (printModelPerformances) {
1631: evalModel.evaluateModelOnceAndRecordPrediction(
1632: predictions[i], originalInstances
1633: .instance(j));
1634: }
1635: }
1636: // Now we're done with model #i, so we can release it.
1637: m_chosen_models[i].releaseModel();
1638:
1639: Date endTime = new Date();
1640: long diff = endTime.getTime() - startTime.getTime();
1641:
1642: if (m_Debug)
1643: System.out.println("Test time for "
1644: + m_chosen_models[i].getStringRepresentation()
1645: + " was: " + diff);
1646:
1647: if (printModelPerformances) {
1648: String output = new String(m_chosen_models[i]
1649: .getStringRepresentation()
1650: + ": ");
1651: output += "\tRMSE:" + evalModel.rootMeanSquaredError();
1652: output += "\tACC:" + evalModel.pctCorrect();
1653: if (test.numClasses() == 2) {
1654: // For multiclass problems, we could print these too, but
1655: // it's
1656: // not clear which class we should use in that case... so
1657: // instead
1658: // we only print these metrics for binary classification
1659: // problems.
1660: output += "\tROC:" + evalModel.areaUnderROC(1);
1661: output += "\tPREC:" + evalModel.precision(1);
1662: output += "\tFSCR:" + evalModel.fMeasure(1);
1663: }
1664: System.out.println(output);
1665: }
1666: }
1667: }
1668:
1669: /**
1670: * Return the technical information. There is actually another
1671: * paper that describes our current method of CV for this classifier
1672: * TODO: Cite Technical report when published
1673: *
1674: * @return the technical information about this class
1675: */
1676: public TechnicalInformation getTechnicalInformation() {
1677:
1678: TechnicalInformation result;
1679:
1680: result = new TechnicalInformation(Type.INPROCEEDINGS);
1681: result
1682: .setValue(Field.AUTHOR,
1683: "Rich Caruana, Alex Niculescu, Geoff Crew, and Alex Ksikes");
1684: result.setValue(Field.TITLE,
1685: "Ensemble Selection from Libraries of Models");
1686: result.setValue(Field.BOOKTITLE,
1687: "21st International Conference on Machine Learning");
1688: result.setValue(Field.YEAR, "2004");
1689:
1690: return result;
1691: }
1692:
1693: /**
1694: * Executes the classifier from commandline.
1695: *
1696: * @param argv
1697: * should contain the following arguments: -t training file [-T
1698: * test file] [-c class index]
1699: */
1700: public static void main(String[] argv) {
1701:
1702: try {
1703:
1704: String options[] = (String[]) argv.clone();
1705:
1706: // do we get the input from XML instead of normal parameters?
1707: String xml = Utils.getOption("xml", options);
1708: if (!xml.equals(""))
1709: options = new XMLOptions(xml).toArray();
1710:
1711: String trainFileName = Utils.getOption('t', options);
1712: String objectInputFileName = Utils.getOption('l', options);
1713: String testFileName = Utils.getOption('T', options);
1714:
1715: if (testFileName.length() != 0
1716: && objectInputFileName.length() != 0
1717: && trainFileName.length() == 0) {
1718:
1719: System.out.println("Caching predictions");
1720:
1721: EnsembleSelection classifier = null;
1722:
1723: BufferedReader testReader = new BufferedReader(
1724: new FileReader(testFileName));
1725:
1726: // Set up the Instances Object
1727: Instances test;
1728: int classIndex = -1;
1729: String classIndexString = Utils.getOption('c', options);
1730: if (classIndexString.length() != 0) {
1731: classIndex = Integer.parseInt(classIndexString);
1732: }
1733:
1734: test = new Instances(testReader, 1);
1735: if (classIndex != -1) {
1736: test.setClassIndex(classIndex - 1);
1737: } else {
1738: test.setClassIndex(test.numAttributes() - 1);
1739: }
1740: if (classIndex > test.numAttributes()) {
1741: throw new Exception(
1742: "Index of class attribute too large.");
1743: }
1744:
1745: while (test.readInstance(testReader)) {
1746:
1747: }
1748: testReader.close();
1749:
1750: // Now yoink the EnsembleSelection Object from the fileSystem
1751:
1752: InputStream is = new FileInputStream(
1753: objectInputFileName);
1754: if (objectInputFileName.endsWith(".gz")) {
1755: is = new GZIPInputStream(is);
1756: }
1757:
1758: // load from KOML?
1759: if (!(objectInputFileName
1760: .endsWith("UpdateableClassifier.koml") && KOML
1761: .isPresent())) {
1762: ObjectInputStream objectInputStream = new ObjectInputStream(
1763: is);
1764: classifier = (EnsembleSelection) objectInputStream
1765: .readObject();
1766: objectInputStream.close();
1767: } else {
1768: BufferedInputStream xmlInputStream = new BufferedInputStream(
1769: is);
1770: classifier = (EnsembleSelection) KOML
1771: .read(xmlInputStream);
1772: xmlInputStream.close();
1773: }
1774:
1775: String workingDir = Utils.getOption('W', argv);
1776: if (!workingDir.equals("")) {
1777: classifier
1778: .setWorkingDirectory(new File(workingDir));
1779: }
1780:
1781: classifier.setDebug(Utils.getFlag('D', argv));
1782: classifier.setVerboseOutput(Utils.getFlag('O', argv));
1783:
1784: classifier.cachePredictions(test);
1785:
1786: // Now we write the model back out to the file system.
1787: String objectOutputFileName = objectInputFileName;
1788: OutputStream os = new FileOutputStream(
1789: objectOutputFileName);
1790: // binary
1791: if (!(objectOutputFileName.endsWith(".xml") || (objectOutputFileName
1792: .endsWith(".koml") && KOML.isPresent()))) {
1793: if (objectOutputFileName.endsWith(".gz")) {
1794: os = new GZIPOutputStream(os);
1795: }
1796: ObjectOutputStream objectOutputStream = new ObjectOutputStream(
1797: os);
1798: objectOutputStream.writeObject(classifier);
1799: objectOutputStream.flush();
1800: objectOutputStream.close();
1801: }
1802: // KOML/XML
1803: else {
1804: BufferedOutputStream xmlOutputStream = new BufferedOutputStream(
1805: os);
1806: if (objectOutputFileName.endsWith(".xml")) {
1807: XMLSerialization xmlSerial = new XMLClassifier();
1808: xmlSerial.write(xmlOutputStream, classifier);
1809: } else
1810: // whether KOML is present has already been checked
1811: // if not present -> ".koml" is interpreted as binary - see
1812: // above
1813: if (objectOutputFileName.endsWith(".koml")) {
1814: KOML.write(xmlOutputStream, classifier);
1815: }
1816: xmlOutputStream.close();
1817: }
1818:
1819: }
1820:
1821: System.out.println(Evaluation.evaluateModel(
1822: new EnsembleSelection(), argv));
1823:
1824: } catch (Exception e) {
1825: if ((e.getMessage() != null)
1826: && (e.getMessage().indexOf("General options") == -1))
1827: e.printStackTrace();
1828: else
1829: System.err.println(e.getMessage());
1830: }
1831: }
1832: }
|