/* * Encog(tm) Core v3.4 - Java Version * http://www.heatonresearch.com/encog/ * https://github.com/encog/encog-java-core * Copyright 2008-2016 Heaton Research, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * For more information on Heaton Research copyrights, licenses * and trademarks visit: * http://www.heatonresearch.com/copyright */ package org.encog.ml.world.grid; import java.util.ArrayList; import java.util.List; import org.encog.mathutil.EncogMath; import org.encog.ml.world.Action; import org.encog.ml.world.State; import org.encog.ml.world.basic.BasicAction; import org.encog.ml.world.basic.BasicWorld; public class GridWorld extends BasicWorld { public static final Action ACTION_NORTH = new BasicAction("NORTH"); public static final Action ACTION_SOUTH = new BasicAction("SOUTH"); public static final Action ACTION_EAST = new BasicAction("EAST"); public static final Action ACTION_WEST = new BasicAction("WEST"); private GridState[][] state; public GridWorld(int rows, int columns) { addAction(ACTION_NORTH); addAction(ACTION_SOUTH); addAction(ACTION_EAST); addAction(ACTION_WEST); this.state = new GridState[rows][columns]; for (int row = 0; row < rows; row++) { for (int col = 0; col < columns; col++) { GridState state = new GridState(this, row, col, false); addState(state); this.state[row][col] = state; this.state[row][col].setPolicyValueSize(getActions().size()); } } } public static boolean isStateBlocked(GridState state) { if (state == null ) return true; else return false; } public int getRows() { return this.state.length; } public int getColumns() { return this.state[0].length; } public GridState getState(int row, int column) { if (row < 0 || row >= getRows()) { return null; } else if (column < 0 || column >= getColumns()) { return null; } return this.state[row][column]; } public static Action leftOfAction(Action action) { if (action == GridWorld.ACTION_NORTH) { return GridWorld.ACTION_WEST; } else if (action == GridWorld.ACTION_SOUTH) { return GridWorld.ACTION_EAST; } else if (action == GridWorld.ACTION_EAST) { return GridWorld.ACTION_NORTH; } else if (action == GridWorld.ACTION_WEST) { return GridWorld.ACTION_SOUTH; } return null; } public static Action rightOfAction(Action action) { if (action == GridWorld.ACTION_NORTH) { return GridWorld.ACTION_EAST; } else if (action == GridWorld.ACTION_SOUTH) { return GridWorld.ACTION_WEST; } else if (action == GridWorld.ACTION_EAST) { return GridWorld.ACTION_SOUTH; } else if (action == GridWorld.ACTION_WEST) { return GridWorld.ACTION_NORTH; } return null; } public static Action reverseOfAction(Action action) { if (action == GridWorld.ACTION_NORTH) { return GridWorld.ACTION_SOUTH; } else if (action == GridWorld.ACTION_SOUTH) { return GridWorld.ACTION_NORTH; } else if (action == GridWorld.ACTION_EAST) { return GridWorld.ACTION_WEST; } else if (action == GridWorld.ACTION_WEST) { return GridWorld.ACTION_EAST; } return null; } public List<GridState> getAdjacentStates(GridState s) { List<GridState> result = new ArrayList<GridState>(); GridState northState = this.getState(s.getRow() - 1, s.getColumn()); GridState southState = this.getState(s.getRow() + 1, s.getColumn()); GridState eastState = this.getState(s.getRow(), s.getColumn() + 1); GridState westState = this.getState(s.getRow(), s.getColumn() - 1); if (!isStateBlocked(northState)) { result.add(northState); } if (!isStateBlocked(southState)) { result.add(southState); } if (!isStateBlocked(eastState)) { result.add(eastState); } if (!isStateBlocked(westState)) { result.add(westState); } if (!isStateBlocked(s)) { result.add(s); } return result; } public static double euclideanDistance(GridState s1, GridState s2) { double d = EncogMath.square(s1.getRow() - s2.getRow()) + EncogMath.square(s1.getColumn() - s2.getColumn()); return Math.sqrt(d); } public GridState findClosestStateTo(List<GridState> states, GridState goalState) { double min = Double.POSITIVE_INFINITY; GridState minState = null; for (GridState state : states) { double d = euclideanDistance(state, goalState); if (d < min) { min = d; minState = state; } } return minState; } public Action determineActionToState(GridState currentState, GridState targetState) { int rowDiff = currentState.getRow() - targetState.getRow(); int colDiff = currentState.getColumn() - targetState.getColumn(); if (rowDiff == 0 && colDiff == 0) return null; if (Math.abs(rowDiff) >= Math.abs(colDiff)) { if (rowDiff < 0) return GridWorld.ACTION_SOUTH; else return GridWorld.ACTION_NORTH; } else { if (colDiff < 0) return GridWorld.ACTION_EAST; else return GridWorld.ACTION_WEST; } } public GridState findClosestStateToGoal(List<GridState> states) { double min = Double.POSITIVE_INFINITY; GridState minState = null; for (State goalState : this.getGoals()) { for (GridState state : states) { double d = euclideanDistance(state, (GridState) goalState); if (d < min) { min = d; minState = state; } } } return minState; } public void setBlocked(int row, int column) { State state = this.state[row][column]; this.state[row][column] = null; this.getStates().remove(state); } }