001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * DiscreteEstimator.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.estimators;
024:
025: import weka.core.Capabilities.Capability;
026: import weka.core.Capabilities;
027: import weka.core.Utils;
028:
029: /**
030: * Simple symbolic probability estimator based on symbol counts.
031: *
032: * @author Len Trigg (trigg@cs.waikato.ac.nz)
033: * @version $Revision: 1.9 $
034: */
035: public class DiscreteEstimator extends Estimator implements
036: IncrementalEstimator {
037:
038: /** for serialization */
039: private static final long serialVersionUID = -5526486742612434779L;
040:
041: /** Hold the counts */
042: private double[] m_Counts;
043:
044: /** Hold the sum of counts */
045: private double m_SumOfCounts;
046:
047: /**
048: * Constructor
049: *
050: * @param numSymbols the number of possible symbols (remember to include 0)
051: * @param laplace if true, counts will be initialised to 1
052: */
053: public DiscreteEstimator(int numSymbols, boolean laplace) {
054:
055: m_Counts = new double[numSymbols];
056: m_SumOfCounts = 0;
057: if (laplace) {
058: for (int i = 0; i < numSymbols; i++) {
059: m_Counts[i] = 1;
060: }
061: m_SumOfCounts = (double) numSymbols;
062: }
063: }
064:
065: /**
066: * Constructor
067: *
068: * @param nSymbols the number of possible symbols (remember to include 0)
069: * @param fPrior value with which counts will be initialised
070: */
071: public DiscreteEstimator(int nSymbols, double fPrior) {
072:
073: m_Counts = new double[nSymbols];
074: for (int iSymbol = 0; iSymbol < nSymbols; iSymbol++) {
075: m_Counts[iSymbol] = fPrior;
076: }
077: m_SumOfCounts = fPrior * (double) nSymbols;
078: }
079:
080: /**
081: * Add a new data value to the current estimator.
082: *
083: * @param data the new data value
084: * @param weight the weight assigned to the data value
085: */
086: public void addValue(double data, double weight) {
087:
088: m_Counts[(int) data] += weight;
089: m_SumOfCounts += weight;
090: }
091:
092: /**
093: * Get a probability estimate for a value
094: *
095: * @param data the value to estimate the probability of
096: * @return the estimated probability of the supplied value
097: */
098: public double getProbability(double data) {
099:
100: if (m_SumOfCounts == 0) {
101: return 0;
102: }
103: return (double) m_Counts[(int) data] / m_SumOfCounts;
104: }
105:
106: /**
107: * Gets the number of symbols this estimator operates with
108: *
109: * @return the number of estimator symbols
110: */
111: public int getNumSymbols() {
112:
113: return (m_Counts == null) ? 0 : m_Counts.length;
114: }
115:
116: /**
117: * Get the count for a value
118: *
119: * @param data the value to get the count of
120: * @return the count of the supplied value
121: */
122: public double getCount(double data) {
123:
124: if (m_SumOfCounts == 0) {
125: return 0;
126: }
127: return m_Counts[(int) data];
128: }
129:
130: /**
131: * Get the sum of all the counts
132: *
133: * @return the total sum of counts
134: */
135: public double getSumOfCounts() {
136:
137: return m_SumOfCounts;
138: }
139:
140: /**
141: * Display a representation of this estimator
142: */
143: public String toString() {
144:
145: StringBuffer result = new StringBuffer(
146: "Discrete Estimator. Counts = ");
147: if (m_SumOfCounts > 1) {
148: for (int i = 0; i < m_Counts.length; i++) {
149: result.append(" ").append(
150: Utils.doubleToString(m_Counts[i], 2));
151: }
152: result.append(" (Total = ").append(
153: Utils.doubleToString(m_SumOfCounts, 2));
154: result.append(")\n");
155: } else {
156: for (int i = 0; i < m_Counts.length; i++) {
157: result.append(" ").append(m_Counts[i]);
158: }
159: result.append(" (Total = ").append(m_SumOfCounts).append(
160: ")\n");
161: }
162: return result.toString();
163: }
164:
165: /**
166: * Returns default capabilities of the classifier.
167: *
168: * @return the capabilities of this classifier
169: */
170: public Capabilities getCapabilities() {
171: Capabilities result = super .getCapabilities();
172:
173: // attributes
174: result.enable(Capability.NUMERIC_ATTRIBUTES);
175: return result;
176: }
177:
178: /**
179: * Main method for testing this class.
180: *
181: * @param argv should contain a sequence of integers which
182: * will be treated as symbolic.
183: */
184: public static void main(String[] argv) {
185:
186: try {
187: if (argv.length == 0) {
188: System.out
189: .println("Please specify a set of instances.");
190: return;
191: }
192: int current = Integer.parseInt(argv[0]);
193: int max = current;
194: for (int i = 1; i < argv.length; i++) {
195: current = Integer.parseInt(argv[i]);
196: if (current > max) {
197: max = current;
198: }
199: }
200: DiscreteEstimator newEst = new DiscreteEstimator(max + 1,
201: true);
202: for (int i = 0; i < argv.length; i++) {
203: current = Integer.parseInt(argv[i]);
204: System.out.println(newEst);
205: System.out.println("Prediction for " + current + " = "
206: + newEst.getProbability(current));
207: newEst.addValue(current, 1);
208: }
209: } catch (Exception e) {
210: System.out.println(e.getMessage());
211: }
212: }
213: }
|