001: /*
002: * Copyright (c) 1998-2007 Caucho Technology -- all rights reserved
003: *
004: * This file is part of Resin(R) Open Source
005: *
006: * Each copy or derived work must preserve the copyright notice and this
007: * notice unmodified.
008: *
009: * Resin Open Source is free software; you can redistribute it and/or modify
010: * it under the terms of the GNU General Public License as published by
011: * the Free Software Foundation; either version 2 of the License, or
012: * (at your option) any later version.
013: *
014: * Resin Open Source is distributed in the hope that it will be useful,
015: * but WITHOUT ANY WARRANTY; without even the implied warranty of
016: * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE, or any warranty
017: * of NON-INFRINGEMENT. See the GNU General Public License for more
018: * details.
019: *
020: * You should have received a copy of the GNU General Public License
021: * along with Resin Open Source; if not, write to the
022: *
023: * Free Software Foundation, Inc.
024: * 59 Temple Place, Suite 330
025: * Boston, MA 02111-1307 USA
026: *
027: * @author Emil Ong
028: */
029:
030: package com.caucho.soa.rest;
031:
032: import com.caucho.server.util.CauchoSystem;
033: import com.caucho.soa.servlet.ProtocolServlet;
034: import com.caucho.vfs.Vfs;
035: import com.caucho.vfs.WriteStream;
036: import com.caucho.xml.stream.XMLStreamWriterImpl;
037:
038: import javax.jws.WebMethod;
039: import javax.jws.WebParam;
040: import javax.jws.WebService;
041: import javax.servlet.GenericServlet;
042: import javax.servlet.ServletException;
043: import javax.servlet.ServletRequest;
044: import javax.servlet.ServletResponse;
045: import javax.servlet.http.HttpServletRequest;
046: import javax.servlet.http.HttpServletResponse;
047: import javax.xml.bind.JAXBContext;
048: import javax.xml.bind.JAXBException;
049: import javax.xml.bind.Marshaller;
050: import javax.xml.bind.Unmarshaller;
051: import javax.xml.namespace.QName;
052: import java.io.IOException;
053: import java.io.InputStream;
054: import java.io.OutputStream;
055: import java.lang.annotation.Annotation;
056: import java.lang.reflect.Method;
057: import java.lang.reflect.Modifier;
058: import java.util.ArrayList;
059: import java.util.HashMap;
060: import java.util.Map;
061: import java.util.logging.Logger;
062:
063: /**
064: * A binding for REST services.
065: */
066: public abstract class RestProtocolServlet extends GenericServlet
067: implements ProtocolServlet {
068: private static final Logger log = Logger
069: .getLogger(RestProtocolServlet.class.getName());
070:
071: public static final String DELETE = "DELETE";
072: public static final String GET = "GET";
073: public static final String HEAD = "HEAD";
074: public static final String POST = "POST";
075: public static final String PUT = "PUT";
076:
077: private Map<String, Map<String, Method>> _methods = new HashMap<String, Map<String, Method>>();
078:
079: private HashMap<String, Method> _defaultMethods = new HashMap<String, Method>();
080:
081: protected Object _service;
082:
083: public RestProtocolServlet() {
084: }
085:
086: public void setService(Object service) {
087: _service = service;
088: }
089:
090: public void init() throws ServletException {
091: try {
092: Class cl = _service.getClass();
093:
094: if (cl.isAnnotationPresent(WebService.class)) {
095: WebService webService = (WebService) cl
096: .getAnnotation(WebService.class);
097:
098: String endpoint = webService.endpointInterface();
099:
100: if (endpoint != null && !"".equals(endpoint))
101: cl = CauchoSystem.loadClass(webService
102: .endpointInterface());
103: }
104:
105: _methods.put(DELETE, new HashMap<String, Method>());
106: _methods.put(GET, new HashMap<String, Method>());
107: _methods.put(HEAD, new HashMap<String, Method>());
108: _methods.put(POST, new HashMap<String, Method>());
109: _methods.put(PUT, new HashMap<String, Method>());
110:
111: for (Method method : cl.getMethods()) {
112: if (method.getDeclaringClass().equals(Object.class))
113: continue;
114:
115: int modifiers = method.getModifiers();
116:
117: // Allow abstract for interfaces
118: if (Modifier.isStatic(modifiers)
119: || Modifier.isFinal(modifiers)
120: || !Modifier.isPublic(modifiers))
121: continue;
122:
123: String methodName = method.getName();
124:
125: if (method.isAnnotationPresent(WebMethod.class)) {
126: WebMethod webMethod = (WebMethod) method
127: .getAnnotation(WebMethod.class);
128:
129: if (!"".equals(webMethod.operationName()))
130: methodName = webMethod.operationName();
131: }
132:
133: if (method.isAnnotationPresent(RestMethod.class)) {
134: RestMethod restMethod = (RestMethod) method
135: .getAnnotation(RestMethod.class);
136:
137: if (!"".equals(restMethod.operationName()))
138: methodName = restMethod.operationName();
139: }
140:
141: boolean hasHTTPMethod = false;
142:
143: if (method.isAnnotationPresent(Delete.class)) {
144: if (_methods.get(DELETE).containsKey(methodName)) {
145: throw new UnsupportedOperationException(
146: "Overloaded method: "
147: + method.getName());
148: }
149:
150: _methods.get(DELETE).put(methodName, method);
151:
152: hasHTTPMethod = true;
153: }
154:
155: if (method.isAnnotationPresent(Get.class)) {
156: if (_methods.get(GET).containsKey(methodName)) {
157: throw new UnsupportedOperationException(
158: "Overloaded method: "
159: + method.getName());
160: }
161:
162: _methods.get(GET).put(methodName, method);
163:
164: hasHTTPMethod = true;
165: }
166:
167: if (method.isAnnotationPresent(Post.class)) {
168: if (_methods.get(POST).containsKey(methodName)) {
169: throw new UnsupportedOperationException(
170: "Overloaded method: "
171: + method.getName());
172: }
173:
174: _methods.get(POST).put(methodName, method);
175:
176: hasHTTPMethod = true;
177: }
178:
179: if (method.isAnnotationPresent(Put.class)) {
180: if (_methods.get(PUT).containsKey(methodName)) {
181: throw new UnsupportedOperationException(
182: "Overloaded method: "
183: + method.getName());
184: }
185:
186: _methods.get(PUT).put(methodName, method);
187:
188: hasHTTPMethod = true;
189: }
190:
191: if (method.isAnnotationPresent(Head.class)) {
192: if (_methods.get(HEAD).containsKey(methodName)) {
193: throw new UnsupportedOperationException(
194: "Overloaded method: "
195: + method.getName());
196: }
197:
198: _methods.get(HEAD).put(methodName, method);
199:
200: hasHTTPMethod = true;
201: }
202:
203: if (!hasHTTPMethod) {
204: if (_defaultMethods.containsKey(methodName)) {
205: throw new UnsupportedOperationException(
206: "Overloaded method: "
207: + method.getName());
208: }
209:
210: _defaultMethods.put(methodName, method);
211: }
212: }
213: } catch (Exception e) {
214: throw new ServletException(e);
215: }
216: }
217:
218: public void service(ServletRequest request, ServletResponse response)
219: throws ServletException, IOException {
220: HttpServletRequest req = (HttpServletRequest) request;
221: HttpServletResponse res = (HttpServletResponse) response;
222:
223: Map<String, String> queryArguments = new HashMap<String, String>();
224:
225: if (req.getQueryString() != null)
226: queryToMap(req.getQueryString(), queryArguments);
227:
228: String[] pathArguments = null;
229:
230: if (req.getPathInfo() != null) {
231: String pathInfo = req.getPathInfo();
232:
233: // remove the initial and final slashes
234: int startPos = 0;
235: int endPos = pathInfo.length();
236:
237: if (pathInfo.length() > 0 && pathInfo.charAt(0) == '/')
238: startPos = 1;
239:
240: if (pathInfo.length() > startPos
241: && pathInfo.charAt(pathInfo.length() - 1) == '/')
242: endPos = pathInfo.length() - 1;
243:
244: pathInfo = pathInfo.substring(startPos, endPos);
245:
246: pathArguments = pathInfo.split("/");
247:
248: if (pathArguments.length == 1
249: && pathArguments[0].length() == 0)
250: pathArguments = new String[0];
251: } else
252: pathArguments = new String[0];
253:
254: try {
255: invoke(_service, req.getMethod(), pathArguments,
256: queryArguments, req, req.getInputStream(), res
257: .getOutputStream());
258: } catch (NoSuchMethodException e) {
259: res.sendError(HttpServletResponse.SC_BAD_REQUEST);
260: } catch (Throwable e) {
261: throw new ServletException(e);
262: }
263: }
264:
265: private static void queryToMap(String query,
266: Map<String, String> queryArguments) {
267: String[] entries = query.split("&");
268:
269: for (String entry : entries) {
270: if (entry.indexOf("=") < 0)
271: continue;
272:
273: String[] nameValue = entry.split("=", 2);
274:
275: queryArguments.put(nameValue[0], nameValue[1]);
276: }
277: }
278:
279: private void invoke(Object object, String httpMethod,
280: String[] pathArguments, Map<String, String> queryArguments,
281: HttpServletRequest req, InputStream postData,
282: OutputStream out) throws Throwable {
283: int pathIndex = 0;
284: boolean pathMethod = false;
285:
286: // Two special approaches: path and query
287: //
288: // Path takes the first part of the path as the method name
289: //
290: // Query checks for /?method=myMethod in the query part
291: //
292: // Query overrides path since it's more explicit
293:
294: String methodName = queryArguments.get("method");
295:
296: if ((methodName == null) && (pathArguments.length > 0)) {
297: methodName = pathArguments[0];
298:
299: if (methodName != null)
300: pathMethod = true;
301: }
302:
303: // First, look by http method and method name
304: // This may hit the default method since methodName can be null
305: Method method = _methods.get(httpMethod).get(methodName);
306:
307: // next, check for a default method, ignoring http method
308: if (method == null)
309: method = _defaultMethods.get(methodName);
310:
311: // finally, check for a completely default method
312: if (method == null) {
313: method = _defaultMethods.get(null);
314:
315: pathMethod = false;
316: }
317:
318: if (method == null)
319: throw new NoSuchMethodException(methodName);
320:
321: if (pathMethod)
322: pathIndex = 1;
323:
324: // Construct the arguments for the invocation
325: ArrayList arguments = new ArrayList();
326:
327: Class[] parameterTypes = method.getParameterTypes();
328: Annotation[][] annotations = method.getParameterAnnotations();
329:
330: for (int i = 0; i < parameterTypes.length; i++) {
331: RestParam.Source source = RestParam.Source.QUERY;
332: String key = "arg" + i;
333:
334: for (int j = 0; j < annotations[i].length; j++) {
335: if (annotations[i][j].annotationType().equals(
336: RestParam.class)) {
337: RestParam restParam = (RestParam) annotations[i][j];
338: source = restParam.source();
339: } else if (annotations[i][j].annotationType().equals(
340: WebParam.class)) {
341: WebParam webParam = (WebParam) annotations[i][j];
342:
343: if (!"".equals(webParam.name()))
344: key = webParam.name();
345: }
346: }
347:
348: switch (source) {
349: case PATH: {
350: String arg = null;
351:
352: if (pathIndex < pathArguments.length)
353: arg = pathArguments[pathIndex++];
354:
355: arguments.add(stringToType(parameterTypes[i], arg));
356: // XXX var args
357: }
358: break;
359: case QUERY:
360: arguments.add(stringToType(parameterTypes[i],
361: queryArguments.get(key)));
362: break;
363: case POST:
364: arguments.add(readPostData(postData));
365: break;
366: case HEADER:
367: arguments.add(stringToType(parameterTypes[i], req
368: .getHeader(key)));
369: break;
370: }
371: }
372:
373: Object result = method.invoke(object, arguments.toArray());
374:
375: if (result != null)
376: writeResponse(out, result);
377: }
378:
379: protected abstract Object readPostData(InputStream in)
380: throws IOException, RestException;
381:
382: protected abstract void writeResponse(OutputStream out, Object obj)
383: throws IOException, RestException;
384:
385: private static Object stringToType(Class type, String arg)
386: throws Throwable {
387: if (arg == null) {
388: return null;
389: } else if (type.equals(boolean.class)) {
390: return new Boolean(arg);
391: } else if (type.equals(Boolean.class)) {
392: return new Boolean(arg);
393: } else if (type.equals(byte.class)) {
394: return new Byte(arg);
395: } else if (type.equals(Byte.class)) {
396: return new Byte(arg);
397: } else if (type.equals(char.class)) {
398: if (arg.length() != 1) {
399: throw new IllegalArgumentException(
400: "Cannot convert String to type "
401: + type.getName());
402: }
403:
404: return new Character(arg.charAt(0));
405: } else if (type.equals(Character.class)) {
406: if (arg.length() != 1) {
407: throw new IllegalArgumentException(
408: "Cannot convert String to type "
409: + type.getName());
410: }
411:
412: return new Character(arg.charAt(0));
413: } else if (type.equals(double.class)) {
414: return new Double(arg);
415: } else if (type.equals(Double.class)) {
416: return new Double(arg);
417: } else if (type.equals(float.class)) {
418: return new Float(arg);
419: } else if (type.equals(Float.class)) {
420: return new Float(arg);
421: } else if (type.equals(int.class)) {
422: return new Integer(arg);
423: } else if (type.equals(Integer.class)) {
424: return new Integer(arg);
425: } else if (type.equals(long.class)) {
426: return new Long(arg);
427: } else if (type.equals(Long.class)) {
428: return new Long(arg);
429: } else if (type.equals(short.class)) {
430: return new Short(arg);
431: } else if (type.equals(Short.class)) {
432: return new Short(arg);
433: } else if (type.equals(String.class)) {
434: return arg;
435: } else
436: throw new IllegalArgumentException(
437: "Cannot convert String to type " + type.getName());
438: }
439: }
|