package edu.washington.escience.myria.util.concurrent;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import com.google.common.collect.Sets;
import edu.washington.escience.myria.MyriaConstants;
import edu.washington.escience.myria.util.JVMUtils;
/**
* This class cleanup all currently running user threads. It waits all these threads to finish within some given
* timeout. If timeout, try interrupting them. If any thread is interrupted for a given number of times, stop waiting
* and kill it directly.
* */
public class ShutdownThreadCleaner extends Thread {
/** The logger for this class. */
private static final org.slf4j.Logger LOGGER =
org.slf4j.LoggerFactory.getLogger(ShutdownThreadCleaner.class.getName());
/**
* In wait state for at most 5 seconds.
* */
public static final int DEFAULT_WAIT_MAXIMUM_MS = 5 * 1000;
/**
* @param waitBeforeInterruptMS wait this amount of milliseconds before interrupting
* */
private final int waitBeforeInterruptMS;
/**
* @param numInterruptBeforeKill number of interrupts before kill a thread
* */
private final int numInterruptBeforeKill;
/**
* Interrupt an unresponding thread for at most 3 times.
* */
public static final int DEFAULT_MAX_INTERRUPT_TIMES = 3;
/**
* The thread group executing the application's main function.
* */
private final ThreadGroup mainThreadGroup;
/**
* @param mainThreadGroup the thread group executing the application's main function
* */
public ShutdownThreadCleaner(final ThreadGroup mainThreadGroup) {
super.setDaemon(true);
this.mainThreadGroup = mainThreadGroup;
waitBeforeInterruptMS = DEFAULT_WAIT_MAXIMUM_MS;
numInterruptBeforeKill = DEFAULT_MAX_INTERRUPT_TIMES;
}
/**
* @param mainThreadGroup the thread group executing the application's main function
* @param waitBeforeInterruptMS wait this amount of milliseconds before interrupting
* @param numInterruptBeforeKill number of interrupts before kill a thread
* */
public ShutdownThreadCleaner(
final ThreadGroup mainThreadGroup,
final int waitBeforeInterruptMS,
final int numInterruptBeforeKill) {
super.setDaemon(true);
this.mainThreadGroup = mainThreadGroup;
this.waitBeforeInterruptMS = waitBeforeInterruptMS;
this.numInterruptBeforeKill = numInterruptBeforeKill;
}
/**
* How many milliseconds a thread have been waited to get finish.
* */
private final HashMap<Thread, Integer> waitedForMS = new HashMap<Thread, Integer>();
/**
* Same to watedForMS, but keyed by thread name. The SQLiteQueue threads reincarnate themselves when any error occurs.
* Using watedForMS won't capture the SQLiteQueue threads because the Thread instance will be recreated once they
* receive an InterruptedException
* */
private final HashMap<String, Integer> waitedForMSThreadName = new HashMap<String, Integer>();
/**
* How many times a thread has been interrupted.
* */
private final HashMap<Thread, Integer> interruptTimes = new HashMap<Thread, Integer>();
/**
* The set of threads we have been waiting for the maximum MS, and so have decided to kill them directly.
* */
private final Set<Thread> abandonThreads = Sets.newSetFromMap(new HashMap<Thread, Boolean>());
/**
* utility method, add an integer v to the value of m[t] and return the new value. null key and value are taken ca of.
*
* @return the new value
* @param <KEY> the map key type
* @param m a map
* @param t a thread
* @param v the value
* */
private static <KEY> int addToMap(final Map<KEY, Integer> m, final KEY t, final int v) {
Integer tt = m.get(t);
if (tt == null) {
tt = 0;
}
m.put(t, tt + v);
return tt + v;
}
/**
* utility method, get the value of m[t] . null key and value are taken care of.
*
* @param m a map
* @param t a thread
* @return the value
* */
private int getFromMap(final Map<Thread, Integer> m, final Thread t) {
Integer tt = m.get(t);
if (tt == null) {
tt = 0;
}
return tt;
}
@SuppressWarnings("deprecation")
@Override
public final void run() {
while (true) {
Set<Thread> allThreads = Thread.getAllStackTraces().keySet();
HashMap<Thread, Integer> nonSystemThreads = new HashMap<Thread, Integer>();
for (final Thread t : allThreads) {
if (t.getThreadGroup() != null
&& t.getThreadGroup() != mainThreadGroup
&& t.getThreadGroup() != mainThreadGroup.getParent()
&& t != Thread.currentThread()
&& !abandonThreads.contains(t)) {
nonSystemThreads.put(t, 0);
}
}
if (nonSystemThreads.isEmpty()) {
if (abandonThreads.isEmpty()) {
return;
} else {
JVMUtils.shutdownVM();
}
}
try {
Thread.sleep(MyriaConstants.SHORT_WAITING_INTERVAL_100_MS);
} catch (InterruptedException e) {
JVMUtils.shutdownVM();
}
for (final Thread t : nonSystemThreads.keySet()) {
if (addToMap(
waitedForMSThreadName, t.getName(), MyriaConstants.SHORT_WAITING_INTERVAL_100_MS)
>= waitBeforeInterruptMS * numInterruptBeforeKill) {
abandonThreads.add(t);
} else if (addToMap(waitedForMS, t, MyriaConstants.SHORT_WAITING_INTERVAL_100_MS)
> waitBeforeInterruptMS) {
waitedForMS.put(t, 0);
if (addToMap(interruptTimes, t, 1) > numInterruptBeforeKill) {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
"Thread {} have been interrupted for {} times. Kill it directly.",
t,
getFromMap(interruptTimes, t) - 1);
}
abandonThreads.add(t);
t.stop();
} else {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug(
"Waited Thread {} to finish for {} seconds. I'll try interrupting it.",
t,
TimeUnit.MILLISECONDS.toSeconds(waitBeforeInterruptMS)
* getFromMap(interruptTimes, t));
}
t.interrupt();
}
}
}
}
}
}