0001: /*
0002: * This program is free software; you can redistribute it and/or modify
0003: * it under the terms of the GNU General Public License as published by
0004: * the Free Software Foundation; either version 2 of the License, or
0005: * (at your option) any later version.
0006: *
0007: * This program is distributed in the hope that it will be useful,
0008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
0009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
0010: * GNU General Public License for more details.
0011: *
0012: * You should have received a copy of the GNU General Public License
0013: * along with this program; if not, write to the Free Software
0014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
0015: */
0016:
0017: /*
0018: * BFTree.java
0019: * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
0020: *
0021: */
0022:
0023: package weka.classifiers.trees;
0024:
0025: import weka.classifiers.Evaluation;
0026: import weka.classifiers.RandomizableClassifier;
0027: import weka.core.AdditionalMeasureProducer;
0028: import weka.core.Attribute;
0029: import weka.core.Capabilities;
0030: import weka.core.FastVector;
0031: import weka.core.Instance;
0032: import weka.core.Instances;
0033: import weka.core.Option;
0034: import weka.core.SelectedTag;
0035: import weka.core.Tag;
0036: import weka.core.TechnicalInformation;
0037: import weka.core.TechnicalInformationHandler;
0038: import weka.core.Utils;
0039: import weka.core.Capabilities.Capability;
0040: import weka.core.TechnicalInformation.Field;
0041: import weka.core.TechnicalInformation.Type;
0042: import weka.core.matrix.Matrix;
0043:
0044: import java.util.Arrays;
0045: import java.util.Enumeration;
0046: import java.util.Random;
0047: import java.util.Vector;
0048:
0049: /**
0050: <!-- globalinfo-start -->
0051: * Class for building a best-first decision tree classifier. This class uses binary split for both nominal and numeric attributes. For missing values, the method of 'fractional' instances is used.<br/>
0052: * <br/>
0053: * For more information, see:<br/>
0054: * <br/>
0055: * Haijian Shi (2007). Best-first decision tree learning. Hamilton, NZ.<br/>
0056: * <br/>
0057: * Jerome Friedman, Trevor Hastie, Robert Tibshirani (2000). Additive logistic regression : A statistical view of boosting. Annals of statistics. 28(2):337-407.
0058: * <p/>
0059: <!-- globalinfo-end -->
0060: *
0061: <!-- technical-bibtex-start -->
0062: * BibTeX:
0063: * <pre>
0064: * @mastersthesis{Shi2007,
0065: * address = {Hamilton, NZ},
0066: * author = {Haijian Shi},
0067: * note = {COMP594},
0068: * school = {University of Waikato},
0069: * title = {Best-first decision tree learning},
0070: * year = {2007}
0071: * }
0072: *
0073: * @article{Friedman2000,
0074: * author = {Jerome Friedman and Trevor Hastie and Robert Tibshirani},
0075: * journal = {Annals of statistics},
0076: * number = {2},
0077: * pages = {337-407},
0078: * title = {Additive logistic regression : A statistical view of boosting},
0079: * volume = {28},
0080: * year = {2000},
0081: * ISSN = {0090-5364}
0082: * }
0083: * </pre>
0084: * <p/>
0085: <!-- technical-bibtex-end -->
0086: *
0087: <!-- options-start -->
0088: * Valid options are: <p/>
0089: *
0090: * <pre> -S <num>
0091: * Random number seed.
0092: * (default 1)</pre>
0093: *
0094: * <pre> -D
0095: * If set, classifier is run in debug mode and
0096: * may output additional info to the console</pre>
0097: *
0098: * <pre> -P <UNPRUNED|POSTPRUNED|PREPRUNED>
0099: * The pruning strategy.
0100: * (default: POSTPRUNED)</pre>
0101: *
0102: * <pre> -M <min no>
0103: * The minimal number of instances at the terminal nodes.
0104: * (default 2)</pre>
0105: *
0106: * <pre> -N <num folds>
0107: * The number of folds used in the pruning.
0108: * (default 5)</pre>
0109: *
0110: * <pre> -H
0111: * Don't use heuristic search for nominal attributes in multi-class
0112: * problem (default yes).
0113: * </pre>
0114: *
0115: * <pre> -G
0116: * Don't use Gini index for splitting (default yes),
0117: * if not information is used.</pre>
0118: *
0119: * <pre> -R
0120: * Don't use error rate in internal cross-validation (default yes),
0121: * but root mean squared error.</pre>
0122: *
0123: * <pre> -A
0124: * Use the 1 SE rule to make pruning decision.
0125: * (default no).</pre>
0126: *
0127: * <pre> -C
0128: * Percentage of training data size (0-1]
0129: * (default 1).</pre>
0130: *
0131: <!-- options-end -->
0132: *
0133: * @author Haijian Shi (hs69@cs.waikato.ac.nz)
0134: * @version $Revision: 1.2 $
0135: */
0136: public class BFTree extends RandomizableClassifier implements
0137: AdditionalMeasureProducer, TechnicalInformationHandler {
0138:
0139: /** For serialization. */
0140: private static final long serialVersionUID = -7035607375962528217L;
0141:
0142: /** pruning strategy: un-pruned */
0143: public static final int PRUNING_UNPRUNED = 0;
0144: /** pruning strategy: post-pruning */
0145: public static final int PRUNING_POSTPRUNING = 1;
0146: /** pruning strategy: pre-pruning */
0147: public static final int PRUNING_PREPRUNING = 2;
0148: /** pruning strategy */
0149: public static final Tag[] TAGS_PRUNING = {
0150: new Tag(PRUNING_UNPRUNED, "unpruned", "Un-pruned"),
0151: new Tag(PRUNING_POSTPRUNING, "postpruned", "Post-pruning"),
0152: new Tag(PRUNING_PREPRUNING, "prepruned", "Pre-pruning") };
0153:
0154: /** the pruning strategy */
0155: protected int m_PruningStrategy = PRUNING_POSTPRUNING;
0156:
0157: /** Successor nodes. */
0158: protected BFTree[] m_Successors;
0159:
0160: /** Attribute used for splitting. */
0161: protected Attribute m_Attribute;
0162:
0163: /** Split point (for numeric attributes). */
0164: protected double m_SplitValue;
0165:
0166: /** Split subset (for nominal attributes). */
0167: protected String m_SplitString;
0168:
0169: /** Class value for a node. */
0170: protected double m_ClassValue;
0171:
0172: /** Class attribute of a dataset. */
0173: protected Attribute m_ClassAttribute;
0174:
0175: /** Minimum number of instances at leaf nodes. */
0176: protected int m_minNumObj = 2;
0177:
0178: /** Number of folds for the pruning. */
0179: protected int m_numFoldsPruning = 5;
0180:
0181: /** If the ndoe is leaf node. */
0182: protected boolean m_isLeaf;
0183:
0184: /** Number of expansions. */
0185: protected static int m_Expansion;
0186:
0187: /** Fixed number of expansions (if no pruning method is used, its value is -1. Otherwise,
0188: * its value is gotten from internal cross-validation). */
0189: protected int m_FixedExpansion = -1;
0190:
0191: /** If use huristic search for binary split (default true). Note even if its value is true, it is only
0192: * used when the number of values of a nominal attribute is larger than 4. */
0193: protected boolean m_Heuristic = true;
0194:
0195: /** If use Gini index as the splitting criterion - default (if not, information is used). */
0196: protected boolean m_UseGini = true;
0197:
0198: /** If use error rate in internal cross-validation to fix the number of expansions - default
0199: * (if not, root mean squared error is used). */
0200: protected boolean m_UseErrorRate = true;
0201:
0202: /** If use the 1SE rule to make the decision. */
0203: protected boolean m_UseOneSE = false;
0204:
0205: /** Class distributions. */
0206: protected double[] m_Distribution;
0207:
0208: /** Branch proportions. */
0209: protected double[] m_Props;
0210:
0211: /** Sorted indices. */
0212: protected int[][] m_SortedIndices;
0213:
0214: /** Sorted weights. */
0215: protected double[][] m_Weights;
0216:
0217: /** Distributions of each attribute for two successor nodes. */
0218: protected double[][][] m_Dists;
0219:
0220: /** Class probabilities. */
0221: protected double[] m_ClassProbs;
0222:
0223: /** Total weights. */
0224: protected double m_TotalWeight;
0225:
0226: /** The training data size (0-1). Default 1. */
0227: protected double m_SizePer = 1;
0228:
0229: /**
0230: * Returns a string describing classifier
0231: *
0232: * @return a description suitable for displaying in the
0233: * explorer/experimenter gui
0234: */
0235: public String globalInfo() {
0236: return "Class for building a best-first decision tree classifier. "
0237: + "This class uses binary split for both nominal and numeric attributes. "
0238: + "For missing values, the method of 'fractional' instances is used.\n\n"
0239: + "For more information, see:\n\n"
0240: + getTechnicalInformation().toString();
0241: }
0242:
0243: /**
0244: * Returns an instance of a TechnicalInformation object, containing
0245: * detailed information about the technical background of this class,
0246: * e.g., paper reference or book this class is based on.
0247: *
0248: * @return the technical information about this class
0249: */
0250: public TechnicalInformation getTechnicalInformation() {
0251: TechnicalInformation result;
0252: TechnicalInformation additional;
0253:
0254: result = new TechnicalInformation(Type.MASTERSTHESIS);
0255: result.setValue(Field.AUTHOR, "Haijian Shi");
0256: result.setValue(Field.YEAR, "2007");
0257: result.setValue(Field.TITLE,
0258: "Best-first decision tree learning");
0259: result.setValue(Field.SCHOOL, "University of Waikato");
0260: result.setValue(Field.ADDRESS, "Hamilton, NZ");
0261: result.setValue(Field.NOTE, "COMP594");
0262:
0263: additional = result.add(Type.ARTICLE);
0264: additional
0265: .setValue(Field.AUTHOR,
0266: "Jerome Friedman and Trevor Hastie and Robert Tibshirani");
0267: additional.setValue(Field.YEAR, "2000");
0268: additional
0269: .setValue(Field.TITLE,
0270: "Additive logistic regression : A statistical view of boosting");
0271: additional.setValue(Field.JOURNAL, "Annals of statistics");
0272: additional.setValue(Field.VOLUME, "28");
0273: additional.setValue(Field.NUMBER, "2");
0274: additional.setValue(Field.PAGES, "337-407");
0275: additional.setValue(Field.ISSN, "0090-5364");
0276:
0277: return result;
0278: }
0279:
0280: /**
0281: * Returns default capabilities of the classifier.
0282: *
0283: * @return the capabilities of this classifier
0284: */
0285: public Capabilities getCapabilities() {
0286: Capabilities result = super .getCapabilities();
0287:
0288: // attributes
0289: result.enable(Capability.NOMINAL_ATTRIBUTES);
0290: result.enable(Capability.NUMERIC_ATTRIBUTES);
0291: result.enable(Capability.MISSING_VALUES);
0292:
0293: // class
0294: result.enable(Capability.NOMINAL_CLASS);
0295:
0296: return result;
0297: }
0298:
0299: /**
0300: * Method for building a BestFirst decision tree classifier.
0301: *
0302: * @param data set of instances serving as training data
0303: * @throws Exception if decision tree cannot be built successfully
0304: */
0305: public void buildClassifier(Instances data) throws Exception {
0306:
0307: getCapabilities().testWithFail(data);
0308: data = new Instances(data);
0309: data.deleteWithMissingClass();
0310:
0311: // build an unpruned tree
0312: if (m_PruningStrategy == PRUNING_UNPRUNED) {
0313:
0314: // calculate sorted indices, weights and initial class probabilities
0315: int[][] sortedIndices = new int[data.numAttributes()][0];
0316: double[][] weights = new double[data.numAttributes()][0];
0317: double[] classProbs = new double[data.numClasses()];
0318: double totalWeight = computeSortedInfo(data, sortedIndices,
0319: weights, classProbs);
0320:
0321: // Compute information of the best split for this node (include split attribute,
0322: // split value and gini gain (or information gain)). At the same time, compute
0323: // variables dists, props and totalSubsetWeights.
0324: double[][][] dists = new double[data.numAttributes()][2][data
0325: .numClasses()];
0326: double[][] props = new double[data.numAttributes()][2];
0327: double[][] totalSubsetWeights = new double[data
0328: .numAttributes()][2];
0329: FastVector nodeInfo = computeSplitInfo(this , data,
0330: sortedIndices, weights, dists, props,
0331: totalSubsetWeights, m_Heuristic, m_UseGini);
0332:
0333: // add the node (with all split info) into BestFirstElements
0334: FastVector BestFirstElements = new FastVector();
0335: BestFirstElements.addElement(nodeInfo);
0336:
0337: // Make the best-first decision tree.
0338: int attIndex = ((Attribute) nodeInfo.elementAt(1)).index();
0339: m_Expansion = 0;
0340: makeTree(BestFirstElements, data, sortedIndices, weights,
0341: dists, classProbs, totalWeight, props[attIndex],
0342: m_minNumObj, m_Heuristic, m_UseGini,
0343: m_FixedExpansion);
0344:
0345: return;
0346: }
0347:
0348: // the following code is for pre-pruning and post-pruning methods
0349:
0350: // Compute train data, test data, sorted indices, sorted weights, total weights,
0351: // class probabilities, class distributions, branch proportions and total subset
0352: // weights for root nodes of each fold for prepruning and postpruning.
0353: int expansion = 0;
0354:
0355: Random random = new Random(m_Seed);
0356: Instances cvData = new Instances(data);
0357: cvData.randomize(random);
0358: cvData = new Instances(cvData, 0,
0359: (int) (cvData.numInstances() * m_SizePer) - 1);
0360: cvData.stratify(m_numFoldsPruning);
0361:
0362: Instances[] train = new Instances[m_numFoldsPruning];
0363: Instances[] test = new Instances[m_numFoldsPruning];
0364: FastVector[] parallelBFElements = new FastVector[m_numFoldsPruning];
0365: BFTree[] m_roots = new BFTree[m_numFoldsPruning];
0366:
0367: int[][][] sortedIndices = new int[m_numFoldsPruning][data
0368: .numAttributes()][0];
0369: double[][][] weights = new double[m_numFoldsPruning][data
0370: .numAttributes()][0];
0371: double[][] classProbs = new double[m_numFoldsPruning][data
0372: .numClasses()];
0373: double[] totalWeight = new double[m_numFoldsPruning];
0374:
0375: double[][][][] dists = new double[m_numFoldsPruning][data
0376: .numAttributes()][2][data.numClasses()];
0377: double[][][] props = new double[m_numFoldsPruning][data
0378: .numAttributes()][2];
0379: double[][][] totalSubsetWeights = new double[m_numFoldsPruning][data
0380: .numAttributes()][2];
0381: FastVector[] nodeInfo = new FastVector[m_numFoldsPruning];
0382:
0383: for (int i = 0; i < m_numFoldsPruning; i++) {
0384: train[i] = cvData.trainCV(m_numFoldsPruning, i);
0385: test[i] = cvData.testCV(m_numFoldsPruning, i);
0386: parallelBFElements[i] = new FastVector();
0387: m_roots[i] = new BFTree();
0388:
0389: // calculate sorted indices, weights, initial class counts and total weights for each training data
0390: totalWeight[i] = computeSortedInfo(train[i],
0391: sortedIndices[i], weights[i], classProbs[i]);
0392:
0393: // compute information of the best split for this node (include split attribute,
0394: // split value and gini gain (or information gain)) in this fold
0395: nodeInfo[i] = computeSplitInfo(m_roots[i], train[i],
0396: sortedIndices[i], weights[i], dists[i], props[i],
0397: totalSubsetWeights[i], m_Heuristic, m_UseGini);
0398:
0399: // compute information for root nodes
0400:
0401: int attIndex = ((Attribute) nodeInfo[i].elementAt(1))
0402: .index();
0403:
0404: m_roots[i].m_SortedIndices = new int[sortedIndices[i].length][0];
0405: m_roots[i].m_Weights = new double[weights[i].length][0];
0406: m_roots[i].m_Dists = new double[dists[i].length][0][0];
0407: m_roots[i].m_ClassProbs = new double[classProbs[i].length];
0408: m_roots[i].m_Distribution = new double[classProbs[i].length];
0409: m_roots[i].m_Props = new double[2];
0410:
0411: for (int j = 0; j < m_roots[i].m_SortedIndices.length; j++) {
0412: m_roots[i].m_SortedIndices[j] = sortedIndices[i][j];
0413: m_roots[i].m_Weights[j] = weights[i][j];
0414: m_roots[i].m_Dists[j] = dists[i][j];
0415: }
0416:
0417: System.arraycopy(classProbs[i], 0, m_roots[i].m_ClassProbs,
0418: 0, classProbs[i].length);
0419: if (Utils.sum(m_roots[i].m_ClassProbs) != 0)
0420: Utils.normalize(m_roots[i].m_ClassProbs);
0421:
0422: System.arraycopy(classProbs[i], 0,
0423: m_roots[i].m_Distribution, 0, classProbs[i].length);
0424: System.arraycopy(props[i][attIndex], 0, m_roots[i].m_Props,
0425: 0, props[i][attIndex].length);
0426:
0427: m_roots[i].m_TotalWeight = totalWeight[i];
0428:
0429: parallelBFElements[i].addElement(nodeInfo[i]);
0430: }
0431:
0432: // build a pre-pruned tree
0433: if (m_PruningStrategy == PRUNING_PREPRUNING) {
0434:
0435: double previousError = Double.MAX_VALUE;
0436: double currentError = previousError;
0437: double minError = Double.MAX_VALUE;
0438: int minExpansion = 0;
0439: FastVector errorList = new FastVector();
0440: while (true) {
0441: // compute average error
0442: double expansionError = 0;
0443: int count = 0;
0444:
0445: for (int i = 0; i < m_numFoldsPruning; i++) {
0446: Evaluation eval;
0447:
0448: // calculate error rate if only root node
0449: if (expansion == 0) {
0450: m_roots[i].m_isLeaf = true;
0451: eval = new Evaluation(test[i]);
0452: eval.evaluateModel(m_roots[i], test[i]);
0453: if (m_UseErrorRate)
0454: expansionError += eval.errorRate();
0455: else
0456: expansionError += eval
0457: .rootMeanSquaredError();
0458: count++;
0459: }
0460:
0461: // make tree - expand one node at a time
0462: else {
0463: if (m_roots[i] == null)
0464: continue; // if the tree cannot be expanded, go to next fold
0465: m_roots[i].m_isLeaf = false;
0466: BFTree nodeToSplit = (BFTree) (((FastVector) (parallelBFElements[i]
0467: .elementAt(0))).elementAt(0));
0468: if (!m_roots[i].makeTree(parallelBFElements[i],
0469: m_roots[i], train[i],
0470: nodeToSplit.m_SortedIndices,
0471: nodeToSplit.m_Weights,
0472: nodeToSplit.m_Dists,
0473: nodeToSplit.m_ClassProbs,
0474: nodeToSplit.m_TotalWeight,
0475: nodeToSplit.m_Props, m_minNumObj,
0476: m_Heuristic, m_UseGini)) {
0477: m_roots[i] = null; // cannot be expanded
0478: continue;
0479: }
0480: eval = new Evaluation(test[i]);
0481: eval.evaluateModel(m_roots[i], test[i]);
0482: if (m_UseErrorRate)
0483: expansionError += eval.errorRate();
0484: else
0485: expansionError += eval
0486: .rootMeanSquaredError();
0487: count++;
0488: }
0489: }
0490:
0491: // no tree can be expanded any more
0492: if (count == 0)
0493: break;
0494:
0495: expansionError /= count;
0496: errorList.addElement(new Double(expansionError));
0497: currentError = expansionError;
0498:
0499: if (!m_UseOneSE) {
0500: if (currentError > previousError)
0501: break;
0502: }
0503:
0504: else {
0505: if (expansionError < minError) {
0506: minError = expansionError;
0507: minExpansion = expansion;
0508: }
0509:
0510: if (currentError > previousError) {
0511: double oneSE = Math.sqrt(minError
0512: * (1 - minError) / data.numInstances());
0513: if (currentError > minError + oneSE) {
0514: break;
0515: }
0516: }
0517: }
0518:
0519: expansion++;
0520: previousError = currentError;
0521: }
0522:
0523: if (!m_UseOneSE)
0524: expansion = expansion - 1;
0525: else {
0526: double oneSE = Math.sqrt(minError * (1 - minError)
0527: / data.numInstances());
0528: for (int i = 0; i < errorList.size(); i++) {
0529: double error = ((Double) (errorList.elementAt(i)))
0530: .doubleValue();
0531: if (error <= minError + oneSE) { // && counts[i]>=m_numFoldsPruning/2) {
0532: expansion = i;
0533: break;
0534: }
0535: }
0536: }
0537: }
0538:
0539: // build a postpruned tree
0540: else {
0541: FastVector[] modelError = new FastVector[m_numFoldsPruning];
0542:
0543: // calculate error of each expansion for each fold
0544: for (int i = 0; i < m_numFoldsPruning; i++) {
0545: modelError[i] = new FastVector();
0546:
0547: m_roots[i].m_isLeaf = true;
0548: Evaluation eval = new Evaluation(test[i]);
0549: eval.evaluateModel(m_roots[i], test[i]);
0550: double error;
0551: if (m_UseErrorRate)
0552: error = eval.errorRate();
0553: else
0554: error = eval.rootMeanSquaredError();
0555: modelError[i].addElement(new Double(error));
0556:
0557: m_roots[i].m_isLeaf = false;
0558: BFTree nodeToSplit = (BFTree) (((FastVector) (parallelBFElements[i]
0559: .elementAt(0))).elementAt(0));
0560:
0561: m_roots[i].makeTree(parallelBFElements[i], m_roots[i],
0562: train[i], test[i], modelError[i],
0563: nodeToSplit.m_SortedIndices,
0564: nodeToSplit.m_Weights, nodeToSplit.m_Dists,
0565: nodeToSplit.m_ClassProbs,
0566: nodeToSplit.m_TotalWeight, nodeToSplit.m_Props,
0567: m_minNumObj, m_Heuristic, m_UseGini,
0568: m_UseErrorRate);
0569: m_roots[i] = null;
0570: }
0571:
0572: // find the expansion with minimal error rate
0573: double minError = Double.MAX_VALUE;
0574:
0575: int maxExpansion = modelError[0].size();
0576: for (int i = 1; i < modelError.length; i++) {
0577: if (modelError[i].size() > maxExpansion)
0578: maxExpansion = modelError[i].size();
0579: }
0580:
0581: double[] error = new double[maxExpansion];
0582: int[] counts = new int[maxExpansion];
0583: for (int i = 0; i < maxExpansion; i++) {
0584: counts[i] = 0;
0585: error[i] = 0;
0586: for (int j = 0; j < m_numFoldsPruning; j++) {
0587: if (i < modelError[j].size()) {
0588: error[i] += ((Double) modelError[j]
0589: .elementAt(i)).doubleValue();
0590: counts[i]++;
0591: }
0592: }
0593: error[i] = error[i] / counts[i]; //average error for each expansion
0594:
0595: if (error[i] < minError) {// && counts[i]>=m_numFoldsPruning/2) {
0596: minError = error[i];
0597: expansion = i;
0598: }
0599: }
0600:
0601: // the 1 SE rule choosen
0602: if (m_UseOneSE) {
0603: double oneSE = Math.sqrt(minError * (1 - minError)
0604: / data.numInstances());
0605: for (int i = 0; i < maxExpansion; i++) {
0606: if (error[i] <= minError + oneSE) { // && counts[i]>=m_numFoldsPruning/2) {
0607: expansion = i;
0608: break;
0609: }
0610: }
0611: }
0612: }
0613:
0614: // make tree on all data based on the expansion caculated
0615: // from cross-validation
0616:
0617: // calculate sorted indices, weights and initial class counts
0618: int[][] prune_sortedIndices = new int[data.numAttributes()][0];
0619: double[][] prune_weights = new double[data.numAttributes()][0];
0620: double[] prune_classProbs = new double[data.numClasses()];
0621: double prune_totalWeight = computeSortedInfo(data,
0622: prune_sortedIndices, prune_weights, prune_classProbs);
0623:
0624: // compute information of the best split for this node (include split attribute,
0625: // split value and gini gain)
0626: double[][][] prune_dists = new double[data.numAttributes()][2][data
0627: .numClasses()];
0628: double[][] prune_props = new double[data.numAttributes()][2];
0629: double[][] prune_totalSubsetWeights = new double[data
0630: .numAttributes()][2];
0631: FastVector prune_nodeInfo = computeSplitInfo(this , data,
0632: prune_sortedIndices, prune_weights, prune_dists,
0633: prune_props, prune_totalSubsetWeights, m_Heuristic,
0634: m_UseGini);
0635:
0636: // add the root node (with its split info) to BestFirstElements
0637: FastVector BestFirstElements = new FastVector();
0638: BestFirstElements.addElement(prune_nodeInfo);
0639:
0640: int attIndex = ((Attribute) prune_nodeInfo.elementAt(1))
0641: .index();
0642: m_Expansion = 0;
0643: makeTree(BestFirstElements, data, prune_sortedIndices,
0644: prune_weights, prune_dists, prune_classProbs,
0645: prune_totalWeight, prune_props[attIndex], m_minNumObj,
0646: m_Heuristic, m_UseGini, expansion);
0647: }
0648:
0649: /**
0650: * Recursively build a best-first decision tree.
0651: * Method for building a Best-First tree for a given number of expansions.
0652: * preExpasion is -1 means that no expansion is specified (just for a
0653: * tree without any pruning method). Pre-pruning and post-pruning methods also
0654: * use this method to build the final tree on all training data based on the
0655: * expansion calculated from internal cross-validation.
0656: *
0657: * @param BestFirstElements list to store BFTree nodes
0658: * @param data training data
0659: * @param sortedIndices sorted indices of the instances
0660: * @param weights weights of the instances
0661: * @param dists class distributions for each attribute
0662: * @param classProbs class probabilities of this node
0663: * @param totalWeight total weight of this node (note if the node
0664: * can not split, this value is not calculated.)
0665: * @param branchProps proportions of two subbranches
0666: * @param minNumObj minimal number of instances at leaf nodes
0667: * @param useHeuristic if use heuristic search for nominal attributes
0668: * in multi-class problem
0669: * @param useGini if use Gini index as splitting criterion
0670: * @param preExpansion the number of expansions the tree to be expanded
0671: * @throws Exception if something goes wrong
0672: */
0673: protected void makeTree(FastVector BestFirstElements,
0674: Instances data, int[][] sortedIndices, double[][] weights,
0675: double[][][] dists, double[] classProbs,
0676: double totalWeight, double[] branchProps, int minNumObj,
0677: boolean useHeuristic, boolean useGini, int preExpansion)
0678: throws Exception {
0679:
0680: if (BestFirstElements.size() == 0)
0681: return;
0682:
0683: ///////////////////////////////////////////////////////////////////////
0684: // All information about the node to split (the first BestFirst object in
0685: // BestFirstElements)
0686: FastVector firstElement = (FastVector) BestFirstElements
0687: .elementAt(0);
0688:
0689: // split attribute
0690: Attribute att = (Attribute) firstElement.elementAt(1);
0691:
0692: // info of split value or split string
0693: double splitValue = Double.NaN;
0694: String splitStr = null;
0695: if (att.isNumeric())
0696: splitValue = ((Double) firstElement.elementAt(2))
0697: .doubleValue();
0698: else {
0699: splitStr = ((String) firstElement.elementAt(2)).toString();
0700: }
0701:
0702: // the best gini gain or information gain of this node
0703: double gain = ((Double) firstElement.elementAt(3))
0704: .doubleValue();
0705: ///////////////////////////////////////////////////////////////////////
0706:
0707: if (m_ClassProbs == null) {
0708: m_SortedIndices = new int[sortedIndices.length][0];
0709: m_Weights = new double[weights.length][0];
0710: m_Dists = new double[dists.length][0][0];
0711: m_ClassProbs = new double[classProbs.length];
0712: m_Distribution = new double[classProbs.length];
0713: m_Props = new double[2];
0714:
0715: for (int i = 0; i < m_SortedIndices.length; i++) {
0716: m_SortedIndices[i] = sortedIndices[i];
0717: m_Weights[i] = weights[i];
0718: m_Dists[i] = dists[i];
0719: }
0720:
0721: System.arraycopy(classProbs, 0, m_ClassProbs, 0,
0722: classProbs.length);
0723: System.arraycopy(classProbs, 0, m_Distribution, 0,
0724: classProbs.length);
0725: System
0726: .arraycopy(branchProps, 0, m_Props, 0,
0727: m_Props.length);
0728: m_TotalWeight = totalWeight;
0729: if (Utils.sum(m_ClassProbs) != 0)
0730: Utils.normalize(m_ClassProbs);
0731: }
0732:
0733: // If no enough data or this node can not be split, find next node to split.
0734: if (totalWeight < 2 * minNumObj || branchProps[0] == 0
0735: || branchProps[1] == 0) {
0736: // remove the first element
0737: BestFirstElements.removeElementAt(0);
0738:
0739: makeLeaf(data);
0740: if (BestFirstElements.size() != 0) {
0741: FastVector nextSplitElement = (FastVector) BestFirstElements
0742: .elementAt(0);
0743: BFTree nextSplitNode = (BFTree) nextSplitElement
0744: .elementAt(0);
0745: nextSplitNode.makeTree(BestFirstElements, data,
0746: nextSplitNode.m_SortedIndices,
0747: nextSplitNode.m_Weights, nextSplitNode.m_Dists,
0748: nextSplitNode.m_ClassProbs,
0749: nextSplitNode.m_TotalWeight,
0750: nextSplitNode.m_Props, minNumObj, useHeuristic,
0751: useGini, preExpansion);
0752: }
0753: return;
0754: }
0755:
0756: // If gini gain or information gain is 0, make all nodes in the BestFirstElements leaf nodes
0757: // because these nodes are sorted descendingly according to gini gain or information gain.
0758: // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0).
0759: if (gain == 0 || preExpansion == m_Expansion) {
0760: for (int i = 0; i < BestFirstElements.size(); i++) {
0761: FastVector element = (FastVector) BestFirstElements
0762: .elementAt(i);
0763: BFTree node = (BFTree) element.elementAt(0);
0764: node.makeLeaf(data);
0765: }
0766: BestFirstElements.removeAllElements();
0767: }
0768:
0769: // gain is not 0
0770: else {
0771: // remove the first element
0772: BestFirstElements.removeElementAt(0);
0773:
0774: m_Attribute = att;
0775: if (m_Attribute.isNumeric())
0776: m_SplitValue = splitValue;
0777: else
0778: m_SplitString = splitStr;
0779:
0780: int[][][] subsetIndices = new int[2][data.numAttributes()][0];
0781: double[][][] subsetWeights = new double[2][data
0782: .numAttributes()][0];
0783:
0784: splitData(subsetIndices, subsetWeights, m_Attribute,
0785: m_SplitValue, m_SplitString, sortedIndices,
0786: weights, data);
0787:
0788: // If split will generate node(s) which has total weights less than m_minNumObj,
0789: // do not split.
0790: int attIndex = att.index();
0791: if (subsetIndices[0][attIndex].length < minNumObj
0792: || subsetIndices[1][attIndex].length < minNumObj) {
0793: makeLeaf(data);
0794: }
0795:
0796: // split the node
0797: else {
0798: m_isLeaf = false;
0799: m_Attribute = att;
0800:
0801: // if expansion is specified (if pruning method used)
0802: if ((m_PruningStrategy == PRUNING_PREPRUNING)
0803: || (m_PruningStrategy == PRUNING_POSTPRUNING)
0804: || (preExpansion != -1))
0805: m_Expansion++;
0806:
0807: makeSuccessors(BestFirstElements, data, subsetIndices,
0808: subsetWeights, dists, att, useHeuristic,
0809: useGini);
0810: }
0811:
0812: // choose next node to split
0813: if (BestFirstElements.size() != 0) {
0814: FastVector nextSplitElement = (FastVector) BestFirstElements
0815: .elementAt(0);
0816: BFTree nextSplitNode = (BFTree) nextSplitElement
0817: .elementAt(0);
0818: nextSplitNode.makeTree(BestFirstElements, data,
0819: nextSplitNode.m_SortedIndices,
0820: nextSplitNode.m_Weights, nextSplitNode.m_Dists,
0821: nextSplitNode.m_ClassProbs,
0822: nextSplitNode.m_TotalWeight,
0823: nextSplitNode.m_Props, minNumObj, useHeuristic,
0824: useGini, preExpansion);
0825: }
0826:
0827: }
0828: }
0829:
0830: /**
0831: * This method is to find the number of expansions based on internal
0832: * cross-validation for just pre-pruning. It expands the first BestFirst
0833: * node in the BestFirstElements if it is expansible, otherwise it looks
0834: * for next exapansible node. If it finds a node is expansibel, expand the
0835: * node, then return true. (note it just expands one node at a time).
0836: *
0837: * @param BestFirstElements list to store BFTree nodes
0838: * @param root root node of tree in each fold
0839: * @param train training data
0840: * @param sortedIndices sorted indices of the instances
0841: * @param weights weights of the instances
0842: * @param dists class distributions for each attribute
0843: * @param classProbs class probabilities of this node
0844: * @param totalWeight total weight of this node (note if the node
0845: * can not split, this value is not calculated.)
0846: * @param branchProps proportions of two subbranches
0847: * @param minNumObj minimal number of instances at leaf nodes
0848: * @param useHeuristic if use heuristic search for nominal attributes
0849: * in multi-class problem
0850: * @param useGini if use Gini index as splitting criterion
0851: * @return true if expand successfully, otherwise return false
0852: * (all nodes in BestFirstElements cannot be
0853: * expanded).
0854: * @throws Exception if something goes wrong
0855: */
0856: protected boolean makeTree(FastVector BestFirstElements,
0857: BFTree root, Instances train, int[][] sortedIndices,
0858: double[][] weights, double[][][] dists,
0859: double[] classProbs, double totalWeight,
0860: double[] branchProps, int minNumObj, boolean useHeuristic,
0861: boolean useGini) throws Exception {
0862:
0863: if (BestFirstElements.size() == 0)
0864: return false;
0865:
0866: ///////////////////////////////////////////////////////////////////////
0867: // All information about the node to split (first BestFirst object in
0868: // BestFirstElements)
0869: FastVector firstElement = (FastVector) BestFirstElements
0870: .elementAt(0);
0871:
0872: // node to split
0873: BFTree nodeToSplit = (BFTree) firstElement.elementAt(0);
0874:
0875: // split attribute
0876: Attribute att = (Attribute) firstElement.elementAt(1);
0877:
0878: // info of split value or split string
0879: double splitValue = Double.NaN;
0880: String splitStr = null;
0881: if (att.isNumeric())
0882: splitValue = ((Double) firstElement.elementAt(2))
0883: .doubleValue();
0884: else {
0885: splitStr = ((String) firstElement.elementAt(2)).toString();
0886: }
0887:
0888: // the best gini gain or information gain of this node
0889: double gain = ((Double) firstElement.elementAt(3))
0890: .doubleValue();
0891: ///////////////////////////////////////////////////////////////////////
0892:
0893: // If no enough data to split for this node or this node can not be split find next node to split.
0894: if (totalWeight < 2 * minNumObj || branchProps[0] == 0
0895: || branchProps[1] == 0) {
0896: // remove the first element
0897: BestFirstElements.removeElementAt(0);
0898: nodeToSplit.makeLeaf(train);
0899: BFTree nextNode = (BFTree) ((FastVector) BestFirstElements
0900: .elementAt(0)).elementAt(0);
0901: return root.makeTree(BestFirstElements, root, train,
0902: nextNode.m_SortedIndices, nextNode.m_Weights,
0903: nextNode.m_Dists, nextNode.m_ClassProbs,
0904: nextNode.m_TotalWeight, nextNode.m_Props,
0905: minNumObj, useHeuristic, useGini);
0906: }
0907:
0908: // If gini gain or information is 0, make all nodes in the BestFirstElements leaf nodes
0909: // because these node sorted descendingly according to gini gain or information gain.
0910: // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0).
0911: if (gain == 0) {
0912: for (int i = 0; i < BestFirstElements.size(); i++) {
0913: FastVector element = (FastVector) BestFirstElements
0914: .elementAt(i);
0915: BFTree node = (BFTree) element.elementAt(0);
0916: node.makeLeaf(train);
0917: }
0918: BestFirstElements.removeAllElements();
0919: return false;
0920: }
0921:
0922: else {
0923: // remove the first element
0924: BestFirstElements.removeElementAt(0);
0925: nodeToSplit.m_Attribute = att;
0926: if (att.isNumeric())
0927: nodeToSplit.m_SplitValue = splitValue;
0928: else
0929: nodeToSplit.m_SplitString = splitStr;
0930:
0931: int[][][] subsetIndices = new int[2][train.numAttributes()][0];
0932: double[][][] subsetWeights = new double[2][train
0933: .numAttributes()][0];
0934:
0935: splitData(subsetIndices, subsetWeights,
0936: nodeToSplit.m_Attribute, nodeToSplit.m_SplitValue,
0937: nodeToSplit.m_SplitString,
0938: nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights,
0939: train);
0940:
0941: // if split will generate node(s) which has total weights less than m_minNumObj,
0942: // do not split
0943: int attIndex = att.index();
0944: if (subsetIndices[0][attIndex].length < minNumObj
0945: || subsetIndices[1][attIndex].length < minNumObj) {
0946:
0947: nodeToSplit.makeLeaf(train);
0948: BFTree nextNode = (BFTree) ((FastVector) BestFirstElements
0949: .elementAt(0)).elementAt(0);
0950: return root.makeTree(BestFirstElements, root, train,
0951: nextNode.m_SortedIndices, nextNode.m_Weights,
0952: nextNode.m_Dists, nextNode.m_ClassProbs,
0953: nextNode.m_TotalWeight, nextNode.m_Props,
0954: minNumObj, useHeuristic, useGini);
0955: }
0956:
0957: // split the node
0958: else {
0959: nodeToSplit.m_isLeaf = false;
0960: nodeToSplit.m_Attribute = att;
0961:
0962: nodeToSplit.makeSuccessors(BestFirstElements, train,
0963: subsetIndices, subsetWeights, dists,
0964: nodeToSplit.m_Attribute, useHeuristic, useGini);
0965:
0966: for (int i = 0; i < 2; i++) {
0967: nodeToSplit.m_Successors[i].makeLeaf(train);
0968: }
0969:
0970: return true;
0971: }
0972: }
0973: }
0974:
0975: /**
0976: * This method is to find the number of expansions based on internal
0977: * cross-validation for just post-pruning. It expands the first BestFirst
0978: * node in the BestFirstElements until no node can be split. When building
0979: * the tree, stroe error for each temporary tree, namely for each expansion.
0980: *
0981: * @param BestFirstElements list to store BFTree nodes
0982: * @param root root node of tree in each fold
0983: * @param train training data in each fold
0984: * @param test test data in each fold
0985: * @param modelError list to store error for each expansion in
0986: * each fold
0987: * @param sortedIndices sorted indices of the instances
0988: * @param weights weights of the instances
0989: * @param dists class distributions for each attribute
0990: * @param classProbs class probabilities of this node
0991: * @param totalWeight total weight of this node (note if the node
0992: * can not split, this value is not calculated.)
0993: * @param branchProps proportions of two subbranches
0994: * @param minNumObj minimal number of instances at leaf nodes
0995: * @param useHeuristic if use heuristic search for nominal attributes
0996: * in multi-class problem
0997: * @param useGini if use Gini index as splitting criterion
0998: * @param useErrorRate if use error rate in internal cross-validation
0999: * @throws Exception if something goes wrong
1000: */
1001: protected void makeTree(FastVector BestFirstElements, BFTree root,
1002: Instances train, Instances test, FastVector modelError,
1003: int[][] sortedIndices, double[][] weights,
1004: double[][][] dists, double[] classProbs,
1005: double totalWeight, double[] branchProps, int minNumObj,
1006: boolean useHeuristic, boolean useGini, boolean useErrorRate)
1007: throws Exception {
1008:
1009: if (BestFirstElements.size() == 0)
1010: return;
1011:
1012: ///////////////////////////////////////////////////////////////////////
1013: // All information about the node to split (first BestFirst object in
1014: // BestFirstElements)
1015: FastVector firstElement = (FastVector) BestFirstElements
1016: .elementAt(0);
1017:
1018: // node to split
1019: //BFTree nodeToSplit = (BFTree)firstElement.elementAt(0);
1020:
1021: // split attribute
1022: Attribute att = (Attribute) firstElement.elementAt(1);
1023:
1024: // info of split value or split string
1025: double splitValue = Double.NaN;
1026: String splitStr = null;
1027: if (att.isNumeric())
1028: splitValue = ((Double) firstElement.elementAt(2))
1029: .doubleValue();
1030: else {
1031: splitStr = ((String) firstElement.elementAt(2)).toString();
1032: }
1033:
1034: // the best gini gain or information of this node
1035: double gain = ((Double) firstElement.elementAt(3))
1036: .doubleValue();
1037: ///////////////////////////////////////////////////////////////////////
1038:
1039: if (totalWeight < 2 * minNumObj || branchProps[0] == 0
1040: || branchProps[1] == 0) {
1041: // remove the first element
1042: BestFirstElements.removeElementAt(0);
1043: makeLeaf(train);
1044: BFTree nextSplitNode = (BFTree) ((FastVector) BestFirstElements
1045: .elementAt(0)).elementAt(0);
1046: nextSplitNode.makeTree(BestFirstElements, root, train,
1047: test, modelError, nextSplitNode.m_SortedIndices,
1048: nextSplitNode.m_Weights, nextSplitNode.m_Dists,
1049: nextSplitNode.m_ClassProbs,
1050: nextSplitNode.m_TotalWeight, nextSplitNode.m_Props,
1051: minNumObj, useHeuristic, useGini, useErrorRate);
1052: return;
1053:
1054: }
1055:
1056: // If gini gain or information gain is 0, make all nodes in the BestFirstElements leaf nodes
1057: // because these node sorted descendingly according to gini gain or information gain.
1058: // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0).
1059: if (gain == 0) {
1060: for (int i = 0; i < BestFirstElements.size(); i++) {
1061: FastVector element = (FastVector) BestFirstElements
1062: .elementAt(i);
1063: BFTree node = (BFTree) element.elementAt(0);
1064: node.makeLeaf(train);
1065: }
1066: BestFirstElements.removeAllElements();
1067: }
1068:
1069: // gini gain or information gain is not 0
1070: else {
1071: // remove the first element
1072: BestFirstElements.removeElementAt(0);
1073: m_Attribute = att;
1074: if (att.isNumeric())
1075: m_SplitValue = splitValue;
1076: else
1077: m_SplitString = splitStr;
1078:
1079: int[][][] subsetIndices = new int[2][train.numAttributes()][0];
1080: double[][][] subsetWeights = new double[2][train
1081: .numAttributes()][0];
1082:
1083: splitData(subsetIndices, subsetWeights, m_Attribute,
1084: m_SplitValue, m_SplitString, sortedIndices,
1085: weights, train);
1086:
1087: // if split will generate node(s) which has total weights less than m_minNumObj,
1088: // do not split
1089: int attIndex = att.index();
1090: if (subsetIndices[0][attIndex].length < minNumObj
1091: || subsetIndices[1][attIndex].length < minNumObj) {
1092: makeLeaf(train);
1093: }
1094:
1095: // split the node and cauculate error rate of this temporary tree
1096: else {
1097: m_isLeaf = false;
1098: m_Attribute = att;
1099:
1100: makeSuccessors(BestFirstElements, train, subsetIndices,
1101: subsetWeights, dists, m_Attribute,
1102: useHeuristic, useGini);
1103: for (int i = 0; i < 2; i++) {
1104: m_Successors[i].makeLeaf(train);
1105: }
1106:
1107: Evaluation eval = new Evaluation(test);
1108: eval.evaluateModel(root, test);
1109: double error;
1110: if (useErrorRate)
1111: error = eval.errorRate();
1112: else
1113: error = eval.rootMeanSquaredError();
1114: modelError.addElement(new Double(error));
1115: }
1116:
1117: if (BestFirstElements.size() != 0) {
1118: FastVector nextSplitElement = (FastVector) BestFirstElements
1119: .elementAt(0);
1120: BFTree nextSplitNode = (BFTree) nextSplitElement
1121: .elementAt(0);
1122: nextSplitNode.makeTree(BestFirstElements, root, train,
1123: test, modelError,
1124: nextSplitNode.m_SortedIndices,
1125: nextSplitNode.m_Weights, nextSplitNode.m_Dists,
1126: nextSplitNode.m_ClassProbs,
1127: nextSplitNode.m_TotalWeight,
1128: nextSplitNode.m_Props, minNumObj, useHeuristic,
1129: useGini, useErrorRate);
1130: }
1131: }
1132: }
1133:
1134: /**
1135: * Generate successor nodes for a node and put them into BestFirstElements
1136: * according to gini gain or information gain in a descending order.
1137: *
1138: * @param BestFirstElements list to store BestFirst nodes
1139: * @param data training instance
1140: * @param subsetSortedIndices sorted indices of instances of successor nodes
1141: * @param subsetWeights weights of instances of successor nodes
1142: * @param dists class distributions of successor nodes
1143: * @param att attribute used to split the node
1144: * @param useHeuristic if use heuristic search for nominal attributes in multi-class problem
1145: * @param useGini if use Gini index as splitting criterion
1146: * @throws Exception if something goes wrong
1147: */
1148: protected void makeSuccessors(FastVector BestFirstElements,
1149: Instances data, int[][][] subsetSortedIndices,
1150: double[][][] subsetWeights, double[][][] dists,
1151: Attribute att, boolean useHeuristic, boolean useGini)
1152: throws Exception {
1153:
1154: m_Successors = new BFTree[2];
1155:
1156: for (int i = 0; i < 2; i++) {
1157: m_Successors[i] = new BFTree();
1158: m_Successors[i].m_isLeaf = true;
1159:
1160: // class probability and distribution for this successor node
1161: m_Successors[i].m_ClassProbs = new double[data.numClasses()];
1162: m_Successors[i].m_Distribution = new double[data
1163: .numClasses()];
1164: System.arraycopy(dists[att.index()][i], 0,
1165: m_Successors[i].m_ClassProbs, 0,
1166: m_Successors[i].m_ClassProbs.length);
1167: System.arraycopy(dists[att.index()][i], 0,
1168: m_Successors[i].m_Distribution, 0,
1169: m_Successors[i].m_Distribution.length);
1170: if (Utils.sum(m_Successors[i].m_ClassProbs) != 0)
1171: Utils.normalize(m_Successors[i].m_ClassProbs);
1172:
1173: // split information for this successor node
1174: double[][] props = new double[data.numAttributes()][2];
1175: double[][][] subDists = new double[data.numAttributes()][2][data
1176: .numClasses()];
1177: double[][] totalSubsetWeights = new double[data
1178: .numAttributes()][2];
1179: FastVector splitInfo = m_Successors[i].computeSplitInfo(
1180: m_Successors[i], data, subsetSortedIndices[i],
1181: subsetWeights[i], subDists, props,
1182: totalSubsetWeights, useHeuristic, useGini);
1183:
1184: // branch proportion for this successor node
1185: int splitIndex = ((Attribute) splitInfo.elementAt(1))
1186: .index();
1187: m_Successors[i].m_Props = new double[2];
1188: System.arraycopy(props[splitIndex], 0,
1189: m_Successors[i].m_Props, 0,
1190: m_Successors[i].m_Props.length);
1191:
1192: // sorted indices and weights of each attribute for this successor node
1193: m_Successors[i].m_SortedIndices = new int[data
1194: .numAttributes()][0];
1195: m_Successors[i].m_Weights = new double[data.numAttributes()][0];
1196: for (int j = 0; j < m_Successors[i].m_SortedIndices.length; j++) {
1197: m_Successors[i].m_SortedIndices[j] = subsetSortedIndices[i][j];
1198: m_Successors[i].m_Weights[j] = subsetWeights[i][j];
1199: }
1200:
1201: // distribution of each attribute for this successor node
1202: m_Successors[i].m_Dists = new double[data.numAttributes()][2][data
1203: .numClasses()];
1204: for (int j = 0; j < subDists.length; j++) {
1205: m_Successors[i].m_Dists[j] = subDists[j];
1206: }
1207:
1208: // total weights for this successor node.
1209: m_Successors[i].m_TotalWeight = Utils
1210: .sum(totalSubsetWeights[splitIndex]);
1211:
1212: // insert this successor node into BestFirstElements according to gini gain or information gain
1213: // descendingly
1214: if (BestFirstElements.size() == 0) {
1215: BestFirstElements.addElement(splitInfo);
1216: } else {
1217: double gGain = ((Double) (splitInfo.elementAt(3)))
1218: .doubleValue();
1219: int vectorSize = BestFirstElements.size();
1220: FastVector lastNode = (FastVector) BestFirstElements
1221: .elementAt(vectorSize - 1);
1222:
1223: // If gini gain is less than that of last node in FastVector
1224: if (gGain < ((Double) (lastNode.elementAt(3)))
1225: .doubleValue()) {
1226: BestFirstElements.insertElementAt(splitInfo,
1227: vectorSize);
1228: } else {
1229: for (int j = 0; j < vectorSize; j++) {
1230: FastVector node = (FastVector) BestFirstElements
1231: .elementAt(j);
1232: double nodeGain = ((Double) (node.elementAt(3)))
1233: .doubleValue();
1234: if (gGain >= nodeGain) {
1235: BestFirstElements.insertElementAt(
1236: splitInfo, j);
1237: break;
1238: }
1239: }
1240: }
1241: }
1242: }
1243: }
1244:
1245: /**
1246: * Compute sorted indices, weights and class probabilities for a given
1247: * dataset. Return total weights of the data at the node.
1248: *
1249: * @param data training data
1250: * @param sortedIndices sorted indices of instances at the node
1251: * @param weights weights of instances at the node
1252: * @param classProbs class probabilities at the node
1253: * @return total weights of instances at the node
1254: * @throws Exception if something goes wrong
1255: */
1256: protected double computeSortedInfo(Instances data,
1257: int[][] sortedIndices, double[][] weights,
1258: double[] classProbs) throws Exception {
1259:
1260: // Create array of sorted indices and weights
1261: double[] vals = new double[data.numInstances()];
1262: for (int j = 0; j < data.numAttributes(); j++) {
1263: if (j == data.classIndex())
1264: continue;
1265: weights[j] = new double[data.numInstances()];
1266:
1267: if (data.attribute(j).isNominal()) {
1268:
1269: // Handling nominal attributes. Putting indices of
1270: // instances with missing values at the end.
1271: sortedIndices[j] = new int[data.numInstances()];
1272: int count = 0;
1273: for (int i = 0; i < data.numInstances(); i++) {
1274: Instance inst = data.instance(i);
1275: if (!inst.isMissing(j)) {
1276: sortedIndices[j][count] = i;
1277: weights[j][count] = inst.weight();
1278: count++;
1279: }
1280: }
1281: for (int i = 0; i < data.numInstances(); i++) {
1282: Instance inst = data.instance(i);
1283: if (inst.isMissing(j)) {
1284: sortedIndices[j][count] = i;
1285: weights[j][count] = inst.weight();
1286: count++;
1287: }
1288: }
1289: } else {
1290:
1291: // Sorted indices are computed for numeric attributes
1292: // missing values instances are put to end (through Utils.sort() method)
1293: for (int i = 0; i < data.numInstances(); i++) {
1294: Instance inst = data.instance(i);
1295: vals[i] = inst.value(j);
1296: }
1297: sortedIndices[j] = Utils.sort(vals);
1298: for (int i = 0; i < data.numInstances(); i++) {
1299: weights[j][i] = data.instance(sortedIndices[j][i])
1300: .weight();
1301: }
1302: }
1303: }
1304:
1305: // Compute initial class counts and total weight
1306: double totalWeight = 0;
1307: for (int i = 0; i < data.numInstances(); i++) {
1308: Instance inst = data.instance(i);
1309: classProbs[(int) inst.classValue()] += inst.weight();
1310: totalWeight += inst.weight();
1311: }
1312:
1313: return totalWeight;
1314: }
1315:
1316: /**
1317: * Compute the best splitting attribute, split point or subset and the best
1318: * gini gain or iformation gain for a given dataset.
1319: *
1320: * @param node node to be split
1321: * @param data training data
1322: * @param sortedIndices sorted indices of the instances
1323: * @param weights weights of the instances
1324: * @param dists class distributions for each attribute
1325: * @param props proportions of two branches
1326: * @param totalSubsetWeights total weight of two subsets
1327: * @param useHeuristic if use heuristic search for nominal attributes
1328: * in multi-class problem
1329: * @param useGini if use Gini index as splitting criterion
1330: * @return split information about the node
1331: * @throws Exception if something is wrong
1332: */
1333: protected FastVector computeSplitInfo(BFTree node, Instances data,
1334: int[][] sortedIndices, double[][] weights,
1335: double[][][] dists, double[][] props,
1336: double[][] totalSubsetWeights, boolean useHeuristic,
1337: boolean useGini) throws Exception {
1338:
1339: double[] splits = new double[data.numAttributes()];
1340: String[] splitString = new String[data.numAttributes()];
1341: double[] gains = new double[data.numAttributes()];
1342:
1343: for (int i = 0; i < data.numAttributes(); i++) {
1344: if (i == data.classIndex())
1345: continue;
1346: Attribute att = data.attribute(i);
1347: if (att.isNumeric()) {
1348: // numeric attribute
1349: splits[i] = numericDistribution(props, dists, att,
1350: sortedIndices[i], weights[i],
1351: totalSubsetWeights, gains, data, useGini);
1352: } else {
1353: // nominal attribute
1354: splitString[i] = nominalDistribution(props, dists, att,
1355: sortedIndices[i], weights[i],
1356: totalSubsetWeights, gains, data, useHeuristic,
1357: useGini);
1358: }
1359: }
1360:
1361: int index = Utils.maxIndex(gains);
1362: double mBestGain = gains[index];
1363:
1364: Attribute att = data.attribute(index);
1365: double mValue = Double.NaN;
1366: String mString = null;
1367: if (att.isNumeric())
1368: mValue = splits[index];
1369: else {
1370: mString = splitString[index];
1371: if (mString == null)
1372: mString = "";
1373: }
1374:
1375: // split information
1376: FastVector splitInfo = new FastVector();
1377: splitInfo.addElement(node);
1378: splitInfo.addElement(att);
1379: if (att.isNumeric())
1380: splitInfo.addElement(new Double(mValue));
1381: else
1382: splitInfo.addElement(mString);
1383: splitInfo.addElement(new Double(mBestGain));
1384:
1385: return splitInfo;
1386: }
1387:
1388: /**
1389: * Compute distributions, proportions and total weights of two successor nodes for
1390: * a given numeric attribute.
1391: *
1392: * @param props proportions of each two branches for each attribute
1393: * @param dists class distributions of two branches for each attribute
1394: * @param att numeric att split on
1395: * @param sortedIndices sorted indices of instances for the attirubte
1396: * @param weights weights of instances for the attirbute
1397: * @param subsetWeights total weight of two branches split based on the attribute
1398: * @param gains Gini gains or information gains for each attribute
1399: * @param data training instances
1400: * @param useGini if use Gini index as splitting criterion
1401: * @return Gini gain or information gain for the given attribute
1402: * @throws Exception if something goes wrong
1403: */
1404: protected double numericDistribution(double[][] props,
1405: double[][][] dists, Attribute att, int[] sortedIndices,
1406: double[] weights, double[][] subsetWeights, double[] gains,
1407: Instances data, boolean useGini) throws Exception {
1408:
1409: double splitPoint = Double.NaN;
1410: double[][] dist = null;
1411: int numClasses = data.numClasses();
1412: int i; // differ instances with or without missing values
1413:
1414: double[][] currDist = new double[2][numClasses];
1415: dist = new double[2][numClasses];
1416:
1417: // Move all instances without missing values into second subset
1418: double[] parentDist = new double[numClasses];
1419: int missingStart = 0;
1420: for (int j = 0; j < sortedIndices.length; j++) {
1421: Instance inst = data.instance(sortedIndices[j]);
1422: if (!inst.isMissing(att)) {
1423: missingStart++;
1424: currDist[1][(int) inst.classValue()] += weights[j];
1425: }
1426: parentDist[(int) inst.classValue()] += weights[j];
1427: }
1428: System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);
1429:
1430: // Try all possible split points
1431: double currSplit = data.instance(sortedIndices[0]).value(att);
1432: double currGain;
1433: double bestGain = -Double.MAX_VALUE;
1434:
1435: for (i = 0; i < sortedIndices.length; i++) {
1436: Instance inst = data.instance(sortedIndices[i]);
1437: if (inst.isMissing(att)) {
1438: break;
1439: }
1440: if (inst.value(att) > currSplit) {
1441:
1442: double[][] tempDist = new double[2][numClasses];
1443: for (int k = 0; k < 2; k++) {
1444: //tempDist[k] = currDist[k];
1445: System.arraycopy(currDist[k], 0, tempDist[k], 0,
1446: tempDist[k].length);
1447: }
1448:
1449: double[] tempProps = new double[2];
1450: for (int k = 0; k < 2; k++) {
1451: tempProps[k] = Utils.sum(tempDist[k]);
1452: }
1453:
1454: if (Utils.sum(tempProps) != 0)
1455: Utils.normalize(tempProps);
1456:
1457: // split missing values
1458: int index = missingStart;
1459: while (index < sortedIndices.length) {
1460: Instance insta = data
1461: .instance(sortedIndices[index]);
1462: for (int j = 0; j < 2; j++) {
1463: tempDist[j][(int) insta.classValue()] += tempProps[j]
1464: * weights[index];
1465: }
1466: index++;
1467: }
1468:
1469: if (useGini)
1470: currGain = computeGiniGain(parentDist, tempDist);
1471: else
1472: currGain = computeInfoGain(parentDist, tempDist);
1473:
1474: if (currGain > bestGain) {
1475: bestGain = currGain;
1476: // clean split point
1477: splitPoint = Math
1478: .rint((inst.value(att) + currSplit) / 2.0 * 100000) / 100000.0;
1479: for (int j = 0; j < currDist.length; j++) {
1480: System.arraycopy(tempDist[j], 0, dist[j], 0,
1481: dist[j].length);
1482: }
1483: }
1484: }
1485: currSplit = inst.value(att);
1486: currDist[0][(int) inst.classValue()] += weights[i];
1487: currDist[1][(int) inst.classValue()] -= weights[i];
1488: }
1489:
1490: // Compute weights
1491: int attIndex = att.index();
1492: props[attIndex] = new double[2];
1493: for (int k = 0; k < 2; k++) {
1494: props[attIndex][k] = Utils.sum(dist[k]);
1495: }
1496: if (Utils.sum(props[attIndex]) != 0)
1497: Utils.normalize(props[attIndex]);
1498:
1499: // Compute subset weights
1500: subsetWeights[attIndex] = new double[2];
1501: for (int j = 0; j < 2; j++) {
1502: subsetWeights[attIndex][j] += Utils.sum(dist[j]);
1503: }
1504:
1505: // clean gain
1506: gains[attIndex] = Math.rint(bestGain * 10000000) / 10000000.0;
1507: dists[attIndex] = dist;
1508: return splitPoint;
1509: }
1510:
1511: /**
1512: * Compute distributions, proportions and total weights of two successor
1513: * nodes for a given nominal attribute.
1514: *
1515: * @param props proportions of each two branches for each attribute
1516: * @param dists class distributions of two branches for each attribute
1517: * @param att numeric att split on
1518: * @param sortedIndices sorted indices of instances for the attirubte
1519: * @param weights weights of instances for the attirbute
1520: * @param subsetWeights total weight of two branches split based on the attribute
1521: * @param gains Gini gains for each attribute
1522: * @param data training instances
1523: * @param useHeuristic if use heuristic search
1524: * @param useGini if use Gini index as splitting criterion
1525: * @return Gini gain for the given attribute
1526: * @throws Exception if something goes wrong
1527: */
1528: protected String nominalDistribution(double[][] props,
1529: double[][][] dists, Attribute att, int[] sortedIndices,
1530: double[] weights, double[][] subsetWeights, double[] gains,
1531: Instances data, boolean useHeuristic, boolean useGini)
1532: throws Exception {
1533:
1534: String[] values = new String[att.numValues()];
1535: int numCat = values.length; // number of values of the attribute
1536: int numClasses = data.numClasses();
1537:
1538: String bestSplitString = "";
1539: double bestGain = -Double.MAX_VALUE;
1540:
1541: // class frequency for each value
1542: int[] classFreq = new int[numCat];
1543: for (int j = 0; j < numCat; j++)
1544: classFreq[j] = 0;
1545:
1546: double[] parentDist = new double[numClasses];
1547: double[][] currDist = new double[2][numClasses];
1548: double[][] dist = new double[2][numClasses];
1549: int missingStart = 0;
1550:
1551: for (int i = 0; i < sortedIndices.length; i++) {
1552: Instance inst = data.instance(sortedIndices[i]);
1553: if (!inst.isMissing(att)) {
1554: missingStart++;
1555: classFreq[(int) inst.value(att)]++;
1556: }
1557: parentDist[(int) inst.classValue()] += weights[i];
1558: }
1559:
1560: // count the number of values that class frequency is not 0
1561: int nonEmpty = 0;
1562: for (int j = 0; j < numCat; j++) {
1563: if (classFreq[j] != 0)
1564: nonEmpty++;
1565: }
1566:
1567: // attribute values which class frequency is not 0
1568: String[] nonEmptyValues = new String[nonEmpty];
1569: int nonEmptyIndex = 0;
1570: for (int j = 0; j < numCat; j++) {
1571: if (classFreq[j] != 0) {
1572: nonEmptyValues[nonEmptyIndex] = att.value(j);
1573: nonEmptyIndex++;
1574: }
1575: }
1576:
1577: // attribute values which class frequency is 0
1578: int empty = numCat - nonEmpty;
1579: String[] emptyValues = new String[empty];
1580: int emptyIndex = 0;
1581: for (int j = 0; j < numCat; j++) {
1582: if (classFreq[j] == 0) {
1583: emptyValues[emptyIndex] = att.value(j);
1584: emptyIndex++;
1585: }
1586: }
1587:
1588: if (nonEmpty <= 1) {
1589: gains[att.index()] = 0;
1590: return "";
1591: }
1592:
1593: // for tow-class probloms
1594: if (data.numClasses() == 2) {
1595:
1596: //// Firstly, for attribute values which class frequency is not zero
1597:
1598: // probability of class 0 for each attribute value
1599: double[] pClass0 = new double[nonEmpty];
1600: // class distribution for each attribute value
1601: double[][] valDist = new double[nonEmpty][2];
1602:
1603: for (int j = 0; j < nonEmpty; j++) {
1604: for (int k = 0; k < 2; k++) {
1605: valDist[j][k] = 0;
1606: }
1607: }
1608:
1609: for (int i = 0; i < sortedIndices.length; i++) {
1610: Instance inst = data.instance(sortedIndices[i]);
1611: if (inst.isMissing(att)) {
1612: break;
1613: }
1614:
1615: for (int j = 0; j < nonEmpty; j++) {
1616: if (att.value((int) inst.value(att)).compareTo(
1617: nonEmptyValues[j]) == 0) {
1618: valDist[j][(int) inst.classValue()] += inst
1619: .weight();
1620: break;
1621: }
1622: }
1623: }
1624:
1625: for (int j = 0; j < nonEmpty; j++) {
1626: double distSum = Utils.sum(valDist[j]);
1627: if (distSum == 0)
1628: pClass0[j] = 0;
1629: else
1630: pClass0[j] = valDist[j][0] / distSum;
1631: }
1632:
1633: // sort category according to the probability of class 0.0
1634: String[] sortedValues = new String[nonEmpty];
1635: for (int j = 0; j < nonEmpty; j++) {
1636: sortedValues[j] = nonEmptyValues[Utils
1637: .minIndex(pClass0)];
1638: pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE;
1639: }
1640:
1641: // Find a subset of attribute values that maximize impurity decrease
1642:
1643: // for the attribute values that class frequency is not 0
1644: String tempStr = "";
1645:
1646: for (int j = 0; j < nonEmpty - 1; j++) {
1647: currDist = new double[2][numClasses];
1648: if (tempStr == "")
1649: tempStr = "(" + sortedValues[j] + ")";
1650: else
1651: tempStr += "|" + "(" + sortedValues[j] + ")";
1652: //System.out.println(sortedValues[j]);
1653: for (int i = 0; i < sortedIndices.length; i++) {
1654: Instance inst = data.instance(sortedIndices[i]);
1655: if (inst.isMissing(att)) {
1656: break;
1657: }
1658:
1659: if (tempStr.indexOf("("
1660: + att.value((int) inst.value(att)) + ")") != -1) {
1661: currDist[0][(int) inst.classValue()] += weights[i];
1662: } else
1663: currDist[1][(int) inst.classValue()] += weights[i];
1664: }
1665:
1666: double[][] tempDist = new double[2][numClasses];
1667: for (int kk = 0; kk < 2; kk++) {
1668: tempDist[kk] = currDist[kk];
1669: }
1670:
1671: double[] tempProps = new double[2];
1672: for (int kk = 0; kk < 2; kk++) {
1673: tempProps[kk] = Utils.sum(tempDist[kk]);
1674: }
1675:
1676: if (Utils.sum(tempProps) != 0)
1677: Utils.normalize(tempProps);
1678:
1679: // split missing values
1680: int mstart = missingStart;
1681: while (mstart < sortedIndices.length) {
1682: Instance insta = data
1683: .instance(sortedIndices[mstart]);
1684: for (int jj = 0; jj < 2; jj++) {
1685: tempDist[jj][(int) insta.classValue()] += tempProps[jj]
1686: * weights[mstart];
1687: }
1688: mstart++;
1689: }
1690:
1691: double currGain;
1692: if (useGini)
1693: currGain = computeGiniGain(parentDist, tempDist);
1694: else
1695: currGain = computeInfoGain(parentDist, tempDist);
1696:
1697: if (currGain > bestGain) {
1698: bestGain = currGain;
1699: bestSplitString = tempStr;
1700: for (int jj = 0; jj < 2; jj++) {
1701: System.arraycopy(tempDist[jj], 0, dist[jj], 0,
1702: dist[jj].length);
1703: }
1704: }
1705: }
1706: }
1707:
1708: // multi-class problems (exhaustive search)
1709: else if (!useHeuristic || nonEmpty <= 4) {
1710: //else if (!useHeuristic || nonEmpty==2) {
1711:
1712: // Firstly, for attribute values which class frequency is not zero
1713: for (int i = 0; i < (int) Math.pow(2, nonEmpty - 1); i++) {
1714: String tempStr = "";
1715: currDist = new double[2][numClasses];
1716: int mod;
1717: int bit10 = i;
1718: for (int j = nonEmpty - 1; j >= 0; j--) {
1719: mod = bit10 % 2; // convert from 10bit to 2bit
1720: if (mod == 1) {
1721: if (tempStr == "")
1722: tempStr = "(" + nonEmptyValues[j] + ")";
1723: else
1724: tempStr += "|" + "(" + nonEmptyValues[j]
1725: + ")";
1726: }
1727: bit10 = bit10 / 2;
1728: }
1729: for (int j = 0; j < sortedIndices.length; j++) {
1730: Instance inst = data.instance(sortedIndices[j]);
1731: if (inst.isMissing(att)) {
1732: break;
1733: }
1734:
1735: if (tempStr.indexOf("("
1736: + att.value((int) inst.value(att)) + ")") != -1) {
1737: currDist[0][(int) inst.classValue()] += weights[j];
1738: } else
1739: currDist[1][(int) inst.classValue()] += weights[j];
1740: }
1741:
1742: double[][] tempDist = new double[2][numClasses];
1743: for (int k = 0; k < 2; k++) {
1744: tempDist[k] = currDist[k];
1745: }
1746:
1747: double[] tempProps = new double[2];
1748: for (int k = 0; k < 2; k++) {
1749: tempProps[k] = Utils.sum(tempDist[k]);
1750: }
1751:
1752: if (Utils.sum(tempProps) != 0)
1753: Utils.normalize(tempProps);
1754:
1755: // split missing values
1756: int index = missingStart;
1757: while (index < sortedIndices.length) {
1758: Instance insta = data
1759: .instance(sortedIndices[index]);
1760: for (int j = 0; j < 2; j++) {
1761: tempDist[j][(int) insta.classValue()] += tempProps[j]
1762: * weights[index];
1763: }
1764: index++;
1765: }
1766:
1767: double currGain;
1768: if (useGini)
1769: currGain = computeGiniGain(parentDist, tempDist);
1770: else
1771: currGain = computeInfoGain(parentDist, tempDist);
1772:
1773: if (currGain > bestGain) {
1774: bestGain = currGain;
1775: bestSplitString = tempStr;
1776: for (int j = 0; j < 2; j++) {
1777: //dist[jj] = new double[currDist[jj].length];
1778: System.arraycopy(tempDist[j], 0, dist[j], 0,
1779: dist[j].length);
1780: }
1781: }
1782: }
1783: }
1784:
1785: // huristic method to solve multi-classes problems
1786: else {
1787: // Firstly, for attribute values which class frequency is not zero
1788: int n = nonEmpty;
1789: int k = data.numClasses(); // number of classes of the data
1790: double[][] P = new double[n][k]; // class probability matrix
1791: int[] numInstancesValue = new int[n]; // number of instances for an attribute value
1792: double[] meanClass = new double[k]; // vector of mean class probability
1793: int numInstances = data.numInstances(); // total number of instances
1794:
1795: // initialize the vector of mean class probability
1796: for (int j = 0; j < meanClass.length; j++)
1797: meanClass[j] = 0;
1798:
1799: for (int j = 0; j < numInstances; j++) {
1800: Instance inst = (Instance) data.instance(j);
1801: int valueIndex = 0; // attribute value index in nonEmptyValues
1802: for (int i = 0; i < nonEmpty; i++) {
1803: if (att.value((int) inst.value(att))
1804: .compareToIgnoreCase(nonEmptyValues[i]) == 0) {
1805: valueIndex = i;
1806: break;
1807: }
1808: }
1809: P[valueIndex][(int) inst.classValue()]++;
1810: numInstancesValue[valueIndex]++;
1811: meanClass[(int) inst.classValue()]++;
1812: }
1813:
1814: // calculate the class probability matrix
1815: for (int i = 0; i < P.length; i++) {
1816: for (int j = 0; j < P[0].length; j++) {
1817: if (numInstancesValue[i] == 0)
1818: P[i][j] = 0;
1819: else
1820: P[i][j] /= numInstancesValue[i];
1821: }
1822: }
1823:
1824: //calculate the vector of mean class probability
1825: for (int i = 0; i < meanClass.length; i++) {
1826: meanClass[i] /= numInstances;
1827: }
1828:
1829: // calculate the covariance matrix
1830: double[][] covariance = new double[k][k];
1831: for (int i1 = 0; i1 < k; i1++) {
1832: for (int i2 = 0; i2 < k; i2++) {
1833: double element = 0;
1834: for (int j = 0; j < n; j++) {
1835: element += (P[j][i2] - meanClass[i2])
1836: * (P[j][i1] - meanClass[i1])
1837: * numInstancesValue[j];
1838: }
1839: covariance[i1][i2] = element;
1840: }
1841: }
1842:
1843: Matrix matrix = new Matrix(covariance);
1844: weka.core.matrix.EigenvalueDecomposition eigen = new weka.core.matrix.EigenvalueDecomposition(
1845: matrix);
1846: double[] eigenValues = eigen.getRealEigenvalues();
1847:
1848: // find index of the largest eigenvalue
1849: int index = 0;
1850: double largest = eigenValues[0];
1851: for (int i = 1; i < eigenValues.length; i++) {
1852: if (eigenValues[i] > largest) {
1853: index = i;
1854: largest = eigenValues[i];
1855: }
1856: }
1857:
1858: // calculate the first principle component
1859: double[] FPC = new double[k];
1860: Matrix eigenVector = eigen.getV();
1861: double[][] vectorArray = eigenVector.getArray();
1862: for (int i = 0; i < FPC.length; i++) {
1863: FPC[i] = vectorArray[i][index];
1864: }
1865:
1866: // calculate the first principle component scores
1867: double[] Sa = new double[n];
1868: for (int i = 0; i < Sa.length; i++) {
1869: Sa[i] = 0;
1870: for (int j = 0; j < k; j++) {
1871: Sa[i] += FPC[j] * P[i][j];
1872: }
1873: }
1874:
1875: // sort category according to Sa(s)
1876: double[] pCopy = new double[n];
1877: System.arraycopy(Sa, 0, pCopy, 0, n);
1878: String[] sortedValues = new String[n];
1879: Arrays.sort(Sa);
1880:
1881: for (int j = 0; j < n; j++) {
1882: sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)];
1883: pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE;
1884: }
1885:
1886: // for the attribute values that class frequency is not 0
1887: String tempStr = "";
1888:
1889: for (int j = 0; j < nonEmpty - 1; j++) {
1890: currDist = new double[2][numClasses];
1891: if (tempStr == "")
1892: tempStr = "(" + sortedValues[j] + ")";
1893: else
1894: tempStr += "|" + "(" + sortedValues[j] + ")";
1895: for (int i = 0; i < sortedIndices.length; i++) {
1896: Instance inst = data.instance(sortedIndices[i]);
1897: if (inst.isMissing(att)) {
1898: break;
1899: }
1900:
1901: if (tempStr.indexOf("("
1902: + att.value((int) inst.value(att)) + ")") != -1) {
1903: currDist[0][(int) inst.classValue()] += weights[i];
1904: } else
1905: currDist[1][(int) inst.classValue()] += weights[i];
1906: }
1907:
1908: double[][] tempDist = new double[2][numClasses];
1909: for (int kk = 0; kk < 2; kk++) {
1910: tempDist[kk] = currDist[kk];
1911: }
1912:
1913: double[] tempProps = new double[2];
1914: for (int kk = 0; kk < 2; kk++) {
1915: tempProps[kk] = Utils.sum(tempDist[kk]);
1916: }
1917:
1918: if (Utils.sum(tempProps) != 0)
1919: Utils.normalize(tempProps);
1920:
1921: // split missing values
1922: int mstart = missingStart;
1923: while (mstart < sortedIndices.length) {
1924: Instance insta = data
1925: .instance(sortedIndices[mstart]);
1926: for (int jj = 0; jj < 2; jj++) {
1927: tempDist[jj][(int) insta.classValue()] += tempProps[jj]
1928: * weights[mstart];
1929: }
1930: mstart++;
1931: }
1932:
1933: double currGain;
1934: if (useGini)
1935: currGain = computeGiniGain(parentDist, tempDist);
1936: else
1937: currGain = computeInfoGain(parentDist, tempDist);
1938:
1939: if (currGain > bestGain) {
1940: bestGain = currGain;
1941: bestSplitString = tempStr;
1942: for (int jj = 0; jj < 2; jj++) {
1943: //dist[jj] = new double[currDist[jj].length];
1944: System.arraycopy(tempDist[jj], 0, dist[jj], 0,
1945: dist[jj].length);
1946: }
1947: }
1948: }
1949: }
1950:
1951: // Compute weights
1952: int attIndex = att.index();
1953: props[attIndex] = new double[2];
1954: for (int k = 0; k < 2; k++) {
1955: props[attIndex][k] = Utils.sum(dist[k]);
1956: }
1957: if (!(Utils.sum(props[attIndex]) > 0)) {
1958: for (int k = 0; k < props[attIndex].length; k++) {
1959: props[attIndex][k] = 1.0 / (double) props[attIndex].length;
1960: }
1961: } else {
1962: Utils.normalize(props[attIndex]);
1963: }
1964:
1965: // Compute subset weights
1966: subsetWeights[attIndex] = new double[2];
1967: for (int j = 0; j < 2; j++) {
1968: subsetWeights[attIndex][j] += Utils.sum(dist[j]);
1969: }
1970:
1971: // Then, for the attribute values that class frequency is 0, split it into the
1972: // most frequent branch
1973: for (int j = 0; j < empty; j++) {
1974: if (props[attIndex][0] >= props[attIndex][1]) {
1975: if (bestSplitString == "")
1976: bestSplitString = "(" + emptyValues[j] + ")";
1977: else
1978: bestSplitString += "|" + "(" + emptyValues[j] + ")";
1979: }
1980: }
1981:
1982: // clean gain
1983: gains[attIndex] = Math.rint(bestGain * 10000000) / 10000000.0;
1984:
1985: dists[attIndex] = dist;
1986: return bestSplitString;
1987: }
1988:
1989: /**
1990: * Split data into two subsets and store sorted indices and weights for two
1991: * successor nodes.
1992: *
1993: * @param subsetIndices sorted indecis of instances for each attribute for two successor node
1994: * @param subsetWeights weights of instances for each attribute for two successor node
1995: * @param att attribute the split based on
1996: * @param splitPoint split point the split based on if att is numeric
1997: * @param splitStr split subset the split based on if att is nominal
1998: * @param sortedIndices sorted indices of the instances to be split
1999: * @param weights weights of the instances to bes split
2000: * @param data training data
2001: * @throws Exception if something goes wrong
2002: */
2003: protected void splitData(int[][][] subsetIndices,
2004: double[][][] subsetWeights, Attribute att,
2005: double splitPoint, String splitStr, int[][] sortedIndices,
2006: double[][] weights, Instances data) throws Exception {
2007:
2008: int j;
2009: // For each attribute
2010: for (int i = 0; i < data.numAttributes(); i++) {
2011: if (i == data.classIndex())
2012: continue;
2013: int[] num = new int[2];
2014: for (int k = 0; k < 2; k++) {
2015: subsetIndices[k][i] = new int[sortedIndices[i].length];
2016: subsetWeights[k][i] = new double[weights[i].length];
2017: }
2018:
2019: for (j = 0; j < sortedIndices[i].length; j++) {
2020: Instance inst = data.instance(sortedIndices[i][j]);
2021: if (inst.isMissing(att)) {
2022: // Split instance up
2023: for (int k = 0; k < 2; k++) {
2024: if (m_Props[k] > 0) {
2025: subsetIndices[k][i][num[k]] = sortedIndices[i][j];
2026: subsetWeights[k][i][num[k]] = m_Props[k]
2027: * weights[i][j];
2028: num[k]++;
2029: }
2030: }
2031: } else {
2032: int subset;
2033: if (att.isNumeric()) {
2034: subset = (inst.value(att) < splitPoint) ? 0 : 1;
2035: } else { // nominal attribute
2036: if (splitStr.indexOf("("
2037: + att.value((int) inst.value(att
2038: .index())) + ")") != -1) {
2039: subset = 0;
2040: } else
2041: subset = 1;
2042: }
2043: subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
2044: subsetWeights[subset][i][num[subset]] = weights[i][j];
2045: num[subset]++;
2046: }
2047: }
2048:
2049: // Trim arrays
2050: for (int k = 0; k < 2; k++) {
2051: int[] copy = new int[num[k]];
2052: System.arraycopy(subsetIndices[k][i], 0, copy, 0,
2053: num[k]);
2054: subsetIndices[k][i] = copy;
2055: double[] copyWeights = new double[num[k]];
2056: System.arraycopy(subsetWeights[k][i], 0, copyWeights,
2057: 0, num[k]);
2058: subsetWeights[k][i] = copyWeights;
2059: }
2060: }
2061: }
2062:
2063: /**
2064: * Compute and return gini gain for given distributions of a node and its
2065: * successor nodes.
2066: *
2067: * @param parentDist class distributions of parent node
2068: * @param childDist class distributions of successor nodes
2069: * @return Gini gain computed
2070: */
2071: protected double computeGiniGain(double[] parentDist,
2072: double[][] childDist) {
2073: double totalWeight = Utils.sum(parentDist);
2074: if (totalWeight == 0)
2075: return 0;
2076:
2077: double leftWeight = Utils.sum(childDist[0]);
2078: double rightWeight = Utils.sum(childDist[1]);
2079:
2080: double parentGini = computeGini(parentDist, totalWeight);
2081: double leftGini = computeGini(childDist[0], leftWeight);
2082: double rightGini = computeGini(childDist[1], rightWeight);
2083:
2084: return parentGini - leftWeight / totalWeight * leftGini
2085: - rightWeight / totalWeight * rightGini;
2086: }
2087:
2088: /**
2089: * Compute and return gini index for a given distribution of a node.
2090: *
2091: * @param dist class distributions
2092: * @param total class distributions
2093: * @return Gini index of the class distributions
2094: */
2095: protected double computeGini(double[] dist, double total) {
2096: if (total == 0)
2097: return 0;
2098: double val = 0;
2099: for (int i = 0; i < dist.length; i++) {
2100: val += (dist[i] / total) * (dist[i] / total);
2101: }
2102: return 1 - val;
2103: }
2104:
2105: /**
2106: * Compute and return information gain for given distributions of a node
2107: * and its successor nodes.
2108: *
2109: * @param parentDist class distributions of parent node
2110: * @param childDist class distributions of successor nodes
2111: * @return information gain computed
2112: */
2113: protected double computeInfoGain(double[] parentDist,
2114: double[][] childDist) {
2115: double totalWeight = Utils.sum(parentDist);
2116: if (totalWeight == 0)
2117: return 0;
2118:
2119: double leftWeight = Utils.sum(childDist[0]);
2120: double rightWeight = Utils.sum(childDist[1]);
2121:
2122: double parentInfo = computeEntropy(parentDist, totalWeight);
2123: double leftInfo = computeEntropy(childDist[0], leftWeight);
2124: double rightInfo = computeEntropy(childDist[1], rightWeight);
2125:
2126: return parentInfo - leftWeight / totalWeight * leftInfo
2127: - rightWeight / totalWeight * rightInfo;
2128: }
2129:
2130: /**
2131: * Compute and return entropy for a given distribution of a node.
2132: *
2133: * @param dist class distributions
2134: * @param total class distributions
2135: * @return entropy of the class distributions
2136: */
2137: protected double computeEntropy(double[] dist, double total) {
2138: if (total == 0)
2139: return 0;
2140: double entropy = 0;
2141: for (int i = 0; i < dist.length; i++) {
2142: if (dist[i] != 0)
2143: entropy -= dist[i] / total
2144: * Utils.log2(dist[i] / total);
2145: }
2146: return entropy;
2147: }
2148:
2149: /**
2150: * Make the node leaf node.
2151: *
2152: * @param data training data
2153: */
2154: protected void makeLeaf(Instances data) {
2155: m_Attribute = null;
2156: m_isLeaf = true;
2157: m_ClassValue = Utils.maxIndex(m_ClassProbs);
2158: m_ClassAttribute = data.classAttribute();
2159: }
2160:
2161: /**
2162: * Computes class probabilities for instance using the decision tree.
2163: *
2164: * @param instance the instance for which class probabilities is to be computed
2165: * @return the class probabilities for the given instance
2166: * @throws Exception if something goes wrong
2167: */
2168: public double[] distributionForInstance(Instance instance)
2169: throws Exception {
2170: if (!m_isLeaf) {
2171: // value of split attribute is missing
2172: if (instance.isMissing(m_Attribute)) {
2173: double[] returnedDist = new double[m_ClassProbs.length];
2174:
2175: for (int i = 0; i < m_Successors.length; i++) {
2176: double[] help = m_Successors[i]
2177: .distributionForInstance(instance);
2178: if (help != null) {
2179: for (int j = 0; j < help.length; j++) {
2180: returnedDist[j] += m_Props[i] * help[j];
2181: }
2182: }
2183: }
2184: return returnedDist;
2185: }
2186:
2187: // split attribute is nonimal
2188: else if (m_Attribute.isNominal()) {
2189: if (m_SplitString.indexOf("("
2190: + m_Attribute.value((int) instance
2191: .value(m_Attribute)) + ")") != -1)
2192: return m_Successors[0]
2193: .distributionForInstance(instance);
2194: else
2195: return m_Successors[1]
2196: .distributionForInstance(instance);
2197: }
2198:
2199: // split attribute is numeric
2200: else {
2201: if (instance.value(m_Attribute) < m_SplitValue)
2202: return m_Successors[0]
2203: .distributionForInstance(instance);
2204: else
2205: return m_Successors[1]
2206: .distributionForInstance(instance);
2207: }
2208: }
2209:
2210: // leaf node
2211: else
2212: return m_ClassProbs;
2213: }
2214:
2215: /**
2216: * Prints the decision tree using the protected toString method from below.
2217: *
2218: * @return a textual description of the classifier
2219: */
2220: public String toString() {
2221: if ((m_Distribution == null) && (m_Successors == null)) {
2222: return "Best-First: No model built yet.";
2223: }
2224: return "Best-First Decision Tree\n" + toString(0) + "\n\n"
2225: + "Size of the Tree: " + numNodes() + "\n\n"
2226: + "Number of Leaf Nodes: " + numLeaves();
2227: }
2228:
2229: /**
2230: * Outputs a tree at a certain level.
2231: *
2232: * @param level the level at which the tree is to be printed
2233: * @return a tree at a certain level.
2234: */
2235: protected String toString(int level) {
2236: StringBuffer text = new StringBuffer();
2237: // if leaf nodes
2238: if (m_Attribute == null) {
2239: if (Instance.isMissingValue(m_ClassValue)) {
2240: text.append(": null");
2241: } else {
2242: double correctNum = Math.rint(m_Distribution[Utils
2243: .maxIndex(m_Distribution)] * 100) / 100.0;
2244: double wrongNum = Math
2245: .rint((Utils.sum(m_Distribution) - m_Distribution[Utils
2246: .maxIndex(m_Distribution)]) * 100) / 100.0;
2247: String str = "(" + correctNum + "/" + wrongNum + ")";
2248: text.append(": "
2249: + m_ClassAttribute.value((int) m_ClassValue)
2250: + str);
2251: }
2252: } else {
2253: for (int j = 0; j < 2; j++) {
2254: text.append("\n");
2255: for (int i = 0; i < level; i++) {
2256: text.append("| ");
2257: }
2258: if (j == 0) {
2259: if (m_Attribute.isNumeric())
2260: text.append(m_Attribute.name() + " < "
2261: + m_SplitValue);
2262: else
2263: text.append(m_Attribute.name() + "="
2264: + m_SplitString);
2265: } else {
2266: if (m_Attribute.isNumeric())
2267: text.append(m_Attribute.name() + " >= "
2268: + m_SplitValue);
2269: else
2270: text.append(m_Attribute.name() + "!="
2271: + m_SplitString);
2272: }
2273: text.append(m_Successors[j].toString(level + 1));
2274: }
2275: }
2276: return text.toString();
2277: }
2278:
2279: /**
2280: * Compute size of the tree.
2281: *
2282: * @return size of the tree
2283: */
2284: public int numNodes() {
2285: if (m_isLeaf) {
2286: return 1;
2287: } else {
2288: int size = 1;
2289: for (int i = 0; i < m_Successors.length; i++) {
2290: size += m_Successors[i].numNodes();
2291: }
2292: return size;
2293: }
2294: }
2295:
2296: /**
2297: * Compute number of leaf nodes.
2298: *
2299: * @return number of leaf nodes
2300: */
2301: public int numLeaves() {
2302: if (m_isLeaf)
2303: return 1;
2304: else {
2305: int size = 0;
2306: for (int i = 0; i < m_Successors.length; i++) {
2307: size += m_Successors[i].numLeaves();
2308: }
2309: return size;
2310: }
2311: }
2312:
2313: /**
2314: * Returns an enumeration describing the available options.
2315: *
2316: * @return an enumeration describing the available options.
2317: */
2318: public Enumeration listOptions() {
2319: Vector result;
2320: Enumeration en;
2321:
2322: result = new Vector();
2323:
2324: en = super .listOptions();
2325: while (en.hasMoreElements())
2326: result.addElement(en.nextElement());
2327:
2328: result.addElement(new Option("\tThe pruning strategy.\n"
2329: + "\t(default: "
2330: + new SelectedTag(PRUNING_POSTPRUNING, TAGS_PRUNING)
2331: + ")", "P", 1, "-P " + Tag.toOptionList(TAGS_PRUNING)));
2332:
2333: result.addElement(new Option(
2334: "\tThe minimal number of instances at the terminal nodes.\n"
2335: + "\t(default 2)", "M", 1, "-M <min no>"));
2336:
2337: result.addElement(new Option(
2338: "\tThe number of folds used in the pruning.\n"
2339: + "\t(default 5)", "N", 5, "-N <num folds>"));
2340:
2341: result.addElement(new Option(
2342: "\tDon't use heuristic search for nominal attributes in multi-class\n"
2343: + "\tproblem (default yes).\n", "H", 0, "-H"));
2344:
2345: result
2346: .addElement(new Option(
2347: "\tDon't use Gini index for splitting (default yes),\n"
2348: + "\tif not information is used.", "G",
2349: 0, "-G"));
2350:
2351: result.addElement(new Option(
2352: "\tDon't use error rate in internal cross-validation (default yes), \n"
2353: + "\tbut root mean squared error.", "R", 0,
2354: "-R"));
2355:
2356: result.addElement(new Option(
2357: "\tUse the 1 SE rule to make pruning decision.\n"
2358: + "\t(default no).", "A", 0, "-A"));
2359:
2360: result.addElement(new Option(
2361: "\tPercentage of training data size (0-1]\n"
2362: + "\t(default 1).", "C", 0, "-C"));
2363:
2364: return result.elements();
2365: }
2366:
2367: /**
2368: * Parses the options for this object. <p/>
2369: *
2370: <!-- options-start -->
2371: * Valid options are: <p/>
2372: *
2373: * <pre> -S <num>
2374: * Random number seed.
2375: * (default 1)</pre>
2376: *
2377: * <pre> -D
2378: * If set, classifier is run in debug mode and
2379: * may output additional info to the console</pre>
2380: *
2381: * <pre> -P <UNPRUNED|POSTPRUNED|PREPRUNED>
2382: * The pruning strategy.
2383: * (default: POSTPRUNED)</pre>
2384: *
2385: * <pre> -M <min no>
2386: * The minimal number of instances at the terminal nodes.
2387: * (default 2)</pre>
2388: *
2389: * <pre> -N <num folds>
2390: * The number of folds used in the pruning.
2391: * (default 5)</pre>
2392: *
2393: * <pre> -H
2394: * Don't use heuristic search for nominal attributes in multi-class
2395: * problem (default yes).
2396: * </pre>
2397: *
2398: * <pre> -G
2399: * Don't use Gini index for splitting (default yes),
2400: * if not information is used.</pre>
2401: *
2402: * <pre> -R
2403: * Don't use error rate in internal cross-validation (default yes),
2404: * but root mean squared error.</pre>
2405: *
2406: * <pre> -A
2407: * Use the 1 SE rule to make pruning decision.
2408: * (default no).</pre>
2409: *
2410: * <pre> -C
2411: * Percentage of training data size (0-1]
2412: * (default 1).</pre>
2413: *
2414: <!-- options-end -->
2415: *
2416: * @param options the options to use
2417: * @throws Exception if setting of options fails
2418: */
2419: public void setOptions(String[] options) throws Exception {
2420: String tmpStr;
2421:
2422: super .setOptions(options);
2423:
2424: tmpStr = Utils.getOption('M', options);
2425: if (tmpStr.length() != 0)
2426: setMinNumObj(Integer.parseInt(tmpStr));
2427: else
2428: setMinNumObj(2);
2429:
2430: tmpStr = Utils.getOption('N', options);
2431: if (tmpStr.length() != 0)
2432: setNumFoldsPruning(Integer.parseInt(tmpStr));
2433: else
2434: setNumFoldsPruning(5);
2435:
2436: tmpStr = Utils.getOption('C', options);
2437: if (tmpStr.length() != 0)
2438: setSizePer(Double.parseDouble(tmpStr));
2439: else
2440: setSizePer(1);
2441:
2442: tmpStr = Utils.getOption('P', options);
2443: if (tmpStr.length() != 0)
2444: setPruningStrategy(new SelectedTag(tmpStr, TAGS_PRUNING));
2445: else
2446: setPruningStrategy(new SelectedTag(PRUNING_POSTPRUNING,
2447: TAGS_PRUNING));
2448:
2449: setHeuristic(!Utils.getFlag('H', options));
2450:
2451: setUseGini(!Utils.getFlag('G', options));
2452:
2453: setUseErrorRate(!Utils.getFlag('R', options));
2454:
2455: setUseOneSE(Utils.getFlag('A', options));
2456: }
2457:
2458: /**
2459: * Gets the current settings of the Classifier.
2460: *
2461: * @return the current settings of the Classifier
2462: */
2463: public String[] getOptions() {
2464: int i;
2465: Vector result;
2466: String[] options;
2467:
2468: result = new Vector();
2469:
2470: options = super .getOptions();
2471: for (i = 0; i < options.length; i++)
2472: result.add(options[i]);
2473:
2474: result.add("-M");
2475: result.add("" + getMinNumObj());
2476:
2477: result.add("-N");
2478: result.add("" + getNumFoldsPruning());
2479:
2480: if (!getHeuristic())
2481: result.add("-H");
2482:
2483: if (!getUseGini())
2484: result.add("-G");
2485:
2486: if (!getUseErrorRate())
2487: result.add("-R");
2488:
2489: if (getUseOneSE())
2490: result.add("-A");
2491:
2492: result.add("-C");
2493: result.add("" + getSizePer());
2494:
2495: result.add("-P");
2496: result.add("" + getPruningStrategy());
2497:
2498: return (String[]) result.toArray(new String[result.size()]);
2499: }
2500:
2501: /**
2502: * Return an enumeration of the measure names.
2503: *
2504: * @return an enumeration of the measure names
2505: */
2506: public Enumeration enumerateMeasures() {
2507: Vector result = new Vector();
2508:
2509: result.addElement("measureTreeSize");
2510:
2511: return result.elements();
2512: }
2513:
2514: /**
2515: * Return number of tree size.
2516: *
2517: * @return number of tree size
2518: */
2519: public double measureTreeSize() {
2520: return numNodes();
2521: }
2522:
2523: /**
2524: * Returns the value of the named measure
2525: *
2526: * @param additionalMeasureName the name of the measure to query for its value
2527: * @return the value of the named measure
2528: * @throws IllegalArgumentException if the named measure is not supported
2529: */
2530: public double getMeasure(String additionalMeasureName) {
2531: if (additionalMeasureName
2532: .compareToIgnoreCase("measureTreeSize") == 0) {
2533: return measureTreeSize();
2534: } else {
2535: throw new IllegalArgumentException(additionalMeasureName
2536: + " not supported (Best-First)");
2537: }
2538: }
2539:
2540: /**
2541: * Returns the tip text for this property
2542: *
2543: * @return tip text for this property suitable for
2544: * displaying in the explorer/experimenter gui
2545: */
2546: public String pruningStrategyTipText() {
2547: return "Sets the pruning strategy.";
2548: }
2549:
2550: /**
2551: * Sets the pruning strategy.
2552: *
2553: * @param value the strategy
2554: */
2555: public void setPruningStrategy(SelectedTag value) {
2556: if (value.getTags() == TAGS_PRUNING) {
2557: m_PruningStrategy = value.getSelectedTag().getID();
2558: }
2559: }
2560:
2561: /**
2562: * Gets the pruning strategy.
2563: *
2564: * @return the current strategy.
2565: */
2566: public SelectedTag getPruningStrategy() {
2567: return new SelectedTag(m_PruningStrategy, TAGS_PRUNING);
2568: }
2569:
2570: /**
2571: * Returns the tip text for this property
2572: *
2573: * @return tip text for this property suitable for
2574: * displaying in the explorer/experimenter gui
2575: */
2576: public String minNumObjTipText() {
2577: return "Set minimal number of instances at the terminal nodes.";
2578: }
2579:
2580: /**
2581: * Set minimal number of instances at the terminal nodes.
2582: *
2583: * @param value minimal number of instances at the terminal nodes
2584: */
2585: public void setMinNumObj(int value) {
2586: m_minNumObj = value;
2587: }
2588:
2589: /**
2590: * Get minimal number of instances at the terminal nodes.
2591: *
2592: * @return minimal number of instances at the terminal nodes
2593: */
2594: public int getMinNumObj() {
2595: return m_minNumObj;
2596: }
2597:
2598: /**
2599: * Returns the tip text for this property
2600: *
2601: * @return tip text for this property suitable for
2602: * displaying in the explorer/experimenter gui
2603: */
2604: public String numFoldsPruningTipText() {
2605: return "Number of folds in internal cross-validation.";
2606: }
2607:
2608: /**
2609: * Set number of folds in internal cross-validation.
2610: *
2611: * @param value the number of folds
2612: */
2613: public void setNumFoldsPruning(int value) {
2614: m_numFoldsPruning = value;
2615: }
2616:
2617: /**
2618: * Set number of folds in internal cross-validation.
2619: *
2620: * @return number of folds in internal cross-validation
2621: */
2622: public int getNumFoldsPruning() {
2623: return m_numFoldsPruning;
2624: }
2625:
2626: /**
2627: * Returns the tip text for this property
2628: *
2629: * @return tip text for this property suitable for
2630: * displaying in the explorer/experimenter gui.
2631: */
2632: public String heuristicTipText() {
2633: return "If heuristic search is used for binary split for nominal attributes.";
2634: }
2635:
2636: /**
2637: * Set if use heuristic search for nominal attributes in multi-class problems.
2638: *
2639: * @param value if use heuristic search for nominal attributes in
2640: * multi-class problems
2641: */
2642: public void setHeuristic(boolean value) {
2643: m_Heuristic = value;
2644: }
2645:
2646: /**
2647: * Get if use heuristic search for nominal attributes in multi-class problems.
2648: *
2649: * @return if use heuristic search for nominal attributes in
2650: * multi-class problems
2651: */
2652: public boolean getHeuristic() {
2653: return m_Heuristic;
2654: }
2655:
2656: /**
2657: * Returns the tip text for this property
2658: *
2659: * @return tip text for this property suitable for
2660: * displaying in the explorer/experimenter gui.
2661: */
2662: public String useGiniTipText() {
2663: return "If true the Gini index is used for splitting criterion, otherwise the information is used.";
2664: }
2665:
2666: /**
2667: * Set if use Gini index as splitting criterion.
2668: *
2669: * @param value if use Gini index splitting criterion
2670: */
2671: public void setUseGini(boolean value) {
2672: m_UseGini = value;
2673: }
2674:
2675: /**
2676: * Get if use Gini index as splitting criterion.
2677: *
2678: * @return if use Gini index as splitting criterion
2679: */
2680: public boolean getUseGini() {
2681: return m_UseGini;
2682: }
2683:
2684: /**
2685: * Returns the tip text for this property
2686: *
2687: * @return tip text for this property suitable for
2688: * displaying in the explorer/experimenter gui.
2689: */
2690: public String useErrorRateTipText() {
2691: return "If error rate is used as error estimate. if not, root mean squared error is used.";
2692: }
2693:
2694: /**
2695: * Set if use error rate in internal cross-validation.
2696: *
2697: * @param value if use error rate in internal cross-validation
2698: */
2699: public void setUseErrorRate(boolean value) {
2700: m_UseErrorRate = value;
2701: }
2702:
2703: /**
2704: * Get if use error rate in internal cross-validation.
2705: *
2706: * @return if use error rate in internal cross-validation.
2707: */
2708: public boolean getUseErrorRate() {
2709: return m_UseErrorRate;
2710: }
2711:
2712: /**
2713: * Returns the tip text for this property
2714: *
2715: * @return tip text for this property suitable for
2716: * displaying in the explorer/experimenter gui.
2717: */
2718: public String useOneSETipText() {
2719: return "Use the 1SE rule to make pruning decision.";
2720: }
2721:
2722: /**
2723: * Set if use the 1SE rule to choose final model.
2724: *
2725: * @param value if use the 1SE rule to choose final model
2726: */
2727: public void setUseOneSE(boolean value) {
2728: m_UseOneSE = value;
2729: }
2730:
2731: /**
2732: * Get if use the 1SE rule to choose final model.
2733: *
2734: * @return if use the 1SE rule to choose final model
2735: */
2736: public boolean getUseOneSE() {
2737: return m_UseOneSE;
2738: }
2739:
2740: /**
2741: * Returns the tip text for this property
2742: *
2743: * @return tip text for this property suitable for
2744: * displaying in the explorer/experimenter gui.
2745: */
2746: public String sizePerTipText() {
2747: return "The percentage of the training set size (0-1, 0 not included).";
2748: }
2749:
2750: /**
2751: * Set training set size.
2752: *
2753: * @param value training set size
2754: */
2755: public void setSizePer(double value) {
2756: if ((value <= 0) || (value > 1))
2757: System.err
2758: .println("The percentage of the training set size must be in range 0 to 1 "
2759: + "(0 not included) - ignored!");
2760: else
2761: m_SizePer = value;
2762: }
2763:
2764: /**
2765: * Get training set size.
2766: *
2767: * @return training set size
2768: */
2769: public double getSizePer() {
2770: return m_SizePer;
2771: }
2772:
2773: /**
2774: * Main method.
2775: *
2776: * @param args the options for the classifier
2777: */
2778: public static void main(String[] args) {
2779: runClassifier(new BFTree(), args);
2780: }
2781: }
|