001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * ResidualSplit.java
019: * Copyright (C) 2003 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.classifiers.trees.lmt;
024:
025: import weka.classifiers.trees.j48.ClassifierSplitModel;
026: import weka.classifiers.trees.j48.Distribution;
027: import weka.core.Attribute;
028: import weka.core.Instance;
029: import weka.core.Instances;
030: import weka.core.Utils;
031:
032: /**
033: * Helper class for logistic model trees (weka.classifiers.trees.lmt.LMT) to implement the
034: * splitting criterion based on residuals of the LogitBoost algorithm.
035: *
036: * @author Niels Landwehr
037: * @version $Revision: 1.3 $
038: */
039: public class ResidualSplit extends ClassifierSplitModel {
040:
041: /** for serialization */
042: private static final long serialVersionUID = -5055883734183713525L;
043:
044: /**The attribute selected for the split*/
045: protected Attribute m_attribute;
046:
047: /**The index of the attribute selected for the split*/
048: protected int m_attIndex;
049:
050: /**Number of instances in the set*/
051: protected int m_numInstances;
052:
053: /**Number of classed*/
054: protected int m_numClasses;
055:
056: /**The set of instances*/
057: protected Instances m_data;
058:
059: /**The Z-values (LogitBoost response) for the set of instances*/
060: protected double[][] m_dataZs;
061:
062: /**The LogitBoost-weights for the set of instances*/
063: protected double[][] m_dataWs;
064:
065: /**The split point (for numeric attributes)*/
066: protected double m_splitPoint;
067:
068: /**
069: *Creates a split object
070: *@param attIndex the index of the attribute to split on
071: */
072: public ResidualSplit(int attIndex) {
073: m_attIndex = attIndex;
074: }
075:
076: /**
077: * Builds the split.
078: * Needs the Z/W values of LogitBoost for the set of instances.
079: */
080: public void buildClassifier(Instances data, double[][] dataZs,
081: double[][] dataWs) throws Exception {
082:
083: m_numClasses = data.numClasses();
084: m_numInstances = data.numInstances();
085: if (m_numInstances == 0)
086: throw new Exception("Can't build split on 0 instances");
087:
088: //save data/Zs/Ws
089: m_data = data;
090: m_dataZs = dataZs;
091: m_dataWs = dataWs;
092: m_attribute = data.attribute(m_attIndex);
093:
094: //determine number of subsets and split point for numeric attributes
095: if (m_attribute.isNominal()) {
096: m_splitPoint = 0.0;
097: m_numSubsets = m_attribute.numValues();
098: } else {
099: getSplitPoint();
100: m_numSubsets = 2;
101: }
102: //create distribution for data
103: m_distribution = new Distribution(data, this );
104: }
105:
106: /**
107: * Selects split point for numeric attribute.
108: */
109: protected boolean getSplitPoint() throws Exception {
110:
111: //compute possible split points
112: double[] splitPoints = new double[m_numInstances];
113: int numSplitPoints = 0;
114:
115: Instances sortedData = new Instances(m_data);
116: sortedData.sort(sortedData.attribute(m_attIndex));
117:
118: double last, current;
119:
120: last = sortedData.instance(0).value(m_attIndex);
121:
122: for (int i = 0; i < m_numInstances - 1; i++) {
123: current = sortedData.instance(i + 1).value(m_attIndex);
124: if (!Utils.eq(current, last)) {
125: splitPoints[numSplitPoints++] = (last + current) / 2.0;
126: }
127: last = current;
128: }
129:
130: //compute entropy for all split points
131: double[] entropyGain = new double[numSplitPoints];
132:
133: for (int i = 0; i < numSplitPoints; i++) {
134: m_splitPoint = splitPoints[i];
135: entropyGain[i] = entropyGain();
136: }
137:
138: //get best entropy gain
139: int bestSplit = -1;
140: double bestGain = -Double.MAX_VALUE;
141:
142: for (int i = 0; i < numSplitPoints; i++) {
143: if (entropyGain[i] > bestGain) {
144: bestGain = entropyGain[i];
145: bestSplit = i;
146: }
147: }
148:
149: if (bestSplit < 0)
150: return false;
151:
152: m_splitPoint = splitPoints[bestSplit];
153: return true;
154: }
155:
156: /**
157: * Computes entropy gain for current split.
158: */
159: public double entropyGain() throws Exception {
160:
161: int numSubsets;
162: if (m_attribute.isNominal()) {
163: numSubsets = m_attribute.numValues();
164: } else {
165: numSubsets = 2;
166: }
167:
168: double[][][] splitDataZs = new double[numSubsets][][];
169: double[][][] splitDataWs = new double[numSubsets][][];
170:
171: //determine size of the subsets
172: int[] subsetSize = new int[numSubsets];
173: for (int i = 0; i < m_numInstances; i++) {
174: int subset = whichSubset(m_data.instance(i));
175: if (subset < 0)
176: throw new Exception(
177: "ResidualSplit: no support for splits on missing values");
178: subsetSize[subset]++;
179: }
180:
181: for (int i = 0; i < numSubsets; i++) {
182: splitDataZs[i] = new double[subsetSize[i]][];
183: splitDataWs[i] = new double[subsetSize[i]][];
184: }
185:
186: int[] subsetCount = new int[numSubsets];
187:
188: //sort Zs/Ws into subsets
189: for (int i = 0; i < m_numInstances; i++) {
190: int subset = whichSubset(m_data.instance(i));
191: splitDataZs[subset][subsetCount[subset]] = m_dataZs[i];
192: splitDataWs[subset][subsetCount[subset]] = m_dataWs[i];
193: subsetCount[subset]++;
194: }
195:
196: //calculate entropy gain
197: double entropyOrig = entropy(m_dataZs, m_dataWs);
198:
199: double entropySplit = 0.0;
200:
201: for (int i = 0; i < numSubsets; i++) {
202: entropySplit += entropy(splitDataZs[i], splitDataWs[i]);
203: }
204:
205: return entropyOrig - entropySplit;
206: }
207:
208: /**
209: * Helper function to compute entropy from Z/W values.
210: */
211: protected double entropy(double[][] dataZs, double[][] dataWs) {
212: //method returns entropy * sumOfWeights
213: double entropy = 0.0;
214: int numInstances = dataZs.length;
215:
216: for (int j = 0; j < m_numClasses; j++) {
217:
218: //compute mean for class
219: double m = 0.0;
220: double sum = 0.0;
221: for (int i = 0; i < numInstances; i++) {
222: m += dataZs[i][j] * dataWs[i][j];
223: sum += dataWs[i][j];
224: }
225: m /= sum;
226:
227: //sum up entropy for class
228: for (int i = 0; i < numInstances; i++) {
229: entropy += dataWs[i][j] * Math.pow(dataZs[i][j] - m, 2);
230: }
231:
232: }
233:
234: return entropy;
235: }
236:
237: /**
238: * Checks if there are at least 2 subsets that contain >= minNumInstances.
239: */
240: public boolean checkModel(int minNumInstances) {
241: //checks if there are at least 2 subsets that contain >= minNumInstances
242: int count = 0;
243: for (int i = 0; i < m_distribution.numBags(); i++) {
244: if (m_distribution.perBag(i) >= minNumInstances)
245: count++;
246: }
247: return (count >= 2);
248: }
249:
250: /**
251: * Returns name of splitting attribute (left side of condition).
252: */
253: public final String leftSide(Instances data) {
254:
255: return data.attribute(m_attIndex).name();
256: }
257:
258: /**
259: * Prints the condition satisfied by instances in a subset.
260: */
261: public final String rightSide(int index, Instances data) {
262:
263: StringBuffer text;
264:
265: text = new StringBuffer();
266: if (data.attribute(m_attIndex).isNominal())
267: text
268: .append(" = "
269: + data.attribute(m_attIndex).value(index));
270: else if (index == 0)
271: text.append(" <= " + Utils.doubleToString(m_splitPoint, 6));
272: else
273: text.append(" > " + Utils.doubleToString(m_splitPoint, 6));
274: return text.toString();
275: }
276:
277: public final int whichSubset(Instance instance) throws Exception {
278:
279: if (instance.isMissing(m_attIndex))
280: return -1;
281: else {
282: if (instance.attribute(m_attIndex).isNominal())
283: return (int) instance.value(m_attIndex);
284: else if (Utils.smOrEq(instance.value(m_attIndex),
285: m_splitPoint))
286: return 0;
287: else
288: return 1;
289: }
290: }
291:
292: /** Method not in use*/
293: public void buildClassifier(Instances data) {
294: //method not in use
295: }
296:
297: /**Method not in use*/
298: public final double[] weights(Instance instance) {
299: //method not in use
300: return null;
301: }
302:
303: /**Method not in use*/
304: public final String sourceExpression(int index, Instances data) {
305: //method not in use
306: return "";
307: }
308:
309: }
|