package com.github.bingoohuang.springrestclient.provider; import com.github.bingoohuang.utils.codec.Base64; import com.github.bingoohuang.utils.time.Now; import com.google.common.base.Joiner; import com.google.common.base.Throwables; import com.google.common.collect.Maps; import com.google.common.io.Files; import com.mashape.unirest.request.HttpRequest; import com.mashape.unirest.request.ValueUtils; import lombok.val; import org.apache.commons.lang3.ArrayUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import org.springframework.web.multipart.MultipartFile; import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import java.io.File; import java.io.IOException; import java.security.MessageDigest; import java.util.Collection; import java.util.List; import java.util.Map; public class DefaultSignProvider implements SignProvider { public static final String CLIENT_KEY = "8d37d3eb-310f-4354-81bb-222e9441e37f"; public static final String CLIENT_SECURITY = "d51fd93e-f6c9-4eae-ae7a-9b37af1a60cc"; private final String clientKey; private final String clientSecurity; public DefaultSignProvider() { this.clientKey = CLIENT_KEY; this.clientSecurity = CLIENT_SECURITY; } public DefaultSignProvider(Class<?> apiClass) { val clientSign = apiClass.getAnnotation(ClientSign.class); if (clientSign != null) { this.clientKey = clientSign.clientKey(); this.clientSecurity = clientSign.security(); } else { this.clientKey = CLIENT_KEY; this.clientSecurity = CLIENT_SECURITY; } } @Override public void sign( Class<?> apiClass, String uuid, Map<String, Object> requestParams, HttpRequest httpRequest) { httpRequest.header("hict", Now.now()); httpRequest.header("hici", uuid); httpRequest.header("hick", clientKey); httpRequest.header("hisa", "hmac"); httpRequest.header("hisv", hmac(apiClass, requestParams, httpRequest)); } private String hmac( Class<?> apiClass, Map<String, Object> requestParams, HttpRequest httpRequest) { String originalStr = createOriginalStringForSign(apiClass, requestParams, httpRequest); return hmacSHA256(originalStr, clientSecurity); } private String createOriginalStringForSign( Class<?> apiClass, Map<String, Object> requestParams, HttpRequest httpRequest) { val signStr = new StringBuilder(); val logStr = new StringBuilder(); val proxy = new AbbreviateAppendable(logStr, signStr); appendMethodAndUrl(httpRequest, proxy); appendHeaders(httpRequest, proxy); appendRequestParams(requestParams, proxy); Logger logger = LoggerFactory.getLogger(apiClass); logger.debug("string to be signed : {}", logStr); return signStr.toString(); } private void appendMethodAndUrl( HttpRequest httpRequest, Appendable signStr) { signStr.append(httpRequest.getHttpMethod().name()).append('$'); signStr.append(httpRequest.getUrl()).append('$'); } private void appendRequestParams( Map<String, Object> requestParams, Appendable signStr) { Map<String, Object> params = Maps.newTreeMap(); if (requestParams != null) params.putAll(requestParams); for (Map.Entry<String, Object> entry : params.entrySet()) { signStr.append(entry.getKey()).append('$'); Object value = entry.getValue(); boolean isFile = false; if (value instanceof Collection) { isFile = true; for (Object s : (Collection) value) { if (s instanceof File || s instanceof MultipartFile) append(signStr, s); else { isFile = false; break; } } } if (!isFile) append(signStr, value); } } private static String[] filtered = new String[]{ "Content-Type" }; private void appendHeaders(HttpRequest httpRequest, Appendable signStr) { Map<String, List<String>> headers = Maps.newTreeMap(); headers.putAll(httpRequest.getHeaders()); Joiner joiner = Joiner.on('$'); for (Map.Entry<String, List<String>> entry : headers.entrySet()) { String key = entry.getKey(); if (ArrayUtils.contains(filtered, key)) continue; signStr.append(key).append('$') .append(joiner.join(entry.getValue())).append('$'); } } private void append(Appendable signStr, Object object) { if (object instanceof File) { signStr.append(md5((File) object)); } else if (object instanceof MultipartFile) { byte[] bytes = fuckFileGetBytesException((MultipartFile) object); signStr.append(md5(bytes)); } else { signStr.appendAbbreviate(ValueUtils.processValue(object)); } signStr.append('$'); } private byte[] fuckFileGetBytesException(MultipartFile file) { try { return file.getBytes(); } catch (IOException e) { // fuck this will never happen return new byte[0]; } } private String md5(File file) { try { byte[] bytes = Files.toByteArray(file); return md5(bytes); } catch (IOException e) { throw new RuntimeException(e); } } public static String md5(byte[] bytes) { try { val md = MessageDigest.getInstance("MD5"); byte[] digest = md.digest(bytes); return Base64.base64(digest, Base64.Format.Standard); } catch (Exception e) { throw new RuntimeException(e); } } public static String hmacSHA256(String data, String key) { try { val secretKey = new SecretKeySpec(key.getBytes("UTF-8"), "HmacSHA256"); val mac = Mac.getInstance("HmacSHA256"); mac.init(secretKey); val hmacData = mac.doFinal(data.getBytes("UTF-8")); return Base64.base64(hmacData, Base64.Format.Standard); } catch (Throwable e) { throw Throwables.propagate(e); } } }