/* * Licensed to the Apache Software Foundation (ASF) under one or more * contributor license agreements. See the NOTICE file distributed with * this work for additional information regarding copyright ownership. * The ASF licenses this file to You under the Apache License, Version 2.0 * (the "License"); you may not use this file except in compliance with * the License. You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package org.apache.tika.detect; import java.io.File; import java.io.IOException; import java.io.InputStream; import java.io.Writer; import java.nio.ByteBuffer; import java.nio.channels.Channels; import java.nio.channels.ReadableByteChannel; import java.nio.file.Files; import java.nio.file.Path; import java.util.HashMap; import java.util.Iterator; import java.util.Map; import org.apache.tika.io.TemporaryResources; import org.apache.tika.metadata.Metadata; import org.apache.tika.mime.MediaType; import static java.nio.charset.StandardCharsets.UTF_8; public abstract class TrainedModelDetector implements Detector { private final Map<MediaType, TrainedModel> MODEL_MAP = new HashMap<>(); private static final long serialVersionUID = 1L; public TrainedModelDetector() { loadDefaultModels(getClass().getClassLoader()); } public int getMinLength() { return Integer.MAX_VALUE; } public MediaType detect(InputStream input, Metadata metadata) throws IOException { // convert to byte-histogram if (input != null) { input.mark(getMinLength()); float[] histogram = readByteFrequencies(input); // writeHisto(histogram); //on testing purpose /* * iterate the map to find out the one that gives the higher * prediction value. */ Iterator<MediaType> iter = MODEL_MAP.keySet().iterator(); float threshold = 0.5f;// probability threshold, any value below the // threshold will be considered as // MediaType.OCTET_STREAM float maxprob = threshold; MediaType maxType = MediaType.OCTET_STREAM; while (iter.hasNext()) { MediaType key = iter.next(); TrainedModel model = MODEL_MAP.get(key); float prob = model.predict(histogram); if (maxprob < prob) { maxprob = prob; maxType = key; } } input.reset(); return maxType; } return null; } /** * Read the {@code inputstream} and build a byte frequency histogram * * @param input stream to read from * @return byte frequencies array * @throws IOException */ protected float[] readByteFrequencies(final InputStream input) throws IOException { ReadableByteChannel inputChannel; // TODO: any reason to avoid closing of input & inputChannel? try { inputChannel = Channels.newChannel(input); // long inSize = inputChannel.size(); float histogram[] = new float[257]; histogram[0] = 1; // create buffer with capacity of maxBufSize bytes ByteBuffer buf = ByteBuffer.allocate(1024 * 5); int bytesRead = inputChannel.read(buf); // read into buffer. float max = -1; while (bytesRead != -1) { buf.flip(); // make buffer ready for read while (buf.hasRemaining()) { byte byt = buf.get(); int idx = byt; idx++; if (byt < 0) { idx = 256 + idx; histogram[idx]++; } else { histogram[idx]++; } max = max < histogram[idx] ? histogram[idx] : max; } buf.clear(); // make buffer ready for writing bytesRead = inputChannel.read(buf); } int i; for (i = 1; i < histogram.length; i++) { histogram[i] /= max; histogram[i] = (float) Math.sqrt(histogram[i]); } return histogram; } finally { // inputChannel.close(); } } /** * for testing purposes; this method write the histogram vector to a file. * * @param histogram * @throws IOException */ private void writeHisto(final float[] histogram) throws IOException { Path histPath = new TemporaryResources().createTempFile(); try (Writer writer = Files.newBufferedWriter(histPath, UTF_8)) { for (float bin : histogram) { writer.write(String.valueOf(bin) + "\t"); // writer.write(i + "\t"); } writer.write("\r\n"); } } public void loadDefaultModels(Path modelFile) { try (InputStream in = Files.newInputStream(modelFile)) { loadDefaultModels(in); } catch (IOException e) { throw new RuntimeException("Unable to read the default media type registry", e); } } public void loadDefaultModels(File modelFile) { loadDefaultModels(modelFile.toPath()); } public abstract void loadDefaultModels(final InputStream modelStream); public abstract void loadDefaultModels(final ClassLoader classLoader); protected void registerModels(MediaType type, TrainedModel model) { MODEL_MAP.put(type, model); } }