001: /*
002: * Licensed to the Apache Software Foundation (ASF) under one or more
003: * contributor license agreements. See the NOTICE file distributed with
004: * this work for additional information regarding copyright ownership.
005: * The ASF licenses this file to You under the Apache License, Version 2.0
006: * (the "License"); you may not use this file except in compliance with
007: * the License. You may obtain a copy of the License at
008: *
009: * http://www.apache.org/licenses/LICENSE-2.0
010: *
011: * Unless required by applicable law or agreed to in writing, software
012: * distributed under the License is distributed on an "AS IS" BASIS,
013: * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014: * See the License for the specific language governing permissions and
015: * limitations under the License.
016: */
017:
018: package org.apache.tomcat.util.net.puretls;
019:
020: import java.io.IOException;
021: import java.net.InetAddress;
022: import java.net.ServerSocket;
023: import java.net.Socket;
024: import java.net.SocketException;
025: import java.util.Vector;
026:
027: import COM.claymoresystems.ptls.SSLContext;
028: import COM.claymoresystems.ptls.SSLException;
029: import COM.claymoresystems.ptls.SSLServerSocket;
030: import COM.claymoresystems.ptls.SSLSocket;
031: import COM.claymoresystems.sslg.SSLPolicyInt;
032:
033: /**
034: * SSL server socket factory--wraps PureTLS
035: *
036: * @author Eric Rescorla
037: *
038: * some sections of this file cribbed from SSLSocketFactory
039: * (the JSSE socket factory)
040: *
041: */
042:
043: public class PureTLSSocketFactory extends
044: org.apache.tomcat.util.net.ServerSocketFactory {
045: static org.apache.commons.logging.Log logger = org.apache.commons.logging.LogFactory
046: .getLog(PureTLSSocketFactory.class);
047: static String defaultProtocol = "TLS";
048: static boolean defaultClientAuth = false;
049: static String defaultKeyStoreFile = "server.pem";
050: static String defaultKeyPass = "password";
051: static String defaultRootFile = "root.pem";
052: static String defaultRandomFile = "random.pem";
053:
054: private COM.claymoresystems.ptls.SSLContext context = null;
055:
056: public PureTLSSocketFactory() {
057: }
058:
059: public ServerSocket createSocket(int port) throws IOException {
060: init();
061: return new SSLServerSocket(context, port);
062: }
063:
064: public ServerSocket createSocket(int port, int backlog)
065: throws IOException {
066: init();
067: ServerSocket tmp;
068:
069: try {
070: tmp = new SSLServerSocket(context, port, backlog);
071: } catch (IOException e) {
072: throw e;
073: }
074: return tmp;
075: }
076:
077: public ServerSocket createSocket(int port, int backlog,
078: InetAddress ifAddress) throws IOException {
079: init();
080: return new SSLServerSocket(context, port, backlog, ifAddress);
081: }
082:
083: private void init() throws IOException {
084: if (context != null)
085: return;
086:
087: boolean clientAuth = defaultClientAuth;
088:
089: try {
090: String keyStoreFile = (String) attributes.get("keystore");
091: if (keyStoreFile == null)
092: keyStoreFile = defaultKeyStoreFile;
093:
094: String keyPass = (String) attributes.get("keypass");
095: if (keyPass == null)
096: keyPass = defaultKeyPass;
097:
098: String rootFile = (String) attributes.get("rootfile");
099: if (rootFile == null)
100: rootFile = defaultRootFile;
101:
102: String randomFile = (String) attributes.get("randomfile");
103: if (randomFile == null)
104: randomFile = defaultRandomFile;
105:
106: String protocol = (String) attributes.get("protocol");
107: if (protocol == null)
108: protocol = defaultProtocol;
109:
110: String clientAuthStr = (String) attributes
111: .get("clientauth");
112: if (clientAuthStr != null) {
113: if (clientAuthStr.equals("true")) {
114: clientAuth = true;
115: } else if (clientAuthStr.equals("false")) {
116: clientAuth = false;
117: } else {
118: throw new IOException("Invalid value '"
119: + clientAuthStr
120: + "' for 'clientauth' parameter:");
121: }
122: }
123:
124: SSLContext tmpContext = new SSLContext();
125: try {
126: tmpContext.loadRootCertificates(rootFile);
127: } catch (IOException iex) {
128: if (logger.isDebugEnabled())
129: logger.debug("Error loading Client Root Store: "
130: + rootFile, iex);
131: }
132: tmpContext.loadEAYKeyFile(keyStoreFile, keyPass);
133: tmpContext.useRandomnessFile(randomFile, keyPass);
134:
135: SSLPolicyInt policy = new SSLPolicyInt();
136: policy.requireClientAuth(clientAuth);
137: policy.handshakeOnConnect(false);
138: policy.waitOnClose(false);
139: short[] enabledCiphers = getEnabledCiphers(policy
140: .getCipherSuites());
141: if (enabledCiphers != null) {
142: policy.setCipherSuites(enabledCiphers);
143: }
144: tmpContext.setPolicy(policy);
145: context = tmpContext;
146: } catch (Exception e) {
147: logger.info("Error initializing SocketFactory", e);
148: throw new IOException(e.getMessage());
149: }
150: }
151:
152: /*
153: * Determines the SSL cipher suites to be enabled.
154: *
155: * @return Array of SSL cipher suites to be enabled, or null if the
156: * cipherSuites property was not specified (meaning that all supported
157: * cipher suites are to be enabled)
158: */
159: private short[] getEnabledCiphers(short[] supportedCiphers) {
160:
161: short[] enabledCiphers = null;
162:
163: String attrValue = (String) attributes.get("ciphers");
164: if (attrValue != null) {
165: Vector vec = null;
166: int fromIndex = 0;
167: int index = attrValue.indexOf(',', fromIndex);
168: while (index != -1) {
169: String cipher = attrValue.substring(fromIndex, index)
170: .trim();
171: int cipherValue = SSLPolicyInt
172: .getCipherSuiteNumber(cipher);
173: /*
174: * Check to see if the requested cipher is among the supported
175: * ciphers, i.e., may be enabled
176: */
177: if (cipherValue >= 0) {
178: for (int i = 0; supportedCiphers != null
179: && i < supportedCiphers.length; i++) {
180:
181: if (cipherValue == supportedCiphers[i]) {
182: if (vec == null) {
183: vec = new Vector();
184: }
185: vec.addElement(new Integer(cipherValue));
186: break;
187: }
188: }
189: }
190: fromIndex = index + 1;
191: index = attrValue.indexOf(',', fromIndex);
192: }
193:
194: if (vec != null) {
195: int nCipher = vec.size();
196: enabledCiphers = new short[nCipher];
197: for (int i = 0; i < nCipher; i++) {
198: Integer value = (Integer) vec.elementAt(i);
199: enabledCiphers[i] = value.shortValue();
200: }
201: }
202: }
203:
204: return enabledCiphers;
205:
206: }
207:
208: public Socket acceptSocket(ServerSocket socket) throws IOException {
209: try {
210: Socket sock = socket.accept();
211: return sock;
212: } catch (SSLException e) {
213: logger.debug("SSL handshake error", e);
214: throw new SocketException("SSL handshake error"
215: + e.toString());
216: }
217: }
218:
219: public void handshake(Socket sock) throws IOException {
220: ((SSLSocket) sock).handshake();
221: }
222: }
|