Source Code Cross Referenced for MIEMDD.java in  » Science » weka » weka » classifiers » mi » Java Source Code / Java DocumentationJava Source Code and Java Documentation

Java Source Code / Java Documentation
1. 6.0 JDK Core
2. 6.0 JDK Modules
3. 6.0 JDK Modules com.sun
4. 6.0 JDK Modules com.sun.java
5. 6.0 JDK Modules sun
6. 6.0 JDK Platform
7. Ajax
8. Apache Harmony Java SE
9. Aspect oriented
10. Authentication Authorization
11. Blogger System
12. Build
13. Byte Code
14. Cache
15. Chart
16. Chat
17. Code Analyzer
18. Collaboration
19. Content Management System
20. Database Client
21. Database DBMS
22. Database JDBC Connection Pool
23. Database ORM
24. Development
25. EJB Server geronimo
26. EJB Server GlassFish
27. EJB Server JBoss 4.2.1
28. EJB Server resin 3.1.5
29. ERP CRM Financial
30. ESB
31. Forum
32. GIS
33. Graphic Library
34. Groupware
35. HTML Parser
36. IDE
37. IDE Eclipse
38. IDE Netbeans
39. Installer
40. Internationalization Localization
41. Inversion of Control
42. Issue Tracking
43. J2EE
44. JBoss
45. JMS
46. JMX
47. Library
48. Mail Clients
49. Net
50. Parser
51. PDF
52. Portal
53. Profiler
54. Project Management
55. Report
56. RSS RDF
57. Rule Engine
58. Science
59. Scripting
60. Search Engine
61. Security
62. Sevlet Container
63. Source Control
64. Swing Library
65. Template Engine
66. Test Coverage
67. Testing
68. UML
69. Web Crawler
70. Web Framework
71. Web Mail
72. Web Server
73. Web Services
74. Web Services apache cxf 2.0.1
75. Web Services AXIS2
76. Wiki Engine
77. Workflow Engines
78. XML
79. XML UI
Java
Java Tutorial
Java Open Source
Jar File Download
Java Articles
Java Products
Java by API
Photoshop Tutorials
Maya Tutorials
Flash Tutorials
3ds-Max Tutorials
Illustrator Tutorials
GIMP Tutorials
C# / C Sharp
C# / CSharp Tutorial
C# / CSharp Open Source
ASP.Net
ASP.NET Tutorial
JavaScript DHTML
JavaScript Tutorial
JavaScript Reference
HTML / CSS
HTML CSS Reference
C / ANSI-C
C Tutorial
C++
C++ Tutorial
Ruby
PHP
Python
Python Tutorial
Python Open Source
SQL Server / T-SQL
SQL Server / T-SQL Tutorial
Oracle PL / SQL
Oracle PL/SQL Tutorial
PostgreSQL
SQL / MySQL
MySQL Tutorial
VB.Net
VB.Net Tutorial
Flash / Flex / ActionScript
VBA / Excel / Access / Word
XML
XML Tutorial
Microsoft Office PowerPoint 2007 Tutorial
Microsoft Office Excel 2007 Tutorial
Microsoft Office Word 2007 Tutorial
Java Source Code / Java Documentation » Science » weka » weka.classifiers.mi 
Source Cross Referenced  Class Diagram Java Document (Java Doc) 


