001: // Copyright (c) 2005 Brian Wellington (bwelling@xbill.org)
002:
003: package org.xbill.DNS;
004:
005: import java.io.*;
006: import java.net.*;
007: import java.nio.*;
008: import java.nio.channels.*;
009:
010: final class TCPClient extends Client {
011:
012: public TCPClient(long endTime) throws IOException {
013: super (SocketChannel.open(), endTime);
014: }
015:
016: void bind(SocketAddress addr) throws IOException {
017: SocketChannel channel = (SocketChannel) key.channel();
018: channel.socket().bind(addr);
019: }
020:
021: void connect(SocketAddress addr) throws IOException {
022: SocketChannel channel = (SocketChannel) key.channel();
023: if (channel.connect(addr))
024: return;
025: key.interestOps(SelectionKey.OP_CONNECT);
026: try {
027: while (!channel.finishConnect()) {
028: if (!key.isConnectable())
029: blockUntil(key, endTime);
030: }
031: } finally {
032: if (key.isValid())
033: key.interestOps(0);
034: }
035: }
036:
037: void send(byte[] data) throws IOException {
038: SocketChannel channel = (SocketChannel) key.channel();
039: verboseLog("TCP write", data);
040: byte[] lengthArray = new byte[2];
041: lengthArray[0] = (byte) (data.length >>> 8);
042: lengthArray[1] = (byte) (data.length & 0xFF);
043: ByteBuffer[] buffers = new ByteBuffer[2];
044: buffers[0] = ByteBuffer.wrap(lengthArray);
045: buffers[1] = ByteBuffer.wrap(data);
046: int nsent = 0;
047: key.interestOps(SelectionKey.OP_WRITE);
048: try {
049: while (nsent < data.length + 2) {
050: if (key.isWritable()) {
051: long n = channel.write(buffers);
052: if (n < 0)
053: throw new EOFException();
054: nsent += (int) n;
055: if (nsent < data.length + 2
056: && System.currentTimeMillis() > endTime)
057: throw new SocketTimeoutException();
058: } else
059: blockUntil(key, endTime);
060: }
061: } finally {
062: if (key.isValid())
063: key.interestOps(0);
064: }
065: }
066:
067: private byte[] _recv(int length) throws IOException {
068: SocketChannel channel = (SocketChannel) key.channel();
069: int nrecvd = 0;
070: byte[] data = new byte[length];
071: ByteBuffer buffer = ByteBuffer.wrap(data);
072: key.interestOps(SelectionKey.OP_READ);
073: try {
074: while (nrecvd < length) {
075: if (key.isReadable()) {
076: long n = channel.read(buffer);
077: if (n < 0)
078: throw new EOFException();
079: nrecvd += (int) n;
080: if (nrecvd < length
081: && System.currentTimeMillis() > endTime)
082: throw new SocketTimeoutException();
083: } else
084: blockUntil(key, endTime);
085: }
086: } finally {
087: if (key.isValid())
088: key.interestOps(0);
089: }
090: return data;
091: }
092:
093: byte[] recv() throws IOException {
094: byte[] buf = _recv(2);
095: int length = ((buf[0] & 0xFF) << 8) + (buf[1] & 0xFF);
096: byte[] data = _recv(length);
097: verboseLog("TCP read", data);
098: return data;
099: }
100:
101: static byte[] sendrecv(SocketAddress local, SocketAddress remote,
102: byte[] data, long endTime) throws IOException {
103: TCPClient client = new TCPClient(endTime);
104: try {
105: if (local != null)
106: client.bind(local);
107: client.connect(remote);
108: client.send(data);
109: return client.recv();
110: } finally {
111: client.cleanup();
112: }
113: }
114:
115: static byte[] sendrecv(SocketAddress addr, byte[] data, long endTime)
116: throws IOException {
117: return sendrecv(null, addr, data, endTime);
118: }
119:
120: }
|