Source Code Cross Referenced for BFTree.java in  » Science » weka » weka » classifiers » trees » Java Source Code / Java DocumentationJava Source Code and Java Documentation

Java Source Code / Java Documentation
1. 6.0 JDK Core
2. 6.0 JDK Modules
3. 6.0 JDK Modules com.sun
4. 6.0 JDK Modules com.sun.java
5. 6.0 JDK Modules sun
6. 6.0 JDK Platform
7. Ajax
8. Apache Harmony Java SE
9. Aspect oriented
10. Authentication Authorization
11. Blogger System
12. Build
13. Byte Code
14. Cache
15. Chart
16. Chat
17. Code Analyzer
18. Collaboration
19. Content Management System
20. Database Client
21. Database DBMS
22. Database JDBC Connection Pool
23. Database ORM
24. Development
25. EJB Server geronimo
26. EJB Server GlassFish
27. EJB Server JBoss 4.2.1
28. EJB Server resin 3.1.5
29. ERP CRM Financial
30. ESB
31. Forum
32. GIS
33. Graphic Library
34. Groupware
35. HTML Parser
36. IDE
37. IDE Eclipse
38. IDE Netbeans
39. Installer
40. Internationalization Localization
41. Inversion of Control
42. Issue Tracking
43. J2EE
44. JBoss
45. JMS
46. JMX
47. Library
48. Mail Clients
49. Net
50. Parser
51. PDF
52. Portal
53. Profiler
54. Project Management
55. Report
56. RSS RDF
57. Rule Engine
58. Science
59. Scripting
60. Search Engine
61. Security
62. Sevlet Container
63. Source Control
64. Swing Library
65. Template Engine
66. Test Coverage
67. Testing
68. UML
69. Web Crawler
70. Web Framework
71. Web Mail
72. Web Server
73. Web Services
74. Web Services apache cxf 2.0.1
75. Web Services AXIS2
76. Wiki Engine
77. Workflow Engines
78. XML
79. XML UI
Java
Java Tutorial
Java Open Source
Jar File Download
Java Articles
Java Products
Java by API
Photoshop Tutorials
Maya Tutorials
Flash Tutorials
3ds-Max Tutorials
Illustrator Tutorials
GIMP Tutorials
C# / C Sharp
C# / CSharp Tutorial
C# / CSharp Open Source
ASP.Net
ASP.NET Tutorial
JavaScript DHTML
JavaScript Tutorial
JavaScript Reference
HTML / CSS
HTML CSS Reference
C / ANSI-C
C Tutorial
C++
C++ Tutorial
Ruby
PHP
Python
Python Tutorial
Python Open Source
SQL Server / T-SQL
SQL Server / T-SQL Tutorial
Oracle PL / SQL
Oracle PL/SQL Tutorial
PostgreSQL
SQL / MySQL
MySQL Tutorial
VB.Net
VB.Net Tutorial
Flash / Flex / ActionScript
VBA / Excel / Access / Word
XML
XML Tutorial
Microsoft Office PowerPoint 2007 Tutorial
Microsoft Office Excel 2007 Tutorial
Microsoft Office Word 2007 Tutorial
Java Source Code / Java Documentation » Science » weka » weka.classifiers.trees 
Source Cross Referenced  Class Diagram Java Document (Java Doc) 


