001: /*
002: * TableDependencySorter.java
003: *
004: * This file is part of SQL Workbench/J, http://www.sql-workbench.net
005: *
006: * Copyright 2002-2008, Thomas Kellerer
007: * No part of this code maybe reused without the permission of the author
008: *
009: * To contact the author please send an email to: support@sql-workbench.net
010: *
011: */
012: package workbench.db.importer;
013:
014: import java.util.ArrayList;
015: import java.util.Collections;
016: import java.util.Comparator;
017: import java.util.LinkedList;
018: import java.util.List;
019: import workbench.db.DependencyNode;
020: import workbench.db.TableDependency;
021: import workbench.db.TableIdentifier;
022: import workbench.db.WbConnection;
023:
024: /**
025: * A class to sort tables according to their foreign key constraints,
026: * so that data can be imported or deleted without disabling FK constraints.
027: *
028: * @author support@sql-workbench.net
029: */
030: public class TableDependencySorter {
031: private WbConnection dbConn;
032: private List<TableIdentifier> cycleErrors;
033:
034: public TableDependencySorter(WbConnection con) {
035: this .dbConn = con;
036: }
037:
038: public List<TableIdentifier> sortForInsert(
039: List<TableIdentifier> tables) {
040: return getSortedTableList(tables, false, false);
041: }
042:
043: public List<TableIdentifier> sortForDelete(
044: List<TableIdentifier> tables, boolean addMissing) {
045: return getSortedTableList(tables, addMissing, true);
046: }
047:
048: public boolean hasErrors() {
049: return cycleErrors != null;
050: }
051:
052: public List<TableIdentifier> getErrorTables() {
053: if (cycleErrors == null)
054: return null;
055: return Collections.unmodifiableList(cycleErrors);
056: }
057:
058: /**
059: * Return a sorted list of DependencyNodes that need to be taken care of
060: * when deleting a row from the passed table.
061: *
062: * @param table
063: * @return all DependencyNodes relevant for deleting
064: */
065: public List<DependencyNode> getSortedNodesForDelete(
066: TableIdentifier table) {
067: ArrayList<TableIdentifier> tables = new ArrayList<TableIdentifier>(
068: 1);
069: tables.add(table);
070: List<LevelNode> levelMapping = createLevelMapping(tables, true);
071:
072: ArrayList<DependencyNode> result = new ArrayList<DependencyNode>(
073: levelMapping.size());
074: for (LevelNode lvl : levelMapping) {
075: result.add(lvl.node);
076: }
077: return result;
078: }
079:
080: /**
081: * Determines the FK dependencies for each table in the passed List,
082: * and sorts them so that data can be imported without violating
083: * foreign key constraints
084: *
085: * @param tables the list of tables to be sorted
086: * @returns the tables sorted according to their FK dependencies
087: * @throws DependencyCycleException if an endless loop in the dependencies was detected
088: */
089: private List<TableIdentifier> getSortedTableList(
090: List<TableIdentifier> tables, boolean addMissing,
091: boolean bottomUp) {
092: List<LevelNode> levelMapping = createLevelMapping(tables,
093: bottomUp);
094:
095: ArrayList<TableIdentifier> result = new ArrayList<TableIdentifier>();
096: for (LevelNode lvl : levelMapping) {
097: int index = findTable(lvl.node.getTable(), tables);
098: if (index > -1) {
099: result.add(tables.get(index));
100: } else if (addMissing) {
101: result.add(lvl.node.getTable());
102: }
103: }
104: return result;
105: }
106:
107: private List<LevelNode> createLevelMapping(
108: List<TableIdentifier> tables, boolean bottomUp) {
109: List<LevelNode> levelMapping = new ArrayList<LevelNode>(tables
110: .size());
111:
112: for (TableIdentifier tbl : tables) {
113: TableDependency deps = new TableDependency(dbConn, tbl);
114: deps.readTreeForChildren();
115: if (deps.wasAborted()) {
116: if (cycleErrors == null)
117: cycleErrors = new LinkedList<TableIdentifier>();
118: cycleErrors.add(tbl);
119: }
120:
121: DependencyNode root = deps.getRootNode();
122: if (root != null) {
123: List<DependencyNode> allChildren = getAllNodes(root);
124: putNodes(levelMapping, allChildren);
125: }
126: }
127:
128: Comparator<LevelNode> comp = null;
129:
130: if (bottomUp) {
131: comp = new Comparator<LevelNode>() {
132: public int compare(LevelNode o1, LevelNode o2) {
133: return o2.level - o1.level;
134: }
135: };
136: } else {
137: comp = new Comparator<LevelNode>() {
138: public int compare(LevelNode o1, LevelNode o2) {
139: return o1.level - o2.level;
140: }
141: };
142: }
143:
144: Collections.sort(levelMapping, comp);
145: return levelMapping;
146: }
147:
148: private int findTable(TableIdentifier tofind,
149: List<TableIdentifier> toSearch) {
150:
151: for (int i = 0; i < toSearch.size(); i++) {
152: TableIdentifier tbl = toSearch.get(i);
153: if (tbl.getTableName().equalsIgnoreCase(
154: tofind.getTableName()))
155: return i;
156: }
157: return -1;
158: }
159:
160: protected void putNodes(List<LevelNode> levelMapping,
161: List<DependencyNode> nodes) {
162: for (DependencyNode node : nodes) {
163: TableIdentifier tbl = node.getTable();
164: int level = node.getLevel();
165: LevelNode lvl = findLevelNode(levelMapping, tbl);
166: if (lvl == null) {
167: lvl = new LevelNode(node, level);
168: levelMapping.add(lvl);
169: } else if (level > lvl.level) {
170: lvl.level = level;
171: }
172: }
173: }
174:
175: private LevelNode findLevelNode(List<LevelNode> levelMapping,
176: TableIdentifier tbl) {
177: for (LevelNode lvl : levelMapping) {
178: if (lvl.node.getTable().getTableName().equalsIgnoreCase(
179: tbl.getTableName()))
180: return lvl;
181: }
182: return null;
183: }
184:
185: /**
186: * Get all nodes of the passed dependency hierarchy as a "flat" list.
187: * This is public mainly to be able to run a unit test agains it.
188: */
189: public List<DependencyNode> getAllNodes(DependencyNode startWith) {
190: if (startWith == null)
191: return Collections.emptyList();
192:
193: ArrayList<DependencyNode> result = new ArrayList<DependencyNode>();
194: result.add(startWith);
195:
196: List<DependencyNode> children = startWith.getChildren();
197:
198: if (children.size() == 0) {
199: return result;
200: }
201:
202: for (DependencyNode node : children) {
203: result.addAll(getAllNodes(node));
204: }
205: return result;
206: }
207:
208: static class LevelNode {
209: int level;
210: DependencyNode node;
211:
212: public LevelNode(DependencyNode nd, int lvl) {
213: level = lvl;
214: node = nd;
215: }
216:
217: public boolean equals(Object other) {
218: if (other instanceof LevelNode) {
219: LevelNode n = (LevelNode) other;
220: return node.getTable().getTableName().equalsIgnoreCase(
221: n.node.getTable().getTableName());
222: }
223: return false;
224: }
225:
226: public int hashCode() {
227: return node.getTable().getTableName().hashCode();
228: }
229:
230: public String toString() {
231: return node.getTable().getTableName() + ", Level=" + level;
232: }
233: }
234:
235: }
|