001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * DiscreteEstimatorBayes.java
019: * Adapted from DiscreteEstimator.java
020: *
021: */
022: package weka.classifiers.bayes.net.estimate;
023:
024: import weka.classifiers.bayes.net.search.local.Scoreable;
025: import weka.core.Statistics;
026: import weka.core.Utils;
027: import weka.estimators.DiscreteEstimator;
028: import weka.estimators.Estimator;
029:
030: /**
031: * Symbolic probability estimator based on symbol counts and a prior.
032: *
033: * @author Remco Bouckaert (rrb@xm.co.nz)
034: * @version $Revision: 1.6 $
035: */
036: public class DiscreteEstimatorBayes extends Estimator implements
037: Scoreable {
038:
039: /** for serialization */
040: static final long serialVersionUID = 4215400230843212684L;
041:
042: /**
043: * Hold the counts
044: */
045: protected double[] m_Counts;
046:
047: /**
048: * Hold the sum of counts
049: */
050: protected double m_SumOfCounts;
051:
052: /**
053: * Holds number of symbols in distribution
054: */
055: protected int m_nSymbols = 0;
056:
057: /**
058: * Holds the prior probability
059: */
060: protected double m_fPrior = 0.0;
061:
062: /**
063: * Constructor
064: *
065: * @param nSymbols the number of possible symbols (remember to include 0)
066: * @param fPrior
067: */
068: public DiscreteEstimatorBayes(int nSymbols, double fPrior) {
069: m_fPrior = fPrior;
070: m_nSymbols = nSymbols;
071: m_Counts = new double[m_nSymbols];
072:
073: for (int iSymbol = 0; iSymbol < m_nSymbols; iSymbol++) {
074: m_Counts[iSymbol] = m_fPrior;
075: }
076:
077: m_SumOfCounts = m_fPrior * (double) m_nSymbols;
078: } // DiscreteEstimatorBayes
079:
080: /**
081: * Add a new data value to the current estimator.
082: *
083: * @param data the new data value
084: * @param weight the weight assigned to the data value
085: */
086: public void addValue(double data, double weight) {
087: m_Counts[(int) data] += weight;
088: m_SumOfCounts += weight;
089: }
090:
091: /**
092: * Get a probability estimate for a value
093: *
094: * @param data the value to estimate the probability of
095: * @return the estimated probability of the supplied value
096: */
097: public double getProbability(double data) {
098: if (m_SumOfCounts == 0) {
099:
100: // this can only happen if numSymbols = 0 in constructor
101: return 0;
102: }
103:
104: return (double) m_Counts[(int) data] / m_SumOfCounts;
105: }
106:
107: /**
108: * Get a counts for a value
109: *
110: * @param data the value to get the counts for
111: * @return the count of the supplied value
112: */
113: public double getCount(double data) {
114: if (m_SumOfCounts == 0) {
115: // this can only happen if numSymbols = 0 in constructor
116: return 0;
117: }
118:
119: return m_Counts[(int) data];
120: }
121:
122: /**
123: * Gets the number of symbols this estimator operates with
124: *
125: * @return the number of estimator symbols
126: */
127: public int getNumSymbols() {
128: return (m_Counts == null) ? 0 : m_Counts.length;
129: }
130:
131: /**
132: * Gets the log score contribution of this distribution
133: * @param nType score type
134: * @return the score
135: */
136: public double logScore(int nType, int nCardinality) {
137: double fScore = 0.0;
138:
139: switch (nType) {
140:
141: case (Scoreable.BAYES): {
142: for (int iSymbol = 0; iSymbol < m_nSymbols; iSymbol++) {
143: fScore += Statistics.lnGamma(m_Counts[iSymbol]);
144: }
145:
146: fScore -= Statistics.lnGamma(m_SumOfCounts);
147: if (m_fPrior != 0.0) {
148: fScore -= m_nSymbols * Statistics.lnGamma(m_fPrior);
149: fScore += Statistics.lnGamma(m_nSymbols * m_fPrior);
150: }
151: }
152:
153: break;
154: case (Scoreable.BDeu): {
155: for (int iSymbol = 0; iSymbol < m_nSymbols; iSymbol++) {
156: fScore += Statistics.lnGamma(m_Counts[iSymbol]);
157: }
158:
159: fScore -= Statistics.lnGamma(m_SumOfCounts);
160: //fScore -= m_nSymbols * Statistics.lnGamma(1.0);
161: //fScore += Statistics.lnGamma(m_nSymbols * 1.0);
162: fScore -= m_nSymbols
163: * Statistics
164: .lnGamma(1.0 / (m_nSymbols * nCardinality));
165: fScore += Statistics.lnGamma(1.0 / nCardinality);
166: }
167: break;
168:
169: case (Scoreable.MDL):
170:
171: case (Scoreable.AIC):
172:
173: case (Scoreable.ENTROPY): {
174: for (int iSymbol = 0; iSymbol < m_nSymbols; iSymbol++) {
175: double fP = getProbability(iSymbol);
176:
177: fScore += m_Counts[iSymbol] * Math.log(fP);
178: }
179: }
180:
181: break;
182:
183: default: {
184: }
185: }
186:
187: return fScore;
188: }
189:
190: /**
191: * Display a representation of this estimator
192: *
193: * @return a string representation of the estimator
194: */
195: public String toString() {
196: String result = "Discrete Estimator. Counts = ";
197:
198: if (m_SumOfCounts > 1) {
199: for (int i = 0; i < m_Counts.length; i++) {
200: result += " " + Utils.doubleToString(m_Counts[i], 2);
201: }
202:
203: result += " (Total = "
204: + Utils.doubleToString(m_SumOfCounts, 2) + ")\n";
205: } else {
206: for (int i = 0; i < m_Counts.length; i++) {
207: result += " " + m_Counts[i];
208: }
209:
210: result += " (Total = " + m_SumOfCounts + ")\n";
211: }
212:
213: return result;
214: }
215:
216: /**
217: * Main method for testing this class.
218: *
219: * @param argv should contain a sequence of integers which
220: * will be treated as symbolic.
221: */
222: public static void main(String[] argv) {
223: try {
224: if (argv.length == 0) {
225: System.out
226: .println("Please specify a set of instances.");
227:
228: return;
229: }
230:
231: int current = Integer.parseInt(argv[0]);
232: int max = current;
233:
234: for (int i = 1; i < argv.length; i++) {
235: current = Integer.parseInt(argv[i]);
236:
237: if (current > max) {
238: max = current;
239: }
240: }
241:
242: DiscreteEstimator newEst = new DiscreteEstimator(max + 1,
243: true);
244:
245: for (int i = 0; i < argv.length; i++) {
246: current = Integer.parseInt(argv[i]);
247:
248: System.out.println(newEst);
249: System.out.println("Prediction for " + current + " = "
250: + newEst.getProbability(current));
251: newEst.addValue(current, 1);
252: }
253: } catch (Exception e) {
254: System.out.println(e.getMessage());
255: }
256: } // main
257:
258: } // class DiscreteEstimatorBayes
|