0001:        /*
0002:         *    This program is free software; you can redistribute it and/or modify
0003:         *    it under the terms of the GNU General Public License as published by
0004:         *    the Free Software Foundation; either version 2 of the License, or
0005:         *    (at your option) any later version.
0006:         *
0007:         *    This program is distributed in the hope that it will be useful,
0008:         *    but WITHOUT ANY WARRANTY; without even the implied warranty of
0009:         *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
0010:         *    GNU General Public License for more details.
0011:         *
0012:         *    You should have received a copy of the GNU General Public License
0013:         *    along with this program; if not, write to the Free Software
0014:         *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
0015:         */
0016:
0017:        /*
0018:         * BFTree.java
0019:         * Copyright (C) 2007 University of Waikato, Hamilton, New Zealand
0020:         *
0021:         */
0022:
0023:        package weka.classifiers.trees;
0024:
0025:        import weka.classifiers.Evaluation;
0026:        import weka.classifiers.RandomizableClassifier;
0027:        import weka.core.AdditionalMeasureProducer;
0028:        import weka.core.Attribute;
0029:        import weka.core.Capabilities;
0030:        import weka.core.FastVector;
0031:        import weka.core.Instance;
0032:        import weka.core.Instances;
0033:        import weka.core.Option;
0034:        import weka.core.SelectedTag;
0035:        import weka.core.Tag;
0036:        import weka.core.TechnicalInformation;
0037:        import weka.core.TechnicalInformationHandler;
0038:        import weka.core.Utils;
0039:        import weka.core.Capabilities.Capability;
0040:        import weka.core.TechnicalInformation.Field;
0041:        import weka.core.TechnicalInformation.Type;
0042:        import weka.core.matrix.Matrix;
0043:
0044:        import java.util.Arrays;
0045:        import java.util.Enumeration;
0046:        import java.util.Random;
0047:        import java.util.Vector;
0048:
0049:        /**
0050:         <!-- globalinfo-start -->
0051:         * Class for building a best-first decision tree classifier. This class uses binary split for both nominal and numeric attributes. For missing values, the method of 'fractional' instances is used.<br/>
0052:         * <br/>
0053:         * For more information, see:<br/>
0054:         * <br/>
0055:         * Haijian Shi (2007). Best-first decision tree learning. Hamilton, NZ.<br/>
0056:         * <br/>
0057:         * Jerome Friedman, Trevor Hastie, Robert Tibshirani (2000). Additive logistic regression : A statistical view of boosting. Annals of statistics. 28(2):337-407.
0058:         * <p/>
0059:         <!-- globalinfo-end -->
0060:         *
0061:         <!-- technical-bibtex-start -->
0062:         * BibTeX:
0063:         * <pre>
0064:         * &#64;mastersthesis{Shi2007,
0065:         *    address = {Hamilton, NZ},
0066:         *    author = {Haijian Shi},
0067:         *    note = {COMP594},
0068:         *    school = {University of Waikato},
0069:         *    title = {Best-first decision tree learning},
0070:         *    year = {2007}
0071:         * }
0072:         * 
0073:         * &#64;article{Friedman2000,
0074:         *    author = {Jerome Friedman and Trevor Hastie and Robert Tibshirani},
0075:         *    journal = {Annals of statistics},
0076:         *    number = {2},
0077:         *    pages = {337-407},
0078:         *    title = {Additive logistic regression : A statistical view of boosting},
0079:         *    volume = {28},
0080:         *    year = {2000},
0081:         *    ISSN = {0090-5364}
0082:         * }
0083:         * </pre>
0084:         * <p/>
0085:         <!-- technical-bibtex-end -->
0086:         *
0087:         <!-- options-start -->
0088:         * Valid options are: <p/>
0089:         * 
0090:         * <pre> -S &lt;num&gt;
0091:         *  Random number seed.
0092:         *  (default 1)</pre>
0093:         * 
0094:         * <pre> -D
0095:         *  If set, classifier is run in debug mode and
0096:         *  may output additional info to the console</pre>
0097:         * 
0098:         * <pre> -P &lt;UNPRUNED|POSTPRUNED|PREPRUNED&gt;
0099:         *  The pruning strategy.
0100:         *  (default: POSTPRUNED)</pre>
0101:         * 
0102:         * <pre> -M &lt;min no&gt;
0103:         *  The minimal number of instances at the terminal nodes.
0104:         *  (default 2)</pre>
0105:         * 
0106:         * <pre> -N &lt;num folds&gt;
0107:         *  The number of folds used in the pruning.
0108:         *  (default 5)</pre>
0109:         * 
0110:         * <pre> -H
0111:         *  Don't use heuristic search for nominal attributes in multi-class
0112:         *  problem (default yes).
0113:         * </pre>
0114:         * 
0115:         * <pre> -G
0116:         *  Don't use Gini index for splitting (default yes),
0117:         *  if not information is used.</pre>
0118:         * 
0119:         * <pre> -R
0120:         *  Don't use error rate in internal cross-validation (default yes), 
0121:         *  but root mean squared error.</pre>
0122:         * 
0123:         * <pre> -A
0124:         *  Use the 1 SE rule to make pruning decision.
0125:         *  (default no).</pre>
0126:         * 
0127:         * <pre> -C
0128:         *  Percentage of training data size (0-1]
0129:         *  (default 1).</pre>
0130:         * 
0131:         <!-- options-end -->
0132:         *
0133:         * @author Haijian Shi (hs69@cs.waikato.ac.nz)
0134:         * @version $Revision: 1.2 $
0135:         */
0136:        public class BFTree extends RandomizableClassifier implements 
0137:                AdditionalMeasureProducer, TechnicalInformationHandler {
0138:
0139:            /** For serialization.	 */
0140:            private static final long serialVersionUID = -7035607375962528217L;
0141:
0142:            /** pruning strategy: un-pruned */
0143:            public static final int PRUNING_UNPRUNED = 0;
0144:            /** pruning strategy: post-pruning */
0145:            public static final int PRUNING_POSTPRUNING = 1;
0146:            /** pruning strategy: pre-pruning */
0147:            public static final int PRUNING_PREPRUNING = 2;
0148:            /** pruning strategy */
0149:            public static final Tag[] TAGS_PRUNING = {
0150:                    new Tag(PRUNING_UNPRUNED, "unpruned", "Un-pruned"),
0151:                    new Tag(PRUNING_POSTPRUNING, "postpruned", "Post-pruning"),
0152:                    new Tag(PRUNING_PREPRUNING, "prepruned", "Pre-pruning") };
0153:
0154:            /** the pruning strategy */
0155:            protected int m_PruningStrategy = PRUNING_POSTPRUNING;
0156:
0157:            /** Successor nodes. */
0158:            protected BFTree[] m_Successors;
0159:
0160:            /** Attribute used for splitting. */
0161:            protected Attribute m_Attribute;
0162:
0163:            /** Split point (for numeric attributes). */
0164:            protected double m_SplitValue;
0165:
0166:            /** Split subset (for nominal attributes). */
0167:            protected String m_SplitString;
0168:
0169:            /** Class value for a node. */
0170:            protected double m_ClassValue;
0171:
0172:            /** Class attribute of a dataset. */
0173:            protected Attribute m_ClassAttribute;
0174:
0175:            /** Minimum number of instances at leaf nodes. */
0176:            protected int m_minNumObj = 2;
0177:
0178:            /** Number of folds for the pruning. */
0179:            protected int m_numFoldsPruning = 5;
0180:
0181:            /** If the ndoe is leaf node. */
0182:            protected boolean m_isLeaf;
0183:
0184:            /** Number of expansions. */
0185:            protected static int m_Expansion;
0186:
0187:            /** Fixed number of expansions (if no pruning method is used, its value is -1. Otherwise,
0188:             *  its value is gotten from internal cross-validation).   */
0189:            protected int m_FixedExpansion = -1;
0190:
0191:            /** If use huristic search for binary split (default true). Note even if its value is true, it is only
0192:             * used when the number of values of a nominal attribute is larger than 4. */
0193:            protected boolean m_Heuristic = true;
0194:
0195:            /** If use Gini index as the splitting criterion - default (if not, information is used). */
0196:            protected boolean m_UseGini = true;
0197:
0198:            /** If use error rate in internal cross-validation to fix the number of expansions - default
0199:             *  (if not, root mean squared error is used). */
0200:            protected boolean m_UseErrorRate = true;
0201:
0202:            /** If use the 1SE rule to make the decision. */
0203:            protected boolean m_UseOneSE = false;
0204:
0205:            /** Class distributions.  */
0206:            protected double[] m_Distribution;
0207:
0208:            /** Branch proportions. */
0209:            protected double[] m_Props;
0210:
0211:            /** Sorted indices. */
0212:            protected int[][] m_SortedIndices;
0213:
0214:            /** Sorted weights. */
0215:            protected double[][] m_Weights;
0216:
0217:            /** Distributions of each attribute for two successor nodes. */
0218:            protected double[][][] m_Dists;
0219:
0220:            /** Class probabilities. */
0221:            protected double[] m_ClassProbs;
0222:
0223:            /** Total weights. */
0224:            protected double m_TotalWeight;
0225:
0226:            /** The training data size (0-1). Default 1. */
0227:            protected double m_SizePer = 1;
0228:
0229:            /**
0230:             * Returns a string describing classifier
0231:             * 
0232:             * @return 		a description suitable for displaying in the 
0233:             * 			explorer/experimenter gui
0234:             */
0235:            public String globalInfo() {
0236:                return "Class for building a best-first decision tree classifier. "
0237:                        + "This class uses binary split for both nominal and numeric attributes. "
0238:                        + "For missing values, the method of 'fractional' instances is used.\n\n"
0239:                        + "For more information, see:\n\n"
0240:                        + getTechnicalInformation().toString();
0241:            }
0242:
0243:            /**
0244:             * Returns an instance of a TechnicalInformation object, containing 
0245:             * detailed information about the technical background of this class,
0246:             * e.g., paper reference or book this class is based on.
0247:             * 
0248:             * @return the technical information about this class
0249:             */
0250:            public TechnicalInformation getTechnicalInformation() {
0251:                TechnicalInformation result;
0252:                TechnicalInformation additional;
0253:
0254:                result = new TechnicalInformation(Type.MASTERSTHESIS);
0255:                result.setValue(Field.AUTHOR, "Haijian Shi");
0256:                result.setValue(Field.YEAR, "2007");
0257:                result.setValue(Field.TITLE,
0258:                        "Best-first decision tree learning");
0259:                result.setValue(Field.SCHOOL, "University of Waikato");
0260:                result.setValue(Field.ADDRESS, "Hamilton, NZ");
0261:                result.setValue(Field.NOTE, "COMP594");
0262:
0263:                additional = result.add(Type.ARTICLE);
0264:                additional
0265:                        .setValue(Field.AUTHOR,
0266:                                "Jerome Friedman and Trevor Hastie and Robert Tibshirani");
0267:                additional.setValue(Field.YEAR, "2000");
0268:                additional
0269:                        .setValue(Field.TITLE,
0270:                                "Additive logistic regression : A statistical view of boosting");
0271:                additional.setValue(Field.JOURNAL, "Annals of statistics");
0272:                additional.setValue(Field.VOLUME, "28");
0273:                additional.setValue(Field.NUMBER, "2");
0274:                additional.setValue(Field.PAGES, "337-407");
0275:                additional.setValue(Field.ISSN, "0090-5364");
0276:
0277:                return result;
0278:            }
0279:
0280:            /**
0281:             * Returns default capabilities of the classifier.
0282:             * 
0283:             * @return 		the capabilities of this classifier
0284:             */
0285:            public Capabilities getCapabilities() {
0286:                Capabilities result = super .getCapabilities();
0287:
0288:                // attributes
0289:                result.enable(Capability.NOMINAL_ATTRIBUTES);
0290:                result.enable(Capability.NUMERIC_ATTRIBUTES);
0291:                result.enable(Capability.MISSING_VALUES);
0292:
0293:                // class
0294:                result.enable(Capability.NOMINAL_CLASS);
0295:
0296:                return result;
0297:            }
0298:
0299:            /**
0300:             * Method for building a BestFirst decision tree classifier.
0301:             *
0302:             * @param data 	set of instances serving as training data
0303:             * @throws Exception 	if decision tree cannot be built successfully
0304:             */
0305:            public void buildClassifier(Instances data) throws Exception {
0306:
0307:                getCapabilities().testWithFail(data);
0308:                data = new Instances(data);
0309:                data.deleteWithMissingClass();
0310:
0311:                // build an unpruned tree
0312:                if (m_PruningStrategy == PRUNING_UNPRUNED) {
0313:
0314:                    // calculate sorted indices, weights and initial class probabilities
0315:                    int[][] sortedIndices = new int[data.numAttributes()][0];
0316:                    double[][] weights = new double[data.numAttributes()][0];
0317:                    double[] classProbs = new double[data.numClasses()];
0318:                    double totalWeight = computeSortedInfo(data, sortedIndices,
0319:                            weights, classProbs);
0320:
0321:                    // Compute information of the best split for this node (include split attribute,
0322:                    // split value and gini gain (or information gain)). At the same time, compute
0323:                    // variables dists, props and totalSubsetWeights.
0324:                    double[][][] dists = new double[data.numAttributes()][2][data
0325:                            .numClasses()];
0326:                    double[][] props = new double[data.numAttributes()][2];
0327:                    double[][] totalSubsetWeights = new double[data
0328:                            .numAttributes()][2];
0329:                    FastVector nodeInfo = computeSplitInfo(this , data,
0330:                            sortedIndices, weights, dists, props,
0331:                            totalSubsetWeights, m_Heuristic, m_UseGini);
0332:
0333:                    // add the node (with all split info) into BestFirstElements
0334:                    FastVector BestFirstElements = new FastVector();
0335:                    BestFirstElements.addElement(nodeInfo);
0336:
0337:                    // Make the best-first decision tree.
0338:                    int attIndex = ((Attribute) nodeInfo.elementAt(1)).index();
0339:                    m_Expansion = 0;
0340:                    makeTree(BestFirstElements, data, sortedIndices, weights,
0341:                            dists, classProbs, totalWeight, props[attIndex],
0342:                            m_minNumObj, m_Heuristic, m_UseGini,
0343:                            m_FixedExpansion);
0344:
0345:                    return;
0346:                }
0347:
0348:                // the following code is for pre-pruning and post-pruning methods
0349:
0350:                // Compute train data, test data, sorted indices, sorted weights, total weights,
0351:                // class probabilities, class distributions, branch proportions and total subset
0352:                // weights for root nodes of each fold for prepruning and postpruning.
0353:                int expansion = 0;
0354:
0355:                Random random = new Random(m_Seed);
0356:                Instances cvData = new Instances(data);
0357:                cvData.randomize(random);
0358:                cvData = new Instances(cvData, 0,
0359:                        (int) (cvData.numInstances() * m_SizePer) - 1);
0360:                cvData.stratify(m_numFoldsPruning);
0361:
0362:                Instances[] train = new Instances[m_numFoldsPruning];
0363:                Instances[] test = new Instances[m_numFoldsPruning];
0364:                FastVector[] parallelBFElements = new FastVector[m_numFoldsPruning];
0365:                BFTree[] m_roots = new BFTree[m_numFoldsPruning];
0366:
0367:                int[][][] sortedIndices = new int[m_numFoldsPruning][data
0368:                        .numAttributes()][0];
0369:                double[][][] weights = new double[m_numFoldsPruning][data
0370:                        .numAttributes()][0];
0371:                double[][] classProbs = new double[m_numFoldsPruning][data
0372:                        .numClasses()];
0373:                double[] totalWeight = new double[m_numFoldsPruning];
0374:
0375:                double[][][][] dists = new double[m_numFoldsPruning][data
0376:                        .numAttributes()][2][data.numClasses()];
0377:                double[][][] props = new double[m_numFoldsPruning][data
0378:                        .numAttributes()][2];
0379:                double[][][] totalSubsetWeights = new double[m_numFoldsPruning][data
0380:                        .numAttributes()][2];
0381:                FastVector[] nodeInfo = new FastVector[m_numFoldsPruning];
0382:
0383:                for (int i = 0; i < m_numFoldsPruning; i++) {
0384:                    train[i] = cvData.trainCV(m_numFoldsPruning, i);
0385:                    test[i] = cvData.testCV(m_numFoldsPruning, i);
0386:                    parallelBFElements[i] = new FastVector();
0387:                    m_roots[i] = new BFTree();
0388:
0389:                    // calculate sorted indices, weights, initial class counts and total weights for each training data
0390:                    totalWeight[i] = computeSortedInfo(train[i],
0391:                            sortedIndices[i], weights[i], classProbs[i]);
0392:
0393:                    // compute information of the best split for this node (include split attribute,
0394:                    // split value and gini gain (or information gain)) in this fold
0395:                    nodeInfo[i] = computeSplitInfo(m_roots[i], train[i],
0396:                            sortedIndices[i], weights[i], dists[i], props[i],
0397:                            totalSubsetWeights[i], m_Heuristic, m_UseGini);
0398:
0399:                    // compute information for root nodes
0400:
0401:                    int attIndex = ((Attribute) nodeInfo[i].elementAt(1))
0402:                            .index();
0403:
0404:                    m_roots[i].m_SortedIndices = new int[sortedIndices[i].length][0];
0405:                    m_roots[i].m_Weights = new double[weights[i].length][0];
0406:                    m_roots[i].m_Dists = new double[dists[i].length][0][0];
0407:                    m_roots[i].m_ClassProbs = new double[classProbs[i].length];
0408:                    m_roots[i].m_Distribution = new double[classProbs[i].length];
0409:                    m_roots[i].m_Props = new double[2];
0410:
0411:                    for (int j = 0; j < m_roots[i].m_SortedIndices.length; j++) {
0412:                        m_roots[i].m_SortedIndices[j] = sortedIndices[i][j];
0413:                        m_roots[i].m_Weights[j] = weights[i][j];
0414:                        m_roots[i].m_Dists[j] = dists[i][j];
0415:                    }
0416:
0417:                    System.arraycopy(classProbs[i], 0, m_roots[i].m_ClassProbs,
0418:                            0, classProbs[i].length);
0419:                    if (Utils.sum(m_roots[i].m_ClassProbs) != 0)
0420:                        Utils.normalize(m_roots[i].m_ClassProbs);
0421:
0422:                    System.arraycopy(classProbs[i], 0,
0423:                            m_roots[i].m_Distribution, 0, classProbs[i].length);
0424:                    System.arraycopy(props[i][attIndex], 0, m_roots[i].m_Props,
0425:                            0, props[i][attIndex].length);
0426:
0427:                    m_roots[i].m_TotalWeight = totalWeight[i];
0428:
0429:                    parallelBFElements[i].addElement(nodeInfo[i]);
0430:                }
0431:
0432:                // build a pre-pruned tree
0433:                if (m_PruningStrategy == PRUNING_PREPRUNING) {
0434:
0435:                    double previousError = Double.MAX_VALUE;
0436:                    double currentError = previousError;
0437:                    double minError = Double.MAX_VALUE;
0438:                    int minExpansion = 0;
0439:                    FastVector errorList = new FastVector();
0440:                    while (true) {
0441:                        // compute average error
0442:                        double expansionError = 0;
0443:                        int count = 0;
0444:
0445:                        for (int i = 0; i < m_numFoldsPruning; i++) {
0446:                            Evaluation eval;
0447:
0448:                            // calculate error rate if only root node
0449:                            if (expansion == 0) {
0450:                                m_roots[i].m_isLeaf = true;
0451:                                eval = new Evaluation(test[i]);
0452:                                eval.evaluateModel(m_roots[i], test[i]);
0453:                                if (m_UseErrorRate)
0454:                                    expansionError += eval.errorRate();
0455:                                else
0456:                                    expansionError += eval
0457:                                            .rootMeanSquaredError();
0458:                                count++;
0459:                            }
0460:
0461:                            // make tree - expand one node at a time
0462:                            else {
0463:                                if (m_roots[i] == null)
0464:                                    continue; // if the tree cannot be expanded, go to next fold
0465:                                m_roots[i].m_isLeaf = false;
0466:                                BFTree nodeToSplit = (BFTree) (((FastVector) (parallelBFElements[i]
0467:                                        .elementAt(0))).elementAt(0));
0468:                                if (!m_roots[i].makeTree(parallelBFElements[i],
0469:                                        m_roots[i], train[i],
0470:                                        nodeToSplit.m_SortedIndices,
0471:                                        nodeToSplit.m_Weights,
0472:                                        nodeToSplit.m_Dists,
0473:                                        nodeToSplit.m_ClassProbs,
0474:                                        nodeToSplit.m_TotalWeight,
0475:                                        nodeToSplit.m_Props, m_minNumObj,
0476:                                        m_Heuristic, m_UseGini)) {
0477:                                    m_roots[i] = null; // cannot be expanded
0478:                                    continue;
0479:                                }
0480:                                eval = new Evaluation(test[i]);
0481:                                eval.evaluateModel(m_roots[i], test[i]);
0482:                                if (m_UseErrorRate)
0483:                                    expansionError += eval.errorRate();
0484:                                else
0485:                                    expansionError += eval
0486:                                            .rootMeanSquaredError();
0487:                                count++;
0488:                            }
0489:                        }
0490:
0491:                        // no tree can be expanded any more
0492:                        if (count == 0)
0493:                            break;
0494:
0495:                        expansionError /= count;
0496:                        errorList.addElement(new Double(expansionError));
0497:                        currentError = expansionError;
0498:
0499:                        if (!m_UseOneSE) {
0500:                            if (currentError > previousError)
0501:                                break;
0502:                        }
0503:
0504:                        else {
0505:                            if (expansionError < minError) {
0506:                                minError = expansionError;
0507:                                minExpansion = expansion;
0508:                            }
0509:
0510:                            if (currentError > previousError) {
0511:                                double oneSE = Math.sqrt(minError
0512:                                        * (1 - minError) / data.numInstances());
0513:                                if (currentError > minError + oneSE) {
0514:                                    break;
0515:                                }
0516:                            }
0517:                        }
0518:
0519:                        expansion++;
0520:                        previousError = currentError;
0521:                    }
0522:
0523:                    if (!m_UseOneSE)
0524:                        expansion = expansion - 1;
0525:                    else {
0526:                        double oneSE = Math.sqrt(minError * (1 - minError)
0527:                                / data.numInstances());
0528:                        for (int i = 0; i < errorList.size(); i++) {
0529:                            double error = ((Double) (errorList.elementAt(i)))
0530:                                    .doubleValue();
0531:                            if (error <= minError + oneSE) { // && counts[i]>=m_numFoldsPruning/2) {
0532:                                expansion = i;
0533:                                break;
0534:                            }
0535:                        }
0536:                    }
0537:                }
0538:
0539:                // build a postpruned tree
0540:                else {
0541:                    FastVector[] modelError = new FastVector[m_numFoldsPruning];
0542:
0543:                    // calculate error of each expansion for each fold
0544:                    for (int i = 0; i < m_numFoldsPruning; i++) {
0545:                        modelError[i] = new FastVector();
0546:
0547:                        m_roots[i].m_isLeaf = true;
0548:                        Evaluation eval = new Evaluation(test[i]);
0549:                        eval.evaluateModel(m_roots[i], test[i]);
0550:                        double error;
0551:                        if (m_UseErrorRate)
0552:                            error = eval.errorRate();
0553:                        else
0554:                            error = eval.rootMeanSquaredError();
0555:                        modelError[i].addElement(new Double(error));
0556:
0557:                        m_roots[i].m_isLeaf = false;
0558:                        BFTree nodeToSplit = (BFTree) (((FastVector) (parallelBFElements[i]
0559:                                .elementAt(0))).elementAt(0));
0560:
0561:                        m_roots[i].makeTree(parallelBFElements[i], m_roots[i],
0562:                                train[i], test[i], modelError[i],
0563:                                nodeToSplit.m_SortedIndices,
0564:                                nodeToSplit.m_Weights, nodeToSplit.m_Dists,
0565:                                nodeToSplit.m_ClassProbs,
0566:                                nodeToSplit.m_TotalWeight, nodeToSplit.m_Props,
0567:                                m_minNumObj, m_Heuristic, m_UseGini,
0568:                                m_UseErrorRate);
0569:                        m_roots[i] = null;
0570:                    }
0571:
0572:                    // find the expansion with minimal error rate
0573:                    double minError = Double.MAX_VALUE;
0574:
0575:                    int maxExpansion = modelError[0].size();
0576:                    for (int i = 1; i < modelError.length; i++) {
0577:                        if (modelError[i].size() > maxExpansion)
0578:                            maxExpansion = modelError[i].size();
0579:                    }
0580:
0581:                    double[] error = new double[maxExpansion];
0582:                    int[] counts = new int[maxExpansion];
0583:                    for (int i = 0; i < maxExpansion; i++) {
0584:                        counts[i] = 0;
0585:                        error[i] = 0;
0586:                        for (int j = 0; j < m_numFoldsPruning; j++) {
0587:                            if (i < modelError[j].size()) {
0588:                                error[i] += ((Double) modelError[j]
0589:                                        .elementAt(i)).doubleValue();
0590:                                counts[i]++;
0591:                            }
0592:                        }
0593:                        error[i] = error[i] / counts[i]; //average error for each expansion
0594:
0595:                        if (error[i] < minError) {// && counts[i]>=m_numFoldsPruning/2) {
0596:                            minError = error[i];
0597:                            expansion = i;
0598:                        }
0599:                    }
0600:
0601:                    // the 1 SE rule choosen
0602:                    if (m_UseOneSE) {
0603:                        double oneSE = Math.sqrt(minError * (1 - minError)
0604:                                / data.numInstances());
0605:                        for (int i = 0; i < maxExpansion; i++) {
0606:                            if (error[i] <= minError + oneSE) { // && counts[i]>=m_numFoldsPruning/2) {
0607:                                expansion = i;
0608:                                break;
0609:                            }
0610:                        }
0611:                    }
0612:                }
0613:
0614:                // make tree on all data based on the expansion caculated
0615:                // from cross-validation
0616:
0617:                // calculate sorted indices, weights and initial class counts
0618:                int[][] prune_sortedIndices = new int[data.numAttributes()][0];
0619:                double[][] prune_weights = new double[data.numAttributes()][0];
0620:                double[] prune_classProbs = new double[data.numClasses()];
0621:                double prune_totalWeight = computeSortedInfo(data,
0622:                        prune_sortedIndices, prune_weights, prune_classProbs);
0623:
0624:                // compute information of the best split for this node (include split attribute,
0625:                // split value and gini gain)
0626:                double[][][] prune_dists = new double[data.numAttributes()][2][data
0627:                        .numClasses()];
0628:                double[][] prune_props = new double[data.numAttributes()][2];
0629:                double[][] prune_totalSubsetWeights = new double[data
0630:                        .numAttributes()][2];
0631:                FastVector prune_nodeInfo = computeSplitInfo(this , data,
0632:                        prune_sortedIndices, prune_weights, prune_dists,
0633:                        prune_props, prune_totalSubsetWeights, m_Heuristic,
0634:                        m_UseGini);
0635:
0636:                // add the root node (with its split info) to BestFirstElements
0637:                FastVector BestFirstElements = new FastVector();
0638:                BestFirstElements.addElement(prune_nodeInfo);
0639:
0640:                int attIndex = ((Attribute) prune_nodeInfo.elementAt(1))
0641:                        .index();
0642:                m_Expansion = 0;
0643:                makeTree(BestFirstElements, data, prune_sortedIndices,
0644:                        prune_weights, prune_dists, prune_classProbs,
0645:                        prune_totalWeight, prune_props[attIndex], m_minNumObj,
0646:                        m_Heuristic, m_UseGini, expansion);
0647:            }
0648:
0649:            /**
0650:             * Recursively build a best-first decision tree.
0651:             * Method for building a Best-First tree for a given number of expansions.
0652:             * preExpasion is -1 means that no expansion is specified (just for a
0653:             * tree without any pruning method). Pre-pruning and post-pruning methods also
0654:             * use this method to build the final tree on all training data based on the
0655:             * expansion calculated from internal cross-validation.
0656:             *
0657:             * @param BestFirstElements 	list to store BFTree nodes
0658:             * @param data 		training data
0659:             * @param sortedIndices 	sorted indices of the instances
0660:             * @param weights 		weights of the instances
0661:             * @param dists 		class distributions for each attribute
0662:             * @param classProbs 		class probabilities of this node
0663:             * @param totalWeight 	total weight of this node (note if the node 
0664:             * 				can not split, this value is not calculated.)
0665:             * @param branchProps 	proportions of two subbranches
0666:             * @param minNumObj 		minimal number of instances at leaf nodes
0667:             * @param useHeuristic 	if use heuristic search for nominal attributes 
0668:             * 				in multi-class problem
0669:             * @param useGini 		if use Gini index as splitting criterion
0670:             * @param preExpansion 	the number of expansions the tree to be expanded
0671:             * @throws Exception 		if something goes wrong
0672:             */
0673:            protected void makeTree(FastVector BestFirstElements,
0674:                    Instances data, int[][] sortedIndices, double[][] weights,
0675:                    double[][][] dists, double[] classProbs,
0676:                    double totalWeight, double[] branchProps, int minNumObj,
0677:                    boolean useHeuristic, boolean useGini, int preExpansion)
0678:                    throws Exception {
0679:
0680:                if (BestFirstElements.size() == 0)
0681:                    return;
0682:
0683:                ///////////////////////////////////////////////////////////////////////
0684:                // All information about the node to split (the first BestFirst object in
0685:                // BestFirstElements)
0686:                FastVector firstElement = (FastVector) BestFirstElements
0687:                        .elementAt(0);
0688:
0689:                // split attribute
0690:                Attribute att = (Attribute) firstElement.elementAt(1);
0691:
0692:                // info of split value or split string
0693:                double splitValue = Double.NaN;
0694:                String splitStr = null;
0695:                if (att.isNumeric())
0696:                    splitValue = ((Double) firstElement.elementAt(2))
0697:                            .doubleValue();
0698:                else {
0699:                    splitStr = ((String) firstElement.elementAt(2)).toString();
0700:                }
0701:
0702:                // the best gini gain or information gain of this node
0703:                double gain = ((Double) firstElement.elementAt(3))
0704:                        .doubleValue();
0705:                ///////////////////////////////////////////////////////////////////////
0706:
0707:                if (m_ClassProbs == null) {
0708:                    m_SortedIndices = new int[sortedIndices.length][0];
0709:                    m_Weights = new double[weights.length][0];
0710:                    m_Dists = new double[dists.length][0][0];
0711:                    m_ClassProbs = new double[classProbs.length];
0712:                    m_Distribution = new double[classProbs.length];
0713:                    m_Props = new double[2];
0714:
0715:                    for (int i = 0; i < m_SortedIndices.length; i++) {
0716:                        m_SortedIndices[i] = sortedIndices[i];
0717:                        m_Weights[i] = weights[i];
0718:                        m_Dists[i] = dists[i];
0719:                    }
0720:
0721:                    System.arraycopy(classProbs, 0, m_ClassProbs, 0,
0722:                            classProbs.length);
0723:                    System.arraycopy(classProbs, 0, m_Distribution, 0,
0724:                            classProbs.length);
0725:                    System
0726:                            .arraycopy(branchProps, 0, m_Props, 0,
0727:                                    m_Props.length);
0728:                    m_TotalWeight = totalWeight;
0729:                    if (Utils.sum(m_ClassProbs) != 0)
0730:                        Utils.normalize(m_ClassProbs);
0731:                }
0732:
0733:                // If no enough data or this node can not be split, find next node to split.
0734:                if (totalWeight < 2 * minNumObj || branchProps[0] == 0
0735:                        || branchProps[1] == 0) {
0736:                    // remove the first element
0737:                    BestFirstElements.removeElementAt(0);
0738:
0739:                    makeLeaf(data);
0740:                    if (BestFirstElements.size() != 0) {
0741:                        FastVector nextSplitElement = (FastVector) BestFirstElements
0742:                                .elementAt(0);
0743:                        BFTree nextSplitNode = (BFTree) nextSplitElement
0744:                                .elementAt(0);
0745:                        nextSplitNode.makeTree(BestFirstElements, data,
0746:                                nextSplitNode.m_SortedIndices,
0747:                                nextSplitNode.m_Weights, nextSplitNode.m_Dists,
0748:                                nextSplitNode.m_ClassProbs,
0749:                                nextSplitNode.m_TotalWeight,
0750:                                nextSplitNode.m_Props, minNumObj, useHeuristic,
0751:                                useGini, preExpansion);
0752:                    }
0753:                    return;
0754:                }
0755:
0756:                // If gini gain or information gain is 0, make all nodes in the BestFirstElements leaf nodes
0757:                // because these nodes are sorted descendingly according to gini gain or information gain.
0758:                // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0).
0759:                if (gain == 0 || preExpansion == m_Expansion) {
0760:                    for (int i = 0; i < BestFirstElements.size(); i++) {
0761:                        FastVector element = (FastVector) BestFirstElements
0762:                                .elementAt(i);
0763:                        BFTree node = (BFTree) element.elementAt(0);
0764:                        node.makeLeaf(data);
0765:                    }
0766:                    BestFirstElements.removeAllElements();
0767:                }
0768:
0769:                // gain is not 0
0770:                else {
0771:                    // remove the first element
0772:                    BestFirstElements.removeElementAt(0);
0773:
0774:                    m_Attribute = att;
0775:                    if (m_Attribute.isNumeric())
0776:                        m_SplitValue = splitValue;
0777:                    else
0778:                        m_SplitString = splitStr;
0779:
0780:                    int[][][] subsetIndices = new int[2][data.numAttributes()][0];
0781:                    double[][][] subsetWeights = new double[2][data
0782:                            .numAttributes()][0];
0783:
0784:                    splitData(subsetIndices, subsetWeights, m_Attribute,
0785:                            m_SplitValue, m_SplitString, sortedIndices,
0786:                            weights, data);
0787:
0788:                    // If split will generate node(s) which has total weights less than m_minNumObj,
0789:                    // do not split.
0790:                    int attIndex = att.index();
0791:                    if (subsetIndices[0][attIndex].length < minNumObj
0792:                            || subsetIndices[1][attIndex].length < minNumObj) {
0793:                        makeLeaf(data);
0794:                    }
0795:
0796:                    // split the node
0797:                    else {
0798:                        m_isLeaf = false;
0799:                        m_Attribute = att;
0800:
0801:                        // if expansion is specified (if pruning method used)
0802:                        if ((m_PruningStrategy == PRUNING_PREPRUNING)
0803:                                || (m_PruningStrategy == PRUNING_POSTPRUNING)
0804:                                || (preExpansion != -1))
0805:                            m_Expansion++;
0806:
0807:                        makeSuccessors(BestFirstElements, data, subsetIndices,
0808:                                subsetWeights, dists, att, useHeuristic,
0809:                                useGini);
0810:                    }
0811:
0812:                    // choose next node to split
0813:                    if (BestFirstElements.size() != 0) {
0814:                        FastVector nextSplitElement = (FastVector) BestFirstElements
0815:                                .elementAt(0);
0816:                        BFTree nextSplitNode = (BFTree) nextSplitElement
0817:                                .elementAt(0);
0818:                        nextSplitNode.makeTree(BestFirstElements, data,
0819:                                nextSplitNode.m_SortedIndices,
0820:                                nextSplitNode.m_Weights, nextSplitNode.m_Dists,
0821:                                nextSplitNode.m_ClassProbs,
0822:                                nextSplitNode.m_TotalWeight,
0823:                                nextSplitNode.m_Props, minNumObj, useHeuristic,
0824:                                useGini, preExpansion);
0825:                    }
0826:
0827:                }
0828:            }
0829:
0830:            /**
0831:             * This method is to find the number of expansions based on internal 
0832:             * cross-validation for just pre-pruning. It expands the first BestFirst 
0833:             * node in the BestFirstElements if it is expansible, otherwise it looks 
0834:             * for next exapansible node. If it finds a node is expansibel, expand the 
0835:             * node, then return true. (note it just expands one node at a time).
0836:             *
0837:             * @param BestFirstElements 	list to store BFTree nodes
0838:             * @param root 		root node of tree in each fold
0839:             * @param train 		training data
0840:             * @param sortedIndices 	sorted indices of the instances
0841:             * @param weights 		weights of the instances
0842:             * @param dists 		class distributions for each attribute
0843:             * @param classProbs 		class probabilities of this node
0844:             * @param totalWeight 	total weight of this node (note if the node 
0845:             * 				can not split, this value is not calculated.)
0846:             * @param branchProps 	proportions of two subbranches
0847:             * @param minNumObj 	minimal number of instances at leaf nodes
0848:             * @param useHeuristic 	if use heuristic search for nominal attributes 
0849:             * 				in multi-class problem
0850:             * @param useGini 		if use Gini index as splitting criterion
0851:             * @return true 		if expand successfully, otherwise return false 
0852:             * 				(all nodes in BestFirstElements cannot be 
0853:             * 				expanded).
0854:             * @throws Exception 		if something goes wrong
0855:             */
0856:            protected boolean makeTree(FastVector BestFirstElements,
0857:                    BFTree root, Instances train, int[][] sortedIndices,
0858:                    double[][] weights, double[][][] dists,
0859:                    double[] classProbs, double totalWeight,
0860:                    double[] branchProps, int minNumObj, boolean useHeuristic,
0861:                    boolean useGini) throws Exception {
0862:
0863:                if (BestFirstElements.size() == 0)
0864:                    return false;
0865:
0866:                ///////////////////////////////////////////////////////////////////////
0867:                // All information about the node to split (first BestFirst object in
0868:                // BestFirstElements)
0869:                FastVector firstElement = (FastVector) BestFirstElements
0870:                        .elementAt(0);
0871:
0872:                // node to split
0873:                BFTree nodeToSplit = (BFTree) firstElement.elementAt(0);
0874:
0875:                // split attribute
0876:                Attribute att = (Attribute) firstElement.elementAt(1);
0877:
0878:                // info of split value or split string
0879:                double splitValue = Double.NaN;
0880:                String splitStr = null;
0881:                if (att.isNumeric())
0882:                    splitValue = ((Double) firstElement.elementAt(2))
0883:                            .doubleValue();
0884:                else {
0885:                    splitStr = ((String) firstElement.elementAt(2)).toString();
0886:                }
0887:
0888:                // the best gini gain or information gain of this node
0889:                double gain = ((Double) firstElement.elementAt(3))
0890:                        .doubleValue();
0891:                ///////////////////////////////////////////////////////////////////////
0892:
0893:                // If no enough data to split for this node or this node can not be split find next node to split.
0894:                if (totalWeight < 2 * minNumObj || branchProps[0] == 0
0895:                        || branchProps[1] == 0) {
0896:                    // remove the first element
0897:                    BestFirstElements.removeElementAt(0);
0898:                    nodeToSplit.makeLeaf(train);
0899:                    BFTree nextNode = (BFTree) ((FastVector) BestFirstElements
0900:                            .elementAt(0)).elementAt(0);
0901:                    return root.makeTree(BestFirstElements, root, train,
0902:                            nextNode.m_SortedIndices, nextNode.m_Weights,
0903:                            nextNode.m_Dists, nextNode.m_ClassProbs,
0904:                            nextNode.m_TotalWeight, nextNode.m_Props,
0905:                            minNumObj, useHeuristic, useGini);
0906:                }
0907:
0908:                // If gini gain or information is 0, make all nodes in the BestFirstElements leaf nodes
0909:                // because these node sorted descendingly according to gini gain or information gain.
0910:                // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0).
0911:                if (gain == 0) {
0912:                    for (int i = 0; i < BestFirstElements.size(); i++) {
0913:                        FastVector element = (FastVector) BestFirstElements
0914:                                .elementAt(i);
0915:                        BFTree node = (BFTree) element.elementAt(0);
0916:                        node.makeLeaf(train);
0917:                    }
0918:                    BestFirstElements.removeAllElements();
0919:                    return false;
0920:                }
0921:
0922:                else {
0923:                    // remove the first element
0924:                    BestFirstElements.removeElementAt(0);
0925:                    nodeToSplit.m_Attribute = att;
0926:                    if (att.isNumeric())
0927:                        nodeToSplit.m_SplitValue = splitValue;
0928:                    else
0929:                        nodeToSplit.m_SplitString = splitStr;
0930:
0931:                    int[][][] subsetIndices = new int[2][train.numAttributes()][0];
0932:                    double[][][] subsetWeights = new double[2][train
0933:                            .numAttributes()][0];
0934:
0935:                    splitData(subsetIndices, subsetWeights,
0936:                            nodeToSplit.m_Attribute, nodeToSplit.m_SplitValue,
0937:                            nodeToSplit.m_SplitString,
0938:                            nodeToSplit.m_SortedIndices, nodeToSplit.m_Weights,
0939:                            train);
0940:
0941:                    // if split will generate node(s) which has total weights less than m_minNumObj,
0942:                    // do not split
0943:                    int attIndex = att.index();
0944:                    if (subsetIndices[0][attIndex].length < minNumObj
0945:                            || subsetIndices[1][attIndex].length < minNumObj) {
0946:
0947:                        nodeToSplit.makeLeaf(train);
0948:                        BFTree nextNode = (BFTree) ((FastVector) BestFirstElements
0949:                                .elementAt(0)).elementAt(0);
0950:                        return root.makeTree(BestFirstElements, root, train,
0951:                                nextNode.m_SortedIndices, nextNode.m_Weights,
0952:                                nextNode.m_Dists, nextNode.m_ClassProbs,
0953:                                nextNode.m_TotalWeight, nextNode.m_Props,
0954:                                minNumObj, useHeuristic, useGini);
0955:                    }
0956:
0957:                    // split the node
0958:                    else {
0959:                        nodeToSplit.m_isLeaf = false;
0960:                        nodeToSplit.m_Attribute = att;
0961:
0962:                        nodeToSplit.makeSuccessors(BestFirstElements, train,
0963:                                subsetIndices, subsetWeights, dists,
0964:                                nodeToSplit.m_Attribute, useHeuristic, useGini);
0965:
0966:                        for (int i = 0; i < 2; i++) {
0967:                            nodeToSplit.m_Successors[i].makeLeaf(train);
0968:                        }
0969:
0970:                        return true;
0971:                    }
0972:                }
0973:            }
0974:
0975:            /**
0976:             * This method is to find the number of expansions based on internal 
0977:             * cross-validation for just post-pruning. It expands the first BestFirst 
0978:             * node in the BestFirstElements until no node can be split. When building 
0979:             * the tree, stroe error for each temporary tree, namely for each expansion.
0980:             *
0981:             * @param BestFirstElements 	list to store BFTree nodes
0982:             * @param root 		root node of tree in each fold
0983:             * @param train 		training data in each fold
0984:             * @param test 		test data in each fold
0985:             * @param modelError 		list to store error for each expansion in 
0986:             * 				each fold
0987:             * @param sortedIndices 	sorted indices of the instances
0988:             * @param weights 		weights of the instances
0989:             * @param dists 		class distributions for each attribute
0990:             * @param classProbs 		class probabilities of this node
0991:             * @param totalWeight 	total weight of this node (note if the node 
0992:             * 				can not split, this value is not calculated.)
0993:             * @param branchProps 	proportions of two subbranches
0994:             * @param minNumObj 		minimal number of instances at leaf nodes
0995:             * @param useHeuristic 	if use heuristic search for nominal attributes 
0996:             * 				in multi-class problem
0997:             * @param useGini 		if use Gini index as splitting criterion
0998:             * @param useErrorRate 	if use error rate in internal cross-validation
0999:             * @throws Exception 		if something goes wrong
1000:             */
1001:            protected void makeTree(FastVector BestFirstElements, BFTree root,
1002:                    Instances train, Instances test, FastVector modelError,
1003:                    int[][] sortedIndices, double[][] weights,
1004:                    double[][][] dists, double[] classProbs,
1005:                    double totalWeight, double[] branchProps, int minNumObj,
1006:                    boolean useHeuristic, boolean useGini, boolean useErrorRate)
1007:                    throws Exception {
1008:
1009:                if (BestFirstElements.size() == 0)
1010:                    return;
1011:
1012:                ///////////////////////////////////////////////////////////////////////
1013:                // All information about the node to split (first BestFirst object in
1014:                // BestFirstElements)
1015:                FastVector firstElement = (FastVector) BestFirstElements
1016:                        .elementAt(0);
1017:
1018:                // node to split
1019:                //BFTree nodeToSplit = (BFTree)firstElement.elementAt(0);
1020:
1021:                // split attribute
1022:                Attribute att = (Attribute) firstElement.elementAt(1);
1023:
1024:                // info of split value or split string
1025:                double splitValue = Double.NaN;
1026:                String splitStr = null;
1027:                if (att.isNumeric())
1028:                    splitValue = ((Double) firstElement.elementAt(2))
1029:                            .doubleValue();
1030:                else {
1031:                    splitStr = ((String) firstElement.elementAt(2)).toString();
1032:                }
1033:
1034:                // the best gini gain or information of this node
1035:                double gain = ((Double) firstElement.elementAt(3))
1036:                        .doubleValue();
1037:                ///////////////////////////////////////////////////////////////////////
1038:
1039:                if (totalWeight < 2 * minNumObj || branchProps[0] == 0
1040:                        || branchProps[1] == 0) {
1041:                    // remove the first element
1042:                    BestFirstElements.removeElementAt(0);
1043:                    makeLeaf(train);
1044:                    BFTree nextSplitNode = (BFTree) ((FastVector) BestFirstElements
1045:                            .elementAt(0)).elementAt(0);
1046:                    nextSplitNode.makeTree(BestFirstElements, root, train,
1047:                            test, modelError, nextSplitNode.m_SortedIndices,
1048:                            nextSplitNode.m_Weights, nextSplitNode.m_Dists,
1049:                            nextSplitNode.m_ClassProbs,
1050:                            nextSplitNode.m_TotalWeight, nextSplitNode.m_Props,
1051:                            minNumObj, useHeuristic, useGini, useErrorRate);
1052:                    return;
1053:
1054:                }
1055:
1056:                // If gini gain or information gain is 0, make all nodes in the BestFirstElements leaf nodes
1057:                // because these node sorted descendingly according to gini gain or information gain.
1058:                // (namely, gini gain or information gain of all nodes in BestFirstEelements is 0).
1059:                if (gain == 0) {
1060:                    for (int i = 0; i < BestFirstElements.size(); i++) {
1061:                        FastVector element = (FastVector) BestFirstElements
1062:                                .elementAt(i);
1063:                        BFTree node = (BFTree) element.elementAt(0);
1064:                        node.makeLeaf(train);
1065:                    }
1066:                    BestFirstElements.removeAllElements();
1067:                }
1068:
1069:                // gini gain or information gain is not 0
1070:                else {
1071:                    // remove the first element
1072:                    BestFirstElements.removeElementAt(0);
1073:                    m_Attribute = att;
1074:                    if (att.isNumeric())
1075:                        m_SplitValue = splitValue;
1076:                    else
1077:                        m_SplitString = splitStr;
1078:
1079:                    int[][][] subsetIndices = new int[2][train.numAttributes()][0];
1080:                    double[][][] subsetWeights = new double[2][train
1081:                            .numAttributes()][0];
1082:
1083:                    splitData(subsetIndices, subsetWeights, m_Attribute,
1084:                            m_SplitValue, m_SplitString, sortedIndices,
1085:                            weights, train);
1086:
1087:                    // if split will generate node(s) which has total weights less than m_minNumObj,
1088:                    // do not split
1089:                    int attIndex = att.index();
1090:                    if (subsetIndices[0][attIndex].length < minNumObj
1091:                            || subsetIndices[1][attIndex].length < minNumObj) {
1092:                        makeLeaf(train);
1093:                    }
1094:
1095:                    // split the node and cauculate error rate of this temporary tree
1096:                    else {
1097:                        m_isLeaf = false;
1098:                        m_Attribute = att;
1099:
1100:                        makeSuccessors(BestFirstElements, train, subsetIndices,
1101:                                subsetWeights, dists, m_Attribute,
1102:                                useHeuristic, useGini);
1103:                        for (int i = 0; i < 2; i++) {
1104:                            m_Successors[i].makeLeaf(train);
1105:                        }
1106:
1107:                        Evaluation eval = new Evaluation(test);
1108:                        eval.evaluateModel(root, test);
1109:                        double error;
1110:                        if (useErrorRate)
1111:                            error = eval.errorRate();
1112:                        else
1113:                            error = eval.rootMeanSquaredError();
1114:                        modelError.addElement(new Double(error));
1115:                    }
1116:
1117:                    if (BestFirstElements.size() != 0) {
1118:                        FastVector nextSplitElement = (FastVector) BestFirstElements
1119:                                .elementAt(0);
1120:                        BFTree nextSplitNode = (BFTree) nextSplitElement
1121:                                .elementAt(0);
1122:                        nextSplitNode.makeTree(BestFirstElements, root, train,
1123:                                test, modelError,
1124:                                nextSplitNode.m_SortedIndices,
1125:                                nextSplitNode.m_Weights, nextSplitNode.m_Dists,
1126:                                nextSplitNode.m_ClassProbs,
1127:                                nextSplitNode.m_TotalWeight,
1128:                                nextSplitNode.m_Props, minNumObj, useHeuristic,
1129:                                useGini, useErrorRate);
1130:                    }
1131:                }
1132:            }
1133:
1134:            /**
1135:             * Generate successor nodes for a node and put them into BestFirstElements 
1136:             * according to gini gain or information gain in a descending order.
1137:             *
1138:             * @param BestFirstElements 	list to store BestFirst nodes
1139:             * @param data 		training instance
1140:             * @param subsetSortedIndices	sorted indices of instances of successor nodes
1141:             * @param subsetWeights 	weights of instances of successor nodes
1142:             * @param dists 		class distributions of successor nodes
1143:             * @param att 		attribute used to split the node
1144:             * @param useHeuristic 	if use heuristic search for nominal attributes in multi-class problem
1145:             * @param useGini 		if use Gini index as splitting criterion
1146:             * @throws Exception 		if something goes wrong 
1147:             */
1148:            protected void makeSuccessors(FastVector BestFirstElements,
1149:                    Instances data, int[][][] subsetSortedIndices,
1150:                    double[][][] subsetWeights, double[][][] dists,
1151:                    Attribute att, boolean useHeuristic, boolean useGini)
1152:                    throws Exception {
1153:
1154:                m_Successors = new BFTree[2];
1155:
1156:                for (int i = 0; i < 2; i++) {
1157:                    m_Successors[i] = new BFTree();
1158:                    m_Successors[i].m_isLeaf = true;
1159:
1160:                    // class probability and distribution for this successor node
1161:                    m_Successors[i].m_ClassProbs = new double[data.numClasses()];
1162:                    m_Successors[i].m_Distribution = new double[data
1163:                            .numClasses()];
1164:                    System.arraycopy(dists[att.index()][i], 0,
1165:                            m_Successors[i].m_ClassProbs, 0,
1166:                            m_Successors[i].m_ClassProbs.length);
1167:                    System.arraycopy(dists[att.index()][i], 0,
1168:                            m_Successors[i].m_Distribution, 0,
1169:                            m_Successors[i].m_Distribution.length);
1170:                    if (Utils.sum(m_Successors[i].m_ClassProbs) != 0)
1171:                        Utils.normalize(m_Successors[i].m_ClassProbs);
1172:
1173:                    // split information for this successor node
1174:                    double[][] props = new double[data.numAttributes()][2];
1175:                    double[][][] subDists = new double[data.numAttributes()][2][data
1176:                            .numClasses()];
1177:                    double[][] totalSubsetWeights = new double[data
1178:                            .numAttributes()][2];
1179:                    FastVector splitInfo = m_Successors[i].computeSplitInfo(
1180:                            m_Successors[i], data, subsetSortedIndices[i],
1181:                            subsetWeights[i], subDists, props,
1182:                            totalSubsetWeights, useHeuristic, useGini);
1183:
1184:                    // branch proportion for this successor node
1185:                    int splitIndex = ((Attribute) splitInfo.elementAt(1))
1186:                            .index();
1187:                    m_Successors[i].m_Props = new double[2];
1188:                    System.arraycopy(props[splitIndex], 0,
1189:                            m_Successors[i].m_Props, 0,
1190:                            m_Successors[i].m_Props.length);
1191:
1192:                    // sorted indices and weights of each attribute for this successor node
1193:                    m_Successors[i].m_SortedIndices = new int[data
1194:                            .numAttributes()][0];
1195:                    m_Successors[i].m_Weights = new double[data.numAttributes()][0];
1196:                    for (int j = 0; j < m_Successors[i].m_SortedIndices.length; j++) {
1197:                        m_Successors[i].m_SortedIndices[j] = subsetSortedIndices[i][j];
1198:                        m_Successors[i].m_Weights[j] = subsetWeights[i][j];
1199:                    }
1200:
1201:                    // distribution of each attribute for this successor node
1202:                    m_Successors[i].m_Dists = new double[data.numAttributes()][2][data
1203:                            .numClasses()];
1204:                    for (int j = 0; j < subDists.length; j++) {
1205:                        m_Successors[i].m_Dists[j] = subDists[j];
1206:                    }
1207:
1208:                    // total weights for this successor node. 
1209:                    m_Successors[i].m_TotalWeight = Utils
1210:                            .sum(totalSubsetWeights[splitIndex]);
1211:
1212:                    // insert this successor node into BestFirstElements according to gini gain or information gain
1213:                    //  descendingly
1214:                    if (BestFirstElements.size() == 0) {
1215:                        BestFirstElements.addElement(splitInfo);
1216:                    } else {
1217:                        double gGain = ((Double) (splitInfo.elementAt(3)))
1218:                                .doubleValue();
1219:                        int vectorSize = BestFirstElements.size();
1220:                        FastVector lastNode = (FastVector) BestFirstElements
1221:                                .elementAt(vectorSize - 1);
1222:
1223:                        // If gini gain is less than that of last node in FastVector
1224:                        if (gGain < ((Double) (lastNode.elementAt(3)))
1225:                                .doubleValue()) {
1226:                            BestFirstElements.insertElementAt(splitInfo,
1227:                                    vectorSize);
1228:                        } else {
1229:                            for (int j = 0; j < vectorSize; j++) {
1230:                                FastVector node = (FastVector) BestFirstElements
1231:                                        .elementAt(j);
1232:                                double nodeGain = ((Double) (node.elementAt(3)))
1233:                                        .doubleValue();
1234:                                if (gGain >= nodeGain) {
1235:                                    BestFirstElements.insertElementAt(
1236:                                            splitInfo, j);
1237:                                    break;
1238:                                }
1239:                            }
1240:                        }
1241:                    }
1242:                }
1243:            }
1244:
1245:            /**
1246:             * Compute sorted indices, weights and class probabilities for a given 
1247:             * dataset. Return total weights of the data at the node.
1248:             * 
1249:             * @param data 		training data
1250:             * @param sortedIndices 	sorted indices of instances at the node
1251:             * @param weights 		weights of instances at the node
1252:             * @param classProbs 		class probabilities at the node
1253:             * @return 			total weights of instances at the node
1254:             * @throws Exception 		if something goes wrong
1255:             */
1256:            protected double computeSortedInfo(Instances data,
1257:                    int[][] sortedIndices, double[][] weights,
1258:                    double[] classProbs) throws Exception {
1259:
1260:                // Create array of sorted indices and weights
1261:                double[] vals = new double[data.numInstances()];
1262:                for (int j = 0; j < data.numAttributes(); j++) {
1263:                    if (j == data.classIndex())
1264:                        continue;
1265:                    weights[j] = new double[data.numInstances()];
1266:
1267:                    if (data.attribute(j).isNominal()) {
1268:
1269:                        // Handling nominal attributes. Putting indices of
1270:                        // instances with missing values at the end.
1271:                        sortedIndices[j] = new int[data.numInstances()];
1272:                        int count = 0;
1273:                        for (int i = 0; i < data.numInstances(); i++) {
1274:                            Instance inst = data.instance(i);
1275:                            if (!inst.isMissing(j)) {
1276:                                sortedIndices[j][count] = i;
1277:                                weights[j][count] = inst.weight();
1278:                                count++;
1279:                            }
1280:                        }
1281:                        for (int i = 0; i < data.numInstances(); i++) {
1282:                            Instance inst = data.instance(i);
1283:                            if (inst.isMissing(j)) {
1284:                                sortedIndices[j][count] = i;
1285:                                weights[j][count] = inst.weight();
1286:                                count++;
1287:                            }
1288:                        }
1289:                    } else {
1290:
1291:                        // Sorted indices are computed for numeric attributes
1292:                        // missing values instances are put to end (through Utils.sort() method)
1293:                        for (int i = 0; i < data.numInstances(); i++) {
1294:                            Instance inst = data.instance(i);
1295:                            vals[i] = inst.value(j);
1296:                        }
1297:                        sortedIndices[j] = Utils.sort(vals);
1298:                        for (int i = 0; i < data.numInstances(); i++) {
1299:                            weights[j][i] = data.instance(sortedIndices[j][i])
1300:                                    .weight();
1301:                        }
1302:                    }
1303:                }
1304:
1305:                // Compute initial class counts and total weight
1306:                double totalWeight = 0;
1307:                for (int i = 0; i < data.numInstances(); i++) {
1308:                    Instance inst = data.instance(i);
1309:                    classProbs[(int) inst.classValue()] += inst.weight();
1310:                    totalWeight += inst.weight();
1311:                }
1312:
1313:                return totalWeight;
1314:            }
1315:
1316:            /**
1317:             * Compute the best splitting attribute, split point or subset and the best
1318:             * gini gain or iformation gain for a given dataset.
1319:             *
1320:             * @param node 		node to be split
1321:             * @param data 		training data
1322:             * @param sortedIndices 	sorted indices of the instances
1323:             * @param weights 		weights of the instances
1324:             * @param dists 		class distributions for each attribute
1325:             * @param props 		proportions of two branches
1326:             * @param totalSubsetWeights 	total weight of two subsets
1327:             * @param useHeuristic 	if use heuristic search for nominal attributes 
1328:             * 				in multi-class problem
1329:             * @param useGini 		if use Gini index as splitting criterion
1330:             * @return 			split information about the node
1331:             * @throws Exception 		if something is wrong
1332:             */
1333:            protected FastVector computeSplitInfo(BFTree node, Instances data,
1334:                    int[][] sortedIndices, double[][] weights,
1335:                    double[][][] dists, double[][] props,
1336:                    double[][] totalSubsetWeights, boolean useHeuristic,
1337:                    boolean useGini) throws Exception {
1338:
1339:                double[] splits = new double[data.numAttributes()];
1340:                String[] splitString = new String[data.numAttributes()];
1341:                double[] gains = new double[data.numAttributes()];
1342:
1343:                for (int i = 0; i < data.numAttributes(); i++) {
1344:                    if (i == data.classIndex())
1345:                        continue;
1346:                    Attribute att = data.attribute(i);
1347:                    if (att.isNumeric()) {
1348:                        // numeric attribute
1349:                        splits[i] = numericDistribution(props, dists, att,
1350:                                sortedIndices[i], weights[i],
1351:                                totalSubsetWeights, gains, data, useGini);
1352:                    } else {
1353:                        // nominal attribute
1354:                        splitString[i] = nominalDistribution(props, dists, att,
1355:                                sortedIndices[i], weights[i],
1356:                                totalSubsetWeights, gains, data, useHeuristic,
1357:                                useGini);
1358:                    }
1359:                }
1360:
1361:                int index = Utils.maxIndex(gains);
1362:                double mBestGain = gains[index];
1363:
1364:                Attribute att = data.attribute(index);
1365:                double mValue = Double.NaN;
1366:                String mString = null;
1367:                if (att.isNumeric())
1368:                    mValue = splits[index];
1369:                else {
1370:                    mString = splitString[index];
1371:                    if (mString == null)
1372:                        mString = "";
1373:                }
1374:
1375:                // split information
1376:                FastVector splitInfo = new FastVector();
1377:                splitInfo.addElement(node);
1378:                splitInfo.addElement(att);
1379:                if (att.isNumeric())
1380:                    splitInfo.addElement(new Double(mValue));
1381:                else
1382:                    splitInfo.addElement(mString);
1383:                splitInfo.addElement(new Double(mBestGain));
1384:
1385:                return splitInfo;
1386:            }
1387:
1388:            /**
1389:             * Compute distributions, proportions and total weights of two successor nodes for 
1390:             * a given numeric attribute.
1391:             *
1392:             * @param props 		proportions of each two branches for each attribute
1393:             * @param dists 		class distributions of two branches for each attribute
1394:             * @param att 		numeric att split on
1395:             * @param sortedIndices 	sorted indices of instances for the attirubte
1396:             * @param weights 		weights of instances for the attirbute
1397:             * @param subsetWeights 	total weight of two branches split based on the attribute
1398:             * @param gains 		Gini gains or information gains for each attribute 
1399:             * @param data 		training instances
1400:             * @param useGini 		if use Gini index as splitting criterion
1401:             * @return 			Gini gain or information gain for the given attribute
1402:             * @throws Exception 		if something goes wrong
1403:             */
1404:            protected double numericDistribution(double[][] props,
1405:                    double[][][] dists, Attribute att, int[] sortedIndices,
1406:                    double[] weights, double[][] subsetWeights, double[] gains,
1407:                    Instances data, boolean useGini) throws Exception {
1408:
1409:                double splitPoint = Double.NaN;
1410:                double[][] dist = null;
1411:                int numClasses = data.numClasses();
1412:                int i; // differ instances with or without missing values
1413:
1414:                double[][] currDist = new double[2][numClasses];
1415:                dist = new double[2][numClasses];
1416:
1417:                // Move all instances without missing values into second subset
1418:                double[] parentDist = new double[numClasses];
1419:                int missingStart = 0;
1420:                for (int j = 0; j < sortedIndices.length; j++) {
1421:                    Instance inst = data.instance(sortedIndices[j]);
1422:                    if (!inst.isMissing(att)) {
1423:                        missingStart++;
1424:                        currDist[1][(int) inst.classValue()] += weights[j];
1425:                    }
1426:                    parentDist[(int) inst.classValue()] += weights[j];
1427:                }
1428:                System.arraycopy(currDist[1], 0, dist[1], 0, dist[1].length);
1429:
1430:                // Try all possible split points
1431:                double currSplit = data.instance(sortedIndices[0]).value(att);
1432:                double currGain;
1433:                double bestGain = -Double.MAX_VALUE;
1434:
1435:                for (i = 0; i < sortedIndices.length; i++) {
1436:                    Instance inst = data.instance(sortedIndices[i]);
1437:                    if (inst.isMissing(att)) {
1438:                        break;
1439:                    }
1440:                    if (inst.value(att) > currSplit) {
1441:
1442:                        double[][] tempDist = new double[2][numClasses];
1443:                        for (int k = 0; k < 2; k++) {
1444:                            //tempDist[k] = currDist[k];
1445:                            System.arraycopy(currDist[k], 0, tempDist[k], 0,
1446:                                    tempDist[k].length);
1447:                        }
1448:
1449:                        double[] tempProps = new double[2];
1450:                        for (int k = 0; k < 2; k++) {
1451:                            tempProps[k] = Utils.sum(tempDist[k]);
1452:                        }
1453:
1454:                        if (Utils.sum(tempProps) != 0)
1455:                            Utils.normalize(tempProps);
1456:
1457:                        // split missing values
1458:                        int index = missingStart;
1459:                        while (index < sortedIndices.length) {
1460:                            Instance insta = data
1461:                                    .instance(sortedIndices[index]);
1462:                            for (int j = 0; j < 2; j++) {
1463:                                tempDist[j][(int) insta.classValue()] += tempProps[j]
1464:                                        * weights[index];
1465:                            }
1466:                            index++;
1467:                        }
1468:
1469:                        if (useGini)
1470:                            currGain = computeGiniGain(parentDist, tempDist);
1471:                        else
1472:                            currGain = computeInfoGain(parentDist, tempDist);
1473:
1474:                        if (currGain > bestGain) {
1475:                            bestGain = currGain;
1476:                            // clean split point
1477:                            splitPoint = Math
1478:                                    .rint((inst.value(att) + currSplit) / 2.0 * 100000) / 100000.0;
1479:                            for (int j = 0; j < currDist.length; j++) {
1480:                                System.arraycopy(tempDist[j], 0, dist[j], 0,
1481:                                        dist[j].length);
1482:                            }
1483:                        }
1484:                    }
1485:                    currSplit = inst.value(att);
1486:                    currDist[0][(int) inst.classValue()] += weights[i];
1487:                    currDist[1][(int) inst.classValue()] -= weights[i];
1488:                }
1489:
1490:                // Compute weights
1491:                int attIndex = att.index();
1492:                props[attIndex] = new double[2];
1493:                for (int k = 0; k < 2; k++) {
1494:                    props[attIndex][k] = Utils.sum(dist[k]);
1495:                }
1496:                if (Utils.sum(props[attIndex]) != 0)
1497:                    Utils.normalize(props[attIndex]);
1498:
1499:                // Compute subset weights
1500:                subsetWeights[attIndex] = new double[2];
1501:                for (int j = 0; j < 2; j++) {
1502:                    subsetWeights[attIndex][j] += Utils.sum(dist[j]);
1503:                }
1504:
1505:                // clean gain
1506:                gains[attIndex] = Math.rint(bestGain * 10000000) / 10000000.0;
1507:                dists[attIndex] = dist;
1508:                return splitPoint;
1509:            }
1510:
1511:            /**
1512:             * Compute distributions, proportions and total weights of two successor 
1513:             * nodes for a given nominal attribute.
1514:             *
1515:             * @param props 		proportions of each two branches for each attribute
1516:             * @param dists 		class distributions of two branches for each attribute
1517:             * @param att 		numeric att split on
1518:             * @param sortedIndices 	sorted indices of instances for the attirubte
1519:             * @param weights 		weights of instances for the attirbute
1520:             * @param subsetWeights 	total weight of two branches split based on the attribute
1521:             * @param gains 		Gini gains for each attribute 
1522:             * @param data 		training instances
1523:             * @param useHeuristic 	if use heuristic search
1524:             * @param useGini 		if use Gini index as splitting criterion
1525:             * @return 			Gini gain for the given attribute
1526:             * @throws Exception 		if something goes wrong
1527:             */
1528:            protected String nominalDistribution(double[][] props,
1529:                    double[][][] dists, Attribute att, int[] sortedIndices,
1530:                    double[] weights, double[][] subsetWeights, double[] gains,
1531:                    Instances data, boolean useHeuristic, boolean useGini)
1532:                    throws Exception {
1533:
1534:                String[] values = new String[att.numValues()];
1535:                int numCat = values.length; // number of values of the attribute
1536:                int numClasses = data.numClasses();
1537:
1538:                String bestSplitString = "";
1539:                double bestGain = -Double.MAX_VALUE;
1540:
1541:                // class frequency for each value
1542:                int[] classFreq = new int[numCat];
1543:                for (int j = 0; j < numCat; j++)
1544:                    classFreq[j] = 0;
1545:
1546:                double[] parentDist = new double[numClasses];
1547:                double[][] currDist = new double[2][numClasses];
1548:                double[][] dist = new double[2][numClasses];
1549:                int missingStart = 0;
1550:
1551:                for (int i = 0; i < sortedIndices.length; i++) {
1552:                    Instance inst = data.instance(sortedIndices[i]);
1553:                    if (!inst.isMissing(att)) {
1554:                        missingStart++;
1555:                        classFreq[(int) inst.value(att)]++;
1556:                    }
1557:                    parentDist[(int) inst.classValue()] += weights[i];
1558:                }
1559:
1560:                // count the number of values that class frequency is not 0
1561:                int nonEmpty = 0;
1562:                for (int j = 0; j < numCat; j++) {
1563:                    if (classFreq[j] != 0)
1564:                        nonEmpty++;
1565:                }
1566:
1567:                // attribute values which class frequency is not 0
1568:                String[] nonEmptyValues = new String[nonEmpty];
1569:                int nonEmptyIndex = 0;
1570:                for (int j = 0; j < numCat; j++) {
1571:                    if (classFreq[j] != 0) {
1572:                        nonEmptyValues[nonEmptyIndex] = att.value(j);
1573:                        nonEmptyIndex++;
1574:                    }
1575:                }
1576:
1577:                // attribute values which class frequency is 0
1578:                int empty = numCat - nonEmpty;
1579:                String[] emptyValues = new String[empty];
1580:                int emptyIndex = 0;
1581:                for (int j = 0; j < numCat; j++) {
1582:                    if (classFreq[j] == 0) {
1583:                        emptyValues[emptyIndex] = att.value(j);
1584:                        emptyIndex++;
1585:                    }
1586:                }
1587:
1588:                if (nonEmpty <= 1) {
1589:                    gains[att.index()] = 0;
1590:                    return "";
1591:                }
1592:
1593:                // for tow-class probloms
1594:                if (data.numClasses() == 2) {
1595:
1596:                    //// Firstly, for attribute values which class frequency is not zero
1597:
1598:                    // probability of class 0 for each attribute value
1599:                    double[] pClass0 = new double[nonEmpty];
1600:                    // class distribution for each attribute value
1601:                    double[][] valDist = new double[nonEmpty][2];
1602:
1603:                    for (int j = 0; j < nonEmpty; j++) {
1604:                        for (int k = 0; k < 2; k++) {
1605:                            valDist[j][k] = 0;
1606:                        }
1607:                    }
1608:
1609:                    for (int i = 0; i < sortedIndices.length; i++) {
1610:                        Instance inst = data.instance(sortedIndices[i]);
1611:                        if (inst.isMissing(att)) {
1612:                            break;
1613:                        }
1614:
1615:                        for (int j = 0; j < nonEmpty; j++) {
1616:                            if (att.value((int) inst.value(att)).compareTo(
1617:                                    nonEmptyValues[j]) == 0) {
1618:                                valDist[j][(int) inst.classValue()] += inst
1619:                                        .weight();
1620:                                break;
1621:                            }
1622:                        }
1623:                    }
1624:
1625:                    for (int j = 0; j < nonEmpty; j++) {
1626:                        double distSum = Utils.sum(valDist[j]);
1627:                        if (distSum == 0)
1628:                            pClass0[j] = 0;
1629:                        else
1630:                            pClass0[j] = valDist[j][0] / distSum;
1631:                    }
1632:
1633:                    // sort category according to the probability of class 0.0
1634:                    String[] sortedValues = new String[nonEmpty];
1635:                    for (int j = 0; j < nonEmpty; j++) {
1636:                        sortedValues[j] = nonEmptyValues[Utils
1637:                                .minIndex(pClass0)];
1638:                        pClass0[Utils.minIndex(pClass0)] = Double.MAX_VALUE;
1639:                    }
1640:
1641:                    // Find a subset of attribute values that maximize impurity decrease
1642:
1643:                    // for the attribute values that class frequency is not 0
1644:                    String tempStr = "";
1645:
1646:                    for (int j = 0; j < nonEmpty - 1; j++) {
1647:                        currDist = new double[2][numClasses];
1648:                        if (tempStr == "")
1649:                            tempStr = "(" + sortedValues[j] + ")";
1650:                        else
1651:                            tempStr += "|" + "(" + sortedValues[j] + ")";
1652:                        //System.out.println(sortedValues[j]);
1653:                        for (int i = 0; i < sortedIndices.length; i++) {
1654:                            Instance inst = data.instance(sortedIndices[i]);
1655:                            if (inst.isMissing(att)) {
1656:                                break;
1657:                            }
1658:
1659:                            if (tempStr.indexOf("("
1660:                                    + att.value((int) inst.value(att)) + ")") != -1) {
1661:                                currDist[0][(int) inst.classValue()] += weights[i];
1662:                            } else
1663:                                currDist[1][(int) inst.classValue()] += weights[i];
1664:                        }
1665:
1666:                        double[][] tempDist = new double[2][numClasses];
1667:                        for (int kk = 0; kk < 2; kk++) {
1668:                            tempDist[kk] = currDist[kk];
1669:                        }
1670:
1671:                        double[] tempProps = new double[2];
1672:                        for (int kk = 0; kk < 2; kk++) {
1673:                            tempProps[kk] = Utils.sum(tempDist[kk]);
1674:                        }
1675:
1676:                        if (Utils.sum(tempProps) != 0)
1677:                            Utils.normalize(tempProps);
1678:
1679:                        // split missing values
1680:                        int mstart = missingStart;
1681:                        while (mstart < sortedIndices.length) {
1682:                            Instance insta = data
1683:                                    .instance(sortedIndices[mstart]);
1684:                            for (int jj = 0; jj < 2; jj++) {
1685:                                tempDist[jj][(int) insta.classValue()] += tempProps[jj]
1686:                                        * weights[mstart];
1687:                            }
1688:                            mstart++;
1689:                        }
1690:
1691:                        double currGain;
1692:                        if (useGini)
1693:                            currGain = computeGiniGain(parentDist, tempDist);
1694:                        else
1695:                            currGain = computeInfoGain(parentDist, tempDist);
1696:
1697:                        if (currGain > bestGain) {
1698:                            bestGain = currGain;
1699:                            bestSplitString = tempStr;
1700:                            for (int jj = 0; jj < 2; jj++) {
1701:                                System.arraycopy(tempDist[jj], 0, dist[jj], 0,
1702:                                        dist[jj].length);
1703:                            }
1704:                        }
1705:                    }
1706:                }
1707:
1708:                // multi-class problems (exhaustive search)
1709:                else if (!useHeuristic || nonEmpty <= 4) {
1710:                    //else if (!useHeuristic || nonEmpty==2) {
1711:
1712:                    // Firstly, for attribute values which class frequency is not zero
1713:                    for (int i = 0; i < (int) Math.pow(2, nonEmpty - 1); i++) {
1714:                        String tempStr = "";
1715:                        currDist = new double[2][numClasses];
1716:                        int mod;
1717:                        int bit10 = i;
1718:                        for (int j = nonEmpty - 1; j >= 0; j--) {
1719:                            mod = bit10 % 2; // convert from 10bit to 2bit
1720:                            if (mod == 1) {
1721:                                if (tempStr == "")
1722:                                    tempStr = "(" + nonEmptyValues[j] + ")";
1723:                                else
1724:                                    tempStr += "|" + "(" + nonEmptyValues[j]
1725:                                            + ")";
1726:                            }
1727:                            bit10 = bit10 / 2;
1728:                        }
1729:                        for (int j = 0; j < sortedIndices.length; j++) {
1730:                            Instance inst = data.instance(sortedIndices[j]);
1731:                            if (inst.isMissing(att)) {
1732:                                break;
1733:                            }
1734:
1735:                            if (tempStr.indexOf("("
1736:                                    + att.value((int) inst.value(att)) + ")") != -1) {
1737:                                currDist[0][(int) inst.classValue()] += weights[j];
1738:                            } else
1739:                                currDist[1][(int) inst.classValue()] += weights[j];
1740:                        }
1741:
1742:                        double[][] tempDist = new double[2][numClasses];
1743:                        for (int k = 0; k < 2; k++) {
1744:                            tempDist[k] = currDist[k];
1745:                        }
1746:
1747:                        double[] tempProps = new double[2];
1748:                        for (int k = 0; k < 2; k++) {
1749:                            tempProps[k] = Utils.sum(tempDist[k]);
1750:                        }
1751:
1752:                        if (Utils.sum(tempProps) != 0)
1753:                            Utils.normalize(tempProps);
1754:
1755:                        // split missing values
1756:                        int index = missingStart;
1757:                        while (index < sortedIndices.length) {
1758:                            Instance insta = data
1759:                                    .instance(sortedIndices[index]);
1760:                            for (int j = 0; j < 2; j++) {
1761:                                tempDist[j][(int) insta.classValue()] += tempProps[j]
1762:                                        * weights[index];
1763:                            }
1764:                            index++;
1765:                        }
1766:
1767:                        double currGain;
1768:                        if (useGini)
1769:                            currGain = computeGiniGain(parentDist, tempDist);
1770:                        else
1771:                            currGain = computeInfoGain(parentDist, tempDist);
1772:
1773:                        if (currGain > bestGain) {
1774:                            bestGain = currGain;
1775:                            bestSplitString = tempStr;
1776:                            for (int j = 0; j < 2; j++) {
1777:                                //dist[jj] = new double[currDist[jj].length];
1778:                                System.arraycopy(tempDist[j], 0, dist[j], 0,
1779:                                        dist[j].length);
1780:                            }
1781:                        }
1782:                    }
1783:                }
1784:
1785:                // huristic method to solve multi-classes problems
1786:                else {
1787:                    // Firstly, for attribute values which class frequency is not zero
1788:                    int n = nonEmpty;
1789:                    int k = data.numClasses(); // number of classes of the data
1790:                    double[][] P = new double[n][k]; // class probability matrix
1791:                    int[] numInstancesValue = new int[n]; // number of instances for an attribute value
1792:                    double[] meanClass = new double[k]; // vector of mean class probability
1793:                    int numInstances = data.numInstances(); // total number of instances
1794:
1795:                    // initialize the vector of mean class probability
1796:                    for (int j = 0; j < meanClass.length; j++)
1797:                        meanClass[j] = 0;
1798:
1799:                    for (int j = 0; j < numInstances; j++) {
1800:                        Instance inst = (Instance) data.instance(j);
1801:                        int valueIndex = 0; // attribute value index in nonEmptyValues
1802:                        for (int i = 0; i < nonEmpty; i++) {
1803:                            if (att.value((int) inst.value(att))
1804:                                    .compareToIgnoreCase(nonEmptyValues[i]) == 0) {
1805:                                valueIndex = i;
1806:                                break;
1807:                            }
1808:                        }
1809:                        P[valueIndex][(int) inst.classValue()]++;
1810:                        numInstancesValue[valueIndex]++;
1811:                        meanClass[(int) inst.classValue()]++;
1812:                    }
1813:
1814:                    // calculate the class probability matrix
1815:                    for (int i = 0; i < P.length; i++) {
1816:                        for (int j = 0; j < P[0].length; j++) {
1817:                            if (numInstancesValue[i] == 0)
1818:                                P[i][j] = 0;
1819:                            else
1820:                                P[i][j] /= numInstancesValue[i];
1821:                        }
1822:                    }
1823:
1824:                    //calculate the vector of mean class probability
1825:                    for (int i = 0; i < meanClass.length; i++) {
1826:                        meanClass[i] /= numInstances;
1827:                    }
1828:
1829:                    // calculate the covariance matrix
1830:                    double[][] covariance = new double[k][k];
1831:                    for (int i1 = 0; i1 < k; i1++) {
1832:                        for (int i2 = 0; i2 < k; i2++) {
1833:                            double element = 0;
1834:                            for (int j = 0; j < n; j++) {
1835:                                element += (P[j][i2] - meanClass[i2])
1836:                                        * (P[j][i1] - meanClass[i1])
1837:                                        * numInstancesValue[j];
1838:                            }
1839:                            covariance[i1][i2] = element;
1840:                        }
1841:                    }
1842:
1843:                    Matrix matrix = new Matrix(covariance);
1844:                    weka.core.matrix.EigenvalueDecomposition eigen = new weka.core.matrix.EigenvalueDecomposition(
1845:                            matrix);
1846:                    double[] eigenValues = eigen.getRealEigenvalues();
1847:
1848:                    // find index of the largest eigenvalue
1849:                    int index = 0;
1850:                    double largest = eigenValues[0];
1851:                    for (int i = 1; i < eigenValues.length; i++) {
1852:                        if (eigenValues[i] > largest) {
1853:                            index = i;
1854:                            largest = eigenValues[i];
1855:                        }
1856:                    }
1857:
1858:                    // calculate the first principle component
1859:                    double[] FPC = new double[k];
1860:                    Matrix eigenVector = eigen.getV();
1861:                    double[][] vectorArray = eigenVector.getArray();
1862:                    for (int i = 0; i < FPC.length; i++) {
1863:                        FPC[i] = vectorArray[i][index];
1864:                    }
1865:
1866:                    // calculate the first principle component scores
1867:                    double[] Sa = new double[n];
1868:                    for (int i = 0; i < Sa.length; i++) {
1869:                        Sa[i] = 0;
1870:                        for (int j = 0; j < k; j++) {
1871:                            Sa[i] += FPC[j] * P[i][j];
1872:                        }
1873:                    }
1874:
1875:                    // sort category according to Sa(s)
1876:                    double[] pCopy = new double[n];
1877:                    System.arraycopy(Sa, 0, pCopy, 0, n);
1878:                    String[] sortedValues = new String[n];
1879:                    Arrays.sort(Sa);
1880:
1881:                    for (int j = 0; j < n; j++) {
1882:                        sortedValues[j] = nonEmptyValues[Utils.minIndex(pCopy)];
1883:                        pCopy[Utils.minIndex(pCopy)] = Double.MAX_VALUE;
1884:                    }
1885:
1886:                    // for the attribute values that class frequency is not 0
1887:                    String tempStr = "";
1888:
1889:                    for (int j = 0; j < nonEmpty - 1; j++) {
1890:                        currDist = new double[2][numClasses];
1891:                        if (tempStr == "")
1892:                            tempStr = "(" + sortedValues[j] + ")";
1893:                        else
1894:                            tempStr += "|" + "(" + sortedValues[j] + ")";
1895:                        for (int i = 0; i < sortedIndices.length; i++) {
1896:                            Instance inst = data.instance(sortedIndices[i]);
1897:                            if (inst.isMissing(att)) {
1898:                                break;
1899:                            }
1900:
1901:                            if (tempStr.indexOf("("
1902:                                    + att.value((int) inst.value(att)) + ")") != -1) {
1903:                                currDist[0][(int) inst.classValue()] += weights[i];
1904:                            } else
1905:                                currDist[1][(int) inst.classValue()] += weights[i];
1906:                        }
1907:
1908:                        double[][] tempDist = new double[2][numClasses];
1909:                        for (int kk = 0; kk < 2; kk++) {
1910:                            tempDist[kk] = currDist[kk];
1911:                        }
1912:
1913:                        double[] tempProps = new double[2];
1914:                        for (int kk = 0; kk < 2; kk++) {
1915:                            tempProps[kk] = Utils.sum(tempDist[kk]);
1916:                        }
1917:
1918:                        if (Utils.sum(tempProps) != 0)
1919:                            Utils.normalize(tempProps);
1920:
1921:                        // split missing values
1922:                        int mstart = missingStart;
1923:                        while (mstart < sortedIndices.length) {
1924:                            Instance insta = data
1925:                                    .instance(sortedIndices[mstart]);
1926:                            for (int jj = 0; jj < 2; jj++) {
1927:                                tempDist[jj][(int) insta.classValue()] += tempProps[jj]
1928:                                        * weights[mstart];
1929:                            }
1930:                            mstart++;
1931:                        }
1932:
1933:                        double currGain;
1934:                        if (useGini)
1935:                            currGain = computeGiniGain(parentDist, tempDist);
1936:                        else
1937:                            currGain = computeInfoGain(parentDist, tempDist);
1938:
1939:                        if (currGain > bestGain) {
1940:                            bestGain = currGain;
1941:                            bestSplitString = tempStr;
1942:                            for (int jj = 0; jj < 2; jj++) {
1943:                                //dist[jj] = new double[currDist[jj].length];
1944:                                System.arraycopy(tempDist[jj], 0, dist[jj], 0,
1945:                                        dist[jj].length);
1946:                            }
1947:                        }
1948:                    }
1949:                }
1950:
1951:                // Compute weights
1952:                int attIndex = att.index();
1953:                props[attIndex] = new double[2];
1954:                for (int k = 0; k < 2; k++) {
1955:                    props[attIndex][k] = Utils.sum(dist[k]);
1956:                }
1957:                if (!(Utils.sum(props[attIndex]) > 0)) {
1958:                    for (int k = 0; k < props[attIndex].length; k++) {
1959:                        props[attIndex][k] = 1.0 / (double) props[attIndex].length;
1960:                    }
1961:                } else {
1962:                    Utils.normalize(props[attIndex]);
1963:                }
1964:
1965:                // Compute subset weights
1966:                subsetWeights[attIndex] = new double[2];
1967:                for (int j = 0; j < 2; j++) {
1968:                    subsetWeights[attIndex][j] += Utils.sum(dist[j]);
1969:                }
1970:
1971:                // Then, for the attribute values that class frequency is 0, split it into the
1972:                // most frequent branch
1973:                for (int j = 0; j < empty; j++) {
1974:                    if (props[attIndex][0] >= props[attIndex][1]) {
1975:                        if (bestSplitString == "")
1976:                            bestSplitString = "(" + emptyValues[j] + ")";
1977:                        else
1978:                            bestSplitString += "|" + "(" + emptyValues[j] + ")";
1979:                    }
1980:                }
1981:
1982:                // clean gain
1983:                gains[attIndex] = Math.rint(bestGain * 10000000) / 10000000.0;
1984:
1985:                dists[attIndex] = dist;
1986:                return bestSplitString;
1987:            }
1988:
1989:            /**
1990:             * Split data into two subsets and store sorted indices and weights for two
1991:             * successor nodes.
1992:             *
1993:             * @param subsetIndices 	sorted indecis of instances for each attribute for two successor node
1994:             * @param subsetWeights 	weights of instances for each attribute for two successor node
1995:             * @param att 		attribute the split based on
1996:             * @param splitPoint 		split point the split based on if att is numeric
1997:             * @param splitStr 		split subset the split based on if att is nominal
1998:             * @param sortedIndices 	sorted indices of the instances to be split
1999:             * @param weights 		weights of the instances to bes split
2000:             * @param data 		training data
2001:             * @throws Exception 		if something goes wrong  
2002:             */
2003:            protected void splitData(int[][][] subsetIndices,
2004:                    double[][][] subsetWeights, Attribute att,
2005:                    double splitPoint, String splitStr, int[][] sortedIndices,
2006:                    double[][] weights, Instances data) throws Exception {
2007:
2008:                int j;
2009:                // For each attribute
2010:                for (int i = 0; i < data.numAttributes(); i++) {
2011:                    if (i == data.classIndex())
2012:                        continue;
2013:                    int[] num = new int[2];
2014:                    for (int k = 0; k < 2; k++) {
2015:                        subsetIndices[k][i] = new int[sortedIndices[i].length];
2016:                        subsetWeights[k][i] = new double[weights[i].length];
2017:                    }
2018:
2019:                    for (j = 0; j < sortedIndices[i].length; j++) {
2020:                        Instance inst = data.instance(sortedIndices[i][j]);
2021:                        if (inst.isMissing(att)) {
2022:                            // Split instance up
2023:                            for (int k = 0; k < 2; k++) {
2024:                                if (m_Props[k] > 0) {
2025:                                    subsetIndices[k][i][num[k]] = sortedIndices[i][j];
2026:                                    subsetWeights[k][i][num[k]] = m_Props[k]
2027:                                            * weights[i][j];
2028:                                    num[k]++;
2029:                                }
2030:                            }
2031:                        } else {
2032:                            int subset;
2033:                            if (att.isNumeric()) {
2034:                                subset = (inst.value(att) < splitPoint) ? 0 : 1;
2035:                            } else { // nominal attribute
2036:                                if (splitStr.indexOf("("
2037:                                        + att.value((int) inst.value(att
2038:                                                .index())) + ")") != -1) {
2039:                                    subset = 0;
2040:                                } else
2041:                                    subset = 1;
2042:                            }
2043:                            subsetIndices[subset][i][num[subset]] = sortedIndices[i][j];
2044:                            subsetWeights[subset][i][num[subset]] = weights[i][j];
2045:                            num[subset]++;
2046:                        }
2047:                    }
2048:
2049:                    // Trim arrays
2050:                    for (int k = 0; k < 2; k++) {
2051:                        int[] copy = new int[num[k]];
2052:                        System.arraycopy(subsetIndices[k][i], 0, copy, 0,
2053:                                num[k]);
2054:                        subsetIndices[k][i] = copy;
2055:                        double[] copyWeights = new double[num[k]];
2056:                        System.arraycopy(subsetWeights[k][i], 0, copyWeights,
2057:                                0, num[k]);
2058:                        subsetWeights[k][i] = copyWeights;
2059:                    }
2060:                }
2061:            }
2062:
2063:            /**
2064:             * Compute and return gini gain for given distributions of a node and its 
2065:             * successor nodes.
2066:             * 
2067:             * @param parentDist 	class distributions of parent node
2068:             * @param childDist 	class distributions of successor nodes
2069:             * @return 		Gini gain computed
2070:             */
2071:            protected double computeGiniGain(double[] parentDist,
2072:                    double[][] childDist) {
2073:                double totalWeight = Utils.sum(parentDist);
2074:                if (totalWeight == 0)
2075:                    return 0;
2076:
2077:                double leftWeight = Utils.sum(childDist[0]);
2078:                double rightWeight = Utils.sum(childDist[1]);
2079:
2080:                double parentGini = computeGini(parentDist, totalWeight);
2081:                double leftGini = computeGini(childDist[0], leftWeight);
2082:                double rightGini = computeGini(childDist[1], rightWeight);
2083:
2084:                return parentGini - leftWeight / totalWeight * leftGini
2085:                        - rightWeight / totalWeight * rightGini;
2086:            }
2087:
2088:            /**
2089:             * Compute and return gini index for a given distribution of a node.
2090:             * 
2091:             * @param dist 	class distributions
2092:             * @param total 	class distributions
2093:             * @return 		Gini index of the class distributions
2094:             */
2095:            protected double computeGini(double[] dist, double total) {
2096:                if (total == 0)
2097:                    return 0;
2098:                double val = 0;
2099:                for (int i = 0; i < dist.length; i++) {
2100:                    val += (dist[i] / total) * (dist[i] / total);
2101:                }
2102:                return 1 - val;
2103:            }
2104:
2105:            /**
2106:             * Compute and return information gain for given distributions of a node 
2107:             * and its successor nodes.
2108:             * 
2109:             * @param parentDist 	class distributions of parent node
2110:             * @param childDist 	class distributions of successor nodes
2111:             * @return 		information gain computed
2112:             */
2113:            protected double computeInfoGain(double[] parentDist,
2114:                    double[][] childDist) {
2115:                double totalWeight = Utils.sum(parentDist);
2116:                if (totalWeight == 0)
2117:                    return 0;
2118:
2119:                double leftWeight = Utils.sum(childDist[0]);
2120:                double rightWeight = Utils.sum(childDist[1]);
2121:
2122:                double parentInfo = computeEntropy(parentDist, totalWeight);
2123:                double leftInfo = computeEntropy(childDist[0], leftWeight);
2124:                double rightInfo = computeEntropy(childDist[1], rightWeight);
2125:
2126:                return parentInfo - leftWeight / totalWeight * leftInfo
2127:                        - rightWeight / totalWeight * rightInfo;
2128:            }
2129:
2130:            /**
2131:             * Compute and return entropy for a given distribution of a node.
2132:             * 
2133:             * @param dist 	class distributions
2134:             * @param total 	class distributions
2135:             * @return 		entropy of the class distributions
2136:             */
2137:            protected double computeEntropy(double[] dist, double total) {
2138:                if (total == 0)
2139:                    return 0;
2140:                double entropy = 0;
2141:                for (int i = 0; i < dist.length; i++) {
2142:                    if (dist[i] != 0)
2143:                        entropy -= dist[i] / total
2144:                                * Utils.log2(dist[i] / total);
2145:                }
2146:                return entropy;
2147:            }
2148:
2149:            /**
2150:             * Make the node leaf node.
2151:             * 
2152:             * @param data 	training data
2153:             */
2154:            protected void makeLeaf(Instances data) {
2155:                m_Attribute = null;
2156:                m_isLeaf = true;
2157:                m_ClassValue = Utils.maxIndex(m_ClassProbs);
2158:                m_ClassAttribute = data.classAttribute();
2159:            }
2160:
2161:            /**
2162:             * Computes class probabilities for instance using the decision tree.
2163:             *
2164:             * @param instance 	the instance for which class probabilities is to be computed
2165:             * @return 		the class probabilities for the given instance
2166:             * @throws Exception 	if something goes wrong
2167:             */
2168:            public double[] distributionForInstance(Instance instance)
2169:                    throws Exception {
2170:                if (!m_isLeaf) {
2171:                    // value of split attribute is missing
2172:                    if (instance.isMissing(m_Attribute)) {
2173:                        double[] returnedDist = new double[m_ClassProbs.length];
2174:
2175:                        for (int i = 0; i < m_Successors.length; i++) {
2176:                            double[] help = m_Successors[i]
2177:                                    .distributionForInstance(instance);
2178:                            if (help != null) {
2179:                                for (int j = 0; j < help.length; j++) {
2180:                                    returnedDist[j] += m_Props[i] * help[j];
2181:                                }
2182:                            }
2183:                        }
2184:                        return returnedDist;
2185:                    }
2186:
2187:                    // split attribute is nonimal
2188:                    else if (m_Attribute.isNominal()) {
2189:                        if (m_SplitString.indexOf("("
2190:                                + m_Attribute.value((int) instance
2191:                                        .value(m_Attribute)) + ")") != -1)
2192:                            return m_Successors[0]
2193:                                    .distributionForInstance(instance);
2194:                        else
2195:                            return m_Successors[1]
2196:                                    .distributionForInstance(instance);
2197:                    }
2198:
2199:                    // split attribute is numeric
2200:                    else {
2201:                        if (instance.value(m_Attribute) < m_SplitValue)
2202:                            return m_Successors[0]
2203:                                    .distributionForInstance(instance);
2204:                        else
2205:                            return m_Successors[1]
2206:                                    .distributionForInstance(instance);
2207:                    }
2208:                }
2209:
2210:                // leaf node
2211:                else
2212:                    return m_ClassProbs;
2213:            }
2214:
2215:            /**
2216:             * Prints the decision tree using the protected toString method from below.
2217:             * 
2218:             * @return 		a textual description of the classifier
2219:             */
2220:            public String toString() {
2221:                if ((m_Distribution == null) && (m_Successors == null)) {
2222:                    return "Best-First: No model built yet.";
2223:                }
2224:                return "Best-First Decision Tree\n" + toString(0) + "\n\n"
2225:                        + "Size of the Tree: " + numNodes() + "\n\n"
2226:                        + "Number of Leaf Nodes: " + numLeaves();
2227:            }
2228:
2229:            /**
2230:             * Outputs a tree at a certain level.
2231:             * 
2232:             * @param level 	the level at which the tree is to be printed
2233:             * @return 		a tree at a certain level.
2234:             */
2235:            protected String toString(int level) {
2236:                StringBuffer text = new StringBuffer();
2237:                // if leaf nodes
2238:                if (m_Attribute == null) {
2239:                    if (Instance.isMissingValue(m_ClassValue)) {
2240:                        text.append(": null");
2241:                    } else {
2242:                        double correctNum = Math.rint(m_Distribution[Utils
2243:                                .maxIndex(m_Distribution)] * 100) / 100.0;
2244:                        double wrongNum = Math
2245:                                .rint((Utils.sum(m_Distribution) - m_Distribution[Utils
2246:                                        .maxIndex(m_Distribution)]) * 100) / 100.0;
2247:                        String str = "(" + correctNum + "/" + wrongNum + ")";
2248:                        text.append(": "
2249:                                + m_ClassAttribute.value((int) m_ClassValue)
2250:                                + str);
2251:                    }
2252:                } else {
2253:                    for (int j = 0; j < 2; j++) {
2254:                        text.append("\n");
2255:                        for (int i = 0; i < level; i++) {
2256:                            text.append("|  ");
2257:                        }
2258:                        if (j == 0) {
2259:                            if (m_Attribute.isNumeric())
2260:                                text.append(m_Attribute.name() + " < "
2261:                                        + m_SplitValue);
2262:                            else
2263:                                text.append(m_Attribute.name() + "="
2264:                                        + m_SplitString);
2265:                        } else {
2266:                            if (m_Attribute.isNumeric())
2267:                                text.append(m_Attribute.name() + " >= "
2268:                                        + m_SplitValue);
2269:                            else
2270:                                text.append(m_Attribute.name() + "!="
2271:                                        + m_SplitString);
2272:                        }
2273:                        text.append(m_Successors[j].toString(level + 1));
2274:                    }
2275:                }
2276:                return text.toString();
2277:            }
2278:
2279:            /**
2280:             * Compute size of the tree.
2281:             * 
2282:             * @return 		size of the tree
2283:             */
2284:            public int numNodes() {
2285:                if (m_isLeaf) {
2286:                    return 1;
2287:                } else {
2288:                    int size = 1;
2289:                    for (int i = 0; i < m_Successors.length; i++) {
2290:                        size += m_Successors[i].numNodes();
2291:                    }
2292:                    return size;
2293:                }
2294:            }
2295:
2296:            /**
2297:             * Compute number of leaf nodes.
2298:             * 
2299:             * @return 		number of leaf nodes
2300:             */
2301:            public int numLeaves() {
2302:                if (m_isLeaf)
2303:                    return 1;
2304:                else {
2305:                    int size = 0;
2306:                    for (int i = 0; i < m_Successors.length; i++) {
2307:                        size += m_Successors[i].numLeaves();
2308:                    }
2309:                    return size;
2310:                }
2311:            }
2312:
2313:            /**
2314:             * Returns an enumeration describing the available options.
2315:             * 
2316:             * @return 		an enumeration describing the available options.
2317:             */
2318:            public Enumeration listOptions() {
2319:                Vector result;
2320:                Enumeration en;
2321:
2322:                result = new Vector();
2323:
2324:                en = super .listOptions();
2325:                while (en.hasMoreElements())
2326:                    result.addElement(en.nextElement());
2327:
2328:                result.addElement(new Option("\tThe pruning strategy.\n"
2329:                        + "\t(default: "
2330:                        + new SelectedTag(PRUNING_POSTPRUNING, TAGS_PRUNING)
2331:                        + ")", "P", 1, "-P " + Tag.toOptionList(TAGS_PRUNING)));
2332:
2333:                result.addElement(new Option(
2334:                        "\tThe minimal number of instances at the terminal nodes.\n"
2335:                                + "\t(default 2)", "M", 1, "-M <min no>"));
2336:
2337:                result.addElement(new Option(
2338:                        "\tThe number of folds used in the pruning.\n"
2339:                                + "\t(default 5)", "N", 5, "-N <num folds>"));
2340:
2341:                result.addElement(new Option(
2342:                        "\tDon't use heuristic search for nominal attributes in multi-class\n"
2343:                                + "\tproblem (default yes).\n", "H", 0, "-H"));
2344:
2345:                result
2346:                        .addElement(new Option(
2347:                                "\tDon't use Gini index for splitting (default yes),\n"
2348:                                        + "\tif not information is used.", "G",
2349:                                0, "-G"));
2350:
2351:                result.addElement(new Option(
2352:                        "\tDon't use error rate in internal cross-validation (default yes), \n"
2353:                                + "\tbut root mean squared error.", "R", 0,
2354:                        "-R"));
2355:
2356:                result.addElement(new Option(
2357:                        "\tUse the 1 SE rule to make pruning decision.\n"
2358:                                + "\t(default no).", "A", 0, "-A"));
2359:
2360:                result.addElement(new Option(
2361:                        "\tPercentage of training data size (0-1]\n"
2362:                                + "\t(default 1).", "C", 0, "-C"));
2363:
2364:                return result.elements();
2365:            }
2366:
2367:            /**
2368:             * Parses the options for this object. <p/>
2369:             *
2370:             <!-- options-start -->
2371:             * Valid options are: <p/>
2372:             * 
2373:             * <pre> -S &lt;num&gt;
2374:             *  Random number seed.
2375:             *  (default 1)</pre>
2376:             * 
2377:             * <pre> -D
2378:             *  If set, classifier is run in debug mode and
2379:             *  may output additional info to the console</pre>
2380:             * 
2381:             * <pre> -P &lt;UNPRUNED|POSTPRUNED|PREPRUNED&gt;
2382:             *  The pruning strategy.
2383:             *  (default: POSTPRUNED)</pre>
2384:             * 
2385:             * <pre> -M &lt;min no&gt;
2386:             *  The minimal number of instances at the terminal nodes.
2387:             *  (default 2)</pre>
2388:             * 
2389:             * <pre> -N &lt;num folds&gt;
2390:             *  The number of folds used in the pruning.
2391:             *  (default 5)</pre>
2392:             * 
2393:             * <pre> -H
2394:             *  Don't use heuristic search for nominal attributes in multi-class
2395:             *  problem (default yes).
2396:             * </pre>
2397:             * 
2398:             * <pre> -G
2399:             *  Don't use Gini index for splitting (default yes),
2400:             *  if not information is used.</pre>
2401:             * 
2402:             * <pre> -R
2403:             *  Don't use error rate in internal cross-validation (default yes), 
2404:             *  but root mean squared error.</pre>
2405:             * 
2406:             * <pre> -A
2407:             *  Use the 1 SE rule to make pruning decision.
2408:             *  (default no).</pre>
2409:             * 
2410:             * <pre> -C
2411:             *  Percentage of training data size (0-1]
2412:             *  (default 1).</pre>
2413:             * 
2414:             <!-- options-end -->
2415:             *
2416:             * @param options	the options to use
2417:             * @throws Exception	if setting of options fails
2418:             */
2419:            public void setOptions(String[] options) throws Exception {
2420:                String tmpStr;
2421:
2422:                super .setOptions(options);
2423:
2424:                tmpStr = Utils.getOption('M', options);
2425:                if (tmpStr.length() != 0)
2426:                    setMinNumObj(Integer.parseInt(tmpStr));
2427:                else
2428:                    setMinNumObj(2);
2429:
2430:                tmpStr = Utils.getOption('N', options);
2431:                if (tmpStr.length() != 0)
2432:                    setNumFoldsPruning(Integer.parseInt(tmpStr));
2433:                else
2434:                    setNumFoldsPruning(5);
2435:
2436:                tmpStr = Utils.getOption('C', options);
2437:                if (tmpStr.length() != 0)
2438:                    setSizePer(Double.parseDouble(tmpStr));
2439:                else
2440:                    setSizePer(1);
2441:
2442:                tmpStr = Utils.getOption('P', options);
2443:                if (tmpStr.length() != 0)
2444:                    setPruningStrategy(new SelectedTag(tmpStr, TAGS_PRUNING));
2445:                else
2446:                    setPruningStrategy(new SelectedTag(PRUNING_POSTPRUNING,
2447:                            TAGS_PRUNING));
2448:
2449:                setHeuristic(!Utils.getFlag('H', options));
2450:
2451:                setUseGini(!Utils.getFlag('G', options));
2452:
2453:                setUseErrorRate(!Utils.getFlag('R', options));
2454:
2455:                setUseOneSE(Utils.getFlag('A', options));
2456:            }
2457:
2458:            /**
2459:             * Gets the current settings of the Classifier.
2460:             * 
2461:             * @return 		the current settings of the Classifier
2462:             */
2463:            public String[] getOptions() {
2464:                int i;
2465:                Vector result;
2466:                String[] options;
2467:
2468:                result = new Vector();
2469:
2470:                options = super .getOptions();
2471:                for (i = 0; i < options.length; i++)
2472:                    result.add(options[i]);
2473:
2474:                result.add("-M");
2475:                result.add("" + getMinNumObj());
2476:
2477:                result.add("-N");
2478:                result.add("" + getNumFoldsPruning());
2479:
2480:                if (!getHeuristic())
2481:                    result.add("-H");
2482:
2483:                if (!getUseGini())
2484:                    result.add("-G");
2485:
2486:                if (!getUseErrorRate())
2487:                    result.add("-R");
2488:
2489:                if (getUseOneSE())
2490:                    result.add("-A");
2491:
2492:                result.add("-C");
2493:                result.add("" + getSizePer());
2494:
2495:                result.add("-P");
2496:                result.add("" + getPruningStrategy());
2497:
2498:                return (String[]) result.toArray(new String[result.size()]);
2499:            }
2500:
2501:            /**
2502:             * Return an enumeration of the measure names.
2503:             * 
2504:             * @return 		an enumeration of the measure names
2505:             */
2506:            public Enumeration enumerateMeasures() {
2507:                Vector result = new Vector();
2508:
2509:                result.addElement("measureTreeSize");
2510:
2511:                return result.elements();
2512:            }
2513:
2514:            /**
2515:             * Return number of tree size.
2516:             * 
2517:             * @return 		number of tree size
2518:             */
2519:            public double measureTreeSize() {
2520:                return numNodes();
2521:            }
2522:
2523:            /**
2524:             * Returns the value of the named measure
2525:             *
2526:             * @param additionalMeasureName 	the name of the measure to query for its value
2527:             * @return 				the value of the named measure
2528:             * @throws IllegalArgumentException 	if the named measure is not supported
2529:             */
2530:            public double getMeasure(String additionalMeasureName) {
2531:                if (additionalMeasureName
2532:                        .compareToIgnoreCase("measureTreeSize") == 0) {
2533:                    return measureTreeSize();
2534:                } else {
2535:                    throw new IllegalArgumentException(additionalMeasureName
2536:                            + " not supported (Best-First)");
2537:                }
2538:            }
2539:
2540:            /**
2541:             * Returns the tip text for this property
2542:             * 
2543:             * @return 		tip text for this property suitable for
2544:             * 			displaying in the explorer/experimenter gui
2545:             */
2546:            public String pruningStrategyTipText() {
2547:                return "Sets the pruning strategy.";
2548:            }
2549:
2550:            /**
2551:             * Sets the pruning strategy. 
2552:             *
2553:             * @param value 	the strategy
2554:             */
2555:            public void setPruningStrategy(SelectedTag value) {
2556:                if (value.getTags() == TAGS_PRUNING) {
2557:                    m_PruningStrategy = value.getSelectedTag().getID();
2558:                }
2559:            }
2560:
2561:            /**
2562:             * Gets the pruning strategy. 
2563:             *
2564:             * @return 		the current strategy.
2565:             */
2566:            public SelectedTag getPruningStrategy() {
2567:                return new SelectedTag(m_PruningStrategy, TAGS_PRUNING);
2568:            }
2569:
2570:            /**
2571:             * Returns the tip text for this property
2572:             * 
2573:             * @return 		tip text for this property suitable for
2574:             * 			displaying in the explorer/experimenter gui
2575:             */
2576:            public String minNumObjTipText() {
2577:                return "Set minimal number of instances at the terminal nodes.";
2578:            }
2579:
2580:            /**
2581:             * Set minimal number of instances at the terminal nodes.
2582:             * 
2583:             * @param value 	minimal number of instances at the terminal nodes
2584:             */
2585:            public void setMinNumObj(int value) {
2586:                m_minNumObj = value;
2587:            }
2588:
2589:            /**
2590:             * Get minimal number of instances at the terminal nodes.
2591:             * 
2592:             * @return 		minimal number of instances at the terminal nodes
2593:             */
2594:            public int getMinNumObj() {
2595:                return m_minNumObj;
2596:            }
2597:
2598:            /**
2599:             * Returns the tip text for this property
2600:             * 
2601:             * @return 		tip text for this property suitable for
2602:             * 			displaying in the explorer/experimenter gui
2603:             */
2604:            public String numFoldsPruningTipText() {
2605:                return "Number of folds in internal cross-validation.";
2606:            }
2607:
2608:            /**
2609:             * Set number of folds in internal cross-validation.
2610:             * 
2611:             * @param value 	the number of folds
2612:             */
2613:            public void setNumFoldsPruning(int value) {
2614:                m_numFoldsPruning = value;
2615:            }
2616:
2617:            /**
2618:             * Set number of folds in internal cross-validation.
2619:             * 
2620:             * @return 		number of folds in internal cross-validation
2621:             */
2622:            public int getNumFoldsPruning() {
2623:                return m_numFoldsPruning;
2624:            }
2625:
2626:            /**
2627:             * Returns the tip text for this property
2628:             * 
2629:             * @return 		tip text for this property suitable for
2630:             * 			displaying in the explorer/experimenter gui.
2631:             */
2632:            public String heuristicTipText() {
2633:                return "If heuristic search is used for binary split for nominal attributes.";
2634:            }
2635:
2636:            /**
2637:             * Set if use heuristic search for nominal attributes in multi-class problems.
2638:             * 
2639:             * @param value 	if use heuristic search for nominal attributes in 
2640:             * 			multi-class problems
2641:             */
2642:            public void setHeuristic(boolean value) {
2643:                m_Heuristic = value;
2644:            }
2645:
2646:            /**
2647:             * Get if use heuristic search for nominal attributes in multi-class problems.
2648:             * 
2649:             * @return 		if use heuristic search for nominal attributes in 
2650:             * 			multi-class problems
2651:             */
2652:            public boolean getHeuristic() {
2653:                return m_Heuristic;
2654:            }
2655:
2656:            /**
2657:             * Returns the tip text for this property
2658:             * 
2659:             * @return 		tip text for this property suitable for
2660:             * 			displaying in the explorer/experimenter gui.
2661:             */
2662:            public String useGiniTipText() {
2663:                return "If true the Gini index is used for splitting criterion, otherwise the information is used.";
2664:            }
2665:
2666:            /**
2667:             * Set if use Gini index as splitting criterion.
2668:             * 
2669:             * @param value 	if use Gini index splitting criterion
2670:             */
2671:            public void setUseGini(boolean value) {
2672:                m_UseGini = value;
2673:            }
2674:
2675:            /**
2676:             * Get if use Gini index as splitting criterion.
2677:             * 
2678:             * @return 		if use Gini index as splitting criterion
2679:             */
2680:            public boolean getUseGini() {
2681:                return m_UseGini;
2682:            }
2683:
2684:            /**
2685:             * Returns the tip text for this property
2686:             * 
2687:             * @return 		tip text for this property suitable for
2688:             * 			displaying in the explorer/experimenter gui.
2689:             */
2690:            public String useErrorRateTipText() {
2691:                return "If error rate is used as error estimate. if not, root mean squared error is used.";
2692:            }
2693:
2694:            /**
2695:             * Set if use error rate in internal cross-validation.
2696:             * 
2697:             * @param value 	if use error rate in internal cross-validation
2698:             */
2699:            public void setUseErrorRate(boolean value) {
2700:                m_UseErrorRate = value;
2701:            }
2702:
2703:            /**
2704:             * Get if use error rate in internal cross-validation.
2705:             * 
2706:             * @return 		if use error rate in internal cross-validation.
2707:             */
2708:            public boolean getUseErrorRate() {
2709:                return m_UseErrorRate;
2710:            }
2711:
2712:            /**
2713:             * Returns the tip text for this property
2714:             * 
2715:             * @return 		tip text for this property suitable for
2716:             * 			displaying in the explorer/experimenter gui.
2717:             */
2718:            public String useOneSETipText() {
2719:                return "Use the 1SE rule to make pruning decision.";
2720:            }
2721:
2722:            /**
2723:             * Set if use the 1SE rule to choose final model.
2724:             * 
2725:             * @param value 	if use the 1SE rule to choose final model
2726:             */
2727:            public void setUseOneSE(boolean value) {
2728:                m_UseOneSE = value;
2729:            }
2730:
2731:            /**
2732:             * Get if use the 1SE rule to choose final model.
2733:             * 
2734:             * @return 		if use the 1SE rule to choose final model
2735:             */
2736:            public boolean getUseOneSE() {
2737:                return m_UseOneSE;
2738:            }
2739:
2740:            /**
2741:             * Returns the tip text for this property
2742:             * 
2743:             * @return 		tip text for this property suitable for
2744:             * 			displaying in the explorer/experimenter gui.
2745:             */
2746:            public String sizePerTipText() {
2747:                return "The percentage of the training set size (0-1, 0 not included).";
2748:            }
2749:
2750:            /**
2751:             * Set training set size.
2752:             * 
2753:             * @param value 	training set size
2754:             */
2755:            public void setSizePer(double value) {
2756:                if ((value <= 0) || (value > 1))
2757:                    System.err
2758:                            .println("The percentage of the training set size must be in range 0 to 1 "
2759:                                    + "(0 not included) - ignored!");
2760:                else
2761:                    m_SizePer = value;
2762:            }
2763:
2764:            /**
2765:             * Get training set size.
2766:             * 
2767:             * @return 		training set size
2768:             */
2769:            public double getSizePer() {
2770:                return m_SizePer;
2771:            }
2772:
2773:            /**
2774:             * Main method.
2775:             *
2776:             * @param args the options for the classifier
2777:             */
2778:            public static void main(String[] args) {
2779:                runClassifier(new BFTree(), args);
2780:            }
2781:        }
www.java2java.com | Contact Us
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.