001: package org.skunk.net;
002:
003: import java.io.*;
004: import java.net.*;
005: import java.text.*;
006: import java.util.*;
007: import java.security.KeyStore;
008: import javax.net.*;
009: import javax.net.ssl.*;
010: import java.security.Security;
011: import javax.security.cert.X509Certificate;
012: import com.sun.net.ssl.*;
013:
014: public class LoggingProxyServer implements Runnable {
015: private int port;
016: private String remoteHost;
017: private int remotePort;
018: private String logFileName;
019: private Thread myThread;
020: private boolean stopThread = false;
021: private ServerSocket myServerSocket;
022: private boolean secure;
023:
024: public LoggingProxyServer(int port, String remoteHost,
025: int remotePort, String logFileName, boolean secure) {
026: this .port = port;
027: this .remoteHost = remoteHost;
028: this .remotePort = remotePort;
029: this .logFileName = logFileName;
030: this .secure = secure;
031: myThread = new Thread(this );
032: myThread.setPriority(Thread.NORM_PRIORITY);
033: myThread.start();
034: }
035:
036: public void quit() {
037: stopThread = true;
038: }
039:
040: public void finalize() {
041: quit();
042: }
043:
044: private static ServerSocketFactory getSSLServerSocketFactory() {
045: Security
046: .addProvider(new com.sun.net.ssl.internal.ssl.Provider());
047: SSLServerSocketFactory ssf = null;
048: try {
049: // set up key manager to do server authentication
050: SSLContext ctx;
051: KeyManagerFactory kmf;
052: KeyStore ks;
053: char[] passphrase = "passphrase".toCharArray();
054:
055: ctx = SSLContext.getInstance("TLS");
056: kmf = KeyManagerFactory.getInstance("SunX509");
057: ks = KeyStore.getInstance("JKS");
058:
059: ks.load(LoggingProxyServer.class
060: .getResourceAsStream("testkeys"), passphrase);
061: kmf.init(ks, passphrase);
062: ctx.init(kmf.getKeyManagers(), null, null);
063:
064: ssf = ctx.getServerSocketFactory();
065: return ssf;
066: } catch (Exception e) {
067: e.printStackTrace();
068: }
069: return null;
070: }
071:
072: /**
073: *
074: */
075: private void addAnonymousCipherSuites(SSLServerSocket sss) {
076: String[] supported = sss.getSupportedCipherSuites();
077: ArrayList enabled = new ArrayList();
078: String[] oldEnabled = sss.getEnabledCipherSuites();
079: for (int i = 0; i > oldEnabled.length; i++)
080: enabled.add(oldEnabled[i]);
081: for (int i = 0; i < supported.length; i++) {
082: String candidate = supported[i];
083: if (candidate.indexOf("_anon_") > 0
084: && !enabled.contains(candidate)) {
085: enabled.add(candidate);
086: }
087: }
088: String[] enabledArray = new String[enabled.size()];
089: for (int i = 0; i < enabledArray.length; i++)
090: enabledArray[i] = enabled.get(i).toString();
091: sss.setEnabledCipherSuites(enabledArray);
092: }
093:
094: protected ServerSocket getServerSocket(int port) throws IOException {
095: if (secure) {
096: ServerSocket ss = getSSLServerSocketFactory()
097: .createServerSocket(port);
098: addAnonymousCipherSuites((SSLServerSocket) ss);
099: return ss;
100:
101: } else
102: return new ServerSocket(port);
103: }
104:
105: public void run() {
106: try {
107: myServerSocket = getServerSocket(port); //new ServerSocket(port);
108: while (true) {
109: //gracefully kill thread if flag is set
110: if (stopThread) {
111: try {
112: myServerSocket.close();
113: } finally {
114: return;
115: }
116: }
117: try {
118: //get a request on listening socket
119: Socket clientToProxySocket = myServerSocket
120: .accept();
121: final InputStream in = clientToProxySocket
122: .getInputStream();
123: final OutputStream out = clientToProxySocket
124: .getOutputStream();
125: Socket proxyToRemoteSocket = new Socket(remoteHost,
126: remotePort);
127: //is this a reasonable timeout or not?
128: //proxyToRemoteSocket.setSoTimeout(150000);
129: final InputStream remoteIn = proxyToRemoteSocket
130: .getInputStream();
131: final OutputStream remoteOut = proxyToRemoteSocket
132: .getOutputStream();
133: //copy stream unchanged to log and remote host
134: /*
135: note that this logs to the same file in two threads. They can thus end up mixing their output.
136: for our test purposes this is okay -- the only interpolated messages seem to be some TCP codes
137: */
138: Thread requestThread = new Thread() {
139: public void run() {
140: try {
141: copyStream(in, new LoggingOutputStream(
142: logFileName, remoteOut));
143: } catch (IOException ohGoodness) {
144: ohGoodness.printStackTrace(System.err);
145: }
146: }
147: };
148:
149: requestThread.start();
150:
151: Thread responseThread = new Thread() {
152: public void run() {
153: try {
154: copyStream(remoteIn,
155: new LoggingOutputStream(
156: logFileName, out));
157: } catch (IOException notAgain) {
158: notAgain.printStackTrace(System.err);
159: }
160: }
161: };
162:
163: responseThread.start();
164: } catch (IOException oyster) {
165: System.err.println("error in main thread loop:");
166: oyster.printStackTrace();
167: stopThread = true;
168: continue;
169: }
170: }
171: } catch (IOException oyVeh) {
172: oyVeh.printStackTrace();
173: }
174:
175: }
176:
177: private void copyStream(InputStream in, OutputStream out)
178: throws IOException {
179:
180: byte[] buffer = new byte[1024];
181: while (true) {
182: int bytesRead = in.read(buffer);
183: if (bytesRead == -1)
184: break;
185: out.write(buffer, 0, bytesRead);
186: }
187: out.flush();
188: }
189:
190: public static void main(String[] args) {
191: if (args.length != 4 && args.length != 5) {
192: usage();
193: }
194: try {
195: int port = Integer.parseInt(args[0]);
196: String remoteHost = args[1];
197: int remotePort = Integer.parseInt(args[2]);
198: String logFile = args[3];
199: boolean secure = args.length == 5 && args[4].equals("-s");
200: LoggingProxyServer lps = new LoggingProxyServer(port,
201: remoteHost, remotePort, logFile, secure);
202: System.out.println("proxy server for " + remoteHost + ":"
203: + remotePort + " running on port " + port
204: + " logging to " + logFile);
205: if (secure)
206: System.out.println("server socket is SSL");
207: } catch (Exception e) {
208: e.printStackTrace(System.err);
209: }
210: }
211:
212: public static void usage() {
213: System.out
214: .println("Usage: java org.skunk.net.LoggingProxyServer port remoteHost remotePort logFile [-s]");
215: System.out
216: .println("\tuse of the -s option will cause a secure server socket to be created.");
217: System.exit(1);
218: }
219: }
220:
221: class LoggingOutputStream extends FilterOutputStream {
222: private OutputStream logStream;
223:
224: public LoggingOutputStream(String fileName, OutputStream out)
225: throws IOException {
226: this (new FileOutputStream(fileName, true), out);
227: }
228:
229: public LoggingOutputStream(OutputStream logStream, OutputStream out) {
230: super (out);
231: this .logStream = logStream;
232: }
233:
234: public void write(int b) throws IOException {
235: out.write(b);
236: logStream.write(b);
237: }
238:
239: public void write(byte[] b) throws IOException {
240: out.write(b);
241: logStream.write(b);
242: }
243:
244: public void write(byte[] b, int off, int len) throws IOException {
245: out.write(b, off, len);
246: logStream.write(b, off, len);
247: }
248:
249: public void flush() throws IOException {
250: out.flush();
251: logStream.flush();
252: }
253:
254: public void close() throws IOException {
255: out.close();
256: logStream.close();
257: }
258: }
|