001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: package weka.classifiers.bayes.net.estimate;
018:
019: import weka.classifiers.bayes.BayesNet;
020: import weka.classifiers.bayes.net.search.local.K2;
021: import weka.core.Attribute;
022: import weka.core.FastVector;
023: import weka.core.Instance;
024: import weka.core.Instances;
025: import weka.core.Option;
026: import weka.core.Statistics;
027: import weka.core.Utils;
028: import weka.estimators.Estimator;
029:
030: import java.util.Enumeration;
031: import java.util.Vector;
032:
033: /**
034: <!-- globalinfo-start -->
035: * Multinomial BMA Estimator.
036: * <p/>
037: <!-- globalinfo-end -->
038: *
039: <!-- options-start -->
040: * Valid options are: <p/>
041: *
042: * <pre> -k2
043: * Whether to use K2 prior.
044: * </pre>
045: *
046: * <pre> -A <alpha>
047: * Initial count (alpha)
048: * </pre>
049: *
050: <!-- options-end -->
051: *
052: * @version $Revision: 1.7 $
053: * @author Remco Bouckaert (rrb@xm.co.nz)
054: */
055: public class MultiNomialBMAEstimator extends BayesNetEstimator {
056:
057: /** for serialization */
058: static final long serialVersionUID = 8330705772601586313L;
059:
060: /** whether to use K2 prior */
061: protected boolean m_bUseK2Prior = true;
062:
063: /**
064: * Returns a string describing this object
065: * @return a description of the classifier suitable for
066: * displaying in the explorer/experimenter gui
067: */
068: public String globalInfo() {
069: return "Multinomial BMA Estimator.";
070: }
071:
072: /**
073: * estimateCPTs estimates the conditional probability tables for the Bayes
074: * Net using the network structure.
075: *
076: * @param bayesNet the bayes net to use
077: * @throws Exception if number of parents doesn't fit (more than 1)
078: */
079: public void estimateCPTs(BayesNet bayesNet) throws Exception {
080: initCPTs(bayesNet);
081:
082: // sanity check to see if nodes have not more than one parent
083: for (int iAttribute = 0; iAttribute < bayesNet.m_Instances
084: .numAttributes(); iAttribute++) {
085: if (bayesNet.getParentSet(iAttribute).getNrOfParents() > 1) {
086: throw new Exception(
087: "Cannot handle networks with nodes with more than 1 parent (yet).");
088: }
089: }
090:
091: // filter data to binary
092: Instances instances = new Instances(bayesNet.m_Instances);
093: while (instances.numInstances() > 0) {
094: instances.delete(0);
095: }
096: for (int iAttribute = instances.numAttributes() - 1; iAttribute >= 0; iAttribute--) {
097: if (iAttribute != instances.classIndex()) {
098: FastVector values = new FastVector();
099: values.addElement("0");
100: values.addElement("1");
101: Attribute a = new Attribute(instances.attribute(
102: iAttribute).name(), (FastVector) values);
103: instances.deleteAttributeAt(iAttribute);
104: instances.insertAttributeAt(a, iAttribute);
105: }
106: }
107:
108: for (int iInstance = 0; iInstance < bayesNet.m_Instances
109: .numInstances(); iInstance++) {
110: Instance instanceOrig = bayesNet.m_Instances
111: .instance(iInstance);
112: Instance instance = new Instance(instances.numAttributes());
113: for (int iAttribute = 0; iAttribute < instances
114: .numAttributes(); iAttribute++) {
115: if (iAttribute != instances.classIndex()) {
116: if (instanceOrig.value(iAttribute) > 0) {
117: instance.setValue(iAttribute, 1);
118: }
119: } else {
120: instance.setValue(iAttribute, instanceOrig
121: .value(iAttribute));
122: }
123: }
124: }
125: // ok, now all data is binary, except the class attribute
126: // now learn the empty and tree network
127:
128: BayesNet EmptyNet = new BayesNet();
129: K2 oSearchAlgorithm = new K2();
130: oSearchAlgorithm.setInitAsNaiveBayes(false);
131: oSearchAlgorithm.setMaxNrOfParents(0);
132: EmptyNet.setSearchAlgorithm(oSearchAlgorithm);
133: EmptyNet.buildClassifier(instances);
134:
135: BayesNet NBNet = new BayesNet();
136: oSearchAlgorithm.setInitAsNaiveBayes(true);
137: oSearchAlgorithm.setMaxNrOfParents(1);
138: NBNet.setSearchAlgorithm(oSearchAlgorithm);
139: NBNet.buildClassifier(instances);
140:
141: // estimate CPTs
142: for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
143: if (iAttribute != instances.classIndex()) {
144: double w1 = 0.0, w2 = 0.0;
145: int nAttValues = instances.attribute(iAttribute)
146: .numValues();
147: if (m_bUseK2Prior == true) {
148: // use Cooper and Herskovitz's metric
149: for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
150: w1 += Statistics
151: .lnGamma(1 + ((DiscreteEstimatorBayes) EmptyNet.m_Distributions[iAttribute][0])
152: .getCount(iAttValue))
153: - Statistics.lnGamma(1);
154: }
155: w1 += Statistics.lnGamma(nAttValues)
156: - Statistics.lnGamma(nAttValues
157: + instances.numInstances());
158:
159: for (int iParent = 0; iParent < bayesNet
160: .getParentSet(iAttribute)
161: .getCardinalityOfParents(); iParent++) {
162: int nTotal = 0;
163: for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
164: double nCount = ((DiscreteEstimatorBayes) NBNet.m_Distributions[iAttribute][iParent])
165: .getCount(iAttValue);
166: w2 += Statistics.lnGamma(1 + nCount)
167: - Statistics.lnGamma(1);
168: nTotal += nCount;
169: }
170: w2 += Statistics.lnGamma(nAttValues)
171: - Statistics.lnGamma(nAttValues
172: + nTotal);
173: }
174: } else {
175: // use BDe metric
176: for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
177: w1 += Statistics
178: .lnGamma(1.0
179: / nAttValues
180: + ((DiscreteEstimatorBayes) EmptyNet.m_Distributions[iAttribute][0])
181: .getCount(iAttValue))
182: - Statistics.lnGamma(1.0 / nAttValues);
183: }
184: w1 += Statistics.lnGamma(1)
185: - Statistics.lnGamma(1 + instances
186: .numInstances());
187:
188: int nParentValues = bayesNet.getParentSet(
189: iAttribute).getCardinalityOfParents();
190: for (int iParent = 0; iParent < nParentValues; iParent++) {
191: int nTotal = 0;
192: for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
193: double nCount = ((DiscreteEstimatorBayes) NBNet.m_Distributions[iAttribute][iParent])
194: .getCount(iAttValue);
195: w2 += Statistics.lnGamma(1.0
196: / (nAttValues * nParentValues)
197: + nCount)
198: - Statistics
199: .lnGamma(1.0 / (nAttValues * nParentValues));
200: nTotal += nCount;
201: }
202: w2 += Statistics.lnGamma(1)
203: - Statistics.lnGamma(1 + nTotal);
204: }
205: }
206:
207: // System.out.println(w1 + " " + w2 + " " + (w2 - w1));
208: // normalize weigths
209: if (w1 < w2) {
210: w2 = w2 - w1;
211: w1 = 0;
212: w1 = 1 / (1 + Math.exp(w2));
213: w2 = Math.exp(w2) / (1 + Math.exp(w2));
214: } else {
215: w1 = w1 - w2;
216: w2 = 0;
217: w2 = 1 / (1 + Math.exp(w1));
218: w1 = Math.exp(w1) / (1 + Math.exp(w1));
219: }
220:
221: for (int iParent = 0; iParent < bayesNet.getParentSet(
222: iAttribute).getCardinalityOfParents(); iParent++) {
223: bayesNet.m_Distributions[iAttribute][iParent] = new DiscreteEstimatorFullBayes(
224: instances.attribute(iAttribute).numValues(),
225: w1,
226: w2,
227: (DiscreteEstimatorBayes) EmptyNet.m_Distributions[iAttribute][0],
228: (DiscreteEstimatorBayes) NBNet.m_Distributions[iAttribute][iParent],
229: m_fAlpha);
230: }
231: }
232: }
233: int iAttribute = instances.classIndex();
234: bayesNet.m_Distributions[iAttribute][0] = EmptyNet.m_Distributions[iAttribute][0];
235: } // estimateCPTs
236:
237: /**
238: * Updates the classifier with the given instance.
239: *
240: * @param bayesNet the bayes net to use
241: * @param instance the new training instance to include in the model
242: * @throws Exception if the instance could not be incorporated in
243: * the model.
244: */
245: public void updateClassifier(BayesNet bayesNet, Instance instance)
246: throws Exception {
247: throw new Exception(
248: "updateClassifier does not apply to BMA estimator");
249: } // updateClassifier
250:
251: /**
252: * initCPTs reserves space for CPTs and set all counts to zero
253: *
254: * @param bayesNet the bayes net to use
255: * @throws Exception doesn't apply
256: */
257: public void initCPTs(BayesNet bayesNet) throws Exception {
258: // Reserve sufficient memory
259: bayesNet.m_Distributions = new Estimator[bayesNet.m_Instances
260: .numAttributes()][2];
261: } // initCPTs
262:
263: /**
264: * @return boolean
265: */
266: public boolean isUseK2Prior() {
267: return m_bUseK2Prior;
268: }
269:
270: /**
271: * Sets the UseK2Prior.
272: *
273: * @param bUseK2Prior The bUseK2Prior to set
274: */
275: public void setUseK2Prior(boolean bUseK2Prior) {
276: m_bUseK2Prior = bUseK2Prior;
277: }
278:
279: /**
280: * Calculates the class membership probabilities for the given test
281: * instance.
282: *
283: * @param bayesNet the bayes net to use
284: * @param instance the instance to be classified
285: * @return predicted class probability distribution
286: * @throws Exception if there is a problem generating the prediction
287: */
288: public double[] distributionForInstance(BayesNet bayesNet,
289: Instance instance) throws Exception {
290: Instances instances = bayesNet.m_Instances;
291: int nNumClasses = instances.numClasses();
292: double[] fProbs = new double[nNumClasses];
293:
294: for (int iClass = 0; iClass < nNumClasses; iClass++) {
295: fProbs[iClass] = 1.0;
296: }
297:
298: for (int iClass = 0; iClass < nNumClasses; iClass++) {
299: double logfP = 0;
300:
301: for (int iAttribute = 0; iAttribute < instances
302: .numAttributes(); iAttribute++) {
303: double iCPT = 0;
304:
305: for (int iParent = 0; iParent < bayesNet.getParentSet(
306: iAttribute).getNrOfParents(); iParent++) {
307: int nParent = bayesNet.getParentSet(iAttribute)
308: .getParent(iParent);
309:
310: if (nParent == instances.classIndex()) {
311: iCPT = iCPT * nNumClasses + iClass;
312: } else {
313: iCPT = iCPT
314: * instances.attribute(nParent)
315: .numValues()
316: + instance.value(nParent);
317: }
318: }
319:
320: if (iAttribute == instances.classIndex()) {
321: logfP += Math
322: .log(bayesNet.m_Distributions[iAttribute][(int) iCPT]
323: .getProbability(iClass));
324: } else {
325: logfP += instance.value(iAttribute)
326: * Math
327: .log(bayesNet.m_Distributions[iAttribute][(int) iCPT]
328: .getProbability(instance
329: .value(1)));
330: }
331: }
332:
333: fProbs[iClass] += logfP;
334: }
335:
336: // Find maximum
337: double fMax = fProbs[0];
338: for (int iClass = 0; iClass < nNumClasses; iClass++) {
339: if (fProbs[iClass] > fMax) {
340: fMax = fProbs[iClass];
341: }
342: }
343: // transform from log-space to normal-space
344: for (int iClass = 0; iClass < nNumClasses; iClass++) {
345: fProbs[iClass] = Math.exp(fProbs[iClass] - fMax);
346: }
347:
348: // Display probabilities
349: Utils.normalize(fProbs);
350:
351: return fProbs;
352: } // distributionForInstance
353:
354: /**
355: * Returns an enumeration describing the available options
356: *
357: * @return an enumeration of all the available options
358: */
359: public Enumeration listOptions() {
360: Vector newVector = new Vector(1);
361:
362: newVector.addElement(new Option("\tWhether to use K2 prior.\n",
363: "k2", 0, "-k2"));
364:
365: Enumeration enu = super .listOptions();
366: while (enu.hasMoreElements()) {
367: newVector.addElement(enu.nextElement());
368: }
369:
370: return newVector.elements();
371: } // listOptions
372:
373: /**
374: * Parses a given list of options. <p/>
375: *
376: <!-- options-start -->
377: * Valid options are: <p/>
378: *
379: * <pre> -k2
380: * Whether to use K2 prior.
381: * </pre>
382: *
383: * <pre> -A <alpha>
384: * Initial count (alpha)
385: * </pre>
386: *
387: <!-- options-end -->
388: *
389: * @param options the list of options as an array of strings
390: * @throws Exception if an option is not supported
391: */
392: public void setOptions(String[] options) throws Exception {
393: setUseK2Prior(Utils.getFlag("k2", options));
394:
395: super .setOptions(options);
396: } // setOptions
397:
398: /**
399: * Gets the current settings of the classifier.
400: *
401: * @return an array of strings suitable for passing to setOptions
402: */
403: public String[] getOptions() {
404: String[] super Options = super .getOptions();
405: String[] options = new String[1 + super Options.length];
406: int current = 0;
407:
408: if (isUseK2Prior())
409: options[current++] = "-k2";
410:
411: // insert options from parent class
412: for (int iOption = 0; iOption < super Options.length; iOption++) {
413: options[current++] = super Options[iOption];
414: }
415:
416: // Fill up rest with empty strings, not nulls!
417: while (current < options.length) {
418: options[current++] = "";
419: }
420:
421: return options;
422: } // getOptions
423: } // class MultiNomialBMAEstimator
|