0001: /*
0002: * This program is free software; you can redistribute it and/or modify
0003: * it under the terms of the GNU General Public License as published by
0004: * the Free Software Foundation; either version 2 of the License, or
0005: * (at your option) any later version.
0006: *
0007: * This program is distributed in the hope that it will be useful,
0008: * but WITHOUT ANY WARRANTY; without even the implied warranty of
0009: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
0010: * GNU General Public License for more details.
0011: *
0012: * You should have received a copy of the GNU General Public License
0013: * along with this program; if not, write to the Free Software
0014: * Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
0015: */
0016:
0017: /*
0018: * TLD.java
0019: * Copyright (C) 2005 University of Waikato, Hamilton, New Zealand
0020: *
0021: */
0022:
0023: package weka.classifiers.mi;
0024:
0025: import weka.classifiers.RandomizableClassifier;
0026: import weka.core.Capabilities;
0027: import weka.core.Instance;
0028: import weka.core.Instances;
0029: import weka.core.MultiInstanceCapabilitiesHandler;
0030: import weka.core.Optimization;
0031: import weka.core.Option;
0032: import weka.core.OptionHandler;
0033: import weka.core.TechnicalInformation;
0034: import weka.core.TechnicalInformationHandler;
0035: import weka.core.Utils;
0036: import weka.core.Capabilities.Capability;
0037: import weka.core.TechnicalInformation.Field;
0038: import weka.core.TechnicalInformation.Type;
0039:
0040: import java.util.Enumeration;
0041: import java.util.Random;
0042: import java.util.Vector;
0043:
0044: /**
0045: <!-- globalinfo-start -->
0046: * Two-Level Distribution approach, changes the starting value of the searching algorithm, supplement the cut-off modification and check missing values.<br/>
0047: * <br/>
0048: * For more information see:<br/>
0049: * <br/>
0050: * Xin Xu (2003). Statistical learning in multiple instance problem. Hamilton, NZ.
0051: * <p/>
0052: <!-- globalinfo-end -->
0053: *
0054: <!-- technical-bibtex-start -->
0055: * BibTeX:
0056: * <pre>
0057: * @mastersthesis{Xu2003,
0058: * address = {Hamilton, NZ},
0059: * author = {Xin Xu},
0060: * note = {0657.594},
0061: * school = {University of Waikato},
0062: * title = {Statistical learning in multiple instance problem},
0063: * year = {2003}
0064: * }
0065: * </pre>
0066: * <p/>
0067: <!-- technical-bibtex-end -->
0068: *
0069: <!-- options-start -->
0070: * Valid options are: <p/>
0071: *
0072: * <pre> -C
0073: * Set whether or not use empirical
0074: * log-odds cut-off instead of 0</pre>
0075: *
0076: * <pre> -R <numOfRuns>
0077: * Set the number of multiple runs
0078: * needed for searching the MLE.</pre>
0079: *
0080: * <pre> -S <num>
0081: * Random number seed.
0082: * (default 1)</pre>
0083: *
0084: * <pre> -D
0085: * If set, classifier is run in debug mode and
0086: * may output additional info to the console</pre>
0087: *
0088: <!-- options-end -->
0089: *
0090: * @author Eibe Frank (eibe@cs.waikato.ac.nz)
0091: * @author Xin Xu (xx5@cs.waikato.ac.nz)
0092: * @version $Revision: 1.5 $
0093: */
0094: public class TLD extends RandomizableClassifier implements
0095: OptionHandler, MultiInstanceCapabilitiesHandler,
0096: TechnicalInformationHandler {
0097:
0098: /** for serialization */
0099: static final long serialVersionUID = 6657315525171152210L;
0100:
0101: /** The mean for each attribute of each positive exemplar */
0102: protected double[][] m_MeanP = null;
0103:
0104: /** The variance for each attribute of each positive exemplar */
0105: protected double[][] m_VarianceP = null;
0106:
0107: /** The mean for each attribute of each negative exemplar */
0108: protected double[][] m_MeanN = null;
0109:
0110: /** The variance for each attribute of each negative exemplar */
0111: protected double[][] m_VarianceN = null;
0112:
0113: /** The effective sum of weights of each positive exemplar in each dimension*/
0114: protected double[][] m_SumP = null;
0115:
0116: /** The effective sum of weights of each negative exemplar in each dimension*/
0117: protected double[][] m_SumN = null;
0118:
0119: /** The parameters to be estimated for each positive exemplar*/
0120: protected double[] m_ParamsP = null;
0121:
0122: /** The parameters to be estimated for each negative exemplar*/
0123: protected double[] m_ParamsN = null;
0124:
0125: /** The dimension of each exemplar, i.e. (numAttributes-2) */
0126: protected int m_Dimension = 0;
0127:
0128: /** The class label of each exemplar */
0129: protected double[] m_Class = null;
0130:
0131: /** The number of class labels in the data */
0132: protected int m_NumClasses = 2;
0133:
0134: /** The very small number representing zero */
0135: static public double ZERO = 1.0e-6;
0136:
0137: /** The number of runs to perform */
0138: protected int m_Run = 1;
0139:
0140: protected double m_Cutoff;
0141:
0142: protected boolean m_UseEmpiricalCutOff = false;
0143:
0144: /**
0145: * Returns a string describing this filter
0146: *
0147: * @return a description of the filter suitable for
0148: * displaying in the explorer/experimenter gui
0149: */
0150: public String globalInfo() {
0151: return "Two-Level Distribution approach, changes the starting value of "
0152: + "the searching algorithm, supplement the cut-off modification and "
0153: + "check missing values.\n\n"
0154: + "For more information see:\n\n"
0155: + getTechnicalInformation().toString();
0156: }
0157:
0158: /**
0159: * Returns an instance of a TechnicalInformation object, containing
0160: * detailed information about the technical background of this class,
0161: * e.g., paper reference or book this class is based on.
0162: *
0163: * @return the technical information about this class
0164: */
0165: public TechnicalInformation getTechnicalInformation() {
0166: TechnicalInformation result;
0167:
0168: result = new TechnicalInformation(Type.MASTERSTHESIS);
0169: result.setValue(Field.AUTHOR, "Xin Xu");
0170: result.setValue(Field.YEAR, "2003");
0171: result.setValue(Field.TITLE,
0172: "Statistical learning in multiple instance problem");
0173: result.setValue(Field.SCHOOL, "University of Waikato");
0174: result.setValue(Field.ADDRESS, "Hamilton, NZ");
0175: result.setValue(Field.NOTE, "0657.594");
0176:
0177: return result;
0178: }
0179:
0180: /**
0181: * Returns default capabilities of the classifier.
0182: *
0183: * @return the capabilities of this classifier
0184: */
0185: public Capabilities getCapabilities() {
0186: Capabilities result = super .getCapabilities();
0187:
0188: // attributes
0189: result.enable(Capability.NOMINAL_ATTRIBUTES);
0190: result.enable(Capability.RELATIONAL_ATTRIBUTES);
0191: result.enable(Capability.MISSING_VALUES);
0192:
0193: // class
0194: result.enable(Capability.BINARY_CLASS);
0195: result.enable(Capability.MISSING_CLASS_VALUES);
0196:
0197: // other
0198: result.enable(Capability.ONLY_MULTIINSTANCE);
0199:
0200: return result;
0201: }
0202:
0203: /**
0204: * Returns the capabilities of this multi-instance classifier for the
0205: * relational data.
0206: *
0207: * @return the capabilities of this object
0208: * @see Capabilities
0209: */
0210: public Capabilities getMultiInstanceCapabilities() {
0211: Capabilities result = super .getCapabilities();
0212:
0213: // attributes
0214: result.enable(Capability.NUMERIC_ATTRIBUTES);
0215: result.enable(Capability.MISSING_VALUES);
0216:
0217: // class
0218: result.disableAllClasses();
0219: result.enable(Capability.NO_CLASS);
0220:
0221: return result;
0222: }
0223:
0224: /**
0225: *
0226: * @param exs the training exemplars
0227: * @throws Exception if the model cannot be built properly
0228: */
0229: public void buildClassifier(Instances exs) throws Exception {
0230: // can classifier handle the data?
0231: getCapabilities().testWithFail(exs);
0232:
0233: // remove instances with missing class
0234: exs = new Instances(exs);
0235: exs.deleteWithMissingClass();
0236:
0237: int numegs = exs.numInstances();
0238: m_Dimension = exs.attribute(1).relation().numAttributes();
0239: Instances pos = new Instances(exs, 0), neg = new Instances(exs,
0240: 0);
0241:
0242: for (int u = 0; u < numegs; u++) {
0243: Instance example = exs.instance(u);
0244: if (example.classValue() == 1)
0245: pos.add(example);
0246: else
0247: neg.add(example);
0248: }
0249:
0250: int pnum = pos.numInstances(), nnum = neg.numInstances();
0251:
0252: m_MeanP = new double[pnum][m_Dimension];
0253: m_VarianceP = new double[pnum][m_Dimension];
0254: m_SumP = new double[pnum][m_Dimension];
0255: m_MeanN = new double[nnum][m_Dimension];
0256: m_VarianceN = new double[nnum][m_Dimension];
0257: m_SumN = new double[nnum][m_Dimension];
0258: m_ParamsP = new double[4 * m_Dimension];
0259: m_ParamsN = new double[4 * m_Dimension];
0260:
0261: // Estimation of the parameters: as the start value for search
0262: double[] pSumVal = new double[m_Dimension], // for m
0263: nSumVal = new double[m_Dimension];
0264: double[] maxVarsP = new double[m_Dimension], // for a
0265: maxVarsN = new double[m_Dimension];
0266: // Mean of sample variances: for b, b=a/E(\sigma^2)+2
0267: double[] varMeanP = new double[m_Dimension], varMeanN = new double[m_Dimension];
0268: // Variances of sample means: for w, w=E[var(\mu)]/E[\sigma^2]
0269: double[] meanVarP = new double[m_Dimension], meanVarN = new double[m_Dimension];
0270: // number of exemplars without all values missing
0271: double[] numExsP = new double[m_Dimension], numExsN = new double[m_Dimension];
0272:
0273: // Extract metadata fro both positive and negative bags
0274: for (int v = 0; v < pnum; v++) {
0275: /*Exemplar px = pos.exemplar(v);
0276: m_MeanP[v] = px.meanOrMode();
0277: m_VarianceP[v] = px.variance();
0278: Instances pxi = px.getInstances();
0279: */
0280:
0281: Instances pxi = pos.instance(v).relationalValue(1);
0282: for (int k = 0; k < pxi.numAttributes(); k++) {
0283: m_MeanP[v][k] = pxi.meanOrMode(k);
0284: m_VarianceP[v][k] = pxi.variance(k);
0285: }
0286:
0287: for (int w = 0, t = 0; w < m_Dimension; w++, t++) {
0288: //if((t==m_ClassIndex) || (t==m_IdIndex))
0289: // t++;
0290:
0291: if (!Double.isNaN(m_MeanP[v][w])) {
0292: for (int u = 0; u < pxi.numInstances(); u++) {
0293: Instance ins = pxi.instance(u);
0294: if (!ins.isMissing(t))
0295: m_SumP[v][w] += ins.weight();
0296: }
0297: numExsP[w]++;
0298: pSumVal[w] += m_MeanP[v][w];
0299: meanVarP[w] += m_MeanP[v][w] * m_MeanP[v][w];
0300: if (maxVarsP[w] < m_VarianceP[v][w])
0301: maxVarsP[w] = m_VarianceP[v][w];
0302: varMeanP[w] += m_VarianceP[v][w];
0303: m_VarianceP[v][w] *= (m_SumP[v][w] - 1.0);
0304: if (m_VarianceP[v][w] < 0.0)
0305: m_VarianceP[v][w] = 0.0;
0306: }
0307: }
0308: }
0309:
0310: for (int v = 0; v < nnum; v++) {
0311: /*Exemplar nx = neg.exemplar(v);
0312: m_MeanN[v] = nx.meanOrMode();
0313: m_VarianceN[v] = nx.variance();
0314: Instances nxi = nx.getInstances();
0315: */
0316: Instances nxi = neg.instance(v).relationalValue(1);
0317: for (int k = 0; k < nxi.numAttributes(); k++) {
0318: m_MeanN[v][k] = nxi.meanOrMode(k);
0319: m_VarianceN[v][k] = nxi.variance(k);
0320: }
0321:
0322: for (int w = 0, t = 0; w < m_Dimension; w++, t++) {
0323: //if((t==m_ClassIndex) || (t==m_IdIndex))
0324: // t++;
0325:
0326: if (!Double.isNaN(m_MeanN[v][w])) {
0327: for (int u = 0; u < nxi.numInstances(); u++)
0328: if (!nxi.instance(u).isMissing(t))
0329: m_SumN[v][w] += nxi.instance(u).weight();
0330: numExsN[w]++;
0331: nSumVal[w] += m_MeanN[v][w];
0332: meanVarN[w] += m_MeanN[v][w] * m_MeanN[v][w];
0333: if (maxVarsN[w] < m_VarianceN[v][w])
0334: maxVarsN[w] = m_VarianceN[v][w];
0335: varMeanN[w] += m_VarianceN[v][w];
0336: m_VarianceN[v][w] *= (m_SumN[v][w] - 1.0);
0337: if (m_VarianceN[v][w] < 0.0)
0338: m_VarianceN[v][w] = 0.0;
0339: }
0340: }
0341: }
0342:
0343: for (int w = 0; w < m_Dimension; w++) {
0344: pSumVal[w] /= numExsP[w];
0345: nSumVal[w] /= numExsN[w];
0346: if (numExsP[w] > 1)
0347: meanVarP[w] = meanVarP[w] / (numExsP[w] - 1.0)
0348: - pSumVal[w] * numExsP[w] / (numExsP[w] - 1.0);
0349: if (numExsN[w] > 1)
0350: meanVarN[w] = meanVarN[w] / (numExsN[w] - 1.0)
0351: - nSumVal[w] * numExsN[w] / (numExsN[w] - 1.0);
0352: varMeanP[w] /= numExsP[w];
0353: varMeanN[w] /= numExsN[w];
0354: }
0355:
0356: //Bounds and parameter values for each run
0357: double[][] bounds = new double[2][4];
0358: double[] pThisParam = new double[4], nThisParam = new double[4];
0359:
0360: // Initial values for parameters
0361: double a, b, w, m;
0362:
0363: // Optimize for one dimension
0364: for (int x = 0; x < m_Dimension; x++) {
0365: if (getDebug())
0366: System.err
0367: .println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Dimension #"
0368: + x);
0369:
0370: // Positive examplars: first run
0371: a = (maxVarsP[x] > ZERO) ? maxVarsP[x] : 1.0;
0372: if (varMeanP[x] <= ZERO)
0373: varMeanP[x] = ZERO; // modified by LinDong (09/2005)
0374: b = a / varMeanP[x] + 2.0; // a/(b-2) = E(\sigma^2)
0375: w = meanVarP[x] / varMeanP[x]; // E[var(\mu)] = w*E[\sigma^2]
0376: if (w <= ZERO)
0377: w = 1.0;
0378:
0379: m = pSumVal[x];
0380: pThisParam[0] = a; // a
0381: pThisParam[1] = b; // b
0382: pThisParam[2] = w; // w
0383: pThisParam[3] = m; // m
0384:
0385: // Negative examplars: first run
0386: a = (maxVarsN[x] > ZERO) ? maxVarsN[x] : 1.0;
0387: if (varMeanN[x] <= ZERO)
0388: varMeanN[x] = ZERO; // modified by LinDong (09/2005)
0389: b = a / varMeanN[x] + 2.0; // a/(b-2) = E(\sigma^2)
0390: w = meanVarN[x] / varMeanN[x]; // E[var(\mu)] = w*E[\sigma^2]
0391: if (w <= ZERO)
0392: w = 1.0;
0393:
0394: m = nSumVal[x];
0395: nThisParam[0] = a; // a
0396: nThisParam[1] = b; // b
0397: nThisParam[2] = w; // w
0398: nThisParam[3] = m; // m
0399:
0400: // Bound constraints
0401: bounds[0][0] = ZERO; // a > 0
0402: bounds[0][1] = 2.0 + ZERO; // b > 2
0403: bounds[0][2] = ZERO; // w > 0
0404: bounds[0][3] = Double.NaN;
0405:
0406: for (int t = 0; t < 4; t++) {
0407: bounds[1][t] = Double.NaN;
0408: m_ParamsP[4 * x + t] = pThisParam[t];
0409: m_ParamsN[4 * x + t] = nThisParam[t];
0410: }
0411: double pminVal = Double.MAX_VALUE, nminVal = Double.MAX_VALUE;
0412: Random whichEx = new Random(m_Seed);
0413: TLD_Optm pOp = null, nOp = null;
0414: boolean isRunValid = true;
0415: double[] sumP = new double[pnum], meanP = new double[pnum], varP = new double[pnum];
0416: double[] sumN = new double[nnum], meanN = new double[nnum], varN = new double[nnum];
0417:
0418: // One dimension
0419: for (int p = 0; p < pnum; p++) {
0420: sumP[p] = m_SumP[p][x];
0421: meanP[p] = m_MeanP[p][x];
0422: varP[p] = m_VarianceP[p][x];
0423: }
0424: for (int q = 0; q < nnum; q++) {
0425: sumN[q] = m_SumN[q][x];
0426: meanN[q] = m_MeanN[q][x];
0427: varN[q] = m_VarianceN[q][x];
0428: }
0429:
0430: for (int y = 0; y < m_Run;) {
0431: if (getDebug())
0432: System.err
0433: .println("\n\n!!!!!!!!!!!!!!!!!!!!!!???Run #"
0434: + y);
0435: double this Min;
0436:
0437: if (getDebug())
0438: System.err.println("\nPositive exemplars");
0439: pOp = new TLD_Optm();
0440: pOp.setNum(sumP);
0441: pOp.setSSquare(varP);
0442: pOp.setXBar(meanP);
0443:
0444: pThisParam = pOp.findArgmin(pThisParam, bounds);
0445: while (pThisParam == null) {
0446: pThisParam = pOp.getVarbValues();
0447: if (getDebug())
0448: System.err
0449: .println("!!! 200 iterations finished, not enough!");
0450: pThisParam = pOp.findArgmin(pThisParam, bounds);
0451: }
0452:
0453: this Min = pOp.getMinFunction();
0454: if (!Double.isNaN(this Min) && (this Min < pminVal)) {
0455: pminVal = this Min;
0456: for (int z = 0; z < 4; z++)
0457: m_ParamsP[4 * x + z] = pThisParam[z];
0458: }
0459:
0460: if (Double.isNaN(this Min)) {
0461: pThisParam = new double[4];
0462: isRunValid = false;
0463: }
0464:
0465: if (getDebug())
0466: System.err.println("\nNegative exemplars");
0467: nOp = new TLD_Optm();
0468: nOp.setNum(sumN);
0469: nOp.setSSquare(varN);
0470: nOp.setXBar(meanN);
0471:
0472: nThisParam = nOp.findArgmin(nThisParam, bounds);
0473: while (nThisParam == null) {
0474: nThisParam = nOp.getVarbValues();
0475: if (getDebug())
0476: System.err
0477: .println("!!! 200 iterations finished, not enough!");
0478: nThisParam = nOp.findArgmin(nThisParam, bounds);
0479: }
0480: this Min = nOp.getMinFunction();
0481: if (!Double.isNaN(this Min) && (this Min < nminVal)) {
0482: nminVal = this Min;
0483: for (int z = 0; z < 4; z++)
0484: m_ParamsN[4 * x + z] = nThisParam[z];
0485: }
0486:
0487: if (Double.isNaN(this Min)) {
0488: nThisParam = new double[4];
0489: isRunValid = false;
0490: }
0491:
0492: if (!isRunValid) {
0493: y--;
0494: isRunValid = true;
0495: }
0496:
0497: if (++y < m_Run) {
0498: // Change the initial parameters and restart
0499: int pone = whichEx.nextInt(pnum), // Randomly pick one pos. exmpl.
0500: none = whichEx.nextInt(nnum);
0501:
0502: // Positive exemplars: next run
0503: while ((m_SumP[pone][x] <= 1.0)
0504: || Double.isNaN(m_MeanP[pone][x]))
0505: pone = whichEx.nextInt(pnum);
0506:
0507: a = m_VarianceP[pone][x] / (m_SumP[pone][x] - 1.0);
0508: if (a <= ZERO)
0509: a = m_ParamsN[4 * x]; // Change to negative params
0510: m = m_MeanP[pone][x];
0511: double sq = (m - m_ParamsP[4 * x + 3])
0512: * (m - m_ParamsP[4 * x + 3]);
0513:
0514: b = a * m_ParamsP[4 * x + 2] / sq + 2.0; // b=a/Var+2, assuming Var=Sq/w'
0515: if ((b <= ZERO) || Double.isNaN(b)
0516: || Double.isInfinite(b))
0517: b = m_ParamsN[4 * x + 1];
0518:
0519: w = sq * (m_ParamsP[4 * x + 1] - 2.0)
0520: / m_ParamsP[4 * x];//w=Sq/Var, assuming Var=a'/(b'-2)
0521: if ((w <= ZERO) || Double.isNaN(w)
0522: || Double.isInfinite(w))
0523: w = m_ParamsN[4 * x + 2];
0524:
0525: pThisParam[0] = a; // a
0526: pThisParam[1] = b; // b
0527: pThisParam[2] = w; // w
0528: pThisParam[3] = m; // m
0529:
0530: // Negative exemplars: next run
0531: while ((m_SumN[none][x] <= 1.0)
0532: || Double.isNaN(m_MeanN[none][x]))
0533: none = whichEx.nextInt(nnum);
0534:
0535: a = m_VarianceN[none][x] / (m_SumN[none][x] - 1.0);
0536: if (a <= ZERO)
0537: a = m_ParamsP[4 * x];
0538: m = m_MeanN[none][x];
0539: sq = (m - m_ParamsN[4 * x + 3])
0540: * (m - m_ParamsN[4 * x + 3]);
0541:
0542: b = a * m_ParamsN[4 * x + 2] / sq + 2.0; // b=a/Var+2, assuming Var=Sq/w'
0543: if ((b <= ZERO) || Double.isNaN(b)
0544: || Double.isInfinite(b))
0545: b = m_ParamsP[4 * x + 1];
0546:
0547: w = sq * (m_ParamsN[4 * x + 1] - 2.0)
0548: / m_ParamsN[4 * x];//w=Sq/Var, assuming Var=a'/(b'-2)
0549: if ((w <= ZERO) || Double.isNaN(w)
0550: || Double.isInfinite(w))
0551: w = m_ParamsP[4 * x + 2];
0552:
0553: nThisParam[0] = a; // a
0554: nThisParam[1] = b; // b
0555: nThisParam[2] = w; // w
0556: nThisParam[3] = m; // m
0557: }
0558: }
0559: }
0560:
0561: for (int x = 0, y = 0; x < m_Dimension; x++, y++) {
0562: //if((x==exs.classIndex()) || (x==exs.idIndex()))
0563: //y++;
0564: a = m_ParamsP[4 * x];
0565: b = m_ParamsP[4 * x + 1];
0566: w = m_ParamsP[4 * x + 2];
0567: m = m_ParamsP[4 * x + 3];
0568: if (getDebug())
0569: System.err.println("\n\n???Positive: ( "
0570: + exs.attribute(1).relation().attribute(y)
0571: + "): a=" + a + ", b=" + b + ", w=" + w
0572: + ", m=" + m);
0573:
0574: a = m_ParamsN[4 * x];
0575: b = m_ParamsN[4 * x + 1];
0576: w = m_ParamsN[4 * x + 2];
0577: m = m_ParamsN[4 * x + 3];
0578: if (getDebug())
0579: System.err.println("???Negative: ("
0580: + exs.attribute(1).relation().attribute(y)
0581: + "): a=" + a + ", b=" + b + ", w=" + w
0582: + ", m=" + m);
0583: }
0584:
0585: if (m_UseEmpiricalCutOff) {
0586: // Find the empirical cut-off
0587: double[] pLogOdds = new double[pnum], nLogOdds = new double[nnum];
0588: for (int p = 0; p < pnum; p++)
0589: pLogOdds[p] = likelihoodRatio(m_SumP[p], m_MeanP[p],
0590: m_VarianceP[p]);
0591:
0592: for (int q = 0; q < nnum; q++)
0593: nLogOdds[q] = likelihoodRatio(m_SumN[q], m_MeanN[q],
0594: m_VarianceN[q]);
0595:
0596: // Update m_Cutoff
0597: findCutOff(pLogOdds, nLogOdds);
0598: } else
0599: m_Cutoff = -Math.log((double) pnum / (double) nnum);
0600:
0601: if (getDebug())
0602: System.err.println("???Cut-off=" + m_Cutoff);
0603: }
0604:
0605: /**
0606: *
0607: * @param ex the given test exemplar
0608: * @return the classification
0609: * @throws Exception if the exemplar could not be classified
0610: * successfully
0611: */
0612: public double classifyInstance(Instance ex) throws Exception {
0613: //Exemplar ex = new Exemplar(e);
0614: Instances exi = ex.relationalValue(1);
0615: double[] n = new double[m_Dimension];
0616: double[] xBar = new double[m_Dimension];
0617: double[] sSq = new double[m_Dimension];
0618: for (int i = 0; i < exi.numAttributes(); i++) {
0619: xBar[i] = exi.meanOrMode(i);
0620: sSq[i] = exi.variance(i);
0621: }
0622:
0623: for (int w = 0, t = 0; w < m_Dimension; w++, t++) {
0624: //if((t==m_ClassIndex) || (t==m_IdIndex))
0625: //t++;
0626: for (int u = 0; u < exi.numInstances(); u++)
0627: if (!exi.instance(u).isMissing(t))
0628: n[w] += exi.instance(u).weight();
0629:
0630: sSq[w] = sSq[w] * (n[w] - 1.0);
0631: if (sSq[w] <= 0.0)
0632: sSq[w] = 0.0;
0633: }
0634:
0635: double logOdds = likelihoodRatio(n, xBar, sSq);
0636: return (logOdds > m_Cutoff) ? 1 : 0;
0637: }
0638:
0639: private double likelihoodRatio(double[] n, double[] xBar,
0640: double[] sSq) {
0641: double LLP = 0.0, LLN = 0.0;
0642:
0643: for (int x = 0; x < m_Dimension; x++) {
0644: if (Double.isNaN(xBar[x]))
0645: continue; // All missing values
0646:
0647: int halfN = ((int) n[x]) / 2;
0648: //Log-likelihood for positive
0649: double a = m_ParamsP[4 * x], b = m_ParamsP[4 * x + 1], w = m_ParamsP[4 * x + 2], m = m_ParamsP[4 * x + 3];
0650: LLP += 0.5
0651: * b
0652: * Math.log(a)
0653: + 0.5
0654: * (b + n[x] - 1.0)
0655: * Math.log(1.0 + n[x] * w)
0656: - 0.5
0657: * (b + n[x])
0658: * Math.log((1.0 + n[x] * w) * (a + sSq[x]) + n[x]
0659: * (xBar[x] - m) * (xBar[x] - m)) - 0.5
0660: * n[x] * Math.log(Math.PI);
0661: for (int y = 1; y <= halfN; y++)
0662: LLP += Math.log(b / 2.0 + n[x] / 2.0 - (double) y);
0663:
0664: if (n[x] / 2.0 > halfN) // n is odd
0665: LLP += TLD_Optm.diffLnGamma(b / 2.0);
0666:
0667: //Log-likelihood for negative
0668: a = m_ParamsN[4 * x];
0669: b = m_ParamsN[4 * x + 1];
0670: w = m_ParamsN[4 * x + 2];
0671: m = m_ParamsN[4 * x + 3];
0672: LLN += 0.5
0673: * b
0674: * Math.log(a)
0675: + 0.5
0676: * (b + n[x] - 1.0)
0677: * Math.log(1.0 + n[x] * w)
0678: - 0.5
0679: * (b + n[x])
0680: * Math.log((1.0 + n[x] * w) * (a + sSq[x]) + n[x]
0681: * (xBar[x] - m) * (xBar[x] - m)) - 0.5
0682: * n[x] * Math.log(Math.PI);
0683: for (int y = 1; y <= halfN; y++)
0684: LLN += Math.log(b / 2.0 + n[x] / 2.0 - (double) y);
0685:
0686: if (n[x] / 2.0 > halfN) // n is odd
0687: LLN += TLD_Optm.diffLnGamma(b / 2.0);
0688: }
0689:
0690: return LLP - LLN;
0691: }
0692:
0693: private void findCutOff(double[] pos, double[] neg) {
0694: int[] pOrder = Utils.sort(pos), nOrder = Utils.sort(neg);
0695: /*
0696: System.err.println("\n\n???Positive: ");
0697: for(int t=0; t<pOrder.length; t++)
0698: System.err.print(t+":"+Utils.doubleToString(pos[pOrder[t]],0,2)+" ");
0699: System.err.println("\n\n???Negative: ");
0700: for(int t=0; t<nOrder.length; t++)
0701: System.err.print(t+":"+Utils.doubleToString(neg[nOrder[t]],0,2)+" ");
0702: */
0703: int pNum = pos.length, nNum = neg.length, count, p = 0, n = 0;
0704: double fstAccu = 0.0, sndAccu = (double) pNum, split;
0705: double maxAccu = 0, minDistTo0 = Double.MAX_VALUE;
0706:
0707: // Skip continuous negatives
0708: for (; (n < nNum) && (pos[pOrder[0]] >= neg[nOrder[n]]); n++, fstAccu++)
0709: ;
0710:
0711: if (n >= nNum) { // totally seperate
0712: m_Cutoff = (neg[nOrder[nNum - 1]] + pos[pOrder[0]]) / 2.0;
0713: //m_Cutoff = neg[nOrder[nNum-1]];
0714: return;
0715: }
0716:
0717: count = n;
0718: while ((p < pNum) && (n < nNum)) {
0719: // Compare the next in the two lists
0720: if (pos[pOrder[p]] >= neg[nOrder[n]]) { // Neg has less log-odds
0721: fstAccu += 1.0;
0722: split = neg[nOrder[n]];
0723: n++;
0724: } else {
0725: sndAccu -= 1.0;
0726: split = pos[pOrder[p]];
0727: p++;
0728: }
0729: count++;
0730: if ((fstAccu + sndAccu > maxAccu)
0731: || ((fstAccu + sndAccu == maxAccu) && (Math
0732: .abs(split) < minDistTo0))) {
0733: maxAccu = fstAccu + sndAccu;
0734: m_Cutoff = split;
0735: minDistTo0 = Math.abs(split);
0736: }
0737: }
0738: }
0739:
0740: /**
0741: * Returns an enumeration describing the available options
0742: *
0743: * @return an enumeration of all the available options
0744: */
0745: public Enumeration listOptions() {
0746: Vector result = new Vector();
0747:
0748: result.addElement(new Option(
0749: "\tSet whether or not use empirical\n"
0750: + "\tlog-odds cut-off instead of 0", "C", 0,
0751: "-C"));
0752:
0753: result.addElement(new Option(
0754: "\tSet the number of multiple runs \n"
0755: + "\tneeded for searching the MLE.", "R", 1,
0756: "-R <numOfRuns>"));
0757:
0758: Enumeration enu = super .listOptions();
0759: while (enu.hasMoreElements()) {
0760: result.addElement(enu.nextElement());
0761: }
0762:
0763: return result.elements();
0764: }
0765:
0766: /**
0767: * Parses a given list of options. <p/>
0768: *
0769: <!-- options-start -->
0770: * Valid options are: <p/>
0771: *
0772: * <pre> -C
0773: * Set whether or not use empirical
0774: * log-odds cut-off instead of 0</pre>
0775: *
0776: * <pre> -R <numOfRuns>
0777: * Set the number of multiple runs
0778: * needed for searching the MLE.</pre>
0779: *
0780: * <pre> -S <num>
0781: * Random number seed.
0782: * (default 1)</pre>
0783: *
0784: * <pre> -D
0785: * If set, classifier is run in debug mode and
0786: * may output additional info to the console</pre>
0787: *
0788: <!-- options-end -->
0789: *
0790: * @param options the list of options as an array of strings
0791: * @throws Exception if an option is not supported
0792: */
0793: public void setOptions(String[] options) throws Exception {
0794: setDebug(Utils.getFlag('D', options));
0795:
0796: setUsingCutOff(Utils.getFlag('C', options));
0797:
0798: String runString = Utils.getOption('R', options);
0799: if (runString.length() != 0)
0800: setNumRuns(Integer.parseInt(runString));
0801: else
0802: setNumRuns(1);
0803:
0804: super .setOptions(options);
0805: }
0806:
0807: /**
0808: * Gets the current settings of the Classifier.
0809: *
0810: * @return an array of strings suitable for passing to setOptions
0811: */
0812: public String[] getOptions() {
0813: Vector result;
0814: String[] options;
0815: int i;
0816:
0817: result = new Vector();
0818: options = super .getOptions();
0819: for (i = 0; i < options.length; i++)
0820: result.add(options[i]);
0821:
0822: if (getDebug())
0823: result.add("-D");
0824:
0825: if (getUsingCutOff())
0826: result.add("-C");
0827:
0828: result.add("-R");
0829: result.add("" + getNumRuns());
0830:
0831: return (String[]) result.toArray(new String[result.size()]);
0832: }
0833:
0834: /**
0835: * Returns the tip text for this property
0836: *
0837: * @return tip text for this property suitable for
0838: * displaying in the explorer/experimenter gui
0839: */
0840: public String numRunsTipText() {
0841: return "The number of runs to perform.";
0842: }
0843:
0844: /**
0845: * Sets the number of runs to perform.
0846: *
0847: * @param numRuns the number of runs to perform
0848: */
0849: public void setNumRuns(int numRuns) {
0850: m_Run = numRuns;
0851: }
0852:
0853: /**
0854: * Returns the number of runs to perform.
0855: *
0856: * @return the number of runs to perform
0857: */
0858: public int getNumRuns() {
0859: return m_Run;
0860: }
0861:
0862: /**
0863: * Returns the tip text for this property
0864: *
0865: * @return tip text for this property suitable for
0866: * displaying in the explorer/experimenter gui
0867: */
0868: public String usingCutOffTipText() {
0869: return "Whether to use an empirical cutoff.";
0870: }
0871:
0872: /**
0873: * Sets whether to use an empirical cutoff.
0874: *
0875: * @param cutOff whether to use an empirical cutoff
0876: */
0877: public void setUsingCutOff(boolean cutOff) {
0878: m_UseEmpiricalCutOff = cutOff;
0879: }
0880:
0881: /**
0882: * Returns whether an empirical cutoff is used
0883: *
0884: * @return true if an empirical cutoff is used
0885: */
0886: public boolean getUsingCutOff() {
0887: return m_UseEmpiricalCutOff;
0888: }
0889:
0890: /**
0891: * Main method for testing.
0892: *
0893: * @param args the options for the classifier
0894: */
0895: public static void main(String[] args) {
0896: runClassifier(new TLD(), args);
0897: }
0898: }
0899:
0900: class TLD_Optm extends Optimization {
0901:
0902: private double[] num;
0903: private double[] sSq;
0904: private double[] xBar;
0905:
0906: public void setNum(double[] n) {
0907: num = n;
0908: }
0909:
0910: public void setSSquare(double[] s) {
0911: sSq = s;
0912: }
0913:
0914: public void setXBar(double[] x) {
0915: xBar = x;
0916: }
0917:
0918: /**
0919: * Compute Ln[Gamma(b+0.5)] - Ln[Gamma(b)]
0920: *
0921: * @param b the value in the above formula
0922: * @return the result
0923: */
0924: public static double diffLnGamma(double b) {
0925: double[] coef = { 76.18009172947146, -86.50532032941677,
0926: 24.01409824083091, -1.231739572450155,
0927: 0.1208650973866179e-2, -0.5395239384953e-5 };
0928: double rt = -0.5;
0929: rt += (b + 1.0) * Math.log(b + 6.0) - (b + 0.5)
0930: * Math.log(b + 5.5);
0931: double series1 = 1.000000000190015, series2 = 1.000000000190015;
0932: for (int i = 0; i < 6; i++) {
0933: series1 += coef[i] / (b + 1.5 + (double) i);
0934: series2 += coef[i] / (b + 1.0 + (double) i);
0935: }
0936:
0937: rt += Math.log(series1 * b) - Math.log(series2 * (b + 0.5));
0938: return rt;
0939: }
0940:
0941: /**
0942: * Compute dLn[Gamma(x+0.5)]/dx - dLn[Gamma(x)]/dx
0943: *
0944: * @param x the value in the above formula
0945: * @return the result
0946: */
0947: protected double diffFstDervLnGamma(double x) {
0948: double rt = 0, series = 1.0;// Just make it >0
0949: for (int i = 0; series >= m_Zero * 1e-3; i++) {
0950: series = 0.5 / ((x + (double) i) * (x + (double) i + 0.5));
0951: rt += series;
0952: }
0953: return rt;
0954: }
0955:
0956: /**
0957: * Compute {Ln[Gamma(x+0.5)]}'' - {Ln[Gamma(x)]}''
0958: *
0959: * @param x the value in the above formula
0960: * @return the result
0961: */
0962: protected double diffSndDervLnGamma(double x) {
0963: double rt = 0, series = 1.0;// Just make it >0
0964: for (int i = 0; series >= m_Zero * 1e-3; i++) {
0965: series = (x + (double) i + 0.25)
0966: / ((x + (double) i) * (x + (double) i)
0967: * (x + (double) i + 0.5) * (x + (double) i + 0.5));
0968: rt -= series;
0969: }
0970: return rt;
0971: }
0972:
0973: /**
0974: * Implement this procedure to evaluate objective
0975: * function to be minimized
0976: */
0977: protected double objectiveFunction(double[] x) {
0978: int numExs = num.length;
0979: double NLL = 0; // Negative Log-Likelihood
0980:
0981: double a = x[0], b = x[1], w = x[2], m = x[3];
0982: for (int j = 0; j < numExs; j++) {
0983:
0984: if (Double.isNaN(xBar[j]))
0985: continue; // All missing values
0986:
0987: NLL += 0.5
0988: * (b + num[j])
0989: * Math.log((1.0 + num[j] * w) * (a + sSq[j])
0990: + num[j] * (xBar[j] - m) * (xBar[j] - m));
0991:
0992: if (Double.isNaN(NLL) && m_Debug) {
0993: System.err.println("???????????1: " + a + " " + b + " "
0994: + w + " " + m + "|x-: " + xBar[j] + "|n: "
0995: + num[j] + "|S^2: " + sSq[j]);
0996: System.exit(1);
0997: }
0998:
0999: // Doesn't affect optimization
1000: //NLL += 0.5*num[j]*Math.log(Math.PI);
1001:
1002: NLL -= 0.5 * (b + num[j] - 1.0)
1003: * Math.log(1.0 + num[j] * w);
1004:
1005: if (Double.isNaN(NLL) && m_Debug) {
1006: System.err.println("???????????2: " + a + " " + b + " "
1007: + w + " " + m + "|x-: " + xBar[j] + "|n: "
1008: + num[j] + "|S^2: " + sSq[j]);
1009: System.exit(1);
1010: }
1011:
1012: int halfNum = ((int) num[j]) / 2;
1013: for (int z = 1; z <= halfNum; z++)
1014: NLL -= Math.log(0.5 * b + 0.5 * num[j] - (double) z);
1015:
1016: if (0.5 * num[j] > halfNum) // num[j] is odd
1017: NLL -= diffLnGamma(0.5 * b);
1018:
1019: if (Double.isNaN(NLL) && m_Debug) {
1020: System.err.println("???????????3: " + a + " " + b + " "
1021: + w + " " + m + "|x-: " + xBar[j] + "|n: "
1022: + num[j] + "|S^2: " + sSq[j]);
1023: System.exit(1);
1024: }
1025:
1026: NLL -= 0.5 * Math.log(a) * b;
1027: if (Double.isNaN(NLL) && m_Debug) {
1028: System.err.println("???????????4:" + a + " " + b + " "
1029: + w + " " + m);
1030: System.exit(1);
1031: }
1032: }
1033: if (m_Debug)
1034: System.err.println("?????????????5: " + NLL);
1035: if (Double.isNaN(NLL))
1036: System.exit(1);
1037:
1038: return NLL;
1039: }
1040:
1041: /**
1042: * Subclass should implement this procedure to evaluate gradient
1043: * of the objective function
1044: */
1045: protected double[] evaluateGradient(double[] x) {
1046: double[] g = new double[x.length];
1047: int numExs = num.length;
1048:
1049: double a = x[0], b = x[1], w = x[2], m = x[3];
1050:
1051: double da = 0.0, db = 0.0, dw = 0.0, dm = 0.0;
1052: for (int j = 0; j < numExs; j++) {
1053:
1054: if (Double.isNaN(xBar[j]))
1055: continue; // All missing values
1056:
1057: double denorm = (1.0 + num[j] * w) * (a + sSq[j]) + num[j]
1058: * (xBar[j] - m) * (xBar[j] - m);
1059:
1060: da += 0.5 * (b + num[j]) * (1.0 + num[j] * w) / denorm
1061: - 0.5 * b / a;
1062:
1063: db += 0.5 * Math.log(denorm) - 0.5
1064: * Math.log(1.0 + num[j] * w) - 0.5 * Math.log(a);
1065:
1066: int halfNum = ((int) num[j]) / 2;
1067: for (int z = 1; z <= halfNum; z++)
1068: db -= 1.0 / (b + num[j] - 2.0 * (double) z);
1069: if (num[j] / 2.0 > halfNum) // num[j] is odd
1070: db -= 0.5 * diffFstDervLnGamma(0.5 * b);
1071:
1072: dw += 0.5 * (b + num[j]) * (a + sSq[j]) * num[j] / denorm
1073: - 0.5 * (b + num[j] - 1.0) * num[j]
1074: / (1.0 + num[j] * w);
1075:
1076: dm += num[j] * (b + num[j]) * (m - xBar[j]) / denorm;
1077: }
1078:
1079: g[0] = da;
1080: g[1] = db;
1081: g[2] = dw;
1082: g[3] = dm;
1083: return g;
1084: }
1085:
1086: /**
1087: * Subclass should implement this procedure to evaluate second-order
1088: * gradient of the objective function
1089: */
1090: protected double[] evaluateHessian(double[] x, int index) {
1091: double[] h = new double[x.length];
1092:
1093: // # of exemplars, # of dimensions
1094: // which dimension and which variable for 'index'
1095: int numExs = num.length;
1096: double a, b, w, m;
1097: // Take the 2nd-order derivative
1098: switch (index) {
1099: case 0: // a
1100: a = x[0];
1101: b = x[1];
1102: w = x[2];
1103: m = x[3];
1104:
1105: for (int j = 0; j < numExs; j++) {
1106: if (Double.isNaN(xBar[j]))
1107: continue; //All missing values
1108: double denorm = (1.0 + num[j] * w) * (a + sSq[j])
1109: + num[j] * (xBar[j] - m) * (xBar[j] - m);
1110:
1111: h[0] += 0.5 * b / (a * a) - 0.5 * (b + num[j])
1112: * (1.0 + num[j] * w) * (1.0 + num[j] * w)
1113: / (denorm * denorm);
1114:
1115: h[1] += 0.5 * (1.0 + num[j] * w) / denorm - 0.5 / a;
1116:
1117: h[2] += 0.5 * num[j] * num[j] * (b + num[j])
1118: * (xBar[j] - m) * (xBar[j] - m)
1119: / (denorm * denorm);
1120:
1121: h[3] -= num[j] * (b + num[j]) * (m - xBar[j])
1122: * (1.0 + num[j] * w) / (denorm * denorm);
1123: }
1124: break;
1125:
1126: case 1: // b
1127: a = x[0];
1128: b = x[1];
1129: w = x[2];
1130: m = x[3];
1131:
1132: for (int j = 0; j < numExs; j++) {
1133: if (Double.isNaN(xBar[j]))
1134: continue; //All missing values
1135: double denorm = (1.0 + num[j] * w) * (a + sSq[j])
1136: + num[j] * (xBar[j] - m) * (xBar[j] - m);
1137:
1138: h[0] += 0.5 * (1.0 + num[j] * w) / denorm - 0.5 / a;
1139:
1140: int halfNum = ((int) num[j]) / 2;
1141: for (int z = 1; z <= halfNum; z++)
1142: h[1] += 1.0 / ((b + num[j] - 2.0 * (double) z) * (b
1143: + num[j] - 2.0 * (double) z));
1144: if (num[j] / 2.0 > halfNum) // num[j] is odd
1145: h[1] -= 0.25 * diffSndDervLnGamma(0.5 * b);
1146:
1147: h[2] += 0.5 * (a + sSq[j]) * num[j] / denorm - 0.5
1148: * num[j] / (1.0 + num[j] * w);
1149:
1150: h[3] += num[j] * (m - xBar[j]) / denorm;
1151: }
1152: break;
1153:
1154: case 2: // w
1155: a = x[0];
1156: b = x[1];
1157: w = x[2];
1158: m = x[3];
1159:
1160: for (int j = 0; j < numExs; j++) {
1161: if (Double.isNaN(xBar[j]))
1162: continue; //All missing values
1163: double denorm = (1.0 + num[j] * w) * (a + sSq[j])
1164: + num[j] * (xBar[j] - m) * (xBar[j] - m);
1165:
1166: h[0] += 0.5 * num[j] * num[j] * (b + num[j])
1167: * (xBar[j] - m) * (xBar[j] - m)
1168: / (denorm * denorm);
1169:
1170: h[1] += 0.5 * (a + sSq[j]) * num[j] / denorm - 0.5
1171: * num[j] / (1.0 + num[j] * w);
1172:
1173: h[2] += 0.5 * (b + num[j] - 1.0) * num[j] * num[j]
1174: / ((1.0 + num[j] * w) * (1.0 + num[j] * w))
1175: - 0.5 * (b + num[j]) * (a + sSq[j])
1176: * (a + sSq[j]) * num[j] * num[j]
1177: / (denorm * denorm);
1178:
1179: h[3] -= num[j] * num[j] * (b + num[j]) * (m - xBar[j])
1180: * (a + sSq[j]) / (denorm * denorm);
1181: }
1182: break;
1183:
1184: case 3: // m
1185: a = x[0];
1186: b = x[1];
1187: w = x[2];
1188: m = x[3];
1189:
1190: for (int j = 0; j < numExs; j++) {
1191: if (Double.isNaN(xBar[j]))
1192: continue; //All missing values
1193: double denorm = (1.0 + num[j] * w) * (a + sSq[j])
1194: + num[j] * (xBar[j] - m) * (xBar[j] - m);
1195:
1196: h[0] -= num[j] * (b + num[j]) * (m - xBar[j])
1197: * (1.0 + num[j] * w) / (denorm * denorm);
1198:
1199: h[1] += num[j] * (m - xBar[j]) / denorm;
1200:
1201: h[2] -= num[j] * num[j] * (b + num[j]) * (m - xBar[j])
1202: * (a + sSq[j]) / (denorm * denorm);
1203:
1204: h[3] += num[j]
1205: * (b + num[j])
1206: * ((1.0 + num[j] * w) * (a + sSq[j]) - num[j]
1207: * (m - xBar[j]) * (m - xBar[j]))
1208: / (denorm * denorm);
1209: }
1210: }
1211:
1212: return h;
1213: }
1214: }
|