001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * ClassifierPerformanceEvaluator.java
019: * Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.gui.beans;
024:
025: import weka.classifiers.Classifier;
026: import weka.classifiers.Evaluation;
027: import weka.classifiers.evaluation.ThresholdCurve;
028: import weka.core.FastVector;
029: import weka.core.Instance;
030: import weka.core.Instances;
031: import weka.gui.visualize.PlotData2D;
032:
033: import java.io.Serializable;
034: import java.util.Enumeration;
035: import java.util.Vector;
036:
037: /**
038: * A bean that evaluates the performance of batch trained classifiers
039: *
040: * @author <a href="mailto:mhall@cs.waikato.ac.nz">Mark Hall</a>
041: * @version $Revision: 1.15 $
042: */
043: public class ClassifierPerformanceEvaluator extends AbstractEvaluator
044: implements BatchClassifierListener, Serializable,
045: UserRequestAcceptor, EventConstraints {
046:
047: /** for serialization */
048: private static final long serialVersionUID = -3511801418192148690L;
049:
050: /**
051: * Evaluation object used for evaluating a classifier
052: */
053: private transient Evaluation m_eval;
054:
055: /**
056: * Holds the classifier to be evaluated
057: */
058: private transient Classifier m_classifier;
059:
060: private transient Thread m_evaluateThread = null;
061:
062: private Vector m_textListeners = new Vector();
063: private Vector m_thresholdListeners = new Vector();
064: private Vector m_visualizableErrorListeners = new Vector();
065:
066: public ClassifierPerformanceEvaluator() {
067: m_visual
068: .loadIcons(
069: BeanVisual.ICON_PATH
070: + "ClassifierPerformanceEvaluator.gif",
071: BeanVisual.ICON_PATH
072: + "ClassifierPerformanceEvaluator_animated.gif");
073: m_visual.setText("ClassifierPerformanceEvaluator");
074: }
075:
076: /**
077: * Global info for this bean
078: *
079: * @return a <code>String</code> value
080: */
081: public String globalInfo() {
082: return "Evaluate the performance of batch trained classifiers.";
083: }
084:
085: // ----- Stuff for ROC curves
086: private boolean m_rocListenersConnected = false;
087: // Plottable Instances with predictions appended
088: private Instances m_predInstances = null;
089: // Actual predictions
090: private FastVector m_plotShape = null;
091: private FastVector m_plotSize = null;
092:
093: /**
094: * Accept a classifier to be evaluated
095: *
096: * @param ce a <code>BatchClassifierEvent</code> value
097: */
098: public void acceptClassifier(final BatchClassifierEvent ce) {
099: if (ce.getTestSet().isStructureOnly()) {
100: return; // cant evaluate empty instances
101: }
102: try {
103: if (m_evaluateThread == null) {
104: m_evaluateThread = new Thread() {
105: public void run() {
106: final String oldText = m_visual.getText();
107: try {
108: if (ce.getSetNumber() == 1
109: || ce.getClassifier() != m_classifier) {
110: m_eval = new Evaluation(ce.getTestSet()
111: .getDataSet());
112: m_classifier = ce.getClassifier();
113: m_predInstances = weka.gui.explorer.ClassifierPanel
114: .setUpVisualizableInstances(new Instances(
115: ce.getTestSet()
116: .getDataSet()));
117: m_plotShape = new FastVector();
118: m_plotSize = new FastVector();
119: }
120: if (ce.getSetNumber() <= ce
121: .getMaxSetNumber()) {
122: m_visual.setText("Evaluating ("
123: + ce.getSetNumber() + ")...");
124: if (m_logger != null) {
125: m_logger
126: .statusMessage("ClassifierPerformaceEvaluator : "
127: + "evaluating ("
128: + ce.getSetNumber()
129: + ")...");
130: }
131: m_visual.setAnimated();
132: /*
133: m_eval.evaluateModel(ce.getClassifier(),
134: ce.getTestSet().getDataSet()); */
135: for (int i = 0; i < ce.getTestSet()
136: .getDataSet().numInstances(); i++) {
137: Instance temp = ce.getTestSet()
138: .getDataSet().instance(i);
139: weka.gui.explorer.ClassifierPanel
140: .processClassifierPrediction(
141: temp,
142: ce.getClassifier(),
143: m_eval,
144: m_predInstances,
145: m_plotShape,
146: m_plotSize);
147: }
148: }
149:
150: if (ce.getSetNumber() == ce
151: .getMaxSetNumber()) {
152: System.err.println(m_eval
153: .toSummaryString());
154: // m_resultsString.append(m_eval.toSummaryString());
155: // m_outText.setText(m_resultsString.toString());
156: String textTitle = m_classifier
157: .getClass().getName();
158: textTitle = textTitle.substring(
159: textTitle.lastIndexOf('.') + 1,
160: textTitle.length());
161: String resultT = "=== Evaluation result ===\n\n"
162: + "Scheme: "
163: + textTitle
164: + "\n"
165: + "Relation: "
166: + ce.getTestSet().getDataSet()
167: .relationName()
168: + "\n\n"
169: + m_eval.toSummaryString();
170:
171: if (ce.getTestSet().getDataSet()
172: .classAttribute().isNominal()) {
173: resultT += "\n"
174: + m_eval
175: .toClassDetailsString()
176: + "\n"
177: + m_eval.toMatrixString();
178: }
179:
180: TextEvent te = new TextEvent(
181: ClassifierPerformanceEvaluator.this ,
182: resultT, textTitle);
183: notifyTextListeners(te);
184:
185: // set up visualizable errors
186: if (m_visualizableErrorListeners.size() > 0) {
187: PlotData2D errorD = new PlotData2D(
188: m_predInstances);
189: errorD.setShapeSize(m_plotSize);
190: errorD.setShapeType(m_plotShape);
191: errorD.setPlotName(textTitle
192: + " ("
193: + ce.getTestSet()
194: .getDataSet()
195: .relationName()
196: + ")");
197: errorD.addInstanceNumberAttribute();
198: VisualizableErrorEvent vel = new VisualizableErrorEvent(
199: ClassifierPerformanceEvaluator.this ,
200: errorD);
201: notifyVisualizableErrorListeners(vel);
202: }
203:
204: if (ce.getTestSet().getDataSet()
205: .classAttribute().isNominal()) {
206: ThresholdCurve tc = new ThresholdCurve();
207: Instances result = tc.getCurve(
208: m_eval.predictions(), 0);
209: result.setRelationName(ce
210: .getTestSet().getDataSet()
211: .relationName());
212: PlotData2D pd = new PlotData2D(
213: result);
214: pd.setPlotName(textTitle
215: + " ("
216: + ce.getTestSet()
217: .getDataSet()
218: .classAttribute()
219: .value(0) + ")");
220: boolean[] connectPoints = new boolean[result
221: .numInstances()];
222: for (int jj = 1; jj < connectPoints.length; jj++) {
223: connectPoints[jj] = true;
224: }
225: pd.setConnectPoints(connectPoints);
226: ThresholdDataEvent rde = new ThresholdDataEvent(
227: ClassifierPerformanceEvaluator.this ,
228: pd);
229: notifyThresholdListeners(rde);
230: /*te = new TextEvent(ClassifierPerformanceEvaluator.this,
231: result.toString(),
232: "ThresholdCurveInst");
233: notifyTextListeners(te); */
234: }
235: if (m_logger != null) {
236: m_logger.statusMessage("Done.");
237: }
238: }
239: } catch (Exception ex) {
240: ex.printStackTrace();
241: } finally {
242: m_visual.setText(oldText);
243: m_visual.setStatic();
244: m_evaluateThread = null;
245: if (isInterrupted()) {
246: if (m_logger != null) {
247: m_logger
248: .logMessage("Evaluation interrupted!");
249: m_logger.statusMessage("OK");
250: }
251: }
252: block(false);
253: }
254: }
255: };
256: m_evaluateThread.setPriority(Thread.MIN_PRIORITY);
257: m_evaluateThread.start();
258:
259: // make sure the thread is still running before we block
260: // if (m_evaluateThread.isAlive()) {
261: block(true);
262: // }
263: m_evaluateThread = null;
264: }
265: } catch (Exception ex) {
266: ex.printStackTrace();
267: }
268: }
269:
270: /**
271: * Try and stop any action
272: */
273: public void stop() {
274: // tell the listenee (upstream bean) to stop
275: if (m_listenee instanceof BeanCommon) {
276: System.err.println("Listener is BeanCommon");
277: ((BeanCommon) m_listenee).stop();
278: }
279:
280: // stop the evaluate thread
281: if (m_evaluateThread != null) {
282: m_evaluateThread.interrupt();
283: m_evaluateThread.stop();
284: }
285: }
286:
287: /**
288: * Function used to stop code that calls acceptClassifier. This is
289: * needed as classifier evaluation is performed inside a separate
290: * thread of execution.
291: *
292: * @param tf a <code>boolean</code> value
293: */
294: private synchronized void block(boolean tf) {
295: if (tf) {
296: try {
297: // only block if thread is still doing something useful!
298: if (m_evaluateThread != null
299: && m_evaluateThread.isAlive()) {
300: wait();
301: }
302: } catch (InterruptedException ex) {
303: }
304: } else {
305: notifyAll();
306: }
307: }
308:
309: /**
310: * Return an enumeration of user activated requests for this bean
311: *
312: * @return an <code>Enumeration</code> value
313: */
314: public Enumeration enumerateRequests() {
315: Vector newVector = new Vector(0);
316: if (m_evaluateThread != null) {
317: newVector.addElement("Stop");
318: }
319: return newVector.elements();
320: }
321:
322: /**
323: * Perform the named request
324: *
325: * @param request the request to perform
326: * @exception IllegalArgumentException if an error occurs
327: */
328: public void performRequest(String request) {
329: if (request.compareTo("Stop") == 0) {
330: stop();
331: } else {
332: throw new IllegalArgumentException(request
333:
334: + " not supported (ClassifierPerformanceEvaluator)");
335: }
336: }
337:
338: /**
339: * Add a text listener
340: *
341: * @param cl a <code>TextListener</code> value
342: */
343: public synchronized void addTextListener(TextListener cl) {
344: m_textListeners.addElement(cl);
345: }
346:
347: /**
348: * Remove a text listener
349: *
350: * @param cl a <code>TextListener</code> value
351: */
352: public synchronized void removeTextListener(TextListener cl) {
353: m_textListeners.remove(cl);
354: }
355:
356: /**
357: * Add a threshold data listener
358: *
359: * @param cl a <code>ThresholdDataListener</code> value
360: */
361: public synchronized void addThresholdDataListener(
362: ThresholdDataListener cl) {
363: m_thresholdListeners.addElement(cl);
364: }
365:
366: /**
367: * Remove a Threshold data listener
368: *
369: * @param cl a <code>ThresholdDataListener</code> value
370: */
371: public synchronized void removeThresholdDataListener(
372: ThresholdDataListener cl) {
373: m_thresholdListeners.remove(cl);
374: }
375:
376: /**
377: * Add a visualizable error listener
378: *
379: * @param vel a <code>VisualizableErrorListener</code> value
380: */
381: public synchronized void addVisualizableErrorListener(
382: VisualizableErrorListener vel) {
383: m_visualizableErrorListeners.add(vel);
384: }
385:
386: /**
387: * Remove a visualizable error listener
388: *
389: * @param vel a <code>VisualizableErrorListener</code> value
390: */
391: public synchronized void removeVisualizableErrorListener(
392: VisualizableErrorListener vel) {
393: m_visualizableErrorListeners.remove(vel);
394: }
395:
396: /**
397: * Notify all text listeners of a TextEvent
398: *
399: * @param te a <code>TextEvent</code> value
400: */
401: private void notifyTextListeners(TextEvent te) {
402: Vector l;
403: synchronized (this ) {
404: l = (Vector) m_textListeners.clone();
405: }
406: if (l.size() > 0) {
407: for (int i = 0; i < l.size(); i++) {
408: // System.err.println("Notifying text listeners "
409: // +"(ClassifierPerformanceEvaluator)");
410: ((TextListener) l.elementAt(i)).acceptText(te);
411: }
412: }
413: }
414:
415: /**
416: * Notify all ThresholdDataListeners of a ThresholdDataEvent
417: *
418: * @param te a <code>ThresholdDataEvent</code> value
419: */
420: private void notifyThresholdListeners(ThresholdDataEvent re) {
421: Vector l;
422: synchronized (this ) {
423: l = (Vector) m_thresholdListeners.clone();
424: }
425: if (l.size() > 0) {
426: for (int i = 0; i < l.size(); i++) {
427: // System.err.println("Notifying text listeners "
428: // +"(ClassifierPerformanceEvaluator)");
429: ((ThresholdDataListener) l.elementAt(i))
430: .acceptDataSet(re);
431: }
432: }
433: }
434:
435: /**
436: * Notify all VisualizableErrorListeners of a VisualizableErrorEvent
437: *
438: * @param te a <code>VisualizableErrorEvent</code> value
439: */
440: private void notifyVisualizableErrorListeners(
441: VisualizableErrorEvent re) {
442: Vector l;
443: synchronized (this ) {
444: l = (Vector) m_visualizableErrorListeners.clone();
445: }
446: if (l.size() > 0) {
447: for (int i = 0; i < l.size(); i++) {
448: // System.err.println("Notifying text listeners "
449: // +"(ClassifierPerformanceEvaluator)");
450: ((VisualizableErrorListener) l.elementAt(i))
451: .acceptDataSet(re);
452: }
453: }
454: }
455:
456: /**
457: * Returns true, if at the current time, the named event could
458: * be generated. Assumes that supplied event names are names of
459: * events that could be generated by this bean.
460: *
461: * @param eventName the name of the event in question
462: * @return true if the named event could be generated at this point in
463: * time
464: */
465: public boolean eventGeneratable(String eventName) {
466: if (m_listenee == null) {
467: return false;
468: }
469:
470: if (m_listenee instanceof EventConstraints) {
471: if (!((EventConstraints) m_listenee)
472: .eventGeneratable("batchClassifier")) {
473: return false;
474: }
475: }
476: return true;
477: }
478: }
|