001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * CostSensitiveClassifierSplitEvaluator.java
019: * Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.experiment;
024:
025: import weka.classifiers.Classifier;
026: import weka.classifiers.CostMatrix;
027: import weka.classifiers.Evaluation;
028: import weka.core.AdditionalMeasureProducer;
029: import weka.core.Attribute;
030: import weka.core.Instance;
031: import weka.core.Instances;
032: import weka.core.Option;
033: import weka.core.Summarizable;
034: import weka.core.Utils;
035:
036: import java.io.BufferedReader;
037: import java.io.File;
038: import java.io.FileReader;
039: import java.lang.management.ManagementFactory;
040: import java.lang.management.ThreadMXBean;
041: import java.util.Enumeration;
042: import java.util.Vector;
043:
044: /**
045: <!-- globalinfo-start -->
046: * SplitEvaluator that produces results for a classification scheme on a nominal class attribute, including weighted misclassification costs.
047: * <p/>
048: <!-- globalinfo-end -->
049: *
050: <!-- options-start -->
051: * Valid options are: <p/>
052: *
053: * <pre> -W <class name>
054: * The full class name of the classifier.
055: * eg: weka.classifiers.bayes.NaiveBayes</pre>
056: *
057: * <pre> -C <index>
058: * The index of the class for which IR statistics
059: * are to be output. (default 1)</pre>
060: *
061: * <pre> -I <index>
062: * The index of an attribute to output in the
063: * results. This attribute should identify an
064: * instance in order to know which instances are
065: * in the test set of a cross validation. if 0
066: * no output (default 0).</pre>
067: *
068: * <pre> -P
069: * Add target and prediction columns to the result
070: * for each fold.</pre>
071: *
072: * <pre>
073: * Options specific to classifier weka.classifiers.rules.ZeroR:
074: * </pre>
075: *
076: * <pre> -D
077: * If set, classifier is run in debug mode and
078: * may output additional info to the console</pre>
079: *
080: * <pre> -D <directory>
081: * Name of a directory to search for cost files when loading
082: * costs on demand (default current directory).</pre>
083: *
084: <!-- options-end -->
085: *
086: * All options after -- will be passed to the classifier.
087: *
088: * @author Len Trigg (len@reeltwo.com)
089: * @version $Revision: 1.15 $
090: */
091: public class CostSensitiveClassifierSplitEvaluator extends
092: ClassifierSplitEvaluator {
093:
094: /** for serialization */
095: static final long serialVersionUID = -8069566663019501276L;
096:
097: /**
098: * The directory used when loading cost files on demand, null indicates
099: * current directory
100: */
101: protected File m_OnDemandDirectory = new File(System
102: .getProperty("user.dir"));
103:
104: /** The length of a result */
105: private static final int RESULT_SIZE = 27; //23;
106:
107: /**
108: * Returns a string describing this split evaluator
109: * @return a description of the split evaluator suitable for
110: * displaying in the explorer/experimenter gui
111: */
112: public String globalInfo() {
113: return " SplitEvaluator that produces results for a classification scheme "
114: + "on a nominal class attribute, including weighted misclassification "
115: + "costs.";
116: }
117:
118: /**
119: * Returns an enumeration describing the available options..
120: *
121: * @return an enumeration of all the available options.
122: */
123: public Enumeration listOptions() {
124:
125: Vector newVector = new Vector(1);
126: Enumeration enu = super .listOptions();
127: while (enu.hasMoreElements()) {
128: newVector.addElement(enu.nextElement());
129: }
130:
131: newVector
132: .addElement(new Option(
133: "\tName of a directory to search for cost files when loading\n"
134: + "\tcosts on demand (default current directory).",
135: "D", 1, "-D <directory>"));
136:
137: return newVector.elements();
138: }
139:
140: /**
141: * Parses a given list of options. <p/>
142: *
143: <!-- options-start -->
144: * Valid options are: <p/>
145: *
146: * <pre> -W <class name>
147: * The full class name of the classifier.
148: * eg: weka.classifiers.bayes.NaiveBayes</pre>
149: *
150: * <pre> -C <index>
151: * The index of the class for which IR statistics
152: * are to be output. (default 1)</pre>
153: *
154: * <pre> -I <index>
155: * The index of an attribute to output in the
156: * results. This attribute should identify an
157: * instance in order to know which instances are
158: * in the test set of a cross validation. if 0
159: * no output (default 0).</pre>
160: *
161: * <pre> -P
162: * Add target and prediction columns to the result
163: * for each fold.</pre>
164: *
165: * <pre>
166: * Options specific to classifier weka.classifiers.rules.ZeroR:
167: * </pre>
168: *
169: * <pre> -D
170: * If set, classifier is run in debug mode and
171: * may output additional info to the console</pre>
172: *
173: * <pre> -D <directory>
174: * Name of a directory to search for cost files when loading
175: * costs on demand (default current directory).</pre>
176: *
177: <!-- options-end -->
178: *
179: * All options after -- will be passed to the classifier.
180: *
181: * @param options the list of options as an array of strings
182: * @throws Exception if an option is not supported
183: */
184: public void setOptions(String[] options) throws Exception {
185:
186: String demandDir = Utils.getOption('D', options);
187: if (demandDir.length() != 0) {
188: setOnDemandDirectory(new File(demandDir));
189: }
190:
191: super .setOptions(options);
192: }
193:
194: /**
195: * Gets the current settings of the Classifier.
196: *
197: * @return an array of strings suitable for passing to setOptions
198: */
199: public String[] getOptions() {
200:
201: String[] super Options = super .getOptions();
202: String[] options = new String[super Options.length + 3];
203: int current = 0;
204:
205: options[current++] = "-D";
206: options[current++] = "" + getOnDemandDirectory();
207:
208: System.arraycopy(super Options, 0, options, current,
209: super Options.length);
210: current += super Options.length;
211: while (current < options.length) {
212: options[current++] = "";
213: }
214: return options;
215: }
216:
217: /**
218: * Returns the tip text for this property
219: * @return tip text for this property suitable for
220: * displaying in the explorer/experimenter gui
221: */
222: public String onDemandDirectoryTipText() {
223: return "The directory to look in for cost files. This directory will be "
224: + "searched for cost files when loading on demand.";
225: }
226:
227: /**
228: * Returns the directory that will be searched for cost files when
229: * loading on demand.
230: *
231: * @return The cost file search directory.
232: */
233: public File getOnDemandDirectory() {
234:
235: return m_OnDemandDirectory;
236: }
237:
238: /**
239: * Sets the directory that will be searched for cost files when
240: * loading on demand.
241: *
242: * @param newDir The cost file search directory.
243: */
244: public void setOnDemandDirectory(File newDir) {
245:
246: if (newDir.isDirectory()) {
247: m_OnDemandDirectory = newDir;
248: } else {
249: m_OnDemandDirectory = new File(newDir.getParent());
250: }
251: }
252:
253: /**
254: * Gets the data types of each of the result columns produced for a
255: * single run. The number of result fields must be constant
256: * for a given SplitEvaluator.
257: *
258: * @return an array containing objects of the type of each result column.
259: * The objects should be Strings, or Doubles.
260: */
261: public Object[] getResultTypes() {
262: int addm = (m_AdditionalMeasures != null) ? m_AdditionalMeasures.length
263: : 0;
264: Object[] resultTypes = new Object[RESULT_SIZE + addm];
265: Double doub = new Double(0);
266: int current = 0;
267: resultTypes[current++] = doub;
268:
269: resultTypes[current++] = doub;
270: resultTypes[current++] = doub;
271: resultTypes[current++] = doub;
272: resultTypes[current++] = doub;
273: resultTypes[current++] = doub;
274: resultTypes[current++] = doub;
275: resultTypes[current++] = doub;
276: resultTypes[current++] = doub;
277:
278: resultTypes[current++] = doub;
279: resultTypes[current++] = doub;
280: resultTypes[current++] = doub;
281: resultTypes[current++] = doub;
282:
283: resultTypes[current++] = doub;
284: resultTypes[current++] = doub;
285: resultTypes[current++] = doub;
286: resultTypes[current++] = doub;
287: resultTypes[current++] = doub;
288: resultTypes[current++] = doub;
289:
290: resultTypes[current++] = doub;
291: resultTypes[current++] = doub;
292: resultTypes[current++] = doub;
293:
294: // Timing stats
295: resultTypes[current++] = doub;
296: resultTypes[current++] = doub;
297: resultTypes[current++] = doub;
298: resultTypes[current++] = doub;
299:
300: resultTypes[current++] = "";
301:
302: // add any additional measures
303: for (int i = 0; i < addm; i++) {
304: resultTypes[current++] = doub;
305: }
306: if (current != RESULT_SIZE + addm) {
307: throw new Error("ResultTypes didn't fit RESULT_SIZE");
308: }
309: return resultTypes;
310: }
311:
312: /**
313: * Gets the names of each of the result columns produced for a single run.
314: * The number of result fields must be constant
315: * for a given SplitEvaluator.
316: *
317: * @return an array containing the name of each result column
318: */
319: public String[] getResultNames() {
320: int addm = (m_AdditionalMeasures != null) ? m_AdditionalMeasures.length
321: : 0;
322: String[] resultNames = new String[RESULT_SIZE + addm];
323: int current = 0;
324: resultNames[current++] = "Number_of_instances";
325:
326: // Basic performance stats - right vs wrong
327: resultNames[current++] = "Number_correct";
328: resultNames[current++] = "Number_incorrect";
329: resultNames[current++] = "Number_unclassified";
330: resultNames[current++] = "Percent_correct";
331: resultNames[current++] = "Percent_incorrect";
332: resultNames[current++] = "Percent_unclassified";
333: resultNames[current++] = "Total_cost";
334: resultNames[current++] = "Average_cost";
335:
336: // Sensitive stats - certainty of predictions
337: resultNames[current++] = "Mean_absolute_error";
338: resultNames[current++] = "Root_mean_squared_error";
339: resultNames[current++] = "Relative_absolute_error";
340: resultNames[current++] = "Root_relative_squared_error";
341:
342: // SF stats
343: resultNames[current++] = "SF_prior_entropy";
344: resultNames[current++] = "SF_scheme_entropy";
345: resultNames[current++] = "SF_entropy_gain";
346: resultNames[current++] = "SF_mean_prior_entropy";
347: resultNames[current++] = "SF_mean_scheme_entropy";
348: resultNames[current++] = "SF_mean_entropy_gain";
349:
350: // K&B stats
351: resultNames[current++] = "KB_information";
352: resultNames[current++] = "KB_mean_information";
353: resultNames[current++] = "KB_relative_information";
354:
355: // Timing stats
356: resultNames[current++] = "Elapsed_Time_training";
357: resultNames[current++] = "Elapsed_Time_testing";
358: resultNames[current++] = "UserCPU_Time_training";
359: resultNames[current++] = "UserCPU_Time_testing";
360:
361: // Classifier defined extras
362: resultNames[current++] = "Summary";
363: // add any additional measures
364: for (int i = 0; i < addm; i++) {
365: resultNames[current++] = m_AdditionalMeasures[i];
366: }
367: if (current != RESULT_SIZE + addm) {
368: throw new Error("ResultNames didn't fit RESULT_SIZE");
369: }
370: return resultNames;
371: }
372:
373: /**
374: * Gets the results for the supplied train and test datasets. Now performs
375: * a deep copy of the classifier before it is built and evaluated (just in case
376: * the classifier is not initialized properly in buildClassifier()).
377: *
378: * @param train the training Instances.
379: * @param test the testing Instances.
380: * @return the results stored in an array. The objects stored in
381: * the array may be Strings, Doubles, or null (for the missing value).
382: * @throws Exception if a problem occurs while getting the results
383: */
384: public Object[] getResult(Instances train, Instances test)
385: throws Exception {
386:
387: if (train.classAttribute().type() != Attribute.NOMINAL) {
388: throw new Exception("Class attribute is not nominal!");
389: }
390: if (m_Template == null) {
391: throw new Exception("No classifier has been specified");
392: }
393: ThreadMXBean thMonitor = ManagementFactory.getThreadMXBean();
394: boolean canMeasureCPUTime = thMonitor
395: .isThreadCpuTimeSupported();
396: if (!thMonitor.isThreadCpuTimeEnabled())
397: thMonitor.setThreadCpuTimeEnabled(true);
398:
399: int addm = (m_AdditionalMeasures != null) ? m_AdditionalMeasures.length
400: : 0;
401: Object[] result = new Object[RESULT_SIZE + addm];
402: long thID = Thread.currentThread().getId();
403: long CPUStartTime = -1, trainCPUTimeElapsed = -1, testCPUTimeElapsed = -1, trainTimeStart, trainTimeElapsed, testTimeStart, testTimeElapsed;
404:
405: String costName = train.relationName()
406: + CostMatrix.FILE_EXTENSION;
407: File costFile = new File(getOnDemandDirectory(), costName);
408: if (!costFile.exists()) {
409: throw new Exception("On-demand cost file doesn't exist: "
410: + costFile);
411: }
412: CostMatrix costMatrix = new CostMatrix(new BufferedReader(
413: new FileReader(costFile)));
414:
415: Evaluation eval = new Evaluation(train, costMatrix);
416: m_Classifier = Classifier.makeCopy(m_Template);
417:
418: trainTimeStart = System.currentTimeMillis();
419: if (canMeasureCPUTime)
420: CPUStartTime = thMonitor.getThreadUserTime(thID);
421: m_Classifier.buildClassifier(train);
422: if (canMeasureCPUTime)
423: trainCPUTimeElapsed = thMonitor.getThreadUserTime(thID)
424: - CPUStartTime;
425: trainTimeElapsed = System.currentTimeMillis() - trainTimeStart;
426: testTimeStart = System.currentTimeMillis();
427: if (canMeasureCPUTime)
428: CPUStartTime = thMonitor.getThreadUserTime(thID);
429: eval.evaluateModel(m_Classifier, test);
430: if (canMeasureCPUTime)
431: testCPUTimeElapsed = thMonitor.getThreadUserTime(thID)
432: - CPUStartTime;
433: testTimeElapsed = System.currentTimeMillis() - testTimeStart;
434: thMonitor = null;
435:
436: m_result = eval.toSummaryString();
437: // The results stored are all per instance -- can be multiplied by the
438: // number of instances to get absolute numbers
439: int current = 0;
440: result[current++] = new Double(eval.numInstances());
441:
442: result[current++] = new Double(eval.correct());
443: result[current++] = new Double(eval.incorrect());
444: result[current++] = new Double(eval.unclassified());
445: result[current++] = new Double(eval.pctCorrect());
446: result[current++] = new Double(eval.pctIncorrect());
447: result[current++] = new Double(eval.pctUnclassified());
448: result[current++] = new Double(eval.totalCost());
449: result[current++] = new Double(eval.avgCost());
450:
451: result[current++] = new Double(eval.meanAbsoluteError());
452: result[current++] = new Double(eval.rootMeanSquaredError());
453: result[current++] = new Double(eval.relativeAbsoluteError());
454: result[current++] = new Double(eval.rootRelativeSquaredError());
455:
456: result[current++] = new Double(eval.SFPriorEntropy());
457: result[current++] = new Double(eval.SFSchemeEntropy());
458: result[current++] = new Double(eval.SFEntropyGain());
459: result[current++] = new Double(eval.SFMeanPriorEntropy());
460: result[current++] = new Double(eval.SFMeanSchemeEntropy());
461: result[current++] = new Double(eval.SFMeanEntropyGain());
462:
463: // K&B stats
464: result[current++] = new Double(eval.KBInformation());
465: result[current++] = new Double(eval.KBMeanInformation());
466: result[current++] = new Double(eval.KBRelativeInformation());
467:
468: // Timing stats
469: result[current++] = new Double(trainTimeElapsed / 1000.0);
470: result[current++] = new Double(testTimeElapsed / 1000.0);
471: if (canMeasureCPUTime) {
472: result[current++] = new Double(
473: (trainCPUTimeElapsed / 1000000.0) / 1000.0);
474: result[current++] = new Double(
475: (testCPUTimeElapsed / 1000000.0) / 1000.0);
476: } else {
477: result[current++] = new Double(Instance.missingValue());
478: result[current++] = new Double(Instance.missingValue());
479: }
480:
481: if (m_Classifier instanceof Summarizable) {
482: result[current++] = ((Summarizable) m_Classifier)
483: .toSummaryString();
484: } else {
485: result[current++] = null;
486: }
487:
488: for (int i = 0; i < addm; i++) {
489: if (m_doesProduce[i]) {
490: try {
491: double dv = ((AdditionalMeasureProducer) m_Classifier)
492: .getMeasure(m_AdditionalMeasures[i]);
493: if (!Instance.isMissingValue(dv)) {
494: Double value = new Double(dv);
495: result[current++] = value;
496: } else {
497: result[current++] = null;
498: }
499: } catch (Exception ex) {
500: System.err.println(ex);
501: }
502: } else {
503: result[current++] = null;
504: }
505: }
506:
507: if (current != RESULT_SIZE + addm) {
508: throw new Error("Results didn't fit RESULT_SIZE");
509: }
510: return result;
511: }
512:
513: /**
514: * Returns a text description of the split evaluator.
515: *
516: * @return a text description of the split evaluator.
517: */
518: public String toString() {
519:
520: String result = "CostSensitiveClassifierSplitEvaluator: ";
521: if (m_Template == null) {
522: return result + "<null> classifier";
523: }
524: return result + m_Template.getClass().getName() + " "
525: + m_ClassifierOptions + "(version "
526: + m_ClassifierVersion + ")";
527: }
528: } // CostSensitiveClassifierSplitEvaluator
|