/* * Copyright 2010-2016 Amazon.com, Inc. or its affiliates. All Rights Reserved. * * Licensed under the Apache License, Version 2.0 (the "License"). * You may not use this file except in compliance with the License. * A copy of the License is located at * * http://aws.amazon.com/apache2.0 * * or in the "license" file accompanying this file. This file is distributed * on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either * express or implied. See the License for the specific language governing * permissions and limitations under the License. */ package com.amazon.blueshift.bluefront.android.vad; import com.amazonaws.mobileconnectors.lex.interactionkit.internal.vad.VADException; import com.amazonaws.mobileconnectors.lex.interactionkit.internal.vad.VoiceActivityDetector; import com.amazonaws.mobileconnectors.lex.interactionkit.internal.vad.config.VADConfig; import java.nio.ByteBuffer; import java.nio.ShortBuffer; /** * Abstract Voice Activity Detector class. * @param <T> typed class. */ public abstract class AbstractVAD<T extends VADConfig> implements VoiceActivityDetector { private static final int VAD_FRAMES_PER_SEC = 100; /** * Load the library. */ static { System.loadLibrary("blueshift-audioprocessing"); } private final int mStartpointingThreshold; private final int mEndpointingThreshold; private final int mSampleRate; private int mConsecutiveSpeechFrames; private int mConsecutiveNonSpeechFrames; private final ShortBuffer mSamplesBuffer; private ByteBuffer mVAD; private VADState mVADState; /** * Create a VAD with given audio sample rate and configuration. * @param sampleRate audio sample rate. * @param vadConfig configuration used to set up the VAD. */ protected AbstractVAD(final int sampleRate, final T vadConfig) { mSampleRate = sampleRate; mStartpointingThreshold = vadConfig.getStartpointingThreshold(); mEndpointingThreshold = vadConfig.getEndpointingThreshold(); mConsecutiveSpeechFrames = 0; mConsecutiveNonSpeechFrames = 0; mVAD = setupVAD(vadConfig); mVADState = VADState.NOT_STARTPOINTED; mSamplesBuffer = ShortBuffer.wrap(new short[mSampleRate / VAD_FRAMES_PER_SEC]); } /** * Abstract method to setup the VAD. * @param vadConfig the configuration used to set up the VAD. * @return created VAD. */ protected abstract ByteBuffer setupVAD(final T vadConfig); @Override public final synchronized VADState processSamples(final short[] samples, final int samplesRead) throws VADException { if (mVAD == null) { throw new VADException("VAD is not initialized"); } int samplesProcessed = 0; int toWrite = 0; // Loop while there are enough samples to fill the buffer. while (mSamplesBuffer.remaining() <= samplesRead - samplesProcessed) { // Copy samples into buffer. toWrite = mSamplesBuffer.remaining(); mSamplesBuffer.put(samples, samplesProcessed, toWrite); samplesProcessed += toWrite; // Process samples. final int result = isSpeech(mVAD, mSamplesBuffer.array(), mSampleRate); mSamplesBuffer.clear(); // Update internal state. if (result == 1) { mConsecutiveSpeechFrames++; mConsecutiveNonSpeechFrames = 0; } else if (result == 0) { mConsecutiveNonSpeechFrames++; mConsecutiveSpeechFrames = 0; } else { throw new VADException("Error processing speech frames"); } // Update VAD state. updateVADState(); } // If any samples remain, copy into buffer. mSamplesBuffer.put(samples, samplesProcessed, samplesRead - samplesProcessed); return mVADState; } /** * Update internal VAD state. */ private void updateVADState() { if (mVADState == VADState.NOT_STARTPOINTED && mConsecutiveSpeechFrames >= mStartpointingThreshold) { mVADState = VADState.STARTPOINTED; } else if (mVADState == VADState.STARTPOINTED && mConsecutiveNonSpeechFrames >= mEndpointingThreshold) { mVADState = VADState.ENDPOINTED; } } @Override public synchronized void close() { if (mVAD != null) { destroyVAD(mVAD); mVAD = null; } } /** * Create VAD.. * @param useDNN flag to indicate whether use DNN implementation or WebRtc implementation. * @return the handle to the VAD structure. */ protected native ByteBuffer createVAD(final boolean useDNN); /** * Set the aggression mode for the WebRtc VAD. * @param vadInstance the WebRtc VAD. * @param aggressionMode the aggression mode. * @return 0 if success, -1 if error. */ protected native int setWebRtcMode(final ByteBuffer vadInstance, final int aggressionMode); /** * Set the customized aggression mode for the WebRtc VAD. * @param vadInstance the WebRtc VAD. * @param overHangMax1 the 1st max over hang value. * @param overHangMax2 the 2nd max over hang value. * @param localThreshold the local threshold. * @param globalThreshold the global threshold. * @return 0 if success, -1 if error. */ protected native int setWebRtcCustomizedMode(final ByteBuffer vadInstance, final int overHangMax1, final int overHangMax2, final int localThreshold, final int globalThreshold); /** * Set the specific threshold for the DNN VAD. * @param vadInstance the DNN VAD. * @param threshold the LRT threshold value. * @return 0 if success, -1 if error. */ protected native int setDNNThreshold(final ByteBuffer vadInstance, final float threshold); /** * Free VAD memory. * @param vadInstance the VAD instance. * @return 0 if success, -1 if error. */ protected native int destroyVAD(final ByteBuffer vadInstance); /** * Check if audio contains speech. * @param vadInstance the native VAD instance. * @param samples the audio buffer. * @param sampleRate the sample rate of the audio. * @return 1 if speech, 0 if non-speech, and -1 if error. */ protected native int isSpeech(final ByteBuffer vadInstance, final short[] samples, final int sampleRate); }