001: package net.sf.saxon.functions;
002:
003: import net.sf.saxon.expr.*;
004: import net.sf.saxon.om.Item;
005: import net.sf.saxon.om.SequenceIterator;
006: import net.sf.saxon.trans.DynamicError;
007: import net.sf.saxon.trans.XPathException;
008: import net.sf.saxon.type.ItemType;
009: import net.sf.saxon.type.Type;
010: import net.sf.saxon.type.TypeHierarchy;
011: import net.sf.saxon.value.*;
012:
013: /**
014: * This class implements the sum(), avg(), count() functions,
015: */
016:
017: public class Aggregate extends SystemFunction {
018:
019: public static final int SUM = 0;
020: public static final int AVG = 1;
021: public static final int COUNT = 4;
022:
023: /**
024: * Static analysis: prevent sorting of the argument
025: */
026:
027: public void checkArguments(StaticContext env) throws XPathException {
028: super .checkArguments(env);
029: Optimizer opt = env.getConfiguration().getOptimizer();
030: argument[0] = ExpressionTool.unsorted(opt, argument[0], true);
031: }
032:
033: /**
034: * Determine the item type of the value returned by the function
035: * @param th
036: */
037:
038: public ItemType getItemType(TypeHierarchy th) {
039: switch (operation) {
040: case COUNT:
041: return super .getItemType(th);
042: case SUM: {
043: //ItemType base = argument[0].getItemType();
044: ItemType base = Atomizer.getAtomizedItemType(argument[0],
045: false, th);
046: if (base == Type.UNTYPED_ATOMIC_TYPE) {
047: base = Type.DOUBLE_TYPE;
048: }
049: if (Cardinality.allowsZero(argument[0].getCardinality())) {
050: if (argument.length == 1) {
051: return Type.getCommonSuperType(base,
052: Type.INTEGER_TYPE, th);
053: } else {
054: return Type.getCommonSuperType(base, argument[1]
055: .getItemType(th), th);
056: }
057: } else {
058: return base;
059: }
060: }
061: case AVG: {
062: ItemType base = Atomizer.getAtomizedItemType(argument[0],
063: false, th);
064: if (base == Type.UNTYPED_ATOMIC_TYPE) {
065: return Type.DOUBLE_TYPE;
066: } else if (base.getPrimitiveType() == Type.INTEGER) {
067: return Type.DECIMAL_TYPE;
068: } else {
069: return base;
070: }
071: }
072: default:
073: throw new AssertionError("Unknown aggregate operation");
074: }
075: }
076:
077: /**
078: * Evaluate the function
079: */
080:
081: public Item evaluateItem(XPathContext context)
082: throws XPathException {
083: // Note: these functions do not need to sort the underlying sequence,
084: // but they do need to de-duplicate it
085: switch (operation) {
086: case COUNT:
087: SequenceIterator iter = argument[0].iterate(context);
088: return new IntegerValue(count(iter));
089: case SUM:
090: return total(argument[0].iterate(context), context);
091: case AVG:
092: return average(argument[0].iterate(context), context);
093: default:
094: throw new UnsupportedOperationException(
095: "Unknown aggregate function");
096: }
097: }
098:
099: /**
100: * Calculate total
101: */
102:
103: private AtomicValue total(SequenceIterator iter,
104: XPathContext context) throws XPathException {
105: AtomicValue sum = (AtomicValue) iter.next();
106: if (sum == null) {
107: // the sequence is empty
108: if (argument.length == 2) {
109: return (AtomicValue) argument[1].evaluateItem(context);
110: } else {
111: return IntegerValue.ZERO;
112: }
113: }
114: if (!sum.hasBuiltInType()) {
115: sum = sum.getPrimitiveValue();
116: }
117: if (sum instanceof UntypedAtomicValue) {
118: sum = sum.convert(Type.DOUBLE, context);
119: }
120: if (sum instanceof NumericValue) {
121: while (true) {
122: AtomicValue nextVal = (AtomicValue) iter.next();
123: if (nextVal == null) {
124: return sum;
125: }
126: AtomicValue next = nextVal.getPrimitiveValue();
127: if (next instanceof UntypedAtomicValue) {
128: next = next.convert(Type.DOUBLE, context);
129: } else if (!(next instanceof NumericValue)) {
130: DynamicError err = new DynamicError(
131: "Input to sum() contains a mix of numeric and non-numeric values");
132: err.setXPathContext(context);
133: err.setErrorCode("FORG0006");
134: err.setLocator(this );
135: throw err;
136: }
137: sum = ((NumericValue) sum).arithmetic(Token.PLUS,
138: (NumericValue) next, context);
139: if (((NumericValue) sum).isNaN()) {
140: // take an early bath, once we've got a NaN it's not going to change
141: return sum;
142: }
143: }
144: } else if (sum instanceof DurationValue) {
145: while (true) {
146: AtomicValue nextVal = (AtomicValue) iter.next();
147: if (nextVal == null) {
148: return sum;
149: }
150: AtomicValue next = nextVal.getPrimitiveValue();
151: if (!(next instanceof DurationValue)) {
152: DynamicError err = new DynamicError(
153: "Input to sum() contains a mix of duration and non-duration values");
154: err.setXPathContext(context);
155: err.setErrorCode("FORG0006");
156: err.setLocator(this );
157: throw err;
158: }
159: sum = ((DurationValue) sum).add((DurationValue) next,
160: context);
161: }
162: } else {
163: DynamicError err = new DynamicError(
164: "Input to sum() contains a value that is neither numeric, nor a duration");
165: err.setXPathContext(context);
166: err.setErrorCode("FORG0006");
167: err.setLocator(this );
168: throw err;
169: }
170: }
171:
172: /**
173: * Calculate average
174: */
175:
176: private AtomicValue average(SequenceIterator iter,
177: XPathContext context) throws XPathException {
178: int count = 0;
179: AtomicValue sum = (AtomicValue) iter.next();
180: if (sum == null) {
181: // the sequence is empty
182: return null;
183: }
184: count++;
185: if (!sum.hasBuiltInType()) {
186: sum = sum.getPrimitiveValue();
187: }
188: if (sum instanceof UntypedAtomicValue) {
189: sum = sum.convert(Type.DOUBLE, context);
190: }
191: if (sum instanceof NumericValue) {
192: while (true) {
193: AtomicValue nextVal = (AtomicValue) iter.next();
194: if (nextVal == null) {
195: return ((NumericValue) sum).arithmetic(Token.DIV,
196: new IntegerValue(count), context);
197: }
198: count++;
199: AtomicValue next = nextVal.getPrimitiveValue();
200: if (next instanceof UntypedAtomicValue) {
201: next = next.convert(Type.DOUBLE, context);
202: } else if (!(next instanceof NumericValue)) {
203: DynamicError err = new DynamicError(
204: "Input to avg() contains a mix of numeric and non-numeric values");
205: err.setXPathContext(context);
206: err.setErrorCode("FORG0006");
207: err.setLocator(this );
208: throw err;
209: }
210: sum = ((NumericValue) sum).arithmetic(Token.PLUS,
211: (NumericValue) next, context);
212: if (((NumericValue) sum).isNaN()) {
213: // take an early bath, once we've got a NaN it's not going to change
214: return sum;
215: }
216: }
217: } else if (sum instanceof DurationValue) {
218: while (true) {
219: AtomicValue nextVal = (AtomicValue) iter.next();
220: if (nextVal == null) {
221: return ((DurationValue) sum).multiply(1.0 / count,
222: context);
223: }
224: count++;
225: AtomicValue next = nextVal.getPrimitiveValue();
226: if (!(next instanceof DurationValue)) {
227: DynamicError err = new DynamicError(
228: "Input to avg() contains a mix of duration and non-duration values");
229: err.setXPathContext(context);
230: err.setErrorCode("FORG0006");
231: err.setLocator(this );
232: throw err;
233: }
234: sum = ((DurationValue) sum).add((DurationValue) next,
235: context);
236: }
237: } else {
238: DynamicError err = new DynamicError(
239: "Input to avg() contains a value that is neither numeric, nor a duration");
240: err.setXPathContext(context);
241: err.setErrorCode("FORG0006");
242: err.setLocator(this );
243: throw err;
244: }
245: }
246:
247: /**
248: * Get the number of items in a sequence identified by a SequenceIterator
249: * @param iter The SequenceIterator. This method moves the current position
250: * of the supplied iterator; if this isn't safe, make a copy of the iterator
251: * first by calling getAnother(). The supplied iterator must be positioned
252: * before the first item (there must have been no call on next()).
253: * @return the number of items in the underlying sequence
254: * @throws XPathException if a failure occurs reading the input sequence
255: */
256:
257: public static int count(SequenceIterator iter)
258: throws XPathException {
259: if ((iter.getProperties() & SequenceIterator.LAST_POSITION_FINDER) != 0) {
260: return ((LastPositionFinder) iter).getLastPosition();
261: } else {
262: int n = 0;
263: while (iter.next() != null) {
264: n++;
265: }
266: return n;
267: }
268: }
269:
270: /**
271: * Determine whether a given expression is a call to the count() function
272: */
273:
274: public static boolean isCountFunction(Expression exp) {
275: if (!(exp instanceof Aggregate))
276: return false;
277: Aggregate ag = (Aggregate) exp;
278: return ag.getNumberOfArguments() == 1 && ag.operation == COUNT;
279: }
280:
281: }
282:
283: //
284: // The contents of this file are subject to the Mozilla Public License Version 1.0 (the "License");
285: // you may not use this file except in compliance with the License. You may obtain a copy of the
286: // License at http://www.mozilla.org/MPL/
287: //
288: // Software distributed under the License is distributed on an "AS IS" basis,
289: // WITHOUT WARRANTY OF ANY KIND, either express or implied.
290: // See the License for the specific language governing rights and limitations under the License.
291: //
292: // The Original Code is: all this file.
293: //
294: // The Initial Developer of the Original Code is Michael H. Kay.
295: //
296: // Portions created by (your name) are Copyright (C) (your legal entity). All Rights Reserved.
297: //
298: // Contributor(s): none.
299: //
|