001: /*
002: * Copyright (c) 1998-2008 Caucho Technology -- all rights reserved
003: *
004: * This file is part of Resin(R) Open Source
005: *
006: * Each copy or derived work must preserve the copyright notice and this
007: * notice unmodified.
008: *
009: * Resin Open Source is free software; you can redistribute it and/or modify
010: * it under the terms of the GNU General Public License as published by
011: * the Free Software Foundation; either version 2 of the License, or
012: * (at your option) any later version.
013: *
014: * Resin Open Source is distributed in the hope that it will be useful,
015: * but WITHOUT ANY WARRANTY; without even the implied warranty of
016: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE, or any warranty
017: * of NON-INFRINGEMENT. See the GNU General Public License for more
018: * details.
019: *
020: * You should have received a copy of the GNU General Public License
021: * along with Resin Open Source; if not, write to the
022: *
023: * Free Software Foundation, Inc.
024: * 59 Temple Place, Suite 330
025: * Boston, MA 02111-1307 USA
026: *
027: * @author Scott Ferguson
028: */
029:
030: package com.caucho.ejb.gen;
031:
032: import com.caucho.java.JavaWriter;
033: import com.caucho.util.L10N;
034: import com.caucho.webbeans.component.*;
035: import com.caucho.webbeans.manager.*;
036:
037: import java.io.*;
038: import java.lang.annotation.*;
039: import java.lang.reflect.*;
040: import java.util.*;
041: import javax.annotation.security.*;
042: import javax.ejb.*;
043: import javax.interceptor.*;
044: import javax.webbeans.*;
045:
046: /**
047: * Represents the interception
048: */
049: public class InterceptorCallChain extends AbstractCallChain {
050: private static final L10N L = new L10N(InterceptorCallChain.class);
051:
052: private View _view;
053: private BusinessMethodGenerator _next;
054:
055: private String _uniqueName;
056: private Method _implMethod;
057:
058: private ArrayList<Class> _defaultInterceptors = new ArrayList<Class>();
059: private ArrayList<Class> _classInterceptors = new ArrayList<Class>();
060: private ArrayList<Class> _methodInterceptors = new ArrayList<Class>();
061:
062: private boolean _isExcludeDefaultInterceptors;
063: private boolean _isExcludeClassInterceptors;
064:
065: private ArrayList<Class> _interceptors = new ArrayList<Class>();
066:
067: // map from the interceptor class to the local variable for the interceptor
068: private HashMap<Class, String> _interceptorVarMap = new HashMap<Class, String>();
069:
070: // interceptors we're responsible for initializing
071: private ArrayList<Class> _ownInterceptors = new ArrayList<Class>();
072:
073: public InterceptorCallChain(BusinessMethodGenerator next, View view) {
074: super (next);
075:
076: _next = next;
077: _view = view;
078: }
079:
080: /**
081: * Returns true if the business method has any active XA annotation.
082: */
083: public boolean isEnhanced() {
084: return (_defaultInterceptors.size() > 0
085: || (!_isExcludeDefaultInterceptors && _view.getBean()
086: .getDefaultInterceptors().size() > 0)
087: || _classInterceptors.size() > 0
088: || _methodInterceptors.size() > 0 || getAroundInvokeMethod() != null);
089: }
090:
091: public ArrayList<Class> getInterceptors() {
092: return _interceptors;
093: }
094:
095: public Method getAroundInvokeMethod() {
096: return _view.getAroundInvokeMethod();
097: }
098:
099: /**
100: * Introspects the @Interceptors annotation on the method
101: * and the class.
102: */
103: public void introspect(Method apiMethod, Method implMethod) {
104: if (implMethod == null)
105: return;
106:
107: Class apiClass = apiMethod.getDeclaringClass();
108:
109: Class implClass = implMethod.getDeclaringClass();
110:
111: _implMethod = implMethod;
112:
113: Interceptors iAnn;
114:
115: iAnn = (Interceptors) apiClass
116: .getAnnotation(Interceptors.class);
117:
118: if (iAnn != null) {
119: for (Class iClass : iAnn.value())
120: _classInterceptors.add(iClass);
121: }
122:
123: if (implClass != null) {
124: iAnn = (Interceptors) implClass
125: .getAnnotation(Interceptors.class);
126:
127: if (apiMethod != implMethod && iAnn != null) {
128: for (Class iClass : iAnn.value())
129: _classInterceptors.add(iClass);
130: }
131: }
132:
133: iAnn = (Interceptors) apiMethod
134: .getAnnotation(Interceptors.class);
135:
136: if (iAnn != null) {
137: for (Class iClass : iAnn.value())
138: _methodInterceptors.add(iClass);
139: }
140:
141: if (implMethod != null) {
142: iAnn = (Interceptors) implMethod
143: .getAnnotation(Interceptors.class);
144:
145: if (apiMethod != implMethod && iAnn != null) {
146: for (Class iClass : iAnn.value())
147: _methodInterceptors.add(iClass);
148: }
149: }
150:
151: if (apiMethod
152: .isAnnotationPresent(ExcludeClassInterceptors.class))
153: _isExcludeClassInterceptors = true;
154:
155: if (implMethod
156: .isAnnotationPresent(ExcludeClassInterceptors.class))
157: _isExcludeClassInterceptors = true;
158:
159: if (apiMethod
160: .isAnnotationPresent(ExcludeDefaultInterceptors.class))
161: _isExcludeDefaultInterceptors = true;
162:
163: if (implMethod
164: .isAnnotationPresent(ExcludeDefaultInterceptors.class))
165: _isExcludeDefaultInterceptors = true;
166:
167: // webbeans annotations
168: WebBeansContainer webBeans = WebBeansContainer.create();
169:
170: ArrayList<Annotation> interceptorTypes = new ArrayList<Annotation>();
171: for (Annotation ann : implMethod.getAnnotations()) {
172: Class annType = ann.annotationType();
173:
174: if (annType
175: .isAnnotationPresent(InterceptorBindingType.class))
176: interceptorTypes.add(ann);
177: }
178:
179: if (interceptorTypes.size() > 0) {
180: ArrayList<Class> interceptors = webBeans
181: .findInterceptors(interceptorTypes);
182:
183: if (interceptors != null)
184: _methodInterceptors.addAll(interceptors);
185: }
186: }
187:
188: @Override
189: public void generatePrologue(JavaWriter out, HashMap map)
190: throws IOException {
191: if (!isEnhanced()) {
192: _next.generatePrologue(out, map);
193: return;
194: }
195:
196: if (!_isExcludeDefaultInterceptors)
197: _interceptors.addAll(_view.getBean()
198: .getDefaultInterceptors());
199:
200: // ejb/0fb6
201: if (!_isExcludeClassInterceptors && _interceptors.size() == 0)
202: _interceptors.addAll(_classInterceptors);
203:
204: _interceptors.addAll(_methodInterceptors);
205:
206: if (_interceptors.size() == 0
207: && getAroundInvokeMethod() == null)
208: return;
209:
210: _uniqueName = "_v" + out.generateId();
211:
212: out.println();
213: out.println("private static java.lang.reflect.Method "
214: + _uniqueName + "_method;");
215: out.println("private static java.lang.reflect.Method "
216: + _uniqueName + "_implMethod;");
217:
218: boolean isAroundInvokePrologue = false;
219: if (getAroundInvokeMethod() != null
220: && map.get("ejb.around-invoke") == null) {
221: isAroundInvokePrologue = true;
222: map.put("ejb.around-invoke", "_caucho_aroundInvokeMethod");
223:
224: out
225: .println("private static java.lang.reflect.Method __caucho_aroundInvokeMethod;");
226: }
227:
228: out.println("private static java.lang.reflect.Method []"
229: + _uniqueName + "_methodChain;");
230: out.println("private transient Object []" + _uniqueName
231: + "_objectChain;");
232:
233: Class cl = _implMethod.getDeclaringClass();
234:
235: out.println();
236: out.println("static {");
237: out.pushDepth();
238:
239: out.println("try {");
240: out.pushDepth();
241:
242: out.print(_uniqueName + "_method = ");
243: generateGetMethod(out, _implMethod.getDeclaringClass()
244: .getName(), _implMethod.getName(), _implMethod
245: .getParameterTypes());
246: out.println(";");
247: out.println(_uniqueName + "_method.setAccessible(true);");
248:
249: out.print(_uniqueName + "_implMethod = ");
250: generateGetMethod(out, _next.getView().getBeanClassName(),
251: "__caucho_" + _implMethod.getName(), _implMethod
252: .getParameterTypes());
253: out.println(";");
254: out.println(_uniqueName + "_implMethod.setAccessible(true);");
255:
256: if (isAroundInvokePrologue) {
257: Method aroundInvoke = getAroundInvokeMethod();
258:
259: out.print("__caucho_aroundInvokeMethod = ");
260: generateGetMethod(out, aroundInvoke.getDeclaringClass()
261: .getName(), aroundInvoke.getName(), aroundInvoke
262: .getParameterTypes());
263: out.println(";");
264: out
265: .println("__caucho_aroundInvokeMethod.setAccessible(true);");
266: }
267:
268: generateMethodChain(out);
269:
270: out.popDepth();
271: out.println("} catch (Exception e) {");
272: out.println(" throw new RuntimeException(e);");
273: out.println("}");
274: out.popDepth();
275: out.println("}");
276:
277: for (Class iClass : _interceptors) {
278: String var = (String) map.get("interceptor-"
279: + iClass.getName());
280: if (var == null) {
281: var = "__caucho_i" + out.generateId();
282:
283: out.println();
284: out.print("private static ");
285: out.printClass(ComponentFactory.class);
286: out.println(" " + var + "_f;");
287:
288: out.print("private transient ");
289: out.printClass(iClass);
290: out.println(" " + var + ";");
291:
292: map.put("interceptor-" + iClass.getName(), var);
293:
294: _ownInterceptors.add(iClass);
295: }
296:
297: _interceptorVarMap.put(iClass, var);
298: }
299:
300: _next.generatePrologue(out, map);
301: }
302:
303: @Override
304: public void generateConstructor(JavaWriter out, HashMap map)
305: throws IOException {
306: for (Class iClass : _ownInterceptors) {
307: String var = _interceptorVarMap.get(iClass);
308:
309: out.println("if (" + var + "_f == null)");
310: out
311: .println(" "
312: + var
313: + "_f = com.caucho.webbeans.manager.WebBeansContainer.create().createTransient("
314: + iClass.getName() + ".class);");
315:
316: out.print(var + " = (");
317: out.printClass(iClass);
318: out.println(")" + var + "_f.get();");
319: }
320:
321: _next.generateConstructor(out, map);
322: }
323:
324: @Override
325: public void generateCall(JavaWriter out) throws IOException {
326: if (_interceptors.size() == 0
327: && getAroundInvokeMethod() == null) {
328: _next.generateCall(out);
329: return;
330: }
331:
332: out.println("try {");
333: out.pushDepth();
334:
335: out.println("if (" + _uniqueName + "_objectChain == null) {");
336: out.pushDepth();
337: generateObjectChain(out);
338: out.popDepth();
339: out.println("}");
340:
341: if (!void.class.equals(_implMethod.getReturnType())) {
342: out.printClass(_implMethod.getReturnType());
343: out.println(" result;");
344: }
345:
346: if (!void.class.equals(_implMethod.getReturnType())) {
347: out.print("result = (");
348: printCastClass(out, _implMethod.getReturnType());
349: out.print(") ");
350: }
351:
352: out.print("new com.caucho.ejb3.gen.InvocationContextImpl(");
353: out.print("this, ");
354: out.print(_uniqueName + "_method, ");
355: out.print(_uniqueName + "_implMethod, ");
356: out.print(_uniqueName + "_methodChain, ");
357: out.print(_uniqueName + "_objectChain, ");
358: out.print("new Object[] { ");
359: for (int i = 0; i < _implMethod.getParameterTypes().length; i++) {
360: out.print("a" + i + ", ");
361: }
362: out.println("}).proceed();");
363:
364: if (!void.class.equals(_implMethod.getReturnType())) {
365: out.println("return result;");
366: }
367:
368: out.popDepth();
369: out.println("} catch (RuntimeException e) {");
370: out.println(" throw e;");
371:
372: for (Class cl : _implMethod.getExceptionTypes()) {
373: if (!RuntimeException.class.isAssignableFrom(cl)) {
374: out.println("} catch (" + cl.getName() + " e) {");
375: out.println(" throw e;");
376: }
377: }
378:
379: out.println("} catch (Exception e) {");
380: out.println(" throw new RuntimeException(e);");
381: out.println("}");
382: }
383:
384: protected Method findInterceptorMethod(Class cl) {
385: if (cl == null)
386: return null;
387:
388: for (Method method : cl.getDeclaredMethods()) {
389: if (method.isAnnotationPresent(AroundInvoke.class))
390: return method;
391: }
392:
393: return findInterceptorMethod(cl.getSuperclass());
394: }
395:
396: protected void generateMethodChain(JavaWriter out)
397: throws IOException {
398: out.println(_uniqueName
399: + "_methodChain = new java.lang.reflect.Method[] {");
400: out.pushDepth();
401:
402: for (Class iClass : _interceptors) {
403: Method method = findInterceptorMethod(iClass);
404:
405: if (method == null)
406: throw new IllegalStateException(L.l(
407: "Can't find @AroundInvoke in '{0}'", iClass
408: .getName()));
409:
410: generateGetMethod(out, method);
411: out.println(", ");
412: }
413:
414: if (getAroundInvokeMethod() != null) {
415: out.println("__caucho_aroundInvokeMethod, ");
416: }
417:
418: out.popDepth();
419: out.println("};");
420: }
421:
422: protected void generateObjectChain(JavaWriter out)
423: throws IOException {
424: out.print(_uniqueName + "_objectChain = new Object[] {");
425:
426: for (Class iClass : _interceptors) {
427: out.print(_interceptorVarMap.get(iClass) + ", ");
428: }
429:
430: if (getAroundInvokeMethod() != null) {
431: _next.generateThis(out);
432: out.print(", ");
433: }
434:
435: out.println("};");
436: }
437:
438: protected void generateGetMethod(JavaWriter out, Method method)
439: throws IOException {
440: generateGetMethod(out, method.getDeclaringClass().getName(),
441: method.getName(), method.getParameterTypes());
442: }
443:
444: protected void generateGetMethod(JavaWriter out, String className,
445: String methodName, Class[] paramTypes) throws IOException {
446: out.print("com.caucho.ejb.util.EjbUtil.getMethod(");
447: out.print(className + ".class");
448: out.print(", \"" + methodName + "\", new Class[] { ");
449:
450: for (Class type : paramTypes) {
451: out.printClass(type);
452: out.print(".class, ");
453: }
454: out.print("})");
455: }
456:
457: protected void printCastClass(JavaWriter out, Class type)
458: throws IOException {
459: if (!type.isPrimitive())
460: out.printClass(type);
461: else if (boolean.class.equals(type))
462: out.print("Boolean");
463: else if (char.class.equals(type))
464: out.print("Character");
465: else if (byte.class.equals(type))
466: out.print("Byte");
467: else if (short.class.equals(type))
468: out.print("Short");
469: else if (int.class.equals(type))
470: out.print("Integer");
471: else if (long.class.equals(type))
472: out.print("Long");
473: else if (float.class.equals(type))
474: out.print("Float");
475: else if (double.class.equals(type))
476: out.print("Double");
477: else
478: throw new IllegalStateException(type.getName());
479: }
480: }
|