0001: /*
0002: * This program is free software; you can redistribute it and/or modify
0003: * it under the terms of the GNU General Public License as published by
0004: * the Free Software Foundation; either version 2 of the License, or
0005: * (at your option) any later version.
0006: *
0007: * This program is distributed in the hope that it will be useful,
0008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
0009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
0010: * GNU General Public License for more details.
0011: *
0012: * You should have received a copy of the GNU General Public License
0013: * along with this program; if not, write to the Free Software
0014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
0015: */
0016:
0017: /*
0018: * RacedIncrementalLogitBoost.java
0019: * Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
0020: *
0021: */
0022:
0023: package weka.classifiers.meta;
0024:
0025: import weka.classifiers.Classifier;
0026: import weka.classifiers.RandomizableSingleClassifierEnhancer;
0027: import weka.classifiers.UpdateableClassifier;
0028: import weka.classifiers.rules.ZeroR;
0029: import weka.core.Attribute;
0030: import weka.core.Capabilities;
0031: import weka.core.FastVector;
0032: import weka.core.Instance;
0033: import weka.core.Instances;
0034: import weka.core.Option;
0035: import weka.core.SelectedTag;
0036: import weka.core.Tag;
0037: import weka.core.Utils;
0038: import weka.core.WeightedInstancesHandler;
0039: import weka.core.Capabilities.Capability;
0040:
0041: import java.io.Serializable;
0042: import java.util.Enumeration;
0043: import java.util.Random;
0044: import java.util.Vector;
0045:
0046: /**
0047: <!-- globalinfo-start -->
0048: * Classifier for incremental learning of large datasets by way of racing logit-boosted committees.
0049: * <p/>
0050: <!-- globalinfo-end -->
0051: *
0052: <!-- options-start -->
0053: * Valid options are: <p/>
0054: *
0055: * <pre> -C <num>
0056: * Minimum size of chunks.
0057: * (default 500)</pre>
0058: *
0059: * <pre> -M <num>
0060: * Maximum size of chunks.
0061: * (default 2000)</pre>
0062: *
0063: * <pre> -V <num>
0064: * Size of validation set.
0065: * (default 1000)</pre>
0066: *
0067: * <pre> -P <pruning type>
0068: * Committee pruning to perform.
0069: * 0=none, 1=log likelihood (default)</pre>
0070: *
0071: * <pre> -Q
0072: * Use resampling for boosting.</pre>
0073: *
0074: * <pre> -S <num>
0075: * Random number seed.
0076: * (default 1)</pre>
0077: *
0078: * <pre> -D
0079: * If set, classifier is run in debug mode and
0080: * may output additional info to the console</pre>
0081: *
0082: * <pre> -W
0083: * Full name of base classifier.
0084: * (default: weka.classifiers.trees.DecisionStump)</pre>
0085: *
0086: * <pre>
0087: * Options specific to classifier weka.classifiers.trees.DecisionStump:
0088: * </pre>
0089: *
0090: * <pre> -D
0091: * If set, classifier is run in debug mode and
0092: * may output additional info to the console</pre>
0093: *
0094: <!-- options-end -->
0095: *
0096: * Options after -- are passed to the designated learner.<p>
0097: *
0098: * @author Richard Kirkby (rkirkby@cs.waikato.ac.nz)
0099: * @author Eibe Frank (eibe@cs.waikato.ac.nz)
0100: * @version $Revision: 1.11 $
0101: */
0102: public class RacedIncrementalLogitBoost extends
0103: RandomizableSingleClassifierEnhancer implements
0104: UpdateableClassifier {
0105:
0106: /** for serialization */
0107: static final long serialVersionUID = 908598343772170052L;
0108:
0109: /** no pruning */
0110: public static final int PRUNETYPE_NONE = 0;
0111: /** log likelihood pruning */
0112: public static final int PRUNETYPE_LOGLIKELIHOOD = 1;
0113: /** The pruning types */
0114: public static final Tag[] TAGS_PRUNETYPE = {
0115: new Tag(PRUNETYPE_NONE, "No pruning"),
0116: new Tag(PRUNETYPE_LOGLIKELIHOOD, "Log likelihood pruning") };
0117:
0118: /** The committees */
0119: protected FastVector m_committees;
0120:
0121: /** The pruning type used */
0122: protected int m_PruningType = PRUNETYPE_LOGLIKELIHOOD;
0123:
0124: /** Whether to use resampling */
0125: protected boolean m_UseResampling = false;
0126:
0127: /** The number of classes */
0128: protected int m_NumClasses;
0129:
0130: /** A threshold for responses (Friedman suggests between 2 and 4) */
0131: protected static final double Z_MAX = 4;
0132:
0133: /** Dummy dataset with a numeric class */
0134: protected Instances m_NumericClassData;
0135:
0136: /** The actual class attribute (for getting class names) */
0137: protected Attribute m_ClassAttribute;
0138:
0139: /** The minimum chunk size used for training */
0140: protected int m_minChunkSize = 500;
0141:
0142: /** The maimum chunk size used for training */
0143: protected int m_maxChunkSize = 2000;
0144:
0145: /** The size of the validation set */
0146: protected int m_validationChunkSize = 1000;
0147:
0148: /** The number of instances consumed */
0149: protected int m_numInstancesConsumed;
0150:
0151: /** The instances used for validation */
0152: protected Instances m_validationSet;
0153:
0154: /** The instances currently in memory for training */
0155: protected Instances m_currentSet;
0156:
0157: /** The current best committee */
0158: protected Committee m_bestCommittee;
0159:
0160: /** The default scheme used when committees aren't ready */
0161: protected ZeroR m_zeroR = null;
0162:
0163: /** Whether the validation set has recently been changed */
0164: protected boolean m_validationSetChanged;
0165:
0166: /** The maximum number of instances required for processing */
0167: protected int m_maxBatchSizeRequired;
0168:
0169: /** The random number generator used */
0170: protected Random m_RandomInstance = null;
0171:
0172: /**
0173: * Constructor.
0174: */
0175: public RacedIncrementalLogitBoost() {
0176:
0177: m_Classifier = new weka.classifiers.trees.DecisionStump();
0178: }
0179:
0180: /**
0181: * String describing default classifier.
0182: *
0183: * @return the default classifier classname
0184: */
0185: protected String defaultClassifierString() {
0186:
0187: return "weka.classifiers.trees.DecisionStump";
0188: }
0189:
0190: /**
0191: * Class representing a committee of LogitBoosted models
0192: */
0193: protected class Committee implements Serializable {
0194:
0195: /** for serialization */
0196: static final long serialVersionUID = 5559880306684082199L;
0197:
0198: protected int m_chunkSize;
0199:
0200: /** number eaten from m_currentSet */
0201: protected int m_instancesConsumed;
0202:
0203: protected FastVector m_models;
0204: protected double m_lastValidationError;
0205: protected double m_lastLogLikelihood;
0206: protected boolean m_modelHasChanged;
0207: protected boolean m_modelHasChangedLL;
0208: protected double[][] m_validationFs;
0209: protected double[][] m_newValidationFs;
0210:
0211: /**
0212: * constructor
0213: *
0214: * @param chunkSize the size of the chunk
0215: */
0216: public Committee(int chunkSize) {
0217:
0218: m_chunkSize = chunkSize;
0219: m_instancesConsumed = 0;
0220: m_models = new FastVector();
0221: m_lastValidationError = 1.0;
0222: m_lastLogLikelihood = Double.MAX_VALUE;
0223: m_modelHasChanged = true;
0224: m_modelHasChangedLL = true;
0225: m_validationFs = new double[m_validationChunkSize][m_NumClasses];
0226: m_newValidationFs = new double[m_validationChunkSize][m_NumClasses];
0227: }
0228:
0229: /**
0230: * update the committee
0231: *
0232: * @return true if the committee has changed
0233: * @throws Exception if anything goes wrong
0234: */
0235: public boolean update() throws Exception {
0236:
0237: boolean hasChanged = false;
0238: while (m_currentSet.numInstances() - m_instancesConsumed >= m_chunkSize) {
0239: Classifier[] newModel = boost(new Instances(
0240: m_currentSet, m_instancesConsumed, m_chunkSize));
0241: for (int i = 0; i < m_validationSet.numInstances(); i++) {
0242: m_newValidationFs[i] = updateFS(m_validationSet
0243: .instance(i), newModel, m_validationFs[i]);
0244: }
0245: m_models.addElement(newModel);
0246: m_instancesConsumed += m_chunkSize;
0247: hasChanged = true;
0248: }
0249: if (hasChanged) {
0250: m_modelHasChanged = true;
0251: m_modelHasChangedLL = true;
0252: }
0253: return hasChanged;
0254: }
0255:
0256: /** reset consumation counts */
0257: public void resetConsumed() {
0258:
0259: m_instancesConsumed = 0;
0260: }
0261:
0262: /** remove the last model from the committee */
0263: public void pruneLastModel() {
0264:
0265: if (m_models.size() > 0) {
0266: m_models.removeElementAt(m_models.size() - 1);
0267: m_modelHasChanged = true;
0268: m_modelHasChangedLL = true;
0269: }
0270: }
0271:
0272: /**
0273: * decide to keep the last model in the committee
0274: * @throws Exception if anything goes wrong
0275: */
0276: public void keepLastModel() throws Exception {
0277:
0278: m_validationFs = m_newValidationFs;
0279: m_newValidationFs = new double[m_validationChunkSize][m_NumClasses];
0280: m_modelHasChanged = true;
0281: m_modelHasChangedLL = true;
0282: }
0283:
0284: /**
0285: * calculate the log likelihood on the validation data
0286: * @return the log likelihood
0287: * @throws Exception if computation fails
0288: */
0289: public double logLikelihood() throws Exception {
0290:
0291: if (m_modelHasChangedLL) {
0292:
0293: Instance inst;
0294: double llsum = 0.0;
0295: for (int i = 0; i < m_validationSet.numInstances(); i++) {
0296: inst = m_validationSet.instance(i);
0297: llsum += (logLikelihood(m_validationFs[i],
0298: (int) inst.classValue()));
0299: }
0300: m_lastLogLikelihood = llsum
0301: / (double) m_validationSet.numInstances();
0302: m_modelHasChangedLL = false;
0303: }
0304: return m_lastLogLikelihood;
0305: }
0306:
0307: /**
0308: * calculate the log likelihood on the validation data after adding the last model
0309: * @return the log likelihood
0310: * @throws Exception if computation fails
0311: */
0312: public double logLikelihoodAfter() throws Exception {
0313:
0314: Instance inst;
0315: double llsum = 0.0;
0316: for (int i = 0; i < m_validationSet.numInstances(); i++) {
0317: inst = m_validationSet.instance(i);
0318: llsum += (logLikelihood(m_newValidationFs[i],
0319: (int) inst.classValue()));
0320: }
0321: return llsum / (double) m_validationSet.numInstances();
0322: }
0323:
0324: /**
0325: * calculates the log likelihood of an instance
0326: * @param Fs the Fs values
0327: * @param classIndex the class index
0328: * @return the log likelihood
0329: * @throws Exception if computation fails
0330: */
0331: private double logLikelihood(double[] Fs, int classIndex)
0332: throws Exception {
0333:
0334: return -Math.log(distributionForInstance(Fs)[classIndex]);
0335: }
0336:
0337: /**
0338: * calculates the validation error of the committee
0339: * @return the validation error
0340: * @throws Exception if computation fails
0341: */
0342: public double validationError() throws Exception {
0343:
0344: if (m_modelHasChanged) {
0345:
0346: Instance inst;
0347: int numIncorrect = 0;
0348: for (int i = 0; i < m_validationSet.numInstances(); i++) {
0349: inst = m_validationSet.instance(i);
0350: if (classifyInstance(m_validationFs[i]) != inst
0351: .classValue())
0352: numIncorrect++;
0353: }
0354: m_lastValidationError = (double) numIncorrect
0355: / (double) m_validationSet.numInstances();
0356: m_modelHasChanged = false;
0357: }
0358: return m_lastValidationError;
0359: }
0360:
0361: /**
0362: * returns the chunk size used by the committee
0363: *
0364: * @return the chunk size
0365: */
0366: public int chunkSize() {
0367:
0368: return m_chunkSize;
0369: }
0370:
0371: /**
0372: * returns the number of models in the committee
0373: *
0374: * @return the committee size
0375: */
0376: public int committeeSize() {
0377:
0378: return m_models.size();
0379: }
0380:
0381: /**
0382: * classifies an instance (given Fs values) with the committee
0383: *
0384: * @param Fs the Fs values
0385: * @return the classification
0386: * @throws Exception if anything goes wrong
0387: */
0388: public double classifyInstance(double[] Fs) throws Exception {
0389:
0390: double[] dist = distributionForInstance(Fs);
0391:
0392: double max = 0;
0393: int maxIndex = 0;
0394:
0395: for (int i = 0; i < dist.length; i++) {
0396: if (dist[i] > max) {
0397: maxIndex = i;
0398: max = dist[i];
0399: }
0400: }
0401: if (max > 0) {
0402: return maxIndex;
0403: } else {
0404: return Instance.missingValue();
0405: }
0406: }
0407:
0408: /**
0409: * classifies an instance with the committee
0410: *
0411: * @param instance the instance to classify
0412: * @return the classification
0413: * @throws Exception if anything goes wrong
0414: */
0415: public double classifyInstance(Instance instance)
0416: throws Exception {
0417:
0418: double[] dist = distributionForInstance(instance);
0419: switch (instance.classAttribute().type()) {
0420: case Attribute.NOMINAL:
0421: double max = 0;
0422: int maxIndex = 0;
0423:
0424: for (int i = 0; i < dist.length; i++) {
0425: if (dist[i] > max) {
0426: maxIndex = i;
0427: max = dist[i];
0428: }
0429: }
0430: if (max > 0) {
0431: return maxIndex;
0432: } else {
0433: return Instance.missingValue();
0434: }
0435: case Attribute.NUMERIC:
0436: return dist[0];
0437: default:
0438: return Instance.missingValue();
0439: }
0440: }
0441:
0442: /**
0443: * returns the distribution the committee generates for an instance (given Fs values)
0444: *
0445: * @param Fs the Fs values
0446: * @return the distribution
0447: * @throws Exception if anything goes wrong
0448: */
0449: public double[] distributionForInstance(double[] Fs)
0450: throws Exception {
0451:
0452: double[] distribution = new double[m_NumClasses];
0453: for (int j = 0; j < m_NumClasses; j++) {
0454: distribution[j] = RtoP(Fs, j);
0455: }
0456: return distribution;
0457: }
0458:
0459: /**
0460: * updates the Fs values given a new model in the committee
0461: *
0462: * @param instance the instance to use
0463: * @param newModel the new model
0464: * @param Fs the Fs values to update
0465: * @return the updated Fs values
0466: * @throws Exception if anything goes wrong
0467: */
0468: public double[] updateFS(Instance instance,
0469: Classifier[] newModel, double[] Fs) throws Exception {
0470:
0471: instance = (Instance) instance.copy();
0472: instance.setDataset(m_NumericClassData);
0473:
0474: double[] Fi = new double[m_NumClasses];
0475: double Fsum = 0;
0476: for (int j = 0; j < m_NumClasses; j++) {
0477: Fi[j] = newModel[j].classifyInstance(instance);
0478: Fsum += Fi[j];
0479: }
0480: Fsum /= m_NumClasses;
0481:
0482: double[] newFs = new double[Fs.length];
0483: for (int j = 0; j < m_NumClasses; j++) {
0484: newFs[j] = Fs[j]
0485: + ((Fi[j] - Fsum) * (m_NumClasses - 1) / m_NumClasses);
0486: }
0487: return newFs;
0488: }
0489:
0490: /**
0491: * returns the distribution the committee generates for an instance
0492: *
0493: * @param instance the instance to get the distribution for
0494: * @return the distribution
0495: * @throws Exception if anything goes wrong
0496: */
0497: public double[] distributionForInstance(Instance instance)
0498: throws Exception {
0499:
0500: instance = (Instance) instance.copy();
0501: instance.setDataset(m_NumericClassData);
0502: double[] Fs = new double[m_NumClasses];
0503: for (int i = 0; i < m_models.size(); i++) {
0504: double[] Fi = new double[m_NumClasses];
0505: double Fsum = 0;
0506: Classifier[] model = (Classifier[]) m_models
0507: .elementAt(i);
0508: for (int j = 0; j < m_NumClasses; j++) {
0509: Fi[j] = model[j].classifyInstance(instance);
0510: Fsum += Fi[j];
0511: }
0512: Fsum /= m_NumClasses;
0513: for (int j = 0; j < m_NumClasses; j++) {
0514: Fs[j] += (Fi[j] - Fsum) * (m_NumClasses - 1)
0515: / m_NumClasses;
0516: }
0517: }
0518: double[] distribution = new double[m_NumClasses];
0519: for (int j = 0; j < m_NumClasses; j++) {
0520: distribution[j] = RtoP(Fs, j);
0521: }
0522: return distribution;
0523: }
0524:
0525: /**
0526: * performs a boosting iteration, returning a new model for the committee
0527: *
0528: * @param data the data to boost on
0529: * @return the new model
0530: * @throws Exception if anything goes wrong
0531: */
0532: protected Classifier[] boost(Instances data) throws Exception {
0533:
0534: Classifier[] newModel = Classifier.makeCopies(m_Classifier,
0535: m_NumClasses);
0536:
0537: // Create a copy of the data with the class transformed into numeric
0538: Instances boostData = new Instances(data);
0539: boostData.deleteWithMissingClass();
0540: int numInstances = boostData.numInstances();
0541:
0542: // Temporarily unset the class index
0543: int classIndex = data.classIndex();
0544: boostData.setClassIndex(-1);
0545: boostData.deleteAttributeAt(classIndex);
0546: boostData.insertAttributeAt(
0547: new Attribute("'pseudo class'"), classIndex);
0548: boostData.setClassIndex(classIndex);
0549: double[][] trainFs = new double[numInstances][m_NumClasses];
0550: double[][] trainYs = new double[numInstances][m_NumClasses];
0551: for (int j = 0; j < m_NumClasses; j++) {
0552: for (int i = 0, k = 0; i < numInstances; i++, k++) {
0553: while (data.instance(k).classIsMissing())
0554: k++;
0555: trainYs[i][j] = (data.instance(k).classValue() == j) ? 1
0556: : 0;
0557: }
0558: }
0559:
0560: // Evaluate / increment trainFs from the classifiers
0561: for (int x = 0; x < m_models.size(); x++) {
0562: for (int i = 0; i < numInstances; i++) {
0563: double[] pred = new double[m_NumClasses];
0564: double predSum = 0;
0565: Classifier[] model = (Classifier[]) m_models
0566: .elementAt(x);
0567: for (int j = 0; j < m_NumClasses; j++) {
0568: pred[j] = model[j].classifyInstance(boostData
0569: .instance(i));
0570: predSum += pred[j];
0571: }
0572: predSum /= m_NumClasses;
0573: for (int j = 0; j < m_NumClasses; j++) {
0574: trainFs[i][j] += (pred[j] - predSum)
0575: * (m_NumClasses - 1) / m_NumClasses;
0576: }
0577: }
0578: }
0579:
0580: for (int j = 0; j < m_NumClasses; j++) {
0581:
0582: // Set instance pseudoclass and weights
0583: for (int i = 0; i < numInstances; i++) {
0584: double p = RtoP(trainFs[i], j);
0585: Instance current = boostData.instance(i);
0586: double z, actual = trainYs[i][j];
0587: if (actual == 1) {
0588: z = 1.0 / p;
0589: if (z > Z_MAX) { // threshold
0590: z = Z_MAX;
0591: }
0592: } else if (actual == 0) {
0593: z = -1.0 / (1.0 - p);
0594: if (z < -Z_MAX) { // threshold
0595: z = -Z_MAX;
0596: }
0597: } else {
0598: z = (actual - p) / (p * (1 - p));
0599: }
0600:
0601: double w = (actual - p) / z;
0602: current.setValue(classIndex, z);
0603: current.setWeight(numInstances * w);
0604: }
0605:
0606: Instances trainData = boostData;
0607: if (m_UseResampling) {
0608: double[] weights = new double[boostData
0609: .numInstances()];
0610: for (int kk = 0; kk < weights.length; kk++) {
0611: weights[kk] = boostData.instance(kk).weight();
0612: }
0613: trainData = boostData.resampleWithWeights(
0614: m_RandomInstance, weights);
0615: }
0616:
0617: // Build the classifier
0618: newModel[j].buildClassifier(trainData);
0619: }
0620:
0621: return newModel;
0622: }
0623:
0624: /**
0625: * outputs description of the committee
0626: *
0627: * @return a string representation of the classifier
0628: */
0629: public String toString() {
0630:
0631: StringBuffer text = new StringBuffer();
0632:
0633: text
0634: .append("RacedIncrementalLogitBoost: Best committee on validation data\n");
0635: text.append("Base classifiers: \n");
0636:
0637: for (int i = 0; i < m_models.size(); i++) {
0638: text.append("\nModel " + (i + 1));
0639: Classifier[] cModels = (Classifier[]) m_models
0640: .elementAt(i);
0641: for (int j = 0; j < m_NumClasses; j++) {
0642: text.append("\n\tClass " + (j + 1) + " ("
0643: + m_ClassAttribute.name() + "="
0644: + m_ClassAttribute.value(j) + ")\n\n"
0645: + cModels[j].toString() + "\n");
0646: }
0647: }
0648: text.append("Number of models: " + m_models.size() + "\n");
0649: text.append("Chunk size per model: " + m_chunkSize + "\n");
0650:
0651: return text.toString();
0652: }
0653: }
0654:
0655: /**
0656: * Returns default capabilities of the classifier.
0657: *
0658: * @return the capabilities of this classifier
0659: */
0660: public Capabilities getCapabilities() {
0661: Capabilities result = super .getCapabilities();
0662:
0663: // class
0664: result.disableAllClasses();
0665: result.disableAllClassDependencies();
0666: result.enable(Capability.NOMINAL_CLASS);
0667:
0668: // instances
0669: result.setMinimumNumberInstances(0);
0670:
0671: return result;
0672: }
0673:
0674: /**
0675: * Builds the classifier.
0676: *
0677: * @param data the instances to train the classifier with
0678: * @throws Exception if something goes wrong
0679: */
0680: public void buildClassifier(Instances data) throws Exception {
0681:
0682: m_RandomInstance = new Random(m_Seed);
0683:
0684: Instances boostData;
0685: int classIndex = data.classIndex();
0686:
0687: // can classifier handle the data?
0688: getCapabilities().testWithFail(data);
0689:
0690: // remove instances with missing class
0691: data = new Instances(data);
0692: data.deleteWithMissingClass();
0693:
0694: if (m_Classifier == null) {
0695: throw new Exception(
0696: "A base classifier has not been specified!");
0697: }
0698:
0699: if (!(m_Classifier instanceof WeightedInstancesHandler)
0700: && !m_UseResampling) {
0701: m_UseResampling = true;
0702: }
0703:
0704: m_NumClasses = data.numClasses();
0705: m_ClassAttribute = data.classAttribute();
0706:
0707: // Create a copy of the data with the class transformed into numeric
0708: boostData = new Instances(data);
0709:
0710: // Temporarily unset the class index
0711: boostData.setClassIndex(-1);
0712: boostData.deleteAttributeAt(classIndex);
0713: boostData.insertAttributeAt(new Attribute("'pseudo class'"),
0714: classIndex);
0715: boostData.setClassIndex(classIndex);
0716: m_NumericClassData = new Instances(boostData, 0);
0717:
0718: data.randomize(m_RandomInstance);
0719:
0720: // create the committees
0721: int cSize = m_minChunkSize;
0722: m_committees = new FastVector();
0723: while (cSize <= m_maxChunkSize) {
0724: m_committees.addElement(new Committee(cSize));
0725: m_maxBatchSizeRequired = cSize;
0726: cSize *= 2;
0727: }
0728:
0729: // set up for consumption
0730: m_validationSet = new Instances(data, m_validationChunkSize);
0731: m_currentSet = new Instances(data, m_maxBatchSizeRequired);
0732: m_bestCommittee = null;
0733: m_numInstancesConsumed = 0;
0734:
0735: // start eating what we've been given
0736: for (int i = 0; i < data.numInstances(); i++)
0737: updateClassifier(data.instance(i));
0738: }
0739:
0740: /**
0741: * Updates the classifier.
0742: *
0743: * @param instance the next instance in the stream of training data
0744: * @throws Exception if something goes wrong
0745: */
0746: public void updateClassifier(Instance instance) throws Exception {
0747:
0748: m_numInstancesConsumed++;
0749:
0750: if (m_validationSet.numInstances() < m_validationChunkSize) {
0751: m_validationSet.add(instance);
0752: m_validationSetChanged = true;
0753: } else {
0754: m_currentSet.add(instance);
0755: boolean hasChanged = false;
0756:
0757: // update each committee
0758: for (int i = 0; i < m_committees.size(); i++) {
0759: Committee c = (Committee) m_committees.elementAt(i);
0760: if (c.update()) {
0761:
0762: hasChanged = true;
0763:
0764: if (m_PruningType == PRUNETYPE_LOGLIKELIHOOD) {
0765: double oldLL = c.logLikelihood();
0766: double newLL = c.logLikelihoodAfter();
0767: if (newLL >= oldLL && c.committeeSize() > 1) {
0768: c.pruneLastModel();
0769: if (m_Debug)
0770: System.out.println("Pruning "
0771: + c.chunkSize()
0772: + " committee (" + oldLL
0773: + " < " + newLL + ")");
0774: } else
0775: c.keepLastModel();
0776: } else
0777: c.keepLastModel(); // no pruning
0778: }
0779: }
0780: if (hasChanged) {
0781:
0782: if (m_Debug)
0783: System.out.println("After consuming "
0784: + m_numInstancesConsumed
0785: + " instances... ("
0786: + m_validationSet.numInstances() + " + "
0787: + m_currentSet.numInstances()
0788: + " instances currently in memory)");
0789:
0790: // find best committee
0791: double lowestError = 1.0;
0792: for (int i = 0; i < m_committees.size(); i++) {
0793: Committee c = (Committee) m_committees.elementAt(i);
0794:
0795: if (c.committeeSize() > 0) {
0796:
0797: double err = c.validationError();
0798: double ll = c.logLikelihood();
0799:
0800: if (m_Debug)
0801: System.out
0802: .println("Chunk size "
0803: + c.chunkSize()
0804: + " with "
0805: + c.committeeSize()
0806: + " models, has validation error of "
0807: + err
0808: + ", log likelihood of "
0809: + ll);
0810: if (err < lowestError) {
0811: lowestError = err;
0812: m_bestCommittee = c;
0813: }
0814: }
0815: }
0816: }
0817: if (m_currentSet.numInstances() >= m_maxBatchSizeRequired) {
0818: m_currentSet = new Instances(m_currentSet,
0819: m_maxBatchSizeRequired);
0820:
0821: // reset consumation counts
0822: for (int i = 0; i < m_committees.size(); i++) {
0823: Committee c = (Committee) m_committees.elementAt(i);
0824: c.resetConsumed();
0825: }
0826: }
0827: }
0828: }
0829:
0830: /**
0831: * Convert from function responses to probabilities
0832: *
0833: * @param Fs an array containing the responses from each function
0834: * @param j the class value of interest
0835: * @return the probability prediction for j
0836: * @throws Exception if can't normalize
0837: */
0838: protected static double RtoP(double[] Fs, int j) throws Exception {
0839:
0840: double maxF = -Double.MAX_VALUE;
0841: for (int i = 0; i < Fs.length; i++) {
0842: if (Fs[i] > maxF) {
0843: maxF = Fs[i];
0844: }
0845: }
0846: double sum = 0;
0847: double[] probs = new double[Fs.length];
0848: for (int i = 0; i < Fs.length; i++) {
0849: probs[i] = Math.exp(Fs[i] - maxF);
0850: sum += probs[i];
0851: }
0852: if (sum == 0) {
0853: throw new Exception("Can't normalize");
0854: }
0855: return probs[j] / sum;
0856: }
0857:
0858: /**
0859: * Computes class distribution of an instance using the best committee.
0860: *
0861: * @param instance the instance to get the distribution for
0862: * @return the distribution
0863: * @throws Exception if anything goes wrong
0864: */
0865: public double[] distributionForInstance(Instance instance)
0866: throws Exception {
0867:
0868: if (m_bestCommittee != null)
0869: return m_bestCommittee.distributionForInstance(instance);
0870: else {
0871: if (m_validationSetChanged || m_zeroR == null) {
0872: m_zeroR = new ZeroR();
0873: m_zeroR.buildClassifier(m_validationSet);
0874: m_validationSetChanged = false;
0875: }
0876: return m_zeroR.distributionForInstance(instance);
0877: }
0878: }
0879:
0880: /**
0881: * Returns an enumeration describing the available options
0882: *
0883: * @return an enumeration of all the available options
0884: */
0885: public Enumeration listOptions() {
0886:
0887: Vector newVector = new Vector(9);
0888:
0889: newVector.addElement(new Option("\tMinimum size of chunks.\n"
0890: + "\t(default 500)", "C", 1, "-C <num>"));
0891:
0892: newVector.addElement(new Option("\tMaximum size of chunks.\n"
0893: + "\t(default 2000)", "M", 1, "-M <num>"));
0894:
0895: newVector.addElement(new Option("\tSize of validation set.\n"
0896: + "\t(default 1000)", "V", 1, "-V <num>"));
0897:
0898: newVector.addElement(new Option(
0899: "\tCommittee pruning to perform.\n"
0900: + "\t0=none, 1=log likelihood (default)", "P",
0901: 1, "-P <pruning type>"));
0902:
0903: newVector.addElement(new Option(
0904: "\tUse resampling for boosting.", "Q", 0, "-Q"));
0905:
0906: Enumeration enu = super .listOptions();
0907: while (enu.hasMoreElements()) {
0908: newVector.addElement(enu.nextElement());
0909: }
0910: return newVector.elements();
0911: }
0912:
0913: /**
0914: * Parses a given list of options. <p/>
0915: *
0916: <!-- options-start -->
0917: * Valid options are: <p/>
0918: *
0919: * <pre> -C <num>
0920: * Minimum size of chunks.
0921: * (default 500)</pre>
0922: *
0923: * <pre> -M <num>
0924: * Maximum size of chunks.
0925: * (default 2000)</pre>
0926: *
0927: * <pre> -V <num>
0928: * Size of validation set.
0929: * (default 1000)</pre>
0930: *
0931: * <pre> -P <pruning type>
0932: * Committee pruning to perform.
0933: * 0=none, 1=log likelihood (default)</pre>
0934: *
0935: * <pre> -Q
0936: * Use resampling for boosting.</pre>
0937: *
0938: * <pre> -S <num>
0939: * Random number seed.
0940: * (default 1)</pre>
0941: *
0942: * <pre> -D
0943: * If set, classifier is run in debug mode and
0944: * may output additional info to the console</pre>
0945: *
0946: * <pre> -W
0947: * Full name of base classifier.
0948: * (default: weka.classifiers.trees.DecisionStump)</pre>
0949: *
0950: * <pre>
0951: * Options specific to classifier weka.classifiers.trees.DecisionStump:
0952: * </pre>
0953: *
0954: * <pre> -D
0955: * If set, classifier is run in debug mode and
0956: * may output additional info to the console</pre>
0957: *
0958: <!-- options-end -->
0959: *
0960: * @param options the list of options as an array of strings
0961: * @throws Exception if an option is not supported
0962: */
0963: public void setOptions(String[] options) throws Exception {
0964:
0965: String minChunkSize = Utils.getOption('C', options);
0966: if (minChunkSize.length() != 0) {
0967: setMinChunkSize(Integer.parseInt(minChunkSize));
0968: } else {
0969: setMinChunkSize(500);
0970: }
0971:
0972: String maxChunkSize = Utils.getOption('M', options);
0973: if (maxChunkSize.length() != 0) {
0974: setMaxChunkSize(Integer.parseInt(maxChunkSize));
0975: } else {
0976: setMaxChunkSize(2000);
0977: }
0978:
0979: String validationChunkSize = Utils.getOption('V', options);
0980: if (validationChunkSize.length() != 0) {
0981: setValidationChunkSize(Integer
0982: .parseInt(validationChunkSize));
0983: } else {
0984: setValidationChunkSize(1000);
0985: }
0986:
0987: String pruneType = Utils.getOption('P', options);
0988: if (pruneType.length() != 0) {
0989: setPruningType(new SelectedTag(Integer.parseInt(pruneType),
0990: TAGS_PRUNETYPE));
0991: } else {
0992: setPruningType(new SelectedTag(PRUNETYPE_LOGLIKELIHOOD,
0993: TAGS_PRUNETYPE));
0994: }
0995:
0996: setUseResampling(Utils.getFlag('Q', options));
0997:
0998: super .setOptions(options);
0999: }
1000:
1001: /**
1002: * Gets the current settings of the Classifier.
1003: *
1004: * @return an array of strings suitable for passing to setOptions
1005: */
1006: public String[] getOptions() {
1007:
1008: String[] super Options = super .getOptions();
1009: String[] options = new String[super Options.length + 9];
1010:
1011: int current = 0;
1012:
1013: if (getUseResampling()) {
1014: options[current++] = "-Q";
1015: }
1016: options[current++] = "-C";
1017: options[current++] = "" + getMinChunkSize();
1018:
1019: options[current++] = "-M";
1020: options[current++] = "" + getMaxChunkSize();
1021:
1022: options[current++] = "-V";
1023: options[current++] = "" + getValidationChunkSize();
1024:
1025: options[current++] = "-P";
1026: options[current++] = "" + m_PruningType;
1027:
1028: System.arraycopy(super Options, 0, options, current,
1029: super Options.length);
1030:
1031: current += super Options.length;
1032: while (current < options.length) {
1033: options[current++] = "";
1034: }
1035: return options;
1036: }
1037:
1038: /**
1039: * @return a description of the classifier suitable for
1040: * displaying in the explorer/experimenter gui
1041: */
1042: public String globalInfo() {
1043:
1044: return "Classifier for incremental learning of large datasets by way of racing logit-boosted committees.";
1045: }
1046:
1047: /**
1048: * Set the base learner.
1049: *
1050: * @param newClassifier the classifier to use.
1051: * @throws IllegalArgumentException if base classifier cannot handle numeric
1052: * class
1053: */
1054: public void setClassifier(Classifier newClassifier) {
1055: Capabilities cap = newClassifier.getCapabilities();
1056:
1057: if (!cap.handles(Capability.NUMERIC_CLASS))
1058: throw new IllegalArgumentException(
1059: "Base classifier cannot handle numeric class!");
1060:
1061: super .setClassifier(newClassifier);
1062: }
1063:
1064: /**
1065: * @return tip text for this property suitable for
1066: * displaying in the explorer/experimenter gui
1067: */
1068: public String minChunkSizeTipText() {
1069:
1070: return "The minimum number of instances to train the base learner with.";
1071: }
1072:
1073: /**
1074: * Set the minimum chunk size
1075: *
1076: * @param chunkSize the minimum chunk size
1077: */
1078: public void setMinChunkSize(int chunkSize) {
1079:
1080: m_minChunkSize = chunkSize;
1081: }
1082:
1083: /**
1084: * Get the minimum chunk size
1085: *
1086: * @return the chunk size
1087: */
1088: public int getMinChunkSize() {
1089:
1090: return m_minChunkSize;
1091: }
1092:
1093: /**
1094: * @return tip text for this property suitable for
1095: * displaying in the explorer/experimenter gui
1096: */
1097: public String maxChunkSizeTipText() {
1098:
1099: return "The maximum number of instances to train the base learner with. The chunk sizes used will start at minChunkSize and grow twice as large for as many times as they are less than or equal to the maximum size.";
1100: }
1101:
1102: /**
1103: * Set the maximum chunk size
1104: *
1105: * @param chunkSize the maximum chunk size
1106: */
1107: public void setMaxChunkSize(int chunkSize) {
1108:
1109: m_maxChunkSize = chunkSize;
1110: }
1111:
1112: /**
1113: * Get the maximum chunk size
1114: *
1115: * @return the chunk size
1116: */
1117: public int getMaxChunkSize() {
1118:
1119: return m_maxChunkSize;
1120: }
1121:
1122: /**
1123: * @return tip text for this property suitable for
1124: * displaying in the explorer/experimenter gui
1125: */
1126: public String validationChunkSizeTipText() {
1127:
1128: return "The number of instances to hold out for validation. These instances will be taken from the beginning of the stream, so learning will not start until these instances have been consumed first.";
1129: }
1130:
1131: /**
1132: * Set the validation chunk size
1133: *
1134: * @param chunkSize the validation chunk size
1135: */
1136: public void setValidationChunkSize(int chunkSize) {
1137:
1138: m_validationChunkSize = chunkSize;
1139: }
1140:
1141: /**
1142: * Get the validation chunk size
1143: *
1144: * @return the chunk size
1145: */
1146: public int getValidationChunkSize() {
1147:
1148: return m_validationChunkSize;
1149: }
1150:
1151: /**
1152: * @return tip text for this property suitable for
1153: * displaying in the explorer/experimenter gui
1154: */
1155: public String pruningTypeTipText() {
1156:
1157: return "The pruning method to use within each committee. Log likelihood pruning will discard new models if they have a negative effect on the log likelihood of the validation data.";
1158: }
1159:
1160: /**
1161: * Set the pruning type
1162: *
1163: * @param pruneType the pruning type
1164: */
1165: public void setPruningType(SelectedTag pruneType) {
1166:
1167: if (pruneType.getTags() == TAGS_PRUNETYPE) {
1168: m_PruningType = pruneType.getSelectedTag().getID();
1169: }
1170: }
1171:
1172: /**
1173: * Get the pruning type
1174: *
1175: * @return the type
1176: */
1177: public SelectedTag getPruningType() {
1178:
1179: return new SelectedTag(m_PruningType, TAGS_PRUNETYPE);
1180: }
1181:
1182: /**
1183: * @return tip text for this property suitable for
1184: * displaying in the explorer/experimenter gui
1185: */
1186: public String useResamplingTipText() {
1187:
1188: return "Force the use of resampling data rather than using the weight-handling capabilities of the base classifier. Resampling is always used if the base classifier cannot handle weighted instances.";
1189: }
1190:
1191: /**
1192: * Set resampling mode
1193: *
1194: * @param r true if resampling should be done
1195: */
1196: public void setUseResampling(boolean r) {
1197:
1198: m_UseResampling = r;
1199: }
1200:
1201: /**
1202: * Get whether resampling is turned on
1203: *
1204: * @return true if resampling output is on
1205: */
1206: public boolean getUseResampling() {
1207:
1208: return m_UseResampling;
1209: }
1210:
1211: /**
1212: * Get the best committee chunk size
1213: *
1214: * @return the best committee chunk size
1215: */
1216: public int getBestCommitteeChunkSize() {
1217:
1218: if (m_bestCommittee != null) {
1219: return m_bestCommittee.chunkSize();
1220: } else
1221: return 0;
1222: }
1223:
1224: /**
1225: * Get the number of members in the best committee
1226: *
1227: * @return the number of members
1228: */
1229: public int getBestCommitteeSize() {
1230:
1231: if (m_bestCommittee != null) {
1232: return m_bestCommittee.committeeSize();
1233: } else
1234: return 0;
1235: }
1236:
1237: /**
1238: * Get the best committee's error on the validation data
1239: *
1240: * @return the best committee's error
1241: */
1242: public double getBestCommitteeErrorEstimate() {
1243:
1244: if (m_bestCommittee != null) {
1245: try {
1246: return m_bestCommittee.validationError() * 100.0;
1247: } catch (Exception e) {
1248: System.err.println(e.getMessage());
1249: return 100.0;
1250: }
1251: } else
1252: return 100.0;
1253: }
1254:
1255: /**
1256: * Get the best committee's log likelihood on the validation data
1257: *
1258: * @return best committee's log likelihood
1259: */
1260: public double getBestCommitteeLLEstimate() {
1261:
1262: if (m_bestCommittee != null) {
1263: try {
1264: return m_bestCommittee.logLikelihood();
1265: } catch (Exception e) {
1266: System.err.println(e.getMessage());
1267: return Double.MAX_VALUE;
1268: }
1269: } else
1270: return Double.MAX_VALUE;
1271: }
1272:
1273: /**
1274: * Returns description of the boosted classifier.
1275: *
1276: * @return description of the boosted classifier as a string
1277: */
1278: public String toString() {
1279:
1280: if (m_bestCommittee != null) {
1281: return m_bestCommittee.toString();
1282: } else {
1283: if ((m_validationSetChanged || m_zeroR == null)
1284: && m_validationSet != null
1285: && m_validationSet.numInstances() > 0) {
1286: m_zeroR = new ZeroR();
1287: try {
1288: m_zeroR.buildClassifier(m_validationSet);
1289: } catch (Exception e) {
1290: }
1291: m_validationSetChanged = false;
1292: }
1293: if (m_zeroR != null) {
1294: return ("RacedIncrementalLogitBoost: insufficient data to build model, resorting to ZeroR:\n\n" + m_zeroR
1295: .toString());
1296: } else
1297: return ("RacedIncrementalLogitBoost: no model built yet.");
1298: }
1299: }
1300:
1301: /**
1302: * Main method for this class.
1303: *
1304: * @param argv the commandline parameters
1305: */
1306: public static void main(String[] argv) {
1307: runClassifier(new RacedIncrementalLogitBoost(), argv);
1308: }
1309: }
|