001: /*
002: *
003: * The DbUnit Database Testing Framework
004: * Copyright (C)2002-2004, DbUnit.org
005: *
006: * This library is free software; you can redistribute it and/or
007: * modify it under the terms of the GNU Lesser General Public
008: * License as published by the Free Software Foundation; either
009: * version 2.1 of the License, or (at your option) any later version.
010: *
011: * This library is distributed in the hope that it will be useful,
012: * but WITHOUT ANY WARRANTY; without even the implied warranty of
013: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
014: * Lesser General Public License for more details.
015: *
016: * You should have received a copy of the GNU Lesser General Public
017: * License along with this library; if not, write to the Free Software
018: * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
019: *
020: */
021: package org.dbunit.ant;
022:
023: import org.slf4j.Logger;
024: import org.slf4j.LoggerFactory;
025:
026: import java.io.File;
027: import java.io.IOException;
028: import java.sql.SQLException;
029: import java.util.*;
030:
031: import org.apache.tools.ant.Task;
032: import org.dbunit.DatabaseUnitException;
033: import org.dbunit.database.*;
034: import org.dbunit.dataset.*;
035: import org.dbunit.dataset.csv.CsvProducer;
036: import org.dbunit.dataset.stream.IDataSetProducer;
037: import org.dbunit.dataset.stream.StreamingDataSet;
038: import org.dbunit.dataset.xml.*;
039: import org.xml.sax.InputSource;
040:
041: /**
042: * @author Manuel Laflamme
043: * @since Apr 3, 2004
044: * @version $Revision: 554 $
045: */
046: public abstract class AbstractStep implements DbUnitTaskStep {
047:
048: /**
049: * Logger for this class
050: */
051: private static final Logger logger = LoggerFactory
052: .getLogger(AbstractStep.class);
053:
054: public static final String FORMAT_FLAT = "flat";
055: public static final String FORMAT_XML = "xml";
056: public static final String FORMAT_DTD = "dtd";
057: public static final String FORMAT_CSV = "csv";
058:
059: // Needed a path to Project for logging and references.
060: private Task parentTask;
061:
062: protected IDataSet getDatabaseDataSet(
063: IDatabaseConnection connection, List tables,
064: boolean forwardonly) throws DatabaseUnitException {
065: logger.debug("getDatabaseDataSet(connection=" + connection
066: + ", tables=" + tables + ", forwardonly=" + forwardonly
067: + ") - start");
068:
069: try {
070: // Setup the ResultSet table factory
071: IResultSetTableFactory factory = null;
072: if (forwardonly) {
073: factory = new ForwardOnlyResultSetTableFactory();
074: } else {
075: factory = new CachedResultSetTableFactory();
076: }
077: DatabaseConfig config = connection.getConfig();
078: config.setProperty(
079: DatabaseConfig.PROPERTY_RESULTSET_TABLE_FACTORY,
080: factory);
081:
082: // Retrieve the complete database if no tables or queries specified.
083: if (tables.size() == 0) {
084: return connection.createDataSet();
085: }
086:
087: List queryDataSets = new ArrayList();
088:
089: QueryDataSet queryDataSet = new QueryDataSet(connection);
090:
091: for (Iterator it = tables.iterator(); it.hasNext();) {
092: Object item = it.next();
093: if (item instanceof QuerySet) {
094: if (queryDataSet.getTableNames().length > 0)
095: queryDataSets.add(queryDataSet);
096: queryDataSets.add(getQueryDataSetForQuerySet(
097: connection, (QuerySet) item));
098: queryDataSet = new QueryDataSet(connection);
099: } else if (item instanceof Query) {
100: Query queryItem = (Query) item;
101: queryDataSet.addTable(queryItem.getName(),
102: queryItem.getSql());
103: } else {
104: Table tableItem = (Table) item;
105: queryDataSet.addTable(tableItem.getName());
106: }
107: }
108:
109: if (queryDataSet.getTableNames().length > 0)
110: queryDataSets.add(queryDataSet);
111:
112: IDataSet[] dataSetsArray = new IDataSet[queryDataSets
113: .size()];
114: return new CompositeDataSet((IDataSet[]) queryDataSets
115: .toArray(dataSetsArray));
116: } catch (SQLException e) {
117: logger.error("getDatabaseDataSet()", e);
118:
119: throw new DatabaseUnitException(e);
120: }
121: }
122:
123: protected IDataSet getSrcDataSet(File src, String format,
124: boolean forwardonly) throws DatabaseUnitException {
125: logger.debug("getSrcDataSet(src=" + src + ", format=" + format
126: + ", forwardonly=" + forwardonly + ") - start");
127:
128: try {
129: IDataSetProducer producer = null;
130: if (format.equalsIgnoreCase(FORMAT_XML)) {
131: producer = new XmlProducer(new InputSource(src.toURL()
132: .toString()));
133: } else if (format.equalsIgnoreCase(FORMAT_CSV)) {
134: producer = new CsvProducer(src);
135: } else if (format.equalsIgnoreCase(FORMAT_FLAT)) {
136: producer = new FlatXmlProducer(new InputSource(src
137: .toURL().toString()));
138: } else if (format.equalsIgnoreCase(FORMAT_DTD)) {
139: producer = new FlatDtdProducer(new InputSource(src
140: .toURL().toString()));
141: } else {
142: throw new IllegalArgumentException(
143: "Type must be either 'flat'(default), 'xml', 'csv' or 'dtd' but was: "
144: + format);
145: }
146:
147: if (forwardonly) {
148: return new StreamingDataSet(producer);
149: }
150: return new CachedDataSet(producer);
151: } catch (IOException e) {
152: logger.error("getSrcDataSet()", e);
153:
154: throw new DatabaseUnitException(e);
155: }
156: }
157:
158: private QueryDataSet getQueryDataSetForQuerySet(
159: IDatabaseConnection connection, QuerySet querySet)
160: throws SQLException {
161: logger.debug("getQueryDataSetForQuerySet(connection="
162: + connection + ", querySet=" + querySet + ") - start");
163:
164: //incorporate queries from referenced queryset
165: String refid = querySet.getRefid();
166: if (refid != null) {
167: QuerySet referenced = (QuerySet) getParentTask()
168: .getProject().getReference(refid);
169: querySet.copyQueriesFrom(referenced);
170: }
171:
172: QueryDataSet partialDataSet = new QueryDataSet(connection);
173:
174: Iterator queriesIter = querySet.getQueries().iterator();
175: while (queriesIter.hasNext()) {
176: Query query = (Query) queriesIter.next();
177: partialDataSet.addTable(query.getName(), query.getSql());
178: }
179:
180: return partialDataSet;
181:
182: }
183:
184: public Task getParentTask() {
185: logger.debug("getParentTask() - start");
186:
187: return parentTask;
188: }
189:
190: public void setParentTask(Task task) {
191: logger.debug("setParentTask(task=" + task + ") - start");
192:
193: parentTask = task;
194: }
195:
196: public void log(String msg, int level) {
197: logger.debug("log(msg=" + msg + ", level=" + level
198: + ") - start");
199:
200: if (parentTask != null)
201: parentTask.log(msg, level);
202: }
203:
204: }
|