001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * KKConditionalEstimator.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.estimators;
024:
025: import java.util.Random;
026:
027: import weka.core.Statistics;
028: import weka.core.Utils;
029:
030: /**
031: * Conditional probability estimator for a numeric domain conditional upon
032: * a numeric domain.
033: *
034: * @author Len Trigg (trigg@cs.waikato.ac.nz)
035: * @version $Revision: 1.7 $
036: */
037: public class KKConditionalEstimator implements ConditionalEstimator {
038:
039: /** Vector containing all of the values seen */
040: private double[] m_Values;
041:
042: /** Vector containing all of the conditioning values seen */
043: private double[] m_CondValues;
044:
045: /** Vector containing the associated weights */
046: private double[] m_Weights;
047:
048: /**
049: * Number of values stored in m_Weights, m_CondValues, and m_Values so far
050: */
051: private int m_NumValues;
052:
053: /** The sum of the weights so far */
054: private double m_SumOfWeights;
055:
056: /** Current standard dev */
057: private double m_StandardDev;
058:
059: /** Whether we can optimise the kernel summation */
060: private boolean m_AllWeightsOne;
061:
062: /** The numeric precision */
063: private double m_Precision;
064:
065: /**
066: * Execute a binary search to locate the nearest data value
067: *
068: * @param key the data value to locate
069: * @param secondaryKey the data value to locate
070: * @return the index of the nearest data value
071: */
072: private int findNearestPair(double key, double secondaryKey) {
073:
074: int low = 0;
075: int high = m_NumValues;
076: int middle = 0;
077: while (low < high) {
078: middle = (low + high) / 2;
079: double current = m_CondValues[middle];
080: if (current == key) {
081: double secondary = m_Values[middle];
082: if (secondary == secondaryKey) {
083: return middle;
084: }
085: if (secondary > secondaryKey) {
086: high = middle;
087: } else if (secondary < secondaryKey) {
088: low = middle + 1;
089: }
090: }
091: if (current > key) {
092: high = middle;
093: } else if (current < key) {
094: low = middle + 1;
095: }
096: }
097: return low;
098: }
099:
100: /**
101: * Round a data value using the defined precision for this estimator
102: *
103: * @param data the value to round
104: * @return the rounded data value
105: */
106: private double round(double data) {
107:
108: return Math.rint(data / m_Precision) * m_Precision;
109: }
110:
111: /**
112: * Constructor
113: *
114: * @param precision the precision to which numeric values are given. For
115: * example, if the precision is stated to be 0.1, the values in the
116: * interval (0.25,0.35] are all treated as 0.3.
117: */
118: public KKConditionalEstimator(double precision) {
119:
120: m_CondValues = new double[50];
121: m_Values = new double[50];
122: m_Weights = new double[50];
123: m_NumValues = 0;
124: m_SumOfWeights = 0;
125: m_StandardDev = 0;
126: m_AllWeightsOne = true;
127: m_Precision = precision;
128: }
129:
130: /**
131: * Add a new data value to the current estimator.
132: *
133: * @param data the new data value
134: * @param given the new value that data is conditional upon
135: * @param weight the weight assigned to the data value
136: */
137: public void addValue(double data, double given, double weight) {
138:
139: data = round(data);
140: given = round(given);
141: int insertIndex = findNearestPair(given, data);
142: if ((m_NumValues <= insertIndex)
143: || (m_CondValues[insertIndex] != given)
144: || (m_Values[insertIndex] != data)) {
145: if (m_NumValues < m_Values.length) {
146: int left = m_NumValues - insertIndex;
147: System.arraycopy(m_Values, insertIndex, m_Values,
148: insertIndex + 1, left);
149: System.arraycopy(m_CondValues, insertIndex,
150: m_CondValues, insertIndex + 1, left);
151: System.arraycopy(m_Weights, insertIndex, m_Weights,
152: insertIndex + 1, left);
153: m_Values[insertIndex] = data;
154: m_CondValues[insertIndex] = given;
155: m_Weights[insertIndex] = weight;
156: m_NumValues++;
157: } else {
158: double[] newValues = new double[m_Values.length * 2];
159: double[] newCondValues = new double[m_Values.length * 2];
160: double[] newWeights = new double[m_Values.length * 2];
161: int left = m_NumValues - insertIndex;
162: System
163: .arraycopy(m_Values, 0, newValues, 0,
164: insertIndex);
165: System.arraycopy(m_CondValues, 0, newCondValues, 0,
166: insertIndex);
167: System.arraycopy(m_Weights, 0, newWeights, 0,
168: insertIndex);
169: newValues[insertIndex] = data;
170: newCondValues[insertIndex] = given;
171: newWeights[insertIndex] = weight;
172: System.arraycopy(m_Values, insertIndex, newValues,
173: insertIndex + 1, left);
174: System.arraycopy(m_CondValues, insertIndex,
175: newCondValues, insertIndex + 1, left);
176: System.arraycopy(m_Weights, insertIndex, newWeights,
177: insertIndex + 1, left);
178: m_NumValues++;
179: m_Values = newValues;
180: m_CondValues = newCondValues;
181: m_Weights = newWeights;
182: }
183: if (weight != 1) {
184: m_AllWeightsOne = false;
185: }
186: } else {
187: m_Weights[insertIndex] += weight;
188: m_AllWeightsOne = false;
189: }
190: m_SumOfWeights += weight;
191: double range = m_CondValues[m_NumValues - 1] - m_CondValues[0];
192: m_StandardDev = Math.max(range / Math.sqrt(m_SumOfWeights),
193: // allow at most 3 sds within one interval
194: m_Precision / (2 * 3));
195: }
196:
197: /**
198: * Get a probability estimator for a value
199: *
200: * @param given the new value that data is conditional upon
201: * @return the estimator for the supplied value given the condition
202: */
203: public Estimator getEstimator(double given) {
204:
205: Estimator result = new KernelEstimator(m_Precision);
206: if (m_NumValues == 0) {
207: return result;
208: }
209:
210: double delta = 0, currentProb = 0;
211: double zLower, zUpper;
212: for (int i = 0; i < m_NumValues; i++) {
213: delta = m_CondValues[i] - given;
214: zLower = (delta - (m_Precision / 2)) / m_StandardDev;
215: zUpper = (delta + (m_Precision / 2)) / m_StandardDev;
216: currentProb = (Statistics.normalProbability(zUpper) - Statistics
217: .normalProbability(zLower));
218: result.addValue(m_Values[i], currentProb * m_Weights[i]);
219: }
220: return result;
221: }
222:
223: /**
224: * Get a probability estimate for a value
225: *
226: * @param data the value to estimate the probability of
227: * @param given the new value that data is conditional upon
228: * @return the estimated probability of the supplied value
229: */
230: public double getProbability(double data, double given) {
231:
232: return getEstimator(given).getProbability(data);
233: }
234:
235: /**
236: * Display a representation of this estimator
237: */
238: public String toString() {
239:
240: String result = "KK Conditional Estimator. " + m_NumValues
241: + " Normal Kernels:\n" + "StandardDev = "
242: + Utils.doubleToString(m_StandardDev, 4, 2)
243: + " \nMeans =";
244: for (int i = 0; i < m_NumValues; i++) {
245: result += " (" + m_Values[i] + ", " + m_CondValues[i] + ")";
246: if (!m_AllWeightsOne) {
247: result += "w=" + m_Weights[i];
248: }
249: }
250: return result;
251: }
252:
253: /**
254: * Main method for testing this class. Creates some random points
255: * in the range 0 - 100,
256: * and prints out a distribution conditional on some value
257: *
258: * @param argv should contain: seed conditional_value numpoints
259: */
260: public static void main(String[] argv) {
261:
262: try {
263: int seed = 42;
264: if (argv.length > 0) {
265: seed = Integer.parseInt(argv[0]);
266: }
267: KKConditionalEstimator newEst = new KKConditionalEstimator(
268: 0.1);
269:
270: // Create 100 random points and add them
271: Random r = new Random(seed);
272:
273: int numPoints = 50;
274: if (argv.length > 2) {
275: numPoints = Integer.parseInt(argv[2]);
276: }
277: for (int i = 0; i < numPoints; i++) {
278: int x = Math.abs(r.nextInt() % 100);
279: int y = Math.abs(r.nextInt() % 100);
280: System.out.println("# " + x + " " + y);
281: newEst.addValue(x, y, 1);
282: }
283: // System.out.println(newEst);
284: int cond;
285: if (argv.length > 1) {
286: cond = Integer.parseInt(argv[1]);
287: } else {
288: cond = Math.abs(r.nextInt() % 100);
289: }
290: System.out.println("## Conditional = " + cond);
291: Estimator result = newEst.getEstimator(cond);
292: for (int i = 0; i <= 100; i += 5) {
293: System.out.println(" " + i + " "
294: + result.getProbability(i));
295: }
296: } catch (Exception e) {
297: System.out.println(e.getMessage());
298: }
299: }
300: }
|