001: package org.depunit;
002:
003: import java.lang.reflect.*;
004: import java.lang.annotation.*;
005: import org.depunit.annotations.*;
006: import java.util.*;
007:
008: public class TestClass {
009: public static class InitDataDriver extends DataDriver {
010: private Map<String, String> m_data;
011: private boolean m_reset;
012:
013: public InitDataDriver(Map<String, String> data) {
014: m_data = data;
015: m_reset = true;
016: }
017:
018: public void reset() {
019: m_reset = true;
020: }
021:
022: public boolean hasNextDataSet() {
023: return (m_reset);
024: }
025:
026: public Map<String, ? extends Object> getNextDataSet() {
027: m_reset = false;
028: return (m_data);
029: }
030: }
031:
032: private Class m_class;
033: private Object m_classInstance;
034: private LinkedList<Method> m_beforeTest;
035: private LinkedList<Method> m_afterTest;
036: private String m_fullName;
037: private LinkedList<TestMethod> m_testMethods;
038: //private Map<String, String> m_initParams;
039: private DataDriver m_dataDriver;
040:
041: public TestClass(String className) throws ClassNotFoundException {
042: m_dataDriver = null;
043: m_classInstance = null;
044: //m_initParams = null;
045: m_testMethods = new LinkedList<TestMethod>();
046: m_beforeTest = new LinkedList<Method>();
047: m_afterTest = new LinkedList<Method>();
048: List<TestMethod> testMethods = new LinkedList<TestMethod>();
049: List<TestMethod> beforeClass = new LinkedList<TestMethod>();
050: List<TestMethod> afterClass = new LinkedList<TestMethod>();
051:
052: m_class = Class.forName(className);
053:
054: Package p = m_class.getPackage();
055: m_fullName = "";
056: if (p != null)
057: m_fullName = p.getName() + ".";
058:
059: m_fullName += m_class.getName();
060:
061: Method[] methods = m_class.getMethods();
062: for (Method m : methods) {
063: boolean add = false;
064: Annotation[] annots = m.getDeclaredAnnotations();
065: for (Annotation a : annots) {
066: if (a instanceof Test)
067: testMethods.add(new TestMethod(m, this , false));
068:
069: else if (a instanceof BeforeTest)
070: m_beforeTest.add(m);
071:
072: else if (a instanceof AfterTest)
073: m_afterTest.add(m);
074:
075: else if (a instanceof BeforeClass)
076: beforeClass.add(new TestMethod(m, this , true));
077:
078: else if (a instanceof AfterClass)
079: afterClass.add(new TestMethod(m, this , true));
080: }
081: }
082:
083: //setup dependencies
084: for (TestMethod tm : testMethods) {
085: for (TestMethod pretm : beforeClass)
086: tm.addHardDependency(pretm.getFullName());
087:
088: for (TestMethod posttm : afterClass)
089: posttm.addSoftDependency(tm.getFullName());
090: }
091:
092: for (TestMethod posttm : afterClass)
093: for (TestMethod pretm : beforeClass) {
094: pretm.addCleanupMethod(posttm.getFullName());
095: //posttm.addHardDependency(pretm.getFullName()); //done later on
096: }
097:
098: for (TestMethod tm : beforeClass)
099: m_testMethods.add(tm);
100:
101: for (TestMethod tm : testMethods)
102: m_testMethods.add(tm);
103:
104: for (TestMethod tm : afterClass)
105: m_testMethods.add(tm);
106: }
107:
108: //---------------------------------------------------------------------------
109: public void setDataDriver(DataDriver dd) {
110: //System.out.println("DataDriver "+dd);
111: m_dataDriver = dd;
112: }
113:
114: //---------------------------------------------------------------------------
115: public DataDriver getDataDriver() {
116: return (m_dataDriver);
117: }
118:
119: //---------------------------------------------------------------------------
120: public void setInitParams(Map<String, String> initParams) {
121: m_dataDriver = new InitDataDriver(initParams);
122: }
123:
124: //---------------------------------------------------------------------------
125: public String getFullName() {
126: return (m_fullName);
127: }
128:
129: //---------------------------------------------------------------------------
130: /**
131: Sets the class with data from the data provider
132: */
133: public void initialize() throws InitializationException {
134: Map<String, ? extends Object> dataSet = null;
135: try {
136: dataSet = m_dataDriver.getNextDataSet();
137: } catch (Exception e) {
138: throw new InitializationException(e);
139: }
140:
141: BeanUtil.initializeClass(m_class, dataSet, m_classInstance);
142: }
143:
144: //---------------------------------------------------------------------------
145: public synchronized Object getClassInstance(
146: Map<String, Object> runContext)
147: throws ObjectCreationException {
148: try {
149: if (m_classInstance == null) {
150: try {
151: Constructor c = m_class
152: .getConstructor(RunContext.class);
153: m_classInstance = c.newInstance(new RunContext(
154: runContext));
155: } catch (Exception e) {
156: }
157:
158: if (m_classInstance == null)
159: m_classInstance = m_class.newInstance();
160: }
161:
162: //If the context has params that match set methods we will set them
163: //We want to do this every time in case the context values change
164: Method[] methods = m_class.getMethods();
165: for (Method m : methods) {
166: String methodName = m.getName();
167: if (methodName.startsWith("set")) {
168: String param = methodName.substring(3);
169:
170: Class<?>[] paramTypes = m.getParameterTypes();
171: Object data = runContext.get(param.toLowerCase());
172: if ((data != null) && (paramTypes.length == 1)
173: && (data.getClass() == paramTypes[0])) {
174: m.invoke(m_classInstance, data);
175: }
176: }
177: }
178: } catch (InvocationTargetException ite) {
179: throw new ObjectCreationException(ite);
180: } catch (InstantiationException ie) {
181: throw new ObjectCreationException(ie);
182: } catch (IllegalAccessException iae) {
183: throw new ObjectCreationException(iae);
184: }
185:
186: return (m_classInstance);
187: }
188:
189: //---------------------------------------------------------------------------
190: public void callBeforeTest() {
191: }
192:
193: //---------------------------------------------------------------------------
194: public void callAfterTest() {
195: }
196:
197: //---------------------------------------------------------------------------
198: public List<TestMethod> getTestMethods() {
199: return (m_testMethods);
200: }
201: }
|