001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * HillClimber.java
019: * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.bayes.net.search.local;
024:
025: import weka.classifiers.bayes.BayesNet;
026: import weka.classifiers.bayes.net.ParentSet;
027: import weka.core.Instances;
028: import weka.core.Option;
029: import weka.core.Utils;
030:
031: import java.io.Serializable;
032: import java.util.Enumeration;
033: import java.util.Vector;
034:
035: /**
036: <!-- globalinfo-start -->
037: * This Bayes Network learning algorithm uses a hill climbing algorithm adding, deleting and reversing arcs. The search is not restricted by an order on the variables (unlike K2). The difference with B and B2 is that this hill climber also considers arrows part of the naive Bayes structure for deletion.
038: * <p/>
039: <!-- globalinfo-end -->
040: *
041: <!-- options-start -->
042: * Valid options are: <p/>
043: *
044: * <pre> -P <nr of parents>
045: * Maximum number of parents</pre>
046: *
047: * <pre> -R
048: * Use arc reversal operation.
049: * (default false)</pre>
050: *
051: * <pre> -N
052: * Initial structure is empty (instead of Naive Bayes)</pre>
053: *
054: * <pre> -mbc
055: * Applies a Markov Blanket correction to the network structure,
056: * after a network structure is learned. This ensures that all
057: * nodes in the network are part of the Markov blanket of the
058: * classifier node.</pre>
059: *
060: * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
061: * Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
062: *
063: <!-- options-end -->
064: *
065: * @author Remco Bouckaert (rrb@xm.co.nz)
066: * @version $Revision: 1.8 $
067: */
068: public class HillClimber extends LocalScoreSearchAlgorithm {
069:
070: /** for serialization */
071: static final long serialVersionUID = 4322783593818122403L;
072:
073: /** the Operation class contains info on operations performed
074: * on the current Bayesian network.
075: */
076: class Operation implements Serializable {
077:
078: /** for serialization */
079: static final long serialVersionUID = -4880888790432547895L;
080:
081: // constants indicating the type of an operation
082: final static int OPERATION_ADD = 0;
083: final static int OPERATION_DEL = 1;
084: final static int OPERATION_REVERSE = 2;
085:
086: /**
087: * c'tor
088: */
089: public Operation() {
090: }
091:
092: /** c'tor + initializers
093: *
094: * @param nTail
095: * @param nHead
096: * @param nOperation
097: */
098: public Operation(int nTail, int nHead, int nOperation) {
099: m_nHead = nHead;
100: m_nTail = nTail;
101: m_nOperation = nOperation;
102: }
103:
104: /** compare this operation with another
105: * @param other operation to compare with
106: * @return true if operation is the same
107: */
108: public boolean equals(Operation other) {
109: if (other == null) {
110: return false;
111: }
112: return ((m_nOperation == other.m_nOperation)
113: && (m_nHead == other.m_nHead) && (m_nTail == other.m_nTail));
114: } // equals
115:
116: /** number of the tail node **/
117: public int m_nTail;
118:
119: /** number of the head node **/
120: public int m_nHead;
121:
122: /** type of operation (ADD, DEL, REVERSE) **/
123: public int m_nOperation;
124:
125: /** change of score due to this operation **/
126: public double m_fDeltaScore = -1E100;
127: } // class Operation
128:
129: /** cache for remembering the change in score for steps in the search space
130: */
131: class Cache {
132: /** change in score due to adding an arc **/
133: double[][] m_fDeltaScoreAdd;
134: /** change in score due to deleting an arc **/
135: double[][] m_fDeltaScoreDel;
136:
137: /** c'tor
138: * @param nNrOfNodes number of nodes in network, used to determine memory size to reserve
139: */
140: Cache(int nNrOfNodes) {
141: m_fDeltaScoreAdd = new double[nNrOfNodes][nNrOfNodes];
142: m_fDeltaScoreDel = new double[nNrOfNodes][nNrOfNodes];
143: }
144:
145: /** set cache entry
146: * @param oOperation operation to perform
147: * @param fValue value to put in cache
148: */
149: public void put(Operation oOperation, double fValue) {
150: if (oOperation.m_nOperation == Operation.OPERATION_ADD) {
151: m_fDeltaScoreAdd[oOperation.m_nTail][oOperation.m_nHead] = fValue;
152: } else {
153: m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead] = fValue;
154: }
155: } // put
156:
157: /** get cache entry
158: * @param oOperation operation to perform
159: * @return cache value
160: */
161: public double get(Operation oOperation) {
162: switch (oOperation.m_nOperation) {
163: case Operation.OPERATION_ADD:
164: return m_fDeltaScoreAdd[oOperation.m_nTail][oOperation.m_nHead];
165: case Operation.OPERATION_DEL:
166: return m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead];
167: case Operation.OPERATION_REVERSE:
168: return m_fDeltaScoreDel[oOperation.m_nTail][oOperation.m_nHead]
169: + m_fDeltaScoreAdd[oOperation.m_nHead][oOperation.m_nTail];
170: }
171: // should never get here
172: return 0;
173: } // get
174: } // class Cache
175:
176: /** cache for storing score differences **/
177: Cache m_Cache = null;
178:
179: /** use the arc reversal operator **/
180: boolean m_bUseArcReversal = false;
181:
182: /**
183: * search determines the network structure/graph of the network
184: * with the Taby algorithm.
185: *
186: * @param bayesNet the network to use
187: * @param instances the data to use
188: * @throws Exception if something goes wrong
189: */
190: protected void search(BayesNet bayesNet, Instances instances)
191: throws Exception {
192: initCache(bayesNet, instances);
193:
194: // go do the search
195: Operation oOperation = getOptimalOperation(bayesNet, instances);
196: while ((oOperation != null) && (oOperation.m_fDeltaScore > 0)) {
197: performOperation(bayesNet, instances, oOperation);
198: oOperation = getOptimalOperation(bayesNet, instances);
199: }
200:
201: // free up memory
202: m_Cache = null;
203: } // search
204:
205: /**
206: * initCache initializes the cache
207: *
208: * @param bayesNet Bayes network to be learned
209: * @param instances data set to learn from
210: * @throws Exception if something goes wrong
211: */
212: void initCache(BayesNet bayesNet, Instances instances)
213: throws Exception {
214:
215: // determine base scores
216: double[] fBaseScores = new double[instances.numAttributes()];
217: int nNrOfAtts = instances.numAttributes();
218:
219: m_Cache = new Cache(nNrOfAtts);
220:
221: for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) {
222: updateCache(iAttribute, nNrOfAtts, bayesNet
223: .getParentSet(iAttribute));
224: }
225:
226: for (int iAttribute = 0; iAttribute < nNrOfAtts; iAttribute++) {
227: fBaseScores[iAttribute] = calcNodeScore(iAttribute);
228: }
229:
230: for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) {
231: for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
232: if (iAttributeHead != iAttributeTail) {
233: Operation oOperation = new Operation(
234: iAttributeTail, iAttributeHead,
235: Operation.OPERATION_ADD);
236: m_Cache.put(oOperation, calcScoreWithExtraParent(
237: iAttributeHead, iAttributeTail)
238: - fBaseScores[iAttributeHead]);
239: }
240: }
241: }
242:
243: } // initCache
244:
245: /** check whether the operation is not in the forbidden.
246: * For base hill climber, there are no restrictions on operations,
247: * so we always return true.
248: * @param oOperation operation to be checked
249: * @return true if operation is not in the tabu list
250: */
251: boolean isNotTabu(Operation oOperation) {
252: return true;
253: } // isNotTabu
254:
255: /**
256: * getOptimalOperation finds the optimal operation that can be performed
257: * on the Bayes network that is not in the tabu list.
258: *
259: * @param bayesNet Bayes network to apply operation on
260: * @param instances data set to learn from
261: * @return optimal operation found
262: * @throws Exception if something goes wrong
263: */
264: Operation getOptimalOperation(BayesNet bayesNet, Instances instances)
265: throws Exception {
266: Operation oBestOperation = new Operation();
267:
268: // Add???
269: oBestOperation = findBestArcToAdd(bayesNet, instances,
270: oBestOperation);
271: // Delete???
272: oBestOperation = findBestArcToDelete(bayesNet, instances,
273: oBestOperation);
274: // Reverse???
275: if (getUseArcReversal()) {
276: oBestOperation = findBestArcToReverse(bayesNet, instances,
277: oBestOperation);
278: }
279:
280: // did we find something?
281: if (oBestOperation.m_fDeltaScore == -1E100) {
282: return null;
283: }
284:
285: return oBestOperation;
286: } // getOptimalOperation
287:
288: /**
289: * performOperation applies an operation
290: * on the Bayes network and update the cache.
291: *
292: * @param bayesNet Bayes network to apply operation on
293: * @param instances data set to learn from
294: * @param oOperation operation to perform
295: * @throws Exception if something goes wrong
296: */
297: void performOperation(BayesNet bayesNet, Instances instances,
298: Operation oOperation) throws Exception {
299: // perform operation
300: switch (oOperation.m_nOperation) {
301: case Operation.OPERATION_ADD:
302: applyArcAddition(bayesNet, oOperation.m_nHead,
303: oOperation.m_nTail, instances);
304: if (bayesNet.getDebug()) {
305: System.out.print("Add " + oOperation.m_nHead + " -> "
306: + oOperation.m_nTail);
307: }
308: break;
309: case Operation.OPERATION_DEL:
310: applyArcDeletion(bayesNet, oOperation.m_nHead,
311: oOperation.m_nTail, instances);
312: if (bayesNet.getDebug()) {
313: System.out.print("Del " + oOperation.m_nHead + " -> "
314: + oOperation.m_nTail);
315: }
316: break;
317: case Operation.OPERATION_REVERSE:
318: applyArcDeletion(bayesNet, oOperation.m_nHead,
319: oOperation.m_nTail, instances);
320: applyArcAddition(bayesNet, oOperation.m_nTail,
321: oOperation.m_nHead, instances);
322: if (bayesNet.getDebug()) {
323: System.out.print("Rev " + oOperation.m_nHead + " -> "
324: + oOperation.m_nTail);
325: }
326: break;
327: }
328: } // performOperation
329:
330: /**
331: *
332: * @param bayesNet
333: * @param iHead
334: * @param iTail
335: * @param instances
336: */
337: void applyArcAddition(BayesNet bayesNet, int iHead, int iTail,
338: Instances instances) {
339: ParentSet bestParentSet = bayesNet.getParentSet(iHead);
340: bestParentSet.addParent(iTail, instances);
341: updateCache(iHead, instances.numAttributes(), bestParentSet);
342: } // applyArcAddition
343:
344: /**
345: *
346: * @param bayesNet
347: * @param iHead
348: * @param iTail
349: * @param instances
350: */
351: void applyArcDeletion(BayesNet bayesNet, int iHead, int iTail,
352: Instances instances) {
353: ParentSet bestParentSet = bayesNet.getParentSet(iHead);
354: bestParentSet.deleteParent(iTail, instances);
355: updateCache(iHead, instances.numAttributes(), bestParentSet);
356: } // applyArcAddition
357:
358: /**
359: * find best (or least bad) arc addition operation
360: *
361: * @param bayesNet Bayes network to add arc to
362: * @param instances data set
363: * @param oBestOperation
364: * @return Operation containing best arc to add, or null if no arc addition is allowed
365: * (this can happen if any arc addition introduces a cycle, or all parent sets are filled
366: * up to the maximum nr of parents).
367: */
368: Operation findBestArcToAdd(BayesNet bayesNet, Instances instances,
369: Operation oBestOperation) {
370: int nNrOfAtts = instances.numAttributes();
371: // find best arc to add
372: for (int iAttributeHead = 0; iAttributeHead < nNrOfAtts; iAttributeHead++) {
373: if (bayesNet.getParentSet(iAttributeHead).getNrOfParents() < m_nMaxNrOfParents) {
374: for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
375: if (addArcMakesSense(bayesNet, instances,
376: iAttributeHead, iAttributeTail)) {
377: Operation oOperation = new Operation(
378: iAttributeTail, iAttributeHead,
379: Operation.OPERATION_ADD);
380: if (m_Cache.get(oOperation) > oBestOperation.m_fDeltaScore) {
381: if (isNotTabu(oOperation)) {
382: oBestOperation = oOperation;
383: oBestOperation.m_fDeltaScore = m_Cache
384: .get(oOperation);
385: }
386: }
387: }
388: }
389: }
390: }
391: return oBestOperation;
392: } // findBestArcToAdd
393:
394: /**
395: * find best (or least bad) arc deletion operation
396: *
397: * @param bayesNet Bayes network to delete arc from
398: * @param instances data set
399: * @param oBestOperation
400: * @return Operation containing best arc to delete, or null if no deletion can be made
401: * (happens when there is no arc in the network yet).
402: */
403: Operation findBestArcToDelete(BayesNet bayesNet,
404: Instances instances, Operation oBestOperation) {
405: int nNrOfAtts = instances.numAttributes();
406: // find best arc to delete
407: for (int iNode = 0; iNode < nNrOfAtts; iNode++) {
408: ParentSet parentSet = bayesNet.getParentSet(iNode);
409: for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) {
410: Operation oOperation = new Operation(parentSet
411: .getParent(iParent), iNode,
412: Operation.OPERATION_DEL);
413: if (m_Cache.get(oOperation) > oBestOperation.m_fDeltaScore) {
414: if (isNotTabu(oOperation)) {
415: oBestOperation = oOperation;
416: oBestOperation.m_fDeltaScore = m_Cache
417: .get(oOperation);
418: }
419: }
420: }
421: }
422: return oBestOperation;
423: } // findBestArcToDelete
424:
425: /**
426: * find best (or least bad) arc reversal operation
427: *
428: * @param bayesNet Bayes network to reverse arc in
429: * @param instances data set
430: * @param oBestOperation
431: * @return Operation containing best arc to reverse, or null if no reversal is allowed
432: * (happens if there is no arc in the network yet, or when any such reversal introduces
433: * a cycle).
434: */
435: Operation findBestArcToReverse(BayesNet bayesNet,
436: Instances instances, Operation oBestOperation) {
437: int nNrOfAtts = instances.numAttributes();
438: // find best arc to reverse
439: for (int iNode = 0; iNode < nNrOfAtts; iNode++) {
440: ParentSet parentSet = bayesNet.getParentSet(iNode);
441: for (int iParent = 0; iParent < parentSet.getNrOfParents(); iParent++) {
442: int iTail = parentSet.getParent(iParent);
443: // is reversal allowed?
444: if (reverseArcMakesSense(bayesNet, instances, iNode,
445: iTail)
446: && bayesNet.getParentSet(iTail)
447: .getNrOfParents() < m_nMaxNrOfParents) {
448: // go check if reversal results in the best step forward
449: Operation oOperation = new Operation(parentSet
450: .getParent(iParent), iNode,
451: Operation.OPERATION_REVERSE);
452: if (m_Cache.get(oOperation) > oBestOperation.m_fDeltaScore) {
453: if (isNotTabu(oOperation)) {
454: oBestOperation = oOperation;
455: oBestOperation.m_fDeltaScore = m_Cache
456: .get(oOperation);
457: }
458: }
459: }
460: }
461: }
462: return oBestOperation;
463: } // findBestArcToReverse
464:
465: /**
466: * update the cache due to change of parent set of a node
467: *
468: * @param iAttributeHead node that has its parent set changed
469: * @param nNrOfAtts number of nodes/attributes in data set
470: * @param parentSet new parents set of node iAttributeHead
471: */
472: void updateCache(int iAttributeHead, int nNrOfAtts,
473: ParentSet parentSet) {
474: // update cache entries for arrows heading towards iAttributeHead
475: double fBaseScore = calcNodeScore(iAttributeHead);
476: int nNrOfParents = parentSet.getNrOfParents();
477: for (int iAttributeTail = 0; iAttributeTail < nNrOfAtts; iAttributeTail++) {
478: if (iAttributeTail != iAttributeHead) {
479: if (!parentSet.contains(iAttributeTail)) {
480: // add entries to cache for adding arcs
481: if (nNrOfParents < m_nMaxNrOfParents) {
482: Operation oOperation = new Operation(
483: iAttributeTail, iAttributeHead,
484: Operation.OPERATION_ADD);
485: m_Cache.put(oOperation,
486: calcScoreWithExtraParent(
487: iAttributeHead, iAttributeTail)
488: - fBaseScore);
489: }
490: } else {
491: // add entries to cache for deleting arcs
492: Operation oOperation = new Operation(
493: iAttributeTail, iAttributeHead,
494: Operation.OPERATION_DEL);
495: m_Cache.put(oOperation, calcScoreWithMissingParent(
496: iAttributeHead, iAttributeTail)
497: - fBaseScore);
498: }
499: }
500: }
501: } // updateCache
502:
503: /**
504: * Sets the max number of parents
505: *
506: * @param nMaxNrOfParents the max number of parents
507: */
508: public void setMaxNrOfParents(int nMaxNrOfParents) {
509: m_nMaxNrOfParents = nMaxNrOfParents;
510: }
511:
512: /**
513: * Gets the max number of parents.
514: *
515: * @return the max number of parents
516: */
517: public int getMaxNrOfParents() {
518: return m_nMaxNrOfParents;
519: }
520:
521: /**
522: * Returns an enumeration describing the available options.
523: *
524: * @return an enumeration of all the available options.
525: */
526: public Enumeration listOptions() {
527: Vector newVector = new Vector(2);
528:
529: newVector.addElement(new Option("\tMaximum number of parents",
530: "P", 1, "-P <nr of parents>"));
531: newVector.addElement(new Option(
532: "\tUse arc reversal operation.\n\t(default false)",
533: "R", 0, "-R"));
534: newVector
535: .addElement(new Option(
536: "\tInitial structure is empty (instead of Naive Bayes)",
537: "N", 0, "-N"));
538:
539: Enumeration enu = super .listOptions();
540: while (enu.hasMoreElements()) {
541: newVector.addElement(enu.nextElement());
542: }
543: return newVector.elements();
544: } // listOptions
545:
546: /**
547: * Parses a given list of options. <p/>
548: *
549: <!-- options-start -->
550: * Valid options are: <p/>
551: *
552: * <pre> -P <nr of parents>
553: * Maximum number of parents</pre>
554: *
555: * <pre> -R
556: * Use arc reversal operation.
557: * (default false)</pre>
558: *
559: * <pre> -N
560: * Initial structure is empty (instead of Naive Bayes)</pre>
561: *
562: * <pre> -mbc
563: * Applies a Markov Blanket correction to the network structure,
564: * after a network structure is learned. This ensures that all
565: * nodes in the network are part of the Markov blanket of the
566: * classifier node.</pre>
567: *
568: * <pre> -S [BAYES|MDL|ENTROPY|AIC|CROSS_CLASSIC|CROSS_BAYES]
569: * Score type (BAYES, BDeu, MDL, ENTROPY and AIC)</pre>
570: *
571: <!-- options-end -->
572: *
573: * @param options the list of options as an array of strings
574: * @throws Exception if an option is not supported
575: */
576: public void setOptions(String[] options) throws Exception {
577: setUseArcReversal(Utils.getFlag('R', options));
578:
579: setInitAsNaiveBayes(!(Utils.getFlag('N', options)));
580:
581: String sMaxNrOfParents = Utils.getOption('P', options);
582: if (sMaxNrOfParents.length() != 0) {
583: setMaxNrOfParents(Integer.parseInt(sMaxNrOfParents));
584: } else {
585: setMaxNrOfParents(100000);
586: }
587:
588: super .setOptions(options);
589: } // setOptions
590:
591: /**
592: * Gets the current settings of the search algorithm.
593: *
594: * @return an array of strings suitable for passing to setOptions
595: */
596: public String[] getOptions() {
597: String[] super Options = super .getOptions();
598: String[] options = new String[7 + super Options.length];
599: int current = 0;
600: if (getUseArcReversal()) {
601: options[current++] = "-R";
602: }
603:
604: if (!getInitAsNaiveBayes()) {
605: options[current++] = "-N";
606: }
607:
608: options[current++] = "-P";
609: options[current++] = "" + m_nMaxNrOfParents;
610:
611: // insert options from parent class
612: for (int iOption = 0; iOption < super Options.length; iOption++) {
613: options[current++] = super Options[iOption];
614: }
615:
616: // Fill up rest with empty strings, not nulls!
617: while (current < options.length) {
618: options[current++] = "";
619: }
620: return options;
621: } // getOptions
622:
623: /**
624: * Sets whether to init as naive bayes
625: *
626: * @param bInitAsNaiveBayes whether to init as naive bayes
627: */
628: public void setInitAsNaiveBayes(boolean bInitAsNaiveBayes) {
629: m_bInitAsNaiveBayes = bInitAsNaiveBayes;
630: }
631:
632: /**
633: * Gets whether to init as naive bayes
634: *
635: * @return whether to init as naive bayes
636: */
637: public boolean getInitAsNaiveBayes() {
638: return m_bInitAsNaiveBayes;
639: }
640:
641: /** get use the arc reversal operation
642: * @return whether the arc reversal operation should be used
643: */
644: public boolean getUseArcReversal() {
645: return m_bUseArcReversal;
646: } // getUseArcReversal
647:
648: /** set use the arc reversal operation
649: * @param bUseArcReversal whether the arc reversal operation should be used
650: */
651: public void setUseArcReversal(boolean bUseArcReversal) {
652: m_bUseArcReversal = bUseArcReversal;
653: } // setUseArcReversal
654:
655: /**
656: * This will return a string describing the search algorithm.
657: * @return The string.
658: */
659: public String globalInfo() {
660: return "This Bayes Network learning algorithm uses a hill climbing algorithm "
661: + "adding, deleting and reversing arcs. The search is not restricted by an order "
662: + "on the variables (unlike K2). The difference with B and B2 is that this hill "
663: + "climber also considers arrows part of the naive Bayes structure for deletion.";
664: } // globalInfo
665:
666: /**
667: * @return a string to describe the Use Arc Reversal option.
668: */
669: public String useArcReversalTipText() {
670: return "When set to true, the arc reversal operation is used in the search.";
671: } // useArcReversalTipText
672:
673: } // HillClimber
|