001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * BayesNet.java
019: * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.bayes.net.estimate;
024:
025: import weka.classifiers.bayes.BayesNet;
026: import weka.classifiers.bayes.net.search.local.K2;
027: import weka.core.Instance;
028: import weka.core.Instances;
029: import weka.core.Option;
030: import weka.core.Statistics;
031: import weka.core.Utils;
032: import weka.estimators.Estimator;
033:
034: import java.util.Enumeration;
035: import java.util.Vector;
036:
037: /**
038: <!-- globalinfo-start -->
039: * BMAEstimator estimates conditional probability tables of a Bayes network using Bayes Model Averaging (BMA).
040: * <p/>
041: <!-- globalinfo-end -->
042: *
043: <!-- options-start -->
044: * Valid options are: <p/>
045: *
046: * <pre> -k2
047: * Whether to use K2 prior.
048: * </pre>
049: *
050: * <pre> -A <alpha>
051: * Initial count (alpha)
052: * </pre>
053: *
054: <!-- options-end -->
055: *
056: * @author Remco Bouckaert (rrb@xm.co.nz)
057: * @version $Revision: 1.7 $
058: */
059: public class BMAEstimator extends SimpleEstimator {
060:
061: /** for serialization */
062: static final long serialVersionUID = -1846028304233257309L;
063:
064: /** whether to use K2 prior */
065: protected boolean m_bUseK2Prior = false;
066:
067: /**
068: * Returns a string describing this object
069: * @return a description of the classifier suitable for
070: * displaying in the explorer/experimenter gui
071: */
072: public String globalInfo() {
073: return "BMAEstimator estimates conditional probability tables of a Bayes "
074: + "network using Bayes Model Averaging (BMA).";
075: }
076:
077: /**
078: * estimateCPTs estimates the conditional probability tables for the Bayes
079: * Net using the network structure.
080: *
081: * @param bayesNet the bayes net to use
082: * @throws Exception if an error occurs
083: */
084: public void estimateCPTs(BayesNet bayesNet) throws Exception {
085: initCPTs(bayesNet);
086:
087: Instances instances = bayesNet.m_Instances;
088: // sanity check to see if nodes have not more than one parent
089: for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
090: if (bayesNet.getParentSet(iAttribute).getNrOfParents() > 1) {
091: throw new Exception(
092: "Cannot handle networks with nodes with more than 1 parent (yet).");
093: }
094: }
095:
096: BayesNet EmptyNet = new BayesNet();
097: K2 oSearchAlgorithm = new K2();
098: oSearchAlgorithm.setInitAsNaiveBayes(false);
099: oSearchAlgorithm.setMaxNrOfParents(0);
100: EmptyNet.setSearchAlgorithm(oSearchAlgorithm);
101: EmptyNet.buildClassifier(instances);
102:
103: BayesNet NBNet = new BayesNet();
104: oSearchAlgorithm.setInitAsNaiveBayes(true);
105: oSearchAlgorithm.setMaxNrOfParents(1);
106: NBNet.setSearchAlgorithm(oSearchAlgorithm);
107: NBNet.buildClassifier(instances);
108:
109: // estimate CPTs
110: for (int iAttribute = 0; iAttribute < instances.numAttributes(); iAttribute++) {
111: if (iAttribute != instances.classIndex()) {
112: double w1 = 0.0, w2 = 0.0;
113: int nAttValues = instances.attribute(iAttribute)
114: .numValues();
115: if (m_bUseK2Prior == true) {
116: // use Cooper and Herskovitz's metric
117: for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
118: w1 += Statistics
119: .lnGamma(1 + ((DiscreteEstimatorBayes) EmptyNet.m_Distributions[iAttribute][0])
120: .getCount(iAttValue))
121: - Statistics.lnGamma(1);
122: }
123: w1 += Statistics.lnGamma(nAttValues)
124: - Statistics.lnGamma(nAttValues
125: + instances.numInstances());
126:
127: for (int iParent = 0; iParent < bayesNet
128: .getParentSet(iAttribute)
129: .getCardinalityOfParents(); iParent++) {
130: int nTotal = 0;
131: for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
132: double nCount = ((DiscreteEstimatorBayes) NBNet.m_Distributions[iAttribute][iParent])
133: .getCount(iAttValue);
134: w2 += Statistics.lnGamma(1 + nCount)
135: - Statistics.lnGamma(1);
136: nTotal += nCount;
137: }
138: w2 += Statistics.lnGamma(nAttValues)
139: - Statistics.lnGamma(nAttValues
140: + nTotal);
141: }
142: } else {
143: // use BDe metric
144: for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
145: w1 += Statistics
146: .lnGamma(1.0
147: / nAttValues
148: + ((DiscreteEstimatorBayes) EmptyNet.m_Distributions[iAttribute][0])
149: .getCount(iAttValue))
150: - Statistics.lnGamma(1.0 / nAttValues);
151: }
152: w1 += Statistics.lnGamma(1)
153: - Statistics.lnGamma(1 + instances
154: .numInstances());
155:
156: int nParentValues = bayesNet.getParentSet(
157: iAttribute).getCardinalityOfParents();
158: for (int iParent = 0; iParent < nParentValues; iParent++) {
159: int nTotal = 0;
160: for (int iAttValue = 0; iAttValue < nAttValues; iAttValue++) {
161: double nCount = ((DiscreteEstimatorBayes) NBNet.m_Distributions[iAttribute][iParent])
162: .getCount(iAttValue);
163: w2 += Statistics.lnGamma(1.0
164: / (nAttValues * nParentValues)
165: + nCount)
166: - Statistics
167: .lnGamma(1.0 / (nAttValues * nParentValues));
168: nTotal += nCount;
169: }
170: w2 += Statistics.lnGamma(1)
171: - Statistics.lnGamma(1 + nTotal);
172: }
173: }
174:
175: // System.out.println(w1 + " " + w2 + " " + (w2 - w1));
176: if (w1 < w2) {
177: w2 = w2 - w1;
178: w1 = 0;
179: w1 = 1 / (1 + Math.exp(w2));
180: w2 = Math.exp(w2) / (1 + Math.exp(w2));
181: } else {
182: w1 = w1 - w2;
183: w2 = 0;
184: w2 = 1 / (1 + Math.exp(w1));
185: w1 = Math.exp(w1) / (1 + Math.exp(w1));
186: }
187:
188: for (int iParent = 0; iParent < bayesNet.getParentSet(
189: iAttribute).getCardinalityOfParents(); iParent++) {
190: bayesNet.m_Distributions[iAttribute][iParent] = new DiscreteEstimatorFullBayes(
191: instances.attribute(iAttribute).numValues(),
192: w1,
193: w2,
194: (DiscreteEstimatorBayes) EmptyNet.m_Distributions[iAttribute][0],
195: (DiscreteEstimatorBayes) NBNet.m_Distributions[iAttribute][iParent],
196: m_fAlpha);
197: }
198: }
199: }
200: int iAttribute = instances.classIndex();
201: bayesNet.m_Distributions[iAttribute][0] = EmptyNet.m_Distributions[iAttribute][0];
202: } // estimateCPTs
203:
204: /**
205: * Updates the classifier with the given instance.
206: *
207: * @param bayesNet the bayes net to use
208: * @param instance the new training instance to include in the model
209: * @throws Exception if the instance could not be incorporated in
210: * the model.
211: */
212: public void updateClassifier(BayesNet bayesNet, Instance instance)
213: throws Exception {
214: throw new Exception(
215: "updateClassifier does not apply to BMA estimator");
216: } // updateClassifier
217:
218: /**
219: * initCPTs reserves space for CPTs and set all counts to zero
220: *
221: * @param bayesNet the bayes net to use
222: * @throws Exception if something goes wrong
223: */
224: public void initCPTs(BayesNet bayesNet) throws Exception {
225: // Reserve space for CPTs
226: int nMaxParentCardinality = 1;
227:
228: for (int iAttribute = 0; iAttribute < bayesNet.m_Instances
229: .numAttributes(); iAttribute++) {
230: if (bayesNet.getParentSet(iAttribute)
231: .getCardinalityOfParents() > nMaxParentCardinality) {
232: nMaxParentCardinality = bayesNet.getParentSet(
233: iAttribute).getCardinalityOfParents();
234: }
235: }
236:
237: // Reserve plenty of memory
238: bayesNet.m_Distributions = new Estimator[bayesNet.m_Instances
239: .numAttributes()][nMaxParentCardinality];
240: } // initCPTs
241:
242: /**
243: * Returns whether K2 prior is used
244: *
245: * @return true if K2 prior is used
246: */
247: public boolean isUseK2Prior() {
248: return m_bUseK2Prior;
249: }
250:
251: /**
252: * Sets the UseK2Prior.
253: *
254: * @param bUseK2Prior The bUseK2Prior to set
255: */
256: public void setUseK2Prior(boolean bUseK2Prior) {
257: m_bUseK2Prior = bUseK2Prior;
258: }
259:
260: /**
261: * Returns an enumeration describing the available options
262: *
263: * @return an enumeration of all the available options
264: */
265: public Enumeration listOptions() {
266: Vector newVector = new Vector(1);
267:
268: newVector.addElement(new Option("\tWhether to use K2 prior.\n",
269: "k2", 0, "-k2"));
270:
271: Enumeration enu = super .listOptions();
272: while (enu.hasMoreElements()) {
273: newVector.addElement(enu.nextElement());
274: }
275:
276: return newVector.elements();
277: } // listOptions
278:
279: /**
280: * Parses a given list of options. <p/>
281: *
282: <!-- options-start -->
283: * Valid options are: <p/>
284: *
285: * <pre> -k2
286: * Whether to use K2 prior.
287: * </pre>
288: *
289: * <pre> -A <alpha>
290: * Initial count (alpha)
291: * </pre>
292: *
293: <!-- options-end -->
294: *
295: * @param options the list of options as an array of strings
296: * @throws Exception if an option is not supported
297: */
298: public void setOptions(String[] options) throws Exception {
299: setUseK2Prior(Utils.getFlag("k2", options));
300:
301: super .setOptions(options);
302: } // setOptions
303:
304: /**
305: * Gets the current settings of the classifier.
306: *
307: * @return an array of strings suitable for passing to setOptions
308: */
309: public String[] getOptions() {
310: String[] super Options = super .getOptions();
311: String[] options = new String[1 + super Options.length];
312: int current = 0;
313:
314: if (isUseK2Prior())
315: options[current++] = "-k2";
316:
317: // insert options from parent class
318: for (int iOption = 0; iOption < super Options.length; iOption++) {
319: options[current++] = super Options[iOption];
320: }
321:
322: // Fill up rest with empty strings, not nulls!
323: while (current < options.length) {
324: options[current++] = "";
325: }
326:
327: return options;
328: } // getOptions
329: } // class BMAEstimator
|