001: /*
002: * All content copyright (c) 2003-2006 Terracotta, Inc., except as may otherwise be noted in a separate copyright
003: * notice. All rights reserved.
004: */
005: package com.tc.objectserver.handshakemanager;
006:
007: import com.tc.exception.ImplementMe;
008: import com.tc.logging.TCLogger;
009: import com.tc.logging.TCLogging;
010: import com.tc.net.groups.ClientID;
011: import com.tc.net.groups.NodeID;
012: import com.tc.net.protocol.tcm.ChannelID;
013: import com.tc.net.protocol.tcm.MessageChannel;
014: import com.tc.net.protocol.tcm.TestMessageChannel;
015: import com.tc.net.protocol.transport.ConnectionID;
016: import com.tc.object.ObjectID;
017: import com.tc.object.lockmanager.api.LockContext;
018: import com.tc.object.lockmanager.api.LockID;
019: import com.tc.object.lockmanager.api.LockLevel;
020: import com.tc.object.lockmanager.api.ThreadID;
021: import com.tc.object.lockmanager.api.WaitContext;
022: import com.tc.object.msg.BatchTransactionAcknowledgeMessage;
023: import com.tc.object.msg.ClientHandshakeAckMessage;
024: import com.tc.object.msg.TestClientHandshakeMessage;
025: import com.tc.object.net.DSOChannelManager;
026: import com.tc.object.net.DSOChannelManagerEventListener;
027: import com.tc.object.tx.WaitInvocation;
028: import com.tc.objectserver.api.TestSink;
029: import com.tc.objectserver.l1.api.TestClientStateManager;
030: import com.tc.objectserver.l1.api.TestClientStateManager.AddReferenceContext;
031: import com.tc.objectserver.lockmanager.api.TestLockManager;
032: import com.tc.objectserver.lockmanager.api.TestLockManager.ReestablishLockContext;
033: import com.tc.objectserver.lockmanager.api.TestLockManager.WaitCallContext;
034: import com.tc.objectserver.tx.TestServerTransactionManager;
035: import com.tc.test.TCTestCase;
036: import com.tc.util.SequenceID;
037: import com.tc.util.SequenceValidator;
038: import com.tc.util.TestTimer;
039: import com.tc.util.TestTimer.ScheduleCallContext;
040: import com.tc.util.concurrent.NoExceptionLinkedQueue;
041: import com.tc.util.sequence.ObjectIDSequenceProvider;
042:
043: import java.util.ArrayList;
044: import java.util.Collection;
045: import java.util.HashMap;
046: import java.util.HashSet;
047: import java.util.Iterator;
048: import java.util.LinkedList;
049: import java.util.List;
050: import java.util.Map;
051: import java.util.Set;
052:
053: public class ServerClientHandshakeManagerTest extends TCTestCase {
054:
055: private ServerClientHandshakeManager hm;
056: private TestClientStateManager clientStateManager;
057: private TestLockManager lockManager;
058: private TestSink lockResponseSink;
059: private long reconnectTimeout;
060: private Set existingUnconnectedClients;
061: private TestTimer timer;
062: private TestChannelManager channelManager;
063: private SequenceValidator sequenceValidator;
064: private long objectIDSequenceStart;
065:
066: public void setUp() {
067: existingUnconnectedClients = new HashSet();
068: clientStateManager = new TestClientStateManager();
069: lockManager = new TestLockManager();
070: lockResponseSink = new TestSink();
071: reconnectTimeout = 10 * 1000;
072: timer = new TestTimer();
073: channelManager = new TestChannelManager();
074: sequenceValidator = new SequenceValidator(0);
075: objectIDSequenceStart = 1000;
076: }
077:
078: private void initHandshakeManager() {
079: TCLogger logger = TCLogging
080: .getLogger(ServerClientHandshakeManager.class);
081: this .hm = new ServerClientHandshakeManager(logger,
082: channelManager, new TestServerTransactionManager(),
083: sequenceValidator, clientStateManager, lockManager,
084: lockResponseSink, new ObjectIDSequenceProvider(
085: objectIDSequenceStart), timer,
086: reconnectTimeout, false, logger);
087: this .hm
088: .setStarting(convertToConnectionIds(existingUnconnectedClients));
089: }
090:
091: private Set convertToConnectionIds(Set s) {
092: HashSet ns = new HashSet();
093: for (Iterator i = s.iterator(); i.hasNext();) {
094: ClientID cid = (ClientID) i.next();
095: ns.add(new ConnectionID(cid.getChannelID().toLong(),
096: "FORTESTING"));
097: }
098: return ns;
099: }
100:
101: public void testNoUnconnectedClients() throws Exception {
102: initHandshakeManager();
103: assertStarted();
104: }
105:
106: public void testTimeout() throws Exception {
107: ClientID clientID = new ClientID(new ChannelID(100));
108:
109: existingUnconnectedClients.add(clientID);
110: existingUnconnectedClients
111: .add(new ClientID(new ChannelID(101)));
112:
113: initHandshakeManager();
114:
115: TestClientHandshakeMessage handshake = newClientHandshakeMessage(clientID);
116: hm.notifyClientConnect(handshake);
117:
118: // make sure connecting a client schedules the timer
119: assertEquals(1, timer.scheduleCalls.size());
120: TestTimer.ScheduleCallContext scc = (ScheduleCallContext) timer.scheduleCalls
121: .get(0);
122:
123: // make sure executing the timer task calls cancel on the timer and calls
124: // notifyTimeout() on the handshake manager.
125: assertTrue(timer.cancelCalls.isEmpty());
126: scc.task.run();
127: assertEquals(1, timer.cancelCalls.size());
128: assertEquals(1, channelManager.closeAllChannelIDs.size());
129: assertEquals(new ClientID(new ChannelID(101)),
130: channelManager.closeAllChannelIDs.get(0));
131:
132: // make sure everything is started properly
133: assertStarted();
134: }
135:
136: public void testNotifyTimeout() throws Exception {
137: ClientID channelID1 = new ClientID(new ChannelID(1));
138: ClientID channelID2 = new ClientID(new ChannelID(2));
139:
140: existingUnconnectedClients.add(channelID1);
141: existingUnconnectedClients.add(channelID2);
142:
143: initHandshakeManager();
144:
145: assertFalse(hm.isStarted());
146:
147: // make sure that calling notify timeout causes the remaining unconnected
148: // clients to be closed.
149: hm.notifyTimeout();
150: assertEquals(2, channelManager.closeAllChannelIDs.size());
151: assertEquals(existingUnconnectedClients, new HashSet(
152: channelManager.closeAllChannelIDs));
153: assertStarted();
154: }
155:
156: public void testBasic() throws Exception {
157: final Set connectedClients = new HashSet();
158: ClientID clientID1 = new ClientID(new ChannelID(100));
159: ClientID clientID2 = new ClientID(new ChannelID(101));
160: ClientID clientID3 = new ClientID(new ChannelID(102));
161:
162: // channelManager.channelIDs.add(channelID1);
163: // channelManager.channelIDs.add(channelID2);
164: // channelManager.channelIDs.add(channelID3);
165:
166: existingUnconnectedClients.add(clientID1);
167: existingUnconnectedClients.add(clientID2);
168:
169: initHandshakeManager();
170:
171: TestClientHandshakeMessage handshake = newClientHandshakeMessage(clientID1);
172: ArrayList sequenceIDs = new ArrayList();
173: SequenceID minSequenceID = new SequenceID(10);
174: sequenceIDs.add(minSequenceID);
175: handshake.transactionSequenceIDs = sequenceIDs;
176: handshake.clientObjectIds.add(new ObjectID(200));
177: handshake.clientObjectIds.add(new ObjectID(20002));
178:
179: List lockContexts = new LinkedList();
180:
181: lockContexts.add(new LockContext(new LockID("my lock"),
182: clientID1, new ThreadID(10001), LockLevel.WRITE));
183: lockContexts.add(new LockContext(new LockID("my other lock)"),
184: clientID1, new ThreadID(10002), LockLevel.READ));
185: handshake.lockContexts.addAll(lockContexts);
186:
187: WaitContext waitContext = new WaitContext(
188: new LockID("d;alkjd"), clientID1, new ThreadID(101),
189: LockLevel.WRITE, new WaitInvocation());
190: handshake.waitContexts.add(waitContext);
191: handshake.isChangeListener = true;
192:
193: assertFalse(sequenceValidator.isNext(handshake.getClientID(),
194: new SequenceID(minSequenceID.toLong())));
195: assertEquals(2, existingUnconnectedClients.size());
196: assertFalse(hm.isStarted());
197: assertTrue(hm.isStarting());
198:
199: // reset sequence validator
200: sequenceValidator.remove(handshake.getClientID());
201:
202: // connect the first client
203: channelManager.clientIDs.add(handshake.clientID);
204: hm.notifyClientConnect(handshake);
205: connectedClients.add(handshake);
206:
207: // make sure no state change happened.
208: assertTrue(hm.isStarting());
209: assertFalse(hm.isStarted());
210:
211: // make sure the timer task was scheduled properly
212: assertEquals(1, timer.scheduleCalls.size());
213: TestTimer.ScheduleCallContext scc = (ScheduleCallContext) timer.scheduleCalls
214: .get(0);
215: assertEquals(new Long(reconnectTimeout), scc.delay);
216: assertTrue(scc.period == null);
217: assertTrue(scc.time == null);
218:
219: // make sure the transaction sequence was set
220: assertTrue(sequenceValidator.isNext(handshake.getClientID(),
221: new SequenceID(minSequenceID.toLong())));
222:
223: // make sure all of the object references from that client were added to the
224: // client state manager.
225: assertTrue(handshake.clientObjectIds.size() > 0);
226: assertEquals(handshake.clientObjectIds.size(),
227: clientStateManager.addReferenceCalls.size());
228: for (Iterator i = clientStateManager.addReferenceCalls
229: .iterator(); i.hasNext();) {
230: TestClientStateManager.AddReferenceContext ctxt = (AddReferenceContext) i
231: .next();
232: assertTrue(handshake.clientObjectIds.remove(ctxt.objectID));
233: }
234: assertTrue(handshake.clientObjectIds.isEmpty());
235:
236: // make sure outstanding locks are reestablished
237: assertEquals(lockContexts.size(), handshake.lockContexts.size());
238: assertEquals(handshake.lockContexts.size(),
239: lockManager.reestablishLockCalls.size());
240: for (int i = 0; i < lockContexts.size(); i++) {
241: LockContext lockContext = (LockContext) lockContexts.get(i);
242: TestLockManager.ReestablishLockContext ctxt = (ReestablishLockContext) lockManager.reestablishLockCalls
243: .get(i);
244: assertEquals(lockContext.getLockID(), ctxt.lockContext
245: .getLockID());
246: assertEquals(lockContext.getNodeID(), ctxt.lockContext
247: .getNodeID());
248: assertEquals(lockContext.getThreadID(), ctxt.lockContext
249: .getThreadID());
250: assertEquals(lockContext.getLockLevel(), ctxt.lockContext
251: .getLockLevel());
252: }
253:
254: // make sure the wait contexts are reestablished.
255: assertEquals(1, handshake.waitContexts.size());
256: assertEquals(handshake.waitContexts.size(),
257: lockManager.reestablishWaitCalls.size());
258: TestLockManager.WaitCallContext ctxt = (WaitCallContext) lockManager.reestablishWaitCalls
259: .get(0);
260: assertEquals(waitContext.getLockID(), ctxt.lockID);
261: assertEquals(waitContext.getNodeID(), ctxt.nid);
262: assertEquals(waitContext.getThreadID(), ctxt.threadID);
263: assertEquals(waitContext.getWaitInvocation(),
264: ctxt.waitInvocation);
265: assertSame(lockResponseSink, ctxt.lockResponseSink);
266:
267: assertEquals(0, timer.cancelCalls.size());
268:
269: // make sure no ack messages have been sent, since we're not started yet.
270: assertEquals(0, channelManager.handshakeMessages.size());
271:
272: // connect the last outstanding client.
273: handshake = newClientHandshakeMessage(clientID2);
274: channelManager.clientIDs.add(handshake.clientID);
275: hm.notifyClientConnect(handshake);
276: connectedClients.add(handshake);
277:
278: assertStarted();
279:
280: // make sure it cancels the timeout timer.
281: assertEquals(1, timer.cancelCalls.size());
282:
283: // now that the server has started, connect a new client
284: handshake = newClientHandshakeMessage(clientID3);
285: channelManager.clientIDs.add(handshake.clientID);
286: hm.notifyClientConnect(handshake);
287: connectedClients.add(handshake);
288:
289: // make sure that ack messages were sent for all incoming handshake messages.
290: for (Iterator i = connectedClients.iterator(); i.hasNext();) {
291: handshake = (TestClientHandshakeMessage) i.next();
292: Collection acks = channelManager
293: .getMessages(handshake.clientID);
294: assertEquals("Wrong number of acks for channel: "
295: + handshake.clientID, 1, acks.size());
296: TestClientHandshakeAckMessage ack = (TestClientHandshakeAckMessage) new ArrayList(
297: acks).get(0);
298: assertNotNull(ack.sendQueue.poll(1));
299: }
300: }
301:
302: public void testObjectIDsInHandshake() throws Exception {
303: final Set connectedClients = new HashSet();
304: ClientID clientID1 = new ClientID(new ChannelID(100));
305: ClientID clientID2 = new ClientID(new ChannelID(101));
306: ClientID clientID3 = new ClientID(new ChannelID(102));
307:
308: existingUnconnectedClients.add(clientID1);
309: existingUnconnectedClients.add(clientID2);
310:
311: initHandshakeManager();
312:
313: TestClientHandshakeMessage handshake = newClientHandshakeMessage(clientID1);
314: handshake.setIsObjectIDsRequested(true);
315:
316: hm.notifyClientConnect(handshake);
317: channelManager.clientIDs.add(handshake.clientID);
318: connectedClients.add(handshake);
319:
320: // make sure no ack messages have been sent, since we're not started yet.
321: assertEquals(0, channelManager.handshakeMessages.size());
322:
323: // connect the last outstanding client.
324: handshake = newClientHandshakeMessage(clientID2);
325: handshake.setIsObjectIDsRequested(false);
326: channelManager.clientIDs.add(handshake.clientID);
327: hm.notifyClientConnect(handshake);
328: connectedClients.add(handshake);
329:
330: assertStarted();
331:
332: // now that the server has started, connect a new client
333: handshake = newClientHandshakeMessage(clientID3);
334: handshake.setIsObjectIDsRequested(true);
335: channelManager.clientIDs.add(handshake.clientID);
336: hm.notifyClientConnect(handshake);
337: connectedClients.add(handshake);
338:
339: // make sure that ack messages were sent for all incoming handshake messages.
340: for (Iterator i = connectedClients.iterator(); i.hasNext();) {
341: handshake = (TestClientHandshakeMessage) i.next();
342: Collection acks = channelManager
343: .getMessages(handshake.clientID);
344: assertEquals("Wrong number of acks for channel: "
345: + handshake.clientID, 1, acks.size());
346: TestClientHandshakeAckMessage ack = (TestClientHandshakeAckMessage) new ArrayList(
347: acks).get(0);
348: assertNotNull(ack.sendQueue.poll(1));
349:
350: if (ack.clientID.equals(clientID2)) {
351: assertTrue(ack.getObjectIDSequenceStart() == 0);
352: assertTrue(ack.getObjectIDSequenceEnd() == 0);
353: } else {
354: assertFalse(ack.getObjectIDSequenceStart() == 0);
355: assertFalse(ack.getObjectIDSequenceEnd() == 0);
356: assertTrue(ack.getObjectIDSequenceStart() < ack
357: .getObjectIDSequenceEnd());
358: }
359: }
360: }
361:
362: private void assertStarted() {
363: // make sure the lock manager got started
364: assertEquals(1, lockManager.startCalls.size());
365:
366: // make sure the state change happens properly
367: assertTrue(hm.isStarted());
368: }
369:
370: private TestClientHandshakeMessage newClientHandshakeMessage(
371: ClientID clientID) {
372: TestClientHandshakeMessage handshake = new TestClientHandshakeMessage();
373: handshake.clientID = clientID;
374: ArrayList sequenceIDs = new ArrayList();
375: sequenceIDs.add(new SequenceID(1));
376: handshake.setTransactionSequenceIDs(sequenceIDs);
377: return handshake;
378: }
379:
380: private static final class TestChannelManager implements
381: DSOChannelManager {
382:
383: public final List closeAllChannelIDs = new ArrayList();
384: public final Map handshakeMessages = new HashMap();
385: public final Set clientIDs = new HashSet();
386: private String serverVersion = "N/A";
387:
388: public void closeAll(Collection theChannelIDs) {
389: closeAllChannelIDs.addAll(theChannelIDs);
390: }
391:
392: public MessageChannel getActiveChannel(NodeID id) {
393: return null;
394: }
395:
396: public MessageChannel[] getActiveChannels() {
397: return null;
398: }
399:
400: public Set getAllActiveClientIDs() {
401: return this .clientIDs;
402: }
403:
404: public boolean isValidID(ChannelID channelID) {
405: return false;
406: }
407:
408: public String getChannelAddress(NodeID nid) {
409: return null;
410: }
411:
412: public Collection getMessages(ClientID clientID) {
413: Collection msgs = (Collection) this .handshakeMessages
414: .get(clientID);
415: if (msgs == null) {
416: msgs = new ArrayList();
417: this .handshakeMessages.put(clientID, msgs);
418: }
419: return msgs;
420: }
421:
422: private ClientHandshakeAckMessage newClientHandshakeAckMessage(
423: ClientID clientID) {
424: ClientHandshakeAckMessage msg = new TestClientHandshakeAckMessage(
425: clientID);
426: getMessages(clientID).add(msg);
427: return msg;
428: }
429:
430: public BatchTransactionAcknowledgeMessage newBatchTransactionAcknowledgeMessage(
431: NodeID nid) {
432: throw new ImplementMe();
433: }
434:
435: public void addEventListener(
436: DSOChannelManagerEventListener listener) {
437: throw new ImplementMe();
438: }
439:
440: public Set getAllClientIDs() {
441: return getAllActiveClientIDs();
442: }
443:
444: public boolean isActiveID(NodeID nodeID) {
445: throw new ImplementMe();
446: }
447:
448: public void makeChannelActive(ClientID clientID, long startIDs,
449: long endIDs, boolean persistent) {
450: ClientHandshakeAckMessage ackMsg = newClientHandshakeAckMessage(clientID);
451: ackMsg.initialize(startIDs, endIDs, persistent,
452: getAllClientIDsString(), clientID.toString(),
453: serverVersion);
454: ackMsg.send();
455: }
456:
457: private Set getAllClientIDsString() {
458: Set s = new HashSet();
459: for (Iterator i = getAllClientIDs().iterator(); i.hasNext();) {
460: ClientID cid = (ClientID) i.next();
461: s.add(cid.toString());
462: }
463: return s;
464: }
465:
466: public void makeChannelActiveNoAck(MessageChannel channel) {
467: //
468: }
469:
470: public ClientID getClientIDFor(ChannelID channelID) {
471: return new ClientID(channelID);
472: }
473:
474: }
475:
476: private static class TestClientHandshakeAckMessage implements
477: ClientHandshakeAckMessage {
478: public final NoExceptionLinkedQueue sendQueue = new NoExceptionLinkedQueue();
479: public final ClientID clientID;
480: public long start;
481: public long end;
482: private boolean persistent;
483: private final TestMessageChannel channel;
484: private String serverVersion;
485:
486: private TestClientHandshakeAckMessage(ClientID clientID) {
487: this .clientID = clientID;
488: this .channel = new TestMessageChannel();
489: this .channel.channelID = clientID.getChannelID();
490: }
491:
492: public void send() {
493: sendQueue.put(new Object());
494: }
495:
496: public long getObjectIDSequenceStart() {
497: return start;
498: }
499:
500: public long getObjectIDSequenceEnd() {
501: return end;
502: }
503:
504: public boolean getPersistentServer() {
505: return persistent;
506: }
507:
508: public void initialize(long startOid, long endOid,
509: boolean isPersistent, Set allNodes, String this NodeID,
510: String sv) {
511: this .start = startOid;
512: this .end = endOid;
513: this .persistent = isPersistent;
514: this .serverVersion = sv;
515: }
516:
517: public MessageChannel getChannel() {
518: return channel;
519: }
520:
521: public String[] getAllNodes() {
522: throw new ImplementMe();
523: }
524:
525: public String getThisNodeId() {
526: throw new ImplementMe();
527: }
528:
529: public String getServerVersion() {
530: return serverVersion;
531: }
532:
533: }
534:
535: }
|