001: package org.drools.base;
002:
003: /*
004: * Copyright 2005 JBoss Inc
005: *
006: * Licensed under the Apache License, Version 2.0 (the "License");
007: * you may not use this file except in compliance with the License.
008: * You may obtain a copy of the License at
009: *
010: * http://www.apache.org/licenses/LICENSE-2.0
011: *
012: * Unless required by applicable law or agreed to in writing, software
013: * distributed under the License is distributed on an "AS IS" BASIS,
014: * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015: * See the License for the specific language governing permissions and
016: * limitations under the License.
017: */
018:
019: import java.lang.reflect.Method;
020: import java.security.AccessController;
021: import java.security.PrivilegedAction;
022: import java.security.ProtectionDomain;
023: import java.util.HashMap;
024: import java.util.Map;
025:
026: import org.drools.RuntimeDroolsException;
027: import org.drools.asm.ClassWriter;
028: import org.drools.asm.Label;
029: import org.drools.asm.MethodVisitor;
030: import org.drools.asm.Opcodes;
031: import org.drools.asm.Type;
032: import org.drools.base.extractors.BaseBooleanClassFieldExtractor;
033: import org.drools.base.extractors.BaseByteClassFieldExtractor;
034: import org.drools.base.extractors.BaseCharClassFieldExtractor;
035: import org.drools.base.extractors.BaseDoubleClassFieldExtractor;
036: import org.drools.base.extractors.BaseFloatClassFieldExtractor;
037: import org.drools.base.extractors.BaseIntClassFieldExtractor;
038: import org.drools.base.extractors.BaseLongClassFieldExtractors;
039: import org.drools.base.extractors.BaseObjectClassFieldExtractor;
040: import org.drools.base.extractors.BaseShortClassFieldExtractor;
041: import org.drools.base.extractors.MVELClassFieldExtractor;
042: import org.drools.base.extractors.SelfReferenceClassFieldExtractor;
043: import org.drools.common.InternalWorkingMemory;
044: import org.drools.util.asm.ClassFieldInspector;
045:
046: /**
047: * This generates subclasses of BaseClassFieldExtractor to provide field extractors.
048: * This should not be used directly, but via ClassFieldExtractor (which ensures that it is
049: * all nicely serializable).
050: *
051: * @author Alexander Bagerman
052: * @author Michael Neale
053: */
054:
055: public class ClassFieldExtractorFactory {
056:
057: private static final String BASE_PACKAGE = "org/drools/base";
058:
059: private static final String SELF_REFERENCE_FIELD = "this";
060:
061: private static final ProtectionDomain PROTECTION_DOMAIN;
062:
063: private static final Map inspectors = new HashMap();
064:
065: private static ByteArrayClassLoader byteArrayClassLoader;
066:
067: static {
068: PROTECTION_DOMAIN = (ProtectionDomain) AccessController
069: .doPrivileged(new PrivilegedAction() {
070: public Object run() {
071: return ClassFieldExtractorFactory.class
072: .getProtectionDomain();
073: }
074: });
075: }
076:
077: public static BaseClassFieldExtractor getClassFieldExtractor(
078: final Class clazz, final String fieldName,
079: final ClassLoader classLoader) {
080: if (byteArrayClassLoader == null) {
081: if (classLoader == null) {
082: throw new RuntimeDroolsException(
083: "ClassFieldExtractorFactory cannot have a null parent ClassLoader");
084: }
085: byteArrayClassLoader = new ByteArrayClassLoader(classLoader);
086: }
087: try {
088: // if it is a self reference
089: if (SELF_REFERENCE_FIELD.equals(fieldName)) {
090: // then just create an instance of the special class field extractor
091: return new SelfReferenceClassFieldExtractor(clazz,
092: fieldName);
093: } else if (fieldName.indexOf('.') > -1
094: || fieldName.indexOf('[') > -1) {
095: // we need MVEL extractor for expressions
096: return new MVELClassFieldExtractor(clazz, fieldName,
097: classLoader);
098: } else {
099: // otherwise, bytecode generate a specific extractor
100: ClassFieldInspector inspector = (ClassFieldInspector) inspectors
101: .get(clazz);
102: if (inspector == null) {
103: inspector = new ClassFieldInspector(clazz);
104: inspectors.put(clazz, inspector);
105: }
106: final Class fieldType = (Class) inspector
107: .getFieldTypes().get(fieldName);
108: final Method getterMethod = (Method) inspector
109: .getGetterMethods().get(fieldName);
110: if (fieldType != null && getterMethod != null) {
111: final String className = ClassFieldExtractorFactory.BASE_PACKAGE
112: + "/"
113: + Type.getInternalName(clazz)
114: + "$"
115: + getterMethod.getName();
116:
117: // generating byte array to create target class
118: final byte[] bytes = dump(clazz, className,
119: getterMethod, fieldType, clazz
120: .isInterface());
121: // use bytes to get a class
122:
123: final Class newClass = byteArrayClassLoader
124: .defineClass(className.replace('/', '.'),
125: bytes, PROTECTION_DOMAIN);
126: // instantiating target class
127: final Integer index = (Integer) inspector
128: .getFieldNames().get(fieldName);
129: final ValueType valueType = ValueType
130: .determineValueType(fieldType);
131: final Object[] params = { index, fieldType,
132: valueType };
133: return (BaseClassFieldExtractor) newClass
134: .getConstructors()[0].newInstance(params);
135: } else {
136: throw new RuntimeDroolsException("Field/method '"
137: + fieldName + "' not found for class '"
138: + clazz.getName() + "'");
139: }
140: }
141: } catch (final RuntimeDroolsException e) {
142: throw e;
143: } catch (final Exception e) {
144: throw new RuntimeDroolsException(e);
145: }
146: }
147:
148: private static byte[] dump(final Class originalClass,
149: final String className, final Method getterMethod,
150: final Class fieldType, final boolean isInterface)
151: throws Exception {
152:
153: final ClassWriter cw = new ClassWriter(true);
154:
155: final Class super Class = getSuperClassFor(fieldType);
156: buildClassHeader(super Class, className, cw);
157:
158: // buildConstructor( superClass,
159: // className,
160: // cw );
161:
162: build3ArgConstructor(super Class, className, cw);
163:
164: buildGetMethod(originalClass, className, super Class,
165: getterMethod, cw);
166:
167: cw.visitEnd();
168:
169: return cw.toByteArray();
170: }
171:
172: /**
173: * Builds the class header
174: *
175: * @param clazz The class to build the extractor for
176: * @param className The extractor class name
177: * @param cw
178: */
179: protected static void buildClassHeader(final Class super Class,
180: final String className, final ClassWriter cw) {
181: cw
182: .visit(Opcodes.V1_2, Opcodes.ACC_PUBLIC
183: + Opcodes.ACC_SUPER, className, null, Type
184: .getInternalName(super Class), null);
185:
186: cw.visitSource(null, null);
187: }
188:
189: // /**
190: // * Creates a constructor for the field extractor receiving
191: // * the class instance and field name
192: // *
193: // * @param originalClassName
194: // * @param className
195: // * @param cw
196: // */
197: // private static void buildConstructor(final Class superClazz,
198: // final String className,
199: // final ClassWriter cw) {
200: // MethodVisitor mv;
201: // {
202: // mv = cw.visitMethod( Opcodes.ACC_PUBLIC,
203: // "<init>",
204: // Type.getMethodDescriptor( Type.VOID_TYPE,
205: // new Type[]{Type.getType( Class.class ), Type.getType( String.class )} ),
206: // null,
207: // null );
208: // mv.visitCode();
209: // final Label l0 = new Label();
210: // mv.visitLabel( l0 );
211: // mv.visitVarInsn( Opcodes.ALOAD,
212: // 0 );
213: // mv.visitVarInsn( Opcodes.ALOAD,
214: // 1 );
215: // mv.visitVarInsn( Opcodes.ALOAD,
216: // 2 );
217: // mv.visitMethodInsn( Opcodes.INVOKESPECIAL,
218: // Type.getInternalName( superClazz ),
219: // "<init>",
220: // Type.getMethodDescriptor( Type.VOID_TYPE,
221: // new Type[]{Type.getType( Class.class ), Type.getType( String.class )} ) );
222: // final Label l1 = new Label();
223: // mv.visitLabel( l1 );
224: // mv.visitInsn( Opcodes.RETURN );
225: // final Label l2 = new Label();
226: // mv.visitLabel( l2 );
227: // mv.visitLocalVariable( "this",
228: // "L" + className + ";",
229: // null,
230: // l0,
231: // l2,
232: // 0 );
233: // mv.visitLocalVariable( "clazz",
234: // Type.getDescriptor( Class.class ),
235: // null,
236: // l0,
237: // l2,
238: // 1 );
239: // mv.visitLocalVariable( "fieldName",
240: // Type.getDescriptor( String.class ),
241: // null,
242: // l0,
243: // l2,
244: // 2 );
245: // mv.visitMaxs( 0,
246: // 0 );
247: // mv.visitEnd();
248: // }
249: // }
250:
251: /**
252: * Creates a constructor for the field extractor receiving
253: * the index, field type and value type
254: *
255: * @param originalClassName
256: * @param className
257: * @param cw
258: */
259: private static void build3ArgConstructor(final Class super Clazz,
260: final String className, final ClassWriter cw) {
261: MethodVisitor mv;
262: {
263: mv = cw.visitMethod(Opcodes.ACC_PUBLIC, "<init>", Type
264: .getMethodDescriptor(Type.VOID_TYPE, new Type[] {
265: Type.getType(int.class),
266: Type.getType(Class.class),
267: Type.getType(ValueType.class) }), null,
268: null);
269: mv.visitCode();
270: final Label l0 = new Label();
271: mv.visitLabel(l0);
272: mv.visitVarInsn(Opcodes.ALOAD, 0);
273: mv.visitVarInsn(Opcodes.ILOAD, 1);
274: mv.visitVarInsn(Opcodes.ALOAD, 2);
275: mv.visitVarInsn(Opcodes.ALOAD, 3);
276: mv.visitMethodInsn(Opcodes.INVOKESPECIAL, Type
277: .getInternalName(super Clazz), "<init>", Type
278: .getMethodDescriptor(Type.VOID_TYPE, new Type[] {
279: Type.getType(int.class),
280: Type.getType(Class.class),
281: Type.getType(ValueType.class) }));
282: final Label l1 = new Label();
283: mv.visitLabel(l1);
284: mv.visitInsn(Opcodes.RETURN);
285: final Label l2 = new Label();
286: mv.visitLabel(l2);
287: mv.visitLocalVariable("this", "L" + className + ";", null,
288: l0, l2, 0);
289: mv.visitLocalVariable("index", Type
290: .getDescriptor(int.class), null, l0, l2, 1);
291: mv.visitLocalVariable("fieldType", Type
292: .getDescriptor(Class.class), null, l0, l2, 2);
293: mv.visitLocalVariable("valueType", Type
294: .getDescriptor(ValueType.class), null, l0, l2, 3);
295: mv.visitMaxs(0, 0);
296: mv.visitEnd();
297: }
298: }
299:
300: /**
301: * Creates the proxy reader method for the given method
302: *
303: * @param fieldName
304: * @param fieldFlag
305: * @param method
306: * @param cw
307: */
308: protected static void buildGetMethod(final Class originalClass,
309: final String className, final Class super Class,
310: final Method getterMethod, final ClassWriter cw) {
311:
312: final Class fieldType = getterMethod.getReturnType();
313: Method overridingMethod;
314: try {
315: overridingMethod = super Class
316: .getMethod(getOverridingMethodName(fieldType),
317: new Class[] { InternalWorkingMemory.class,
318: Object.class });
319: } catch (final Exception e) {
320: throw new RuntimeDroolsException(
321: "This is a bug. Please report back to JBoss Rules team.",
322: e);
323: }
324: final MethodVisitor mv = cw.visitMethod(Opcodes.ACC_PUBLIC,
325: overridingMethod.getName(), Type
326: .getMethodDescriptor(overridingMethod), null,
327: null);
328:
329: mv.visitCode();
330:
331: final Label l0 = new Label();
332: mv.visitLabel(l0);
333: mv.visitVarInsn(Opcodes.ALOAD, 2);
334: mv.visitTypeInsn(Opcodes.CHECKCAST, Type
335: .getInternalName(originalClass));
336:
337: if (originalClass.isInterface()) {
338: mv.visitMethodInsn(Opcodes.INVOKEINTERFACE, Type
339: .getInternalName(originalClass), getterMethod
340: .getName(), Type.getMethodDescriptor(getterMethod));
341: } else {
342: mv.visitMethodInsn(Opcodes.INVOKEVIRTUAL, Type
343: .getInternalName(originalClass), getterMethod
344: .getName(), Type.getMethodDescriptor(getterMethod));
345: }
346: mv
347: .visitInsn(Type.getType(fieldType).getOpcode(
348: Opcodes.IRETURN));
349: final Label l1 = new Label();
350: mv.visitLabel(l1);
351: mv.visitLocalVariable("this", "L" + className + ";", null, l0,
352: l1, 0);
353: mv.visitLocalVariable("workingMemory", Type
354: .getDescriptor(InternalWorkingMemory.class), null, l0,
355: l1, 1);
356: mv.visitLocalVariable("object", Type
357: .getDescriptor(Object.class), null, l0, l1, 2);
358: mv.visitMaxs(0, 0);
359: mv.visitEnd();
360: }
361:
362: private static String getOverridingMethodName(final Class fieldType) {
363: String ret = null;
364: if (fieldType.isPrimitive()) {
365: if (fieldType == char.class) {
366: ret = "getCharValue";
367: } else if (fieldType == byte.class) {
368: ret = "getByteValue";
369: } else if (fieldType == short.class) {
370: ret = "getShortValue";
371: } else if (fieldType == int.class) {
372: ret = "getIntValue";
373: } else if (fieldType == long.class) {
374: ret = "getLongValue";
375: } else if (fieldType == float.class) {
376: ret = "getFloatValue";
377: } else if (fieldType == double.class) {
378: ret = "getDoubleValue";
379: } else if (fieldType == boolean.class) {
380: ret = "getBooleanValue";
381: }
382: } else {
383: ret = "getValue";
384: }
385: return ret;
386: }
387:
388: /**
389: * Returns the appropriate Base class field extractor class
390: * for the given fieldType
391: *
392: * @param fieldType
393: * @return
394: */
395: private static Class getSuperClassFor(final Class fieldType) {
396: Class ret = null;
397: if (fieldType.isPrimitive()) {
398: if (fieldType == char.class) {
399: ret = BaseCharClassFieldExtractor.class;
400: } else if (fieldType == byte.class) {
401: ret = BaseByteClassFieldExtractor.class;
402: } else if (fieldType == short.class) {
403: ret = BaseShortClassFieldExtractor.class;
404: } else if (fieldType == int.class) {
405: ret = BaseIntClassFieldExtractor.class;
406: } else if (fieldType == long.class) {
407: ret = BaseLongClassFieldExtractors.class;
408: } else if (fieldType == float.class) {
409: ret = BaseFloatClassFieldExtractor.class;
410: } else if (fieldType == double.class) {
411: ret = BaseDoubleClassFieldExtractor.class;
412: } else if (fieldType == boolean.class) {
413: ret = BaseBooleanClassFieldExtractor.class;
414: }
415: } else {
416: ret = BaseObjectClassFieldExtractor.class;
417: }
418: return ret;
419: }
420:
421: /**
422: * Simple classloader
423: * @author Michael Neale
424: */
425: static class ByteArrayClassLoader extends ClassLoader {
426: public ByteArrayClassLoader(final ClassLoader parent) {
427: super (parent);
428: }
429:
430: public Class defineClass(final String name, final byte[] bytes,
431: final ProtectionDomain domain) {
432: return defineClass(name, bytes, 0, bytes.length, domain);
433: }
434: }
435: }
|