001: /*
002: * Licensed to the Apache Software Foundation (ASF) under one or more
003: * contributor license agreements. See the NOTICE file distributed with
004: * this work for additional information regarding copyright ownership.
005: * The ASF licenses this file to You under the Apache License, Version 2.0
006: * (the "License"); you may not use this file except in compliance with
007: * the License. You may obtain a copy of the License at
008: *
009: * http://www.apache.org/licenses/LICENSE-2.0
010: *
011: * Unless required by applicable law or agreed to in writing, software
012: * distributed under the License is distributed on an "AS IS" BASIS,
013: * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014: * See the License for the specific language governing permissions and
015: * limitations under the License.
016: */
017: package org.apache.tomcat.util.net;
018:
019: import java.io.IOException;
020: import java.nio.ByteBuffer;
021: import java.nio.channels.SelectionKey;
022: import java.nio.channels.SocketChannel;
023: import javax.net.ssl.SSLEngine;
024: import javax.net.ssl.SSLEngineResult;
025: import javax.net.ssl.SSLEngineResult.HandshakeStatus;
026: import javax.net.ssl.SSLEngineResult.Status;
027: import java.nio.channels.Selector;
028:
029: /**
030: *
031: * Implementation of a secure socket channel
032: * @author Filip Hanik
033: * @version 1.0
034: */
035:
036: public class SecureNioChannel extends NioChannel {
037:
038: protected ByteBuffer netInBuffer;
039: protected ByteBuffer netOutBuffer;
040:
041: protected SSLEngine sslEngine;
042:
043: protected boolean initHandshakeComplete = false;
044: protected HandshakeStatus initHandshakeStatus; //gets set by begin handshake
045:
046: protected boolean closed = false;
047: protected boolean closing = false;
048:
049: protected NioSelectorPool pool;
050:
051: public SecureNioChannel(SocketChannel channel, SSLEngine engine,
052: ApplicationBufferHandler bufHandler, NioSelectorPool pool)
053: throws IOException {
054: super (channel, bufHandler);
055: this .sslEngine = engine;
056: int appBufSize = sslEngine.getSession()
057: .getApplicationBufferSize();
058: int netBufSize = sslEngine.getSession().getPacketBufferSize();
059: //allocate network buffers - TODO, add in optional direct non-direct buffers
060: if (netInBuffer == null)
061: netInBuffer = ByteBuffer.allocateDirect(netBufSize);
062: if (netOutBuffer == null)
063: netOutBuffer = ByteBuffer.allocateDirect(netBufSize);
064:
065: //selector pool for blocking operations
066: this .pool = pool;
067:
068: //ensure that the application has a large enough read/write buffers
069: //by doing this, we should not encounter any buffer overflow errors
070: bufHandler.expand(bufHandler.getReadBuffer(), appBufSize);
071: bufHandler.expand(bufHandler.getWriteBuffer(), appBufSize);
072: reset();
073: }
074:
075: public void reset(SSLEngine engine) throws IOException {
076: this .sslEngine = engine;
077: reset();
078: }
079:
080: public void reset() throws IOException {
081: super .reset();
082: netOutBuffer.position(0);
083: netOutBuffer.limit(0);
084: netInBuffer.position(0);
085: netInBuffer.limit(0);
086: initHandshakeComplete = false;
087: closed = false;
088: closing = false;
089: //initiate handshake
090: sslEngine.beginHandshake();
091: initHandshakeStatus = sslEngine.getHandshakeStatus();
092: }
093:
094: public int getBufferSize() {
095: int size = super .getBufferSize();
096: size += netInBuffer != null ? netInBuffer.capacity() : 0;
097: size += netOutBuffer != null ? netOutBuffer.capacity() : 0;
098: return size;
099: }
100:
101: //===========================================================================================
102: // NIO SSL METHODS
103: //===========================================================================================
104: /**
105: * returns true if the network buffer has
106: * been flushed out and is empty
107: * @return boolean
108: */
109: public boolean flush(boolean block, Selector s, long timeout)
110: throws IOException {
111: if (!block) {
112: flush(netOutBuffer);
113: } else {
114: pool.write(netOutBuffer, this , s, timeout);
115: }
116: return !netOutBuffer.hasRemaining();
117: }
118:
119: /**
120: * Flushes the buffer to the network, non blocking
121: * @param buf ByteBuffer
122: * @return boolean true if the buffer has been emptied out, false otherwise
123: * @throws IOException
124: */
125: protected boolean flush(ByteBuffer buf) throws IOException {
126: int remaining = buf.remaining();
127: if (remaining > 0) {
128: int written = sc.write(buf);
129: return written >= remaining;
130: } else {
131: return true;
132: }
133: }
134:
135: /**
136: * Performs SSL handshake, non blocking, but performs NEED_TASK on the same thread.<br>
137: * Hence, you should never call this method using your Acceptor thread, as you would slow down
138: * your system significantly.<br>
139: * The return for this operation is 0 if the handshake is complete and a positive value if it is not complete.
140: * In the event of a positive value coming back, reregister the selection key for the return values interestOps.
141: * @param read boolean - true if the underlying channel is readable
142: * @param write boolean - true if the underlying channel is writable
143: * @return int - 0 if hand shake is complete, otherwise it returns a SelectionKey interestOps value
144: * @throws IOException
145: */
146: public int handshake(boolean read, boolean write)
147: throws IOException {
148: if (initHandshakeComplete)
149: return 0; //we have done our initial handshake
150:
151: if (!flush(netOutBuffer))
152: return SelectionKey.OP_WRITE; //we still have data to write
153:
154: SSLEngineResult handshake = null;
155:
156: while (!initHandshakeComplete) {
157: switch (initHandshakeStatus) {
158: case NOT_HANDSHAKING: {
159: //should never happen
160: throw new IOException(
161: "NOT_HANDSHAKING during handshake");
162: }
163: case FINISHED: {
164: //we are complete if we have delivered the last package
165: initHandshakeComplete = !netOutBuffer.hasRemaining();
166: //return 0 if we are complete, otherwise we still have data to write
167: return initHandshakeComplete ? 0
168: : SelectionKey.OP_WRITE;
169: }
170: case NEED_WRAP: {
171: //perform the wrap function
172: handshake = handshakeWrap(write);
173: if (handshake.getStatus() == Status.OK) {
174: if (initHandshakeStatus == HandshakeStatus.NEED_TASK)
175: initHandshakeStatus = tasks();
176: } else {
177: //wrap should always work with our buffers
178: throw new IOException("Unexpected status:"
179: + handshake.getStatus()
180: + " during handshake WRAP.");
181: }
182: if (initHandshakeStatus != HandshakeStatus.NEED_UNWRAP
183: || (!flush(netOutBuffer))) {
184: //should actually return OP_READ if we have NEED_UNWRAP
185: return SelectionKey.OP_WRITE;
186: }
187: //fall down to NEED_UNWRAP on the same call, will result in a
188: //BUFFER_UNDERFLOW if it needs data
189: }
190: case NEED_UNWRAP: {
191: //perform the unwrap function
192: handshake = handshakeUnwrap(read);
193: if (handshake.getStatus() == Status.OK) {
194: if (initHandshakeStatus == HandshakeStatus.NEED_TASK)
195: initHandshakeStatus = tasks();
196: } else if (handshake.getStatus() == Status.BUFFER_UNDERFLOW) {
197: //read more data, reregister for OP_READ
198: return SelectionKey.OP_READ;
199: } else {
200: throw new IOException("Invalid handshake status:"
201: + initHandshakeStatus
202: + " during handshake UNWRAP.");
203: }//switch
204: break;
205: }
206: case NEED_TASK: {
207: initHandshakeStatus = tasks();
208: break;
209: }
210: default:
211: throw new IllegalStateException(
212: "Invalid handshake status:"
213: + initHandshakeStatus);
214: }//switch
215: }//while
216: //return 0 if we are complete, otherwise reregister for any activity that
217: //would cause this method to be called again.
218: return initHandshakeComplete ? 0
219: : (SelectionKey.OP_WRITE | SelectionKey.OP_READ);
220: }
221:
222: /**
223: * Executes all the tasks needed on the same thread.
224: * @return HandshakeStatus
225: */
226: protected SSLEngineResult.HandshakeStatus tasks() {
227: Runnable r = null;
228: while ((r = sslEngine.getDelegatedTask()) != null) {
229: r.run();
230: }
231: return sslEngine.getHandshakeStatus();
232: }
233:
234: /**
235: * Performs the WRAP function
236: * @param doWrite boolean
237: * @return SSLEngineResult
238: * @throws IOException
239: */
240: protected SSLEngineResult handshakeWrap(boolean doWrite)
241: throws IOException {
242: //this should never be called with a network buffer that contains data
243: //so we can clear it here.
244: netOutBuffer.clear();
245: //perform the wrap
246: SSLEngineResult result = sslEngine.wrap(bufHandler
247: .getWriteBuffer(), netOutBuffer);
248: //prepare the results to be written
249: netOutBuffer.flip();
250: //set the status
251: initHandshakeStatus = result.getHandshakeStatus();
252: //optimization, if we do have a writable channel, write it now
253: if (doWrite)
254: flush(netOutBuffer);
255: return result;
256: }
257:
258: /**
259: * Perform handshake unwrap
260: * @param doread boolean
261: * @return SSLEngineResult
262: * @throws IOException
263: */
264: protected SSLEngineResult handshakeUnwrap(boolean doread)
265: throws IOException {
266:
267: if (netInBuffer.position() == netInBuffer.limit()) {
268: //clear the buffer if we have emptied it out on data
269: netInBuffer.clear();
270: }
271: if (doread) {
272: //if we have data to read, read it
273: int read = sc.read(netInBuffer);
274: if (read == -1)
275: throw new IOException(
276: "EOF encountered during handshake.");
277: }
278: SSLEngineResult result;
279: boolean cont = false;
280: //loop while we can perform pure SSLEngine data
281: do {
282: //prepare the buffer with the incoming data
283: netInBuffer.flip();
284: //call unwrap
285: result = sslEngine.unwrap(netInBuffer, bufHandler
286: .getReadBuffer());
287: //compact the buffer, this is an optional method, wonder what would happen if we didn't
288: netInBuffer.compact();
289: //read in the status
290: initHandshakeStatus = result.getHandshakeStatus();
291: if (result.getStatus() == SSLEngineResult.Status.OK
292: && result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
293: //execute tasks if we need to
294: initHandshakeStatus = tasks();
295: }
296: //perform another unwrap?
297: cont = result.getStatus() == SSLEngineResult.Status.OK
298: && initHandshakeStatus == HandshakeStatus.NEED_UNWRAP;
299: } while (cont);
300: return result;
301: }
302:
303: /**
304: * Sends a SSL close message, will not physically close the connection here.<br>
305: * To close the connection, you could do something like
306: * <pre><code>
307: * close();
308: * while (isOpen() && !myTimeoutFunction()) Thread.sleep(25);
309: * if ( isOpen() ) close(true); //forces a close if you timed out
310: * </code></pre>
311: * @throws IOException if an I/O error occurs
312: * @throws IOException if there is data on the outgoing network buffer and we are unable to flush it
313: * @todo Implement this java.io.Closeable method
314: */
315: public void close() throws IOException {
316: if (closing)
317: return;
318: closing = true;
319: sslEngine.closeOutbound();
320:
321: if (!flush(netOutBuffer)) {
322: throw new IOException(
323: "Remaining data in the network buffer, can't send SSL close message, force a close with close(true) instead");
324: }
325: //prep the buffer for the close message
326: netOutBuffer.clear();
327: //perform the close, since we called sslEngine.closeOutbound
328: SSLEngineResult handshake = sslEngine.wrap(getEmptyBuf(),
329: netOutBuffer);
330: //we should be in a close state
331: if (handshake.getStatus() != SSLEngineResult.Status.CLOSED) {
332: throw new IOException(
333: "Invalid close state, will not send network data.");
334: }
335: //prepare the buffer for writing
336: netOutBuffer.flip();
337: //if there is data to be written
338: flush(netOutBuffer);
339:
340: //is the channel closed?
341: closed = (!netOutBuffer.hasRemaining() && (handshake
342: .getHandshakeStatus() != HandshakeStatus.NEED_WRAP));
343: }
344:
345: /**
346: * Force a close, can throw an IOException
347: * @param force boolean
348: * @throws IOException
349: */
350: public void close(boolean force) throws IOException {
351: try {
352: close();
353: } finally {
354: if (force || closed) {
355: closed = true;
356: sc.socket().close();
357: sc.close();
358: }
359: }
360: }
361:
362: /**
363: * Reads a sequence of bytes from this channel into the given buffer.
364: *
365: * @param dst The buffer into which bytes are to be transferred
366: * @return The number of bytes read, possibly zero, or <tt>-1</tt> if the channel has reached end-of-stream
367: * @throws IOException If some other I/O error occurs
368: * @throws IllegalArgumentException if the destination buffer is different than bufHandler.getReadBuffer()
369: * @todo Implement this java.nio.channels.ReadableByteChannel method
370: */
371: public int read(ByteBuffer dst) throws IOException {
372: //if we want to take advantage of the expand function, make sure we only use the ApplicationBufferHandler's buffers
373: if (dst != bufHandler.getReadBuffer())
374: throw new IllegalArgumentException(
375: "You can only read using the application read buffer provided by the handler.");
376: //are we in the middle of closing or closed?
377: if (closing || closed)
378: return -1;
379: //did we finish our handshake?
380: if (!initHandshakeComplete)
381: throw new IllegalStateException(
382: "Handshake incomplete, you must complete handshake before reading data.");
383:
384: //read from the network
385: int netread = sc.read(netInBuffer);
386: //did we reach EOF? if so send EOF up one layer.
387: if (netread == -1)
388: return -1;
389:
390: //the data read
391: int read = 0;
392: //the SSL engine result
393: SSLEngineResult unwrap;
394: do {
395: //prepare the buffer
396: netInBuffer.flip();
397: //unwrap the data
398: unwrap = sslEngine.unwrap(netInBuffer, dst);
399: //compact the buffer
400: netInBuffer.compact();
401:
402: if (unwrap.getStatus() == Status.OK
403: || unwrap.getStatus() == Status.BUFFER_UNDERFLOW) {
404: //we did receive some data, add it to our total
405: read += unwrap.bytesProduced();
406: //perform any tasks if needed
407: if (unwrap.getHandshakeStatus() == HandshakeStatus.NEED_TASK)
408: tasks();
409: //if we need more network data, then bail out for now.
410: if (unwrap.getStatus() == Status.BUFFER_UNDERFLOW)
411: break;
412: } else {
413: //here we should trap BUFFER_OVERFLOW and call expand on the buffer
414: //for now, throw an exception, as we initialized the buffers
415: //in the constructor
416: throw new IOException(
417: "Unable to unwrap data, invalid status: "
418: + unwrap.getStatus());
419: }
420: } while ((netInBuffer.position() != 0)); //continue to unwrapping as long as the input buffer has stuff
421: return (read);
422: }
423:
424: /**
425: * Writes a sequence of bytes to this channel from the given buffer.
426: *
427: * @param src The buffer from which bytes are to be retrieved
428: * @return The number of bytes written, possibly zero
429: * @throws IOException If some other I/O error occurs
430: * @todo Implement this java.nio.channels.WritableByteChannel method
431: */
432: public int write(ByteBuffer src) throws IOException {
433: //make sure we can handle expand, and that we only use on buffer
434: if (src != bufHandler.getWriteBuffer())
435: throw new IllegalArgumentException(
436: "You can only write using the application write buffer provided by the handler.");
437: //are we closing or closed?
438: if (closing || closed)
439: throw new IOException("Channel is in closing state.");
440:
441: //the number of bytes written
442: int written = 0;
443:
444: if (!flush(netOutBuffer)) {
445: //we haven't emptied out the buffer yet
446: return written;
447: }
448:
449: /*
450: * The data buffer is empty, we can reuse the entire buffer.
451: */
452: netOutBuffer.clear();
453:
454: SSLEngineResult result = sslEngine.wrap(src, netOutBuffer);
455: written = result.bytesConsumed();
456: netOutBuffer.flip();
457:
458: if (result.getStatus() == Status.OK) {
459: if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK)
460: tasks();
461: } else {
462: throw new IOException(
463: "Unable to wrap data, invalid engine state: "
464: + result.getStatus());
465: }
466:
467: //force a flush
468: flush(netOutBuffer);
469:
470: return written;
471: }
472:
473: /**
474: * Callback interface to be able to expand buffers
475: * when buffer overflow exceptions happen
476: */
477: public static interface ApplicationBufferHandler {
478: public ByteBuffer expand(ByteBuffer buffer, int remaining);
479:
480: public ByteBuffer getReadBuffer();
481:
482: public ByteBuffer getWriteBuffer();
483: }
484:
485: public ApplicationBufferHandler getBufHandler() {
486: return bufHandler;
487: }
488:
489: public boolean isInitHandshakeComplete() {
490: return initHandshakeComplete;
491: }
492:
493: public boolean isClosing() {
494: return closing;
495: }
496:
497: public SSLEngine getSslEngine() {
498: return sslEngine;
499: }
500:
501: public ByteBuffer getEmptyBuf() {
502: return emptyBuf;
503: }
504:
505: public void setBufHandler(ApplicationBufferHandler bufHandler) {
506: this .bufHandler = bufHandler;
507: }
508:
509: public SocketChannel getIOChannel() {
510: return sc;
511: }
512:
513: }
|