001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * NominalPrediction.java
019: * Copyright (C) 2002 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.evaluation;
024:
025: import java.io.Serializable;
026:
027: /**
028: * Encapsulates an evaluatable nominal prediction: the predicted probability
029: * distribution plus the actual class value.
030: *
031: * @author Len Trigg (len@reeltwo.com)
032: * @version $Revision: 1.11 $
033: */
034: public class NominalPrediction implements Prediction, Serializable {
035:
036: /**
037: * Remove this if you change this class so that serialization would be
038: * affected.
039: */
040: static final long serialVersionUID = -8871333992740492788L;
041:
042: /** The predicted probabilities */
043: private double[] m_Distribution;
044:
045: /** The actual class value */
046: private double m_Actual = MISSING_VALUE;
047:
048: /** The predicted class value */
049: private double m_Predicted = MISSING_VALUE;
050:
051: /** The weight assigned to this prediction */
052: private double m_Weight = 1;
053:
054: /**
055: * Creates the NominalPrediction object with a default weight of 1.0.
056: *
057: * @param actual the actual value, or MISSING_VALUE.
058: * @param distribution the predicted probability distribution. Use
059: * NominalPrediction.makeDistribution() if you only know the predicted value.
060: */
061: public NominalPrediction(double actual, double[] distribution) {
062:
063: this (actual, distribution, 1);
064: }
065:
066: /**
067: * Creates the NominalPrediction object.
068: *
069: * @param actual the actual value, or MISSING_VALUE.
070: * @param distribution the predicted probability distribution. Use
071: * NominalPrediction.makeDistribution() if you only know the predicted value.
072: * @param weight the weight assigned to the prediction.
073: */
074: public NominalPrediction(double actual, double[] distribution,
075: double weight) {
076:
077: if (distribution == null) {
078: throw new NullPointerException(
079: "Null distribution in NominalPrediction.");
080: }
081: m_Actual = actual;
082: m_Distribution = distribution.clone();
083: m_Weight = weight;
084: updatePredicted();
085: }
086:
087: /**
088: * Gets the predicted probabilities
089: *
090: * @return the predicted probabilities
091: */
092: public double[] distribution() {
093:
094: return m_Distribution;
095: }
096:
097: /**
098: * Gets the actual class value.
099: *
100: * @return the actual class value, or MISSING_VALUE if no
101: * prediction was made.
102: */
103: public double actual() {
104:
105: return m_Actual;
106: }
107:
108: /**
109: * Gets the predicted class value.
110: *
111: * @return the predicted class value, or MISSING_VALUE if no
112: * prediction was made.
113: */
114: public double predicted() {
115:
116: return m_Predicted;
117: }
118:
119: /**
120: * Gets the weight assigned to this prediction. This is typically the weight
121: * of the test instance the prediction was made for.
122: *
123: * @return the weight assigned to this prediction.
124: */
125: public double weight() {
126:
127: return m_Weight;
128: }
129:
130: /**
131: * Calculates the prediction margin. This is defined as the difference
132: * between the probability predicted for the actual class and the highest
133: * predicted probability of the other classes.
134: *
135: * @return the margin for this prediction, or
136: * MISSING_VALUE if either the actual or predicted value
137: * is missing.
138: */
139: public double margin() {
140:
141: if ((m_Actual == MISSING_VALUE)
142: || (m_Predicted == MISSING_VALUE)) {
143: return MISSING_VALUE;
144: }
145: double probActual = m_Distribution[(int) m_Actual];
146: double probNext = 0;
147: for (int i = 0; i < m_Distribution.length; i++)
148: if ((i != m_Actual) && (m_Distribution[i] > probNext))
149: probNext = m_Distribution[i];
150:
151: return probActual - probNext;
152: }
153:
154: /**
155: * Convert a single prediction into a probability distribution
156: * with all zero probabilities except the predicted value which
157: * has probability 1.0. If no prediction was made, all probabilities
158: * are zero.
159: *
160: * @param predictedClass the index of the predicted class, or
161: * MISSING_VALUE if no prediction was made.
162: * @param numClasses the number of possible classes for this nominal
163: * prediction.
164: * @return the probability distribution.
165: */
166: public static double[] makeDistribution(double predictedClass,
167: int numClasses) {
168:
169: double[] dist = new double[numClasses];
170: if (predictedClass == MISSING_VALUE) {
171: return dist;
172: }
173: dist[(int) predictedClass] = 1.0;
174: return dist;
175: }
176:
177: /**
178: * Creates a uniform probability distribution -- where each of the
179: * possible classes is assigned equal probability.
180: *
181: * @param numClasses the number of possible classes for this nominal
182: * prediction.
183: * @return the probability distribution.
184: */
185: public static double[] makeUniformDistribution(int numClasses) {
186:
187: double[] dist = new double[numClasses];
188: for (int i = 0; i < numClasses; i++) {
189: dist[i] = 1.0 / numClasses;
190: }
191: return dist;
192: }
193:
194: /**
195: * Determines the predicted class (doesn't detect multiple
196: * classifications). If no prediction was made (i.e. all zero
197: * probababilities in the distribution), m_Prediction is set to
198: * MISSING_VALUE.
199: */
200: private void updatePredicted() {
201:
202: int predictedClass = -1;
203: double bestProb = 0.0;
204: for (int i = 0; i < m_Distribution.length; i++) {
205: if (m_Distribution[i] > bestProb) {
206: predictedClass = i;
207: bestProb = m_Distribution[i];
208: }
209: }
210:
211: if (predictedClass != -1) {
212: m_Predicted = predictedClass;
213: } else {
214: m_Predicted = MISSING_VALUE;
215: }
216: }
217:
218: /**
219: * Gets a human readable representation of this prediction.
220: *
221: * @return a human readable representation of this prediction.
222: */
223: public String toString() {
224:
225: StringBuffer sb = new StringBuffer();
226: sb.append("NOM: ").append(actual()).append(" ").append(
227: predicted());
228: sb.append(' ').append(weight());
229: double[] dist = distribution();
230: for (int i = 0; i < dist.length; i++) {
231: sb.append(' ').append(dist[i]);
232: }
233: return sb.toString();
234: }
235: }
|