001: /*
002: *
003: * The DbUnit Database Testing Framework
004: * Copyright (C)2005, DbUnit.org
005: *
006: * This library is free software; you can redistribute it and/or
007: * modify it under the terms of the GNU Lesser General Public
008: * License as published by the Free Software Foundation; either
009: * version 2.1 of the License, or (at your option) any later version.
010: *
011: * This library is distributed in the hope that it will be useful,
012: * but WITHOUT ANY WARRANTY; without even the implied warranty of
013: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
014: * Lesser General Public License for more details.
015: *
016: * You should have received a copy of the GNU Lesser General Public
017: * License along with this library; if not, write to the Free Software
018: * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307 USA
019: *
020: */
021:
022: package org.dbunit.database;
023:
024: import java.sql.PreparedStatement;
025: import java.sql.ResultSet;
026: import java.sql.SQLException;
027: import java.util.ArrayList;
028: import java.util.HashMap;
029: import java.util.HashSet;
030: import java.util.Iterator;
031: import java.util.List;
032: import java.util.Map;
033: import java.util.Set;
034:
035: import org.slf4j.Logger;
036: import org.slf4j.LoggerFactory; //TODO: should not have dependency on sub-package!
037: import org.dbunit.database.search.ForeignKeyRelationshipEdge;
038: import org.dbunit.dataset.DataSetException;
039: import org.dbunit.dataset.IDataSet;
040: import org.dbunit.dataset.ITable;
041: import org.dbunit.dataset.ITableIterator;
042: import org.dbunit.dataset.ITableMetaData;
043: import org.dbunit.dataset.filter.AbstractTableFilter;
044: import org.dbunit.util.SQLHelper;
045:
046: /**
047: * Filter a table given a map of the allowed rows based on primary key values.<br>
048: * It uses a depth-first algorithm (although not recursive - it might be refactored
049: * in the future) to define which rows are allowed, as well which rows are necessary
050: * (and hence allowed) because of dependencies with the allowed rows.<br>
051: * <strong>NOTE:</strong> multi-column primary keys are not supported at the moment.
052: * TODO: test cases
053: * @author Felipe Leme <dbunit@felipeal.net>
054: * @version $Revision: 554 $
055: * @since Sep 9, 2005
056: */
057: public class PrimaryKeyFilter extends AbstractTableFilter {
058:
059: private final IDatabaseConnection connection;
060:
061: private final Map allowedPKsPerTable;
062: private final Map allowedPKsInput;
063: private final Map pksToScanPerTable;
064:
065: private final boolean reverseScan;
066:
067: protected final Logger logger = LoggerFactory.getLogger(getClass());
068:
069: // cache de primary keys
070: private final Map pkColumnPerTable = new HashMap();
071:
072: private final Map fkEdgesPerTable = new HashMap();
073: private final Map fkReverseEdgesPerTable = new HashMap();
074:
075: // name of the tables, in reverse order of depedency
076: private final List tableNames = new ArrayList();
077:
078: /**
079: * Default constructor, it takes as input a map with desired rows in a final
080: * dataset; the filter will ensure that the rows necessary by these initial rows
081: * are also allowed (and so on...).
082: * @param connection database connection
083: * @param allowedPKs map of allowed rows, based on the primary keys (key is the name
084: * of a table; value is a Set with allowed primary keys for that table)
085: * @param reverseDependency flag indicating if the rows that depend on a row should
086: * also be allowed by the filter
087: */
088: public PrimaryKeyFilter(IDatabaseConnection connection,
089: Map allowedPKs, boolean reverseDependency) {
090: this .connection = connection;
091: this .allowedPKsPerTable = new HashMap();
092: this .allowedPKsInput = allowedPKs;
093: this .reverseScan = reverseDependency;
094:
095: // we need a deep copy here
096: // this.idsToScanPerTable = new HashMap(allowedIds);
097: this .pksToScanPerTable = new HashMap(allowedPKs.size());
098: Iterator iterator = allowedPKs.entrySet().iterator();
099: while (iterator.hasNext()) {
100: Map.Entry entry = (Map.Entry) iterator.next();
101: Object table = entry.getKey();
102: Set inputSet = (Set) entry.getValue();
103: Set newSet = new HashSet(inputSet);
104: this .pksToScanPerTable.put(table, newSet);
105: }
106:
107: }
108:
109: public void nodeAdded(Object node) {
110: logger.debug("nodeAdded(node=" + node + ") - start");
111:
112: this .tableNames.add(node);
113: if (this .logger.isDebugEnabled()) {
114: this .logger.debug("nodeAdded: " + node);
115: }
116: }
117:
118: public void edgeAdded(ForeignKeyRelationshipEdge edge) {
119: if (this .logger.isDebugEnabled()) {
120: this .logger.debug("edgeAdded: " + edge);
121: }
122: // first add it to the "direct edges"
123: String from = (String) edge.getFrom();
124: Set edges = (Set) this .fkEdgesPerTable.get(from);
125: if (edges == null) {
126: edges = new HashSet();
127: this .fkEdgesPerTable.put(from, edges);
128: }
129: if (!edges.contains(edge)) {
130: edges.add(edge);
131: }
132:
133: // then add it to the "reverse edges"
134: String to = (String) edge.getTo();
135: edges = (Set) this .fkReverseEdgesPerTable.get(to);
136: if (edges == null) {
137: edges = new HashSet();
138: this .fkReverseEdgesPerTable.put(to, edges);
139: }
140: if (!edges.contains(edge)) {
141: edges.add(edge);
142: }
143:
144: // finally, update the PKs cache
145: Object pkTo = this .pkColumnPerTable.get(to);
146: if (pkTo == null) {
147: Object pk = edge.getPKColumn();
148: this .pkColumnPerTable.put(to, pk);
149: }
150:
151: }
152:
153: /**
154: * @see AbstractTableFilter
155: */
156: public boolean isValidName(String tableName)
157: throws DataSetException {
158: logger
159: .debug("isValidName(tableName=" + tableName
160: + ") - start");
161:
162: // boolean isValid = this.allowedIds.containsKey(tableName);
163: // return isValid;
164: return true;
165: }
166:
167: public ITableIterator iterator(IDataSet dataSet, boolean reversed)
168: throws DataSetException {
169: if (this .logger.isDebugEnabled()) {
170: this .logger.debug("Filter.iterator()");
171: }
172: try {
173: searchPKs(dataSet);
174: } catch (SQLException e) {
175: logger.error("iterator()", e);
176:
177: throw new DataSetException(e);
178: }
179: return new FilterIterator(reversed ? dataSet.reverseIterator()
180: : dataSet.iterator());
181: }
182:
183: private void searchPKs(IDataSet dataSet) throws DataSetException,
184: SQLException {
185: logger.debug("searchPKs(dataSet=" + dataSet + ") - start");
186:
187: int counter = 0;
188: while (!this .pksToScanPerTable.isEmpty()) {
189: counter++;
190: if (this .logger.isDebugEnabled()) {
191: this .logger.debug("RUN # " + counter);
192: }
193:
194: for (int i = this .tableNames.size() - 1; i >= 0; i--) {
195: String tableName = (String) this .tableNames.get(i);
196: // TODO: support multi-column PKs
197: String pkColumn = dataSet.getTable(tableName)
198: .getTableMetaData().getPrimaryKeys()[0]
199: .getColumnName();
200: Set tmpSet = (Set) this .pksToScanPerTable
201: .get(tableName);
202: if (tmpSet != null && !tmpSet.isEmpty()) {
203: Set pksToScan = new HashSet(tmpSet);
204: if (this .logger.isDebugEnabled()) {
205: this .logger.debug("before search: " + tableName
206: + "=>" + pksToScan);
207: }
208: scanPKs(tableName, pkColumn, pksToScan);
209: scanReversePKs(tableName, pksToScan);
210: allowPKs(tableName, pksToScan);
211: removePKsToScan(tableName, pksToScan);
212: } // if
213: } // for
214: removeScannedTables();
215: } // while
216: if (this .logger.isDebugEnabled()) {
217: this .logger.debug("Finished searchIds()");
218: }
219: }
220:
221: private void removeScannedTables() {
222: logger.debug("removeScannedTables() - start");
223:
224: Iterator iterator = this .pksToScanPerTable.entrySet()
225: .iterator();
226: List tablesToRemove = new ArrayList();
227: while (iterator.hasNext()) {
228: Map.Entry entry = (Map.Entry) iterator.next();
229: String table = (String) entry.getKey();
230: Set pksToScan = (Set) entry.getValue();
231: boolean removeIt = pksToScan.isEmpty();
232: if (!this .tableNames.contains(table)) {
233: if (this .logger.isWarnEnabled()) {
234: this .logger
235: .warn("Discarding ids "
236: + pksToScan
237: + " of table "
238: + table
239: + "as this table has not been passed as input");
240: }
241: removeIt = true;
242: }
243: if (removeIt) {
244: tablesToRemove.add(table);
245: }
246: }
247: iterator = tablesToRemove.iterator();
248: while (iterator.hasNext()) {
249: this .pksToScanPerTable.remove(iterator.next());
250: }
251: }
252:
253: private void allowPKs(String table, Set newAllowedPKs) {
254: logger.debug("allowPKs(table=" + table + ", newAllowedPKs="
255: + newAllowedPKs + ") - start");
256:
257: // first, obtain the current allowed ids for that table
258: Set currentAllowedIds = (Set) this .allowedPKsPerTable
259: .get(table);
260: if (currentAllowedIds == null) {
261: currentAllowedIds = new HashSet();
262: this .allowedPKsPerTable.put(table, currentAllowedIds);
263: }
264: // then, add the new ids, but checking if it should be allowed to add them
265: Set forcedAllowedPKs = (Set) this .allowedPKsInput.get(table);
266: if (forcedAllowedPKs == null || forcedAllowedPKs.isEmpty()) {
267: currentAllowedIds.addAll(newAllowedPKs);
268: } else {
269: Iterator iterator = newAllowedPKs.iterator();
270: while (iterator.hasNext()) {
271: Object id = iterator.next();
272: if (forcedAllowedPKs.contains(id)) {
273: currentAllowedIds.add(id);
274: } else {
275: if (this .logger.isDebugEnabled()) {
276: this .logger
277: .debug("Discarding id "
278: + id
279: + " of table "
280: + table
281: + " as it was not included in the input!");
282: }
283: }
284: }
285: }
286: }
287:
288: private void scanPKs(String table, String pkColumn, Set allowedIds)
289: throws SQLException {
290: logger
291: .debug("scanPKs(table=" + table + ", pkColumn="
292: + pkColumn + ", allowedIds=" + allowedIds
293: + ") - start");
294:
295: Set fkEdges = (Set) this .fkEdgesPerTable.get(table);
296: if (fkEdges == null || fkEdges.isEmpty()) {
297: return;
298: }
299: // we need a temporary list as there is no warranty about the set order...
300: List fkTables = new ArrayList(fkEdges.size());
301: Iterator iterator = fkEdges.iterator();
302: StringBuffer colsBuffer = new StringBuffer();
303: while (iterator.hasNext()) {
304: ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator
305: .next();
306: fkTables.add(edge.getTo());
307: colsBuffer.append(edge.getFKColumn());
308: if (iterator.hasNext()) {
309: colsBuffer.append(", ");
310: }
311: }
312: // NOTE: make sure the query below is compatible standard SQL
313: String sql = "SELECT " + colsBuffer + " FROM " + table
314: + " WHERE " + pkColumn + " = ? ";
315: if (this .logger.isDebugEnabled()) {
316: this .logger.debug("SQL: " + sql);
317: }
318: PreparedStatement pstmt = null;
319: ResultSet rs = null;
320: try {
321: pstmt = this .connection.getConnection().prepareStatement(
322: sql);
323: iterator = allowedIds.iterator();
324: while (iterator.hasNext()) {
325: Object pk = iterator.next(); // id being scanned
326: if (this .logger.isDebugEnabled()) {
327: this .logger.debug("Executing sql for ? = " + pk);
328: }
329: pstmt.setObject(1, pk);
330: rs = pstmt.executeQuery();
331: while (rs.next()) {
332: for (int i = 0; i < fkTables.size(); i++) {
333: String newTable = (String) fkTables.get(i);
334: Object fk = rs.getObject(i + 1);
335: if (fk != null) {
336: if (this .logger.isDebugEnabled()) {
337: this .logger.debug("New ID: " + newTable
338: + "->" + fk);
339: }
340: addPKToScan(newTable, fk);
341: } else {
342: this .logger
343: .warn("Found null FK for relationship "
344: + table + "=>" + newTable);
345: }
346: }
347: }
348: }
349: } catch (SQLException e) {
350: logger.error("scanPKs()", e);
351:
352: SQLHelper.close(rs, pstmt);
353: }
354: }
355:
356: private void scanReversePKs(String table, Set pksToScan)
357: throws SQLException {
358: logger.debug("scanReversePKs(table=" + table + ", pksToScan="
359: + pksToScan + ") - start");
360:
361: if (!this .reverseScan) {
362: return;
363: }
364: Set fkReverseEdges = (Set) this .fkReverseEdgesPerTable
365: .get(table);
366: if (fkReverseEdges == null || fkReverseEdges.isEmpty()) {
367: return;
368: }
369: Iterator iterator = fkReverseEdges.iterator();
370: while (iterator.hasNext()) {
371: ForeignKeyRelationshipEdge edge = (ForeignKeyRelationshipEdge) iterator
372: .next();
373: addReverseEdge(edge, pksToScan);
374: }
375: }
376:
377: private void addReverseEdge(ForeignKeyRelationshipEdge edge,
378: Set idsToScan) throws SQLException {
379: logger.debug("addReverseEdge(edge=" + edge + ", idsToScan="
380: + idsToScan + ") - start");
381:
382: String fkTable = (String) edge.getFrom();
383: String fkColumn = edge.getFKColumn();
384: String pkColumn = getPKColumn(fkTable);
385: // NOTE: make sure the query below is compatible standard SQL
386: String sql = "SELECT " + pkColumn + " FROM " + fkTable
387: + " WHERE " + fkColumn + " = ? ";
388:
389: PreparedStatement pstmt = null;
390: try {
391: if (this .logger.isDebugEnabled()) {
392: this .logger.debug("Preparing SQL query '" + sql + "'");
393: }
394: pstmt = this .connection.getConnection().prepareStatement(
395: sql);
396: } catch (SQLException e) {
397: logger.error("addReverseEdge()", e);
398:
399: SQLHelper.close(pstmt);
400: }
401: ResultSet rs = null;
402: Iterator iterator = idsToScan.iterator();
403: try {
404: while (iterator.hasNext()) {
405: Object pk = iterator.next();
406: if (this .logger.isDebugEnabled()) {
407: this .logger.debug("executing query '" + sql
408: + "' for ? = " + pk);
409: }
410: pstmt.setObject(1, pk);
411: rs = pstmt.executeQuery();
412: while (rs.next()) {
413: Object fk = rs.getObject(1);
414: addPKToScan(fkTable, fk);
415: }
416: }
417: } finally {
418: SQLHelper.close(rs, pstmt);
419: }
420: }
421:
422: // TODO: support PKs with multiple values
423: private String getPKColumn(String table) throws SQLException {
424: logger.debug("getPKColumn(table=" + table + ") - start");
425:
426: String pkColumn = (String) this .pkColumnPerTable.get(table);
427: if (pkColumn == null) {
428: pkColumn = SQLHelper.getPrimaryKeyColumn(this .connection
429: .getConnection(), table);
430: this .pkColumnPerTable.put(table, pkColumn);
431: }
432: return pkColumn;
433: }
434:
435: private void removePKsToScan(String table, Set ids) {
436: logger.debug("removePKsToScan(table=" + table + ", ids=" + ids
437: + ") - start");
438:
439: Set pksToScan = (Set) this .pksToScanPerTable.get(table);
440: if (pksToScan != null) {
441: if (pksToScan == ids) {
442: throw new RuntimeException(
443: "INTERNAL ERROR on removeIdsToScan() for table "
444: + table);
445: } else {
446: pksToScan.removeAll(ids);
447: }
448: }
449: }
450:
451: private void addPKToScan(String table, Object pk) {
452: logger.debug("addPKToScan(table=" + table + ", pk=" + pk
453: + ") - start");
454:
455: // first, check if it wasn't added yet
456: Set scannedIds = (Set) this .allowedPKsPerTable.get(table);
457: if (scannedIds != null && scannedIds.contains(pk)) {
458: if (this .logger.isDebugEnabled()) {
459: this .logger.debug("Discarding already scanned id=" + pk
460: + " for table " + table);
461: }
462: return;
463: }
464:
465: Set pksToScan = (Set) this .pksToScanPerTable.get(table);
466: if (pksToScan == null) {
467: pksToScan = new HashSet();
468: this .pksToScanPerTable.put(table, pksToScan);
469: }
470: pksToScan.add(pk);
471: }
472:
473: private class FilterIterator implements ITableIterator {
474:
475: /**
476: * Logger for this class
477: */
478: private final Logger logger = LoggerFactory
479: .getLogger(FilterIterator.class);
480:
481: private final ITableIterator _iterator;
482:
483: public FilterIterator(ITableIterator iterator) {
484:
485: _iterator = iterator;
486: }
487:
488: ////////////////////////////////////////////////////////////////////////////
489: // ITableIterator interface
490:
491: public boolean next() throws DataSetException {
492: if (logger.isDebugEnabled()) {
493: logger.debug("Iterator.next()");
494: }
495: while (_iterator.next()) {
496: if (accept(_iterator.getTableMetaData().getTableName())) {
497: return true;
498: }
499: }
500: return false;
501: }
502:
503: public ITableMetaData getTableMetaData()
504: throws DataSetException {
505: if (logger.isDebugEnabled()) {
506: logger.debug("Iterator.getTableMetaData()");
507: }
508: return _iterator.getTableMetaData();
509: }
510:
511: public ITable getTable() throws DataSetException {
512: if (logger.isDebugEnabled()) {
513: logger.debug("Iterator.getTable()");
514: }
515: ITable table = _iterator.getTable();
516: String tableName = table.getTableMetaData().getTableName();
517: Set allowedPKs = (Set) allowedPKsPerTable.get(tableName);
518: if (allowedPKs != null) {
519: return new PrimaryKeyFilteredTableWrapper(table,
520: allowedPKs);
521: }
522: return table;
523: }
524: }
525:
526: }
|