001: /*
002: * All content copyright (c) 2003-2006 Terracotta, Inc., except as may otherwise be noted in a separate copyright notice. All rights reserved.
003: */
004: package com.tc.net.protocol.transport;
005:
006: import com.tc.bytes.TCByteBuffer;
007: import com.tc.net.protocol.AbstractTCNetworkHeader;
008: import com.tc.net.protocol.TCNetworkMessage;
009: import com.tc.net.protocol.delivery.OOOProtocolMessage;
010: import com.tc.net.protocol.tcm.TCMessage;
011: import com.tc.util.Conversion;
012:
013: /**
014: * This class models the header portion of a TC wire protocol message. NOTE: This class makes no attempt to be thread
015: * safe! All concurrent access must be syncronized
016: *
017: * <pre>
018: * 0 1 2 3
019: * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
020: * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
021: * |Version| HL |Type of Service| Time to Live | Protocol |
022: * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
023: * | Magic number |
024: * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
025: * | 32 Bit Total Length |
026: * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
027: * | Alder32 Header Checksum |
028: * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
029: * | Source Address |
030: * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
031: * | Destination Address |
032: * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
033: * | Source Port | Destination Port |
034: * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
035: * | Options | Padding |
036: * +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
037: * </pre>
038: *
039: * @author teck
040: */
041:
042: public class WireProtocolHeader extends AbstractTCNetworkHeader {
043: public static final byte VERSION_1 = 1;
044: public static final byte[] VALID_VERSIONS = new byte[] { VERSION_1 };
045:
046: public static final short DEFAULT_TTL = 64;
047:
048: public static final short PROTOCOL_UNKNOWN = 0;
049: public static final short PROTOCOL_TCM = 1;
050: public static final short PROTOCOL_TRANSPORT_HANDSHAKE = 2;
051: public static final short PROTOCOL_OOOP = 3;
052:
053: private static final int MAGIC_NUM = 0xAAAAAAAA;
054:
055: public static final short[] VALID_PROTOCOLS = new short[] {
056: PROTOCOL_TCM, PROTOCOL_TRANSPORT_HANDSHAKE, PROTOCOL_OOOP };
057:
058: // 15 32-bit words max
059: static final short MAX_LENGTH = 15 * 4;
060:
061: // 7 32-bit words min
062: static final short MIN_LENGTH = 7 * 4;
063:
064: public static short getProtocolForMessageClass(TCNetworkMessage msg) {
065: // TODO: is there a better way to do this (ie. not using instanceof)?
066: if (msg instanceof TCMessage) {
067: return PROTOCOL_TCM;
068: } else if (msg instanceof OOOProtocolMessage) {
069: return PROTOCOL_OOOP;
070: }
071:
072: throw new AssertionError("Unknown protocol");
073: }
074:
075: public WireProtocolHeader() {
076: super (MIN_LENGTH, MAX_LENGTH);
077:
078: setMagicNum(MAGIC_NUM);
079: setVersion(VERSION_1);
080: setHeaderLength((byte) (MIN_LENGTH / 4));
081: setTimeToLive(DEFAULT_TTL);
082: setTypeOfService(TypeOfService.DEFAULT_TOS.getByteValue());
083: }
084:
085: public WireProtocolHeader(TCByteBuffer buffer) {
086: super (buffer, MIN_LENGTH, MAX_LENGTH);
087: }
088:
089: private void setMagicNum(int magic_num2) {
090: data.putInt(4, MAGIC_NUM);
091: }
092:
093: public void setVersion(byte version) {
094: if ((version <= 0) || (version > 15)) {
095: throw new IllegalArgumentException("invalid version: "
096: + version);
097: }
098:
099: set4BitValue(0, true, version);
100: }
101:
102: protected void setHeaderLength(short length) {
103: if ((length < 6) || (length > 15)) {
104: throw new IllegalArgumentException(
105: "Header length must in range 6-15");
106: }
107:
108: set4BitValue(0, false, (byte) length);
109: }
110:
111: public void setTypeOfService(short tos) {
112: data.putUbyte(1, tos);
113: }
114:
115: public void setTimeToLive(short ttl) {
116: data.putUbyte(2, ttl);
117: }
118:
119: public void setProtocol(short protocol) {
120: data.putUbyte(3, protocol);
121: }
122:
123: public void setTotalPacketLength(int length) {
124: data.putInt(8, length);
125: }
126:
127: public void setSourceAddress(byte[] srcAddr) {
128: data.put(16, srcAddr, 0, 4);
129: }
130:
131: public void setDestinationAddress(byte[] destAddr) {
132: data.put(20, destAddr, 0, 4);
133: }
134:
135: public void setSourcePort(int srcPort) {
136: data.putUshort(24, srcPort);
137: }
138:
139: public void setDestinationPort(int dstPort) {
140: data.putUshort(26, dstPort);
141: }
142:
143: public int getMagicNum() {
144: return data.getInt(4);
145: }
146:
147: public byte getVersion() {
148: return get4BitValue(0, true);
149: }
150:
151: public byte getHeaderLength() {
152: return get4BitValue(0, false);
153: }
154:
155: public short getTypeOfService() {
156: return data.getUbyte(1);
157: }
158:
159: public short getTimeToLive() {
160: return data.getUbyte(2);
161: }
162:
163: public short getProtocol() {
164: return data.getUbyte(3);
165: }
166:
167: public int getTotalPacketLength() {
168: return data.getInt(8);
169: }
170:
171: public long getChecksum() {
172: return data.getUint(12);
173: }
174:
175: public byte[] getSourceAddress() {
176: return getBytes(16, 4);
177: }
178:
179: public byte[] getDestinationAddress() {
180: return getBytes(20, 4);
181: }
182:
183: public int getSourcePort() {
184: return data.getUshort(24);
185: }
186:
187: public int getDestinationPort() {
188: return data.getUshort(26);
189: }
190:
191: public void computeChecksum() {
192: computeAdler32Checksum(12, true);
193: }
194:
195: public boolean isChecksumValid() {
196: return getChecksum() == computeAdler32Checksum(12, false);
197: }
198:
199: public void validate() throws WireProtocolHeaderFormatException {
200: // validate the magic num
201: int magic = getMagicNum();
202: if (magic != MAGIC_NUM) {
203: throw new WireProtocolHeaderFormatException(
204: "Invalid magic number: " + magic + " != "
205: + MAGIC_NUM);
206: }
207:
208: // validate the version byte
209: boolean validVersion = false;
210: byte version = getVersion();
211:
212: for (int i = 0; i < VALID_VERSIONS.length; i++) {
213: if (version == VALID_VERSIONS[i]) {
214: validVersion = true;
215: break;
216: }
217: }
218:
219: if (!validVersion) {
220: throw new WireProtocolHeaderFormatException("Bad Version: "
221: + Conversion.byte2uint(version));
222: }
223:
224: // TODO: validate the TOS byte
225:
226: // validate the TTL byte
227: int ttl = getTimeToLive();
228: if (0 == ttl) {
229: throw new WireProtocolHeaderFormatException(
230: "TTL byte cannot be equal to zero");
231: }
232:
233: // validate the protocol byte
234: boolean validProtocol = false;
235: short protocol = getProtocol();
236:
237: for (int i = 0; i < VALID_PROTOCOLS.length; i++) {
238: if (protocol == VALID_PROTOCOLS[i]) {
239: validProtocol = true;
240: break;
241: }
242: }
243:
244: if (!validProtocol) {
245: throw new WireProtocolHeaderFormatException(
246: "Bad Protocol byte: " + protocol);
247: }
248:
249: // validate the total packet length value
250: int totalLength = getTotalPacketLength();
251:
252: if (totalLength < MIN_LENGTH) {
253: throw new WireProtocolHeaderFormatException(
254: "Total length ("
255: + totalLength
256: + ") can not be less than minimum header size ("
257: + MIN_LENGTH + ")");
258: }
259:
260: if (totalLength < getHeaderByteLength()) {
261: throw new WireProtocolHeaderFormatException(
262: "Total length ("
263: + totalLength
264: + ") can not be less than actual header length ("
265: + getHeaderByteLength() + ")");
266: }
267:
268: // validate the checksum
269: if (!isChecksumValid()) {
270: throw new WireProtocolHeaderFormatException(
271: "Invalid Checksum");
272: }
273:
274: if (getSourcePort() == 0) {
275: throw new WireProtocolHeaderFormatException(
276: "Source port cannot be zero");
277: }
278:
279: if (getDestinationPort() == 0) {
280: throw new WireProtocolHeaderFormatException(
281: "Destination port cannot be zero");
282: }
283:
284: // if (Arrays.equals(getDestinationAddress(), FOUR_ZERO_BYTES)) { throw new WireProtocolHeaderFormatException(
285: // "Destination address cannot be 0.0.0.0"); }
286: //
287: // if (Arrays.equals(getSourceAddress(), FOUR_ZERO_BYTES)) { throw new WireProtocolHeaderFormatException(
288: // "Source address cannot be 0.0.0.0"); }
289:
290: // TODO: validate options (once they exist)
291: }
292:
293: public String toString() {
294: StringBuffer buf = new StringBuffer();
295: buf.append("Version: ").append(
296: Conversion.byte2uint(getVersion())).append(", ");
297: buf.append("Header Length: ").append(
298: Conversion.byte2uint(getHeaderLength())).append(", ");
299: buf.append("TOS: ").append(getTypeOfService()).append(", ");
300: buf.append("TTL: ").append(getTimeToLive()).append(", ");
301: buf.append("Protocol: ").append(getProtocolString());
302: buf.append("\n");
303: buf.append("Total Packet Length: ").append(
304: getTotalPacketLength()).append("\n");
305: buf.append("Adler32 Checksum: ").append(getChecksum()).append(
306: " (valid: ").append(isChecksumValid()).append(")\n");
307: buf.append("Source Addresss: ");
308:
309: byte src[] = getSourceAddress();
310: byte dest[] = getDestinationAddress();
311:
312: for (int i = 0; i < src.length; i++) {
313: buf.append(Conversion.byte2uint(src[i]));
314: if (i != (src.length - 1)) {
315: buf.append(".");
316: }
317: }
318: buf.append("\n");
319:
320: buf.append("Destination Addresss: ");
321: for (int i = 0; i < dest.length; i++) {
322: buf.append(Conversion.byte2uint(dest[i]));
323: if (i != (dest.length - 1)) {
324: buf.append(".");
325: }
326: }
327: buf.append("\n");
328:
329: buf.append("Source Port: ").append(getSourcePort());
330: buf.append(", Destination Port: ").append(getDestinationPort());
331: buf.append("\n");
332:
333: String errMsg = "no message";
334: boolean valid = true;
335: try {
336: validate();
337: } catch (WireProtocolHeaderFormatException e) {
338: errMsg = e.getMessage();
339: valid = false;
340: }
341: buf.append("Header Validity: ").append(valid).append(" (")
342: .append(errMsg).append(")\n");
343:
344: // TODO: display the options (if any)
345:
346: return buf.toString();
347: }
348:
349: private String getProtocolString() {
350: final short protocol = getProtocol();
351: switch (protocol) {
352: case PROTOCOL_TCM: {
353: return "TCM";
354: }
355: case PROTOCOL_OOOP: {
356: return "OOOP";
357: }
358: case PROTOCOL_TRANSPORT_HANDSHAKE: {
359: return "TRANSPORT HANDSHAKE";
360: }
361: default: {
362: return "UNKNOWN (" + protocol + ")";
363: }
364: }
365: }
366:
367: public int getMaxByteLength() {
368: return WireProtocolHeader.MAX_LENGTH;
369: }
370:
371: public int getMinByteLength() {
372: return WireProtocolHeader.MIN_LENGTH;
373: }
374:
375: public int getHeaderByteLength() {
376: return 4 * getHeaderLength();
377: }
378:
379: public boolean isTransportHandshakeMessage() {
380: return getProtocol() == PROTOCOL_TRANSPORT_HANDSHAKE;
381: }
382: }
|