package org.zstack.core.rest; import org.apache.http.HttpStatus; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.http.*; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.web.client.HttpClientErrorException; import org.springframework.web.client.RestClientException; import org.springframework.web.client.RestTemplate; import org.springframework.web.util.UriComponentsBuilder; import org.zstack.core.CoreGlobalProperty; import org.zstack.core.MessageCommandRecorder; import org.zstack.core.Platform; import org.zstack.core.errorcode.ErrorFacade; import org.zstack.core.retry.Retry; import org.zstack.core.retry.RetryCondition; import org.zstack.core.thread.AsyncThread; import org.zstack.core.thread.CancelablePeriodicTask; import org.zstack.core.thread.ThreadFacade; import org.zstack.core.thread.ThreadFacadeImpl.TimeoutTaskReceipt; import org.zstack.core.timeout.ApiTimeoutManager; import org.zstack.core.validation.ValidationFacade; import org.zstack.header.core.Completion; import org.zstack.header.errorcode.ErrorCode; import org.zstack.header.errorcode.OperationFailureException; import org.zstack.header.errorcode.SysErrors; import org.zstack.header.exception.CloudRuntimeException; import org.zstack.header.rest.*; import org.zstack.utils.DebugUtils; import org.zstack.utils.ExceptionDSL; import org.zstack.utils.IptablesUtils; import org.zstack.utils.Utils; import org.zstack.utils.gson.JSONObjectUtil; import org.zstack.utils.logging.CLogger; import static org.zstack.core.Platform.operr; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; import java.io.IOException; import java.util.ArrayList; import java.util.Enumeration; import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; public class RESTFacadeImpl implements RESTFacade { private static final CLogger logger = Utils.getLogger(RESTFacadeImpl.class); @Autowired private ThreadFacade thdf; @Autowired private ErrorFacade errf; @Autowired private ApiTimeoutManager timeoutMgr; @Autowired private ValidationFacade vf; private String hostname; private int port = 8080; private String path; private String callbackUrl; private RestTemplate template; private String baseUrl; private String sendCommandUrl; private Map<String, HttpCallStatistic> statistics = new ConcurrentHashMap<String, HttpCallStatistic>(); private Map<String, HttpCallHandlerWrapper> httpCallhandlers = new ConcurrentHashMap<String, HttpCallHandlerWrapper>(); private List<BeforeAsyncJsonPostInterceptor> interceptors = new ArrayList<BeforeAsyncJsonPostInterceptor>(); private interface AsyncHttpWrapper { void fail(ErrorCode err); void success(HttpEntity<String> responseEntity); } private interface HttpCallHandlerWrapper { String handle(HttpEntity<String> entity); HttpCallHandler getHandler(); } private Map<String, AsyncHttpWrapper> wrappers = new ConcurrentHashMap<String, AsyncHttpWrapper>(); void init() { IptablesUtils.insertRuleToFilterTable(String.format("-A INPUT -p tcp -m state --state NEW -m tcp --dport %s -j ACCEPT", port)); String hname = null; if ("AUTO".equals(hostname)) { hname = Platform.getManagementServerIp(); } else { hname = hostname; } String url; if ("".equals(path) || path == null) { url = String.format("http://%s:%s", hname, port); } else { url = String.format("http://%s:%s/%s", hname, port, path); } UriComponentsBuilder ub = UriComponentsBuilder.fromHttpUrl(url); ub.path(RESTConstant.CALLBACK_PATH); callbackUrl = ub.build().toUriString(); ub = UriComponentsBuilder.fromHttpUrl(url); baseUrl = ub.build().toUriString(); ub = UriComponentsBuilder.fromHttpUrl(url); ub.path(RESTConstant.COMMAND_CHANNEL_PATH); sendCommandUrl = ub.build().toUriString(); logger.debug(String.format("RESTFacade built callback url: %s", callbackUrl)); template = RESTFacade.createRestTemplate(CoreGlobalProperty.REST_FACADE_READ_TIMEOUT, CoreGlobalProperty.REST_FACADE_CONNECT_TIMEOUT); } void notifyCallback(HttpServletRequest req, HttpServletResponse rsp) { String taskUuid = req.getHeader(RESTConstant.TASK_UUID); try { HttpEntity<String> entity = this.httpServletRequestToHttpEntity(req); if (taskUuid == null) { rsp.sendError(HttpStatus.SC_BAD_REQUEST, "No 'taskUuid' found in the header"); logger.warn(String.format("Received a callback request, but no 'taskUuid' found in headers. request body: %s", entity.getBody())); return; } AsyncHttpWrapper wrapper = wrappers.get(taskUuid); if (wrapper == null) { rsp.sendError(HttpStatus.SC_NOT_FOUND, String.format("No callback found for taskUuid[%s]", taskUuid)); logger.warn(String.format("Received a callback request, but no 'callback found for taskUuid[%s]. request body: %s", taskUuid, entity.getBody())); return; } rsp.setStatus(HttpStatus.SC_OK); wrapper.success(entity); } catch (IOException e) { logger.warn(e.getMessage(), e); } catch (Throwable t) { try { rsp.sendError(HttpStatus.SC_INTERNAL_SERVER_ERROR, t.getMessage()); } catch (IOException e) { logger.warn(e.getMessage(), e); } } } void sendCommand(HttpServletRequest req, HttpServletResponse rsp) { String commandPath = req.getHeader(RESTConstant.COMMAND_PATH); try { HttpEntity<String> entity = this.httpServletRequestToHttpEntity(req); if (commandPath == null) { rsp.sendError(HttpStatus.SC_BAD_REQUEST, "No 'commandPath' found in the header"); logger.warn(String.format("Received a command, but no 'taskUuid' found in headers. request body: %s", entity.getBody())); return; } HttpCallHandlerWrapper handler = httpCallhandlers.get(commandPath); if (handler == null) { rsp.sendError(HttpStatus.SC_NOT_FOUND, String.format("no handler found for the command path[%s]", commandPath)); logger.warn(String.format("Received a command, but no handler found for the path[%s]. request body: %s", commandPath, entity.getBody())); return; } String ret = handler.handle(entity); if (ret == null) { rsp.setStatus(HttpStatus.SC_OK); } else { rsp.setStatus(HttpStatus.SC_OK, ret); } } catch (IOException e) { logger.warn(e.getMessage(), e); } catch (Throwable t) { logger.warn(t.getMessage(), t); try { rsp.sendError(HttpStatus.SC_INTERNAL_SERVER_ERROR, t.getMessage()); } catch (IOException e) { logger.warn(e.getMessage(), e); } } } public void setHostname(String hostname) { this.hostname = hostname; } public void setPort(int port) { this.port = port; } public void setPath(String path) { this.path = path; } @Override public void asyncJsonPost(String url, Object body, Map<String, String> headers, AsyncRESTCallback callback, TimeUnit unit, long timeout) { for (BeforeAsyncJsonPostInterceptor ic : interceptors) { ic.beforeAsyncJsonPost(url, body, unit, timeout); } // for unit test finding invocation chain MessageCommandRecorder.record(body.getClass()); String bodyStr = JSONObjectUtil.toJsonString(body); asyncJsonPost(url, bodyStr, headers, callback, unit, timeout); } @Override public void asyncJsonPost(String url, Object body, AsyncRESTCallback callback, TimeUnit unit, long timeout) { asyncJsonPost(url, body, null, callback, unit, timeout); } @Override public void asyncJsonPost(final String url, final String body, final AsyncRESTCallback callback, final TimeUnit unit, final long timeout) { asyncJsonPost(url, body, null, callback, unit, timeout); } @Override public void asyncJsonPost(final String url, final String body, Map<String, String> headers, final AsyncRESTCallback callback, final TimeUnit unit, final long timeout) { for (BeforeAsyncJsonPostInterceptor ic : interceptors) { ic.beforeAsyncJsonPost(url, body, unit, timeout); } long stime = 0; if (CoreGlobalProperty.PROFILER_HTTP_CALL) { stime = System.currentTimeMillis(); HttpCallStatistic stat = statistics.get(url); if (stat == null) { stat = new HttpCallStatistic(); stat.setUrl(url); statistics.put(url, stat); } } final String taskUuid = Platform.getUuid(); final long finalStime = stime; AsyncHttpWrapper wrapper = new AsyncHttpWrapper() { AtomicBoolean called = new AtomicBoolean(false); final AsyncHttpWrapper self = this; TimeoutTaskReceipt timeoutTaskReceipt = thdf.submitTimeoutTask(new Runnable() { @Override public void run() { self.fail(errf.stringToTimeoutError( String.format("[Async Http Timeout] url: %s, timeout after %s[%s], command: %s", url, timeout, unit.toString(), body) )); } }, unit, timeout); private void cancelTimeout() { timeoutTaskReceipt.cancel(); } public void fail(ErrorCode err) { if (!called.compareAndSet(false, true)) { return; } wrappers.remove(taskUuid); if (!SysErrors.TIMEOUT.toString().equals(err.getCode())) { cancelTimeout(); } callback.fail(err); } @Override @AsyncThread public void success(HttpEntity<String> responseEntity) { if (!called.compareAndSet(false, true)) { return; } if (CoreGlobalProperty.PROFILER_HTTP_CALL) { HttpCallStatistic stat = statistics.get(url); stat.addStatistic(System.currentTimeMillis() - finalStime); } wrappers.remove(taskUuid); cancelTimeout(); if (logger.isTraceEnabled()) { List<String> hs = responseEntity.getHeaders().get(RESTConstant.TASK_UUID); String taskUuid = hs == null || hs.isEmpty() ? null : hs.get(0); if (taskUuid == null) { logger.trace(String.format("[http response(url: %s)] %s", url, responseEntity.getBody())); } else { logger.trace(String.format("[http response(url: %s, taskUuid: %s)] %s", url, taskUuid, responseEntity.getBody())); } } if (callback instanceof JsonAsyncRESTCallback) { JsonAsyncRESTCallback jcallback = (JsonAsyncRESTCallback)callback; Object obj = JSONObjectUtil.toObject(responseEntity.getBody(), jcallback.getReturnClass()); try { ErrorCode err = vf.validateErrorByErrorCode(obj); if (err != null) { logger.warn(String.format("error response that causes validation failure: %s", responseEntity.getBody())); jcallback.fail(err); } else { jcallback.success(obj); } } catch (Throwable t) { logger.warn(t.getMessage(), t); callback.fail(errf.throwableToInternalError(t)); } } else { callback.success(responseEntity); } } }; try { wrappers.put(taskUuid, wrapper); HttpHeaders requestHeaders = new HttpHeaders(); requestHeaders.setContentType(MediaType.APPLICATION_JSON); requestHeaders.setContentLength(body.length()); requestHeaders.set(RESTConstant.TASK_UUID, taskUuid); requestHeaders.set(RESTConstant.CALLBACK_URL, callbackUrl); if (headers != null) { for (Map.Entry<String, String> e : headers.entrySet()) { requestHeaders.set(e.getKey(), e.getValue()); } } HttpEntity<String> req = new HttpEntity<String>(body, requestHeaders); if (logger.isTraceEnabled()) { logger.trace(String.format("json post[%s], %s", url, req.toString())); } ResponseEntity<String> rsp; try { if (CoreGlobalProperty.UNIT_TEST_ON) { rsp = template.exchange(url, HttpMethod.POST, req, String.class); } else { rsp = new Retry<ResponseEntity<String>>() { @Override @RetryCondition(onExceptions = {IOException.class, RestClientException.class, HttpClientErrorException.class}) protected ResponseEntity<String> call() { return template.exchange(url, HttpMethod.POST, req, String.class); } }.run(); } } catch (HttpClientErrorException e) { String err = String.format("http status: %s, response body:%s", e.getStatusCode(), e.getResponseBodyAsString()); wrapper.fail(errf.instantiateErrorCode(SysErrors.HTTP_ERROR, err)); return; } if (rsp.getStatusCode() != org.springframework.http.HttpStatus.OK) { String err = String.format("http status: %s, response body:%s", rsp.getStatusCode().toString(), rsp.getBody()); logger.warn(err); wrapper.fail(errf.instantiateErrorCode(SysErrors.HTTP_ERROR, err)); } } catch (Throwable e) { logger.warn(String.format("Unable to post to %s", url), e); wrapper.fail(ExceptionDSL.isCausedBy(e, IOException.class) ? errf.instantiateErrorCode(SysErrors.IO_ERROR, e.getMessage()) : errf.throwableToInternalError(e)); } } @Override public void asyncJsonPost(String url, Object body, Map<String, String> headers, AsyncRESTCallback callback) { Long timeout = timeoutMgr.getTimeout(body.getClass()); asyncJsonPost(url, body, headers, callback, TimeUnit.MILLISECONDS, timeout == null ? 300000 : timeout); } @Override public void asyncJsonPost(String url, Object body, AsyncRESTCallback callback) { Long timeout = timeoutMgr.getTimeout(body.getClass()); asyncJsonPost(url, body, callback, TimeUnit.MILLISECONDS, timeout == null ? 300000 : timeout); } @Override public void asyncJsonPost(String url, String body, AsyncRESTCallback callback) { asyncJsonPost(url, body, callback, TimeUnit.SECONDS, 300); } @Override public HttpEntity<String> httpServletRequestToHttpEntity(HttpServletRequest req) { try { StringBuilder sb = new StringBuilder(); String line; while ((line = req.getReader().readLine()) != null) { sb.append(line); } req.getReader().close(); HttpHeaders header = new HttpHeaders(); for (Enumeration e = req.getHeaderNames() ; e.hasMoreElements() ;) { String name = e.nextElement().toString(); header.add(name, req.getHeader(name)); } return new HttpEntity<String>(sb.toString(), header); } catch (Exception e) { logger.warn(e.getMessage(), e); throw new CloudRuntimeException(e); } } @Override public RestTemplate getRESTTemplate() { return template; } @Override public <T> T syncJsonPost(String url, Object body, Class<T> returnClass) { // for unit test finding invocation chain if (body != null) { MessageCommandRecorder.record(body.getClass()); } return syncJsonPost(url, body == null ? null : JSONObjectUtil.toJsonString(body), returnClass); } @Override public <T> T syncJsonPost(String url, String body, Class<T> returnClass) { return syncJsonPost(url, body, null, returnClass); } @Override public <T> T syncJsonPost(String url, String body, Map<String, String> headers, Class<T> returnClass) { body = body == null ? "" : body; HttpHeaders requestHeaders = new HttpHeaders(); if (headers != null) { requestHeaders.setAll(headers); } requestHeaders.setContentType(MediaType.APPLICATION_JSON); requestHeaders.setContentLength(body.length()); HttpEntity<String> req = new HttpEntity<String>(body, requestHeaders); if (logger.isTraceEnabled()) { logger.trace(String.format("json post[%s], %s", url, req.toString())); } ResponseEntity<String> rsp = new Retry<ResponseEntity<String>>() { @Override @RetryCondition(onExceptions = {IOException.class, RestClientException.class}) protected ResponseEntity<String> call() { return template.exchange(url, HttpMethod.POST, req, String.class); } }.run(); if (rsp.getStatusCode() != org.springframework.http.HttpStatus.OK) { throw new OperationFailureException(operr("failed to post to %s, status code: %s, response body: %s", url, rsp.getStatusCode(), rsp.getBody())); } if (rsp.getBody() != null && returnClass != Void.class) { if (logger.isTraceEnabled()) { logger.trace(String.format("[http response(url: %s)] %s", url, rsp.getBody())); } return JSONObjectUtil.toObject(rsp.getBody(), returnClass); } else { return null; } } @Override public void echo(String url, Completion callback) { echo(url, callback, TimeUnit.SECONDS.toMillis(1), TimeUnit.SECONDS.toMillis(30)); } @Override public void echo(final String url, final Completion completion, final long interval, final long timeout) { class Echo implements CancelablePeriodicTask { private long count; Echo() { this.count = timeout / interval; DebugUtils.Assert(count!=0, String.format("invalid timeout[%s], interval[%s]", timeout, interval)); } @Override public boolean run() { try { syncJsonPost(url, "", Void.class); logger.debug(String.format("successfully echo %s", url)); completion.success(); return true; } catch (Exception e) { String info = String.format("still unable to echo %s, will try %s times. %s", url, count, e.getMessage()); logger.debug(info); if (--count <= 0) { completion.fail(operr("unable to echo %s in %sms", url, timeout)); return true; } else { return false; } } } @Override public TimeUnit getTimeUnit() { return TimeUnit.MILLISECONDS; } @Override public long getInterval() { return interval; } @Override public String getName() { return "RESTFacade echo"; } } thdf.submitCancelablePeriodicTask(new Echo()); } @Override public Map<String, HttpCallStatistic> getStatistics() { return statistics; } @Override public <T> void registerSyncHttpCallHandler(String path, final Class<T> objectType, final SyncHttpCallHandler<T> handler) { HttpCallHandlerWrapper wrapper = httpCallhandlers.get(path); if (wrapper != null) { throw new CloudRuntimeException(String.format("duplicate SyncHttpCallHandler[%s, %s] for the command path[%s]", wrapper.getHandler().getClass(), handler.getClass(), path)); } wrapper = new HttpCallHandlerWrapper() { @Override public String handle(HttpEntity<String> entity) { T cmd = JSONObjectUtil.toObject(entity.getBody(), objectType); return handler.handleSyncHttpCall(cmd); } @Override public HttpCallHandler getHandler() { return handler; } }; httpCallhandlers.put(path, wrapper); } @Override public String getBaseUrl() { return baseUrl; } @Override public String getSendCommandUrl() { return sendCommandUrl; } @Override public String getCallbackUrl() { return callbackUrl; } @Override public String makeUrl(String path) { UriComponentsBuilder ub = UriComponentsBuilder.fromHttpUrl(baseUrl); ub.path(path); return ub.build().toUriString(); } @Override public void installBeforeAsyncJsonPostInterceptor(BeforeAsyncJsonPostInterceptor interceptor) { interceptors.add(interceptor); } }