001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * C45Split.java
019: * Copyright (C) 1999 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.trees.j48;
024:
025: import weka.core.Instance;
026: import weka.core.Instances;
027: import weka.core.Utils;
028:
029: import java.util.Enumeration;
030:
031: /**
032: * Class implementing a C4.5-type split on an attribute.
033: *
034: * @author Eibe Frank (eibe@cs.waikato.ac.nz)
035: * @version $Revision: 1.12 $
036: */
037: public class C45Split extends ClassifierSplitModel {
038:
039: /** for serialization */
040: private static final long serialVersionUID = 3064079330067903161L;
041:
042: /** Desired number of branches. */
043: private int m_complexityIndex;
044:
045: /** Attribute to split on. */
046: private int m_attIndex;
047:
048: /** Minimum number of objects in a split. */
049: private int m_minNoObj;
050:
051: /** Value of split point. */
052: private double m_splitPoint;
053:
054: /** InfoGain of split. */
055: private double m_infoGain;
056:
057: /** GainRatio of split. */
058: private double m_gainRatio;
059:
060: /** The sum of the weights of the instances. */
061: private double m_sumOfWeights;
062:
063: /** Number of split points. */
064: private int m_index;
065:
066: /** Static reference to splitting criterion. */
067: private static InfoGainSplitCrit infoGainCrit = new InfoGainSplitCrit();
068:
069: /** Static reference to splitting criterion. */
070: private static GainRatioSplitCrit gainRatioCrit = new GainRatioSplitCrit();
071:
072: /**
073: * Initializes the split model.
074: */
075: public C45Split(int attIndex, int minNoObj, double sumOfWeights) {
076:
077: // Get index of attribute to split on.
078: m_attIndex = attIndex;
079:
080: // Set minimum number of objects.
081: m_minNoObj = minNoObj;
082:
083: // Set the sum of the weights
084: m_sumOfWeights = sumOfWeights;
085: }
086:
087: /**
088: * Creates a C4.5-type split on the given data. Assumes that none of
089: * the class values is missing.
090: *
091: * @exception Exception if something goes wrong
092: */
093: public void buildClassifier(Instances trainInstances)
094: throws Exception {
095:
096: // Initialize the remaining instance variables.
097: m_numSubsets = 0;
098: m_splitPoint = Double.MAX_VALUE;
099: m_infoGain = 0;
100: m_gainRatio = 0;
101:
102: // Different treatment for enumerated and numeric
103: // attributes.
104: if (trainInstances.attribute(m_attIndex).isNominal()) {
105: m_complexityIndex = trainInstances.attribute(m_attIndex)
106: .numValues();
107: m_index = m_complexityIndex;
108: handleEnumeratedAttribute(trainInstances);
109: } else {
110: m_complexityIndex = 2;
111: m_index = 0;
112: trainInstances.sort(trainInstances.attribute(m_attIndex));
113: handleNumericAttribute(trainInstances);
114: }
115: }
116:
117: /**
118: * Returns index of attribute for which split was generated.
119: */
120: public final int attIndex() {
121:
122: return m_attIndex;
123: }
124:
125: /**
126: * Gets class probability for instance.
127: *
128: * @exception Exception if something goes wrong
129: */
130: public final double classProb(int classIndex, Instance instance,
131: int theSubset) throws Exception {
132:
133: if (theSubset <= -1) {
134: double[] weights = weights(instance);
135: if (weights == null) {
136: return m_distribution.prob(classIndex);
137: } else {
138: double prob = 0;
139: for (int i = 0; i < weights.length; i++) {
140: prob += weights[i]
141: * m_distribution.prob(classIndex, i);
142: }
143: return prob;
144: }
145: } else {
146: if (Utils.gr(m_distribution.perBag(theSubset), 0)) {
147: return m_distribution.prob(classIndex, theSubset);
148: } else {
149: return m_distribution.prob(classIndex);
150: }
151: }
152: }
153:
154: /**
155: * Returns coding cost for split (used in rule learner).
156: */
157: public final double codingCost() {
158:
159: return Utils.log2(m_index);
160: }
161:
162: /**
163: * Returns (C4.5-type) gain ratio for the generated split.
164: */
165: public final double gainRatio() {
166: return m_gainRatio;
167: }
168:
169: /**
170: * Creates split on enumerated attribute.
171: *
172: * @exception Exception if something goes wrong
173: */
174: private void handleEnumeratedAttribute(Instances trainInstances)
175: throws Exception {
176:
177: Instance instance;
178:
179: m_distribution = new Distribution(m_complexityIndex,
180: trainInstances.numClasses());
181:
182: // Only Instances with known values are relevant.
183: Enumeration enu = trainInstances.enumerateInstances();
184: while (enu.hasMoreElements()) {
185: instance = (Instance) enu.nextElement();
186: if (!instance.isMissing(m_attIndex))
187: m_distribution.add((int) instance.value(m_attIndex),
188: instance);
189: }
190:
191: // Check if minimum number of Instances in at least two
192: // subsets.
193: if (m_distribution.check(m_minNoObj)) {
194: m_numSubsets = m_complexityIndex;
195: m_infoGain = infoGainCrit.splitCritValue(m_distribution,
196: m_sumOfWeights);
197: m_gainRatio = gainRatioCrit.splitCritValue(m_distribution,
198: m_sumOfWeights, m_infoGain);
199: }
200: }
201:
202: /**
203: * Creates split on numeric attribute.
204: *
205: * @exception Exception if something goes wrong
206: */
207: private void handleNumericAttribute(Instances trainInstances)
208: throws Exception {
209:
210: int firstMiss;
211: int next = 1;
212: int last = 0;
213: int splitIndex = -1;
214: double currentInfoGain;
215: double defaultEnt;
216: double minSplit;
217: Instance instance;
218: int i;
219:
220: // Current attribute is a numeric attribute.
221: m_distribution = new Distribution(2, trainInstances
222: .numClasses());
223:
224: // Only Instances with known values are relevant.
225: Enumeration enu = trainInstances.enumerateInstances();
226: i = 0;
227: while (enu.hasMoreElements()) {
228: instance = (Instance) enu.nextElement();
229: if (instance.isMissing(m_attIndex))
230: break;
231: m_distribution.add(1, instance);
232: i++;
233: }
234: firstMiss = i;
235:
236: // Compute minimum number of Instances required in each
237: // subset.
238: minSplit = 0.1 * (m_distribution.total())
239: / ((double) trainInstances.numClasses());
240: if (Utils.smOrEq(minSplit, m_minNoObj))
241: minSplit = m_minNoObj;
242: else if (Utils.gr(minSplit, 25))
243: minSplit = 25;
244:
245: // Enough Instances with known values?
246: if (Utils.sm((double) firstMiss, 2 * minSplit))
247: return;
248:
249: // Compute values of criteria for all possible split
250: // indices.
251: defaultEnt = infoGainCrit.oldEnt(m_distribution);
252: while (next < firstMiss) {
253:
254: if (trainInstances.instance(next - 1).value(m_attIndex) + 1e-5 < trainInstances
255: .instance(next).value(m_attIndex)) {
256:
257: // Move class values for all Instances up to next
258: // possible split point.
259: m_distribution.shiftRange(1, 0, trainInstances, last,
260: next);
261:
262: // Check if enough Instances in each subset and compute
263: // values for criteria.
264: if (Utils.grOrEq(m_distribution.perBag(0), minSplit)
265: && Utils.grOrEq(m_distribution.perBag(1),
266: minSplit)) {
267: currentInfoGain = infoGainCrit.splitCritValue(
268: m_distribution, m_sumOfWeights, defaultEnt);
269: if (Utils.gr(currentInfoGain, m_infoGain)) {
270: m_infoGain = currentInfoGain;
271: splitIndex = next - 1;
272: }
273: m_index++;
274: }
275: last = next;
276: }
277: next++;
278: }
279:
280: // Was there any useful split?
281: if (m_index == 0)
282: return;
283:
284: // Compute modified information gain for best split.
285: m_infoGain = m_infoGain
286: - (Utils.log2(m_index) / m_sumOfWeights);
287: if (Utils.smOrEq(m_infoGain, 0))
288: return;
289:
290: // Set instance variables' values to values for
291: // best split.
292: m_numSubsets = 2;
293: m_splitPoint = (trainInstances.instance(splitIndex + 1).value(
294: m_attIndex) + trainInstances.instance(splitIndex)
295: .value(m_attIndex)) / 2;
296:
297: // In case we have a numerical precision problem we need to choose the
298: // smaller value
299: if (m_splitPoint == trainInstances.instance(splitIndex + 1)
300: .value(m_attIndex)) {
301: m_splitPoint = trainInstances.instance(splitIndex).value(
302: m_attIndex);
303: }
304:
305: // Restore distributioN for best split.
306: m_distribution = new Distribution(2, trainInstances
307: .numClasses());
308: m_distribution.addRange(0, trainInstances, 0, splitIndex + 1);
309: m_distribution.addRange(1, trainInstances, splitIndex + 1,
310: firstMiss);
311:
312: // Compute modified gain ratio for best split.
313: m_gainRatio = gainRatioCrit.splitCritValue(m_distribution,
314: m_sumOfWeights, m_infoGain);
315: }
316:
317: /**
318: * Returns (C4.5-type) information gain for the generated split.
319: */
320: public final double infoGain() {
321:
322: return m_infoGain;
323: }
324:
325: /**
326: * Prints left side of condition..
327: *
328: * @param data training set.
329: */
330: public final String leftSide(Instances data) {
331:
332: return data.attribute(m_attIndex).name();
333: }
334:
335: /**
336: * Prints the condition satisfied by instances in a subset.
337: *
338: * @param index of subset
339: * @param data training set.
340: */
341: public final String rightSide(int index, Instances data) {
342:
343: StringBuffer text;
344:
345: text = new StringBuffer();
346: if (data.attribute(m_attIndex).isNominal())
347: text
348: .append(" = "
349: + data.attribute(m_attIndex).value(index));
350: else if (index == 0)
351: text.append(" <= " + Utils.doubleToString(m_splitPoint, 6));
352: else
353: text.append(" > " + Utils.doubleToString(m_splitPoint, 6));
354: return text.toString();
355: }
356:
357: /**
358: * Returns a string containing java source code equivalent to the test
359: * made at this node. The instance being tested is called "i".
360: *
361: * @param index index of the nominal value tested
362: * @param data the data containing instance structure info
363: * @return a value of type 'String'
364: */
365: public final String sourceExpression(int index, Instances data) {
366:
367: StringBuffer expr = null;
368: if (index < 0) {
369: return "i[" + m_attIndex + "] == null";
370: }
371: if (data.attribute(m_attIndex).isNominal()) {
372: expr = new StringBuffer("i[");
373: expr.append(m_attIndex).append("]");
374: expr.append(".equals(\"").append(
375: data.attribute(m_attIndex).value(index)).append(
376: "\")");
377: } else {
378: expr = new StringBuffer("((Double) i[");
379: expr.append(m_attIndex).append("])");
380: if (index == 0) {
381: expr.append(".doubleValue() <= ").append(m_splitPoint);
382: } else {
383: expr.append(".doubleValue() > ").append(m_splitPoint);
384: }
385: }
386: return expr.toString();
387: }
388:
389: /**
390: * Sets split point to greatest value in given data smaller or equal to
391: * old split point.
392: * (C4.5 does this for some strange reason).
393: */
394: public final void setSplitPoint(Instances allInstances) {
395:
396: double newSplitPoint = -Double.MAX_VALUE;
397: double tempValue;
398: Instance instance;
399:
400: if ((allInstances.attribute(m_attIndex).isNumeric())
401: && (m_numSubsets > 1)) {
402: Enumeration enu = allInstances.enumerateInstances();
403: while (enu.hasMoreElements()) {
404: instance = (Instance) enu.nextElement();
405: if (!instance.isMissing(m_attIndex)) {
406: tempValue = instance.value(m_attIndex);
407: if (Utils.gr(tempValue, newSplitPoint)
408: && Utils.smOrEq(tempValue, m_splitPoint))
409: newSplitPoint = tempValue;
410: }
411: }
412: m_splitPoint = newSplitPoint;
413: }
414: }
415:
416: /**
417: * Returns the minsAndMaxs of the index.th subset.
418: */
419: public final double[][] minsAndMaxs(Instances data,
420: double[][] minsAndMaxs, int index) {
421:
422: double[][] newMinsAndMaxs = new double[data.numAttributes()][2];
423:
424: for (int i = 0; i < data.numAttributes(); i++) {
425: newMinsAndMaxs[i][0] = minsAndMaxs[i][0];
426: newMinsAndMaxs[i][1] = minsAndMaxs[i][1];
427: if (i == m_attIndex)
428: if (data.attribute(m_attIndex).isNominal())
429: newMinsAndMaxs[m_attIndex][1] = 1;
430: else
431: newMinsAndMaxs[m_attIndex][1 - index] = m_splitPoint;
432: }
433:
434: return newMinsAndMaxs;
435: }
436:
437: /**
438: * Sets distribution associated with model.
439: */
440: public void resetDistribution(Instances data) throws Exception {
441:
442: Instances insts = new Instances(data, data.numInstances());
443: for (int i = 0; i < data.numInstances(); i++) {
444: if (whichSubset(data.instance(i)) > -1) {
445: insts.add(data.instance(i));
446: }
447: }
448: Distribution newD = new Distribution(insts, this );
449: newD.addInstWithUnknown(data, m_attIndex);
450: m_distribution = newD;
451: }
452:
453: /**
454: * Returns weights if instance is assigned to more than one subset.
455: * Returns null if instance is only assigned to one subset.
456: */
457: public final double[] weights(Instance instance) {
458:
459: double[] weights;
460: int i;
461:
462: if (instance.isMissing(m_attIndex)) {
463: weights = new double[m_numSubsets];
464: for (i = 0; i < m_numSubsets; i++)
465: weights[i] = m_distribution.perBag(i)
466: / m_distribution.total();
467: return weights;
468: } else {
469: return null;
470: }
471: }
472:
473: /**
474: * Returns index of subset instance is assigned to.
475: * Returns -1 if instance is assigned to more than one subset.
476: *
477: * @exception Exception if something goes wrong
478: */
479: public final int whichSubset(Instance instance) throws Exception {
480:
481: if (instance.isMissing(m_attIndex))
482: return -1;
483: else {
484: if (instance.attribute(m_attIndex).isNominal())
485: return (int) instance.value(m_attIndex);
486: else if (Utils.smOrEq(instance.value(m_attIndex),
487: m_splitPoint))
488: return 0;
489: else
490: return 1;
491: }
492: }
493: }
|