001:        /*
002:         *    This program is free software; you can redistribute it and/or modify
003:         *    it under the terms of the GNU General Public License as published by
004:         *    the Free Software Foundation; either version 2 of the License, or
005:         *    (at your option) any later version.
006:         *
007:         *    This program is distributed in the hope that it will be useful,
008:         *    but WITHOUT ANY WARRANTY; without even the implied warranty of
009:         *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
010:         *    GNU General Public License for more details.
011:         *
012:         *    You should have received a copy of the GNU General Public License
013:         *    along with this program; if not, write to the Free Software
014:         *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015:         */
016:
017:        /*
018:         * MIEMDD.java
019:         * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
020:         *
021:         */
022:
023:        package weka.classifiers.mi;
024:
025:        import weka.classifiers.RandomizableClassifier;
026:        import weka.core.Capabilities;
027:        import weka.core.FastVector;
028:        import weka.core.Instance;
029:        import weka.core.Instances;
030:        import weka.core.MultiInstanceCapabilitiesHandler;
031:        import weka.core.Optimization;
032:        import weka.core.Option;
033:        import weka.core.OptionHandler;
034:        import weka.core.SelectedTag;
035:        import weka.core.Tag;
036:        import weka.core.TechnicalInformation;
037:        import weka.core.TechnicalInformationHandler;
038:        import weka.core.Utils;
039:        import weka.core.Capabilities.Capability;
040:        import weka.core.TechnicalInformation.Field;
041:        import weka.core.TechnicalInformation.Type;
042:        import weka.filters.Filter;
043:        import weka.filters.unsupervised.attribute.Normalize;
044:        import weka.filters.unsupervised.attribute.ReplaceMissingValues;
045:        import weka.filters.unsupervised.attribute.Standardize;
046:
047:        import java.util.Enumeration;
048:        import java.util.Random;
049:        import java.util.Vector;
050:
051:        /**
052:         <!-- globalinfo-start -->
053:         * EMDD model builds heavily upon Dietterich's Diverse Density (DD) algorithm.<br/>
054:         * It is a general framework for MI learning of converting the MI problem to a single-instance setting using EM. In this implementation, we use most-likely cause DD model and only use 3 random selected postive bags as initial starting points of EM.<br/>
055:         * <br/>
056:         * For more information see:<br/>
057:         * <br/>
058:         * Qi Zhang, Sally A. Goldman: EM-DD: An Improved Multiple-Instance Learning Technique. In: Advances in Neural Information Processing Systems 14, 1073-108, 2001.
059:         * <p/>
060:         <!-- globalinfo-end -->
061:         * 
062:         <!-- technical-bibtex-start -->
063:         * BibTeX:
064:         * <pre>
065:         * &#64;inproceedings{Zhang2001,
066:         *    author = {Qi Zhang and Sally A. Goldman},
067:         *    booktitle = {Advances in Neural Information Processing Systems 14},
068:         *    pages = {1073-108},
069:         *    publisher = {MIT Press},
070:         *    title = {EM-DD: An Improved Multiple-Instance Learning Technique},
071:         *    year = {2001}
072:         * }
073:         * </pre>
074:         * <p/>
075:         <!-- technical-bibtex-end -->
076:         *
077:         <!-- options-start -->
078:         * Valid options are: <p/>
079:         * 
080:         * <pre> -N &lt;num&gt;
081:         *  Whether to 0=normalize/1=standardize/2=neither.
082:         *  (default 1=standardize)</pre>
083:         * 
084:         * <pre> -S &lt;num&gt;
085:         *  Random number seed.
086:         *  (default 1)</pre>
087:         * 
088:         * <pre> -D
089:         *  If set, classifier is run in debug mode and
090:         *  may output additional info to the console</pre>
091:         * 
092:         <!-- options-end -->
093:         *     
094:         * @author Eibe Frank (eibe@cs.waikato.ac.nz)
095:         * @author Lin Dong (ld21@cs.waikato.ac.nz)
096:         * @version $Revision: 1.5 $ 
097:         */
098:        public class MIEMDD extends RandomizableClassifier implements 
099:                OptionHandler, MultiInstanceCapabilitiesHandler,
100:                TechnicalInformationHandler {
101:
102:            /** for serialization */
103:            static final long serialVersionUID = 3899547154866223734L;
104:
105:            /** The index of the class attribute */
106:            protected int m_ClassIndex;
107:
108:            protected double[] m_Par;
109:
110:            /** The number of the class labels */
111:            protected int m_NumClasses;
112:
113:            /** Class labels for each bag */
114:            protected int[] m_Classes;
115:
116:            /** MI data */
117:            protected double[][][] m_Data;
118:
119:            /** All attribute names */
120:            protected Instances m_Attributes;
121:
122:            /** MI data */
123:            protected double[][] m_emData;
124:
125:            /** The filter used to standardize/normalize all values. */
126:            protected Filter m_Filter = null;
127:
128:            /** Whether to normalize/standardize/neither, default:standardize */
129:            protected int m_filterType = FILTER_STANDARDIZE;
130:
131:            /** Normalize training data */
132:            public static final int FILTER_NORMALIZE = 0;
133:            /** Standardize training data */
134:            public static final int FILTER_STANDARDIZE = 1;
135:            /** No normalization/standardization */
136:            public static final int FILTER_NONE = 2;
137:            /** The filter to apply to the training data */
138:            public static final Tag[] TAGS_FILTER = {
139:                    new Tag(FILTER_NORMALIZE, "Normalize training data"),
140:                    new Tag(FILTER_STANDARDIZE, "Standardize training data"),
141:                    new Tag(FILTER_NONE, "No normalization/standardization"), };
142:
143:            /** The filter used to get rid of missing values. */
144:            protected ReplaceMissingValues m_Missing = new ReplaceMissingValues();
145:
146:            /**
147:             * Returns a string describing this filter
148:             *
149:             * @return a description of the filter suitable for
150:             * displaying in the explorer/experimenter gui
151:             */
152:            public String globalInfo() {
153:                return "EMDD model builds heavily upon Dietterich's Diverse Density (DD) "
154:                        + "algorithm.\nIt is a general framework for MI learning of converting "
155:                        + "the MI problem to a single-instance setting using EM. In this "
156:                        + "implementation, we use most-likely cause DD model and only use 3 "
157:                        + "random selected postive bags as initial starting points of EM.\n\n"
158:                        + "For more information see:\n\n"
159:                        + getTechnicalInformation().toString();
160:            }
161:
162:            /**
163:             * Returns an instance of a TechnicalInformation object, containing 
164:             * detailed information about the technical background of this class,
165:             * e.g., paper reference or book this class is based on.
166:             * 
167:             * @return the technical information about this class
168:             */
169:            public TechnicalInformation getTechnicalInformation() {
170:                TechnicalInformation result;
171:
172:                result = new TechnicalInformation(Type.INPROCEEDINGS);
173:                result.setValue(Field.AUTHOR, "Qi Zhang and Sally A. Goldman");
174:                result
175:                        .setValue(Field.TITLE,
176:                                "EM-DD: An Improved Multiple-Instance Learning Technique");
177:                result.setValue(Field.BOOKTITLE,
178:                        "Advances in Neural Information Processing Systems 14");
179:                result.setValue(Field.YEAR, "2001");
180:                result.setValue(Field.PAGES, "1073-108");
181:                result.setValue(Field.PUBLISHER, "MIT Press");
182:
183:                return result;
184:            }
185:
186:            /**
187:             * Returns an enumeration describing the available options
188:             *
189:             * @return an enumeration of all the available options
190:             */
191:            public Enumeration listOptions() {
192:                Vector result = new Vector();
193:
194:                result.addElement(new Option(
195:                        "\tWhether to 0=normalize/1=standardize/2=neither.\n"
196:                                + "\t(default 1=standardize)", "N", 1,
197:                        "-N <num>"));
198:
199:                Enumeration enm = super .listOptions();
200:                while (enm.hasMoreElements())
201:                    result.addElement(enm.nextElement());
202:
203:                return result.elements();
204:            }
205:
206:            /**
207:             * Parses a given list of options. <p/>
208:             * 
209:             <!-- options-start -->
210:             * Valid options are: <p/>
211:             * 
212:             * <pre> -N &lt;num&gt;
213:             *  Whether to 0=normalize/1=standardize/2=neither.
214:             *  (default 1=standardize)</pre>
215:             * 
216:             * <pre> -S &lt;num&gt;
217:             *  Random number seed.
218:             *  (default 1)</pre>
219:             * 
220:             * <pre> -D
221:             *  If set, classifier is run in debug mode and
222:             *  may output additional info to the console</pre>
223:             * 
224:             <!-- options-end -->
225:             *
226:             * @param options the list of options as an array of strings
227:             * @throws Exception if an option is not supported
228:             */
229:            public void setOptions(String[] options) throws Exception {
230:                String tmpStr;
231:
232:                tmpStr = Utils.getOption('N', options);
233:                if (tmpStr.length() != 0) {
234:                    setFilterType(new SelectedTag(Integer.parseInt(tmpStr),
235:                            TAGS_FILTER));
236:                } else {
237:                    setFilterType(new SelectedTag(FILTER_STANDARDIZE,
238:                            TAGS_FILTER));
239:                }
240:
241:                super .setOptions(options);
242:            }
243:
244:            /**
245:             * Gets the current settings of the classifier.
246:             *
247:             * @return an array of strings suitable for passing to setOptions
248:             */
249:            public String[] getOptions() {
250:                Vector result;
251:                String[] options;
252:                int i;
253:
254:                result = new Vector();
255:                options = super .getOptions();
256:                for (i = 0; i < options.length; i++)
257:                    result.add(options[i]);
258:
259:                result.add("-N");
260:                result.add("" + m_filterType);
261:
262:                return (String[]) result.toArray(new String[result.size()]);
263:            }
264:
265:            /**
266:             * Returns the tip text for this property
267:             *
268:             * @return tip text for this property suitable for
269:             * displaying in the explorer/experimenter gui
270:             */
271:            public String filterTypeTipText() {
272:                return "The filter type for transforming the training data.";
273:            }
274:
275:            /**
276:             * Gets how the training data will be transformed. Will be one of
277:             * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE.
278:             *
279:             * @return the filtering mode
280:             */
281:            public SelectedTag getFilterType() {
282:                return new SelectedTag(m_filterType, TAGS_FILTER);
283:            }
284:
285:            /**
286:             * Sets how the training data will be transformed. Should be one of
287:             * FILTER_NORMALIZE, FILTER_STANDARDIZE, FILTER_NONE.
288:             *
289:             * @param newType the new filtering mode
290:             */
291:            public void setFilterType(SelectedTag newType) {
292:
293:                if (newType.getTags() == TAGS_FILTER) {
294:                    m_filterType = newType.getSelectedTag().getID();
295:                }
296:            }
297:
298:            private class OptEng extends Optimization {
299:                /**
300:                 * Evaluate objective function
301:                 * @param x the current values of variables
302:                 * @return the value of the objective function
303:                 */
304:                protected double objectiveFunction(double[] x) {
305:                    double nll = 0; // -LogLikelihood
306:                    for (int i = 0; i < m_Classes.length; i++) { // ith bag
307:                        double ins = 0.0;
308:                        for (int k = 0; k < m_emData[i].length; k++)
309:                            //attribute index
310:                            ins += (m_emData[i][k] - x[k * 2])
311:                                    * (m_emData[i][k] - x[k * 2])
312:                                    * x[k * 2 + 1] * x[k * 2 + 1];
313:                        ins = Math.exp(-ins); // Pr. of being positive
314:
315:                        if (m_Classes[i] == 1) {
316:                            if (ins <= m_Zero)
317:                                ins = m_Zero;
318:                            nll -= Math.log(ins); //bag level -LogLikelihood
319:                        } else {
320:                            ins = 1.0 - ins; //Pr. of being negative
321:                            if (ins <= m_Zero)
322:                                ins = m_Zero;
323:                            nll -= Math.log(ins);
324:                        }
325:                    }
326:                    return nll;
327:                }
328:
329:                /**
330:                 * Evaluate Jacobian vector
331:                 * @param x the current values of variables
332:                 * @return the gradient vector
333:                 */
334:                protected double[] evaluateGradient(double[] x) {
335:                    double[] grad = new double[x.length];
336:                    for (int i = 0; i < m_Classes.length; i++) { // ith bag
337:                        double[] numrt = new double[x.length];
338:                        double exp = 0.0;
339:                        for (int k = 0; k < m_emData[i].length; k++)
340:                            //attr index
341:                            exp += (m_emData[i][k] - x[k * 2])
342:                                    * (m_emData[i][k] - x[k * 2])
343:                                    * x[k * 2 + 1] * x[k * 2 + 1];
344:                        exp = Math.exp(-exp); //Pr. of being positive
345:
346:                        //Instance-wise update
347:                        for (int p = 0; p < m_emData[i].length; p++) { // pth variable
348:                            numrt[2 * p] = 2.0 * (x[2 * p] - m_emData[i][p])
349:                                    * x[p * 2 + 1] * x[p * 2 + 1];
350:                            numrt[2 * p + 1] = 2.0
351:                                    * (x[2 * p] - m_emData[i][p])
352:                                    * (x[2 * p] - m_emData[i][p])
353:                                    * x[p * 2 + 1];
354:                        }
355:
356:                        //Bag-wise update
357:                        for (int q = 0; q < m_emData[i].length; q++) {
358:                            if (m_Classes[i] == 1) {//derivation of (-LogLikeliHood) for positive bags
359:                                grad[2 * q] += numrt[2 * q];
360:                                grad[2 * q + 1] += numrt[2 * q + 1];
361:                            } else { //derivation of (-LogLikeliHood) for negative bags
362:                                grad[2 * q] -= numrt[2 * q] * exp / (1.0 - exp);
363:                                grad[2 * q + 1] -= numrt[2 * q + 1] * exp
364:                                        / (1.0 - exp);
365:                            }
366:                        }
367:                    } // one bag
368:
369:                    return grad;
370:                }
371:            }
372:
373:            /**
374:             * Returns default capabilities of the classifier.
375:             *
376:             * @return      the capabilities of this classifier
377:             */
378:            public Capabilities getCapabilities() {
379:                Capabilities result = super .getCapabilities();
380:
381:                // attributes
382:                result.enable(Capability.NOMINAL_ATTRIBUTES);
383:                result.enable(Capability.RELATIONAL_ATTRIBUTES);
384:                result.enable(Capability.MISSING_VALUES);
385:
386:                // class
387:                result.enable(Capability.BINARY_CLASS);
388:                result.enable(Capability.MISSING_CLASS_VALUES);
389:
390:                // other
391:                result.enable(Capability.ONLY_MULTIINSTANCE);
392:
393:                return result;
394:            }
395:
396:            /**
397:             * Returns the capabilities of this multi-instance classifier for the
398:             * relational data.
399:             *
400:             * @return            the capabilities of this object
401:             * @see               Capabilities
402:             */
403:            public Capabilities getMultiInstanceCapabilities() {
404:                Capabilities result = super .getCapabilities();
405:
406:                // attributes
407:                result.enable(Capability.NOMINAL_ATTRIBUTES);
408:                result.enable(Capability.NUMERIC_ATTRIBUTES);
409:                result.enable(Capability.DATE_ATTRIBUTES);
410:                result.enable(Capability.MISSING_VALUES);
411:
412:                // class
413:                result.disableAllClasses();
414:                result.enable(Capability.NO_CLASS);
415:
416:                return result;
417:            }
418:
419:            /**
420:             * Builds the classifier
421:             *
422:             * @param train the training data to be used for generating the
423:             * boosted classifier.
424:             * @throws Exception if the classifier could not be built successfully
425:             */
426:            public void buildClassifier(Instances train) throws Exception {
427:                // can classifier handle the data?
428:                getCapabilities().testWithFail(train);
429:
430:                // remove instances with missing class
431:                train = new Instances(train);
432:                train.deleteWithMissingClass();
433:
434:                m_ClassIndex = train.classIndex();
435:                m_NumClasses = train.numClasses();
436:
437:                int nR = train.attribute(1).relation().numAttributes();
438:                int nC = train.numInstances();
439:                int[] bagSize = new int[nC];
440:                Instances datasets = new Instances(train.attribute(1)
441:                        .relation(), 0);
442:
443:                m_Data = new double[nC][nR][]; // Data values
444:                m_Classes = new int[nC]; // Class values
445:                m_Attributes = datasets.stringFreeStructure();
446:                if (m_Debug) {
447:                    System.out.println("\n\nExtracting data...");
448:                }
449:
450:                for (int h = 0; h < nC; h++) {//h_th bag
451:                    Instance current = train.instance(h);
452:                    m_Classes[h] = (int) current.classValue(); // Class value starts from 0
453:                    Instances currInsts = current.relationalValue(1);
454:                    for (int i = 0; i < currInsts.numInstances(); i++) {
455:                        Instance inst = currInsts.instance(i);
456:                        datasets.add(inst);
457:                    }
458:
459:                    int nI = currInsts.numInstances();
460:                    bagSize[h] = nI;
461:                }
462:
463:                /* filter the training data */
464:                if (m_filterType == FILTER_STANDARDIZE)
465:                    m_Filter = new Standardize();
466:                else if (m_filterType == FILTER_NORMALIZE)
467:                    m_Filter = new Normalize();
468:                else
469:                    m_Filter = null;
470:
471:                if (m_Filter != null) {
472:                    m_Filter.setInputFormat(datasets);
473:                    datasets = Filter.useFilter(datasets, m_Filter);
474:                }
475:
476:                m_Missing.setInputFormat(datasets);
477:                datasets = Filter.useFilter(datasets, m_Missing);
478:
479:                int instIndex = 0;
480:                int start = 0;
481:                for (int h = 0; h < nC; h++) {
482:                    for (int i = 0; i < datasets.numAttributes(); i++) {
483:                        // initialize m_data[][][]
484:                        m_Data[h][i] = new double[bagSize[h]];
485:                        instIndex = start;
486:                        for (int k = 0; k < bagSize[h]; k++) {
487:                            m_Data[h][i][k] = datasets.instance(instIndex)
488:                                    .value(i);
489:                            instIndex++;
490:                        }
491:                    }
492:                    start = instIndex;
493:                }
494:
495:                if (m_Debug) {
496:                    System.out.println("\n\nIteration History...");
497:                }
498:
499:                m_emData = new double[nC][nR];
500:                m_Par = new double[2 * nR];
501:
502:                double[] x = new double[nR * 2];
503:                double[] tmp = new double[x.length];
504:                double[] pre_x = new double[x.length];
505:                double[] best_hypothesis = new double[x.length];
506:                double[][] b = new double[2][x.length];
507:
508:                OptEng opt;
509:                double bestnll = Double.MAX_VALUE;
510:                double min_error = Double.MAX_VALUE;
511:                double nll, pre_nll;
512:                int iterationCount;
513:
514:                for (int t = 0; t < x.length; t++) {
515:                    b[0][t] = Double.NaN;
516:                    b[1][t] = Double.NaN;
517:                }
518:
519:                //random pick 3 positive bags 
520:                Random r = new Random(getSeed());
521:                FastVector index = new FastVector();
522:                int n1, n2, n3;
523:                do {
524:                    n1 = r.nextInt(nC - 1);
525:                } while (m_Classes[n1] == 0);
526:                index.addElement(new Integer(n1));
527:
528:                do {
529:                    n2 = r.nextInt(nC - 1);
530:                } while (n2 == n1 || m_Classes[n2] == 0);
531:                index.addElement(new Integer(n2));
532:
533:                do {
534:                    n3 = r.nextInt(nC - 1);
535:                } while (n3 == n1 || n3 == n2 || m_Classes[n3] == 0);
536:                index.addElement(new Integer(n3));
537:
538:                for (int s = 0; s < index.size(); s++) {
539:                    int exIdx = ((Integer) index.elementAt(s)).intValue();
540:                    if (m_Debug)
541:                        System.out.println("\nH0 at " + exIdx);
542:
543:                    for (int p = 0; p < m_Data[exIdx][0].length; p++) {
544:                        //initialize a hypothesis
545:                        for (int q = 0; q < nR; q++) {
546:                            x[2 * q] = m_Data[exIdx][q][p];
547:                            x[2 * q + 1] = 1.0;
548:                        }
549:
550:                        pre_nll = Double.MAX_VALUE;
551:                        nll = Double.MAX_VALUE / 10.0;
552:                        iterationCount = 0;
553:                        //while (Math.abs(nll-pre_nll)>0.01*pre_nll && iterationCount<10) {  //stop condition
554:                        while (nll < pre_nll && iterationCount < 10) {
555:                            iterationCount++;
556:                            pre_nll = nll;
557:
558:                            if (m_Debug)
559:                                System.out.println("\niteration: "
560:                                        + iterationCount);
561:
562:                            //E-step (find one instance from each bag with max likelihood )
563:                            for (int i = 0; i < m_Data.length; i++) { //for each bag
564:
565:                                int insIndex = findInstance(i, x);
566:
567:                                for (int att = 0; att < m_Data[0].length; att++)
568:                                    //for each attribute
569:                                    m_emData[i][att] = m_Data[i][att][insIndex];
570:                            }
571:                            if (m_Debug)
572:                                System.out
573:                                        .println("E-step for new H' finished");
574:
575:                            //M-step
576:                            opt = new OptEng();
577:                            tmp = opt.findArgmin(x, b);
578:                            while (tmp == null) {
579:                                tmp = opt.getVarbValues();
580:                                if (m_Debug)
581:                                    System.out
582:                                            .println("200 iterations finished, not enough!");
583:                                tmp = opt.findArgmin(tmp, b);
584:                            }
585:                            nll = opt.getMinFunction();
586:
587:                            pre_x = x;
588:                            x = tmp; // update hypothesis 
589:
590:                            //keep the track of the best target point which has the minimum nll
591:                            /* if (nll < bestnll) {
592:                               bestnll = nll;
593:                               m_Par = tmp;
594:                               if (m_Debug)
595:                               System.out.println("!!!!!!!!!!!!!!!!Smaller NLL found: " + nll);
596:                               }*/
597:
598:                            //if (m_Debug)
599:                            //System.out.println(exIdx+" "+p+": "+nll+" "+pre_nll+" " +bestnll);
600:                        } //converged for one instance
601:
602:                        //evaluate the hypothesis on the training data and
603:                        //keep the track of the hypothesis with minimum error on training data
604:                        double distribution[] = new double[2];
605:                        int error = 0;
606:                        if (nll > pre_nll)
607:                            m_Par = pre_x;
608:                        else
609:                            m_Par = x;
610:
611:                        for (int i = 0; i < train.numInstances(); i++) {
612:                            distribution = distributionForInstance(train
613:                                    .instance(i));
614:                            if (distribution[1] >= 0.5 && m_Classes[i] == 0)
615:                                error++;
616:                            else if (distribution[1] < 0.5 && m_Classes[i] == 1)
617:                                error++;
618:                        }
619:                        if (error < min_error) {
620:                            best_hypothesis = m_Par;
621:                            min_error = error;
622:                            if (nll > pre_nll)
623:                                bestnll = pre_nll;
624:                            else
625:                                bestnll = nll;
626:                            if (m_Debug)
627:                                System.out.println("error= " + error
628:                                        + "  nll= " + bestnll);
629:                        }
630:                    }
631:                    if (m_Debug) {
632:                        System.out.println(exIdx
633:                                + ":  -------------<Converged>--------------");
634:                        System.out.println("current minimum error= "
635:                                + min_error + "  nll= " + bestnll);
636:                    }
637:                }
638:                m_Par = best_hypothesis;
639:            }
640:
641:            /**
642:             * given x, find the instance in ith bag with the most likelihood
643:             * probability, which is most likely to responsible for the label of the
644:             * bag For a positive bag, find the instance with the maximal probability
645:             * of being positive For a negative bag, find the instance with the minimal
646:             * probability of being negative
647:             *
648:             * @param i the bag index
649:             * @param x the current values of variables
650:             * @return index of the instance in the bag
651:             */
652:            protected int findInstance(int i, double[] x) {
653:
654:                double min = Double.MAX_VALUE;
655:                int insIndex = 0;
656:                int nI = m_Data[i][0].length; // numInstances in ith bag
657:
658:                for (int j = 0; j < nI; j++) {
659:                    double ins = 0.0;
660:                    for (int k = 0; k < m_Data[i].length; k++)
661:                        // for each attribute
662:                        ins += (m_Data[i][k][j] - x[k * 2])
663:                                * (m_Data[i][k][j] - x[k * 2]) * x[k * 2 + 1]
664:                                * x[k * 2 + 1];
665:
666:                    //the probability can be calculated as Math.exp(-ins)
667:                    //to find the maximum Math.exp(-ins) is equivalent to find the minimum of (ins)
668:                    if (ins < min) {
669:                        min = ins;
670:                        insIndex = j;
671:                    }
672:                }
673:                return insIndex;
674:            }
675:
676:            /**
677:             * Computes the distribution for a given exemplar
678:             *
679:             * @param exmp the exemplar for which distribution is computed
680:             * @return the distribution
681:             * @throws Exception if the distribution can't be computed successfully
682:             */
683:            public double[] distributionForInstance(Instance exmp)
684:                    throws Exception {
685:
686:                // Extract the data
687:                Instances ins = exmp.relationalValue(1);
688:                if (m_Filter != null)
689:                    ins = Filter.useFilter(ins, m_Filter);
690:
691:                ins = Filter.useFilter(ins, m_Missing);
692:
693:                int nI = ins.numInstances(), nA = ins.numAttributes();
694:                double[][] dat = new double[nI][nA];
695:                for (int j = 0; j < nI; j++) {
696:                    for (int k = 0; k < nA; k++) {
697:                        dat[j][k] = ins.instance(j).value(k);
698:                    }
699:                }
700:                //find the concept instance in the exemplar
701:                double min = Double.MAX_VALUE;
702:                double maxProb = -1.0;
703:                for (int j = 0; j < nI; j++) {
704:                    double exp = 0.0;
705:                    for (int k = 0; k < nA; k++)
706:                        // for each attribute
707:                        exp += (dat[j][k] - m_Par[k * 2])
708:                                * (dat[j][k] - m_Par[k * 2]) * m_Par[k * 2 + 1]
709:                                * m_Par[k * 2 + 1];
710:                    //the probability can be calculated as Math.exp(-exp)
711:                    //to find the maximum Math.exp(-exp) is equivalent to find the minimum of (exp)
712:                    if (exp < min) {
713:                        min = exp;
714:                        maxProb = Math.exp(-exp); //maximum probability of being positive   
715:                    }
716:                }
717:
718:                // Compute the probability of the bag
719:                double[] distribution = new double[2];
720:                distribution[1] = maxProb;
721:                distribution[0] = 1.0 - distribution[1]; //mininum prob. of being negative
722:
723:                return distribution;
724:            }
725:
726:            /**
727:             * Gets a string describing the classifier.
728:             *
729:             * @return a string describing the classifer built.
730:             */
731:            public String toString() {
732:
733:                String result = "MIEMDD";
734:                if (m_Par == null) {
735:                    return result + ": No model built yet.";
736:                }
737:
738:                result += "\nCoefficients...\n"
739:                        + "Variable       Point       Scale\n";
740:                for (int j = 0, idx = 0; j < m_Par.length / 2; j++, idx++) {
741:                    result += m_Attributes.attribute(idx).name();
742:                    result += " " + Utils.doubleToString(m_Par[j * 2], 12, 4);
743:                    result += " "
744:                            + Utils.doubleToString(m_Par[j * 2 + 1], 12, 4)
745:                            + "\n";
746:                }
747:
748:                return result;
749:            }
750:
751:            /**
752:             * Main method for testing this class.
753:             *
754:             * @param argv should contain the command line arguments to the
755:             * scheme (see Evaluation)
756:             */
757:            public static void main(String[] argv) {
758:                runClassifier(new MIEMDD(), argv);
759:            }
760:        }
www.java2java.com | Contact Us
Copyright 2009 - 12 Demo Source and Support. All rights reserved.
All other trademarks are property of their respective owners.