/* * Copyright (C) 2015 Actor LLC. <https://actor.im> */ package im.actor.core.network.api; import java.io.IOException; import java.util.HashMap; import im.actor.core.api.parser.RpcParser; import im.actor.core.api.parser.UpdatesParser; import im.actor.core.network.*; import im.actor.core.network.parser.ApiParserConfig; import im.actor.core.network.parser.ParsingExtension; import im.actor.runtime.*; import im.actor.runtime.Runtime; import im.actor.runtime.actors.ActorRef; import im.actor.runtime.actors.AskcableActor; import im.actor.core.util.RandomUtils; import im.actor.core.network.mtp.MTProto; import im.actor.core.network.mtp.MTProtoCallback; import im.actor.core.network.mtp.entity.ProtoSerializer; import im.actor.core.network.mtp.entity.ProtoStruct; import im.actor.core.network.mtp.entity.rpc.Push; import im.actor.core.network.mtp.entity.rpc.RpcError; import im.actor.core.network.mtp.entity.rpc.RpcFloodWait; import im.actor.core.network.mtp.entity.rpc.RpcInternalError; import im.actor.core.network.mtp.entity.rpc.RpcOk; import im.actor.core.network.mtp.entity.rpc.RpcRequest; import im.actor.core.network.parser.Request; import im.actor.core.network.parser.Response; import im.actor.core.network.parser.RpcScope; import im.actor.runtime.promise.Promise; import im.actor.runtime.threading.AtomicIntegerCompat; import im.actor.runtime.threading.CommonTimer; public class ApiBroker extends AskcableActor { public static ApiBrokerInt get(final Endpoints endpoints, final AuthKeyStorage keyStorage, final ActorApiCallback callback, final boolean isEnableLog, int id, final int minDelay, final int maxDelay, final int maxFailureCount) { return new ApiBrokerInt(endpoints, keyStorage, callback, isEnableLog, id, minDelay, maxDelay, maxFailureCount); } private static final String TAG = "ApiBroker"; private static final AtomicIntegerCompat NEXT_PROTO_ID = im.actor.runtime.Runtime.createAtomicInt(1); private Endpoints endpoints; private final AuthKeyStorage keyStorage; private final ActorApiCallback callback; private final boolean isEnableLog; private final int minDelay; private final int maxDelay; private final int maxFailureCount; private final HashMap<Long, RequestHolder> requests = new HashMap<>(); private final HashMap<Long, Long> idMap = new HashMap<>(); private final HashMap<Long, CommonTimer> timeouts = new HashMap<>(); private long currentAuthId; private MTProto proto; private ActorRef keyManager; private ApiParserConfig parserConfig; public ApiBroker(Endpoints endpoints, AuthKeyStorage keyStorage, ActorApiCallback callback, boolean isEnableLog, int minDelay, int maxDelay, int maxFailureCount) { this.isEnableLog = isEnableLog; this.endpoints = endpoints; this.keyStorage = keyStorage; this.callback = callback; this.minDelay = minDelay; this.maxDelay = maxDelay; this.maxFailureCount = maxFailureCount; this.parserConfig = new ApiParserConfig(); this.parserConfig.addExtension(new ParsingExtension(new RpcParser(), new UpdatesParser())); } @Override public void preStart() { this.currentAuthId = keyStorage.getAuthKey(); this.keyManager = system().actorOf(getPath() + "/key", AuthKeyActor::new); if (currentAuthId == 0) { this.keyManager.send(new AuthKeyActor.StartKeyCreation(this.endpoints), self()); } else { if (isEnableLog) { Log.d(TAG, "Key loaded: " + currentAuthId); } self().send(new InitMTProto(currentAuthId, keyStorage.getAuthMasterKey())); } } public void changeEndpoints(Endpoints endpoints) { if (endpoints.equals(this.endpoints)) { return; } this.endpoints = endpoints; recreateAuthId(); } @Override public void postStop() { if (proto != null) { proto.stopProto(); proto = null; } } private void onNetworkChanged(NetworkState state) { if (proto != null) { proto.onNetworkChanged(state); } } private void forceNetworkCheck() { if (proto != null) { proto.forceNetworkCheck(); } } private void onNewSessionCreated(long authId) { if (authId != currentAuthId) { return; } Log.w(TAG, "New Session Created"); callback.onNewSessionCreated(); } private void onAuthIdInvalidated(long authId) { if (authId != currentAuthId) { return; } Log.w(TAG, "Auth id invalidated"); callback.onAuthIdInvalidated(); recreateAuthId(); } private void recreateAuthId() { keyStorage.saveAuthKey(0); keyStorage.saveMasterKey(null); currentAuthId = 0; proto = null; this.keyManager.send(new AuthKeyActor.StartKeyCreation(this.endpoints), self()); } private void onKeyCreated(long authId, byte[] authKey) { Log.w(TAG, "Auth id created #" + authId); keyStorage.saveAuthKey(authId); keyStorage.saveMasterKey(authKey); self().send(new InitMTProto(authId, authKey)); } private void createMtProto(long key, byte[] authKey) { Log.d(TAG, "Creating proto"); keyStorage.saveAuthKey(key); keyStorage.saveMasterKey(authKey); currentAuthId = key; proto = new MTProto(key, authKey, RandomUtils.nextRid(), endpoints, new ProtoCallback(key), isEnableLog, getPath() + "/proto#" + NEXT_PROTO_ID.incrementAndGet(), minDelay, maxDelay, maxFailureCount); for (RequestHolder holder : requests.values()) { holder.protoId = proto.sendRpcMessage(holder.message); idMap.put(holder.protoId, holder.publicId); // Log.d(TAG, holder.message + " rid#" + holder.publicId + " <- mid#" + holder.protoId); } } private void performRequest(long randomId, Request message, RpcCallback callback, long timeout) { Log.d(TAG, "-> request#" + randomId + ": " + message); // Log.d(TAG, message + " rid#" + randomId); RequestHolder holder = new RequestHolder( Runtime.getCurrentTime(), randomId, new RpcRequest(message.getHeaderKey(), message.toByteArray()), callback); requests.put(holder.publicId, holder); if (proto != null) { long mid = proto.sendRpcMessage(holder.message); holder.protoId = mid; idMap.put(mid, randomId); // Log.d(TAG, message + " rid#" + randomId + " <- mid#" + mid); } if (timeout > 0) { CommonTimer commonTimer = new CommonTimer(new TimeoutTask(holder.publicId)); timeouts.put(holder.publicId, commonTimer); commonTimer.schedule(timeout); } } private void processResponse(long authId, long mid, byte[] content) { if (authId != currentAuthId) { return; } ProtoStruct protoStruct; try { protoStruct = ProtoSerializer.readRpcResponsePayload(content); } catch (IOException e) { e.printStackTrace(); Log.w(TAG, "Broken response mid#" + mid); return; } // Log.w(TAG, protoStruct + " mid#" + mid); long rid; if (idMap.containsKey(mid)) { rid = idMap.get(mid); } else { return; } CommonTimer timer = timeouts.get(rid); if (timer != null) { timer.cancel(); timeouts.remove(rid); } RequestHolder holder; if (requests.containsKey(rid)) { holder = requests.get(rid); } else { return; } if (protoStruct instanceof RpcOk) { RpcOk ok = (RpcOk) protoStruct; requests.remove(rid); if (holder.protoId != 0) { idMap.remove(holder.protoId); } Response response; try { response = (Response) parserConfig.parseRpc(ok.responseType, ok.payload); } catch (IOException e) { e.printStackTrace(); requests.remove(rid); if (holder.protoId != 0) { idMap.remove(holder.protoId); } holder.callback.onError(new RpcInternalException()); return; } Log.d(TAG, "<- response#" + holder.publicId + ": " + response + " in " + (Runtime.getCurrentTime() - holder.requestTime) + " ms"); holder.callback.onResult(response); } else if (protoStruct instanceof RpcError) { RpcError e = (RpcError) protoStruct; requests.remove(rid); if (holder.protoId != 0) { idMap.remove(holder.protoId); } Log.w(TAG, "<- error#" + holder.publicId + ": " + e.errorTag + " " + e.errorCode + " " + e.userMessage + " in " + (Runtime.getCurrentTime() - holder.requestTime) + " ms"); holder.callback.onError(new RpcException(e.errorTag, e.errorCode, e.userMessage, e.canTryAgain, e.relatedData)); } else if (protoStruct instanceof RpcInternalError) { RpcInternalError e = ((RpcInternalError) protoStruct); Log.d(TAG, "<- internal_error#" + holder.publicId + " " + e.getTryAgainDelay() + " sec" + " in " + (Runtime.getCurrentTime() - holder.requestTime) + " ms"); if (e.isCanTryAgain()) { schedule(new ForceResend(rid), e.getTryAgainDelay() * 1000L); } else { requests.remove(rid); if (holder.protoId != 0) { idMap.remove(holder.protoId); } holder.callback.onError(new RpcInternalException()); } } else if (protoStruct instanceof RpcFloodWait) { RpcFloodWait f = (RpcFloodWait) protoStruct; Log.d(TAG, "<- flood_wait#" + holder.publicId + " " + f.getDelay() + " sec" + " in " + (Runtime.getCurrentTime() - holder.requestTime) + " ms"); schedule(new ForceResend(rid), f.getDelay() * 1000L); } else { Log.d(TAG, "<- unknown_package#" + holder.publicId + " in " + (Runtime.getCurrentTime() - holder.requestTime) + " ms"); } } private void forceResend(long randomId) { RequestHolder holder = requests.get(randomId); if (holder != null) { if (holder.protoId != 0) { idMap.remove(holder.protoId); proto.cancelRpc(holder.protoId); } long mid = proto.sendRpcMessage(holder.message); holder.protoId = mid; idMap.put(mid, randomId); } } private void cancelRequest(long randomId) { CommonTimer timer = timeouts.get(randomId); if (timer != null) { timer.cancel(); timeouts.remove(randomId); } RequestHolder holder = requests.get(randomId); if (holder != null) { requests.remove(randomId); if (holder.protoId != 0 && proto != null) { idMap.remove(holder.protoId); proto.cancelRpc(holder.protoId); } } } private void processUpdate(long authId, byte[] content) { if (authId != currentAuthId) { return; } ProtoStruct protoStruct; try { protoStruct = ProtoSerializer.readUpdate(content); } catch (IOException e) { e.printStackTrace(); Log.w(TAG, "Broken mt update"); return; } int type = ((Push) protoStruct).updateType; byte[] body = ((Push) protoStruct).body; RpcScope updateBox; try { updateBox = parserConfig.parseRpc(type, body); } catch (IOException e) { e.printStackTrace(); Log.w(TAG, "Broken update box"); return; } // Log.w(TAG, "Box: " + updateBox + ""); callback.onUpdateReceived(updateBox); } void connectionCountChanged(int count) { callback.onConnectionsChanged(count); } private Promise<Boolean> checkIsCurrentAuthId(long authId) { return new Promise<>(resolver -> resolver.result(authId == currentAuthId)); } public static class PerformRequest { private Request message; private RpcCallback callback; private long rid; private long timeout; public PerformRequest(long rid, Request message, RpcCallback callback) { this.rid = rid; this.message = message; this.callback = callback; this.timeout = 0; } public PerformRequest(long rid, Request message, RpcCallback callback, long timeout) { this.rid = rid; this.message = message; this.callback = callback; this.timeout = timeout; } public long getRid() { return rid; } public Request getMessage() { return message; } public RpcCallback getCallback() { return callback; } public long getTimeout() { return timeout; } } public static class CancelRequest { private long randomId; public CancelRequest(long randomId) { this.randomId = randomId; } public long getRandomId() { return randomId; } } public static class NetworkChanged { private NetworkState state; public NetworkChanged(NetworkState state) { this.state = state; } public NetworkState getState() { return state; } } public static class ForceNetworkCheck { } public static class ChangeEndpoints { Endpoints endpoints; public ChangeEndpoints(Endpoints endpoints) { this.endpoints = endpoints; } public Endpoints getEndpoints() { return endpoints; } } private class InitMTProto { private long authId; private byte[] authKey; public InitMTProto(long authId, byte[] authKey) { this.authId = authId; this.authKey = authKey; } public long getAuthId() { return authId; } public byte[] getAuthKey() { return authKey; } } private class ProtoResponse { private long authId; private long responseId; private byte[] data; public ProtoResponse(long authId, long responseId, byte[] data) { this.authId = authId; this.responseId = responseId; this.data = data; } public long getAuthId() { return authId; } public long getResponseId() { return responseId; } public byte[] getData() { return data; } } private class ProtoUpdate { private long authId; private byte[] data; public ProtoUpdate(long authId, byte[] data) { this.authId = authId; this.data = data; } public long getAuthId() { return authId; } public byte[] getData() { return data; } } private class ForceResend { private long id; public ForceResend(long id) { this.id = id; } public long getId() { return id; } } private class TimeoutTask implements Runnable { private final long rid; public TimeoutTask(long rid) { this.rid = rid; } @Override public void run() { if (requests.containsKey(rid)) { RpcCallback callBack = requests.get(rid).callback; if (callBack != null) callBack.onError(new RpcTimeoutException()); } cancelRequest(rid); } } private class RequestHolder { private final long requestTime; private final RpcRequest message; private final long publicId; private final RpcCallback callback; private long protoId; private RequestHolder(long requestTime, long publicId, RpcRequest message, RpcCallback callback) { this.requestTime = requestTime; this.message = message; this.publicId = publicId; this.callback = callback; } } private class NewSessionCreated { private long authId; public NewSessionCreated(long authId) { this.authId = authId; } public long getAuthId() { return authId; } } private class AuthIdInvalidated { private long authId; public AuthIdInvalidated(long authId) { this.authId = authId; } public long getAuthId() { return authId; } } private class ConnectionsCountChanged { private int count; public ConnectionsCountChanged(int count) { this.count = count; } public int getCount() { return count; } } private class ProtoCallback implements MTProtoCallback { private long authId; public ProtoCallback(long authId) { this.authId = authId; } @Override public void onRpcResponse(long mid, byte[] content) { self().send(new ProtoResponse(authId, mid, content)); } @Override public void onUpdate(byte[] content) { self().send(new ProtoUpdate(authId, content)); } @Override public void onAuthKeyInvalidated(long authId) { if (this.authId != authId) { // But why?? return; } self().send(new AuthIdInvalidated(authId)); } @Override public void onSessionCreated() { self().send(new NewSessionCreated(authId)); } @Override public void onConnectionsCountChanged(int count) { self().send(new ConnectionsCountChanged(count)); } } @Override public Promise onAsk(Object message) throws Exception { if (message instanceof CheckIsCurrentAuthId) { return checkIsCurrentAuthId(((CheckIsCurrentAuthId) message).getAuthId()); } return super.onAsk(message); } @Override public void onReceive(Object message) { if (message instanceof InitMTProto) { InitMTProto initMTProto = (InitMTProto) message; createMtProto(initMTProto.getAuthId(), initMTProto.getAuthKey()); } else if (message instanceof PerformRequest) { PerformRequest request = (PerformRequest) message; performRequest(request.getRid(), request.getMessage(), request.getCallback(), request.getTimeout()); } else if (message instanceof CancelRequest) { CancelRequest cancelRequest = (CancelRequest) message; cancelRequest(cancelRequest.getRandomId()); } else if (message instanceof ProtoResponse) { ProtoResponse response = (ProtoResponse) message; processResponse(response.getAuthId(), response.getResponseId(), response.getData()); } else if (message instanceof ForceResend) { ForceResend forceResend = (ForceResend) message; forceResend(forceResend.getId()); } else if (message instanceof ProtoUpdate) { ProtoUpdate update = (ProtoUpdate) message; processUpdate(update.getAuthId(), update.getData()); } else if (message instanceof NewSessionCreated) { NewSessionCreated newSessionCreated = (NewSessionCreated) message; onNewSessionCreated(newSessionCreated.getAuthId()); } else if (message instanceof AuthIdInvalidated) { AuthIdInvalidated authIdInvalidated = (AuthIdInvalidated) message; onAuthIdInvalidated(authIdInvalidated.getAuthId()); } else if (message instanceof NetworkChanged) { NetworkChanged networkChanged = (NetworkChanged) message; onNetworkChanged(networkChanged.getState()); } else if (message instanceof ForceNetworkCheck) { forceNetworkCheck(); } else if (message instanceof ConnectionsCountChanged) { connectionCountChanged(((ConnectionsCountChanged) message).getCount()); } else if (message instanceof AuthKeyActor.KeyCreated) { onKeyCreated(((AuthKeyActor.KeyCreated) message).getAuthKeyId(), ((AuthKeyActor.KeyCreated) message).getAuthKey()); } else if (message instanceof ChangeEndpoints) { changeEndpoints(((ChangeEndpoints) message).getEndpoints()); } else { super.onReceive(message); } } }