package com.breakersoft.plow.dispatcher.dao;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import org.springframework.jdbc.core.PreparedStatementCreator;
import org.springframework.jdbc.core.RowMapper;
import org.springframework.stereotype.Repository;
import com.breakersoft.plow.JobId;
import com.breakersoft.plow.Task;
import com.breakersoft.plow.dao.AbstractDao;
import com.breakersoft.plow.dispatcher.domain.DispatchProc;
import com.breakersoft.plow.dispatcher.domain.DispatchResource;
import com.breakersoft.plow.dispatcher.domain.DispatchTask;
import com.breakersoft.plow.rnd.thrift.RunTaskCommand;
import com.breakersoft.plow.thrift.TaskState;
import com.google.common.collect.Maps;
@Repository
public class DispatchTaskDaoImpl extends AbstractDao implements DispatchTaskDao {
private static final String RETRY =
"SELECT " +
"layer.int_retries_max - task.int_retry " +
"FROM " +
"task,"+
"layer " +
"WHERE " +
"task.pk_layer = layer.pk_layer " +
"AND " +
"task.pk_task = ?";
@Override
public boolean isAtMaxRetries(Task task) {
return jdbc.queryForObject(RETRY, Integer.class, task.getTaskId()) <= 0;
}
@Override
public boolean reserve(Task task) {
try {
jdbc.queryForObject("SELECT task.pk_task FROM plow.task WHERE pk_task=? AND int_state=? AND bool_reserved='f' FOR UPDATE NOWAIT",
String.class, task.getTaskId(), TaskState.WAITING.ordinal());
return jdbc.update("UPDATE plow.task SET bool_reserved='t' " +
"WHERE pk_task=? AND int_state=? AND bool_reserved='f'", task.getTaskId(), TaskState.WAITING.ordinal()) == 1;
} catch (Exception e) {
return false;
}
}
@Override
public boolean unreserve(Task task) {
return jdbc.update("UPDATE plow.task SET bool_reserved='f' " +
"WHERE pk_task=? AND bool_reserved='t'", task.getTaskId()) == 1;
}
public static final String START_TASK =
"UPDATE " +
"plow.task " +
"SET " +
"int_state = ?, " +
"bool_reserved = 'f',"+
"int_retry=int_retry+1,"+
"time_updated = txTimeMillis(), " +
"time_started = txTimeMillis(), " +
"time_stopped = 0, " +
"str_last_node_name=?,"+
"int_last_ram=?,"+
"int_last_ram_high=0,"+
"int_last_cores=?,"+
"flt_last_cores_high=0" +
"WHERE " +
"task.pk_task = ? " +
"AND " +
"int_state = ? " +
"AND " +
"bool_reserved = 't'";
@Override
public boolean start(Task task, DispatchProc proc) {
return jdbc.update(START_TASK,
TaskState.RUNNING.ordinal(),
proc.getHostname(),
proc.getIdleRam(),
proc.getIdleCores(),
task.getTaskId(),
TaskState.WAITING.ordinal()) == 1;
}
public static final String STOP_TASK =
"UPDATE " +
"plow.task " +
"SET " +
"int_state = ?, " +
"bool_reserved = 'f', " +
"time_stopped = currentTimeMillis(), " +
"time_updated = currentTimeMillis(), " +
"int_exit_status=?," +
"int_exit_signal=? " +
"WHERE " +
"task.pk_task = ? " +
"AND " +
"int_state = ? ";
@Override
public boolean stop(Task task, TaskState newState, int exitStatus, int exitSignal) {
if (jdbc.update(STOP_TASK,
newState.ordinal(),
exitStatus,
exitSignal,
task.getTaskId(),
TaskState.RUNNING.ordinal()) == 1) {
if (newState.equals(TaskState.SUCCEEDED)) {
jdbc.update("INSERT INTO depend_queue (pk_job, pk_layer, pk_task) VALUES (?,?,?)",
task.getJobId(), task.getLayerId(), task.getTaskId());
}
return true;
}
return false;
}
@Override
public boolean dependQueueProcessed(Task task) {
return jdbc.update("DELETE FROM depend_queue WHERE pk_task=?", task.getTaskId()) == 1;
}
private static final String GET_DISPATCHABLE_TASKS =
"SELECT " +
"task.pk_task,"+
"task.pk_layer,"+
"task.pk_job,"+
"task.str_name," +
"layer.int_ram_min, " +
"layer.int_cores_min,"+
"job.pk_project " +
"FROM " +
"plow.layer " +
"INNER JOIN " +
"plow.task ON layer.pk_layer = task.pk_layer " +
"INNER JOIN " +
"plow.job ON layer.pk_job = job.pk_job " +
"WHERE " +
"layer.pk_job = ? " +
"AND " +
"layer.int_cores_min <= ? " +
"AND " +
"layer.int_ram_min <= ? " +
"AND " +
"layer.str_tags && ? " +
"AND " +
"task.int_state = ? " +
"AND " +
"task.bool_reserved IS FALSE " +
"ORDER BY " +
"task.int_task_order, task.int_layer_order ASC " +
"LIMIT ?";
public static final RowMapper<DispatchTask> DISPATCHABLE_TASK_MAPPER =
new RowMapper<DispatchTask>() {
@Override
public DispatchTask mapRow(ResultSet rs, int rowNum)
throws SQLException {
DispatchTask task = new DispatchTask();
task.projectId = (UUID) rs.getObject("pk_project");
task.taskId = (UUID) rs.getObject("pk_task");
task.layerId = (UUID) rs.getObject("pk_layer");
task.jobId = (UUID) rs.getObject("pk_job");
task.minCores = rs.getInt("int_cores_min");
task.minRam = rs.getInt("int_ram_min");
task.name = rs.getString("str_name");
return task;
}
};
@Override
public List<DispatchTask> getDispatchableTasks(final JobId job, final DispatchResource resource, final int limit) {
return jdbc.query(new PreparedStatementCreator() {
@Override
public PreparedStatement createPreparedStatement(final Connection conn) throws SQLException {
final PreparedStatement ps = conn.prepareStatement(GET_DISPATCHABLE_TASKS);
ps.setObject(1, job.getJobId());
ps.setInt(2, resource.getIdleCores());
ps.setInt(3, resource.getIdleRam());
ps.setArray(4, conn.createArrayOf("text", resource.getTags().toArray()));
ps.setInt(5, TaskState.WAITING.ordinal());
ps.setInt(6, limit);
return ps;
}
}, DISPATCHABLE_TASK_MAPPER);
}
private static final String GET_RUN_TASK =
"SELECT " +
"job.int_uid," +
"job.str_username," +
"job.str_log_path, " +
"job.str_active_name AS job_name, " +
"job.hstore_env AS job_env, " +
"layer.str_command, " +
"layer.str_name AS layer_name, " +
"layer.hstore_env AS layer_env, " +
"layer.int_chunk_size, " +
"task.int_number, " +
"task.pk_task,"+
"task.pk_layer,"+
"task.pk_job,"+
"task.str_name AS task_name, " +
"task.int_retry, " +
"proc.pk_proc,"+
"proc.int_cores " +
"FROM " +
"plow.task " +
"INNER JOIN plow.proc ON task.pk_task = proc.pk_task " +
"INNER JOIN plow.layer ON layer.pk_layer = task.pk_layer " +
"INNER JOIN plow.job ON layer.pk_job = job.pk_job " +
"WHERE " +
"task.pk_task = ? ";
public static final RowMapper<RunTaskCommand> RUN_TASK_MAPPER =
new RowMapper<RunTaskCommand>() {
@Override
public RunTaskCommand mapRow(ResultSet rs, int rowNum)
throws SQLException {
RunTaskCommand task = new RunTaskCommand();
task.jobId = rs.getString("pk_job");
task.taskId = rs.getString("pk_task");
task.layerId = rs.getString("pk_layer");
task.procId = rs.getString("pk_proc");
task.cores = rs.getInt("int_cores");
task.logFile = String.format("%s/%s.%d.log",
rs.getString("str_log_path"), rs.getString("task_name"),
rs.getInt("int_retry"));
task.uid = rs.getInt("int_uid");
task.username = rs.getString("str_username");
task.command = Arrays.asList((String[])rs.getArray("str_command").getArray());
for (int i=0; i<task.command.size(); i++) {
String part = task.command.get(i);
part = part.replace("%{FRAME}", String.valueOf(rs.getInt("int_number")));
part = part.replace("%{TASK}", rs.getString("task_name"));
task.command.set(i, part);
}
task.env = Maps.newHashMap();
Map<String,String> job_env = (Map<String, String>) rs.getObject("job_env");
if (job_env != null) {
task.env.putAll(job_env);
}
Map<String,String> layer_env = (Map<String, String>) rs.getObject("layer_env");
if (layer_env != null) {
task.env.putAll(layer_env);
}
task.env.put("PLOW_TASK_ID", rs.getString("pk_task"));
task.env.put("PLOW_JOB_ID", rs.getString("pk_job"));
task.env.put("PLOW_PROC_ID", rs.getString("pk_proc"));
task.env.put("PLOW_LAYER_ID", rs.getString("pk_layer"));
task.env.put("PLOW_JOB_NAME", rs.getString("job_name"));
task.env.put("PLOW_LAYER_NAME", rs.getString("layer_name"));
task.env.put("PLOW_TASK_NAME", rs.getString("task_name"));
task.env.put("PLOW_LOG_DIR", rs.getString("str_log_path"));
task.env.put("PLOW_UID", rs.getString("int_uid"));
task.env.put("PLOW_TASK_NUMBER", rs.getString("int_number"));
task.env.put("PLOW_CHUNK", rs.getString("int_chunk_size"));
return task;
}
};
@Override
public RunTaskCommand getRunTaskCommand(Task task) {
return jdbc.queryForObject(
GET_RUN_TASK, RUN_TASK_MAPPER, task.getTaskId());
}
}