001: /*
002: * Copyright 2006 Google Inc.
003: *
004: * Licensed under the Apache License, Version 2.0 (the "License"); you may not
005: * use this file except in compliance with the License. You may obtain a copy of
006: * the License at
007: *
008: * http://www.apache.org/licenses/LICENSE-2.0
009: *
010: * Unless required by applicable law or agreed to in writing, software
011: * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
012: * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
013: * License for the specific language governing permissions and limitations under
014: * the License.
015: */
016: package com.google.gwt.user.server.rpc;
017:
018: import com.google.gwt.user.client.rpc.RemoteService;
019: import com.google.gwt.user.client.rpc.SerializationException;
020: import com.google.gwt.user.server.rpc.impl.ServerSerializableTypeOracle;
021: import com.google.gwt.user.server.rpc.impl.ServerSerializableTypeOracleImpl;
022: import com.google.gwt.user.server.rpc.impl.ServerSerializationStreamReader;
023: import com.google.gwt.user.server.rpc.impl.ServerSerializationStreamWriter;
024:
025: import java.io.ByteArrayOutputStream;
026: import java.io.IOException;
027: import java.io.InputStream;
028: import java.io.UnsupportedEncodingException;
029: import java.lang.reflect.InvocationTargetException;
030: import java.lang.reflect.Method;
031: import java.util.HashMap;
032: import java.util.HashSet;
033: import java.util.Set;
034: import java.util.zip.GZIPOutputStream;
035:
036: import javax.servlet.ServletContext;
037: import javax.servlet.ServletException;
038: import javax.servlet.http.HttpServlet;
039: import javax.servlet.http.HttpServletRequest;
040: import javax.servlet.http.HttpServletResponse;
041:
042: /**
043: * The servlet base class for your RPC service implementations that
044: * automatically deserializes incoming requests from the client and serializes
045: * outgoing responses for client/server RPCs.
046: *
047: * This version is a modified version of RemoteServiceServlet. The only changes
048: * have been to remove some final declarations and to convert some private
049: * methods to protected to allow extension of this class.
050: */
051: public class OpenRemoteServiceServlet extends HttpServlet {
052: /*
053: * These members are used to get and set the different HttpServletResponse
054: * and HttpServletRequest headers.
055: */
056: private static final String ACCEPT_ENCODING = "Accept-Encoding";
057: private static final String CHARSET_UTF8 = "UTF-8";
058: private static final String CONTENT_ENCODING = "Content-Encoding";
059: private static final String CONTENT_ENCODING_GZIP = "gzip";
060: private static final String CONTENT_TYPE_TEXT_PLAIN_UTF8 = "text/plain; charset=utf-8";
061: private static final String GENERIC_FAILURE_MSG = "The call failed on the server; see server log for details";
062: private static final HashMap TYPE_NAMES;
063:
064: /**
065: * Controls the compression threshold at and below which no compression will
066: * take place.
067: */
068: private static final int UNCOMPRESSED_BYTE_SIZE_LIMIT = 256;
069:
070: /**
071: * Return true if the response object accepts Gzip encoding. This is done by
072: * checking that the accept-encoding header specifies gzip as a supported
073: * encoding.
074: */
075: private static boolean acceptsGzipEncoding(
076: HttpServletRequest request) {
077: assert (request != null);
078:
079: String acceptEncoding = request.getHeader(ACCEPT_ENCODING);
080: if (null == acceptEncoding) {
081: return false;
082: }
083:
084: return (acceptEncoding.indexOf(CONTENT_ENCODING_GZIP) != -1);
085: }
086:
087: /**
088: * This method attempts to estimate the number of bytes that a string will
089: * consume when it is sent out as part of an HttpServletResponse.
090: *
091: * This really a hack since we are assuming that every character will
092: * consume two bytes upon transmission. This is definitely not true since
093: * some characters actually consume more than two bytes and some consume
094: * less. This is even less accurate if the string is converted to UTF8.
095: * However, it does save us from converting every string that we plan on
096: * sending back to UTF8 just to determine that we should not compress it.
097: */
098: private static int estimateByteSize(final String buffer) {
099: return (buffer.length() * 2);
100: }
101:
102: /**
103: * Find the invoked method on either the specified interface or any super.
104: */
105: private static Method findInterfaceMethod(Class intf,
106: String methodName, Class[] paramTypes,
107: boolean includeInherited) {
108: try {
109: return intf.getDeclaredMethod(methodName, paramTypes);
110: } catch (NoSuchMethodException e) {
111: if (includeInherited) {
112: Class[] super intfs = intf.getInterfaces();
113: for (int i = 0; i < super intfs.length; i++) {
114: Method method = findInterfaceMethod(super intfs[i],
115: methodName, paramTypes, true);
116: if (method != null)
117: return method;
118: }
119: }
120:
121: return null;
122: }
123: }
124:
125: /**
126: * The default constructor.
127: */
128: public OpenRemoteServiceServlet() {
129: serializableTypeOracle = new ServerSerializableTypeOracleImpl(
130: getPackagePaths());
131: }
132:
133: /**
134: * This is called internally.
135: */
136: public void doPost(HttpServletRequest request,
137: HttpServletResponse response) {
138: Throwable caught;
139: try {
140: // Store the request & response objects in thread-local storage.
141: //
142: perThreadRequest.set(request);
143: perThreadResponse.set(response);
144:
145: // Read the request fully.
146: //
147: String requestPayload = readPayloadAsUtf8(request);
148:
149: // Invoke the core dispatching logic, which returns the serialized
150: // result.
151: //
152: String responsePayload = processCall(requestPayload);
153:
154: // Write the response.
155: //
156: writeResponse(request, response, responsePayload);
157:
158: return;
159:
160: } catch (IOException e) {
161: caught = e;
162: } catch (ServletException e) {
163: caught = e;
164: } catch (SerializationException e) {
165: caught = e;
166: } catch (Throwable e) {
167: caught = e;
168: }
169:
170: respondWithFailure(response, caught);
171: }
172:
173: /**
174: * This is public so that it can be unit tested easily without HTTP.
175: */
176: public String processCall(String payload)
177: throws SerializationException {
178:
179: // Let subclasses see the serialized request.
180: //
181: onBeforeRequestDeserialized(payload);
182:
183: // Create a stream to deserialize the request.
184: //
185: ServerSerializationStreamReader streamReader = new ServerSerializationStreamReader(
186: serializableTypeOracle);
187: streamReader.prepareToRead(payload);
188:
189: // Read the service interface
190: //
191: String serviceIntfName = streamReader.readString();
192:
193: // TODO(mmendez): need a way to check the type signature of the service
194: // intf
195: // Verify that this very servlet implements the specified interface
196: // name.
197: //
198: if (!isImplementedRemoteServiceInterface(serviceIntfName)) {
199: // Bad payload, possible hack attempt.
200: //
201: throw new SecurityException(
202: "Blocked attempt to access interface '"
203: + serviceIntfName
204: + "', which is either not implemented by this servlet or which doesn't extend RemoteService; this is either misconfiguration or a hack attempt");
205: }
206:
207: // Actually get the service interface, so that we can query its methods.
208: //
209: Class serviceIntf;
210: try {
211: serviceIntf = getClassFromName(serviceIntfName);
212: } catch (ClassNotFoundException e) {
213: throw new SerializationException(
214: "Unknown service interface class '"
215: + serviceIntfName + "'", e);
216: }
217:
218: // Read the method name.
219: //
220: String methodName = streamReader.readString();
221:
222: // Read the number and names of the parameter classes from the stream.
223: // We have to do this so that we can find the correct overload of the
224: // method.
225: //
226: int paramCount = streamReader.readInt();
227: Class[] paramTypes = new Class[paramCount];
228: for (int i = 0; i < paramTypes.length; i++) {
229: String paramClassName = streamReader.readString();
230: try {
231: paramTypes[i] = getClassOrPrimitiveFromName(paramClassName);
232: } catch (ClassNotFoundException e) {
233: throw new SerializationException("Unknown parameter "
234: + i + " type '" + paramClassName + "'", e);
235: }
236: }
237:
238: // For security, make sure the method is found in the service interface
239: // and not just one that happens to be defined on this class.
240: //
241: Method serviceIntfMethod = findInterfaceMethod(serviceIntf,
242: methodName, paramTypes, true);
243:
244: // If it wasn't found, don't continue.
245: //
246: if (serviceIntfMethod == null) {
247: // Bad payload, possible hack attempt.
248: //
249: throw new SecurityException(
250: "Method '"
251: + methodName
252: + "' (or a particular overload) on interface '"
253: + serviceIntfName
254: + "' was not found, this is either misconfiguration or a hack attempt");
255: }
256:
257: // Deserialize the parameters.
258: //
259: Object[] args = new Object[paramCount];
260: for (int i = 0; i < args.length; i++) {
261: args[i] = streamReader.deserializeValue(paramTypes[i]);
262: }
263:
264: // Make the call via reflection.
265: //
266: String responsePayload = GENERIC_FAILURE_MSG;
267: ServerSerializationStreamWriter streamWriter = new ServerSerializationStreamWriter(
268: serializableTypeOracle);
269: Throwable caught = null;
270: try {
271: Class returnType = serviceIntfMethod.getReturnType();
272: Object returnVal = serviceIntfMethod.invoke(this , args);
273: responsePayload = createResponse(streamWriter, returnType,
274: returnVal, false);
275: } catch (IllegalArgumentException e) {
276: caught = e;
277: } catch (IllegalAccessException e) {
278: caught = e;
279: } catch (InvocationTargetException e) {
280: // Try to serialize the caught exception if the client is expecting
281: // it,
282: // otherwise log the exception server-side.
283: caught = e;
284: Throwable cause = e.getCause();
285: if (cause != null) {
286: // Update the caught exception to the underlying cause
287: caught = cause;
288: // Serialize the exception back to the client if it's a declared
289: // exception
290: if (isExpectedException(serviceIntfMethod, cause)) {
291: Class thrownClass = cause.getClass();
292: responsePayload = createResponse(streamWriter,
293: thrownClass, cause, true);
294: // Don't log the exception on the server
295: caught = null;
296: }
297: }
298: }
299:
300: if (caught != null) {
301: handleException(responsePayload, caught);
302: }
303:
304: // Let subclasses see the serialized response.
305: //
306: onAfterResponseSerialized(responsePayload);
307:
308: return responsePayload;
309: }
310:
311: protected void handleException(String responsePayload,
312: Throwable caught) {
313: responsePayload = GENERIC_FAILURE_MSG;
314: ServletContext servletContext = getServletContext();
315: // servletContext may be null (for example, when unit testing)
316: if (servletContext != null) {
317: // Log the exception server side
318: servletContext.log(
319: "Exception while dispatching incoming RPC call",
320: caught);
321: }
322: }
323:
324: /**
325: * Gets the <code>HttpServletRequest</code> object for the current call.
326: * It is stored thread-locally so that simultaneous invocations can have
327: * different request objects.
328: */
329: protected final HttpServletRequest getThreadLocalRequest() {
330: return (HttpServletRequest) perThreadRequest.get();
331: }
332:
333: /**
334: * Gets the <code>HttpServletResponse</code> object for the current call.
335: * It is stored thread-locally so that simultaneous invocations can have
336: * different response objects.
337: */
338: protected final HttpServletResponse getThreadLocalResponse() {
339: return (HttpServletResponse) perThreadResponse.get();
340: }
341:
342: /**
343: * Override this method to examine the serialized response that will be
344: * returned to the client. The default implementation does nothing and need
345: * not be called by subclasses.
346: */
347: protected void onAfterResponseSerialized(String serializedResponse) {
348: }
349:
350: /**
351: * Override this method to examine the serialized version of the request
352: * payload before it is deserialized into objects. The default
353: * implementation does nothing and need not be called by subclasses.
354: */
355: protected void onBeforeRequestDeserialized(String serializedRequest) {
356: }
357:
358: /**
359: * Determines whether the response to a given servlet request should or
360: * should not be GZIP compressed. This method is only called in cases where
361: * the requestor accepts GZIP encoding.
362: *
363: * <p>
364: * This implementation currently returns <code>true</code> if the response
365: * string's estimated byte length is longer than 256 bytes. Subclasses can
366: * override this logic.
367: * </p>
368: *
369: * @param request
370: * the request being served
371: * @param response
372: * the response that will be written into
373: * @param responsePayload
374: * the payload that is about to be sent to the client
375: * @return <code>true</code> if responsePayload should be GZIP compressed,
376: * otherwise <code>false</code>.
377: */
378: protected boolean shouldCompressResponse(
379: HttpServletRequest request, HttpServletResponse response,
380: String responsePayload) {
381: return estimateByteSize(responsePayload) > UNCOMPRESSED_BYTE_SIZE_LIMIT;
382: }
383:
384: /**
385: * @param stream
386: * @param responseType
387: * @param responseObj
388: * @param isException
389: * @return
390: */
391: private String createResponse(
392: ServerSerializationStreamWriter stream, Class responseType,
393: Object responseObj, boolean isException) {
394: stream.prepareToWrite();
395: if (responseType != void.class) {
396: try {
397: stream.serializeValue(responseObj, responseType);
398: } catch (SerializationException e) {
399: responseObj = e;
400: isException = true;
401: }
402: }
403:
404: String bufferStr = (isException ? "{EX}" : "{OK}")
405: + stream.toString();
406: return bufferStr;
407: }
408:
409: /**
410: * Returns the {@link Class} instance for the named class or primitive type.
411: *
412: * @param name
413: * the name of a class or primitive type
414: * @return Class instance for the given type name
415: * @throws ClassNotFoundException
416: * if the named type was not found
417: */
418: private Class getClassOrPrimitiveFromName(String name)
419: throws ClassNotFoundException {
420: Object value = TYPE_NAMES.get(name);
421: if (value != null) {
422: return (Class) value;
423: }
424:
425: return getClassFromName(name);
426: }
427:
428: /**
429: * Returns the {@link Class} instance for the named class.
430: *
431: * @param name
432: * the name of a class or primitive type
433: * @return Class instance for the given type name
434: * @throws ClassNotFoundException
435: * if the named type was not found
436: */
437: private Class getClassFromName(String name)
438: throws ClassNotFoundException {
439: return Class.forName(name, false, this .getClass()
440: .getClassLoader());
441: }
442:
443: /**
444: * Obtain the special package-prefixes we use to check for custom
445: * serializers that would like to live in a package that they cannot. For
446: * example, "java.util.ArrayList" is in a sealed package, so instead we use
447: * this prefix to check for a custom serializer in
448: * "com.google.gwt.user.client.rpc.core.java.util.ArrayList". Right now,
449: * it's hard-coded because we don't have a pressing need for this mechanism
450: * to be extensible, but it is imaginable, which is why it's implemented
451: * this way.
452: */
453: private String[] getPackagePaths() {
454: return new String[] { "com.google.gwt.user.client.rpc.core" };
455: }
456:
457: /**
458: * Returns true if the {@link java.lang.reflect.Method Method} definition on
459: * the service is specified to throw the exception contained in the
460: * InvocationTargetException or false otherwise.
461: *
462: * NOTE we do not check that the type is serializable here. We assume that
463: * it must be otherwise the application would never have been allowed to
464: * run.
465: *
466: * @param serviceIntfMethod
467: * @param e
468: * @return
469: */
470: private boolean isExpectedException(Method serviceIntfMethod,
471: Throwable cause) {
472: assert (serviceIntfMethod != null);
473: assert (cause != null);
474:
475: Class[] exceptionsThrown = serviceIntfMethod
476: .getExceptionTypes();
477: if (exceptionsThrown.length <= 0) {
478: // The method is not specified to throw any exceptions
479: //
480: return false;
481: }
482:
483: Class causeType = cause.getClass();
484:
485: for (int index = 0; index < exceptionsThrown.length; ++index) {
486: Class exceptionThrown = exceptionsThrown[index];
487: assert (exceptionThrown != null);
488:
489: if (exceptionThrown.isAssignableFrom(causeType)) {
490: return true;
491: }
492: }
493:
494: return false;
495: }
496:
497: /**
498: * Used to determine whether the specified interface name is implemented by
499: * this class without loading the class (for security).
500: */
501: private boolean isImplementedRemoteServiceInterface(String intfName) {
502: synchronized (knownImplementedInterfaces) {
503: // See if it's cached.
504: //
505: if (knownImplementedInterfaces.contains(intfName)) {
506: return true;
507: }
508:
509: Class cls = getClass();
510:
511: // Unknown, so walk up the class hierarchy to find the first class
512: // that
513: // implements the requested interface
514: //
515: while ((cls != null)
516: && !OpenRemoteServiceServlet.class.equals(cls)) {
517: Class[] intfs = cls.getInterfaces();
518: for (int i = 0; i < intfs.length; i++) {
519: Class intf = intfs[i];
520: if (isImplementedRemoteServiceInterfaceRecursive(
521: intfName, intf)) {
522: knownImplementedInterfaces.add(intfName);
523: return true;
524: }
525: }
526:
527: // did not find the interface in this class so we look in the
528: // superclass
529: cls = cls.getSuperclass();
530: }
531:
532: return false;
533: }
534: }
535:
536: /**
537: * Only called from isImplementedInterface().
538: */
539: private boolean isImplementedRemoteServiceInterfaceRecursive(
540: String intfName, Class intfToCheck) {
541: assert (intfToCheck.isInterface());
542:
543: if (intfToCheck.getName().equals(intfName)) {
544: // The name is right, but we also verify that it is assignable to
545: // RemoteService.
546: //
547: if (RemoteService.class.isAssignableFrom(intfToCheck)) {
548: return true;
549: } else {
550: return false;
551: }
552: }
553:
554: Class[] intfs = intfToCheck.getInterfaces();
555: for (int i = 0; i < intfs.length; i++) {
556: Class intf = intfs[i];
557: if (isImplementedRemoteServiceInterfaceRecursive(intfName,
558: intf)) {
559: return true;
560: }
561: }
562:
563: return false;
564: }
565:
566: protected String readPayloadAsUtf8(HttpServletRequest request)
567: throws IOException, ServletException {
568: int contentLength = request.getContentLength();
569: if (contentLength == -1) {
570: // Content length must be known.
571: throw new ServletException(
572: "Content-Length must be specified");
573: }
574:
575: String contentType = request.getContentType();
576: boolean contentTypeIsOkay = false;
577: // Content-Type must be specified.
578: if (contentType != null) {
579: // The type must be plain text.
580: if (contentType.startsWith("text/plain")) {
581: // And it must be UTF-8 encoded (or unspecified, in which case
582: // we assume
583: // that it's either UTF-8 or ASCII).
584: if (contentType.indexOf("charset=") == -1)
585: contentTypeIsOkay = true;
586: else if (contentType.indexOf("charset=utf-8") != -1)
587: contentTypeIsOkay = true;
588: }
589: }
590: if (!contentTypeIsOkay)
591: throw new ServletException(
592: "Content-Type must be 'text/plain' with 'charset=utf-8' (or unspecified charset)");
593:
594: InputStream in = request.getInputStream();
595: try {
596: byte[] payload = new byte[contentLength];
597: int offset = 0;
598: int len = contentLength;
599: int byteCount;
600: while (offset < contentLength) {
601: byteCount = in.read(payload, offset, len);
602: if (byteCount == -1)
603: throw new ServletException("Client did not send "
604: + contentLength + " bytes as expected");
605: offset += byteCount;
606: len -= byteCount;
607: }
608: return new String(payload, "UTF-8");
609: } finally {
610: if (in != null) {
611: in.close();
612: }
613: }
614: }
615:
616: /**
617: * Called when the machinery of this class itself has a problem, rather than
618: * the invoked third-party method. It writes a simple 500 message back to
619: * the client.
620: */
621: protected void respondWithFailure(HttpServletResponse response,
622: Throwable caught) {
623: ServletContext servletContext = getServletContext();
624: servletContext
625: .log("Exception while dispatching incoming RPC call",
626: caught);
627: try {
628: response.setContentType("text/plain");
629: response
630: .setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
631: response.getWriter().write(GENERIC_FAILURE_MSG);
632: } catch (IOException e) {
633: servletContext
634: .log(
635: "sendError() failed while sending the previous failure to the client",
636: caught);
637: }
638: }
639:
640: protected void writeResponse(HttpServletRequest request,
641: HttpServletResponse response, String responsePayload)
642: throws IOException {
643:
644: byte[] reply = responsePayload.getBytes(CHARSET_UTF8);
645: String contentType = CONTENT_TYPE_TEXT_PLAIN_UTF8;
646:
647: if (acceptsGzipEncoding(request)
648: && shouldCompressResponse(request, response,
649: responsePayload)) {
650: // Compress the reply and adjust headers.
651: //
652: ByteArrayOutputStream output = null;
653: GZIPOutputStream gzipOutputStream = null;
654: Throwable caught = null;
655: try {
656: output = new ByteArrayOutputStream(reply.length);
657: gzipOutputStream = new GZIPOutputStream(output);
658: gzipOutputStream.write(reply);
659: gzipOutputStream.finish();
660: gzipOutputStream.flush();
661: response.setHeader(CONTENT_ENCODING,
662: CONTENT_ENCODING_GZIP);
663: reply = output.toByteArray();
664: } catch (UnsupportedEncodingException e) {
665: caught = e;
666: } catch (IOException e) {
667: caught = e;
668: } finally {
669: if (null != gzipOutputStream) {
670: gzipOutputStream.close();
671: }
672: if (null != output) {
673: output.close();
674: }
675: }
676:
677: if (caught != null) {
678: getServletContext().log("Unable to compress response",
679: caught);
680: response
681: .sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
682: return;
683: }
684: }
685:
686: // Send the reply.
687: //
688: response.setContentLength(reply.length);
689: response.setContentType(contentType);
690: response.setStatus(HttpServletResponse.SC_OK);
691: response.getOutputStream().write(reply);
692: }
693:
694: static {
695: TYPE_NAMES = new HashMap();
696: TYPE_NAMES.put("Z", boolean.class);
697: TYPE_NAMES.put("B", byte.class);
698: TYPE_NAMES.put("C", char.class);
699: TYPE_NAMES.put("D", double.class);
700: TYPE_NAMES.put("F", float.class);
701: TYPE_NAMES.put("I", int.class);
702: TYPE_NAMES.put("J", long.class);
703: TYPE_NAMES.put("S", short.class);
704: }
705:
706: private final Set knownImplementedInterfaces = new HashSet();
707: private final ThreadLocal perThreadRequest = new ThreadLocal();
708: private final ThreadLocal perThreadResponse = new ThreadLocal();
709: private final ServerSerializableTypeOracle serializableTypeOracle;
710: }
|