0001: /*
0002: * This program is free software; you can redistribute it and/or modify
0003: * it under the terms of the GNU General Public License as published by
0004: * the Free Software Foundation; either version 2 of the License, or
0005: * (at your option) any later version.
0006: *
0007: * This program is distributed in the hope that it will be useful,
0008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
0009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
0010: * GNU General Public License for more details.
0011: *
0012: * You should have received a copy of the GNU General Public License
0013: * along with this program; if not, write to the Free Software
0014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
0015: */
0016:
0017: /*
0018: * OSDLCore.java
0019: * Copyright (C) 2004 Stijn Lievens
0020: */
0021:
0022: package weka.classifiers.misc.monotone;
0023:
0024: import weka.classifiers.Classifier;
0025: import weka.core.Capabilities;
0026: import weka.core.Instance;
0027: import weka.core.Instances;
0028: import weka.core.Option;
0029: import weka.core.SelectedTag;
0030: import weka.core.Tag;
0031: import weka.core.TechnicalInformation;
0032: import weka.core.TechnicalInformationHandler;
0033: import weka.core.Utils;
0034: import weka.core.Capabilities.Capability;
0035: import weka.core.TechnicalInformation.Field;
0036: import weka.core.TechnicalInformation.Type;
0037: import weka.estimators.DiscreteEstimator;
0038:
0039: import java.util.Arrays;
0040: import java.util.Enumeration;
0041: import java.util.HashMap;
0042: import java.util.Iterator;
0043: import java.util.Map;
0044: import java.util.Vector;
0045:
0046: /**
0047: <!-- globalinfo-start -->
0048: * This class is an implementation of the Ordinal Stochastic Dominance Learner.<br/>
0049: * Further information regarding the OSDL-algorithm can be found in:<br/>
0050: * <br/>
0051: * S. Lievens, B. De Baets, K. Cao-Van (2006). A Probabilistic Framework for the Design of Instance-Based Supervised Ranking Algorithms in an Ordinal Setting. Annals of Operations Research..<br/>
0052: * <br/>
0053: * Kim Cao-Van (2003). Supervised ranking: from semantics to algorithms.<br/>
0054: * <br/>
0055: * Stijn Lievens (2004). Studie en implementatie van instantie-gebaseerde algoritmen voor gesuperviseerd rangschikken.<br/>
0056: * <br/>
0057: * For more information about supervised ranking, see<br/>
0058: * <br/>
0059: * http://users.ugent.be/~slievens/supervised_ranking.php
0060: * <p/>
0061: <!-- globalinfo-end -->
0062: *
0063: <!-- technical-bibtex-start -->
0064: * BibTeX:
0065: * <pre>
0066: * @article{Lievens2006,
0067: * author = {S. Lievens and B. De Baets and K. Cao-Van},
0068: * journal = {Annals of Operations Research},
0069: * title = {A Probabilistic Framework for the Design of Instance-Based Supervised Ranking Algorithms in an Ordinal Setting},
0070: * year = {2006}
0071: * }
0072: *
0073: * @phdthesis{Cao-Van2003,
0074: * author = {Kim Cao-Van},
0075: * school = {Ghent University},
0076: * title = {Supervised ranking: from semantics to algorithms},
0077: * year = {2003}
0078: * }
0079: *
0080: * @mastersthesis{Lievens2004,
0081: * author = {Stijn Lievens},
0082: * school = {Ghent University},
0083: * title = {Studie en implementatie van instantie-gebaseerde algoritmen voor gesuperviseerd rangschikken},
0084: * year = {2004}
0085: * }
0086: * </pre>
0087: * <p/>
0088: <!-- technical-bibtex-end -->
0089: *
0090: <!-- options-start -->
0091: * Valid options are: <p/>
0092: *
0093: * <pre> -D
0094: * If set, classifier is run in debug mode and
0095: * may output additional info to the console</pre>
0096: *
0097: * <pre> -C <REG|WSUM|MAX|MED|RMED>
0098: * Sets the classification type to be used.
0099: * (Default: MED)</pre>
0100: *
0101: * <pre> -B
0102: * Use the balanced version of the Ordinal Stochastic Dominance Learner</pre>
0103: *
0104: * <pre> -W
0105: * Use the weighted version of the Ordinal Stochastic Dominance Learner</pre>
0106: *
0107: * <pre> -S <value of interpolation parameter>
0108: * Sets the value of the interpolation parameter (not with -W/T/P/L/U)
0109: * (default: 0.5).</pre>
0110: *
0111: * <pre> -T
0112: * Tune the interpolation parameter (not with -W/S)
0113: * (default: off)</pre>
0114: *
0115: * <pre> -L <Lower bound for interpolation parameter>
0116: * Lower bound for the interpolation parameter (not with -W/S)
0117: * (default: 0)</pre>
0118: *
0119: * <pre> -U <Upper bound for interpolation parameter>
0120: * Upper bound for the interpolation parameter (not with -W/S)
0121: * (default: 1)</pre>
0122: *
0123: * <pre> -P <Number of parts>
0124: * Determines the step size for tuning the interpolation
0125: * parameter, nl. (U-L)/P (not with -W/S)
0126: * (default: 10)</pre>
0127: *
0128: <!-- options-end -->
0129: *
0130: * @author Stijn Lievens (stijn.lievens@ugent.be)
0131: * @version $Revision: 1.1 $
0132: */
0133: public abstract class OSDLCore extends Classifier implements
0134: TechnicalInformationHandler {
0135:
0136: /** for serialization */
0137: private static final long serialVersionUID = -9209888846680062897L;
0138:
0139: /**
0140: * Constant indicating that the classification type is
0141: * regression (probabilistic weighted sum).
0142: */
0143: public static final int CT_REGRESSION = 0;
0144:
0145: /**
0146: * Constant indicating that the classification type is
0147: * the probabilistic weighted sum.
0148: */
0149: public static final int CT_WEIGHTED_SUM = 1;
0150:
0151: /**
0152: * Constant indicating that the classification type is
0153: * the mode of the distribution.
0154: */
0155: public static final int CT_MAXPROB = 2;
0156:
0157: /**
0158: * Constant indicating that the classification type is
0159: * the median.
0160: */
0161: public static final int CT_MEDIAN = 3;
0162:
0163: /**
0164: * Constant indicating that the classification type is
0165: * the median, but not rounded to the nearest class.
0166: */
0167: public static final int CT_MEDIAN_REAL = 4;
0168:
0169: /** the classification types */
0170: public static final Tag[] TAGS_CLASSIFICATIONTYPES = {
0171: new Tag(CT_REGRESSION, "REG", "Regression"),
0172: new Tag(CT_WEIGHTED_SUM, "WSUM", "Weighted Sum"),
0173: new Tag(CT_MAXPROB, "MAX", "Maximum probability"),
0174: new Tag(CT_MEDIAN, "MED", "Median"),
0175: new Tag(CT_MEDIAN_REAL, "RMED", "Median without rounding") };
0176:
0177: /**
0178: * The classification type, by default set to CT_MEDIAN.
0179: */
0180: private int m_ctype = CT_MEDIAN;
0181:
0182: /**
0183: * The training examples.
0184: */
0185: private Instances m_train;
0186:
0187: /**
0188: * Collection of (Coordinates,DiscreteEstimator) pairs.
0189: * This Map is build from the training examples.
0190: * The DiscreteEstimator is over the classes.
0191: * Each DiscreteEstimator indicates how many training examples
0192: * there are with the specified classes.
0193: */
0194: private Map m_estimatedDistributions;
0195:
0196: /**
0197: * Collection of (Coordinates,CumulativeDiscreteDistribution) pairs.
0198: * This Map is build from the training examples, and more
0199: * specifically from the previous map.
0200: */
0201: private Map m_estimatedCumulativeDistributions;
0202:
0203: /**
0204: * The interpolationparameter s.
0205: * By default set to 1/2.
0206: */
0207: private double m_s = 0.5;
0208:
0209: /**
0210: * Lower bound for the interpolationparameter s.
0211: * Default value is 0.
0212: */
0213: private double m_sLower = 0.;
0214:
0215: /**
0216: * Upper bound for the interpolationparameter s.
0217: * Default value is 1.
0218: */
0219: private double m_sUpper = 1.0;
0220:
0221: /**
0222: * The number of parts the interval [m_sLower,m_sUpper] is
0223: * divided in, while searching for the best parameter s.
0224: * This thus determines the granularity of the search.
0225: * m_sNrParts + 1 values of the interpolationparameter will
0226: * be tested.
0227: */
0228: private int m_sNrParts = 10;
0229:
0230: /**
0231: * Indicates whether the interpolationparameter is to be tuned
0232: * using leave-one-out cross validation. <code> true </code> if
0233: * this is the case (default is <code> false </code>).
0234: */
0235: private boolean m_tuneInterpolationParameter = false;
0236:
0237: /**
0238: * Indicates whether the current value of the interpolationparamter
0239: * is valid. More specifically if <code>
0240: * m_tuneInterpolationParameter == true </code>, and
0241: * <code> m_InterpolationParameter == false </code>,
0242: * this means that the current interpolation parameter is not valid.
0243: * This parameter is only relevant if <code> m_tuneInterpolationParameter
0244: * == true </code>.
0245: *
0246: * If <code> m_tuneInterpolationParameter </code> and <code>
0247: * m_interpolationParameterValid </code> are both <code> true </code>,
0248: * then <code> m_s </code> should always be between
0249: * <code> m_sLower </code> and <code> m_sUpper </code>.
0250: */
0251: private boolean m_interpolationParameterValid = false;
0252:
0253: /**
0254: * Constant to switch between balanced and unbalanced OSDL.
0255: * <code> true </code> means that one chooses balanced OSDL
0256: * (default: <code> false </code>).
0257: */
0258: private boolean m_balanced = false;
0259:
0260: /**
0261: * Constant to choose the weighted variant of the OSDL algorithm.
0262: */
0263: private boolean m_weighted = false;
0264:
0265: /**
0266: * Coordinates representing the smallest element of the data space.
0267: */
0268: private Coordinates smallestElement;
0269:
0270: /**
0271: * Coordinates representing the biggest element of the data space.
0272: */
0273: private Coordinates biggestElement;
0274:
0275: /**
0276: * Returns a string describing the classifier.
0277: * @return a description suitable for displaying in the
0278: * explorer/experimenter gui
0279: */
0280: public String globalInfo() {
0281: return "This class is an implementation of the Ordinal Stochastic "
0282: + "Dominance Learner.\n"
0283: + "Further information regarding the OSDL-algorithm can be found in:\n\n"
0284: + getTechnicalInformation().toString()
0285: + "\n\n"
0286: + "For more information about supervised ranking, see\n\n"
0287: + "http://users.ugent.be/~slievens/supervised_ranking.php";
0288: }
0289:
0290: /**
0291: * Returns an instance of a TechnicalInformation object, containing
0292: * detailed information about the technical background of this class,
0293: * e.g., paper reference or book this class is based on.
0294: *
0295: * @return the technical information about this class
0296: */
0297: public TechnicalInformation getTechnicalInformation() {
0298: TechnicalInformation result;
0299: TechnicalInformation additional;
0300:
0301: result = new TechnicalInformation(Type.ARTICLE);
0302: result.setValue(Field.AUTHOR,
0303: "S. Lievens and B. De Baets and K. Cao-Van");
0304: result.setValue(Field.YEAR, "2006");
0305: result
0306: .setValue(
0307: Field.TITLE,
0308: "A Probabilistic Framework for the Design of Instance-Based Supervised Ranking Algorithms in an Ordinal Setting");
0309: result.setValue(Field.JOURNAL, "Annals of Operations Research");
0310:
0311: additional = result.add(Type.PHDTHESIS);
0312: additional.setValue(Field.AUTHOR, "Kim Cao-Van");
0313: additional.setValue(Field.YEAR, "2003");
0314: additional.setValue(Field.TITLE,
0315: "Supervised ranking: from semantics to algorithms");
0316: additional.setValue(Field.SCHOOL, "Ghent University");
0317:
0318: additional = result.add(Type.MASTERSTHESIS);
0319: additional.setValue(Field.AUTHOR, "Stijn Lievens");
0320: additional.setValue(Field.YEAR, "2004");
0321: additional
0322: .setValue(
0323: Field.TITLE,
0324: "Studie en implementatie van instantie-gebaseerde algoritmen voor gesuperviseerd rangschikken");
0325: additional.setValue(Field.SCHOOL, "Ghent University");
0326:
0327: return result;
0328: }
0329:
0330: /**
0331: * Returns default capabilities of the classifier.
0332: *
0333: * @return the capabilities of this classifier
0334: */
0335: public Capabilities getCapabilities() {
0336: Capabilities result = super .getCapabilities();
0337:
0338: // attributes
0339: result.enable(Capability.NOMINAL_ATTRIBUTES);
0340:
0341: // class
0342: result.enable(Capability.NOMINAL_CLASS);
0343: result.enable(Capability.MISSING_CLASS_VALUES);
0344:
0345: // instances
0346: result.setMinimumNumberInstances(0);
0347:
0348: return result;
0349: }
0350:
0351: /**
0352: * Classifies a given instance using the current settings
0353: * of the classifier.
0354: *
0355: * @param instance the instance to be classified
0356: * @throws Exception if for some reason no distribution
0357: * could be predicted
0358: * @return the classification for the instance. Depending on the
0359: * settings of the classifier this is a double representing
0360: * a classlabel (internal WEKA format) or a real value in the sense
0361: * of regression.
0362: */
0363: public double classifyInstance(Instance instance) throws Exception {
0364:
0365: try {
0366: return classifyInstance(instance, m_s, m_ctype);
0367: } catch (IllegalArgumentException e) {
0368: throw new AssertionError(e);
0369: }
0370: }
0371:
0372: /**
0373: * Classifies a given instance using the settings in the paramater
0374: * list. This doesn't change the internal settings of the classifier.
0375: * In particular the interpolationparameter <code> m_s </code>
0376: * and the classification type <code> m_ctype </code> are not changed.
0377: *
0378: * @param instance the instance to be classified
0379: * @param s the value of the interpolationparameter to be used
0380: * @param ctype the classification type to be used
0381: * @throws IllegalStateException for some reason no distribution
0382: * could be predicted
0383: * @throws IllegalArgumentException if the interpolation parameter or the
0384: * classification type is not valid
0385: * @return the label assigned to the instance. It is given in internal floating point format.
0386: */
0387: private double classifyInstance(Instance instance, double s,
0388: int ctype) throws IllegalArgumentException,
0389: IllegalStateException {
0390:
0391: if (s < 0 || s > 1) {
0392: throw new IllegalArgumentException(
0393: "Interpolation parameter is not valid " + s);
0394: }
0395:
0396: DiscreteDistribution dist = null;
0397: if (!m_balanced) {
0398: dist = distributionForInstance(instance, s);
0399: } else {
0400: dist = distributionForInstanceBalanced(instance, s);
0401: }
0402:
0403: if (dist == null) {
0404: throw new IllegalStateException(
0405: "Null distribution predicted");
0406: }
0407:
0408: double value = 0;
0409: switch (ctype) {
0410: case CT_REGRESSION:
0411: case CT_WEIGHTED_SUM:
0412: value = dist.mean();
0413: if (ctype == CT_WEIGHTED_SUM) {
0414: value = Math.round(value);
0415: }
0416: break;
0417:
0418: case CT_MAXPROB:
0419: value = dist.modes()[0];
0420: break;
0421:
0422: case CT_MEDIAN:
0423: case CT_MEDIAN_REAL:
0424: value = dist.median();
0425: if (ctype == CT_MEDIAN) {
0426: value = Math.round(value);
0427: }
0428: break;
0429:
0430: default:
0431: throw new IllegalArgumentException(
0432: "Not a valid classification type!");
0433: }
0434: return value;
0435: }
0436:
0437: /**
0438: * Calculates the class probabilities for the given test instance.
0439: * Uses the current settings of the parameters if these are valid.
0440: * If necessary it updates the interpolationparameter first, and hence
0441: * this may change the classifier.
0442: *
0443: * @param instance the instance to be classified
0444: * @return an array of doubles representing the predicted
0445: * probability distribution over the class labels
0446: */
0447: public double[] distributionForInstance(Instance instance) {
0448:
0449: if (m_tuneInterpolationParameter
0450: && !m_interpolationParameterValid) {
0451: tuneInterpolationParameter();
0452: }
0453:
0454: if (!m_balanced) {
0455: return distributionForInstance(instance, m_s).toArray();
0456: }
0457: // balanced variant
0458: return distributionForInstanceBalanced(instance, m_s).toArray();
0459: }
0460:
0461: /**
0462: * Calculates the cumulative class probabilities for the given test
0463: * instance. Uses the current settings of the parameters if these are
0464: * valid. If necessary it updates the interpolationparameter first,
0465: * and hence this may change the classifier.
0466: *
0467: * @param instance the instance to be classified
0468: * @return an array of doubles representing the predicted
0469: * cumulative probability distribution over the class labels
0470: */
0471: public double[] cumulativeDistributionForInstance(Instance instance) {
0472:
0473: if (m_tuneInterpolationParameter
0474: && !m_interpolationParameterValid) {
0475: tuneInterpolationParameter();
0476: }
0477:
0478: if (!m_balanced) {
0479: return cumulativeDistributionForInstance(instance, m_s)
0480: .toArray();
0481: }
0482: return cumulativeDistributionForInstanceBalanced(instance, m_s)
0483: .toArray();
0484: }
0485:
0486: /**
0487: * Calculates the class probabilities for the given test instance.
0488: * Uses the interpolation parameter from the parameterlist, and
0489: * always performs the ordinary or weighted OSDL algorithm,
0490: * according to the current settings of the classifier.
0491: * This method doesn't change the classifier.
0492: *
0493: * @param instance the instance to classify
0494: * @param s value of the interpolationparameter to use
0495: * @return the calculated distribution
0496: */
0497: private DiscreteDistribution distributionForInstance(
0498: Instance instance, double s) {
0499: return new DiscreteDistribution(
0500: cumulativeDistributionForInstance(instance, s));
0501: }
0502:
0503: /**
0504: * Calculates the class probabilities for the given test
0505: * instance. Uses the interpolationparameter from the parameterlist, and
0506: * always performs the balanced OSDL algorithm.
0507: * This method doesn't change the classifier.
0508: *
0509: * @param instance the instance to classify
0510: * @param s value of the interpolationparameter to use
0511: * @return the calculated distribution
0512: */
0513: private DiscreteDistribution distributionForInstanceBalanced(
0514: Instance instance, double s) {
0515:
0516: return new DiscreteDistribution(
0517: cumulativeDistributionForInstanceBalanced(instance, s));
0518: }
0519:
0520: /**
0521: * Calculates the cumulative class probabilities for the given test
0522: * instance. Uses the interpolationparameter from the parameterlist, and
0523: * always performs the ordinary or weighted OSDL algorithm,
0524: * according to the current settings of the classifier.
0525: * This method doesn't change the classifier.
0526: *
0527: * @param instance the instance to classify
0528: * @param s value of the interpolationparameter to use
0529: * @return the calculated distribution
0530: */
0531: private CumulativeDiscreteDistribution cumulativeDistributionForInstance(
0532: Instance instance, double s) {
0533:
0534: Coordinates xc = new Coordinates(instance);
0535: int n = instance.numClasses();
0536: int nrSmaller = 0;
0537: int nrGreater = 0;
0538:
0539: if (!containsSmallestElement()) {
0540: // corresponds to adding the minimal element to the data space
0541: nrSmaller = 1; // avoid division by zero
0542: }
0543:
0544: if (!containsBiggestElement()) {
0545: // corresponds to adding the maximal element to the data space
0546: nrGreater = 1; // avoid division by zero
0547: }
0548:
0549: // Create fMin and fMax
0550: CumulativeDiscreteDistribution fMin = DistributionUtils
0551: .getMinimalCumulativeDiscreteDistribution(n);
0552: CumulativeDiscreteDistribution fMax = DistributionUtils
0553: .getMaximalCumulativeDiscreteDistribution(n);
0554:
0555: // Cycle through all the map of cumulative distribution functions
0556: for (Iterator i = m_estimatedCumulativeDistributions.keySet()
0557: .iterator(); i.hasNext();) {
0558: Coordinates yc = (Coordinates) i.next();
0559: CumulativeDiscreteDistribution cdf = (CumulativeDiscreteDistribution) m_estimatedCumulativeDistributions
0560: .get(yc);
0561:
0562: if (yc.equals(xc)) {
0563: nrSmaller++;
0564: fMin = DistributionUtils.takeMin(fMin, cdf);
0565: nrGreater++;
0566: fMax = DistributionUtils.takeMax(fMax, cdf);
0567: } else if (yc.strictlySmaller(xc)) {
0568: nrSmaller++;
0569: fMin = DistributionUtils.takeMin(fMin, cdf);
0570: } else if (xc.strictlySmaller(yc)) {
0571: nrGreater++;
0572: fMax = DistributionUtils.takeMax(fMax, cdf);
0573: }
0574: }
0575:
0576: if (m_weighted) {
0577: s = ((double) nrSmaller) / (nrSmaller + nrGreater);
0578: if (m_Debug) {
0579: System.err
0580: .println("Weighted OSDL: interpolation parameter"
0581: + " is s = " + s);
0582: }
0583: }
0584:
0585: // calculate s*fMin + (1-s)*fMax
0586: return DistributionUtils.interpolate(fMin, fMax, 1 - s);
0587: }
0588:
0589: /**
0590: * @return true if the learning examples contain an element for which
0591: * the coordinates are the minimal element of the data space, false
0592: * otherwise
0593: */
0594: private boolean containsSmallestElement() {
0595: return m_estimatedCumulativeDistributions
0596: .containsKey(smallestElement);
0597: }
0598:
0599: /**
0600: * @return true if the learning examples contain an element for which
0601: * the coordinates are the maximal element of the data space, false
0602: * otherwise
0603: */
0604: private boolean containsBiggestElement() {
0605: return m_estimatedCumulativeDistributions
0606: .containsKey(biggestElement);
0607: }
0608:
0609: /**
0610: * Calculates the cumulative class probabilities for the given test
0611: * instance. Uses the interpolationparameter from the parameterlist, and
0612: * always performs the single or double balanced OSDL algorithm.
0613: * This method doesn't change the classifier.
0614: *
0615: * @param instance the instance to classify
0616: * @param s value of the interpolationparameter to use
0617: * @return the calculated distribution
0618: */
0619: private CumulativeDiscreteDistribution cumulativeDistributionForInstanceBalanced(
0620: Instance instance, double s) {
0621:
0622: Coordinates xc = new Coordinates(instance);
0623: int n = instance.numClasses();
0624:
0625: // n_m[i] represents the number of examples smaller or equal
0626: // than xc and with a class label strictly greater than i
0627: int[] n_m = new int[n];
0628:
0629: // n_M[i] represents the number of examples greater or equal
0630: // than xc and with a class label smaller or equal than i
0631: int[] n_M = new int[n];
0632:
0633: // Create fMin and fMax
0634: CumulativeDiscreteDistribution fMin = DistributionUtils
0635: .getMinimalCumulativeDiscreteDistribution(n);
0636: CumulativeDiscreteDistribution fMax = DistributionUtils
0637: .getMaximalCumulativeDiscreteDistribution(n);
0638:
0639: // Cycle through all the map of cumulative distribution functions
0640: for (Iterator i = m_estimatedCumulativeDistributions.keySet()
0641: .iterator(); i.hasNext();) {
0642: Coordinates yc = (Coordinates) i.next();
0643: CumulativeDiscreteDistribution cdf = (CumulativeDiscreteDistribution) m_estimatedCumulativeDistributions
0644: .get(yc);
0645:
0646: if (yc.equals(xc)) {
0647: // update n_m and n_M
0648: DiscreteEstimator df = (DiscreteEstimator) m_estimatedDistributions
0649: .get(yc);
0650: updateN_m(n_m, df);
0651: updateN_M(n_M, df);
0652:
0653: fMin = DistributionUtils.takeMin(fMin, cdf);
0654: fMax = DistributionUtils.takeMax(fMax, cdf);
0655: } else if (yc.strictlySmaller(xc)) {
0656: // update n_m
0657: DiscreteEstimator df = (DiscreteEstimator) m_estimatedDistributions
0658: .get(yc);
0659: updateN_m(n_m, df);
0660: fMin = DistributionUtils.takeMin(fMin, cdf);
0661: } else if (xc.strictlySmaller(yc)) {
0662: // update n_M
0663: DiscreteEstimator df = (DiscreteEstimator) m_estimatedDistributions
0664: .get(yc);
0665: updateN_M(n_M, df);
0666: fMax = DistributionUtils.takeMax(fMax, cdf);
0667: }
0668: }
0669:
0670: double[] dd = new double[n];
0671:
0672: // for each label decide what formula to use, either using
0673: // n_m[i] and n_M[i] (if fMin[i]<fMax[i]) or using the
0674: // interpolationparameter s or using the double balanced version
0675: for (int i = 0; i < n; i++) {
0676: double fmin = fMin.getCumulativeProbability(i);
0677: double fmax = fMax.getCumulativeProbability(i);
0678:
0679: if (m_weighted == true) { // double balanced version
0680: if (fmin < fmax) { // reversed preference
0681: dd[i] = (n_m[i] * fmin + n_M[i] * fmax)
0682: / (n_m[i] + n_M[i]);
0683: } else {
0684: if (n_m[i] + n_M[i] == 0) { // avoid division by zero
0685: dd[i] = s * fmin + (1 - s) * fmax;
0686: } else {
0687: dd[i] = (n_M[i] * fmin + n_m[i] * fmax)
0688: / (n_m[i] + n_M[i]);
0689: }
0690: }
0691: } else { // singly balanced version
0692: dd[i] = (fmin < fmax) ? (n_m[i] * fmin + n_M[i] * fmax)
0693: / (n_m[i] + n_M[i]) : s * fmin + (1 - s) * fmax;
0694: }
0695: }
0696: try {
0697: return new CumulativeDiscreteDistribution(dd);
0698: } catch (IllegalArgumentException e) {
0699: // this shouldn't happen.
0700: System.err.println("We tried to create a cumulative "
0701: + "discrete distribution from the following array");
0702: for (int i = 0; i < dd.length; i++) {
0703: System.err.print(dd[i] + " ");
0704: }
0705: System.err.println();
0706: throw new AssertionError(dd);
0707: }
0708: }
0709:
0710: /**
0711: * Update the array n_m using the given <code> DiscreteEstimator </code>.
0712: *
0713: * @param n_m the array n_m that will be updated.
0714: * @param de the <code> DiscreteEstimator </code> that gives the
0715: * count over the different class labels.
0716: */
0717: private void updateN_m(int[] n_m, DiscreteEstimator de) {
0718: int[] tmp = new int[n_m.length];
0719:
0720: // all examples have a class labels strictly greater
0721: // than 0, except those that have class label 0.
0722: tmp[0] = (int) de.getSumOfCounts() - (int) de.getCount(0);
0723: n_m[0] += tmp[0];
0724: for (int i = 1; i < n_m.length; i++) {
0725:
0726: // the examples with a class label strictly greater
0727: // than i are exactly those that have a class label strictly
0728: // greater than i-1, except those that have class label i.
0729: tmp[i] = tmp[i - 1] - (int) de.getCount(i);
0730: n_m[i] += tmp[i];
0731: }
0732:
0733: if (n_m[n_m.length - 1] != 0) {
0734: // this shouldn't happen
0735: System.err.println("******** Problem with n_m in "
0736: + m_train.relationName());
0737: System.err.println("Last argument is non-zero, namely : "
0738: + n_m[n_m.length - 1]);
0739: }
0740: }
0741:
0742: /**
0743: * Update the array n_M using the given <code> DiscreteEstimator </code>.
0744: *
0745: * @param n_M the array n_M that will be updated.
0746: * @param de the <code> DiscreteEstimator </code> that gives the
0747: * count over the different class labels.
0748: */
0749: private void updateN_M(int[] n_M, DiscreteEstimator de) {
0750: int n = n_M.length;
0751: int[] tmp = new int[n];
0752:
0753: // all examples have a class label smaller or equal
0754: // than n-1 (which is the maximum class label)
0755: tmp[n - 1] = (int) de.getSumOfCounts();
0756: n_M[n - 1] += tmp[n - 1];
0757: for (int i = n - 2; i >= 0; i--) {
0758:
0759: // the examples with a class label smaller or equal
0760: // than i are exactly those that have a class label
0761: // smaller or equal than i+1, except those that have
0762: // class label i+1.
0763: tmp[i] = tmp[i + 1] - (int) de.getCount(i + 1);
0764: n_M[i] += tmp[i];
0765: }
0766: }
0767:
0768: /**
0769: * Builds the classifier.
0770: * This means that all relevant examples are stored into memory.
0771: * If necessary the interpolation parameter is tuned.
0772: *
0773: * @param instances the instances to be used for building the classifier
0774: * @throws Exception if the classifier can't be built successfully
0775: */
0776: public void buildClassifier(Instances instances) throws Exception {
0777:
0778: getCapabilities().testWithFail(instances);
0779:
0780: // copy the dataset
0781: m_train = new Instances(instances);
0782:
0783: // new dataset in which examples with missing class value are removed
0784: m_train.deleteWithMissingClass();
0785:
0786: // build the Map for the estimatedDistributions
0787: m_estimatedDistributions = new HashMap(
0788: m_train.numInstances() / 2);
0789:
0790: // cycle through all instances
0791: for (Iterator it = new EnumerationIterator(instances
0792: .enumerateInstances()); it.hasNext();) {
0793: Instance instance = (Instance) it.next();
0794: Coordinates c = new Coordinates(instance);
0795:
0796: // get DiscreteEstimator from the map
0797: DiscreteEstimator df = (DiscreteEstimator) m_estimatedDistributions
0798: .get(c);
0799:
0800: // if no DiscreteEstimator is present in the map, create one
0801: if (df == null) {
0802: df = new DiscreteEstimator(instances.numClasses(), 0);
0803: }
0804: df.addValue(instance.classValue(), instance.weight()); // update
0805: m_estimatedDistributions.put(c, df); // put back in map
0806: }
0807:
0808: // build the map of cumulative distribution functions
0809: m_estimatedCumulativeDistributions = new HashMap(
0810: m_estimatedDistributions.size() / 2);
0811:
0812: // Cycle trough the map of discrete distributions, and create a new
0813: // one containing cumulative discrete distributions
0814: for (Iterator it = m_estimatedDistributions.keySet().iterator(); it
0815: .hasNext();) {
0816: Coordinates c = (Coordinates) it.next();
0817: DiscreteEstimator df = (DiscreteEstimator) m_estimatedDistributions
0818: .get(c);
0819: m_estimatedCumulativeDistributions.put(c,
0820: new CumulativeDiscreteDistribution(df));
0821: }
0822:
0823: // check if the interpolation parameter needs to be tuned
0824: if (m_tuneInterpolationParameter
0825: && !m_interpolationParameterValid) {
0826: tuneInterpolationParameter();
0827: }
0828:
0829: // fill in the smallest and biggest element (for use in the
0830: // quasi monotone version of the algorithm)
0831: double[] tmpAttValues = new double[instances.numAttributes()];
0832: Instance instance = new Instance(1, tmpAttValues);
0833: instance.setDataset(instances);
0834: smallestElement = new Coordinates(instance);
0835: if (m_Debug) {
0836: System.err.println("minimal element of data space = "
0837: + smallestElement);
0838: }
0839: for (int i = 0; i < tmpAttValues.length; i++) {
0840: tmpAttValues[i] = instances.attribute(i).numValues() - 1;
0841: }
0842:
0843: instance = new Instance(1, tmpAttValues);
0844: instance.setDataset(instances);
0845: biggestElement = new Coordinates(instance);
0846: if (m_Debug) {
0847: System.err.println("maximal element of data space = "
0848: + biggestElement);
0849: }
0850: }
0851:
0852: /**
0853: * Returns the tip text for this property.
0854: *
0855: * @return tip text for this property suitable for
0856: * displaying in the explorer/experimenter gui
0857: */
0858: public String classificationTypeTipText() {
0859: return "Sets the way in which a single label will be extracted "
0860: + "from the estimated distribution.";
0861: }
0862:
0863: /**
0864: * Sets the classification type. Currently <code> ctype </code>
0865: * must be one of:
0866: * <ul>
0867: * <li> <code> CT_REGRESSION </code> : use expectation value of
0868: * distribution. (Non-ordinal in nature).
0869: * <li> <code> CT_WEIGHTED_SUM </code> : use expectation value of
0870: * distribution rounded to nearest class label. (Non-ordinal in
0871: * nature).
0872: * <li> <code> CT_MAXPROB </code> : use the mode of the distribution.
0873: * (May deliver non-monotone results).
0874: * <li> <code> CT_MEDIAN </code> : use the median of the distribution
0875: * (rounded to the nearest class label).
0876: * <li> <code> CT_MEDIAN_REAL </code> : use the median of the distribution
0877: * but not rounded to the nearest class label.
0878: * </ul>
0879: *
0880: * @param value the classification type
0881: */
0882: public void setClassificationType(SelectedTag value) {
0883: if (value.getTags() == TAGS_CLASSIFICATIONTYPES)
0884: m_ctype = value.getSelectedTag().getID();
0885: }
0886:
0887: /**
0888: * Returns the classification type.
0889: *
0890: * @return the classification type
0891: */
0892: public SelectedTag getClassificationType() {
0893: return new SelectedTag(m_ctype, TAGS_CLASSIFICATIONTYPES);
0894: }
0895:
0896: /**
0897: * Returns the tip text for this property.
0898: *
0899: * @return tip text for this property suitable for
0900: * displaying in the explorer/experimenter gui
0901: */
0902: public String tuneInterpolationParameterTipText() {
0903: return "Whether to tune the interpolation parameter based on the bounds.";
0904: }
0905:
0906: /**
0907: * Sets whether the interpolation parameter is to be tuned based on the
0908: * bounds.
0909: *
0910: * @param value if true the parameter is tuned
0911: */
0912: public void setTuneInterpolationParameter(boolean value) {
0913: m_tuneInterpolationParameter = value;
0914: }
0915:
0916: /**
0917: * Returns whether the interpolation parameter is to be tuned based on the
0918: * bounds.
0919: *
0920: * @return true if the parameter is to be tuned
0921: */
0922: public boolean getTuneInterpolationParameter() {
0923: return m_tuneInterpolationParameter;
0924: }
0925:
0926: /**
0927: * Returns the tip text for this property.
0928: *
0929: * @return tip text for this property suitable for
0930: * displaying in the explorer/experimenter gui
0931: */
0932: public String interpolationParameterLowerBoundTipText() {
0933: return "Sets the lower bound for the interpolation parameter tuning (0 <= x < 1).";
0934: }
0935:
0936: /**
0937: * Sets the lower bound for the interpolation parameter tuning
0938: * (0 <= x < 1).
0939: *
0940: * @param value the tne lower bound
0941: * @throws IllegalArgumentException if bound is invalid
0942: */
0943: public void setInterpolationParameterLowerBound(double value) {
0944: if ((value < 0) || (value >= 1)
0945: || (value > getInterpolationParameterUpperBound()))
0946: throw new IllegalArgumentException("Illegal lower bound");
0947:
0948: m_sLower = value;
0949: m_tuneInterpolationParameter = true;
0950: m_interpolationParameterValid = false;
0951: }
0952:
0953: /**
0954: * Returns the lower bound for the interpolation parameter tuning
0955: * (0 <= x < 1).
0956: *
0957: * @return the lower bound
0958: */
0959: public double getInterpolationParameterLowerBound() {
0960: return m_sLower;
0961: }
0962:
0963: /**
0964: * Returns the tip text for this property.
0965: *
0966: * @return tip text for this property suitable for
0967: * displaying in the explorer/experimenter gui
0968: */
0969: public String interpolationParameterUpperBoundTipText() {
0970: return "Sets the upper bound for the interpolation parameter tuning (0 < x <= 1).";
0971: }
0972:
0973: /**
0974: * Sets the upper bound for the interpolation parameter tuning
0975: * (0 < x <= 1).
0976: *
0977: * @param value the tne upper bound
0978: * @throws IllegalArgumentException if bound is invalid
0979: */
0980: public void setInterpolationParameterUpperBound(double value) {
0981: if ((value <= 0) || (value > 1)
0982: || (value < getInterpolationParameterLowerBound()))
0983: throw new IllegalArgumentException("Illegal upper bound");
0984:
0985: m_sUpper = value;
0986: m_tuneInterpolationParameter = true;
0987: m_interpolationParameterValid = false;
0988: }
0989:
0990: /**
0991: * Returns the upper bound for the interpolation parameter tuning
0992: * (0 < x <= 1).
0993: *
0994: * @return the upper bound
0995: */
0996: public double getInterpolationParameterUpperBound() {
0997: return m_sUpper;
0998: }
0999:
1000: /**
1001: * Sets the interpolation bounds for the interpolation parameter.
1002: * When tuning the interpolation parameter only values in the interval
1003: * <code> [sLow, sUp] </code> are considered.
1004: * It is important to note that using this method immediately
1005: * implies that the interpolation parameter is to be tuned.
1006: *
1007: * @param sLow lower bound for the interpolation parameter,
1008: * should not be smaller than 0 or greater than <code> sUp </code>
1009: * @param sUp upper bound for the interpolation parameter,
1010: * should not exceed 1 or be smaller than <code> sLow </code>
1011: * @throws IllegalArgumentException if one of the above conditions
1012: * is not satisfied.
1013: */
1014: public void setInterpolationParameterBounds(double sLow, double sUp)
1015: throws IllegalArgumentException {
1016:
1017: if (sLow < 0. || sUp > 1. || sLow > sUp)
1018: throw new IllegalArgumentException(
1019: "Illegal upper and lower bounds");
1020: m_sLower = sLow;
1021: m_sUpper = sUp;
1022: m_tuneInterpolationParameter = true;
1023: m_interpolationParameterValid = false;
1024: }
1025:
1026: /**
1027: * Returns the tip text for this property.
1028: *
1029: * @return tip text for this property suitable for
1030: * displaying in the explorer/experimenter gui
1031: */
1032: public String interpolationParameterTipText() {
1033: return "Sets the value of the interpolation parameter s;"
1034: + "Estimated distribution is s * f_min + (1 - s) * f_max. ";
1035: }
1036:
1037: /**
1038: * Sets the interpolation parameter. This immediately means that
1039: * the interpolation parameter is not to be tuned.
1040: *
1041: * @param s value for the interpolation parameter.
1042: * @throws IllegalArgumentException if <code> s </code> is not in
1043: * the range [0,1].
1044: */
1045: public void setInterpolationParameter(double s)
1046: throws IllegalArgumentException {
1047:
1048: if (0 > s || s > 1)
1049: throw new IllegalArgumentException(
1050: "Interpolationparameter exceeds bounds");
1051: m_tuneInterpolationParameter = false;
1052: m_interpolationParameterValid = false;
1053: m_s = s;
1054: }
1055:
1056: /**
1057: * Returns the current value of the interpolation parameter.
1058: *
1059: * @return the value of the interpolation parameter
1060: */
1061: public double getInterpolationParameter() {
1062: return m_s;
1063: }
1064:
1065: /**
1066: * Returns the tip text for this property.
1067: *
1068: * @return tip text for this property suitable for
1069: * displaying in the explorer/experimenter gui
1070: */
1071: public String numberOfPartsForInterpolationParameterTipText() {
1072: return "Sets the granularity for tuning the interpolation parameter; "
1073: + "For instance if the value is 32 then 33 values for the "
1074: + "interpolation are checked.";
1075: }
1076:
1077: /**
1078: * Sets the granularity for tuning the interpolation parameter.
1079: * The interval between lower and upper bounds for the interpolation
1080: * parameter is divided into <code> sParts </code> parts, i.e.
1081: * <code> sParts + 1 </code> values will be checked when
1082: * <code> tuneInterpolationParameter </code> is invoked.
1083: * This also means that the interpolation parameter is to
1084: * be tuned.
1085: *
1086: * @param sParts the number of parts
1087: * @throws IllegalArgumentException if <code> sParts </code> is
1088: * smaller or equal than 0.
1089: */
1090: public void setNumberOfPartsForInterpolationParameter(int sParts)
1091: throws IllegalArgumentException {
1092:
1093: if (sParts <= 0)
1094: throw new IllegalArgumentException(
1095: "Number of parts is negative");
1096:
1097: m_tuneInterpolationParameter = true;
1098: if (m_sNrParts != sParts) {
1099: m_interpolationParameterValid = false;
1100: m_sNrParts = sParts;
1101: }
1102: }
1103:
1104: /**
1105: * Gets the granularity for tuning the interpolation parameter.
1106: *
1107: * @return the number of parts in which the interval
1108: * <code> [s_low, s_up] </code> is to be split
1109: */
1110: public int getNumberOfPartsForInterpolationParameter() {
1111: return m_sNrParts;
1112: }
1113:
1114: /**
1115: * Returns a string suitable for displaying in the gui/experimenter.
1116: *
1117: * @return tip text for this property suitable for
1118: * displaying in the explorer/experimenter gui
1119: */
1120: public String balancedTipText() {
1121: return "If true, the balanced version of the OSDL-algorithm is used\n"
1122: + "This means that distinction is made between the normal and "
1123: + "reversed preference situation.";
1124: }
1125:
1126: /**
1127: * If <code> balanced </code> is <code> true </code> then the balanced
1128: * version of OSDL will be used, otherwise the ordinary version of
1129: * OSDL will be in effect.
1130: *
1131: * @param balanced if <code> true </code> then B-OSDL is used, otherwise
1132: * it is OSDL
1133: */
1134: public void setBalanced(boolean balanced) {
1135: m_balanced = balanced;
1136: }
1137:
1138: /**
1139: * Returns if the balanced version of OSDL is in effect.
1140: *
1141: * @return <code> true </code> if the balanced version is in effect,
1142: * <code> false </code> otherwise
1143: */
1144: public boolean getBalanced() {
1145: return m_balanced;
1146: }
1147:
1148: /**
1149: * Returns a string suitable for displaying in the gui/experimenter.
1150: *
1151: * @return tip text for this property suitable for
1152: * displaying in the explorer/experimenter gui
1153: */
1154: public String weightedTipText() {
1155: return "If true, the weighted version of the OSDL-algorithm is used";
1156: }
1157:
1158: /**
1159: * If <code> weighted </code> is <code> true </code> then the
1160: * weighted version of the OSDL is used.
1161: * Note: using the weighted (non-balanced) version only ensures the
1162: * quasi monotonicity of the results w.r.t. to training set.
1163: *
1164: * @param weighted <code> true </code> if the weighted version to be used,
1165: * <code> false </code> otherwise
1166: */
1167: public void setWeighted(boolean weighted) {
1168: m_weighted = weighted;
1169: }
1170:
1171: /**
1172: * Returns if the weighted version is in effect.
1173: *
1174: * @return <code> true </code> if the weighted version is in effect,
1175: * <code> false </code> otherwise.
1176: */
1177: public boolean getWeighted() {
1178: return m_weighted;
1179: }
1180:
1181: /**
1182: * Returns the current value of the lower bound for the interpolation
1183: * parameter.
1184: *
1185: * @return the current value of the lower bound for the interpolation
1186: * parameter
1187: */
1188: public double getLowerBound() {
1189: return m_sLower;
1190: }
1191:
1192: /**
1193: * Returns the current value of the upper bound for the interpolation
1194: * parameter.
1195: *
1196: * @return the current value of the upper bound for the interpolation
1197: * parameter
1198: */
1199: public double getUpperBound() {
1200: return m_sUpper;
1201: }
1202:
1203: /**
1204: * Returns the number of instances in the training set.
1205: *
1206: * @return the number of instances used for training
1207: */
1208: public int getNumInstances() {
1209: return m_train.numInstances();
1210: }
1211:
1212: /** Tune the interpolation parameter using the current
1213: * settings of the classifier.
1214: * This also sets the interpolation parameter.
1215: * @return the value of the tuned interpolation parameter.
1216: */
1217: public double tuneInterpolationParameter() {
1218: try {
1219: return tuneInterpolationParameter(m_sLower, m_sUpper,
1220: m_sNrParts, m_ctype);
1221: } catch (IllegalArgumentException e) {
1222: throw new AssertionError(e);
1223: }
1224: }
1225:
1226: /**
1227: * Tunes the interpolation parameter using the given settings.
1228: * The parameters of the classifier are updated accordingly!
1229: * Marks the interpolation parameter as valid.
1230: *
1231: * @param sLow lower end point of interval of paramters to be examined
1232: * @param sUp upper end point of interval of paramters to be examined
1233: * @param sParts number of parts the interval is divided into. This thus determines
1234: * the granularity of the search
1235: * @param ctype the classification type to use
1236: * @return the value of the tuned interpolation parameter
1237: * @throws IllegalArgumentException if the given parameter list is not
1238: * valid
1239: */
1240: public double tuneInterpolationParameter(double sLow, double sUp,
1241: int sParts, int ctype) throws IllegalArgumentException {
1242:
1243: setInterpolationParameterBounds(sLow, sUp);
1244: setNumberOfPartsForInterpolationParameter(sParts);
1245: setClassificationType(new SelectedTag(ctype,
1246: TAGS_CLASSIFICATIONTYPES));
1247:
1248: m_s = crossValidate(sLow, sUp, sParts, ctype);
1249: m_tuneInterpolationParameter = true;
1250: m_interpolationParameterValid = true;
1251: return m_s;
1252: }
1253:
1254: /**
1255: * Tunes the interpolation parameter using the current settings
1256: * of the classifier. This doesn't change the classifier, i.e.
1257: * none of the internal parameters is changed!
1258: *
1259: * @return the tuned value of the interpolation parameter
1260: * @throws IllegalArgumentException if somehow the current settings of the
1261: * classifier are illegal.
1262: */
1263: public double crossValidate() throws IllegalArgumentException {
1264: return crossValidate(m_sLower, m_sUpper, m_sNrParts, m_ctype);
1265: }
1266:
1267: /**
1268: * Tune the interpolation parameter using leave-one-out
1269: * cross validation, the loss function used is the 1-0 loss
1270: * function.
1271: * <p>
1272: * The given settings are used, but the classifier is not
1273: * updated!. Also, the interpolation parameter s is not
1274: * set.
1275: * </p>
1276: *
1277: * @param sLow lower end point of interval of paramters to be examined
1278: * @param sUp upper end point of interval of paramters to be examined
1279: * @param sNrParts number of parts the interval is divided into. This thus determines
1280: * the granularity of the search
1281: * @param ctype the classification type to use
1282: * @return the best value for the interpolation parameter
1283: * @throws IllegalArgumentException if the settings for the
1284: * interpolation parameter are not valid or if the classification
1285: * type is not valid
1286: */
1287: public double crossValidate(double sLow, double sUp, int sNrParts,
1288: int ctype) throws IllegalArgumentException {
1289:
1290: double[] performanceStats = new double[sNrParts + 1];
1291: return crossValidate(sLow, sUp, sNrParts, ctype,
1292: performanceStats, new ZeroOneLossFunction());
1293: }
1294:
1295: /**
1296: * Tune the interpolation parameter using leave-one-out
1297: * cross validation. The given parameters are used, but
1298: * the classifier is not changed, in particular, the interpolation
1299: * parameter remains unchanged.
1300: *
1301: * @param sLow lower bound for interpolation parameter
1302: * @param sUp upper bound for interpolation parameter
1303: * @param sNrParts determines the granularity of the search
1304: * @param ctype the classification type to use
1305: * @param performanceStats array acting as output, and that will
1306: * contain the total loss of the leave-one-out cross validation for
1307: * each considered value of the interpolation parameter
1308: * @param lossFunction the loss function to use
1309: * @return the value of the interpolation parameter that is considered
1310: * best
1311: * @throws IllegalArgumentException the length of the array
1312: * <code> performanceStats </code> is not sufficient
1313: * @throws IllegalArgumentException if the interpolation parameters
1314: * are not valid
1315: * @throws IllegalArgumentException if the classification type is
1316: * not valid
1317: */
1318: public double crossValidate(double sLow, double sUp, int sNrParts,
1319: int ctype, double[] performanceStats,
1320: NominalLossFunction lossFunction)
1321: throws IllegalArgumentException {
1322:
1323: if (performanceStats.length < sNrParts + 1) {
1324: throw new IllegalArgumentException(
1325: "Length of array is not sufficient");
1326: }
1327:
1328: if (!interpolationParametersValid(sLow, sUp, sNrParts)) {
1329: throw new IllegalArgumentException(
1330: "Interpolation parameters are not valid");
1331: }
1332:
1333: if (!classificationTypeValid(ctype)) {
1334: throw new IllegalArgumentException(
1335: "Not a valid classification type " + ctype);
1336: }
1337:
1338: Arrays.fill(performanceStats, 0, sNrParts + 1, 0);
1339:
1340: // cycle through all instances
1341: for (Iterator it = new EnumerationIterator(m_train
1342: .enumerateInstances()); it.hasNext();) {
1343: Instance instance = (Instance) it.next();
1344: double classValue = instance.classValue();
1345: removeInstance(instance);
1346:
1347: double s = sLow;
1348: double step = (sUp - sLow) / sNrParts; //step size
1349: for (int i = 0; i <= sNrParts; i++, s += step) {
1350: try {
1351: performanceStats[i] += lossFunction.loss(
1352: classValue, classifyInstance(instance, s,
1353: ctype));
1354: } catch (Exception exception) {
1355:
1356: // XXX what should I do here, normally we shouldn't be here
1357: System.err.println(exception.getMessage());
1358: System.exit(1);
1359: }
1360: }
1361:
1362: // XXX may be done more efficiently
1363: addInstance(instance); // update
1364: }
1365:
1366: // select the 'best' value for s
1367: // to this end, we sort the array with the leave-one-out
1368: // performance statistics, and we choose the middle one
1369: // off all those that score 'best'
1370:
1371: // new code, august 2004
1372: // new code, june 2005. If performanceStats is longer than
1373: // necessary, copy it first
1374: double[] tmp = performanceStats;
1375: if (performanceStats.length > sNrParts + 1) {
1376: tmp = new double[sNrParts + 1];
1377: System.arraycopy(performanceStats, 0, tmp, 0, tmp.length);
1378: }
1379: int[] sort = Utils.stableSort(tmp);
1380: int minIndex = 0;
1381: while (minIndex + 1 < tmp.length
1382: && tmp[sort[minIndex + 1]] == tmp[sort[minIndex]]) {
1383: minIndex++;
1384: }
1385: minIndex = sort[minIndex / 2]; // middle one
1386: // int minIndex = Utils.minIndex(performanceStats); // OLD code
1387:
1388: return sLow + minIndex * (sUp - sLow) / sNrParts;
1389: }
1390:
1391: /**
1392: * Checks if <code> ctype </code> is a valid classification
1393: * type.
1394: * @param ctype the int to be checked
1395: * @return true if ctype is a valid classification type, false otherwise
1396: */
1397: private boolean classificationTypeValid(int ctype) {
1398: return ctype == CT_REGRESSION || ctype == CT_WEIGHTED_SUM
1399: || ctype == CT_MAXPROB || ctype == CT_MEDIAN
1400: || ctype == CT_MEDIAN_REAL;
1401: }
1402:
1403: /**
1404: * Checks if the given parameters are valid interpolation parameters.
1405: * @param sLow lower bound for the interval
1406: * @param sUp upper bound for the interval
1407: * @param sNrParts the number of parts the interval has to be divided in
1408: * @return true is the given parameters are valid interpolation parameters,
1409: * false otherwise
1410: */
1411: private boolean interpolationParametersValid(double sLow,
1412: double sUp, int sNrParts) {
1413: return sLow >= 0 && sUp <= 1 && sLow < sUp && sNrParts > 0
1414: || sLow == sUp && sNrParts == 0;
1415: // special case included
1416: }
1417:
1418: /**
1419: * Remove an instance from the classifier. Updates the hashmaps.
1420: * @param instance the instance to be removed.
1421: */
1422: private void removeInstance(Instance instance) {
1423: Coordinates c = new Coordinates(instance);
1424:
1425: // Remove instance temporarily from the Maps with the distributions
1426: DiscreteEstimator df = (DiscreteEstimator) m_estimatedDistributions
1427: .get(c);
1428:
1429: // remove from df
1430: df.addValue(instance.classValue(), -instance.weight());
1431:
1432: if (Math.abs(df.getSumOfCounts() - 0) < Utils.SMALL) {
1433:
1434: /* There was apparently only one example with coordinates c
1435: * in the training set, and now we removed it.
1436: * Remove the key c from both maps.
1437: */
1438: m_estimatedDistributions.remove(c);
1439: m_estimatedCumulativeDistributions.remove(c);
1440: } else {
1441:
1442: // update both maps
1443: m_estimatedDistributions.put(c, df);
1444: m_estimatedCumulativeDistributions.put(c,
1445: new CumulativeDiscreteDistribution(df));
1446: }
1447: }
1448:
1449: /**
1450: * Update the classifier using the given instance. Updates the hashmaps
1451: * @param instance the instance to be added
1452: */
1453: private void addInstance(Instance instance) {
1454:
1455: Coordinates c = new Coordinates(instance);
1456:
1457: // Get DiscreteEstimator from the map
1458: DiscreteEstimator df = (DiscreteEstimator) m_estimatedDistributions
1459: .get(c);
1460:
1461: // If no DiscreteEstimator is present in the map, create one
1462: if (df == null) {
1463: df = new DiscreteEstimator(instance.dataset().numClasses(),
1464: 0);
1465: }
1466: df.addValue(instance.classValue(), instance.weight()); // update df
1467: m_estimatedDistributions.put(c, df); // put back in map
1468: m_estimatedCumulativeDistributions.put(c,
1469: new CumulativeDiscreteDistribution(df));
1470: }
1471:
1472: /**
1473: * Returns an enumeration describing the available options.
1474: * For a list of available options, see <code> setOptions </code>.
1475: *
1476: * @return an enumeration of all available options.
1477: */
1478: public Enumeration listOptions() {
1479: Vector options = new Vector();
1480:
1481: Enumeration enm = super .listOptions();
1482: while (enm.hasMoreElements())
1483: options.addElement(enm.nextElement());
1484:
1485: String description = "\tSets the classification type to be used.\n"
1486: + "\t(Default: "
1487: + new SelectedTag(CT_MEDIAN, TAGS_CLASSIFICATIONTYPES)
1488: + ")";
1489: String synopsis = "-C "
1490: + Tag.toOptionList(TAGS_CLASSIFICATIONTYPES);
1491: String name = "C";
1492: options.addElement(new Option(description, name, 1, synopsis));
1493:
1494: description = "\tUse the balanced version of the "
1495: + "Ordinal Stochastic Dominance Learner";
1496: synopsis = "-B";
1497: name = "B";
1498: options.addElement(new Option(description, name, 1, synopsis));
1499:
1500: description = "\tUse the weighted version of the "
1501: + "Ordinal Stochastic Dominance Learner";
1502: synopsis = "-W";
1503: name = "W";
1504: options.addElement(new Option(description, name, 1, synopsis));
1505:
1506: description = "\tSets the value of the interpolation parameter (not with -W/T/P/L/U)\n"
1507: + "\t(default: 0.5).";
1508: synopsis = "-S <value of interpolation parameter>";
1509: name = "S";
1510: options.addElement(new Option(description, name, 1, synopsis));
1511:
1512: description = "\tTune the interpolation parameter (not with -W/S)\n"
1513: + "\t(default: off)";
1514: synopsis = "-T";
1515: name = "T";
1516: options.addElement(new Option(description, name, 0, synopsis));
1517:
1518: description = "\tLower bound for the interpolation parameter (not with -W/S)\n"
1519: + "\t(default: 0)";
1520: synopsis = "-L <Lower bound for interpolation parameter>";
1521: name = "L";
1522: options.addElement(new Option(description, name, 1, synopsis));
1523:
1524: description = "\tUpper bound for the interpolation parameter (not with -W/S)\n"
1525: + "\t(default: 1)";
1526: synopsis = "-U <Upper bound for interpolation parameter>";
1527: name = "U";
1528: options.addElement(new Option(description, name, 1, synopsis));
1529:
1530: description = "\tDetermines the step size for tuning the interpolation\n"
1531: + "\tparameter, nl. (U-L)/P (not with -W/S)\n"
1532: + "\t(default: 10)";
1533: synopsis = "-P <Number of parts>";
1534: name = "P";
1535: options.addElement(new Option(description, name, 1, synopsis));
1536:
1537: return options.elements();
1538: }
1539:
1540: /**
1541: * Parses the options for this object. <p/>
1542: *
1543: <!-- options-start -->
1544: * Valid options are: <p/>
1545: *
1546: * <pre> -D
1547: * If set, classifier is run in debug mode and
1548: * may output additional info to the console</pre>
1549: *
1550: * <pre> -C <REG|WSUM|MAX|MED|RMED>
1551: * Sets the classification type to be used.
1552: * (Default: MED)</pre>
1553: *
1554: * <pre> -B
1555: * Use the balanced version of the Ordinal Stochastic Dominance Learner</pre>
1556: *
1557: * <pre> -W
1558: * Use the weighted version of the Ordinal Stochastic Dominance Learner</pre>
1559: *
1560: * <pre> -S <value of interpolation parameter>
1561: * Sets the value of the interpolation parameter (not with -W/T/P/L/U)
1562: * (default: 0.5).</pre>
1563: *
1564: * <pre> -T
1565: * Tune the interpolation parameter (not with -W/S)
1566: * (default: off)</pre>
1567: *
1568: * <pre> -L <Lower bound for interpolation parameter>
1569: * Lower bound for the interpolation parameter (not with -W/S)
1570: * (default: 0)</pre>
1571: *
1572: * <pre> -U <Upper bound for interpolation parameter>
1573: * Upper bound for the interpolation parameter (not with -W/S)
1574: * (default: 1)</pre>
1575: *
1576: * <pre> -P <Number of parts>
1577: * Determines the step size for tuning the interpolation
1578: * parameter, nl. (U-L)/P (not with -W/S)
1579: * (default: 10)</pre>
1580: *
1581: <!-- options-end -->
1582: *
1583: * @param options the list of options as an array of strings
1584: * @throws Exception if an option is not supported
1585: */
1586: public void setOptions(String[] options) throws Exception {
1587: String args;
1588:
1589: args = Utils.getOption('C', options);
1590: if (args.length() != 0)
1591: setClassificationType(new SelectedTag(args,
1592: TAGS_CLASSIFICATIONTYPES));
1593: else
1594: setClassificationType(new SelectedTag(CT_MEDIAN,
1595: TAGS_CLASSIFICATIONTYPES));
1596:
1597: setBalanced(Utils.getFlag('B', options));
1598:
1599: if (Utils.getFlag('W', options)) {
1600: m_weighted = true;
1601: // ignore any T, S, P, L and U options
1602: Utils.getOption('T', options);
1603: Utils.getOption('S', options);
1604: Utils.getOption('P', options);
1605: Utils.getOption('L', options);
1606: Utils.getOption('U', options);
1607: } else {
1608: m_tuneInterpolationParameter = Utils.getFlag('T', options);
1609:
1610: if (!m_tuneInterpolationParameter) {
1611: // ignore P, L, U
1612: Utils.getOption('P', options);
1613: Utils.getOption('L', options);
1614: Utils.getOption('U', options);
1615:
1616: // value of s
1617: args = Utils.getOption('S', options);
1618: if (args.length() != 0)
1619: setInterpolationParameter(Double.parseDouble(args));
1620: else
1621: setInterpolationParameter(0.5);
1622: } else {
1623: // ignore S
1624: Utils.getOption('S', options);
1625:
1626: args = Utils.getOption('L', options);
1627: double l = m_sLower;
1628: if (args.length() != 0)
1629: l = Double.parseDouble(args);
1630: else
1631: l = 0.0;
1632:
1633: args = Utils.getOption('U', options);
1634: double u = m_sUpper;
1635: if (args.length() != 0)
1636: u = Double.parseDouble(args);
1637: else
1638: u = 1.0;
1639:
1640: if (m_tuneInterpolationParameter)
1641: setInterpolationParameterBounds(l, u);
1642:
1643: args = Utils.getOption('P', options);
1644: if (args.length() != 0)
1645: setNumberOfPartsForInterpolationParameter(Integer
1646: .parseInt(args));
1647: else
1648: setNumberOfPartsForInterpolationParameter(10);
1649: }
1650: }
1651:
1652: super .setOptions(options);
1653: }
1654:
1655: /**
1656: * Gets the current settings of the OSDLCore classifier.
1657: *
1658: * @return an array of strings suitable for passing
1659: * to <code> setOptions </code>
1660: */
1661: public String[] getOptions() {
1662: int i;
1663: Vector result;
1664: String[] options;
1665:
1666: result = new Vector();
1667:
1668: options = super .getOptions();
1669: for (i = 0; i < options.length; i++)
1670: result.add(options[i]);
1671:
1672: // classification type
1673: result.add("-C");
1674: result.add("" + getClassificationType());
1675:
1676: if (m_balanced)
1677: result.add("-B");
1678:
1679: if (m_weighted) {
1680: result.add("-W");
1681: } else {
1682: // interpolation parameter
1683: if (!m_tuneInterpolationParameter) {
1684: result.add("-S");
1685: result.add(Double.toString(m_s));
1686: } else {
1687: result.add("-T");
1688: result.add("-L");
1689: result.add(Double.toString(m_sLower));
1690: result.add("-U");
1691: result.add(Double.toString(m_sUpper));
1692: result.add("-P");
1693: result.add(Integer.toString(m_sNrParts));
1694: }
1695: }
1696:
1697: return (String[]) result.toArray(new String[result.size()]);
1698: }
1699:
1700: /**
1701: * Returns a description of the classifier.
1702: * Attention: if debugging is on, the description can be become
1703: * very lengthy.
1704: *
1705: * @return a string containing the description
1706: */
1707: public String toString() {
1708: StringBuffer sb = new StringBuffer();
1709:
1710: // balanced or ordinary OSDL
1711: if (m_balanced) {
1712: sb.append("Balanced OSDL\n=============\n\n");
1713: } else {
1714: sb.append("Ordinary OSDL\n=============\n\n");
1715: }
1716:
1717: if (m_weighted) {
1718: sb.append("Weighted variant\n");
1719: }
1720:
1721: // classification type used
1722: sb.append("Classification type: " + getClassificationType()
1723: + "\n");
1724:
1725: // parameter s
1726: if (!m_weighted) {
1727: sb.append("Interpolation parameter: " + m_s + "\n");
1728: if (m_tuneInterpolationParameter) {
1729: sb.append("Bounds and stepsize: " + m_sLower + " "
1730: + m_sUpper + " " + m_sNrParts + "\n");
1731: if (!m_interpolationParameterValid) {
1732: sb.append("Interpolation parameter is not valid");
1733: }
1734: }
1735: }
1736:
1737: if (m_Debug) {
1738:
1739: if (m_estimatedCumulativeDistributions != null) {
1740: /*
1741: * Cycle through all the map of cumulative distribution functions
1742: * and print each cumulative distribution function
1743: */
1744: for (Iterator i = m_estimatedCumulativeDistributions
1745: .keySet().iterator(); i.hasNext();) {
1746: Coordinates yc = (Coordinates) i.next();
1747: CumulativeDiscreteDistribution cdf = (CumulativeDiscreteDistribution) m_estimatedCumulativeDistributions
1748: .get(yc);
1749: sb.append("[" + yc.hashCode() + "] "
1750: + yc.toString() + " --> " + cdf.toString()
1751: + "\n");
1752: }
1753: }
1754: }
1755: return sb.toString();
1756: }
1757: }
|