/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.kafka.streams.processor.internals.assignment;
import org.apache.kafka.streams.processor.TaskId;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
public class StickyTaskAssignor<ID> implements TaskAssignor<ID, TaskId> {
private static final Logger log = LoggerFactory.getLogger(StickyTaskAssignor.class);
private final Map<ID, ClientState> clients;
private final Set<TaskId> taskIds;
private final Map<TaskId, ID> previousActiveTaskAssignment = new HashMap<>();
private final Map<TaskId, Set<ID>> previousStandbyTaskAssignment = new HashMap<>();
private final TaskPairs taskPairs;
public StickyTaskAssignor(final Map<ID, ClientState> clients, final Set<TaskId> taskIds) {
this.clients = clients;
this.taskIds = taskIds;
taskPairs = new TaskPairs(taskIds.size() * (taskIds.size() - 1) / 2);
mapPreviousTaskAssignment(clients);
}
@Override
public void assign(final int numStandbyReplicas) {
assignActive();
assignStandby(numStandbyReplicas);
}
private void assignStandby(final int numStandbyReplicas) {
for (final TaskId taskId : taskIds) {
for (int i = 0; i < numStandbyReplicas; i++) {
final Set<ID> ids = findClientsWithoutAssignedTask(taskId);
if (ids.isEmpty()) {
log.warn("Unable to assign {} of {} standby tasks for task [{}]. " +
"There is not enough available capacity. You should " +
"increase the number of threads and/or application instances " +
"to maintain the requested number of standby replicas.",
numStandbyReplicas - i,
numStandbyReplicas, taskId);
break;
}
allocateTaskWithClientCandidates(taskId, ids, false);
}
}
}
private void assignActive() {
final int totalCapacity = sumCapacity(clients.values());
final int tasksPerThread = taskIds.size() / totalCapacity;
final Set<TaskId> assigned = new HashSet<>();
// first try and re-assign existing active tasks to clients that previously had
// the same active task
for (final Map.Entry<TaskId, ID> entry : previousActiveTaskAssignment.entrySet()) {
final TaskId taskId = entry.getKey();
if (taskIds.contains(taskId)) {
final ClientState client = clients.get(entry.getValue());
if (client.hasUnfulfilledQuota(tasksPerThread)) {
assignTaskToClient(assigned, taskId, client);
}
}
}
final Set<TaskId> unassigned = new HashSet<>(taskIds);
unassigned.removeAll(assigned);
// try and assign any remaining unassigned tasks to clients that previously
// have seen the task.
for (final Iterator<TaskId> iterator = unassigned.iterator(); iterator.hasNext(); ) {
final TaskId taskId = iterator.next();
final Set<ID> clientIds = previousStandbyTaskAssignment.get(taskId);
if (clientIds != null) {
for (final ID clientId : clientIds) {
final ClientState client = clients.get(clientId);
if (client.hasUnfulfilledQuota(tasksPerThread)) {
assignTaskToClient(assigned, taskId, client);
iterator.remove();
break;
}
}
}
}
// assign any remaining unassigned tasks
for (final TaskId taskId : unassigned) {
allocateTaskWithClientCandidates(taskId, clients.keySet(), true);
}
}
private void allocateTaskWithClientCandidates(final TaskId taskId, final Set<ID> clientsWithin, final boolean active) {
final ClientState client = findClient(taskId, clientsWithin);
taskPairs.addPairs(taskId, client.assignedTasks());
client.assign(taskId, active);
}
private void assignTaskToClient(final Set<TaskId> assigned, final TaskId taskId, final ClientState client) {
taskPairs.addPairs(taskId, client.assignedTasks());
client.assign(taskId, true);
assigned.add(taskId);
}
private Set<ID> findClientsWithoutAssignedTask(final TaskId taskId) {
final Set<ID> clientIds = new HashSet<>();
for (final Map.Entry<ID, ClientState> client : clients.entrySet()) {
if (!client.getValue().hasAssignedTask(taskId)) {
clientIds.add(client.getKey());
}
}
return clientIds;
}
private ClientState findClient(final TaskId taskId, final Set<ID> clientsWithin) {
// optimize the case where there is only 1 id to search within.
if (clientsWithin.size() == 1) {
return clients.get(clientsWithin.iterator().next());
}
final ClientState previous = findClientsWithPreviousAssignedTask(taskId, clientsWithin);
if (previous == null) {
return leastLoaded(taskId, clientsWithin);
}
if (shouldBalanceLoad(previous)) {
final ClientState standby = findLeastLoadedClientWithPreviousStandByTask(taskId, clientsWithin);
if (standby == null
|| shouldBalanceLoad(standby)) {
return leastLoaded(taskId, clientsWithin);
}
return standby;
}
return previous;
}
private boolean shouldBalanceLoad(final ClientState client) {
return client.reachedCapacity() && hasClientsWithMoreAvailableCapacity(client);
}
private boolean hasClientsWithMoreAvailableCapacity(final ClientState client) {
for (ClientState clientState : clients.values()) {
if (clientState.hasMoreAvailableCapacityThan(client)) {
return true;
}
}
return false;
}
private ClientState findClientsWithPreviousAssignedTask(final TaskId taskId,
final Set<ID> clientsWithin) {
final ID previous = previousActiveTaskAssignment.get(taskId);
if (previous != null && clientsWithin.contains(previous)) {
return clients.get(previous);
}
return findLeastLoadedClientWithPreviousStandByTask(taskId, clientsWithin);
}
private ClientState findLeastLoadedClientWithPreviousStandByTask(final TaskId taskId, final Set<ID> clientsWithin) {
final Set<ID> ids = previousStandbyTaskAssignment.get(taskId);
if (ids == null) {
return null;
}
final HashSet<ID> constrainTo = new HashSet<>(ids);
constrainTo.retainAll(clientsWithin);
return leastLoaded(taskId, constrainTo);
}
private ClientState leastLoaded(final TaskId taskId, final Set<ID> clientIds) {
final ClientState leastLoaded = findLeastLoaded(taskId, clientIds, true);
if (leastLoaded == null) {
return findLeastLoaded(taskId, clientIds, false);
}
return leastLoaded;
}
private ClientState findLeastLoaded(final TaskId taskId,
final Set<ID> clientIds,
boolean checkTaskPairs) {
ClientState leastLoaded = null;
for (final ID id : clientIds) {
final ClientState client = clients.get(id);
if (client.assignedTaskCount() == 0) {
return client;
}
if (leastLoaded == null || client.hasMoreAvailableCapacityThan(leastLoaded)) {
if (!checkTaskPairs) {
leastLoaded = client;
} else if (taskPairs.hasNewPair(taskId, client.assignedTasks())) {
leastLoaded = client;
}
}
}
return leastLoaded;
}
private void mapPreviousTaskAssignment(final Map<ID, ClientState> clients) {
for (final Map.Entry<ID, ClientState> clientState : clients.entrySet()) {
for (final TaskId activeTask : clientState.getValue().previousActiveTasks()) {
previousActiveTaskAssignment.put(activeTask, clientState.getKey());
}
for (final TaskId prevAssignedTask : clientState.getValue().previousStandbyTasks()) {
if (!previousStandbyTaskAssignment.containsKey(prevAssignedTask)) {
previousStandbyTaskAssignment.put(prevAssignedTask, new HashSet<ID>());
}
previousStandbyTaskAssignment.get(prevAssignedTask).add(clientState.getKey());
}
}
}
private int sumCapacity(final Collection<ClientState> values) {
int capacity = 0;
for (ClientState client : values) {
capacity += client.capacity();
}
return capacity;
}
private static class TaskPairs {
private final Set<Pair> pairs;
private final int maxPairs;
TaskPairs(final int maxPairs) {
this.maxPairs = maxPairs;
this.pairs = new HashSet<>(maxPairs);
}
boolean hasNewPair(final TaskId task1, final Set<TaskId> taskIds) {
if (pairs.size() == maxPairs) {
return false;
}
for (final TaskId taskId : taskIds) {
if (!pairs.contains(pair(task1, taskId))) {
return true;
}
}
return false;
}
void addPairs(final TaskId taskId, final Set<TaskId> assigned) {
for (final TaskId id : assigned) {
pairs.add(pair(id, taskId));
}
}
Pair pair(final TaskId task1, final TaskId task2) {
if (task1.compareTo(task2) < 0) {
return new Pair(task1, task2);
}
return new Pair(task2, task1);
}
private static class Pair {
private final TaskId task1;
private final TaskId task2;
Pair(final TaskId task1, final TaskId task2) {
this.task1 = task1;
this.task2 = task2;
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final Pair pair = (Pair) o;
return Objects.equals(task1, pair.task1) &&
Objects.equals(task2, pair.task2);
}
@Override
public int hashCode() {
return Objects.hash(task1, task2);
}
}
}
}