001: package org.jboss.seam.remoting.gwt;
002:
003: /*
004: * Copyright 2005 JBoss Inc
005: *
006: * Licensed under the Apache License, Version 2.0 (the "License");
007: * you may not use this file except in compliance with the License.
008: * You may obtain a copy of the License at
009: *
010: * http://www.apache.org/licenses/LICENSE-2.0
011: *
012: * Unless required by applicable law or agreed to in writing, software
013: * distributed under the License is distributed on an "AS IS" BASIS,
014: * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
015: * See the License for the specific language governing permissions and
016: * limitations under the License.
017: */
018:
019: import java.io.ByteArrayOutputStream;
020: import java.io.IOException;
021: import java.io.InputStream;
022: import java.io.UnsupportedEncodingException;
023: import java.lang.reflect.InvocationTargetException;
024: import java.lang.reflect.Method;
025: import java.util.HashMap;
026: import java.util.HashSet;
027: import java.util.Set;
028: import java.util.zip.GZIPOutputStream;
029:
030: import javax.servlet.ServletContext;
031: import javax.servlet.ServletException;
032: import javax.servlet.http.HttpServlet;
033: import javax.servlet.http.HttpServletRequest;
034: import javax.servlet.http.HttpServletResponse;
035:
036: import org.jboss.seam.remoting.gwt.GWTToSeamAdapter.ReturnedObject;
037:
038: import com.google.gwt.user.client.rpc.RemoteService;
039: import com.google.gwt.user.client.rpc.SerializableException;
040: import com.google.gwt.user.client.rpc.SerializationException;
041: import com.google.gwt.user.server.rpc.impl.ServerSerializableTypeOracle;
042: import com.google.gwt.user.server.rpc.impl.ServerSerializableTypeOracleImpl;
043: import com.google.gwt.user.server.rpc.impl.ServerSerializationStreamReader;
044: import com.google.gwt.user.server.rpc.impl.ServerSerializationStreamWriter;
045:
046: /**
047: *
048: * @author @hacker Michael Neale
049: * This is a less then ideal approach, but GWT (up to and including 1.3)
050: * has no means to get into the internals of the RPC mechanism, and free it
051: * from the shackles of the servlet API.
052: * So I, the liberator, have hacked this out of RemoteServiceServlet to do this, heretofore.
053: *
054: * When GWT 1.4 comes along to save us all, this can be retired, and the RPC utility class
055: * - as contributed by Rob Jellinghaus can be used instead.
056: *
057: */
058: public class GWTRemoteServiceServlet extends HttpServlet {
059:
060: /*
061: * These members are used to get and set the different HttpServletResponse and
062: * HttpServletRequest headers.
063: */
064: private static final String ACCEPT_ENCODING = "Accept-Encoding";
065: private static final String CHARSET_UTF8 = "UTF-8";
066: private static final String CONTENT_ENCODING = "Content-Encoding";
067: private static final String CONTENT_ENCODING_GZIP = "gzip";
068: private static final String CONTENT_TYPE_TEXT_PLAIN_UTF8 = "text/plain; charset=utf-8";
069: private static final String GENERIC_FAILURE_MSG = "An error occurred on the server. Consult server log for details.";
070: private static final HashMap TYPE_NAMES;
071:
072: /**
073: * Controls the compression threshold at and below which no compression will
074: * take place.
075: */
076: private static final int UNCOMPRESSED_BYTE_SIZE_LIMIT = 256;
077:
078: static {
079: TYPE_NAMES = new HashMap();
080: TYPE_NAMES.put("Z", boolean.class);
081: TYPE_NAMES.put("B", byte.class);
082: TYPE_NAMES.put("C", char.class);
083: TYPE_NAMES.put("D", double.class);
084: TYPE_NAMES.put("F", float.class);
085: TYPE_NAMES.put("I", int.class);
086: TYPE_NAMES.put("J", long.class);
087: TYPE_NAMES.put("S", short.class);
088: }
089:
090: /**
091: * Return true if the response object accepts Gzip encoding. This is done by
092: * checking that the accept-encoding header specifies gzip as a supported
093: * encoding.
094: */
095: private static boolean acceptsGzipEncoding(
096: HttpServletRequest request) {
097: assert (request != null);
098:
099: String acceptEncoding = request.getHeader(ACCEPT_ENCODING);
100: if (null == acceptEncoding) {
101: return false;
102: }
103:
104: return (acceptEncoding.indexOf(CONTENT_ENCODING_GZIP) != -1);
105: }
106:
107: /**
108: * This method attempts to estimate the number of bytes that a string will
109: * consume when it is sent out as part of an HttpServletResponse. This really
110: * a hack since we are assuming that every character will consume two bytes
111: * upon transmission. This is definitely not true since some characters
112: * actually consume more than two bytes and some consume less. This is even
113: * less accurate if the string is converted to UTF8. However, it does save us
114: * from converting every string that we plan on sending back to UTF8 just to
115: * determine that we should not compress it.
116: */
117: private static int estimateByteSize(final String buffer) {
118: return (buffer.length() * 2);
119: }
120:
121: /**
122: * Find the invoked method on either the specified interface or any super.
123: */
124: private static Method findInterfaceMethod(Class intf,
125: String methodName, Class[] paramTypes,
126: boolean includeInherited) {
127: try {
128: return intf.getDeclaredMethod(methodName, paramTypes);
129: } catch (NoSuchMethodException e) {
130: if (includeInherited) {
131: Class[] super intfs = intf.getInterfaces();
132: for (int i = 0; i < super intfs.length; i++) {
133: Method method = findInterfaceMethod(super intfs[i],
134: methodName, paramTypes, true);
135: if (method != null) {
136: return method;
137: }
138: }
139: }
140:
141: return null;
142: }
143: }
144:
145: private final Set knownImplementedInterfaces = new HashSet();
146:
147: private final ThreadLocal perThreadRequest = new ThreadLocal();
148:
149: private final ThreadLocal perThreadResponse = new ThreadLocal();
150:
151: private final ServerSerializableTypeOracle serializableTypeOracle;
152:
153: /**
154: * The default constructor.
155: */
156: public GWTRemoteServiceServlet() {
157: serializableTypeOracle = new ServerSerializableTypeOracleImpl(
158: getPackagePaths());
159: }
160:
161: /**
162: * This is called internally.
163: */
164: public final void doPost(HttpServletRequest request,
165: HttpServletResponse response) {
166: Throwable caught;
167: try {
168: // Store the request & response objects in thread-local storage.
169: //
170: perThreadRequest.set(request);
171: perThreadResponse.set(response);
172:
173: // Read the request fully.
174: //
175: String requestPayload = readPayloadAsUtf8(request);
176:
177: // Invoke the core dispatching logic, which returns the serialized
178: // result.
179: //
180: String responsePayload = processCall(requestPayload);
181:
182: // Write the response.
183: //
184: writeResponse(request, response, responsePayload);
185: return;
186: } catch (IOException e) {
187: caught = e;
188: } catch (ServletException e) {
189: caught = e;
190: } catch (SerializationException e) {
191: caught = e;
192: } catch (Throwable e) {
193: caught = e;
194: }
195:
196: respondWithFailure(response, caught);
197: }
198:
199: /**
200: * This is public so that it can be unit tested easily without HTTP.
201: */
202: public String processCall(String payload)
203: throws SerializationException {
204:
205: // Let subclasses see the serialized request.
206: //
207: onBeforeRequestDeserialized(payload);
208:
209: // Create a stream to deserialize the request.
210: //
211: ServerSerializationStreamReader streamReader = new ServerSerializationStreamReader(
212: serializableTypeOracle);
213: streamReader.prepareToRead(payload);
214:
215: // Read the service interface
216: //
217: String serviceIntfName = streamReader.readString();
218:
219: // // TODO(mmendez): need a way to check the type signature of the service intf
220: // // Verify that this very servlet implements the specified interface name.
221: // //
222: // if (!isImplementedRemoteServiceInterface(serviceIntfName)) {
223: // // Bad payload, possible hack attempt.
224: // //
225: // throw new SecurityException(
226: // "Blocked attempt to access interface '"
227: // + serviceIntfName
228: // + "', which is either not implemented by this servlet or which doesn't extend RemoteService; this is either misconfiguration or a hack attempt");
229: // }
230:
231: // Actually get the service interface, so that we can query its methods.
232: //
233: // Class serviceIntf;
234: // try {
235: // serviceIntf = getClassFromName(serviceIntfName);
236: // } catch (ClassNotFoundException e) {
237: // throw new SerializationException("Unknown service interface class '"
238: // + serviceIntfName + "'", e);
239: // }
240:
241: // Read the method name.
242: //
243: String methodName = streamReader.readString();
244:
245: // Read the number and names of the parameter classes from the stream.
246: // We have to do this so that we can find the correct overload of the
247: // method.
248: //
249: int paramCount = streamReader.readInt();
250: Class[] paramTypes = new Class[paramCount];
251: for (int i = 0; i < paramTypes.length; i++) {
252: String paramClassName = streamReader.readString();
253: try {
254: paramTypes[i] = getClassOrPrimitiveFromName(paramClassName);
255: } catch (ClassNotFoundException e) {
256: throw new SerializationException("Unknown parameter "
257: + i + " type '" + paramClassName + "'", e);
258: }
259: }
260:
261: // // For security, make sure the method is found in the service interface
262: // // and not just one that happens to be defined on this class.
263: // //
264: // Method serviceIntfMethod = findInterfaceMethod(serviceIntf, methodName,
265: // paramTypes, true);
266:
267: // // If it wasn't found, don't continue.
268: // //
269: // if (serviceIntfMethod == null) {
270: // // Bad payload, possible hack attempt.
271: // //
272: // throw new SecurityException(
273: // "Method '"
274: // + methodName
275: // + "' (or a particular overload) on interface '"
276: // + serviceIntfName
277: // + "' was not found, this is either misconfiguration or a hack attempt");
278: // }
279:
280: // Deserialize the parameters.
281: //
282: Object[] args = new Object[paramCount];
283: for (int i = 0; i < args.length; i++) {
284: args[i] = streamReader.deserializeValue(paramTypes[i]);
285: }
286:
287: GWTToSeamAdapter adapter = new GWTToSeamAdapter();
288:
289: // Make the call via reflection.
290: //
291: String responsePayload = GENERIC_FAILURE_MSG;
292: ServerSerializationStreamWriter streamWriter = new ServerSerializationStreamWriter(
293: serializableTypeOracle);
294: Throwable caught = null;
295: try {
296:
297: //call the component
298: ReturnedObject returnedObject = adapter
299: .callWebRemoteMethod(serviceIntfName, methodName,
300: paramTypes, args);
301: Class returnType = returnedObject.returnType;
302: Object returnVal = returnedObject.returnedObject;
303: // Class returnType = serviceIntfMethod.getReturnType();
304: // Object returnVal = serviceIntfMethod.invoke(this, args);
305: responsePayload = createResponse(streamWriter, returnType,
306: returnVal, false);
307: } catch (IllegalArgumentException e) {
308: caught = e;
309: } catch (IllegalAccessException e) {
310: caught = e;
311: } catch (InvocationTargetException e) {
312: // Try to serialize the caught exception if the client is expecting it,
313: // otherwise log the exception server-side.
314: caught = e;
315: Throwable cause = e.getCause();
316: if (cause != null) {
317: // Update the caught exception to the underlying cause
318: caught = cause;
319: // Serialize the exception back to the client if it's a declared
320: // exception
321: if (cause instanceof SerializableException) {
322: Class thrownClass = cause.getClass();
323: responsePayload = createResponse(streamWriter,
324: thrownClass, cause, true);
325: // Don't log the exception on the server
326: caught = null;
327: }
328: }
329: }
330:
331: if (caught != null) {
332: responsePayload = GENERIC_FAILURE_MSG;
333: ServletContext servletContext = getServletContext();
334: // servletContext may be null (for example, when unit testing)
335: if (servletContext != null) {
336: // Log the exception server side
337: servletContext
338: .log(
339: "Exception while dispatching incoming RPC call",
340: caught);
341: }
342: }
343:
344: // Let subclasses see the serialized response.
345: //
346: onAfterResponseSerialized(responsePayload);
347:
348: return responsePayload;
349: }
350:
351: /**
352: * Gets the <code>HttpServletRequest</code> object for the current call. It
353: * is stored thread-locally so that simultaneous invocations can have
354: * different request objects.
355: */
356: protected final HttpServletRequest getThreadLocalRequest() {
357: return (HttpServletRequest) perThreadRequest.get();
358: }
359:
360: /**
361: * Gets the <code>HttpServletResponse</code> object for the current call. It
362: * is stored thread-locally so that simultaneous invocations can have
363: * different response objects.
364: */
365: protected final HttpServletResponse getThreadLocalResponse() {
366: return (HttpServletResponse) perThreadResponse.get();
367: }
368:
369: /**
370: * Override this method to examine the serialized response that will be
371: * returned to the client. The default implementation does nothing and need
372: * not be called by subclasses.
373: */
374: protected void onAfterResponseSerialized(String serializedResponse) {
375: }
376:
377: /**
378: * Override this method to examine the serialized version of the request
379: * payload before it is deserialized into objects. The default implementation
380: * does nothing and need not be called by subclasses.
381: */
382: protected void onBeforeRequestDeserialized(String serializedRequest) {
383: }
384:
385: /**
386: * Determines whether the response to a given servlet request should or should
387: * not be GZIP compressed. This method is only called in cases where the
388: * requestor accepts GZIP encoding.
389: * <p>
390: * This implementation currently returns <code>true</code> if the response
391: * string's estimated byte length is longer than 256 bytes. Subclasses can
392: * override this logic.
393: * </p>
394: *
395: * @param request the request being served
396: * @param response the response that will be written into
397: * @param responsePayload the payload that is about to be sent to the client
398: * @return <code>true</code> if responsePayload should be GZIP compressed,
399: * otherwise <code>false</code>.
400: */
401: protected boolean shouldCompressResponse(
402: HttpServletRequest request, HttpServletResponse response,
403: String responsePayload) {
404: return estimateByteSize(responsePayload) > UNCOMPRESSED_BYTE_SIZE_LIMIT;
405: }
406:
407: /**
408: * @param stream
409: * @param responseType
410: * @param responseObj
411: * @param isException
412: * @return response
413: */
414: private String createResponse(
415: ServerSerializationStreamWriter stream, Class responseType,
416: Object responseObj, boolean isException) {
417: stream.prepareToWrite();
418: if (responseType != void.class) {
419: try {
420: stream.serializeValue(responseObj, responseType);
421: } catch (SerializationException e) {
422: responseObj = e;
423: isException = true;
424: }
425: }
426:
427: String bufferStr = (isException ? "{EX}" : "{OK}")
428: + stream.toString();
429: return bufferStr;
430: }
431:
432: /**
433: * Returns the {@link Class} instance for the named class.
434: *
435: * @param name the name of a class or primitive type
436: * @return Class instance for the given type name
437: * @throws ClassNotFoundException if the named type was not found
438: */
439: private Class getClassFromName(String name)
440: throws ClassNotFoundException {
441: return Class.forName(name, false, this .getClass()
442: .getClassLoader());
443: }
444:
445: /**
446: * Returns the {@link Class} instance for the named class or primitive type.
447: *
448: * @param name the name of a class or primitive type
449: * @return Class instance for the given type name
450: * @throws ClassNotFoundException if the named type was not found
451: */
452: private Class getClassOrPrimitiveFromName(String name)
453: throws ClassNotFoundException {
454: Object value = TYPE_NAMES.get(name);
455: if (value != null) {
456: return (Class) value;
457: }
458:
459: return getClassFromName(name);
460: }
461:
462: /**
463: * Obtain the special package-prefixes we use to check for custom serializers
464: * that would like to live in a package that they cannot. For example,
465: * "java.util.ArrayList" is in a sealed package, so instead we use this prefix
466: * to check for a custom serializer in
467: * "com.google.gwt.user.client.rpc.core.java.util.ArrayList". Right now, it's
468: * hard-coded because we don't have a pressing need for this mechanism to be
469: * extensible, but it is imaginable, which is why it's implemented this way.
470: */
471: private String[] getPackagePaths() {
472: return new String[] { "com.google.gwt.user.client.rpc.core" };
473: }
474:
475: /**
476: * Returns true if the {@link java.lang.reflect.Method Method} definition on
477: * the service is specified to throw the exception contained in the
478: * InvocationTargetException or false otherwise. NOTE we do not check that the
479: * type is serializable here. We assume that it must be otherwise the
480: * application would never have been allowed to run.
481: *
482: * @param serviceIntfMethod
483: * @param e
484: * @return is expected exception
485: */
486: private boolean isExpectedException(Method serviceIntfMethod,
487: Throwable cause) {
488: assert (serviceIntfMethod != null);
489: assert (cause != null);
490:
491: Class[] exceptionsThrown = serviceIntfMethod
492: .getExceptionTypes();
493: if (exceptionsThrown.length <= 0) {
494: // The method is not specified to throw any exceptions
495: //
496: return false;
497: }
498:
499: Class causeType = cause.getClass();
500:
501: for (int index = 0; index < exceptionsThrown.length; ++index) {
502: Class exceptionThrown = exceptionsThrown[index];
503: assert (exceptionThrown != null);
504:
505: if (exceptionThrown.isAssignableFrom(causeType)) {
506: return true;
507: }
508: }
509:
510: return false;
511: }
512:
513: /**
514: * Used to determine whether the specified interface name is implemented by
515: * this class without loading the class (for security).
516: */
517: private boolean isImplementedRemoteServiceInterface(String intfName) {
518: synchronized (knownImplementedInterfaces) {
519: // See if it's cached.
520: //
521: if (knownImplementedInterfaces.contains(intfName)) {
522: return true;
523: }
524:
525: Class cls = getClass();
526:
527: // Unknown, so walk up the class hierarchy to find the first class that
528: // implements the requested interface
529: //
530: while ((cls != null)
531: && !GWTRemoteServiceServlet.class.equals(cls)) {
532: Class[] intfs = cls.getInterfaces();
533: for (int i = 0; i < intfs.length; i++) {
534: Class intf = intfs[i];
535: if (isImplementedRemoteServiceInterfaceRecursive(
536: intfName, intf)) {
537: knownImplementedInterfaces.add(intfName);
538: return true;
539: }
540: }
541:
542: // did not find the interface in this class so we look in the
543: // superclass
544: cls = cls.getSuperclass();
545: }
546:
547: return false;
548: }
549: }
550:
551: /**
552: * Only called from isImplementedInterface().
553: */
554: private boolean isImplementedRemoteServiceInterfaceRecursive(
555: String intfName, Class intfToCheck) {
556: assert (intfToCheck.isInterface());
557:
558: if (intfToCheck.getName().equals(intfName)) {
559: // The name is right, but we also verify that it is assignable to
560: // RemoteService.
561: //
562: if (RemoteService.class.isAssignableFrom(intfToCheck)) {
563: return true;
564: } else {
565: return false;
566: }
567: }
568:
569: Class[] intfs = intfToCheck.getInterfaces();
570: for (int i = 0; i < intfs.length; i++) {
571: Class intf = intfs[i];
572: if (isImplementedRemoteServiceInterfaceRecursive(intfName,
573: intf)) {
574: return true;
575: }
576: }
577:
578: return false;
579: }
580:
581: private String readPayloadAsUtf8(HttpServletRequest request)
582: throws IOException, ServletException {
583: int contentLength = request.getContentLength();
584: if (contentLength == -1) {
585: // Content length must be known.
586: throw new ServletException(
587: "Content-Length must be specified");
588: }
589:
590: String contentType = request.getContentType();
591: boolean contentTypeIsOkay = false;
592: // Content-Type must be specified.
593: if (contentType != null) {
594: // The type must be plain text.
595: if (contentType.startsWith("text/plain")) {
596: // And it must be UTF-8 encoded (or unspecified, in which case we assume
597: // that it's either UTF-8 or ASCII).
598: if (contentType.indexOf("charset=") == -1) {
599: contentTypeIsOkay = true;
600: } else if (contentType.indexOf("charset=utf-8") != -1) {
601: contentTypeIsOkay = true;
602: }
603: }
604: }
605: if (!contentTypeIsOkay) {
606: throw new ServletException(
607: "Content-Type must be 'text/plain' with 'charset=utf-8' (or unspecified charset)");
608: }
609: InputStream in = request.getInputStream();
610: try {
611: byte[] payload = new byte[contentLength];
612: int offset = 0;
613: int len = contentLength;
614: int byteCount;
615: while (offset < contentLength) {
616: byteCount = in.read(payload, offset, len);
617: if (byteCount == -1) {
618: throw new ServletException("Client did not send "
619: + contentLength + " bytes as expected");
620: }
621: offset += byteCount;
622: len -= byteCount;
623: }
624: return new String(payload, "UTF-8");
625: } finally {
626: if (in != null) {
627: in.close();
628: }
629: }
630: }
631:
632: /**
633: * Called when the machinery of this class itself has a problem, rather than
634: * the invoked third-party method. It writes a simple 500 message back to the
635: * client.
636: */
637: private void respondWithFailure(HttpServletResponse response,
638: Throwable caught) {
639: ServletContext servletContext = getServletContext();
640: servletContext
641: .log("Exception while dispatching incoming RPC call",
642: caught);
643: try {
644: response.setContentType("text/plain");
645: response
646: .setStatus(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
647: response.getWriter().write(GENERIC_FAILURE_MSG);
648: } catch (IOException e) {
649: servletContext
650: .log(
651: "sendError() failed while sending the previous failure to the client",
652: caught);
653: }
654: }
655:
656: private void writeResponse(HttpServletRequest request,
657: HttpServletResponse response, String responsePayload)
658: throws IOException {
659:
660: byte[] reply = responsePayload.getBytes(CHARSET_UTF8);
661: String contentType = CONTENT_TYPE_TEXT_PLAIN_UTF8;
662:
663: if (acceptsGzipEncoding(request)
664: && shouldCompressResponse(request, response,
665: responsePayload)) {
666: // Compress the reply and adjust headers.
667: //
668: ByteArrayOutputStream output = null;
669: GZIPOutputStream gzipOutputStream = null;
670: Throwable caught = null;
671: try {
672: output = new ByteArrayOutputStream(reply.length);
673: gzipOutputStream = new GZIPOutputStream(output);
674: gzipOutputStream.write(reply);
675: gzipOutputStream.finish();
676: gzipOutputStream.flush();
677: response.setHeader(CONTENT_ENCODING,
678: CONTENT_ENCODING_GZIP);
679: reply = output.toByteArray();
680: } catch (UnsupportedEncodingException e) {
681: caught = e;
682: } catch (IOException e) {
683: caught = e;
684: } finally {
685: if (null != gzipOutputStream) {
686: gzipOutputStream.close();
687: }
688: if (null != output) {
689: output.close();
690: }
691: }
692:
693: if (caught != null) {
694: getServletContext().log("Unable to compress response",
695: caught);
696: response
697: .sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR);
698: return;
699: }
700: }
701:
702: // Send the reply.
703: //
704: response.setContentLength(reply.length);
705: response.setContentType(contentType);
706: response.setStatus(HttpServletResponse.SC_OK);
707: response.getOutputStream().write(reply);
708: }
709: }
|