package beast.evolution.alignment; import java.util.ArrayList; import java.util.Arrays; import java.util.List; import beast.core.Description; import beast.core.Input; import beast.core.Input.Validate; import beast.core.parameter.IntegerParameter; import beast.core.util.Log; import beast.evolution.datatype.DataType; @Description("Alignment based on a filter operation on another alignment") public class FilteredAlignment extends Alignment { final public Input<String> filterInput = new Input<>("filter", "specifies which of the sites in the input alignment should be selected " + "First site is 1." + "Filter specs are comma separated, either a singleton, a range [from]-[to] or iteration [from]:[to]:[step]; " + "1-100 defines a range, " + "1-100\3 or 1:100:3 defines every third in range 1-100, " + "1::3,2::3 removes every third site. " + "Default for range [1]-[last site], default for iterator [1]:[last site]:[1]", Validate.REQUIRED); final public Input<Alignment> alignmentInput = new Input<>("data", "alignment to be filtered", Validate.REQUIRED); final public Input<IntegerParameter> constantSiteWeightsInput = new Input<>("constantSiteWeights", "if specified, constant " + "sites will be added with weights specified by the input. The dimension and order of weights must match the datatype. " + "For example for nucleotide data, a 4 dimensional " + "parameter with weights for A, C, G and T respectively need to be specified."); // these triples specify a range for(i=From; i <= To; i += Step) int[] from; int[] to; int[] step; /** * list of indices filtered from input alignment * */ int[] filter; boolean convertDataType = false; public FilteredAlignment() { sequenceInput.setRule(Validate.OPTIONAL); // it does not make sense to set weights on sites, since they can be scrambled by the filter siteWeightsInput.setRule(Validate.FORBIDDEN); } @Override public void initAndValidate() { parseFilterSpec(); calcFilter(); Alignment data = alignmentInput.get(); m_dataType = data.m_dataType; // see if this filter changes data type if (userDataTypeInput.get() != null) { m_dataType = userDataTypeInput.get(); convertDataType = true; } if (constantSiteWeightsInput.get() != null) { if (constantSiteWeightsInput.get().getDimension() != m_dataType.getStateCount()) { throw new IllegalArgumentException("constantSiteWeights should be of the same dimension as the datatype " + "(" + constantSiteWeightsInput.get().getDimension() + "!="+ m_dataType.getStateCount() +")"); } } counts = data.getCounts(); taxaNames = data.taxaNames; stateCounts = data.stateCounts; if (convertDataType && m_dataType.getStateCount() > 0) { for (int i = 0; i < stateCounts.size(); i++) { stateCounts.set(i, m_dataType.getStateCount()); } } if (alignmentInput.get().siteWeightsInput.get() != null) { String str = alignmentInput.get().siteWeightsInput.get().trim(); String [] strs = str.split(","); siteWeights = new int[strs.length]; for (int i = 0; i< strs.length; i++) { siteWeights[i] = Integer.parseInt(strs[i].trim()); } } calcPatterns(); setupAscertainment(); } private void parseFilterSpec() { // parse filter specification String filterString = filterInput.get(); String[] filters = filterString.split(","); from = new int[filters.length]; to = new int[filters.length]; step = new int[filters.length]; for (int i = 0; i < filters.length; i++) { filterString = " " + filters[i] + " "; if (filterString.matches(".*-.*")) { // range, e.g. 1-100/3 if (filterString.indexOf('\\') >= 0) { String str2 = filterString.substring(filterString.indexOf('\\') + 1); step[i] = parseInt(str2, 1); filterString = filterString.substring(0, filterString.indexOf('\\')); } else { step[i] = 1; } String[] strs = filterString.split("-"); from[i] = parseInt(strs[0], 1) - 1; to[i] = parseInt(strs[1], alignmentInput.get().getSiteCount()) - 1; } else if (filterString.matches(".*:.*:.+")) { // iterator, e.g. 1:100:3 String[] strs = filterString.split(":"); from[i] = parseInt(strs[0], 1) - 1; to[i] = parseInt(strs[1], alignmentInput.get().getSiteCount()) - 1; step[i] = parseInt(strs[2], 1); } else if (filterString.trim().matches("[0-9]*")) { from[i] = parseInt(filterString.trim(), 1) - 1; to[i] = from[i]; step[i] = 1; } else { throw new IllegalArgumentException("Don't know how to parse filter " + filterString); } } } int parseInt(String str, int defaultValue) { str = str.replaceAll("\\s+", ""); try { return Integer.parseInt(str); } catch (Exception e) { return defaultValue; } } private void calcFilter() { boolean[] isUsed = new boolean[alignmentInput.get().getSiteCount()]; for (int i = 0; i < to.length; i++) { for (int k = from[i]; k <= to[i]; k += step[i]) { isUsed[k] = true; } } // count int k = 0; for (int i = 0; i < isUsed.length; i++) { if (isUsed[i]) { k++; } } // set up index set filter = new int[k]; k = 0; for (int i = 0; i < isUsed.length; i++) { if (isUsed[i]) { filter[k++] = i; } } } @Override public List<List<Integer>> getCounts() { if (counts == null) { counts = new ArrayList<>(); for (int j = 0; j < sitePatterns[0].length; j++) { counts.add(new ArrayList<>()); } for (int i = 0; i < getSiteCount(); i++) { int [] sites = getPattern(getPatternIndex(i)); for (int j = 0; j < getTaxonCount(); j++) { counts.get(j).add(sites[j]); } } } return counts; } @Override protected void calcPatterns() { int nrOfTaxa = counts.size(); int nrOfSites = filter.length; DataType baseType = alignmentInput.get().m_dataType; // convert data to transposed int array int[][] data = new int[nrOfSites][nrOfTaxa]; for (int i = 0; i < nrOfTaxa; i++) { List<Integer> sites = counts.get(i); for (int j = 0; j < nrOfSites; j++) { data[j][i] = sites.get(filter[j]); if (convertDataType) { try { String code = baseType.getCode(data[j][i]); data[j][i] = m_dataType.string2state(code).get(0); } catch (Exception e) { e.printStackTrace(); } } } } // add constant sites, if specified if (constantSiteWeightsInput.get() != null) { int dim = constantSiteWeightsInput.get().getDimension(); // add constant patterns int [][] data2 = new int[nrOfSites + dim][]; System.arraycopy(data, 0, data2, 0, nrOfSites); for (int i = 0; i < dim; i++) { data2[nrOfSites + i] = new int[nrOfTaxa]; for (int j = 0; j < nrOfTaxa; j++) { data2[nrOfSites+ i][j] = i; } } data = data2; nrOfSites += dim; } // sort data SiteComparator comparator = new SiteComparator(); Arrays.sort(data, comparator); // count patterns in sorted data int[] weights = new int[nrOfSites]; int nrOfPatterns = 1; if (nrOfSites > 0) { weights[0] = 1; for (int i = 1; i < nrOfSites; i++) { if (comparator.compare(data[i - 1], data[i]) != 0) { nrOfPatterns++; data[nrOfPatterns - 1] = data[i]; } weights[nrOfPatterns - 1]++; } } else { nrOfPatterns = 0; } // addjust weight of invariant sites, if stripInvariantSitesInput i sspecified if (stripInvariantSitesInput.get()) { // don't add patterns that are invariant, e.g. all gaps Log.info.print("Stripping invariant sites"); int removedSites = 0; for (int i = 0; i < nrOfPatterns; i++) { boolean isContant = true; for (int j = 1; j < nrOfTaxa; j++) { if (data[i][j] != data[i][0]) { isContant = false; break; } } // if this is a constant site, and it is not an ambiguous site if (isContant) { Log.warning.print(" <" + data[i][0] + "> "); removedSites += weights[i]; weights[i] = 0; } } Log.warning.println(" removed " + removedSites + " sites "); } // addjust weight of constant sites, if specified if (constantSiteWeightsInput.get() != null) { Integer [] constantWeights = constantSiteWeightsInput.get().getValues(); for (int i = 0; i < nrOfPatterns; i++) { boolean isContant = true; for (int j = 1; j < nrOfTaxa; j++) { if (data[i][j] != data[i][0]) { isContant = false; break; } } // if this is a constant site, and it is not an ambiguous site if (isContant && data[i][0] >= 0 && data[i][0] < constantWeights.length) { // take weights in data in account as well // by adding constant patterns, we added a weight of 1, which now gets corrected // but if filtered by stripping constant sites, that weight is already set to zero weights[i] = (stripInvariantSitesInput.get() ? 0 : weights[i] - 1) + constantWeights[data[i][0]]; } } // need to decrease siteCount for mapping sites to patterns in m_nPatternIndex nrOfSites -= constantWeights.length; } // reserve memory for patterns patternWeight = new int[nrOfPatterns]; sitePatterns = new int[nrOfPatterns][nrOfTaxa]; for (int i = 0; i < nrOfPatterns; i++) { patternWeight[i] = weights[i]; sitePatterns[i] = data[i]; } // find patterns for the sites patternIndex = new int[nrOfSites]; for (int i = 0; i < nrOfSites; i++) { int[] sites = new int[nrOfTaxa]; for (int j = 0; j < nrOfTaxa; j++) { sites[j] = counts.get(j).get(filter[i]); if (convertDataType) { try { sites[j] = m_dataType.string2state(baseType.getCode(sites[j])).get(0); } catch (Exception e) { e.printStackTrace(); } } } patternIndex[i] = Arrays.binarySearch(sitePatterns, sites, comparator); } if (siteWeights != null) { // TODO: fill in weights with siteweights. throw new RuntimeException("Cannot handle site weights in FilteredAlignment. Remove \"weights\" from data input."); } // determine maximum state count // Usually, the state count is equal for all sites, // though for SnAP analysis, this is typically not the case. maxStateCount = 0; for (int stateCount1 : stateCounts) { maxStateCount = Math.max(maxStateCount, stateCount1); } if (convertDataType) { maxStateCount = Math.max(maxStateCount, m_dataType.getStateCount()); } // report some statistics //for (int i = 0; i < m_sTaxaNames.size(); i++) { // System.err.println(m_sTaxaNames.get(i) + ": " + m_counts.get(i).size() + " " + m_nStateCounts.get(i)); //} Log.info.println("Filter " + filterInput.get()); Log.info.println(getTaxonCount() + " taxa"); if (constantSiteWeightsInput.get() != null) { Integer [] constantWeights = constantSiteWeightsInput.get().getValues(); int sum = 0; for (int i : constantWeights) { sum += i; } Log.info.println(getSiteCount() + " sites + " + sum + " constant sites"); } else { Log.info.println(getSiteCount() + " sites"); } Log.info.println(getPatternCount() + " patterns"); // counts are not valid any more -- better set to null in case // someone gets bitten by this. this.counts = null; } /** return indices of the sites that the filter uses **/ public int [] indices() { return filter.clone(); } }