001: /*
002: * Name: AuthSSLProtocolSocketFactory.java
003: * Version: $Id: AuthSSLProtocolSocketFactory.java 11799 2008-02-01 04:47:53Z mpreston $
004: * Copyright: Copyright (c) 2006, Bostech Corporation
005: * Author: mpreston
006: * Created on: Apr 20, 2006
007: * Description: Class to provide a mechanism to provide certificates for server or client authentication
008: */
009: package com.bostechcorp.cbesb.runtime.component.http.client;
010:
011: import java.io.FileInputStream;
012: import java.io.IOException;
013: import java.net.InetAddress;
014: import java.net.Socket;
015: import java.net.UnknownHostException;
016: import java.security.KeyStore;
017: import java.security.cert.Certificate;
018: import java.util.Enumeration;
019: import java.util.Vector;
020:
021: import javax.net.ssl.KeyManager;
022: import javax.net.ssl.KeyManagerFactory;
023: import javax.net.ssl.SSLContext;
024: import javax.net.ssl.SSLSocket;
025: import javax.net.ssl.TrustManager;
026: import javax.net.ssl.TrustManagerFactory;
027: import javax.net.ssl.X509TrustManager;
028:
029: import org.apache.commons.httpclient.ConnectTimeoutException;
030: import org.apache.commons.httpclient.params.HttpConnectionParams;
031: import org.apache.commons.httpclient.protocol.SecureProtocolSocketFactory;
032: import org.apache.commons.logging.Log;
033: import org.apache.commons.logging.LogFactory;
034:
035: import com.bostechcorp.cbesb.common.util.EsbPathHelper;
036: import com.bostechcorp.cbesb.runtime.component.http.HttpEndpoint;
037:
038: public class AuthSSLProtocolSocketFactory implements
039: SecureProtocolSocketFactory {
040: /** Log object for this class. */
041: private static final Log logger = LogFactory
042: .getLog(AuthSSLProtocolSocketFactory.class);
043:
044: private HttpEndpoint endpoint;
045:
046: private SSLContext sslcontext = null;
047: private String[] anon_ciphers = null;
048: private String[] authenticated_ciphers = null;
049:
050: public AuthSSLProtocolSocketFactory(HttpEndpoint endpoint) {
051: this .endpoint = endpoint;
052: }
053:
054: private SSLContext createSSLContext() {
055: SSLContext ctx = null;
056: try {
057: ctx = SSLContext.getInstance(endpoint.getSslProtocol());
058:
059: // get the private key
060: KeyManager[] km = null;
061: if (endpoint.isUsePrivateKey()) {
062: KeyStore ks = KeyStore.getInstance("JKS");
063: KeyManagerFactory kmf = KeyManagerFactory
064: .getInstance("SunX509");
065: String keyStorePath = EsbPathHelper
066: .getFullPathForDef(endpoint.getKeyStoreFile());
067: if (keyStorePath.startsWith("file:")) {
068: keyStorePath = keyStorePath.substring(5);
069: }
070: logger.debug("Keystore Path = " + keyStorePath);
071: char[] keyStorePassword = endpoint
072: .getKeyStorePassword().toCharArray();
073: FileInputStream kstr = new FileInputStream(keyStorePath);
074: ks.load(kstr, keyStorePassword);
075: kmf.init(ks, keyStorePassword);
076: if (logger.isDebugEnabled()) {
077: Enumeration<String> aliases = ks.aliases();
078: while (aliases.hasMoreElements()) {
079: String alias = aliases.nextElement();
080: Certificate cert = ks.getCertificate(alias);
081: logger.debug("Keystore contains cert alias ["
082: + alias + "]" + cert.toString());
083: }
084: }
085:
086: km = kmf.getKeyManagers();
087: }
088: // get the trust store
089: TrustManager[] tm = null;
090: if ((endpoint.isAuthenticateServer() || endpoint
091: .isAuthenticateClient())) {
092: if (!endpoint.isUseDefaultTrustStore()) {
093: TrustManagerFactory tmf = TrustManagerFactory
094: .getInstance("SunX509");
095: KeyStore ks = KeyStore.getInstance("JKS");
096: String trustStorePath = EsbPathHelper
097: .getFullPathForDef(endpoint
098: .getTrustStoreFile());
099: if (trustStorePath.startsWith("file:")) {
100: trustStorePath = trustStorePath.substring(5);
101: }
102: logger.debug("TrustStore Path = " + trustStorePath);
103: FileInputStream tstr = new FileInputStream(
104: trustStorePath);
105: ks.load(tstr, null);
106: tmf.init(ks);
107: if (logger.isDebugEnabled()) {
108: Enumeration<String> aliases = ks.aliases();
109: while (aliases.hasMoreElements()) {
110: String alias = aliases.nextElement();
111: Certificate cert = ks.getCertificate(alias);
112: logger
113: .debug("Truststore contains cert alias ["
114: + alias
115: + "]"
116: + cert.toString());
117: }
118: }
119: tm = tmf.getTrustManagers();
120: }
121: } else {
122: // trust anyone for unauthenticated
123: TrustManager[] trustAllCerts = new TrustManager[] { new X509TrustManager() {
124:
125: public java.security.cert.X509Certificate[] getAcceptedIssuers() {
126: return null;
127: }
128:
129: public void checkClientTrusted(
130: java.security.cert.X509Certificate[] certs,
131: String authType) {
132: // No need to implement.
133: }
134:
135: public void checkServerTrusted(
136: java.security.cert.X509Certificate[] certs,
137: String authType) {
138: // No need to implement.
139: }
140: } };
141: tm = trustAllCerts;
142: }
143: // initialize the SSLContext
144: ctx.init(km, tm, null);
145: } catch (Exception e) {
146: logger.error("error getting SSL context ", e);
147: endpoint
148: .setState("Error", "error getting SSL context " + e);
149: ctx = null;
150: }
151: return ctx;
152: }
153:
154: private SSLContext getSSLContext() {
155: logger.debug("AuthSSLProtocolSocketFactory::getSSLContext");
156:
157: if (this .sslcontext == null) {
158: this .sslcontext = createSSLContext();
159: }
160: return this .sslcontext;
161: }
162:
163: public Socket createSocket(Socket socket, String host, int port,
164: boolean autoClose) throws IOException, UnknownHostException {
165: SSLSocket sslsocket = (SSLSocket) getSSLContext()
166: .getSocketFactory().createSocket(socket, host, port,
167: autoClose);
168: setCipherSuites(sslsocket);
169:
170: return socket;
171: }
172:
173: public Socket createSocket(String host, int port,
174: InetAddress clientHost, int clientPort) throws IOException,
175: UnknownHostException {
176: SSLSocket socket = (SSLSocket) getSSLContext()
177: .getSocketFactory().createSocket(host, port,
178: clientHost, clientPort);
179: setCipherSuites(socket);
180:
181: return socket;
182: }
183:
184: public Socket createSocket(String host, int port,
185: InetAddress localAddress, int localPort,
186: HttpConnectionParams params) throws IOException,
187: UnknownHostException, ConnectTimeoutException {
188: // if (params == null) {
189: // throw new IllegalArgumentException("Parameters may not be null");
190: // }
191: // int timeout = params.getConnectionTimeout();
192: // if (timeout == 0) {
193: return createSocket(host, port, localAddress, localPort);
194: // } else {
195: // To be eventually deprecated when migrated to Java 1.4 or above
196: // return ControllerThreadSocketFactory.createSocket(
197: // this, host, port, localAddress, localPort, timeout);
198: // }
199: }
200:
201: public Socket createSocket(String host, int port)
202: throws IOException, UnknownHostException {
203: SSLSocket socket = (SSLSocket) getSSLContext()
204: .getSocketFactory().createSocket(host, port);
205:
206: setCipherSuites(socket);
207:
208: return socket;
209: }
210:
211: private void setCipherSuites(SSLSocket socket) {
212: String proto = endpoint.getSslProtocol();
213: String[] protos = new String[1];
214: if (proto.equalsIgnoreCase("TLS")) {
215: protos[0] = "TLSv1";
216: } else if (proto.equalsIgnoreCase("SSL")) {
217: protos[0] = "SSLv2";
218: } else if (proto.equalsIgnoreCase("SSLv3")) {
219: protos[0] = "SSLv3";
220: } else {
221: protos = socket.getEnabledProtocols();
222: }
223: socket.setEnabledProtocols(protos);
224:
225: //First time through, build the cipher lists
226: if (anon_ciphers == null || authenticated_ciphers == null) {
227: String[] cipherSuites = socket.getSupportedCipherSuites();
228: Vector<String> authList = new Vector<String>();
229: Vector<String> anonList = new Vector<String>();
230: for (int i = 0; i < cipherSuites.length; i++) {
231: //Don't use the SSL_RSA_WITH_NULL_* ciphers
232: if (!cipherSuites[i].startsWith("SSL_RSA_WITH_NULL")) {
233: if (isAnonCipher(cipherSuites[i])) {
234: anonList.add(cipherSuites[i]);
235: } else {
236: authList.add(cipherSuites[i]);
237: }
238: }
239: }
240: anon_ciphers = new String[anonList.size()];
241: authenticated_ciphers = new String[authList.size()];
242: for (int i = 0; i < anonList.size(); i++) {
243: anon_ciphers[i] = (String) anonList.get(i);
244: }
245: for (int i = 0; i < authList.size(); i++) {
246: authenticated_ciphers[i] = (String) authList.get(i);
247: }
248: }
249: if (endpoint.isAllowAnonymous()) {
250: socket.setEnabledCipherSuites(anon_ciphers);
251: if (logger.isDebugEnabled()) {
252: logger.debug("Client socket set with Anon ciphers:");
253: for (int i = 0; i < anon_ciphers.length; i++) {
254: logger.debug(" " + anon_ciphers[i]);
255: }
256: }
257: } else {
258: socket.setEnabledCipherSuites(authenticated_ciphers);
259: if (logger.isDebugEnabled()) {
260: logger.debug("Client socket set with Auth ciphers:");
261: for (int i = 0; i < authenticated_ciphers.length; i++) {
262: logger.debug(" " + authenticated_ciphers[i]);
263: }
264: }
265: }
266: }
267:
268: private boolean isAnonCipher(String cipher) {
269: return (cipher.indexOf("_DH_anon") > -1 || cipher
270: .indexOf("_DHE_") > -1);
271: }
272:
273: }
|