001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * PriorEstimation.java
019: * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.associations;
024:
025: import weka.core.Instances;
026: import weka.core.SpecialFunctions;
027: import weka.core.Utils;
028:
029: import java.io.Serializable;
030: import java.util.Hashtable;
031: import java.util.Random;
032:
033: /**
034: * Class implementing the prior estimattion of the predictive apriori algorithm
035: * for mining association rules.
036: *
037: * Reference: T. Scheffer (2001). <i>Finding Association Rules That Trade Support
038: * Optimally against Confidence</i>. Proc of the 5th European Conf.
039: * on Principles and Practice of Knowledge Discovery in Databases (PKDD'01),
040: * pp. 424-435. Freiburg, Germany: Springer-Verlag. <p>
041: *
042: * @author Stefan Mutter (mutter@cs.waikato.ac.nz)
043: * @version $Revision: 1.6 $ */
044:
045: public class PriorEstimation implements Serializable {
046:
047: /** for serialization */
048: private static final long serialVersionUID = 5570863216522496271L;
049:
050: /** The number of rnadom rules. */
051: protected int m_numRandRules;
052:
053: /** The number of intervals. */
054: protected int m_numIntervals;
055:
056: /** The random seed used for the random rule generation step. */
057: protected static final int SEED = 0;
058:
059: /** The maximum number of attributes for which a prior can be estimated. */
060: protected static final int MAX_N = 1024;
061:
062: /** The random number generator. */
063: protected Random m_randNum;
064:
065: /** The instances for which association rules are mined. */
066: protected Instances m_instances;
067:
068: /** Flag indicating whether standard association rules or class association rules are mined. */
069: protected boolean m_CARs;
070:
071: /** Hashtable to store the confidence values of randomly generated rules. */
072: protected Hashtable m_distribution;
073:
074: /** Hashtable containing the estimated prior probabilities. */
075: protected Hashtable m_priors;
076:
077: /** Sums up the confidences of all rules with a certain length. */
078: protected double m_sum;
079:
080: /** The mid points of the discrete intervals in which the interval [0,1] is divided. */
081: protected double[] m_midPoints;
082:
083: /**
084: * Constructor
085: *
086: * @param instances the instances to be used for generating the associations
087: * @param numRules the number of random rules used for generating the prior
088: * @param numIntervals the number of intervals to discretise [0,1]
089: * @param car flag indicating whether standard or class association rules are mined
090: */
091: public PriorEstimation(Instances instances, int numRules,
092: int numIntervals, boolean car) {
093:
094: m_instances = instances;
095: m_CARs = car;
096: m_numRandRules = numRules;
097: m_numIntervals = numIntervals;
098: m_randNum = m_instances.getRandomNumberGenerator(SEED);
099: }
100:
101: /**
102: * Calculates the prior distribution.
103: *
104: * @exception Exception if prior can't be estimated successfully
105: */
106: public final void generateDistribution() throws Exception {
107:
108: boolean jump;
109: int i, maxLength = m_instances.numAttributes(), count = 0, count1 = 0, ruleCounter;
110: int[] itemArray;
111: m_distribution = new Hashtable(maxLength * m_numIntervals);
112: RuleItem current;
113: ItemSet generate;
114:
115: if (m_instances.numAttributes() == 0)
116: throw new Exception("Dataset has no attributes!");
117: if (m_instances.numAttributes() >= MAX_N)
118: throw new Exception(
119: "Dataset has to many attributes for prior estimation!");
120: if (m_instances.numInstances() == 0)
121: throw new Exception("Dataset has no instances!");
122: for (int h = 0; h < maxLength; h++) {
123: if (m_instances.attribute(h).isNumeric())
124: throw new Exception("Can't handle numeric attributes!");
125: }
126: if (m_numIntervals == 0 || m_numRandRules == 0)
127: throw new Exception("Prior initialisation impossible");
128:
129: //calculate mid points for the intervals
130: midPoints();
131:
132: //create random rules of length i and measure their support and if support >0 their confidence
133: for (i = 1; i <= maxLength; i++) {
134: m_sum = 0;
135: int j = 0;
136: count = 0;
137: count1 = 0;
138: while (j < m_numRandRules) {
139: count++;
140: jump = false;
141: if (!m_CARs) {
142: itemArray = randomRule(maxLength, i, m_randNum);
143: current = splitItemSet(m_randNum.nextInt(i),
144: itemArray);
145: } else {
146: itemArray = randomCARule(maxLength, i, m_randNum);
147: current = addCons(itemArray);
148: }
149: int[] ruleItem = new int[maxLength];
150: for (int k = 0; k < itemArray.length; k++) {
151: if (current.m_premise.m_items[k] != -1)
152: ruleItem[k] = current.m_premise.m_items[k];
153: else if (current.m_consequence.m_items[k] != -1)
154: ruleItem[k] = current.m_consequence.m_items[k];
155: else
156: ruleItem[k] = -1;
157: }
158: ItemSet rule = new ItemSet(ruleItem);
159: updateCounters(rule);
160: ruleCounter = rule.m_counter;
161: if (ruleCounter > 0)
162: jump = true;
163: updateCounters(current.m_premise);
164: j++;
165: if (jump) {
166: buildDistribution((double) ruleCounter
167: / (double) current.m_premise.m_counter,
168: (double) i);
169: }
170: }
171:
172: //normalize
173: if (m_sum > 0) {
174: for (int w = 0; w < m_midPoints.length; w++) {
175: String key = (String.valueOf(m_midPoints[w]))
176: .concat(String.valueOf((double) i));
177: Double oldValue = (Double) m_distribution
178: .remove(key);
179: if (oldValue == null) {
180: m_distribution.put(key, new Double(
181: 1.0 / m_numIntervals));
182: m_sum += 1.0 / m_numIntervals;
183: } else
184: m_distribution.put(key, oldValue);
185: }
186: for (int w = 0; w < m_midPoints.length; w++) {
187: double conf = 0;
188: String key = (String.valueOf(m_midPoints[w]))
189: .concat(String.valueOf((double) i));
190: Double oldValue = (Double) m_distribution
191: .remove(key);
192: if (oldValue != null) {
193: conf = oldValue.doubleValue() / m_sum;
194: m_distribution.put(key, new Double(conf));
195: }
196: }
197: } else {
198: for (int w = 0; w < m_midPoints.length; w++) {
199: String key = (String.valueOf(m_midPoints[w]))
200: .concat(String.valueOf((double) i));
201: m_distribution.put(key, new Double(
202: 1.0 / m_numIntervals));
203: }
204: }
205: }
206:
207: }
208:
209: /**
210: * Constructs an item set of certain length randomly.
211: * This method is used for standard association rule mining.
212: * @param maxLength the number of attributes of the instances
213: * @param actualLength the number of attributes that should be present in the item set
214: * @param randNum the random number generator
215: * @return a randomly constructed item set in form of an int array
216: */
217: public final int[] randomRule(int maxLength, int actualLength,
218: Random randNum) {
219:
220: int[] itemArray = new int[maxLength];
221: for (int k = 0; k < itemArray.length; k++)
222: itemArray[k] = -1;
223: int help = actualLength;
224: if (help == maxLength) {
225: help = 0;
226: for (int h = 0; h < itemArray.length; h++) {
227: itemArray[h] = m_randNum.nextInt((m_instances
228: .attribute(h)).numValues());
229: }
230: }
231: while (help > 0) {
232: int mark = randNum.nextInt(maxLength);
233: if (itemArray[mark] == -1) {
234: help--;
235: itemArray[mark] = m_randNum.nextInt((m_instances
236: .attribute(mark)).numValues());
237: }
238: }
239: return itemArray;
240: }
241:
242: /**
243: * Constructs an item set of certain length randomly.
244: * This method is used for class association rule mining.
245: * @param maxLength the number of attributes of the instances
246: * @param actualLength the number of attributes that should be present in the item set
247: * @param randNum the random number generator
248: * @return a randomly constructed item set in form of an int array
249: */
250: public final int[] randomCARule(int maxLength, int actualLength,
251: Random randNum) {
252:
253: int[] itemArray = new int[maxLength];
254: for (int k = 0; k < itemArray.length; k++)
255: itemArray[k] = -1;
256: if (actualLength == 1)
257: return itemArray;
258: int help = actualLength - 1;
259: if (help == maxLength - 1) {
260: help = 0;
261: for (int h = 0; h < itemArray.length; h++) {
262: if (h != m_instances.classIndex()) {
263: itemArray[h] = m_randNum.nextInt((m_instances
264: .attribute(h)).numValues());
265: }
266: }
267: }
268: while (help > 0) {
269: int mark = randNum.nextInt(maxLength);
270: if (itemArray[mark] == -1
271: && mark != m_instances.classIndex()) {
272: help--;
273: itemArray[mark] = m_randNum.nextInt((m_instances
274: .attribute(mark)).numValues());
275: }
276: }
277: return itemArray;
278: }
279:
280: /**
281: * updates the distribution of the confidence values.
282: * For every confidence value the interval to which it belongs is searched
283: * and the confidence is added to the confidence already found in this
284: * interval.
285: * @param conf the confidence of the randomly created rule
286: * @param length the legnth of the randomly created rule
287: */
288: public final void buildDistribution(double conf, double length) {
289:
290: double mPoint = findIntervall(conf);
291: String key = (String.valueOf(mPoint)).concat(String
292: .valueOf(length));
293: m_sum += conf;
294: Double oldValue = (Double) m_distribution.remove(key);
295: if (oldValue != null)
296: conf = conf + oldValue.doubleValue();
297: m_distribution.put(key, new Double(conf));
298:
299: }
300:
301: /**
302: * searches the mid point of the interval a given confidence value falls into
303: * @param conf the confidence of a rule
304: * @return the mid point of the interval the confidence belongs to
305: */
306: public final double findIntervall(double conf) {
307:
308: if (conf == 1.0)
309: return m_midPoints[m_midPoints.length - 1];
310: int end = m_midPoints.length - 1;
311: int start = 0;
312: while (Math.abs(end - start) > 1) {
313: int mid = (start + end) / 2;
314: if (conf > m_midPoints[mid])
315: start = mid + 1;
316: if (conf < m_midPoints[mid])
317: end = mid - 1;
318: if (conf == m_midPoints[mid])
319: return m_midPoints[mid];
320: }
321: if (Math.abs(conf - m_midPoints[start]) <= Math.abs(conf
322: - m_midPoints[end]))
323: return m_midPoints[start];
324: else
325: return m_midPoints[end];
326: }
327:
328: /**
329: * calculates the numerator and the denominator of the prior equation
330: * @param weighted indicates whether the numerator or the denominator is calculated
331: * @param mPoint the mid Point of an interval
332: * @return the numerator or denominator of the prior equation
333: */
334: public final double calculatePriorSum(boolean weighted,
335: double mPoint) {
336:
337: double distr, sum = 0, max = logbinomialCoefficient(m_instances
338: .numAttributes(), (int) m_instances.numAttributes() / 2);
339:
340: for (int i = 1; i <= m_instances.numAttributes(); i++) {
341:
342: if (weighted) {
343: String key = (String.valueOf(mPoint)).concat(String
344: .valueOf((double) i));
345: Double hashValue = (Double) m_distribution.get(key);
346:
347: if (hashValue != null)
348: distr = hashValue.doubleValue();
349: else
350: distr = 0;
351: //distr = 1.0/m_numIntervals;
352: if (distr != 0) {
353: double addend = Utils.log2(distr)
354: - max
355: + Utils.log2((Math.pow(2, i) - 1))
356: + logbinomialCoefficient(m_instances
357: .numAttributes(), i);
358: sum = sum + Math.pow(2, addend);
359: }
360: } else {
361: double addend = Utils.log2((Math.pow(2, i) - 1))
362: - max
363: + logbinomialCoefficient(m_instances
364: .numAttributes(), i);
365: sum = sum + Math.pow(2, addend);
366: }
367: }
368: return sum;
369: }
370:
371: /**
372: * Method that calculates the base 2 logarithm of a binomial coefficient
373: * @param upperIndex upper Inedx of the binomial coefficient
374: * @param lowerIndex lower index of the binomial coefficient
375: * @return the base 2 logarithm of the binomial coefficient
376: */
377: public static final double logbinomialCoefficient(int upperIndex,
378: int lowerIndex) {
379:
380: double result = 1.0;
381: if (upperIndex == lowerIndex || lowerIndex == 0)
382: return result;
383: result = SpecialFunctions.log2Binomial((double) upperIndex,
384: (double) lowerIndex);
385: return result;
386: }
387:
388: /**
389: * Method to estimate the prior probabilities
390: * @throws Exception throws exception if the prior cannot be calculated
391: * @return a hashtable containing the prior probabilities
392: */
393: public final Hashtable estimatePrior() throws Exception {
394:
395: double distr, prior, denominator, mPoint;
396:
397: Hashtable m_priors = new Hashtable(m_numIntervals);
398: denominator = calculatePriorSum(false, 1.0);
399: generateDistribution();
400: for (int i = 0; i < m_numIntervals; i++) {
401: mPoint = m_midPoints[i];
402: prior = calculatePriorSum(true, mPoint) / denominator;
403: m_priors.put(new Double(mPoint), new Double(prior));
404: }
405: return m_priors;
406: }
407:
408: /**
409: * split the interval [0,1] into a predefined number of intervals and calculates their mid points
410: */
411: public final void midPoints() {
412:
413: m_midPoints = new double[m_numIntervals];
414: for (int i = 0; i < m_numIntervals; i++)
415: m_midPoints[i] = midPoint(1.0 / m_numIntervals, i);
416: }
417:
418: /**
419: * calculates the mid point of an interval
420: * @param size the size of each interval
421: * @param number the number of the interval.
422: * The intervals are numbered from 0 to m_numIntervals.
423: * @return the mid point of the interval
424: */
425: public double midPoint(double size, int number) {
426:
427: return (size * (double) number) + (size / 2.0);
428: }
429:
430: /**
431: * returns an ordered array of all mid points
432: * @return an ordered array of doubles conatining all midpoints
433: */
434: public final double[] getMidPoints() {
435:
436: return m_midPoints;
437: }
438:
439: /**
440: * splits an item set into premise and consequence and constructs therefore
441: * an association rule. The length of the premise is given. The attributes
442: * for premise and consequence are chosen randomly. The result is a RuleItem.
443: * @param premiseLength the length of the premise
444: * @param itemArray a (randomly generated) item set
445: * @return a randomly generated association rule stored in a RuleItem
446: */
447: public final RuleItem splitItemSet(int premiseLength,
448: int[] itemArray) {
449:
450: int[] cons = new int[m_instances.numAttributes()];
451: System.arraycopy(itemArray, 0, cons, 0, itemArray.length);
452: int help = premiseLength;
453: while (help > 0) {
454: int mark = m_randNum.nextInt(itemArray.length);
455: if (cons[mark] != -1) {
456: help--;
457: cons[mark] = -1;
458: }
459: }
460: if (premiseLength == 0)
461: for (int i = 0; i < itemArray.length; i++)
462: itemArray[i] = -1;
463: else
464: for (int i = 0; i < itemArray.length; i++)
465: if (cons[i] != -1)
466: itemArray[i] = -1;
467: ItemSet premise = new ItemSet(itemArray);
468: ItemSet consequence = new ItemSet(cons);
469: RuleItem current = new RuleItem();
470: current.m_premise = premise;
471: current.m_consequence = consequence;
472: return current;
473: }
474:
475: /**
476: * generates a class association rule out of a given premise.
477: * It randomly chooses a class label as consequence.
478: * @param itemArray the (randomly constructed) premise of the class association rule
479: * @return a class association rule stored in a RuleItem
480: */
481: public final RuleItem addCons(int[] itemArray) {
482:
483: ItemSet premise = new ItemSet(itemArray);
484: int[] cons = new int[itemArray.length];
485: for (int i = 0; i < itemArray.length; i++)
486: cons[i] = -1;
487: cons[m_instances.classIndex()] = m_randNum.nextInt((m_instances
488: .attribute(m_instances.classIndex())).numValues());
489: ItemSet consequence = new ItemSet(cons);
490: RuleItem current = new RuleItem();
491: current.m_premise = premise;
492: current.m_consequence = consequence;
493: return current;
494: }
495:
496: /**
497: * updates the support count of an item set
498: * @param itemSet the item set
499: */
500: public final void updateCounters(ItemSet itemSet) {
501:
502: for (int i = 0; i < m_instances.numInstances(); i++)
503: itemSet.upDateCounter(m_instances.instance(i));
504: }
505:
506: }
|