001: /*
002: * This program is free software; you can redistribute it and/or modify
003: * it under the terms of the GNU General Public License as published by
004: * the Free Software Foundation; either version 2 of the License, or
005: * (at your option) any later version.
006: *
007: * This program is distributed in the hope that it will be useful,
008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
010: * GNU General Public License for more details.
011: *
012: * You should have received a copy of the GNU General Public License
013: * along with this program; if not, write to the Free Software
014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
015: */
016:
017: /*
018: * EstimatorUtils.java
019: * Copyright (C) 2004 University of Waikato, Hamilton, New Zealand
020: *
021: */
022:
023: package weka.estimators;
024:
025: import java.io.FileOutputStream;
026: import java.io.PrintWriter;
027: import java.util.*;
028: import weka.core.*;
029:
030: /**
031: * Contains static utility functions for Estimators.<p>
032: *
033: * @author Gabi Schmidberger (gabi@cs.waikato.ac.nz)
034: * @version $Revision: 1.3 $
035: */
036: public class EstimatorUtils {
037:
038: /**
039: * Find the minimum distance between values
040: * @param inst sorted instances, sorted
041: * @param attrIndex index of the attribute, they are sorted after
042: * @return the minimal distance
043: */
044: public static double findMinDistance(Instances inst, int attrIndex) {
045: double min = Double.MAX_VALUE;
046: int numInst = inst.numInstances();
047: double diff;
048: if (numInst < 2)
049: return min;
050: int begin = -1;
051: Instance instance = null;
052: do {
053: begin++;
054: if (begin < numInst) {
055: instance = inst.instance(begin);
056: }
057: } while (begin < numInst && instance.isMissing(attrIndex));
058:
059: double secondValue = inst.instance(begin).value(attrIndex);
060: for (int i = begin; i < numInst
061: && !inst.instance(i).isMissing(attrIndex); i++) {
062: double firstValue = secondValue;
063: secondValue = inst.instance(i).value(attrIndex);
064: if (secondValue != firstValue) {
065: diff = secondValue - firstValue;
066: if (diff < min && diff > 0.0) {
067: min = diff;
068: }
069: }
070: }
071: return min;
072: }
073:
074: /**
075: * Find the minimum and the maximum of the attribute and return it in
076: * the last parameter..
077: * @param inst instances used to build the estimator
078: * @param attrIndex index of the attribute
079: * @param minMax the array to return minimum and maximum in
080: * @return number of not missing values
081: * @exception Exception if parameter minMax wasn't initialized properly
082: */
083: public static int getMinMax(Instances inst, int attrIndex,
084: double[] minMax) throws Exception {
085: double min = Double.NaN;
086: double max = Double.NaN;
087: Instance instance = null;
088: int numNotMissing = 0;
089: if ((minMax == null) || (minMax.length < 2)) {
090: throw new Exception(
091: "Error in Program, privat method getMinMax");
092: }
093:
094: Enumeration enumInst = inst.enumerateInstances();
095: if (enumInst.hasMoreElements()) {
096: do {
097: instance = (Instance) enumInst.nextElement();
098: } while (instance.isMissing(attrIndex)
099: && (enumInst.hasMoreElements()));
100:
101: // add values if not missing
102: if (!instance.isMissing(attrIndex)) {
103: numNotMissing++;
104: min = instance.value(attrIndex);
105: max = instance.value(attrIndex);
106: }
107: while (enumInst.hasMoreElements()) {
108: instance = (Instance) enumInst.nextElement();
109: if (!instance.isMissing(attrIndex)) {
110: numNotMissing++;
111: if (instance.value(attrIndex) < min) {
112: min = (instance.value(attrIndex));
113: } else {
114: if (instance.value(attrIndex) > max) {
115: max = (instance.value(attrIndex));
116: }
117: }
118: }
119: }
120: }
121: minMax[0] = min;
122: minMax[1] = max;
123: return numNotMissing;
124: }
125:
126: /**
127: * Returns a dataset that contains all instances of a certain class value.
128: *
129: * @param data dataset to select the instances from
130: * @param attrIndex index of the relevant attribute
131: * @param classIndex index of the class attribute
132: * @param classValue the relevant class value
133: * @return a dataset with only
134: */
135: public static Vector getInstancesFromClass(Instances data,
136: int attrIndex, int classIndex, double classValue,
137: Instances workData) {
138: //Oops.pln("getInstancesFromClass classValue"+classValue+" workData"+data.numInstances());
139: Vector dataPlusInfo = new Vector(0);
140: int num = 0;
141: int numClassValue = 0;
142: //workData = new Instances(data, 0);
143: for (int i = 0; i < data.numInstances(); i++) {
144: if (!data.instance(i).isMissing(attrIndex)) {
145: num++;
146: if (data.instance(i).value(classIndex) == classValue) {
147: workData.add(data.instance(i));
148: numClassValue++;
149: }
150: }
151: }
152:
153: Double alphaFactor = new Double((double) numClassValue
154: / (double) num);
155: dataPlusInfo.add(workData);
156: dataPlusInfo.add(alphaFactor);
157: return dataPlusInfo;
158: }
159:
160: /**
161: * Returns a dataset that contains of all instances of a certain class value.
162: * @param data dataset to select the instances from
163: * @param classIndex index of the class attribute
164: * @param classValue the class value
165: * @return a dataset with only instances of one class value
166: */
167: public static Instances getInstancesFromClass(Instances data,
168: int classIndex, double classValue) {
169: Instances workData = new Instances(data, 0);
170: for (int i = 0; i < data.numInstances(); i++) {
171: if (data.instance(i).value(classIndex) == classValue) {
172: workData.add(data.instance(i));
173: }
174:
175: }
176: return workData;
177: }
178:
179: /**
180: * Output of an n points of a density curve.
181: * Filename is parameter f + ".curv".
182: *
183: * @param f string to build filename
184: * @param est
185: * @param min
186: * @param max
187: * @param numPoints
188: * @throws Exception if something goes wrong
189: */
190: public static void writeCurve(String f, Estimator est, double min,
191: double max, int numPoints) throws Exception {
192:
193: PrintWriter output = null;
194: StringBuffer text = new StringBuffer("");
195:
196: if (f.length() != 0) {
197: // add attribute indexnumber to filename and extension .hist
198: String name = f + ".curv";
199: output = new PrintWriter(new FileOutputStream(name));
200: } else {
201: return;
202: }
203:
204: double diff = (max - min) / ((double) numPoints - 1.0);
205: try {
206: text.append("" + min + " " + est.getProbability(min)
207: + " \n");
208:
209: for (double value = min + diff; value < max; value += diff) {
210: text.append("" + value + " "
211: + est.getProbability(value) + " \n");
212: }
213: text.append("" + max + " " + est.getProbability(max)
214: + " \n");
215: } catch (Exception ex) {
216: ex.printStackTrace();
217: System.out.println(ex.getMessage());
218: }
219: output.println(text.toString());
220:
221: // close output
222: if (output != null) {
223: output.close();
224: }
225: }
226:
227: /**
228: * Output of an n points of a density curve.
229: * Filename is parameter f + ".curv".
230: *
231: * @param f string to build filename
232: * @param est
233: * @param classEst
234: * @param classIndex
235: * @param min
236: * @param max
237: * @param numPoints
238: * @throws Exception if something goes wrong
239: */
240: public static void writeCurve(String f, Estimator est,
241: Estimator classEst, double classIndex, double min,
242: double max, int numPoints) throws Exception {
243:
244: PrintWriter output = null;
245: StringBuffer text = new StringBuffer("");
246:
247: if (f.length() != 0) {
248: // add attribute indexnumber to filename and extension .hist
249: String name = f + ".curv";
250: output = new PrintWriter(new FileOutputStream(name));
251: } else {
252: return;
253: }
254:
255: double diff = (max - min) / ((double) numPoints - 1.0);
256: try {
257: text.append("" + min + " " + est.getProbability(min)
258: * classEst.getProbability(classIndex) + " \n");
259:
260: for (double value = min + diff; value < max; value += diff) {
261: text.append("" + value + " "
262: + est.getProbability(value)
263: * classEst.getProbability(classIndex) + " \n");
264: }
265: text.append("" + max + " " + est.getProbability(max)
266: * classEst.getProbability(classIndex) + " \n");
267: } catch (Exception ex) {
268: ex.printStackTrace();
269: System.out.println(ex.getMessage());
270: }
271: output.println(text.toString());
272:
273: // close output
274: if (output != null) {
275: output.close();
276: }
277: }
278:
279: /**
280: * Returns a dataset that contains of all instances of a certain value
281: * for the given attribute.
282: * @param data dataset to select the instances from
283: * @param index the index of the attribute
284: * @param v the value
285: * @return a subdataset with only instances of one value for the attribute
286: */
287: public static Instances getInstancesFromValue(Instances data,
288: int index, double v) {
289: Instances workData = new Instances(data, 0);
290: for (int i = 0; i < data.numInstances(); i++) {
291: if (data.instance(i).value(index) == v) {
292: workData.add(data.instance(i));
293: }
294: }
295: return workData;
296: }
297:
298: /**
299: * Returns a string representing the cutpoints
300: */
301: public static String cutpointsToString(double[] cutPoints,
302: boolean[] cutAndLeft) {
303: StringBuffer text = new StringBuffer("");
304: if (cutPoints == null) {
305: text.append("\n# no cutpoints found - attribute \n");
306: } else {
307: text
308: .append("\n#* " + cutPoints.length
309: + " cutpoint(s) -\n");
310: for (int i = 0; i < cutPoints.length; i++) {
311: text.append("# " + cutPoints[i] + " ");
312: text.append("" + cutAndLeft[i] + "\n");
313: }
314: text.append("# end\n");
315: }
316: return text.toString();
317: }
318:
319: }
|