001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms/*
004: * This program is free software; you can redistribute it and/or modify
005: * it under the terms of the GNU General Public License as published by
006: * the Free Software Foundation; either version 2 of the License, or
007: * (at your option) any later version.
008: *
009: * This program is distributed in the hope that it will be useful,
010: * but WITHOUT ANY WARRANTY; without even the implied warranty of
011: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
012: * GNU General Public License for more details.
013: *
014: * You should have received a copy of the GNU General Public License
015: * along with this program; if not, write to the Free Software
016: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
017: */
018:
019: /*
020: * NBTreeClassifierTree.java
021: * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
022: *
023: */
024:
025: package weka.classifiers.trees.j48;
026:
027: import weka.core.Capabilities;
028: import weka.core.Instances;
029: import weka.core.Capabilities.Capability;
030:
031: /**
032: * Class for handling a naive bayes tree structure used for
033: * classification.
034: *
035: * @author Mark Hall (mhall@cs.waikato.ac.nz)
036: * @version $Revision: 1.5 $
037: */
038: public class NBTreeClassifierTree extends ClassifierTree {
039:
040: /** for serialization */
041: private static final long serialVersionUID = -4472639447877404786L;
042:
043: public NBTreeClassifierTree(ModelSelection toSelectLocModel) {
044: super (toSelectLocModel);
045: }
046:
047: /**
048: * Returns default capabilities of the classifier tree.
049: *
050: * @return the capabilities of this classifier tree
051: */
052: public Capabilities getCapabilities() {
053: Capabilities result = super .getCapabilities();
054:
055: // attributes
056: result.enable(Capability.NOMINAL_ATTRIBUTES);
057: result.enable(Capability.NUMERIC_ATTRIBUTES);
058: result.enable(Capability.DATE_ATTRIBUTES);
059: result.enable(Capability.MISSING_VALUES);
060:
061: // class
062: result.enable(Capability.NOMINAL_CLASS);
063: result.enable(Capability.MISSING_CLASS_VALUES);
064:
065: // instances
066: result.setMinimumNumberInstances(0);
067:
068: return result;
069: }
070:
071: /**
072: * Method for building a naive bayes classifier tree
073: *
074: * @exception Exception if something goes wrong
075: */
076: public void buildClassifier(Instances data) throws Exception {
077: super .buildClassifier(data);
078: cleanup(new Instances(data, 0));
079: assignIDs(-1);
080: }
081:
082: /**
083: * Assigns a uniqe id to every node in the tree.
084: *
085: public int assignIDs(int lastID) {
086:
087: int currLastID = lastID + 1;
088:
089: m_id = currLastID;
090: if (m_sons != null) {
091: for (int i = 0; i < m_sons.length; i++) {
092: currLastID = m_sons[i].assignIDs(currLastID);
093: }
094: }
095: return currLastID;
096: } */
097:
098: /**
099: * Returns a newly created tree.
100: *
101: * @param data the training data
102: * @exception Exception if something goes wrong
103: */
104: protected ClassifierTree getNewTree(Instances data)
105: throws Exception {
106:
107: ClassifierTree newTree = new NBTreeClassifierTree(
108: m_toSelectModel);
109: newTree.buildTree(data, false);
110:
111: return newTree;
112: }
113:
114: /**
115: * Returns a newly created tree.
116: *
117: * @param train the training data
118: * @param test the pruning data.
119: * @exception Exception if something goes wrong
120: */
121: protected ClassifierTree getNewTree(Instances train, Instances test)
122: throws Exception {
123:
124: ClassifierTree newTree = new NBTreeClassifierTree(
125: m_toSelectModel);
126: newTree.buildTree(train, test, false);
127:
128: return newTree;
129: }
130:
131: /**
132: * Print the models at the leaves
133: *
134: * @return textual description of the leaf models
135: */
136: public String printLeafModels() {
137: StringBuffer text = new StringBuffer();
138:
139: if (m_isLeaf) {
140: text.append("\nLeaf number: " + m_id + " ");
141: text.append(m_localModel.toString());
142: text.append("\n");
143: } else {
144: for (int i = 0; i < m_sons.length; i++) {
145: text.append(((NBTreeClassifierTree) m_sons[i])
146: .printLeafModels());
147: }
148: }
149: return text.toString();
150: }
151:
152: /**
153: * Prints tree structure.
154: */
155: public String toString() {
156:
157: try {
158: StringBuffer text = new StringBuffer();
159:
160: if (m_isLeaf) {
161: text.append(": NB");
162: text.append(m_id);
163: } else
164: dumpTreeNB(0, text);
165:
166: text.append("\n" + printLeafModels());
167: text.append("\n\nNumber of Leaves : \t" + numLeaves()
168: + "\n");
169: text.append("\nSize of the tree : \t" + numNodes() + "\n");
170:
171: return text.toString();
172: } catch (Exception e) {
173: e.printStackTrace();
174: return "Can't print nb tree.";
175: }
176: }
177:
178: /**
179: * Help method for printing tree structure.
180: *
181: * @exception Exception if something goes wrong
182: */
183: private void dumpTreeNB(int depth, StringBuffer text)
184: throws Exception {
185:
186: int i, j;
187:
188: for (i = 0; i < m_sons.length; i++) {
189: text.append("\n");
190: ;
191: for (j = 0; j < depth; j++)
192: text.append("| ");
193: text.append(m_localModel.leftSide(m_train));
194: text.append(m_localModel.rightSide(i, m_train));
195: if (m_sons[i].m_isLeaf) {
196: text.append(": NB ");
197: text.append(m_sons[i].m_id);
198: } else
199: ((NBTreeClassifierTree) m_sons[i]).dumpTreeNB(
200: depth + 1, text);
201: }
202: }
203:
204: /**
205: * Returns graph describing the tree.
206: *
207: * @exception Exception if something goes wrong
208: */
209: public String graph() throws Exception {
210:
211: StringBuffer text = new StringBuffer();
212:
213: text.append("digraph J48Tree {\n");
214: if (m_isLeaf) {
215: text.append("N" + m_id + " [label=\"" + "NB model" + "\" "
216: + "shape=box style=filled ");
217: if (m_train != null && m_train.numInstances() > 0) {
218: text.append("data =\n" + m_train + "\n");
219: text.append(",\n");
220:
221: }
222: text.append("]\n");
223: } else {
224: text.append("N" + m_id + " [label=\""
225: + m_localModel.leftSide(m_train) + "\" ");
226: if (m_train != null && m_train.numInstances() > 0) {
227: text.append("data =\n" + m_train + "\n");
228: text.append(",\n");
229: }
230: text.append("]\n");
231: graphTree(text);
232: }
233:
234: return text.toString() + "}\n";
235: }
236:
237: /**
238: * Help method for printing tree structure as a graph.
239: *
240: * @exception Exception if something goes wrong
241: */
242: private void graphTree(StringBuffer text) throws Exception {
243:
244: for (int i = 0; i < m_sons.length; i++) {
245: text.append("N" + m_id + "->" + "N" + m_sons[i].m_id
246: + " [label=\""
247: + m_localModel.rightSide(i, m_train).trim()
248: + "\"]\n");
249: if (m_sons[i].m_isLeaf) {
250: text.append("N" + m_sons[i].m_id + " [label=\""
251: + "NB Model" + "\" "
252: + "shape=box style=filled ");
253: if (m_train != null && m_train.numInstances() > 0) {
254: text.append("data =\n" + m_sons[i].m_train + "\n");
255: text.append(",\n");
256: }
257: text.append("]\n");
258: } else {
259: text.append("N" + m_sons[i].m_id + " [label=\""
260: + m_sons[i].m_localModel.leftSide(m_train)
261: + "\" ");
262: if (m_train != null && m_train.numInstances() > 0) {
263: text.append("data =\n" + m_sons[i].m_train + "\n");
264: text.append(",\n");
265: }
266: text.append("]\n");
267: ((NBTreeClassifierTree) m_sons[i]).graphTree(text);
268: }
269: }
270: }
271: }
|