001: /*
002: * Copyright 1999-2004 The Apache Software Foundation
003: *
004: * Licensed under the Apache License, Version 2.0 (the "License");
005: * you may not use this file except in compliance with the License.
006: * You may obtain a copy of the License at
007: *
008: * http://www.apache.org/licenses/LICENSE-2.0
009: *
010: * Unless required by applicable law or agreed to in writing, software
011: * distributed under the License is distributed on an "AS IS" BASIS,
012: * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013: * See the License for the specific language governing permissions and
014: * limitations under the License.
015: */
016:
017: package org.apache.tomcat.util.net.puretls;
018:
019: import java.io.IOException;
020: import java.net.InetAddress;
021: import java.net.ServerSocket;
022: import java.net.Socket;
023: import java.net.SocketException;
024: import java.util.Vector;
025:
026: import COM.claymoresystems.ptls.SSLContext;
027: import COM.claymoresystems.ptls.SSLException;
028: import COM.claymoresystems.ptls.SSLServerSocket;
029: import COM.claymoresystems.ptls.SSLSocket;
030: import COM.claymoresystems.sslg.SSLPolicyInt;
031:
032: /**
033: * SSL server socket factory--wraps PureTLS
034: *
035: * @author Eric Rescorla
036: *
037: * some sections of this file cribbed from SSLSocketFactory
038: * (the JSSE socket factory)
039: *
040: */
041:
042: public class PureTLSSocketFactory extends
043: org.apache.tomcat.util.net.ServerSocketFactory {
044: static org.apache.commons.logging.Log logger = org.apache.commons.logging.LogFactory
045: .getLog(PureTLSSocketFactory.class);
046: static String defaultProtocol = "TLS";
047: static boolean defaultClientAuth = false;
048: static String defaultKeyStoreFile = "server.pem";
049: static String defaultKeyPass = "password";
050: static String defaultRootFile = "root.pem";
051: static String defaultRandomFile = "random.pem";
052:
053: private COM.claymoresystems.ptls.SSLContext context = null;
054:
055: public PureTLSSocketFactory() {
056: }
057:
058: public ServerSocket createSocket(int port) throws IOException {
059: init();
060: return new SSLServerSocket(context, port);
061: }
062:
063: public ServerSocket createSocket(int port, int backlog)
064: throws IOException {
065: init();
066: ServerSocket tmp;
067:
068: try {
069: tmp = new SSLServerSocket(context, port, backlog);
070: } catch (IOException e) {
071: throw e;
072: }
073: return tmp;
074: }
075:
076: public ServerSocket createSocket(int port, int backlog,
077: InetAddress ifAddress) throws IOException {
078: init();
079: return new SSLServerSocket(context, port, backlog, ifAddress);
080: }
081:
082: private void init() throws IOException {
083: if (context != null)
084: return;
085:
086: boolean clientAuth = defaultClientAuth;
087:
088: try {
089: String keyStoreFile = (String) attributes.get("keystore");
090: if (keyStoreFile == null)
091: keyStoreFile = defaultKeyStoreFile;
092:
093: String keyPass = (String) attributes.get("keypass");
094: if (keyPass == null)
095: keyPass = defaultKeyPass;
096:
097: String rootFile = (String) attributes.get("rootfile");
098: if (rootFile == null)
099: rootFile = defaultRootFile;
100:
101: String randomFile = (String) attributes.get("randomfile");
102: if (randomFile == null)
103: randomFile = defaultRandomFile;
104:
105: String protocol = (String) attributes.get("protocol");
106: if (protocol == null)
107: protocol = defaultProtocol;
108:
109: String clientAuthStr = (String) attributes
110: .get("clientauth");
111: if (clientAuthStr != null) {
112: if (clientAuthStr.equals("true")) {
113: clientAuth = true;
114: } else if (clientAuthStr.equals("false")) {
115: clientAuth = false;
116: } else {
117: throw new IOException("Invalid value '"
118: + clientAuthStr
119: + "' for 'clientauth' parameter:");
120: }
121: }
122:
123: SSLContext tmpContext = new SSLContext();
124: try {
125: tmpContext.loadRootCertificates(rootFile);
126: } catch (IOException iex) {
127: if (logger.isDebugEnabled())
128: logger.debug("Error loading Client Root Store: "
129: + rootFile, iex);
130: }
131: tmpContext.loadEAYKeyFile(keyStoreFile, keyPass);
132: tmpContext.useRandomnessFile(randomFile, keyPass);
133:
134: SSLPolicyInt policy = new SSLPolicyInt();
135: policy.requireClientAuth(clientAuth);
136: policy.handshakeOnConnect(false);
137: policy.waitOnClose(false);
138: short[] enabledCiphers = getEnabledCiphers(policy
139: .getCipherSuites());
140: if (enabledCiphers != null) {
141: policy.setCipherSuites(enabledCiphers);
142: }
143: tmpContext.setPolicy(policy);
144: context = tmpContext;
145: } catch (Exception e) {
146: logger.info("Error initializing SocketFactory", e);
147: throw new IOException(e.getMessage());
148: }
149: }
150:
151: /*
152: * Determines the SSL cipher suites to be enabled.
153: *
154: * @return Array of SSL cipher suites to be enabled, or null if the
155: * cipherSuites property was not specified (meaning that all supported
156: * cipher suites are to be enabled)
157: */
158: private short[] getEnabledCiphers(short[] supportedCiphers) {
159:
160: short[] enabledCiphers = null;
161:
162: String attrValue = (String) attributes.get("ciphers");
163: if (attrValue != null) {
164: Vector vec = null;
165: int fromIndex = 0;
166: int index = attrValue.indexOf(',', fromIndex);
167: while (index != -1) {
168: String cipher = attrValue.substring(fromIndex, index)
169: .trim();
170: int cipherValue = SSLPolicyInt
171: .getCipherSuiteNumber(cipher);
172: /*
173: * Check to see if the requested cipher is among the supported
174: * ciphers, i.e., may be enabled
175: */
176: if (cipherValue >= 0) {
177: for (int i = 0; supportedCiphers != null
178: && i < supportedCiphers.length; i++) {
179:
180: if (cipherValue == supportedCiphers[i]) {
181: if (vec == null) {
182: vec = new Vector();
183: }
184: vec.addElement(new Integer(cipherValue));
185: break;
186: }
187: }
188: }
189: fromIndex = index + 1;
190: index = attrValue.indexOf(',', fromIndex);
191: }
192:
193: if (vec != null) {
194: int nCipher = vec.size();
195: enabledCiphers = new short[nCipher];
196: for (int i = 0; i < nCipher; i++) {
197: Integer value = (Integer) vec.elementAt(i);
198: enabledCiphers[i] = value.shortValue();
199: }
200: }
201: }
202:
203: return enabledCiphers;
204:
205: }
206:
207: public Socket acceptSocket(ServerSocket socket) throws IOException {
208: try {
209: Socket sock = socket.accept();
210: return sock;
211: } catch (SSLException e) {
212: logger.debug("SSL handshake error", e);
213: throw new SocketException("SSL handshake error"
214: + e.toString());
215: }
216: }
217:
218: public void handshake(Socket sock) throws IOException {
219: ((SSLSocket) sock).handshake();
220: }
221: }
|