Repo created
This commit is contained in:
parent
81b91f4139
commit
f8c34fa5ee
22732 changed files with 4815320 additions and 2 deletions
|
|
@ -0,0 +1,216 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/adaptive_digital_gain_controller.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
#include "system_wrappers/include/metrics.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
using AdaptiveDigitalConfig =
|
||||
AudioProcessing::Config::GainController2::AdaptiveDigital;
|
||||
|
||||
constexpr int kHeadroomHistogramMin = 0;
|
||||
constexpr int kHeadroomHistogramMax = 50;
|
||||
constexpr int kGainDbHistogramMax = 30;
|
||||
|
||||
// Computes the gain for `input_level_dbfs` to reach `-config.headroom_db`.
|
||||
// Clamps the gain in [0, `config.max_gain_db`]. `config.headroom_db` is a
|
||||
// safety margin to allow transient peaks to exceed the target peak level
|
||||
// without clipping.
|
||||
float ComputeGainDb(float input_level_dbfs,
|
||||
const AdaptiveDigitalConfig& config) {
|
||||
// If the level is very low, apply the maximum gain.
|
||||
if (input_level_dbfs < -(config.headroom_db + config.max_gain_db)) {
|
||||
return config.max_gain_db;
|
||||
}
|
||||
// We expect to end up here most of the time: the level is below
|
||||
// -headroom, but we can boost it to -headroom.
|
||||
if (input_level_dbfs < -config.headroom_db) {
|
||||
return -config.headroom_db - input_level_dbfs;
|
||||
}
|
||||
// The level is too high and we can't boost.
|
||||
RTC_DCHECK_GE(input_level_dbfs, -config.headroom_db);
|
||||
return 0.0f;
|
||||
}
|
||||
|
||||
// Returns `target_gain_db` if applying such a gain to `input_noise_level_dbfs`
|
||||
// does not exceed `max_output_noise_level_dbfs`. Otherwise lowers and returns
|
||||
// `target_gain_db` so that the output noise level equals
|
||||
// `max_output_noise_level_dbfs`.
|
||||
float LimitGainByNoise(float target_gain_db,
|
||||
float input_noise_level_dbfs,
|
||||
float max_output_noise_level_dbfs,
|
||||
ApmDataDumper& apm_data_dumper) {
|
||||
const float max_allowed_gain_db =
|
||||
max_output_noise_level_dbfs - input_noise_level_dbfs;
|
||||
apm_data_dumper.DumpRaw("agc2_adaptive_gain_applier_max_allowed_gain_db",
|
||||
max_allowed_gain_db);
|
||||
return std::min(target_gain_db, std::max(max_allowed_gain_db, 0.0f));
|
||||
}
|
||||
|
||||
float LimitGainByLowConfidence(float target_gain_db,
|
||||
float last_gain_db,
|
||||
float limiter_audio_level_dbfs,
|
||||
bool estimate_is_confident) {
|
||||
if (estimate_is_confident ||
|
||||
limiter_audio_level_dbfs <= kLimiterThresholdForAgcGainDbfs) {
|
||||
return target_gain_db;
|
||||
}
|
||||
const float limiter_level_dbfs_before_gain =
|
||||
limiter_audio_level_dbfs - last_gain_db;
|
||||
|
||||
// Compute a new gain so that `limiter_level_dbfs_before_gain` +
|
||||
// `new_target_gain_db` is not great than `kLimiterThresholdForAgcGainDbfs`.
|
||||
const float new_target_gain_db = std::max(
|
||||
kLimiterThresholdForAgcGainDbfs - limiter_level_dbfs_before_gain, 0.0f);
|
||||
return std::min(new_target_gain_db, target_gain_db);
|
||||
}
|
||||
|
||||
// Computes how the gain should change during this frame.
|
||||
// Return the gain difference in db to 'last_gain_db'.
|
||||
float ComputeGainChangeThisFrameDb(float target_gain_db,
|
||||
float last_gain_db,
|
||||
bool gain_increase_allowed,
|
||||
float max_gain_decrease_db,
|
||||
float max_gain_increase_db) {
|
||||
RTC_DCHECK_GT(max_gain_decrease_db, 0);
|
||||
RTC_DCHECK_GT(max_gain_increase_db, 0);
|
||||
float target_gain_difference_db = target_gain_db - last_gain_db;
|
||||
if (!gain_increase_allowed) {
|
||||
target_gain_difference_db = std::min(target_gain_difference_db, 0.0f);
|
||||
}
|
||||
return rtc::SafeClamp(target_gain_difference_db, -max_gain_decrease_db,
|
||||
max_gain_increase_db);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
AdaptiveDigitalGainController::AdaptiveDigitalGainController(
|
||||
ApmDataDumper* apm_data_dumper,
|
||||
const AudioProcessing::Config::GainController2::AdaptiveDigital& config,
|
||||
int adjacent_speech_frames_threshold)
|
||||
: apm_data_dumper_(apm_data_dumper),
|
||||
gain_applier_(
|
||||
/*hard_clip_samples=*/false,
|
||||
/*initial_gain_factor=*/DbToRatio(config.initial_gain_db)),
|
||||
config_(config),
|
||||
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
|
||||
max_gain_change_db_per_10ms_(config_.max_gain_change_db_per_second *
|
||||
kFrameDurationMs / 1000.0f),
|
||||
calls_since_last_gain_log_(0),
|
||||
frames_to_gain_increase_allowed_(adjacent_speech_frames_threshold),
|
||||
last_gain_db_(config_.initial_gain_db) {
|
||||
RTC_DCHECK_GT(max_gain_change_db_per_10ms_, 0.0f);
|
||||
RTC_DCHECK_GE(frames_to_gain_increase_allowed_, 1);
|
||||
RTC_DCHECK_GE(config_.max_output_noise_level_dbfs, -90.0f);
|
||||
RTC_DCHECK_LE(config_.max_output_noise_level_dbfs, 0.0f);
|
||||
}
|
||||
|
||||
void AdaptiveDigitalGainController::Process(const FrameInfo& info,
|
||||
AudioFrameView<float> frame) {
|
||||
RTC_DCHECK_GE(info.speech_level_dbfs, -150.0f);
|
||||
RTC_DCHECK_GE(frame.num_channels(), 1);
|
||||
RTC_DCHECK(
|
||||
frame.samples_per_channel() == 80 || frame.samples_per_channel() == 160 ||
|
||||
frame.samples_per_channel() == 320 || frame.samples_per_channel() == 480)
|
||||
<< "`frame` does not look like a 10 ms frame for an APM supported sample "
|
||||
"rate";
|
||||
|
||||
// Compute the input level used to select the desired gain.
|
||||
RTC_DCHECK_GT(info.headroom_db, 0.0f);
|
||||
const float input_level_dbfs = info.speech_level_dbfs + info.headroom_db;
|
||||
|
||||
const float target_gain_db = LimitGainByLowConfidence(
|
||||
LimitGainByNoise(ComputeGainDb(input_level_dbfs, config_),
|
||||
info.noise_rms_dbfs, config_.max_output_noise_level_dbfs,
|
||||
*apm_data_dumper_),
|
||||
last_gain_db_, info.limiter_envelope_dbfs, info.speech_level_reliable);
|
||||
|
||||
// Forbid increasing the gain until enough adjacent speech frames are
|
||||
// observed.
|
||||
bool first_confident_speech_frame = false;
|
||||
if (info.speech_probability < kVadConfidenceThreshold) {
|
||||
frames_to_gain_increase_allowed_ = adjacent_speech_frames_threshold_;
|
||||
} else if (frames_to_gain_increase_allowed_ > 0) {
|
||||
frames_to_gain_increase_allowed_--;
|
||||
first_confident_speech_frame = frames_to_gain_increase_allowed_ == 0;
|
||||
}
|
||||
apm_data_dumper_->DumpRaw(
|
||||
"agc2_adaptive_gain_applier_frames_to_gain_increase_allowed",
|
||||
frames_to_gain_increase_allowed_);
|
||||
|
||||
const bool gain_increase_allowed = frames_to_gain_increase_allowed_ == 0;
|
||||
|
||||
float max_gain_increase_db = max_gain_change_db_per_10ms_;
|
||||
if (first_confident_speech_frame) {
|
||||
// No gain increase happened while waiting for a long enough speech
|
||||
// sequence. Therefore, temporarily allow a faster gain increase.
|
||||
RTC_DCHECK(gain_increase_allowed);
|
||||
max_gain_increase_db *= adjacent_speech_frames_threshold_;
|
||||
}
|
||||
|
||||
const float gain_change_this_frame_db = ComputeGainChangeThisFrameDb(
|
||||
target_gain_db, last_gain_db_, gain_increase_allowed,
|
||||
/*max_gain_decrease_db=*/max_gain_change_db_per_10ms_,
|
||||
max_gain_increase_db);
|
||||
|
||||
apm_data_dumper_->DumpRaw("agc2_adaptive_gain_applier_want_to_change_by_db",
|
||||
target_gain_db - last_gain_db_);
|
||||
apm_data_dumper_->DumpRaw("agc2_adaptive_gain_applier_will_change_by_db",
|
||||
gain_change_this_frame_db);
|
||||
|
||||
// Optimization: avoid calling math functions if gain does not
|
||||
// change.
|
||||
if (gain_change_this_frame_db != 0.f) {
|
||||
gain_applier_.SetGainFactor(
|
||||
DbToRatio(last_gain_db_ + gain_change_this_frame_db));
|
||||
}
|
||||
|
||||
gain_applier_.ApplyGain(frame);
|
||||
|
||||
// Remember that the gain has changed for the next iteration.
|
||||
last_gain_db_ = last_gain_db_ + gain_change_this_frame_db;
|
||||
apm_data_dumper_->DumpRaw("agc2_adaptive_gain_applier_applied_gain_db",
|
||||
last_gain_db_);
|
||||
|
||||
// Log every 10 seconds.
|
||||
calls_since_last_gain_log_++;
|
||||
if (calls_since_last_gain_log_ == 1000) {
|
||||
calls_since_last_gain_log_ = 0;
|
||||
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.EstimatedSpeechLevel",
|
||||
-info.speech_level_dbfs, 0, 100, 101);
|
||||
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.EstimatedNoiseLevel",
|
||||
-info.noise_rms_dbfs, 0, 100, 101);
|
||||
RTC_HISTOGRAM_COUNTS_LINEAR(
|
||||
"WebRTC.Audio.Agc2.Headroom", info.headroom_db, kHeadroomHistogramMin,
|
||||
kHeadroomHistogramMax,
|
||||
kHeadroomHistogramMax - kHeadroomHistogramMin + 1);
|
||||
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.DigitalGainApplied",
|
||||
last_gain_db_, 0, kGainDbHistogramMax,
|
||||
kGainDbHistogramMax + 1);
|
||||
RTC_LOG(LS_INFO) << "AGC2 adaptive digital"
|
||||
<< " | speech_dbfs: " << info.speech_level_dbfs
|
||||
<< " | noise_dbfs: " << info.noise_rms_dbfs
|
||||
<< " | headroom_db: " << info.headroom_db
|
||||
<< " | gain_db: " << last_gain_db_;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_CONTROLLER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_CONTROLLER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/gain_applier.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
#include "modules/audio_processing/include/audio_processing.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
class ApmDataDumper;
|
||||
|
||||
// Selects the target digital gain, decides when and how quickly to adapt to the
|
||||
// target and applies the current gain to 10 ms frames.
|
||||
class AdaptiveDigitalGainController {
|
||||
public:
|
||||
// Information about a frame to process.
|
||||
struct FrameInfo {
|
||||
float speech_probability; // Probability of speech in the [0, 1] range.
|
||||
float speech_level_dbfs; // Estimated speech level (dBFS).
|
||||
bool speech_level_reliable; // True with reliable speech level estimation.
|
||||
float noise_rms_dbfs; // Estimated noise RMS level (dBFS).
|
||||
float headroom_db; // Headroom (dB).
|
||||
// TODO(bugs.webrtc.org/7494): Remove `limiter_envelope_dbfs`.
|
||||
float limiter_envelope_dbfs; // Envelope level from the limiter (dBFS).
|
||||
};
|
||||
|
||||
AdaptiveDigitalGainController(
|
||||
ApmDataDumper* apm_data_dumper,
|
||||
const AudioProcessing::Config::GainController2::AdaptiveDigital& config,
|
||||
int adjacent_speech_frames_threshold);
|
||||
AdaptiveDigitalGainController(const AdaptiveDigitalGainController&) = delete;
|
||||
AdaptiveDigitalGainController& operator=(
|
||||
const AdaptiveDigitalGainController&) = delete;
|
||||
|
||||
// Analyzes `info`, updates the digital gain and applies it to a 10 ms
|
||||
// `frame`. Supports any sample rate supported by APM.
|
||||
void Process(const FrameInfo& info, AudioFrameView<float> frame);
|
||||
|
||||
private:
|
||||
ApmDataDumper* const apm_data_dumper_;
|
||||
GainApplier gain_applier_;
|
||||
|
||||
const AudioProcessing::Config::GainController2::AdaptiveDigital config_;
|
||||
const int adjacent_speech_frames_threshold_;
|
||||
const float max_gain_change_db_per_10ms_;
|
||||
|
||||
int calls_since_last_gain_log_;
|
||||
int frames_to_gain_increase_allowed_;
|
||||
float last_gain_db_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_CONTROLLER_H_
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
constexpr float kMinFloatS16Value = -32768.0f;
|
||||
constexpr float kMaxFloatS16Value = 32767.0f;
|
||||
constexpr float kMaxAbsFloatS16Value = 32768.0f;
|
||||
|
||||
// Minimum audio level in dBFS scale for S16 samples.
|
||||
constexpr float kMinLevelDbfs = -90.31f;
|
||||
|
||||
constexpr int kFrameDurationMs = 10;
|
||||
constexpr int kSubFramesInFrame = 20;
|
||||
constexpr int kMaximalNumberOfSamplesPerChannel = 480;
|
||||
|
||||
// Adaptive digital gain applier settings.
|
||||
|
||||
// At what limiter levels should we start decreasing the adaptive digital gain.
|
||||
constexpr float kLimiterThresholdForAgcGainDbfs = -1.0f;
|
||||
|
||||
// Number of milliseconds to wait to periodically reset the VAD.
|
||||
constexpr int kVadResetPeriodMs = 1500;
|
||||
|
||||
// Speech probability threshold to detect speech activity.
|
||||
constexpr float kVadConfidenceThreshold = 0.95f;
|
||||
|
||||
// Minimum number of adjacent speech frames having a sufficiently high speech
|
||||
// probability to reliably detect speech activity.
|
||||
constexpr int kAdjacentSpeechFramesThreshold = 12;
|
||||
|
||||
// Number of milliseconds of speech frames to observe to make the estimator
|
||||
// confident.
|
||||
constexpr float kLevelEstimatorTimeToConfidenceMs = 400;
|
||||
constexpr float kLevelEstimatorLeakFactor =
|
||||
1.0f - 1.0f / kLevelEstimatorTimeToConfidenceMs;
|
||||
|
||||
// Saturation Protector settings.
|
||||
constexpr float kSaturationProtectorInitialHeadroomDb = 20.0f;
|
||||
constexpr int kSaturationProtectorBufferSize = 4;
|
||||
|
||||
// Number of interpolation points for each region of the limiter.
|
||||
// These values have been tuned to limit the interpolated gain curve error given
|
||||
// the limiter parameters and allowing a maximum error of +/- 32768^-1.
|
||||
constexpr int kInterpolatedGainCurveKneePoints = 22;
|
||||
constexpr int kInterpolatedGainCurveBeyondKneePoints = 10;
|
||||
constexpr int kInterpolatedGainCurveTotalPoints =
|
||||
kInterpolatedGainCurveKneePoints + kInterpolatedGainCurveBeyondKneePoints;
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_
|
||||
|
|
@ -0,0 +1,94 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_testing_common.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace test {
|
||||
|
||||
std::vector<double> LinSpace(double l, double r, int num_points) {
|
||||
RTC_CHECK_GE(num_points, 2);
|
||||
std::vector<double> points(num_points);
|
||||
const double step = (r - l) / (num_points - 1.0);
|
||||
points[0] = l;
|
||||
for (int i = 1; i < num_points - 1; i++) {
|
||||
points[i] = static_cast<double>(l) + i * step;
|
||||
}
|
||||
points[num_points - 1] = r;
|
||||
return points;
|
||||
}
|
||||
|
||||
WhiteNoiseGenerator::WhiteNoiseGenerator(int min_amplitude, int max_amplitude)
|
||||
: rand_gen_(42),
|
||||
min_amplitude_(min_amplitude),
|
||||
max_amplitude_(max_amplitude) {
|
||||
RTC_DCHECK_LT(min_amplitude_, max_amplitude_);
|
||||
RTC_DCHECK_LE(kMinS16, min_amplitude_);
|
||||
RTC_DCHECK_LE(min_amplitude_, kMaxS16);
|
||||
RTC_DCHECK_LE(kMinS16, max_amplitude_);
|
||||
RTC_DCHECK_LE(max_amplitude_, kMaxS16);
|
||||
}
|
||||
|
||||
float WhiteNoiseGenerator::operator()() {
|
||||
return static_cast<float>(rand_gen_.Rand(min_amplitude_, max_amplitude_));
|
||||
}
|
||||
|
||||
SineGenerator::SineGenerator(float amplitude,
|
||||
float frequency_hz,
|
||||
int sample_rate_hz)
|
||||
: amplitude_(amplitude),
|
||||
frequency_hz_(frequency_hz),
|
||||
sample_rate_hz_(sample_rate_hz),
|
||||
x_radians_(0.0f) {
|
||||
RTC_DCHECK_GT(amplitude_, 0);
|
||||
RTC_DCHECK_LE(amplitude_, kMaxS16);
|
||||
}
|
||||
|
||||
float SineGenerator::operator()() {
|
||||
constexpr float kPi = 3.1415926536f;
|
||||
x_radians_ += frequency_hz_ / sample_rate_hz_ * 2 * kPi;
|
||||
if (x_radians_ >= 2 * kPi) {
|
||||
x_radians_ -= 2 * kPi;
|
||||
}
|
||||
// Use sinf instead of std::sinf for libstdc++ compatibility.
|
||||
return amplitude_ * sinf(x_radians_);
|
||||
}
|
||||
|
||||
PulseGenerator::PulseGenerator(float pulse_amplitude,
|
||||
float no_pulse_amplitude,
|
||||
float frequency_hz,
|
||||
int sample_rate_hz)
|
||||
: pulse_amplitude_(pulse_amplitude),
|
||||
no_pulse_amplitude_(no_pulse_amplitude),
|
||||
samples_period_(
|
||||
static_cast<int>(static_cast<float>(sample_rate_hz) / frequency_hz)),
|
||||
sample_counter_(0) {
|
||||
RTC_DCHECK_GE(pulse_amplitude_, kMinS16);
|
||||
RTC_DCHECK_LE(pulse_amplitude_, kMaxS16);
|
||||
RTC_DCHECK_GT(no_pulse_amplitude_, kMinS16);
|
||||
RTC_DCHECK_LE(no_pulse_amplitude_, kMaxS16);
|
||||
RTC_DCHECK_GT(sample_rate_hz, frequency_hz);
|
||||
}
|
||||
|
||||
float PulseGenerator::operator()() {
|
||||
sample_counter_++;
|
||||
if (sample_counter_ >= samples_period_) {
|
||||
sample_counter_ -= samples_period_;
|
||||
}
|
||||
return static_cast<float>(sample_counter_ == 0 ? pulse_amplitude_
|
||||
: no_pulse_amplitude_);
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_
|
||||
|
||||
#include <limits>
|
||||
#include <vector>
|
||||
|
||||
#include "rtc_base/random.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace test {
|
||||
|
||||
constexpr float kMinS16 =
|
||||
static_cast<float>(std::numeric_limits<int16_t>::min());
|
||||
constexpr float kMaxS16 =
|
||||
static_cast<float>(std::numeric_limits<int16_t>::max());
|
||||
|
||||
// Level Estimator test parameters.
|
||||
constexpr float kDecayMs = 20.0f;
|
||||
|
||||
// Limiter parameters.
|
||||
constexpr float kLimiterMaxInputLevelDbFs = 1.f;
|
||||
constexpr float kLimiterKneeSmoothnessDb = 1.f;
|
||||
constexpr float kLimiterCompressionRatio = 5.f;
|
||||
|
||||
// Returns evenly spaced `num_points` numbers over a specified interval [l, r].
|
||||
std::vector<double> LinSpace(double l, double r, int num_points);
|
||||
|
||||
// Generates white noise.
|
||||
class WhiteNoiseGenerator {
|
||||
public:
|
||||
WhiteNoiseGenerator(int min_amplitude, int max_amplitude);
|
||||
float operator()();
|
||||
|
||||
private:
|
||||
Random rand_gen_;
|
||||
const int min_amplitude_;
|
||||
const int max_amplitude_;
|
||||
};
|
||||
|
||||
// Generates a sine function.
|
||||
class SineGenerator {
|
||||
public:
|
||||
SineGenerator(float amplitude, float frequency_hz, int sample_rate_hz);
|
||||
float operator()();
|
||||
|
||||
private:
|
||||
const float amplitude_;
|
||||
const float frequency_hz_;
|
||||
const int sample_rate_hz_;
|
||||
float x_radians_;
|
||||
};
|
||||
|
||||
// Generates periodic pulses.
|
||||
class PulseGenerator {
|
||||
public:
|
||||
PulseGenerator(float pulse_amplitude,
|
||||
float no_pulse_amplitude,
|
||||
float frequency_hz,
|
||||
int sample_rate_hz);
|
||||
float operator()();
|
||||
|
||||
private:
|
||||
const float pulse_amplitude_;
|
||||
const float no_pulse_amplitude_;
|
||||
const int samples_period_;
|
||||
int sample_counter_;
|
||||
};
|
||||
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/biquad_filter.h"
|
||||
|
||||
#include "rtc_base/arraysize.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
BiQuadFilter::BiQuadFilter(const Config& config)
|
||||
: config_(config), state_({}) {}
|
||||
|
||||
BiQuadFilter::~BiQuadFilter() = default;
|
||||
|
||||
void BiQuadFilter::SetConfig(const Config& config) {
|
||||
config_ = config;
|
||||
state_ = {};
|
||||
}
|
||||
|
||||
void BiQuadFilter::Reset() {
|
||||
state_ = {};
|
||||
}
|
||||
|
||||
void BiQuadFilter::Process(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float> y) {
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
const float config_a0 = config_.a[0];
|
||||
const float config_a1 = config_.a[1];
|
||||
const float config_b0 = config_.b[0];
|
||||
const float config_b1 = config_.b[1];
|
||||
const float config_b2 = config_.b[2];
|
||||
float state_a0 = state_.a[0];
|
||||
float state_a1 = state_.a[1];
|
||||
float state_b0 = state_.b[0];
|
||||
float state_b1 = state_.b[1];
|
||||
for (size_t k = 0, x_size = x.size(); k < x_size; ++k) {
|
||||
// Use a temporary variable for `x[k]` to allow in-place processing.
|
||||
const float tmp = x[k];
|
||||
float y_k = config_b0 * tmp + config_b1 * state_b0 + config_b2 * state_b1 -
|
||||
config_a0 * state_a0 - config_a1 * state_a1;
|
||||
state_b1 = state_b0;
|
||||
state_b0 = tmp;
|
||||
state_a1 = state_a0;
|
||||
state_a0 = y_k;
|
||||
y[k] = y_k;
|
||||
}
|
||||
state_.a[0] = state_a0;
|
||||
state_.a[1] = state_a1;
|
||||
state_.b[0] = state_b0;
|
||||
state_.b[1] = state_b1;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,56 @@
|
|||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
|
||||
|
||||
#include "api/array_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// Transposed direct form I implementation of a bi-quad filter.
|
||||
// b[0] + b[1] • z^(-1) + b[2] • z^(-2)
|
||||
// H(z) = ------------------------------------
|
||||
// 1 + a[1] • z^(-1) + a[2] • z^(-2)
|
||||
class BiQuadFilter {
|
||||
public:
|
||||
// Normalized filter coefficients.
|
||||
// Computed as `[b, a] = scipy.signal.butter(N=2, Wn, btype)`.
|
||||
struct Config {
|
||||
float b[3]; // b[0], b[1], b[2].
|
||||
float a[2]; // a[1], a[2].
|
||||
};
|
||||
|
||||
explicit BiQuadFilter(const Config& config);
|
||||
BiQuadFilter(const BiQuadFilter&) = delete;
|
||||
BiQuadFilter& operator=(const BiQuadFilter&) = delete;
|
||||
~BiQuadFilter();
|
||||
|
||||
// Sets the filter configuration and resets the internal state.
|
||||
void SetConfig(const Config& config);
|
||||
|
||||
// Zeroes the filter state.
|
||||
void Reset();
|
||||
|
||||
// Filters `x` and writes the output in `y`, which must have the same length
|
||||
// of `x`. In-place processing is supported.
|
||||
void Process(rtc::ArrayView<const float> x, rtc::ArrayView<float> y);
|
||||
|
||||
private:
|
||||
Config config_;
|
||||
struct State {
|
||||
float b[2];
|
||||
float a[2];
|
||||
} state_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
|
||||
|
|
@ -0,0 +1,175 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/biquad_filter.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
|
||||
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// #include "test/fpe_observer.h"
|
||||
#include "rtc_base/gunit.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
constexpr int kFrameSize = 8;
|
||||
constexpr int kNumFrames = 4;
|
||||
using FloatArraySequence =
|
||||
std::array<std::array<float, kFrameSize>, kNumFrames>;
|
||||
|
||||
constexpr FloatArraySequence kBiQuadInputSeq = {
|
||||
{{{-87.166290f, -8.029022f, 101.619583f, -0.294296f, -5.825764f, -8.890625f,
|
||||
10.310432f, 54.845333f}},
|
||||
{{-64.647644f, -6.883945f, 11.059189f, -95.242538f, -108.870834f,
|
||||
11.024944f, 63.044102f, -52.709583f}},
|
||||
{{-32.350529f, -18.108028f, -74.022339f, -8.986874f, -1.525581f,
|
||||
103.705513f, 6.346226f, -14.319557f}},
|
||||
{{22.645832f, -64.597153f, 55.462521f, -109.393188f, 10.117825f,
|
||||
-40.019642f, -98.612228f, -8.330326f}}}};
|
||||
|
||||
// Computed as `scipy.signal.butter(N=2, Wn=60/24000, btype='highpass')`.
|
||||
constexpr BiQuadFilter::Config kBiQuadConfig{
|
||||
{0.99446179f, -1.98892358f, 0.99446179f},
|
||||
{-1.98889291f, 0.98895425f}};
|
||||
|
||||
// Comparing to scipy. The expected output is generated as follows:
|
||||
// zi = np.float32([0, 0])
|
||||
// for i in range(4):
|
||||
// yn, zi = scipy.signal.lfilter(B, A, x[i], zi=zi)
|
||||
// print(yn)
|
||||
constexpr FloatArraySequence kBiQuadOutputSeq = {
|
||||
{{{-86.68354497f, -7.02175351f, 102.10290352f, -0.37487333f, -5.87205847f,
|
||||
-8.85521608f, 10.33772563f, 54.51157181f}},
|
||||
{{-64.92531604f, -6.76395978f, 11.15534507f, -94.68073341f, -107.18177856f,
|
||||
13.24642474f, 64.84288941f, -50.97822629f}},
|
||||
{{-30.1579652f, -15.64850899f, -71.06662821f, -5.5883229f, 1.91175353f,
|
||||
106.5572003f, 8.57183046f, -12.06298473f}},
|
||||
{{24.84286614f, -62.18094158f, 57.91488056f, -106.65685933f, 13.38760103f,
|
||||
-36.60367134f, -94.44880104f, -3.59920354f}}}};
|
||||
|
||||
// Fails for every pair from two equally sized rtc::ArrayView<float> views such
|
||||
// that their relative error is above a given threshold. If the expected value
|
||||
// of a pair is 0, `tolerance` is used to check the absolute error.
|
||||
void ExpectNearRelative(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed,
|
||||
const float tolerance) {
|
||||
// The relative error is undefined when the expected value is 0.
|
||||
// When that happens, check the absolute error instead. `safe_den` is used
|
||||
// below to implement such logic.
|
||||
auto safe_den = [](float x) { return (x == 0.0f) ? 1.0f : std::fabs(x); };
|
||||
ASSERT_EQ(expected.size(), computed.size());
|
||||
for (size_t i = 0; i < expected.size(); ++i) {
|
||||
const float abs_diff = std::fabs(expected[i] - computed[i]);
|
||||
// No failure when the values are equal.
|
||||
if (abs_diff == 0.0f) {
|
||||
continue;
|
||||
}
|
||||
SCOPED_TRACE(i);
|
||||
SCOPED_TRACE(expected[i]);
|
||||
SCOPED_TRACE(computed[i]);
|
||||
EXPECT_LE(abs_diff / safe_den(expected[i]), tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
// Checks that filtering works when different containers are used both as input
|
||||
// and as output.
|
||||
TEST(BiQuadFilterTest, FilterNotInPlace) {
|
||||
BiQuadFilter filter(kBiQuadConfig);
|
||||
std::array<float, kFrameSize> samples;
|
||||
|
||||
// TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
|
||||
for (int i = 0; i < kNumFrames; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
filter.Process(kBiQuadInputSeq[i], samples);
|
||||
ExpectNearRelative(kBiQuadOutputSeq[i], samples, 2e-4f);
|
||||
}
|
||||
}
|
||||
|
||||
// Checks that filtering works when the same container is used both as input and
|
||||
// as output.
|
||||
TEST(BiQuadFilterTest, FilterInPlace) {
|
||||
BiQuadFilter filter(kBiQuadConfig);
|
||||
std::array<float, kFrameSize> samples;
|
||||
|
||||
// TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed.
|
||||
// FloatingPointExceptionObserver fpe_observer;
|
||||
|
||||
for (int i = 0; i < kNumFrames; ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
std::copy(kBiQuadInputSeq[i].begin(), kBiQuadInputSeq[i].end(),
|
||||
samples.begin());
|
||||
filter.Process({samples}, {samples});
|
||||
ExpectNearRelative(kBiQuadOutputSeq[i], samples, 2e-4f);
|
||||
}
|
||||
}
|
||||
|
||||
// Checks that different configurations produce different outputs.
|
||||
TEST(BiQuadFilterTest, SetConfigDifferentOutput) {
|
||||
BiQuadFilter filter(/*config=*/{{0.97803048f, -1.95606096f, 0.97803048f},
|
||||
{-1.95557824f, 0.95654368f}});
|
||||
|
||||
std::array<float, kFrameSize> samples1;
|
||||
for (int i = 0; i < kNumFrames; ++i) {
|
||||
filter.Process(kBiQuadInputSeq[i], samples1);
|
||||
}
|
||||
|
||||
filter.SetConfig(
|
||||
{{0.09763107f, 0.19526215f, 0.09763107f}, {-0.94280904f, 0.33333333f}});
|
||||
std::array<float, kFrameSize> samples2;
|
||||
for (int i = 0; i < kNumFrames; ++i) {
|
||||
filter.Process(kBiQuadInputSeq[i], samples2);
|
||||
}
|
||||
|
||||
EXPECT_NE(samples1, samples2);
|
||||
}
|
||||
|
||||
// Checks that when `SetConfig()` is called but the filter coefficients are the
|
||||
// same the filter state is reset.
|
||||
TEST(BiQuadFilterTest, SetConfigResetsState) {
|
||||
BiQuadFilter filter(kBiQuadConfig);
|
||||
|
||||
std::array<float, kFrameSize> samples1;
|
||||
for (int i = 0; i < kNumFrames; ++i) {
|
||||
filter.Process(kBiQuadInputSeq[i], samples1);
|
||||
}
|
||||
|
||||
filter.SetConfig(kBiQuadConfig);
|
||||
std::array<float, kFrameSize> samples2;
|
||||
for (int i = 0; i < kNumFrames; ++i) {
|
||||
filter.Process(kBiQuadInputSeq[i], samples2);
|
||||
}
|
||||
|
||||
EXPECT_EQ(samples1, samples2);
|
||||
}
|
||||
|
||||
// Checks that when `Reset()` is called the filter state is reset.
|
||||
TEST(BiQuadFilterTest, Reset) {
|
||||
BiQuadFilter filter(kBiQuadConfig);
|
||||
|
||||
std::array<float, kFrameSize> samples1;
|
||||
for (int i = 0; i < kNumFrames; ++i) {
|
||||
filter.Process(kBiQuadInputSeq[i], samples1);
|
||||
}
|
||||
|
||||
filter.Reset();
|
||||
std::array<float, kFrameSize> samples2;
|
||||
for (int i = 0; i < kNumFrames; ++i) {
|
||||
filter.Process(kBiQuadInputSeq[i], samples2);
|
||||
}
|
||||
|
||||
EXPECT_EQ(samples1, samples2);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,384 @@
|
|||
/*
|
||||
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/clipping_predictor.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <memory>
|
||||
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "modules/audio_processing/agc2/clipping_predictor_level_buffer.h"
|
||||
#include "modules/audio_processing/agc2/gain_map_internal.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
constexpr int kClippingPredictorMaxGainChange = 15;
|
||||
|
||||
// Returns an input volume in the [`min_input_volume`, `max_input_volume`] range
|
||||
// that reduces `gain_error_db`, which is a gain error estimated when
|
||||
// `input_volume` was applied, according to a fixed gain map.
|
||||
int ComputeVolumeUpdate(int gain_error_db,
|
||||
int input_volume,
|
||||
int min_input_volume,
|
||||
int max_input_volume) {
|
||||
RTC_DCHECK_GE(input_volume, 0);
|
||||
RTC_DCHECK_LE(input_volume, max_input_volume);
|
||||
if (gain_error_db == 0) {
|
||||
return input_volume;
|
||||
}
|
||||
int new_volume = input_volume;
|
||||
if (gain_error_db > 0) {
|
||||
while (kGainMap[new_volume] - kGainMap[input_volume] < gain_error_db &&
|
||||
new_volume < max_input_volume) {
|
||||
++new_volume;
|
||||
}
|
||||
} else {
|
||||
while (kGainMap[new_volume] - kGainMap[input_volume] > gain_error_db &&
|
||||
new_volume > min_input_volume) {
|
||||
--new_volume;
|
||||
}
|
||||
}
|
||||
return new_volume;
|
||||
}
|
||||
|
||||
float ComputeCrestFactor(const ClippingPredictorLevelBuffer::Level& level) {
|
||||
const float crest_factor =
|
||||
FloatS16ToDbfs(level.max) - FloatS16ToDbfs(std::sqrt(level.average));
|
||||
return crest_factor;
|
||||
}
|
||||
|
||||
// Crest factor-based clipping prediction and clipped level step estimation.
|
||||
class ClippingEventPredictor : public ClippingPredictor {
|
||||
public:
|
||||
// ClippingEventPredictor with `num_channels` channels (limited to values
|
||||
// higher than zero); window size `window_length` and reference window size
|
||||
// `reference_window_length` (both referring to the number of frames in the
|
||||
// respective sliding windows and limited to values higher than zero);
|
||||
// reference window delay `reference_window_delay` (delay in frames, limited
|
||||
// to values zero and higher with an additional requirement of
|
||||
// `window_length` < `reference_window_length` + reference_window_delay`);
|
||||
// and an estimation peak threshold `clipping_threshold` and a crest factor
|
||||
// drop threshold `crest_factor_margin` (both in dB).
|
||||
ClippingEventPredictor(int num_channels,
|
||||
int window_length,
|
||||
int reference_window_length,
|
||||
int reference_window_delay,
|
||||
float clipping_threshold,
|
||||
float crest_factor_margin)
|
||||
: window_length_(window_length),
|
||||
reference_window_length_(reference_window_length),
|
||||
reference_window_delay_(reference_window_delay),
|
||||
clipping_threshold_(clipping_threshold),
|
||||
crest_factor_margin_(crest_factor_margin) {
|
||||
RTC_DCHECK_GT(num_channels, 0);
|
||||
RTC_DCHECK_GT(window_length, 0);
|
||||
RTC_DCHECK_GT(reference_window_length, 0);
|
||||
RTC_DCHECK_GE(reference_window_delay, 0);
|
||||
RTC_DCHECK_GT(reference_window_length + reference_window_delay,
|
||||
window_length);
|
||||
const int buffer_length = GetMinFramesProcessed();
|
||||
RTC_DCHECK_GT(buffer_length, 0);
|
||||
for (int i = 0; i < num_channels; ++i) {
|
||||
ch_buffers_.push_back(
|
||||
std::make_unique<ClippingPredictorLevelBuffer>(buffer_length));
|
||||
}
|
||||
}
|
||||
|
||||
ClippingEventPredictor(const ClippingEventPredictor&) = delete;
|
||||
ClippingEventPredictor& operator=(const ClippingEventPredictor&) = delete;
|
||||
~ClippingEventPredictor() {}
|
||||
|
||||
void Reset() {
|
||||
const int num_channels = ch_buffers_.size();
|
||||
for (int i = 0; i < num_channels; ++i) {
|
||||
ch_buffers_[i]->Reset();
|
||||
}
|
||||
}
|
||||
|
||||
// Analyzes a frame of audio and stores the framewise metrics in
|
||||
// `ch_buffers_`.
|
||||
void Analyze(const AudioFrameView<const float>& frame) {
|
||||
const int num_channels = frame.num_channels();
|
||||
RTC_DCHECK_EQ(num_channels, ch_buffers_.size());
|
||||
const int samples_per_channel = frame.samples_per_channel();
|
||||
RTC_DCHECK_GT(samples_per_channel, 0);
|
||||
for (int channel = 0; channel < num_channels; ++channel) {
|
||||
float sum_squares = 0.0f;
|
||||
float peak = 0.0f;
|
||||
for (const auto& sample : frame.channel(channel)) {
|
||||
sum_squares += sample * sample;
|
||||
peak = std::max(std::fabs(sample), peak);
|
||||
}
|
||||
ch_buffers_[channel]->Push(
|
||||
{sum_squares / static_cast<float>(samples_per_channel), peak});
|
||||
}
|
||||
}
|
||||
|
||||
// Estimates the analog gain adjustment for channel `channel` using a
|
||||
// sliding window over the frame-wise metrics in `ch_buffers_`. Returns an
|
||||
// estimate for the clipped level step equal to `default_clipped_level_step_`
|
||||
// if at least `GetMinFramesProcessed()` frames have been processed since the
|
||||
// last reset and a clipping event is predicted. `level`, `min_mic_level`, and
|
||||
// `max_mic_level` are limited to [0, 255] and `default_step` to [1, 255].
|
||||
absl::optional<int> EstimateClippedLevelStep(int channel,
|
||||
int level,
|
||||
int default_step,
|
||||
int min_mic_level,
|
||||
int max_mic_level) const {
|
||||
RTC_CHECK_GE(channel, 0);
|
||||
RTC_CHECK_LT(channel, ch_buffers_.size());
|
||||
RTC_DCHECK_GE(level, 0);
|
||||
RTC_DCHECK_LE(level, 255);
|
||||
RTC_DCHECK_GT(default_step, 0);
|
||||
RTC_DCHECK_LE(default_step, 255);
|
||||
RTC_DCHECK_GE(min_mic_level, 0);
|
||||
RTC_DCHECK_LE(min_mic_level, 255);
|
||||
RTC_DCHECK_GE(max_mic_level, 0);
|
||||
RTC_DCHECK_LE(max_mic_level, 255);
|
||||
if (level <= min_mic_level) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
if (PredictClippingEvent(channel)) {
|
||||
const int new_level =
|
||||
rtc::SafeClamp(level - default_step, min_mic_level, max_mic_level);
|
||||
const int step = level - new_level;
|
||||
if (step > 0) {
|
||||
return step;
|
||||
}
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
private:
|
||||
int GetMinFramesProcessed() const {
|
||||
return reference_window_delay_ + reference_window_length_;
|
||||
}
|
||||
|
||||
// Predicts clipping events based on the processed audio frames. Returns
|
||||
// true if a clipping event is likely.
|
||||
bool PredictClippingEvent(int channel) const {
|
||||
const auto metrics =
|
||||
ch_buffers_[channel]->ComputePartialMetrics(0, window_length_);
|
||||
if (!metrics.has_value() ||
|
||||
!(FloatS16ToDbfs(metrics.value().max) > clipping_threshold_)) {
|
||||
return false;
|
||||
}
|
||||
const auto reference_metrics = ch_buffers_[channel]->ComputePartialMetrics(
|
||||
reference_window_delay_, reference_window_length_);
|
||||
if (!reference_metrics.has_value()) {
|
||||
return false;
|
||||
}
|
||||
const float crest_factor = ComputeCrestFactor(metrics.value());
|
||||
const float reference_crest_factor =
|
||||
ComputeCrestFactor(reference_metrics.value());
|
||||
if (crest_factor < reference_crest_factor - crest_factor_margin_) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<ClippingPredictorLevelBuffer>> ch_buffers_;
|
||||
const int window_length_;
|
||||
const int reference_window_length_;
|
||||
const int reference_window_delay_;
|
||||
const float clipping_threshold_;
|
||||
const float crest_factor_margin_;
|
||||
};
|
||||
|
||||
// Performs crest factor-based clipping peak prediction.
|
||||
class ClippingPeakPredictor : public ClippingPredictor {
|
||||
public:
|
||||
// Ctor. ClippingPeakPredictor with `num_channels` channels (limited to values
|
||||
// higher than zero); window size `window_length` and reference window size
|
||||
// `reference_window_length` (both referring to the number of frames in the
|
||||
// respective sliding windows and limited to values higher than zero);
|
||||
// reference window delay `reference_window_delay` (delay in frames, limited
|
||||
// to values zero and higher with an additional requirement of
|
||||
// `window_length` < `reference_window_length` + reference_window_delay`);
|
||||
// and a clipping prediction threshold `clipping_threshold` (in dB). Adaptive
|
||||
// clipped level step estimation is used if `adaptive_step_estimation` is
|
||||
// true.
|
||||
explicit ClippingPeakPredictor(int num_channels,
|
||||
int window_length,
|
||||
int reference_window_length,
|
||||
int reference_window_delay,
|
||||
int clipping_threshold,
|
||||
bool adaptive_step_estimation)
|
||||
: window_length_(window_length),
|
||||
reference_window_length_(reference_window_length),
|
||||
reference_window_delay_(reference_window_delay),
|
||||
clipping_threshold_(clipping_threshold),
|
||||
adaptive_step_estimation_(adaptive_step_estimation) {
|
||||
RTC_DCHECK_GT(num_channels, 0);
|
||||
RTC_DCHECK_GT(window_length, 0);
|
||||
RTC_DCHECK_GT(reference_window_length, 0);
|
||||
RTC_DCHECK_GE(reference_window_delay, 0);
|
||||
RTC_DCHECK_GT(reference_window_length + reference_window_delay,
|
||||
window_length);
|
||||
const int buffer_length = GetMinFramesProcessed();
|
||||
RTC_DCHECK_GT(buffer_length, 0);
|
||||
for (int i = 0; i < num_channels; ++i) {
|
||||
ch_buffers_.push_back(
|
||||
std::make_unique<ClippingPredictorLevelBuffer>(buffer_length));
|
||||
}
|
||||
}
|
||||
|
||||
ClippingPeakPredictor(const ClippingPeakPredictor&) = delete;
|
||||
ClippingPeakPredictor& operator=(const ClippingPeakPredictor&) = delete;
|
||||
~ClippingPeakPredictor() {}
|
||||
|
||||
void Reset() {
|
||||
const int num_channels = ch_buffers_.size();
|
||||
for (int i = 0; i < num_channels; ++i) {
|
||||
ch_buffers_[i]->Reset();
|
||||
}
|
||||
}
|
||||
|
||||
// Analyzes a frame of audio and stores the framewise metrics in
|
||||
// `ch_buffers_`.
|
||||
void Analyze(const AudioFrameView<const float>& frame) {
|
||||
const int num_channels = frame.num_channels();
|
||||
RTC_DCHECK_EQ(num_channels, ch_buffers_.size());
|
||||
const int samples_per_channel = frame.samples_per_channel();
|
||||
RTC_DCHECK_GT(samples_per_channel, 0);
|
||||
for (int channel = 0; channel < num_channels; ++channel) {
|
||||
float sum_squares = 0.0f;
|
||||
float peak = 0.0f;
|
||||
for (const auto& sample : frame.channel(channel)) {
|
||||
sum_squares += sample * sample;
|
||||
peak = std::max(std::fabs(sample), peak);
|
||||
}
|
||||
ch_buffers_[channel]->Push(
|
||||
{sum_squares / static_cast<float>(samples_per_channel), peak});
|
||||
}
|
||||
}
|
||||
|
||||
// Estimates the analog gain adjustment for channel `channel` using a
|
||||
// sliding window over the frame-wise metrics in `ch_buffers_`. Returns an
|
||||
// estimate for the clipped level step (equal to
|
||||
// `default_clipped_level_step_` if `adaptive_estimation_` is false) if at
|
||||
// least `GetMinFramesProcessed()` frames have been processed since the last
|
||||
// reset and a clipping event is predicted. `level`, `min_mic_level`, and
|
||||
// `max_mic_level` are limited to [0, 255] and `default_step` to [1, 255].
|
||||
absl::optional<int> EstimateClippedLevelStep(int channel,
|
||||
int level,
|
||||
int default_step,
|
||||
int min_mic_level,
|
||||
int max_mic_level) const {
|
||||
RTC_DCHECK_GE(channel, 0);
|
||||
RTC_DCHECK_LT(channel, ch_buffers_.size());
|
||||
RTC_DCHECK_GE(level, 0);
|
||||
RTC_DCHECK_LE(level, 255);
|
||||
RTC_DCHECK_GT(default_step, 0);
|
||||
RTC_DCHECK_LE(default_step, 255);
|
||||
RTC_DCHECK_GE(min_mic_level, 0);
|
||||
RTC_DCHECK_LE(min_mic_level, 255);
|
||||
RTC_DCHECK_GE(max_mic_level, 0);
|
||||
RTC_DCHECK_LE(max_mic_level, 255);
|
||||
if (level <= min_mic_level) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
absl::optional<float> estimate_db = EstimatePeakValue(channel);
|
||||
if (estimate_db.has_value() && estimate_db.value() > clipping_threshold_) {
|
||||
int step = 0;
|
||||
if (!adaptive_step_estimation_) {
|
||||
step = default_step;
|
||||
} else {
|
||||
const int estimated_gain_change =
|
||||
rtc::SafeClamp(-static_cast<int>(std::ceil(estimate_db.value())),
|
||||
-kClippingPredictorMaxGainChange, 0);
|
||||
step =
|
||||
std::max(level - ComputeVolumeUpdate(estimated_gain_change, level,
|
||||
min_mic_level, max_mic_level),
|
||||
default_step);
|
||||
}
|
||||
const int new_level =
|
||||
rtc::SafeClamp(level - step, min_mic_level, max_mic_level);
|
||||
if (level > new_level) {
|
||||
return level - new_level;
|
||||
}
|
||||
}
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
private:
|
||||
int GetMinFramesProcessed() {
|
||||
return reference_window_delay_ + reference_window_length_;
|
||||
}
|
||||
|
||||
// Predicts clipping sample peaks based on the processed audio frames.
|
||||
// Returns the estimated peak value if clipping is predicted. Otherwise
|
||||
// returns absl::nullopt.
|
||||
absl::optional<float> EstimatePeakValue(int channel) const {
|
||||
const auto reference_metrics = ch_buffers_[channel]->ComputePartialMetrics(
|
||||
reference_window_delay_, reference_window_length_);
|
||||
if (!reference_metrics.has_value()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
const auto metrics =
|
||||
ch_buffers_[channel]->ComputePartialMetrics(0, window_length_);
|
||||
if (!metrics.has_value() ||
|
||||
!(FloatS16ToDbfs(metrics.value().max) > clipping_threshold_)) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
const float reference_crest_factor =
|
||||
ComputeCrestFactor(reference_metrics.value());
|
||||
const float& mean_squares = metrics.value().average;
|
||||
const float projected_peak =
|
||||
reference_crest_factor + FloatS16ToDbfs(std::sqrt(mean_squares));
|
||||
return projected_peak;
|
||||
}
|
||||
|
||||
std::vector<std::unique_ptr<ClippingPredictorLevelBuffer>> ch_buffers_;
|
||||
const int window_length_;
|
||||
const int reference_window_length_;
|
||||
const int reference_window_delay_;
|
||||
const int clipping_threshold_;
|
||||
const bool adaptive_step_estimation_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<ClippingPredictor> CreateClippingPredictor(
|
||||
int num_channels,
|
||||
const AudioProcessing::Config::GainController1::AnalogGainController::
|
||||
ClippingPredictor& config) {
|
||||
if (!config.enabled) {
|
||||
RTC_LOG(LS_INFO) << "[AGC2] Clipping prediction disabled.";
|
||||
return nullptr;
|
||||
}
|
||||
RTC_LOG(LS_INFO) << "[AGC2] Clipping prediction enabled.";
|
||||
using ClippingPredictorMode = AudioProcessing::Config::GainController1::
|
||||
AnalogGainController::ClippingPredictor::Mode;
|
||||
switch (config.mode) {
|
||||
case ClippingPredictorMode::kClippingEventPrediction:
|
||||
return std::make_unique<ClippingEventPredictor>(
|
||||
num_channels, config.window_length, config.reference_window_length,
|
||||
config.reference_window_delay, config.clipping_threshold,
|
||||
config.crest_factor_margin);
|
||||
case ClippingPredictorMode::kAdaptiveStepClippingPeakPrediction:
|
||||
return std::make_unique<ClippingPeakPredictor>(
|
||||
num_channels, config.window_length, config.reference_window_length,
|
||||
config.reference_window_delay, config.clipping_threshold,
|
||||
/*adaptive_step_estimation=*/true);
|
||||
case ClippingPredictorMode::kFixedStepClippingPeakPrediction:
|
||||
return std::make_unique<ClippingPeakPredictor>(
|
||||
num_channels, config.window_length, config.reference_window_length,
|
||||
config.reference_window_delay, config.clipping_threshold,
|
||||
/*adaptive_step_estimation=*/false);
|
||||
}
|
||||
RTC_DCHECK_NOTREACHED();
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
#include "modules/audio_processing/include/audio_processing.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// Frame-wise clipping prediction and clipped level step estimation. Analyzes
|
||||
// 10 ms multi-channel frames and estimates an analog mic level decrease step
|
||||
// to possibly avoid clipping when predicted. `Analyze()` and
|
||||
// `EstimateClippedLevelStep()` can be called in any order.
|
||||
class ClippingPredictor {
|
||||
public:
|
||||
virtual ~ClippingPredictor() = default;
|
||||
|
||||
virtual void Reset() = 0;
|
||||
|
||||
// Analyzes a 10 ms multi-channel audio frame.
|
||||
virtual void Analyze(const AudioFrameView<const float>& frame) = 0;
|
||||
|
||||
// Predicts if clipping is going to occur for the specified `channel` in the
|
||||
// near-future and, if so, it returns a recommended analog mic level decrease
|
||||
// step. Returns absl::nullopt if clipping is not predicted.
|
||||
// `level` is the current analog mic level, `default_step` is the amount the
|
||||
// mic level is lowered by the analog controller with every clipping event and
|
||||
// `min_mic_level` and `max_mic_level` is the range of allowed analog mic
|
||||
// levels.
|
||||
virtual absl::optional<int> EstimateClippedLevelStep(
|
||||
int channel,
|
||||
int level,
|
||||
int default_step,
|
||||
int min_mic_level,
|
||||
int max_mic_level) const = 0;
|
||||
};
|
||||
|
||||
// Creates a ClippingPredictor based on the provided `config`. When enabled,
|
||||
// the following must hold for `config`:
|
||||
// `window_length < reference_window_length + reference_window_delay`.
|
||||
// Returns `nullptr` if `config.enabled` is false.
|
||||
std::unique_ptr<ClippingPredictor> CreateClippingPredictor(
|
||||
int num_channels,
|
||||
const AudioProcessing::Config::GainController1::AnalogGainController::
|
||||
ClippingPredictor& config);
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_H_
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/clipping_predictor_level_buffer.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
bool ClippingPredictorLevelBuffer::Level::operator==(const Level& level) const {
|
||||
constexpr float kEpsilon = 1e-6f;
|
||||
return std::fabs(average - level.average) < kEpsilon &&
|
||||
std::fabs(max - level.max) < kEpsilon;
|
||||
}
|
||||
|
||||
ClippingPredictorLevelBuffer::ClippingPredictorLevelBuffer(int capacity)
|
||||
: tail_(-1), size_(0), data_(std::max(1, capacity)) {
|
||||
if (capacity > kMaxCapacity) {
|
||||
RTC_LOG(LS_WARNING) << "[agc]: ClippingPredictorLevelBuffer exceeds the "
|
||||
<< "maximum allowed capacity. Capacity: " << capacity;
|
||||
}
|
||||
RTC_DCHECK(!data_.empty());
|
||||
}
|
||||
|
||||
void ClippingPredictorLevelBuffer::Reset() {
|
||||
tail_ = -1;
|
||||
size_ = 0;
|
||||
}
|
||||
|
||||
void ClippingPredictorLevelBuffer::Push(Level level) {
|
||||
++tail_;
|
||||
if (tail_ == Capacity()) {
|
||||
tail_ = 0;
|
||||
}
|
||||
if (size_ < Capacity()) {
|
||||
size_++;
|
||||
}
|
||||
data_[tail_] = level;
|
||||
}
|
||||
|
||||
// TODO(bugs.webrtc.org/12774): Optimize partial computation for long buffers.
|
||||
absl::optional<ClippingPredictorLevelBuffer::Level>
|
||||
ClippingPredictorLevelBuffer::ComputePartialMetrics(int delay,
|
||||
int num_items) const {
|
||||
RTC_DCHECK_GE(delay, 0);
|
||||
RTC_DCHECK_LT(delay, Capacity());
|
||||
RTC_DCHECK_GT(num_items, 0);
|
||||
RTC_DCHECK_LE(num_items, Capacity());
|
||||
RTC_DCHECK_LE(delay + num_items, Capacity());
|
||||
if (delay + num_items > Size()) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
float sum = 0.0f;
|
||||
float max = 0.0f;
|
||||
for (int i = 0; i < num_items && i < Size(); ++i) {
|
||||
int idx = tail_ - delay - i;
|
||||
if (idx < 0) {
|
||||
idx += Capacity();
|
||||
}
|
||||
sum += data_[idx].average;
|
||||
max = std::fmax(data_[idx].max, max);
|
||||
}
|
||||
return absl::optional<Level>({sum / static_cast<float>(num_items), max});
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
/*
|
||||
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_LEVEL_BUFFER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_LEVEL_BUFFER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// A circular buffer to store frame-wise `Level` items for clipping prediction.
|
||||
// The current implementation is not optimized for large buffer lengths.
|
||||
class ClippingPredictorLevelBuffer {
|
||||
public:
|
||||
struct Level {
|
||||
float average;
|
||||
float max;
|
||||
bool operator==(const Level& level) const;
|
||||
};
|
||||
|
||||
// Recommended maximum capacity. It is possible to create a buffer with a
|
||||
// larger capacity, but the implementation is not optimized for large values.
|
||||
static constexpr int kMaxCapacity = 100;
|
||||
|
||||
// Ctor. Sets the buffer capacity to max(1, `capacity`) and logs a warning
|
||||
// message if the capacity is greater than `kMaxCapacity`.
|
||||
explicit ClippingPredictorLevelBuffer(int capacity);
|
||||
~ClippingPredictorLevelBuffer() {}
|
||||
ClippingPredictorLevelBuffer(const ClippingPredictorLevelBuffer&) = delete;
|
||||
ClippingPredictorLevelBuffer& operator=(const ClippingPredictorLevelBuffer&) =
|
||||
delete;
|
||||
|
||||
void Reset();
|
||||
|
||||
// Returns the current number of items stored in the buffer.
|
||||
int Size() const { return size_; }
|
||||
|
||||
// Returns the capacity of the buffer.
|
||||
int Capacity() const { return data_.size(); }
|
||||
|
||||
// Adds a `level` item into the circular buffer `data_`. Stores at most
|
||||
// `Capacity()` items. If more items are pushed, the new item replaces the
|
||||
// least recently pushed item.
|
||||
void Push(Level level);
|
||||
|
||||
// If at least `num_items` + `delay` items have been pushed, returns the
|
||||
// average and maximum value for the `num_items` most recently pushed items
|
||||
// from `delay` to `delay` - `num_items` (a delay equal to zero corresponds
|
||||
// to the most recently pushed item). The value of `delay` is limited to
|
||||
// [0, N] and `num_items` to [1, M] where N + M is the capacity of the buffer.
|
||||
absl::optional<Level> ComputePartialMetrics(int delay, int num_items) const;
|
||||
|
||||
private:
|
||||
int tail_;
|
||||
int size_;
|
||||
std::vector<Level> data_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_LEVEL_BUFFER_H_
|
||||
|
|
@ -0,0 +1,229 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/compute_interpolated_gain_curve.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <queue>
|
||||
#include <tuple>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/agc2/agc2_testing_common.h"
|
||||
#include "modules/audio_processing/agc2/limiter_db_gain_curve.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
std::pair<double, double> ComputeLinearApproximationParams(
|
||||
const LimiterDbGainCurve* limiter,
|
||||
const double x) {
|
||||
const double m = limiter->GetGainFirstDerivativeLinear(x);
|
||||
const double q = limiter->GetGainLinear(x) - m * x;
|
||||
return {m, q};
|
||||
}
|
||||
|
||||
double ComputeAreaUnderPiecewiseLinearApproximation(
|
||||
const LimiterDbGainCurve* limiter,
|
||||
const double x0,
|
||||
const double x1) {
|
||||
RTC_CHECK_LT(x0, x1);
|
||||
|
||||
// Linear approximation in x0 and x1.
|
||||
double m0, q0, m1, q1;
|
||||
std::tie(m0, q0) = ComputeLinearApproximationParams(limiter, x0);
|
||||
std::tie(m1, q1) = ComputeLinearApproximationParams(limiter, x1);
|
||||
|
||||
// Intersection point between two adjacent linear pieces.
|
||||
RTC_CHECK_NE(m1, m0);
|
||||
const double x_split = (q0 - q1) / (m1 - m0);
|
||||
RTC_CHECK_LT(x0, x_split);
|
||||
RTC_CHECK_LT(x_split, x1);
|
||||
|
||||
auto area_under_linear_piece = [](double x_l, double x_r, double m,
|
||||
double q) {
|
||||
return x_r * (m * x_r / 2.0 + q) - x_l * (m * x_l / 2.0 + q);
|
||||
};
|
||||
return area_under_linear_piece(x0, x_split, m0, q0) +
|
||||
area_under_linear_piece(x_split, x1, m1, q1);
|
||||
}
|
||||
|
||||
// Computes the approximation error in the limiter region for a given interval.
|
||||
// The error is computed as the difference between the areas beneath the limiter
|
||||
// curve to approximate and its linear under-approximation.
|
||||
double LimiterUnderApproximationNegativeError(const LimiterDbGainCurve* limiter,
|
||||
const double x0,
|
||||
const double x1) {
|
||||
const double area_limiter = limiter->GetGainIntegralLinear(x0, x1);
|
||||
const double area_interpolated_curve =
|
||||
ComputeAreaUnderPiecewiseLinearApproximation(limiter, x0, x1);
|
||||
RTC_CHECK_GE(area_limiter, area_interpolated_curve);
|
||||
return area_limiter - area_interpolated_curve;
|
||||
}
|
||||
|
||||
// Automatically finds where to sample the beyond-knee region of a limiter using
|
||||
// a greedy optimization algorithm that iteratively decreases the approximation
|
||||
// error.
|
||||
// The solution is sub-optimal because the algorithm is greedy and the points
|
||||
// are assigned by halving intervals (starting with the whole beyond-knee region
|
||||
// as a single interval). However, even if sub-optimal, this algorithm works
|
||||
// well in practice and it is efficiently implemented using priority queues.
|
||||
std::vector<double> SampleLimiterRegion(const LimiterDbGainCurve* limiter) {
|
||||
static_assert(kInterpolatedGainCurveBeyondKneePoints > 2, "");
|
||||
|
||||
struct Interval {
|
||||
Interval() = default; // Ctor required by std::priority_queue.
|
||||
Interval(double l, double r, double e) : x0(l), x1(r), error(e) {
|
||||
RTC_CHECK(x0 < x1);
|
||||
}
|
||||
bool operator<(const Interval& other) const { return error < other.error; }
|
||||
|
||||
double x0;
|
||||
double x1;
|
||||
double error;
|
||||
};
|
||||
|
||||
std::priority_queue<Interval, std::vector<Interval>> q;
|
||||
q.emplace(limiter->limiter_start_linear(), limiter->max_input_level_linear(),
|
||||
LimiterUnderApproximationNegativeError(
|
||||
limiter, limiter->limiter_start_linear(),
|
||||
limiter->max_input_level_linear()));
|
||||
|
||||
// Iteratively find points by halving the interval with greatest error.
|
||||
while (q.size() < kInterpolatedGainCurveBeyondKneePoints) {
|
||||
// Get the interval with highest error.
|
||||
const auto interval = q.top();
|
||||
q.pop();
|
||||
|
||||
// Split `interval` and enqueue.
|
||||
double x_split = (interval.x0 + interval.x1) / 2.0;
|
||||
q.emplace(interval.x0, x_split,
|
||||
LimiterUnderApproximationNegativeError(limiter, interval.x0,
|
||||
x_split)); // Left.
|
||||
q.emplace(x_split, interval.x1,
|
||||
LimiterUnderApproximationNegativeError(limiter, x_split,
|
||||
interval.x1)); // Right.
|
||||
}
|
||||
|
||||
// Copy x1 values and sort them.
|
||||
RTC_CHECK_EQ(q.size(), kInterpolatedGainCurveBeyondKneePoints);
|
||||
std::vector<double> samples(kInterpolatedGainCurveBeyondKneePoints);
|
||||
for (size_t i = 0; i < kInterpolatedGainCurveBeyondKneePoints; ++i) {
|
||||
const auto interval = q.top();
|
||||
q.pop();
|
||||
samples[i] = interval.x1;
|
||||
}
|
||||
RTC_CHECK(q.empty());
|
||||
std::sort(samples.begin(), samples.end());
|
||||
|
||||
return samples;
|
||||
}
|
||||
|
||||
// Compute the parameters to over-approximate the knee region via linear
|
||||
// interpolation. Over-approximating is saturation-safe since the knee region is
|
||||
// convex.
|
||||
void PrecomputeKneeApproxParams(const LimiterDbGainCurve* limiter,
|
||||
test::InterpolatedParameters* parameters) {
|
||||
static_assert(kInterpolatedGainCurveKneePoints > 2, "");
|
||||
// Get `kInterpolatedGainCurveKneePoints` - 1 equally spaced points.
|
||||
const std::vector<double> points = test::LinSpace(
|
||||
limiter->knee_start_linear(), limiter->limiter_start_linear(),
|
||||
kInterpolatedGainCurveKneePoints - 1);
|
||||
|
||||
// Set the first two points. The second is computed to help with the beginning
|
||||
// of the knee region, which has high curvature.
|
||||
parameters->computed_approximation_params_x[0] = points[0];
|
||||
parameters->computed_approximation_params_x[1] =
|
||||
(points[0] + points[1]) / 2.0;
|
||||
// Copy the remaining points.
|
||||
std::copy(std::begin(points) + 1, std::end(points),
|
||||
std::begin(parameters->computed_approximation_params_x) + 2);
|
||||
|
||||
// Compute (m, q) pairs for each linear piece y = mx + q.
|
||||
for (size_t i = 0; i < kInterpolatedGainCurveKneePoints - 1; ++i) {
|
||||
const double x0 = parameters->computed_approximation_params_x[i];
|
||||
const double x1 = parameters->computed_approximation_params_x[i + 1];
|
||||
const double y0 = limiter->GetGainLinear(x0);
|
||||
const double y1 = limiter->GetGainLinear(x1);
|
||||
RTC_CHECK_NE(x1, x0);
|
||||
parameters->computed_approximation_params_m[i] = (y1 - y0) / (x1 - x0);
|
||||
parameters->computed_approximation_params_q[i] =
|
||||
y0 - parameters->computed_approximation_params_m[i] * x0;
|
||||
}
|
||||
}
|
||||
|
||||
// Compute the parameters to under-approximate the beyond-knee region via linear
|
||||
// interpolation and greedy sampling. Under-approximating is saturation-safe
|
||||
// since the beyond-knee region is concave.
|
||||
void PrecomputeBeyondKneeApproxParams(
|
||||
const LimiterDbGainCurve* limiter,
|
||||
test::InterpolatedParameters* parameters) {
|
||||
// Find points on which the linear pieces are tangent to the gain curve.
|
||||
const auto samples = SampleLimiterRegion(limiter);
|
||||
|
||||
// Parametrize each linear piece.
|
||||
double m, q;
|
||||
std::tie(m, q) = ComputeLinearApproximationParams(
|
||||
limiter,
|
||||
parameters
|
||||
->computed_approximation_params_x[kInterpolatedGainCurveKneePoints -
|
||||
1]);
|
||||
parameters
|
||||
->computed_approximation_params_m[kInterpolatedGainCurveKneePoints - 1] =
|
||||
m;
|
||||
parameters
|
||||
->computed_approximation_params_q[kInterpolatedGainCurveKneePoints - 1] =
|
||||
q;
|
||||
for (size_t i = 0; i < samples.size(); ++i) {
|
||||
std::tie(m, q) = ComputeLinearApproximationParams(limiter, samples[i]);
|
||||
parameters
|
||||
->computed_approximation_params_m[i +
|
||||
kInterpolatedGainCurveKneePoints] = m;
|
||||
parameters
|
||||
->computed_approximation_params_q[i +
|
||||
kInterpolatedGainCurveKneePoints] = q;
|
||||
}
|
||||
|
||||
// Find the point of intersection between adjacent linear pieces. They will be
|
||||
// used as boundaries between adjacent linear pieces.
|
||||
for (size_t i = kInterpolatedGainCurveKneePoints;
|
||||
i < kInterpolatedGainCurveKneePoints +
|
||||
kInterpolatedGainCurveBeyondKneePoints;
|
||||
++i) {
|
||||
RTC_CHECK_NE(parameters->computed_approximation_params_m[i],
|
||||
parameters->computed_approximation_params_m[i - 1]);
|
||||
parameters->computed_approximation_params_x[i] =
|
||||
( // Formula: (q0 - q1) / (m1 - m0).
|
||||
parameters->computed_approximation_params_q[i - 1] -
|
||||
parameters->computed_approximation_params_q[i]) /
|
||||
(parameters->computed_approximation_params_m[i] -
|
||||
parameters->computed_approximation_params_m[i - 1]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
namespace test {
|
||||
|
||||
InterpolatedParameters ComputeInterpolatedGainCurveApproximationParams() {
|
||||
InterpolatedParameters parameters;
|
||||
LimiterDbGainCurve limiter;
|
||||
parameters.computed_approximation_params_x.fill(0.0f);
|
||||
parameters.computed_approximation_params_m.fill(0.0f);
|
||||
parameters.computed_approximation_params_q.fill(0.0f);
|
||||
PrecomputeKneeApproxParams(&limiter, ¶meters);
|
||||
PrecomputeBeyondKneeApproxParams(&limiter, ¶meters);
|
||||
return parameters;
|
||||
}
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_COMPUTE_INTERPOLATED_GAIN_CURVE_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_COMPUTE_INTERPOLATED_GAIN_CURVE_H_
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
namespace test {
|
||||
|
||||
// Parameters for interpolated gain curve using under-approximation to
|
||||
// avoid saturation.
|
||||
//
|
||||
// The saturation gain is defined in order to let hard-clipping occur for
|
||||
// those samples having a level that falls in the saturation region. It is an
|
||||
// upper bound of the actual gain to apply - i.e., that returned by the
|
||||
// limiter.
|
||||
|
||||
// Knee and beyond-knee regions approximation parameters.
|
||||
// The gain curve is approximated as a piece-wise linear function.
|
||||
// `approx_params_x_` are the boundaries between adjacent linear pieces,
|
||||
// `approx_params_m_` and `approx_params_q_` are the slope and the y-intercept
|
||||
// values of each piece.
|
||||
struct InterpolatedParameters {
|
||||
std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
computed_approximation_params_x;
|
||||
std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
computed_approximation_params_m;
|
||||
std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
computed_approximation_params_q;
|
||||
};
|
||||
|
||||
InterpolatedParameters ComputeInterpolatedGainCurveApproximationParams();
|
||||
} // namespace test
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_COMPUTE_INTERPOLATED_GAIN_CURVE_H_
|
||||
|
|
@ -0,0 +1,62 @@
|
|||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
|
||||
#include "rtc_base/strings/string_builder.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
#include "system_wrappers/include/cpu_features_wrapper.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
std::string AvailableCpuFeatures::ToString() const {
|
||||
char buf[64];
|
||||
rtc::SimpleStringBuilder builder(buf);
|
||||
bool first = true;
|
||||
if (sse2) {
|
||||
builder << (first ? "SSE2" : "_SSE2");
|
||||
first = false;
|
||||
}
|
||||
if (avx2) {
|
||||
builder << (first ? "AVX2" : "_AVX2");
|
||||
first = false;
|
||||
}
|
||||
if (neon) {
|
||||
builder << (first ? "NEON" : "_NEON");
|
||||
first = false;
|
||||
}
|
||||
if (first) {
|
||||
return "none";
|
||||
}
|
||||
return builder.str();
|
||||
}
|
||||
|
||||
// Detects available CPU features.
|
||||
AvailableCpuFeatures GetAvailableCpuFeatures() {
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
return {/*sse2=*/GetCPUInfo(kSSE2) != 0,
|
||||
/*avx2=*/GetCPUInfo(kAVX2) != 0,
|
||||
/*neon=*/false};
|
||||
#elif defined(WEBRTC_HAS_NEON)
|
||||
return {/*sse2=*/false,
|
||||
/*avx2=*/false,
|
||||
/*neon=*/true};
|
||||
#else
|
||||
return {/*sse2=*/false,
|
||||
/*avx2=*/false,
|
||||
/*neon=*/false};
|
||||
#endif
|
||||
}
|
||||
|
||||
AvailableCpuFeatures NoAvailableCpuFeatures() {
|
||||
return {/*sse2=*/false, /*avx2=*/false, /*neon=*/false};
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_
|
||||
|
||||
#include <string>
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// Collection of flags indicating which CPU features are available on the
|
||||
// current platform. True means available.
|
||||
struct AvailableCpuFeatures {
|
||||
AvailableCpuFeatures(bool sse2, bool avx2, bool neon)
|
||||
: sse2(sse2), avx2(avx2), neon(neon) {}
|
||||
// Intel.
|
||||
bool sse2;
|
||||
bool avx2;
|
||||
// ARM.
|
||||
bool neon;
|
||||
std::string ToString() const;
|
||||
};
|
||||
|
||||
// Detects what CPU features are available.
|
||||
AvailableCpuFeatures GetAvailableCpuFeatures();
|
||||
|
||||
// Returns the CPU feature flags all set to false.
|
||||
AvailableCpuFeatures NoAvailableCpuFeatures();
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_
|
||||
|
|
@ -0,0 +1,121 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/fixed_digital_level_estimator.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
constexpr float kInitialFilterStateLevel = 0.0f;
|
||||
|
||||
// Instant attack.
|
||||
constexpr float kAttackFilterConstant = 0.0f;
|
||||
|
||||
// Limiter decay constant.
|
||||
// Computed as `10 ** (-1/20 * subframe_duration / kDecayMs)` where:
|
||||
// - `subframe_duration` is `kFrameDurationMs / kSubFramesInFrame`;
|
||||
// - `kDecayMs` is defined in agc2_testing_common.h.
|
||||
constexpr float kDecayFilterConstant = 0.9971259f;
|
||||
|
||||
} // namespace
|
||||
|
||||
FixedDigitalLevelEstimator::FixedDigitalLevelEstimator(
|
||||
int sample_rate_hz,
|
||||
ApmDataDumper* apm_data_dumper)
|
||||
: apm_data_dumper_(apm_data_dumper),
|
||||
filter_state_level_(kInitialFilterStateLevel) {
|
||||
SetSampleRate(sample_rate_hz);
|
||||
CheckParameterCombination();
|
||||
RTC_DCHECK(apm_data_dumper_);
|
||||
apm_data_dumper_->DumpRaw("agc2_level_estimator_samplerate", sample_rate_hz);
|
||||
}
|
||||
|
||||
void FixedDigitalLevelEstimator::CheckParameterCombination() {
|
||||
RTC_DCHECK_GT(samples_in_frame_, 0);
|
||||
RTC_DCHECK_LE(kSubFramesInFrame, samples_in_frame_);
|
||||
RTC_DCHECK_EQ(samples_in_frame_ % kSubFramesInFrame, 0);
|
||||
RTC_DCHECK_GT(samples_in_sub_frame_, 1);
|
||||
}
|
||||
|
||||
std::array<float, kSubFramesInFrame> FixedDigitalLevelEstimator::ComputeLevel(
|
||||
const AudioFrameView<const float>& float_frame) {
|
||||
RTC_DCHECK_GT(float_frame.num_channels(), 0);
|
||||
RTC_DCHECK_EQ(float_frame.samples_per_channel(), samples_in_frame_);
|
||||
|
||||
// Compute max envelope without smoothing.
|
||||
std::array<float, kSubFramesInFrame> envelope{};
|
||||
for (int channel_idx = 0; channel_idx < float_frame.num_channels();
|
||||
++channel_idx) {
|
||||
const auto channel = float_frame.channel(channel_idx);
|
||||
for (int sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) {
|
||||
for (int sample_in_sub_frame = 0;
|
||||
sample_in_sub_frame < samples_in_sub_frame_; ++sample_in_sub_frame) {
|
||||
envelope[sub_frame] =
|
||||
std::max(envelope[sub_frame],
|
||||
std::abs(channel[sub_frame * samples_in_sub_frame_ +
|
||||
sample_in_sub_frame]));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Make sure envelope increases happen one step earlier so that the
|
||||
// corresponding *gain decrease* doesn't miss a sudden signal
|
||||
// increase due to interpolation.
|
||||
for (int sub_frame = 0; sub_frame < kSubFramesInFrame - 1; ++sub_frame) {
|
||||
if (envelope[sub_frame] < envelope[sub_frame + 1]) {
|
||||
envelope[sub_frame] = envelope[sub_frame + 1];
|
||||
}
|
||||
}
|
||||
|
||||
// Add attack / decay smoothing.
|
||||
for (int sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) {
|
||||
const float envelope_value = envelope[sub_frame];
|
||||
if (envelope_value > filter_state_level_) {
|
||||
envelope[sub_frame] = envelope_value * (1 - kAttackFilterConstant) +
|
||||
filter_state_level_ * kAttackFilterConstant;
|
||||
} else {
|
||||
envelope[sub_frame] = envelope_value * (1 - kDecayFilterConstant) +
|
||||
filter_state_level_ * kDecayFilterConstant;
|
||||
}
|
||||
filter_state_level_ = envelope[sub_frame];
|
||||
|
||||
// Dump data for debug.
|
||||
RTC_DCHECK(apm_data_dumper_);
|
||||
const auto channel = float_frame.channel(0);
|
||||
apm_data_dumper_->DumpRaw("agc2_level_estimator_samples",
|
||||
samples_in_sub_frame_,
|
||||
&channel[sub_frame * samples_in_sub_frame_]);
|
||||
apm_data_dumper_->DumpRaw("agc2_level_estimator_level",
|
||||
envelope[sub_frame]);
|
||||
}
|
||||
|
||||
return envelope;
|
||||
}
|
||||
|
||||
void FixedDigitalLevelEstimator::SetSampleRate(int sample_rate_hz) {
|
||||
samples_in_frame_ =
|
||||
rtc::CheckedDivExact(sample_rate_hz * kFrameDurationMs, 1000);
|
||||
samples_in_sub_frame_ =
|
||||
rtc::CheckedDivExact(samples_in_frame_, kSubFramesInFrame);
|
||||
CheckParameterCombination();
|
||||
}
|
||||
|
||||
void FixedDigitalLevelEstimator::Reset() {
|
||||
filter_state_level_ = kInitialFilterStateLevel;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,66 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_FIXED_DIGITAL_LEVEL_ESTIMATOR_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_FIXED_DIGITAL_LEVEL_ESTIMATOR_H_
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
class ApmDataDumper;
|
||||
// Produces a smooth signal level estimate from an input audio
|
||||
// stream. The estimate smoothing is done through exponential
|
||||
// filtering.
|
||||
class FixedDigitalLevelEstimator {
|
||||
public:
|
||||
// Sample rates are allowed if the number of samples in a frame
|
||||
// (sample_rate_hz * kFrameDurationMs / 1000) is divisible by
|
||||
// kSubFramesInSample. For kFrameDurationMs=10 and
|
||||
// kSubFramesInSample=20, this means that sample_rate_hz has to be
|
||||
// divisible by 2000.
|
||||
FixedDigitalLevelEstimator(int sample_rate_hz,
|
||||
ApmDataDumper* apm_data_dumper);
|
||||
|
||||
FixedDigitalLevelEstimator(const FixedDigitalLevelEstimator&) = delete;
|
||||
FixedDigitalLevelEstimator& operator=(const FixedDigitalLevelEstimator&) =
|
||||
delete;
|
||||
|
||||
// The input is assumed to be in FloatS16 format. Scaled input will
|
||||
// produce similarly scaled output. A frame of with kFrameDurationMs
|
||||
// ms of audio produces a level estimates in the same scale. The
|
||||
// level estimate contains kSubFramesInFrame values.
|
||||
std::array<float, kSubFramesInFrame> ComputeLevel(
|
||||
const AudioFrameView<const float>& float_frame);
|
||||
|
||||
// Rate may be changed at any time (but not concurrently) from the
|
||||
// value passed to the constructor. The class is not thread safe.
|
||||
void SetSampleRate(int sample_rate_hz);
|
||||
|
||||
// Resets the level estimator internal state.
|
||||
void Reset();
|
||||
|
||||
float LastAudioLevel() const { return filter_state_level_; }
|
||||
|
||||
private:
|
||||
void CheckParameterCombination();
|
||||
|
||||
ApmDataDumper* const apm_data_dumper_ = nullptr;
|
||||
float filter_state_level_;
|
||||
int samples_in_frame_;
|
||||
int samples_in_sub_frame_;
|
||||
};
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_FIXED_DIGITAL_LEVEL_ESTIMATOR_H_
|
||||
|
|
@ -0,0 +1,103 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/gain_applier.h"
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
// Returns true when the gain factor is so close to 1 that it would
|
||||
// not affect int16 samples.
|
||||
bool GainCloseToOne(float gain_factor) {
|
||||
return 1.f - 1.f / kMaxFloatS16Value <= gain_factor &&
|
||||
gain_factor <= 1.f + 1.f / kMaxFloatS16Value;
|
||||
}
|
||||
|
||||
void ClipSignal(AudioFrameView<float> signal) {
|
||||
for (int k = 0; k < signal.num_channels(); ++k) {
|
||||
rtc::ArrayView<float> channel_view = signal.channel(k);
|
||||
for (auto& sample : channel_view) {
|
||||
sample = rtc::SafeClamp(sample, kMinFloatS16Value, kMaxFloatS16Value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ApplyGainWithRamping(float last_gain_linear,
|
||||
float gain_at_end_of_frame_linear,
|
||||
float inverse_samples_per_channel,
|
||||
AudioFrameView<float> float_frame) {
|
||||
// Do not modify the signal.
|
||||
if (last_gain_linear == gain_at_end_of_frame_linear &&
|
||||
GainCloseToOne(gain_at_end_of_frame_linear)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Gain is constant and different from 1.
|
||||
if (last_gain_linear == gain_at_end_of_frame_linear) {
|
||||
for (int k = 0; k < float_frame.num_channels(); ++k) {
|
||||
rtc::ArrayView<float> channel_view = float_frame.channel(k);
|
||||
for (auto& sample : channel_view) {
|
||||
sample *= gain_at_end_of_frame_linear;
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// The gain changes. We have to change slowly to avoid discontinuities.
|
||||
const float increment = (gain_at_end_of_frame_linear - last_gain_linear) *
|
||||
inverse_samples_per_channel;
|
||||
float gain = last_gain_linear;
|
||||
for (int i = 0; i < float_frame.samples_per_channel(); ++i) {
|
||||
for (int ch = 0; ch < float_frame.num_channels(); ++ch) {
|
||||
float_frame.channel(ch)[i] *= gain;
|
||||
}
|
||||
gain += increment;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GainApplier::GainApplier(bool hard_clip_samples, float initial_gain_factor)
|
||||
: hard_clip_samples_(hard_clip_samples),
|
||||
last_gain_factor_(initial_gain_factor),
|
||||
current_gain_factor_(initial_gain_factor) {}
|
||||
|
||||
void GainApplier::ApplyGain(AudioFrameView<float> signal) {
|
||||
if (static_cast<int>(signal.samples_per_channel()) != samples_per_channel_) {
|
||||
Initialize(signal.samples_per_channel());
|
||||
}
|
||||
|
||||
ApplyGainWithRamping(last_gain_factor_, current_gain_factor_,
|
||||
inverse_samples_per_channel_, signal);
|
||||
|
||||
last_gain_factor_ = current_gain_factor_;
|
||||
|
||||
if (hard_clip_samples_) {
|
||||
ClipSignal(signal);
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(bugs.webrtc.org/7494): Remove once switched to gains in dB.
|
||||
void GainApplier::SetGainFactor(float gain_factor) {
|
||||
RTC_DCHECK_GT(gain_factor, 0.f);
|
||||
current_gain_factor_ = gain_factor;
|
||||
}
|
||||
|
||||
void GainApplier::Initialize(int samples_per_channel) {
|
||||
RTC_DCHECK_GT(samples_per_channel, 0);
|
||||
samples_per_channel_ = static_cast<int>(samples_per_channel);
|
||||
inverse_samples_per_channel_ = 1.f / samples_per_channel_;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,44 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_GAIN_APPLIER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_GAIN_APPLIER_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
class GainApplier {
|
||||
public:
|
||||
GainApplier(bool hard_clip_samples, float initial_gain_factor);
|
||||
|
||||
void ApplyGain(AudioFrameView<float> signal);
|
||||
void SetGainFactor(float gain_factor);
|
||||
float GetGainFactor() const { return current_gain_factor_; }
|
||||
|
||||
private:
|
||||
void Initialize(int samples_per_channel);
|
||||
|
||||
// Whether to clip samples after gain is applied. If 'true', result
|
||||
// will fit in FloatS16 range.
|
||||
const bool hard_clip_samples_;
|
||||
float last_gain_factor_;
|
||||
|
||||
// If this value is not equal to 'last_gain_factor', gain will be
|
||||
// ramped from 'last_gain_factor_' to this value during the next
|
||||
// 'ApplyGain'.
|
||||
float current_gain_factor_;
|
||||
int samples_per_channel_ = -1;
|
||||
float inverse_samples_per_channel_ = -1.f;
|
||||
};
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_GAIN_APPLIER_H_
|
||||
|
|
@ -0,0 +1,93 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/gain_applier.h"
|
||||
|
||||
#include <math.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
|
||||
#include "modules/audio_processing/agc2/vector_float_frame.h"
|
||||
#include "rtc_base/gunit.h"
|
||||
|
||||
namespace webrtc {
|
||||
TEST(AutomaticGainController2GainApplier, InitialGainIsRespected) {
|
||||
constexpr float initial_signal_level = 123.f;
|
||||
constexpr float gain_factor = 10.f;
|
||||
VectorFloatFrame fake_audio(1, 1, initial_signal_level);
|
||||
GainApplier gain_applier(true, gain_factor);
|
||||
|
||||
gain_applier.ApplyGain(fake_audio.float_frame_view());
|
||||
EXPECT_NEAR(fake_audio.float_frame_view().channel(0)[0],
|
||||
initial_signal_level * gain_factor, 0.1f);
|
||||
}
|
||||
|
||||
TEST(AutomaticGainController2GainApplier, ClippingIsDone) {
|
||||
constexpr float initial_signal_level = 30000.f;
|
||||
constexpr float gain_factor = 10.f;
|
||||
VectorFloatFrame fake_audio(1, 1, initial_signal_level);
|
||||
GainApplier gain_applier(true, gain_factor);
|
||||
|
||||
gain_applier.ApplyGain(fake_audio.float_frame_view());
|
||||
EXPECT_NEAR(fake_audio.float_frame_view().channel(0)[0],
|
||||
std::numeric_limits<int16_t>::max(), 0.1f);
|
||||
}
|
||||
|
||||
TEST(AutomaticGainController2GainApplier, ClippingIsNotDone) {
|
||||
constexpr float initial_signal_level = 30000.f;
|
||||
constexpr float gain_factor = 10.f;
|
||||
VectorFloatFrame fake_audio(1, 1, initial_signal_level);
|
||||
GainApplier gain_applier(false, gain_factor);
|
||||
|
||||
gain_applier.ApplyGain(fake_audio.float_frame_view());
|
||||
|
||||
EXPECT_NEAR(fake_audio.float_frame_view().channel(0)[0],
|
||||
initial_signal_level * gain_factor, 0.1f);
|
||||
}
|
||||
|
||||
TEST(AutomaticGainController2GainApplier, RampingIsDone) {
|
||||
constexpr float initial_signal_level = 30000.f;
|
||||
constexpr float initial_gain_factor = 1.f;
|
||||
constexpr float target_gain_factor = 0.5f;
|
||||
constexpr int num_channels = 3;
|
||||
constexpr int samples_per_channel = 4;
|
||||
VectorFloatFrame fake_audio(num_channels, samples_per_channel,
|
||||
initial_signal_level);
|
||||
GainApplier gain_applier(false, initial_gain_factor);
|
||||
|
||||
gain_applier.SetGainFactor(target_gain_factor);
|
||||
gain_applier.ApplyGain(fake_audio.float_frame_view());
|
||||
|
||||
// The maximal gain change should be close to that in linear interpolation.
|
||||
for (size_t channel = 0; channel < num_channels; ++channel) {
|
||||
float max_signal_change = 0.f;
|
||||
float last_signal_level = initial_signal_level;
|
||||
for (const auto sample : fake_audio.float_frame_view().channel(channel)) {
|
||||
const float current_change = fabs(last_signal_level - sample);
|
||||
max_signal_change = std::max(max_signal_change, current_change);
|
||||
last_signal_level = sample;
|
||||
}
|
||||
const float total_gain_change =
|
||||
fabs((initial_gain_factor - target_gain_factor) * initial_signal_level);
|
||||
EXPECT_NEAR(max_signal_change, total_gain_change / samples_per_channel,
|
||||
0.1f);
|
||||
}
|
||||
|
||||
// Next frame should have the desired level.
|
||||
VectorFloatFrame next_fake_audio_frame(num_channels, samples_per_channel,
|
||||
initial_signal_level);
|
||||
gain_applier.ApplyGain(next_fake_audio_frame.float_frame_view());
|
||||
|
||||
// The last sample should have the new gain.
|
||||
EXPECT_NEAR(next_fake_audio_frame.float_frame_view().channel(0)[0],
|
||||
initial_signal_level * target_gain_factor, 0.1f);
|
||||
}
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_GAIN_MAP_INTERNAL_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_GAIN_MAP_INTERNAL_H_
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
static constexpr int kGainMapSize = 256;
|
||||
// Maps input volumes, which are values in the [0, 255] range, to gains in dB.
|
||||
// The values below are generated with numpy as follows:
|
||||
// SI = 2 # Initial slope.
|
||||
// SF = 0.25 # Final slope.
|
||||
// D = 8/256 # Quantization factor.
|
||||
// x = np.linspace(0, 255, 256) # Input volumes.
|
||||
// y = (SF * x + (SI - SF) * (1 - np.exp(-D*x)) / D - 56).round()
|
||||
static const int kGainMap[kGainMapSize] = {
|
||||
-56, -54, -52, -50, -48, -47, -45, -43, -42, -40, -38, -37, -35, -34, -33,
|
||||
-31, -30, -29, -27, -26, -25, -24, -23, -22, -20, -19, -18, -17, -16, -15,
|
||||
-14, -14, -13, -12, -11, -10, -9, -8, -8, -7, -6, -5, -5, -4, -3,
|
||||
-2, -2, -1, 0, 0, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6,
|
||||
6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13,
|
||||
13, 14, 14, 15, 15, 15, 16, 16, 17, 17, 17, 18, 18, 18, 19,
|
||||
19, 19, 20, 20, 21, 21, 21, 22, 22, 22, 23, 23, 23, 24, 24,
|
||||
24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 28,
|
||||
29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 32, 32, 32, 32, 33,
|
||||
33, 33, 33, 34, 34, 34, 35, 35, 35, 35, 36, 36, 36, 36, 37,
|
||||
37, 37, 38, 38, 38, 38, 39, 39, 39, 39, 40, 40, 40, 40, 41,
|
||||
41, 41, 41, 42, 42, 42, 42, 43, 43, 43, 44, 44, 44, 44, 45,
|
||||
45, 45, 45, 46, 46, 46, 46, 47, 47, 47, 47, 48, 48, 48, 48,
|
||||
49, 49, 49, 49, 50, 50, 50, 50, 51, 51, 51, 51, 52, 52, 52,
|
||||
52, 53, 53, 53, 53, 54, 54, 54, 54, 55, 55, 55, 55, 56, 56,
|
||||
56, 56, 57, 57, 57, 57, 58, 58, 58, 58, 59, 59, 59, 59, 60,
|
||||
60, 60, 60, 61, 61, 61, 61, 62, 62, 62, 62, 63, 63, 63, 63,
|
||||
64};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_GAIN_MAP_INTERNAL_H_
|
||||
|
|
@ -0,0 +1,580 @@
|
|||
/*
|
||||
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/input_volume_controller.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/gain_map_internal.h"
|
||||
#include "modules/audio_processing/agc2/input_volume_stats_reporter.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
#include "system_wrappers/include/field_trial.h"
|
||||
#include "system_wrappers/include/metrics.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
namespace {
|
||||
|
||||
// Amount of error we tolerate in the microphone input volume (presumably due to
|
||||
// OS quantization) before we assume the user has manually adjusted the volume.
|
||||
constexpr int kVolumeQuantizationSlack = 25;
|
||||
|
||||
constexpr int kMaxInputVolume = 255;
|
||||
static_assert(kGainMapSize > kMaxInputVolume, "gain map too small");
|
||||
|
||||
// Maximum absolute RMS error.
|
||||
constexpr int KMaxAbsRmsErrorDbfs = 15;
|
||||
static_assert(KMaxAbsRmsErrorDbfs > 0, "");
|
||||
|
||||
using Agc1ClippingPredictorConfig = AudioProcessing::Config::GainController1::
|
||||
AnalogGainController::ClippingPredictor;
|
||||
|
||||
// TODO(webrtc:7494): Hardcode clipping predictor parameters and remove this
|
||||
// function after no longer needed in the ctor.
|
||||
Agc1ClippingPredictorConfig CreateClippingPredictorConfig(bool enabled) {
|
||||
Agc1ClippingPredictorConfig config;
|
||||
config.enabled = enabled;
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
// Returns an input volume in the [`min_input_volume`, `kMaxInputVolume`] range
|
||||
// that reduces `gain_error_db`, which is a gain error estimated when
|
||||
// `input_volume` was applied, according to a fixed gain map.
|
||||
int ComputeVolumeUpdate(int gain_error_db,
|
||||
int input_volume,
|
||||
int min_input_volume) {
|
||||
RTC_DCHECK_GE(input_volume, 0);
|
||||
RTC_DCHECK_LE(input_volume, kMaxInputVolume);
|
||||
if (gain_error_db == 0) {
|
||||
return input_volume;
|
||||
}
|
||||
|
||||
int new_volume = input_volume;
|
||||
if (gain_error_db > 0) {
|
||||
while (kGainMap[new_volume] - kGainMap[input_volume] < gain_error_db &&
|
||||
new_volume < kMaxInputVolume) {
|
||||
++new_volume;
|
||||
}
|
||||
} else {
|
||||
while (kGainMap[new_volume] - kGainMap[input_volume] > gain_error_db &&
|
||||
new_volume > min_input_volume) {
|
||||
--new_volume;
|
||||
}
|
||||
}
|
||||
return new_volume;
|
||||
}
|
||||
|
||||
// Returns the proportion of samples in the buffer which are at full-scale
|
||||
// (and presumably clipped).
|
||||
float ComputeClippedRatio(const float* const* audio,
|
||||
size_t num_channels,
|
||||
size_t samples_per_channel) {
|
||||
RTC_DCHECK_GT(samples_per_channel, 0);
|
||||
int num_clipped = 0;
|
||||
for (size_t ch = 0; ch < num_channels; ++ch) {
|
||||
int num_clipped_in_ch = 0;
|
||||
for (size_t i = 0; i < samples_per_channel; ++i) {
|
||||
RTC_DCHECK(audio[ch]);
|
||||
if (audio[ch][i] >= 32767.0f || audio[ch][i] <= -32768.0f) {
|
||||
++num_clipped_in_ch;
|
||||
}
|
||||
}
|
||||
num_clipped = std::max(num_clipped, num_clipped_in_ch);
|
||||
}
|
||||
return static_cast<float>(num_clipped) / (samples_per_channel);
|
||||
}
|
||||
|
||||
void LogClippingMetrics(int clipping_rate) {
|
||||
RTC_LOG(LS_INFO) << "[AGC2] Input clipping rate: " << clipping_rate << "%";
|
||||
RTC_HISTOGRAM_COUNTS_LINEAR(/*name=*/"WebRTC.Audio.Agc.InputClippingRate",
|
||||
/*sample=*/clipping_rate, /*min=*/0, /*max=*/100,
|
||||
/*bucket_count=*/50);
|
||||
}
|
||||
|
||||
// Compares `speech_level_dbfs` to the [`target_range_min_dbfs`,
|
||||
// `target_range_max_dbfs`] range and returns the error to be compensated via
|
||||
// input volume adjustment. Returns a positive value when the level is below
|
||||
// the range, a negative value when the level is above the range, zero
|
||||
// otherwise.
|
||||
int GetSpeechLevelRmsErrorDb(float speech_level_dbfs,
|
||||
int target_range_min_dbfs,
|
||||
int target_range_max_dbfs) {
|
||||
constexpr float kMinSpeechLevelDbfs = -90.0f;
|
||||
constexpr float kMaxSpeechLevelDbfs = 30.0f;
|
||||
RTC_DCHECK_GE(speech_level_dbfs, kMinSpeechLevelDbfs);
|
||||
RTC_DCHECK_LE(speech_level_dbfs, kMaxSpeechLevelDbfs);
|
||||
speech_level_dbfs = rtc::SafeClamp<float>(
|
||||
speech_level_dbfs, kMinSpeechLevelDbfs, kMaxSpeechLevelDbfs);
|
||||
|
||||
int rms_error_db = 0;
|
||||
if (speech_level_dbfs > target_range_max_dbfs) {
|
||||
rms_error_db = std::round(target_range_max_dbfs - speech_level_dbfs);
|
||||
} else if (speech_level_dbfs < target_range_min_dbfs) {
|
||||
rms_error_db = std::round(target_range_min_dbfs - speech_level_dbfs);
|
||||
}
|
||||
|
||||
return rms_error_db;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
MonoInputVolumeController::MonoInputVolumeController(
|
||||
int min_input_volume_after_clipping,
|
||||
int min_input_volume,
|
||||
int update_input_volume_wait_frames,
|
||||
float speech_probability_threshold,
|
||||
float speech_ratio_threshold)
|
||||
: min_input_volume_(min_input_volume),
|
||||
min_input_volume_after_clipping_(min_input_volume_after_clipping),
|
||||
max_input_volume_(kMaxInputVolume),
|
||||
update_input_volume_wait_frames_(
|
||||
std::max(update_input_volume_wait_frames, 1)),
|
||||
speech_probability_threshold_(speech_probability_threshold),
|
||||
speech_ratio_threshold_(speech_ratio_threshold) {
|
||||
RTC_DCHECK_GE(min_input_volume_, 0);
|
||||
RTC_DCHECK_LE(min_input_volume_, 255);
|
||||
RTC_DCHECK_GE(min_input_volume_after_clipping_, 0);
|
||||
RTC_DCHECK_LE(min_input_volume_after_clipping_, 255);
|
||||
RTC_DCHECK_GE(max_input_volume_, 0);
|
||||
RTC_DCHECK_LE(max_input_volume_, 255);
|
||||
RTC_DCHECK_GE(update_input_volume_wait_frames_, 0);
|
||||
RTC_DCHECK_GE(speech_probability_threshold_, 0.0f);
|
||||
RTC_DCHECK_LE(speech_probability_threshold_, 1.0f);
|
||||
RTC_DCHECK_GE(speech_ratio_threshold_, 0.0f);
|
||||
RTC_DCHECK_LE(speech_ratio_threshold_, 1.0f);
|
||||
}
|
||||
|
||||
MonoInputVolumeController::~MonoInputVolumeController() = default;
|
||||
|
||||
void MonoInputVolumeController::Initialize() {
|
||||
max_input_volume_ = kMaxInputVolume;
|
||||
capture_output_used_ = true;
|
||||
check_volume_on_next_process_ = true;
|
||||
frames_since_update_input_volume_ = 0;
|
||||
speech_frames_since_update_input_volume_ = 0;
|
||||
is_first_frame_ = true;
|
||||
}
|
||||
|
||||
// A speeh segment is considered active if at least
|
||||
// `update_input_volume_wait_frames_` new frames have been processed since the
|
||||
// previous update and the ratio of non-silence frames (i.e., frames with a
|
||||
// `speech_probability` higher than `speech_probability_threshold_`) is at least
|
||||
// `speech_ratio_threshold_`.
|
||||
void MonoInputVolumeController::Process(absl::optional<int> rms_error_db,
|
||||
float speech_probability) {
|
||||
if (check_volume_on_next_process_) {
|
||||
check_volume_on_next_process_ = false;
|
||||
// We have to wait until the first process call to check the volume,
|
||||
// because Chromium doesn't guarantee it to be valid any earlier.
|
||||
CheckVolumeAndReset();
|
||||
}
|
||||
|
||||
// Count frames with a high speech probability as speech.
|
||||
if (speech_probability >= speech_probability_threshold_) {
|
||||
++speech_frames_since_update_input_volume_;
|
||||
}
|
||||
|
||||
// Reset the counters and maybe update the input volume.
|
||||
if (++frames_since_update_input_volume_ >= update_input_volume_wait_frames_) {
|
||||
const float speech_ratio =
|
||||
static_cast<float>(speech_frames_since_update_input_volume_) /
|
||||
static_cast<float>(update_input_volume_wait_frames_);
|
||||
|
||||
// Always reset the counters regardless of whether the volume changes or
|
||||
// not.
|
||||
frames_since_update_input_volume_ = 0;
|
||||
speech_frames_since_update_input_volume_ = 0;
|
||||
|
||||
// Update the input volume if allowed.
|
||||
if (!is_first_frame_ && speech_ratio >= speech_ratio_threshold_ &&
|
||||
rms_error_db.has_value()) {
|
||||
UpdateInputVolume(*rms_error_db);
|
||||
}
|
||||
}
|
||||
|
||||
is_first_frame_ = false;
|
||||
}
|
||||
|
||||
void MonoInputVolumeController::HandleClipping(int clipped_level_step) {
|
||||
RTC_DCHECK_GT(clipped_level_step, 0);
|
||||
// Always decrease the maximum input volume, even if the current input volume
|
||||
// is below threshold.
|
||||
SetMaxLevel(std::max(min_input_volume_after_clipping_,
|
||||
max_input_volume_ - clipped_level_step));
|
||||
if (log_to_histograms_) {
|
||||
RTC_HISTOGRAM_BOOLEAN("WebRTC.Audio.AgcClippingAdjustmentAllowed",
|
||||
last_recommended_input_volume_ - clipped_level_step >=
|
||||
min_input_volume_after_clipping_);
|
||||
}
|
||||
if (last_recommended_input_volume_ > min_input_volume_after_clipping_) {
|
||||
// Don't try to adjust the input volume if we're already below the limit. As
|
||||
// a consequence, if the user has brought the input volume above the limit,
|
||||
// we will still not react until the postproc updates the input volume.
|
||||
SetInputVolume(
|
||||
std::max(min_input_volume_after_clipping_,
|
||||
last_recommended_input_volume_ - clipped_level_step));
|
||||
frames_since_update_input_volume_ = 0;
|
||||
speech_frames_since_update_input_volume_ = 0;
|
||||
is_first_frame_ = false;
|
||||
}
|
||||
}
|
||||
|
||||
void MonoInputVolumeController::SetInputVolume(int new_volume) {
|
||||
int applied_input_volume = recommended_input_volume_;
|
||||
if (applied_input_volume == 0) {
|
||||
RTC_DLOG(LS_INFO)
|
||||
<< "[AGC2] The applied input volume is zero, taking no action.";
|
||||
return;
|
||||
}
|
||||
if (applied_input_volume < 0 || applied_input_volume > kMaxInputVolume) {
|
||||
RTC_LOG(LS_ERROR) << "[AGC2] Invalid value for the applied input volume: "
|
||||
<< applied_input_volume;
|
||||
return;
|
||||
}
|
||||
|
||||
// Detect manual input volume adjustments by checking if the
|
||||
// `applied_input_volume` is outside of the `[last_recommended_input_volume_ -
|
||||
// kVolumeQuantizationSlack, last_recommended_input_volume_ +
|
||||
// kVolumeQuantizationSlack]` range.
|
||||
if (applied_input_volume >
|
||||
last_recommended_input_volume_ + kVolumeQuantizationSlack ||
|
||||
applied_input_volume <
|
||||
last_recommended_input_volume_ - kVolumeQuantizationSlack) {
|
||||
RTC_DLOG(LS_INFO)
|
||||
<< "[AGC2] The input volume was manually adjusted. Updating "
|
||||
"stored input volume from "
|
||||
<< last_recommended_input_volume_ << " to " << applied_input_volume;
|
||||
last_recommended_input_volume_ = applied_input_volume;
|
||||
// Always allow the user to increase the volume.
|
||||
if (last_recommended_input_volume_ > max_input_volume_) {
|
||||
SetMaxLevel(last_recommended_input_volume_);
|
||||
}
|
||||
// Take no action in this case, since we can't be sure when the volume
|
||||
// was manually adjusted.
|
||||
frames_since_update_input_volume_ = 0;
|
||||
speech_frames_since_update_input_volume_ = 0;
|
||||
is_first_frame_ = false;
|
||||
return;
|
||||
}
|
||||
|
||||
new_volume = std::min(new_volume, max_input_volume_);
|
||||
if (new_volume == last_recommended_input_volume_) {
|
||||
return;
|
||||
}
|
||||
|
||||
recommended_input_volume_ = new_volume;
|
||||
RTC_DLOG(LS_INFO) << "[AGC2] Applied input volume: " << applied_input_volume
|
||||
<< " | last recommended input volume: "
|
||||
<< last_recommended_input_volume_
|
||||
<< " | newly recommended input volume: " << new_volume;
|
||||
last_recommended_input_volume_ = new_volume;
|
||||
}
|
||||
|
||||
void MonoInputVolumeController::SetMaxLevel(int input_volume) {
|
||||
RTC_DCHECK_GE(input_volume, min_input_volume_after_clipping_);
|
||||
max_input_volume_ = input_volume;
|
||||
RTC_DLOG(LS_INFO) << "[AGC2] Maximum input volume updated: "
|
||||
<< max_input_volume_;
|
||||
}
|
||||
|
||||
void MonoInputVolumeController::HandleCaptureOutputUsedChange(
|
||||
bool capture_output_used) {
|
||||
if (capture_output_used_ == capture_output_used) {
|
||||
return;
|
||||
}
|
||||
capture_output_used_ = capture_output_used;
|
||||
|
||||
if (capture_output_used) {
|
||||
// When we start using the output, we should reset things to be safe.
|
||||
check_volume_on_next_process_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
int MonoInputVolumeController::CheckVolumeAndReset() {
|
||||
int input_volume = recommended_input_volume_;
|
||||
// Reasons for taking action at startup:
|
||||
// 1) A person starting a call is expected to be heard.
|
||||
// 2) Independent of interpretation of `input_volume` == 0 we should raise it
|
||||
// so the AGC can do its job properly.
|
||||
if (input_volume == 0 && !startup_) {
|
||||
RTC_DLOG(LS_INFO)
|
||||
<< "[AGC2] The applied input volume is zero, taking no action.";
|
||||
return 0;
|
||||
}
|
||||
if (input_volume < 0 || input_volume > kMaxInputVolume) {
|
||||
RTC_LOG(LS_ERROR) << "[AGC2] Invalid value for the applied input volume: "
|
||||
<< input_volume;
|
||||
return -1;
|
||||
}
|
||||
RTC_DLOG(LS_INFO) << "[AGC2] Initial input volume: " << input_volume;
|
||||
|
||||
if (input_volume < min_input_volume_) {
|
||||
input_volume = min_input_volume_;
|
||||
RTC_DLOG(LS_INFO)
|
||||
<< "[AGC2] The initial input volume is too low, raising to "
|
||||
<< input_volume;
|
||||
recommended_input_volume_ = input_volume;
|
||||
}
|
||||
|
||||
last_recommended_input_volume_ = input_volume;
|
||||
startup_ = false;
|
||||
frames_since_update_input_volume_ = 0;
|
||||
speech_frames_since_update_input_volume_ = 0;
|
||||
is_first_frame_ = true;
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
void MonoInputVolumeController::UpdateInputVolume(int rms_error_db) {
|
||||
RTC_DLOG(LS_INFO) << "[AGC2] RMS error: " << rms_error_db << " dB";
|
||||
// Prevent too large microphone input volume changes by clamping the RMS
|
||||
// error.
|
||||
rms_error_db =
|
||||
rtc::SafeClamp(rms_error_db, -KMaxAbsRmsErrorDbfs, KMaxAbsRmsErrorDbfs);
|
||||
if (rms_error_db == 0) {
|
||||
return;
|
||||
}
|
||||
SetInputVolume(ComputeVolumeUpdate(
|
||||
rms_error_db, last_recommended_input_volume_, min_input_volume_));
|
||||
}
|
||||
|
||||
InputVolumeController::InputVolumeController(int num_capture_channels,
|
||||
const Config& config)
|
||||
: num_capture_channels_(num_capture_channels),
|
||||
min_input_volume_(config.min_input_volume),
|
||||
capture_output_used_(true),
|
||||
clipped_level_step_(config.clipped_level_step),
|
||||
clipped_ratio_threshold_(config.clipped_ratio_threshold),
|
||||
clipped_wait_frames_(config.clipped_wait_frames),
|
||||
clipping_predictor_(CreateClippingPredictor(
|
||||
num_capture_channels,
|
||||
CreateClippingPredictorConfig(config.enable_clipping_predictor))),
|
||||
use_clipping_predictor_step_(
|
||||
!!clipping_predictor_ &&
|
||||
CreateClippingPredictorConfig(config.enable_clipping_predictor)
|
||||
.use_predicted_step),
|
||||
frames_since_clipped_(config.clipped_wait_frames),
|
||||
clipping_rate_log_counter_(0),
|
||||
clipping_rate_log_(0.0f),
|
||||
target_range_max_dbfs_(config.target_range_max_dbfs),
|
||||
target_range_min_dbfs_(config.target_range_min_dbfs),
|
||||
channel_controllers_(num_capture_channels) {
|
||||
RTC_LOG(LS_INFO)
|
||||
<< "[AGC2] Input volume controller enabled. Minimum input volume: "
|
||||
<< min_input_volume_;
|
||||
|
||||
for (auto& controller : channel_controllers_) {
|
||||
controller = std::make_unique<MonoInputVolumeController>(
|
||||
config.clipped_level_min, min_input_volume_,
|
||||
config.update_input_volume_wait_frames,
|
||||
config.speech_probability_threshold, config.speech_ratio_threshold);
|
||||
}
|
||||
|
||||
RTC_DCHECK(!channel_controllers_.empty());
|
||||
RTC_DCHECK_GT(clipped_level_step_, 0);
|
||||
RTC_DCHECK_LE(clipped_level_step_, 255);
|
||||
RTC_DCHECK_GT(clipped_ratio_threshold_, 0.0f);
|
||||
RTC_DCHECK_LT(clipped_ratio_threshold_, 1.0f);
|
||||
RTC_DCHECK_GT(clipped_wait_frames_, 0);
|
||||
channel_controllers_[0]->ActivateLogging();
|
||||
}
|
||||
|
||||
InputVolumeController::~InputVolumeController() {}
|
||||
|
||||
void InputVolumeController::Initialize() {
|
||||
for (auto& controller : channel_controllers_) {
|
||||
controller->Initialize();
|
||||
}
|
||||
capture_output_used_ = true;
|
||||
|
||||
AggregateChannelLevels();
|
||||
clipping_rate_log_ = 0.0f;
|
||||
clipping_rate_log_counter_ = 0;
|
||||
|
||||
applied_input_volume_ = absl::nullopt;
|
||||
}
|
||||
|
||||
void InputVolumeController::AnalyzeInputAudio(int applied_input_volume,
|
||||
const AudioBuffer& audio_buffer) {
|
||||
RTC_DCHECK_GE(applied_input_volume, 0);
|
||||
RTC_DCHECK_LE(applied_input_volume, 255);
|
||||
|
||||
SetAppliedInputVolume(applied_input_volume);
|
||||
|
||||
RTC_DCHECK_EQ(audio_buffer.num_channels(), channel_controllers_.size());
|
||||
const float* const* audio = audio_buffer.channels_const();
|
||||
size_t samples_per_channel = audio_buffer.num_frames();
|
||||
RTC_DCHECK(audio);
|
||||
|
||||
AggregateChannelLevels();
|
||||
if (!capture_output_used_) {
|
||||
return;
|
||||
}
|
||||
|
||||
if (!!clipping_predictor_) {
|
||||
AudioFrameView<const float> frame = AudioFrameView<const float>(
|
||||
audio, num_capture_channels_, static_cast<int>(samples_per_channel));
|
||||
clipping_predictor_->Analyze(frame);
|
||||
}
|
||||
|
||||
// Check for clipped samples. We do this in the preprocessing phase in order
|
||||
// to catch clipped echo as well.
|
||||
//
|
||||
// If we find a sufficiently clipped frame, drop the current microphone
|
||||
// input volume and enforce a new maximum input volume, dropped the same
|
||||
// amount from the current maximum. This harsh treatment is an effort to avoid
|
||||
// repeated clipped echo events.
|
||||
float clipped_ratio =
|
||||
ComputeClippedRatio(audio, num_capture_channels_, samples_per_channel);
|
||||
clipping_rate_log_ = std::max(clipped_ratio, clipping_rate_log_);
|
||||
clipping_rate_log_counter_++;
|
||||
constexpr int kNumFramesIn30Seconds = 3000;
|
||||
if (clipping_rate_log_counter_ == kNumFramesIn30Seconds) {
|
||||
LogClippingMetrics(std::round(100.0f * clipping_rate_log_));
|
||||
clipping_rate_log_ = 0.0f;
|
||||
clipping_rate_log_counter_ = 0;
|
||||
}
|
||||
|
||||
if (frames_since_clipped_ < clipped_wait_frames_) {
|
||||
++frames_since_clipped_;
|
||||
return;
|
||||
}
|
||||
|
||||
const bool clipping_detected = clipped_ratio > clipped_ratio_threshold_;
|
||||
bool clipping_predicted = false;
|
||||
int predicted_step = 0;
|
||||
if (!!clipping_predictor_) {
|
||||
for (int channel = 0; channel < num_capture_channels_; ++channel) {
|
||||
const auto step = clipping_predictor_->EstimateClippedLevelStep(
|
||||
channel, recommended_input_volume_, clipped_level_step_,
|
||||
channel_controllers_[channel]->min_input_volume_after_clipping(),
|
||||
kMaxInputVolume);
|
||||
if (step.has_value()) {
|
||||
predicted_step = std::max(predicted_step, step.value());
|
||||
clipping_predicted = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (clipping_detected) {
|
||||
RTC_DLOG(LS_INFO) << "[AGC2] Clipping detected (ratio: " << clipped_ratio
|
||||
<< ")";
|
||||
}
|
||||
|
||||
int step = clipped_level_step_;
|
||||
if (clipping_predicted) {
|
||||
predicted_step = std::max(predicted_step, clipped_level_step_);
|
||||
RTC_DLOG(LS_INFO) << "[AGC2] Clipping predicted (volume down step: "
|
||||
<< predicted_step << ")";
|
||||
if (use_clipping_predictor_step_) {
|
||||
step = predicted_step;
|
||||
}
|
||||
}
|
||||
|
||||
if (clipping_detected ||
|
||||
(clipping_predicted && use_clipping_predictor_step_)) {
|
||||
for (auto& state_ch : channel_controllers_) {
|
||||
state_ch->HandleClipping(step);
|
||||
}
|
||||
frames_since_clipped_ = 0;
|
||||
if (!!clipping_predictor_) {
|
||||
clipping_predictor_->Reset();
|
||||
}
|
||||
}
|
||||
|
||||
AggregateChannelLevels();
|
||||
}
|
||||
|
||||
absl::optional<int> InputVolumeController::RecommendInputVolume(
|
||||
float speech_probability,
|
||||
absl::optional<float> speech_level_dbfs) {
|
||||
// Only process if applied input volume is set.
|
||||
if (!applied_input_volume_.has_value()) {
|
||||
RTC_LOG(LS_ERROR) << "[AGC2] Applied input volume not set.";
|
||||
return absl::nullopt;
|
||||
}
|
||||
|
||||
AggregateChannelLevels();
|
||||
const int volume_after_clipping_handling = recommended_input_volume_;
|
||||
|
||||
if (!capture_output_used_) {
|
||||
return applied_input_volume_;
|
||||
}
|
||||
|
||||
absl::optional<int> rms_error_db;
|
||||
if (speech_level_dbfs.has_value()) {
|
||||
// Compute the error for all frames (both speech and non-speech frames).
|
||||
rms_error_db = GetSpeechLevelRmsErrorDb(
|
||||
*speech_level_dbfs, target_range_min_dbfs_, target_range_max_dbfs_);
|
||||
}
|
||||
|
||||
for (auto& controller : channel_controllers_) {
|
||||
controller->Process(rms_error_db, speech_probability);
|
||||
}
|
||||
|
||||
AggregateChannelLevels();
|
||||
if (volume_after_clipping_handling != recommended_input_volume_) {
|
||||
// The recommended input volume was adjusted in order to match the target
|
||||
// level.
|
||||
UpdateHistogramOnRecommendedInputVolumeChangeToMatchTarget(
|
||||
recommended_input_volume_);
|
||||
}
|
||||
|
||||
applied_input_volume_ = absl::nullopt;
|
||||
return recommended_input_volume();
|
||||
}
|
||||
|
||||
void InputVolumeController::HandleCaptureOutputUsedChange(
|
||||
bool capture_output_used) {
|
||||
for (auto& controller : channel_controllers_) {
|
||||
controller->HandleCaptureOutputUsedChange(capture_output_used);
|
||||
}
|
||||
|
||||
capture_output_used_ = capture_output_used;
|
||||
}
|
||||
|
||||
void InputVolumeController::SetAppliedInputVolume(int input_volume) {
|
||||
applied_input_volume_ = input_volume;
|
||||
|
||||
for (auto& controller : channel_controllers_) {
|
||||
controller->set_stream_analog_level(input_volume);
|
||||
}
|
||||
|
||||
AggregateChannelLevels();
|
||||
}
|
||||
|
||||
void InputVolumeController::AggregateChannelLevels() {
|
||||
int new_recommended_input_volume =
|
||||
channel_controllers_[0]->recommended_analog_level();
|
||||
channel_controlling_gain_ = 0;
|
||||
for (size_t ch = 1; ch < channel_controllers_.size(); ++ch) {
|
||||
int input_volume = channel_controllers_[ch]->recommended_analog_level();
|
||||
if (input_volume < new_recommended_input_volume) {
|
||||
new_recommended_input_volume = input_volume;
|
||||
channel_controlling_gain_ = static_cast<int>(ch);
|
||||
}
|
||||
}
|
||||
|
||||
// Enforce the minimum input volume when a recommendation is made.
|
||||
if (applied_input_volume_.has_value() && *applied_input_volume_ > 0) {
|
||||
new_recommended_input_volume =
|
||||
std::max(new_recommended_input_volume, min_input_volume_);
|
||||
}
|
||||
|
||||
recommended_input_volume_ = new_recommended_input_volume;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,282 @@
|
|||
/*
|
||||
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_CONTROLLER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_CONTROLLER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/clipping_predictor.h"
|
||||
#include "modules/audio_processing/audio_buffer.h"
|
||||
#include "modules/audio_processing/include/audio_processing.h"
|
||||
#include "rtc_base/gtest_prod_util.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
class MonoInputVolumeController;
|
||||
|
||||
// The input volume controller recommends what volume to use, handles volume
|
||||
// changes and clipping detection and prediction. In particular, it handles
|
||||
// changes triggered by the user (e.g., volume set to zero by a HW mute button).
|
||||
// This class is not thread-safe.
|
||||
// TODO(bugs.webrtc.org/7494): Use applied/recommended input volume naming
|
||||
// convention.
|
||||
class InputVolumeController final {
|
||||
public:
|
||||
// Config for the constructor.
|
||||
struct Config {
|
||||
// Minimum input volume that can be recommended. Not enforced when the
|
||||
// applied input volume is zero outside startup.
|
||||
int min_input_volume = 20;
|
||||
// Lowest input volume level that will be applied in response to clipping.
|
||||
int clipped_level_min = 70;
|
||||
// Amount input volume level is lowered with every clipping event. Limited
|
||||
// to (0, 255].
|
||||
int clipped_level_step = 15;
|
||||
// Proportion of clipped samples required to declare a clipping event.
|
||||
// Limited to (0.0f, 1.0f).
|
||||
float clipped_ratio_threshold = 0.1f;
|
||||
// Time in frames to wait after a clipping event before checking again.
|
||||
// Limited to values higher than 0.
|
||||
int clipped_wait_frames = 300;
|
||||
// Enables clipping prediction functionality.
|
||||
bool enable_clipping_predictor = true;
|
||||
// Speech level target range (dBFS). If the speech level is in the range
|
||||
// [`target_range_min_dbfs`, `target_range_max_dbfs`], no input volume
|
||||
// adjustments are done based on the speech level. For speech levels below
|
||||
// and above the range, the targets `target_range_min_dbfs` and
|
||||
// `target_range_max_dbfs` are used, respectively.
|
||||
int target_range_max_dbfs = -30;
|
||||
int target_range_min_dbfs = -50;
|
||||
// Number of wait frames between the recommended input volume updates.
|
||||
int update_input_volume_wait_frames = 100;
|
||||
// Speech probability threshold: speech probabilities below the threshold
|
||||
// are considered silence. Limited to [0.0f, 1.0f].
|
||||
float speech_probability_threshold = 0.7f;
|
||||
// Minimum speech frame ratio for volume updates to be allowed. Limited to
|
||||
// [0.0f, 1.0f].
|
||||
float speech_ratio_threshold = 0.6f;
|
||||
};
|
||||
|
||||
// Ctor. `num_capture_channels` specifies the number of channels for the audio
|
||||
// passed to `AnalyzePreProcess()` and `Process()`. Clamps
|
||||
// `config.startup_min_level` in the [12, 255] range.
|
||||
InputVolumeController(int num_capture_channels, const Config& config);
|
||||
|
||||
~InputVolumeController();
|
||||
InputVolumeController(const InputVolumeController&) = delete;
|
||||
InputVolumeController& operator=(const InputVolumeController&) = delete;
|
||||
|
||||
// TODO(webrtc:7494): Integrate initialization into ctor and remove.
|
||||
void Initialize();
|
||||
|
||||
// Analyzes `audio_buffer` before `RecommendInputVolume()` is called so tha
|
||||
// the analysis can be performed before digital processing operations take
|
||||
// place (e.g., echo cancellation). The analysis consists of input clipping
|
||||
// detection and prediction (if enabled).
|
||||
void AnalyzeInputAudio(int applied_input_volume,
|
||||
const AudioBuffer& audio_buffer);
|
||||
|
||||
// Adjusts the recommended input volume upwards/downwards based on the result
|
||||
// of `AnalyzeInputAudio()` and on `speech_level_dbfs` (if specified). Must
|
||||
// be called after `AnalyzeInputAudio()`. The value of `speech_probability`
|
||||
// is expected to be in the range [0, 1] and `speech_level_dbfs` in the range
|
||||
// [-90, 30] and both should be estimated after echo cancellation and noise
|
||||
// suppression are applied. Returns a non-empty input volume recommendation if
|
||||
// available. If `capture_output_used_` is true, returns the applied input
|
||||
// volume.
|
||||
absl::optional<int> RecommendInputVolume(
|
||||
float speech_probability,
|
||||
absl::optional<float> speech_level_dbfs);
|
||||
|
||||
// Stores whether the capture output will be used or not. Call when the
|
||||
// capture stream output has been flagged to be used/not-used. If unused, the
|
||||
// controller disregards all incoming audio.
|
||||
void HandleCaptureOutputUsedChange(bool capture_output_used);
|
||||
|
||||
// Returns true if clipping prediction is enabled.
|
||||
// TODO(bugs.webrtc.org/7494): Deprecate this method.
|
||||
bool clipping_predictor_enabled() const { return !!clipping_predictor_; }
|
||||
|
||||
// Returns true if clipping prediction is used to adjust the input volume.
|
||||
// TODO(bugs.webrtc.org/7494): Deprecate this method.
|
||||
bool use_clipping_predictor_step() const {
|
||||
return use_clipping_predictor_step_;
|
||||
}
|
||||
|
||||
// Only use for testing: Use `RecommendInputVolume()` elsewhere.
|
||||
// Returns the value of a member variable, needed for testing
|
||||
// `AnalyzeInputAudio()`.
|
||||
int recommended_input_volume() const { return recommended_input_volume_; }
|
||||
|
||||
// Only use for testing.
|
||||
bool capture_output_used() const { return capture_output_used_; }
|
||||
|
||||
private:
|
||||
friend class InputVolumeControllerTestHelper;
|
||||
|
||||
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest, MinInputVolumeDefault);
|
||||
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest, MinInputVolumeDisabled);
|
||||
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest,
|
||||
MinInputVolumeOutOfRangeAbove);
|
||||
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest,
|
||||
MinInputVolumeOutOfRangeBelow);
|
||||
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest, MinInputVolumeEnabled50);
|
||||
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerParametrizedTest,
|
||||
ClippingParametersVerified);
|
||||
|
||||
// Sets the applied input volume and resets the recommended input volume.
|
||||
void SetAppliedInputVolume(int level);
|
||||
|
||||
void AggregateChannelLevels();
|
||||
|
||||
const int num_capture_channels_;
|
||||
|
||||
// Minimum input volume that can be recommended.
|
||||
const int min_input_volume_;
|
||||
|
||||
// TODO(bugs.webrtc.org/7494): Once
|
||||
// `AudioProcessingImpl::recommended_stream_analog_level()` becomes a trivial
|
||||
// getter, leave uninitialized.
|
||||
// Recommended input volume. After `SetAppliedInputVolume()` is called it
|
||||
// holds holds the observed input volume. Possibly updated by
|
||||
// `AnalyzePreProcess()` and `Process()`; after these calls, holds the
|
||||
// recommended input volume.
|
||||
int recommended_input_volume_ = 0;
|
||||
// Applied input volume. After `SetAppliedInputVolume()` is called it holds
|
||||
// the current applied volume.
|
||||
absl::optional<int> applied_input_volume_;
|
||||
|
||||
bool capture_output_used_;
|
||||
|
||||
// Clipping detection and prediction.
|
||||
const int clipped_level_step_;
|
||||
const float clipped_ratio_threshold_;
|
||||
const int clipped_wait_frames_;
|
||||
const std::unique_ptr<ClippingPredictor> clipping_predictor_;
|
||||
const bool use_clipping_predictor_step_;
|
||||
int frames_since_clipped_;
|
||||
int clipping_rate_log_counter_;
|
||||
float clipping_rate_log_;
|
||||
|
||||
// Target range minimum and maximum. If the seech level is in the range
|
||||
// [`target_range_min_dbfs`, `target_range_max_dbfs`], no volume adjustments
|
||||
// take place. Instead, the digital gain controller is assumed to adapt to
|
||||
// compensate for the speech level RMS error.
|
||||
const int target_range_max_dbfs_;
|
||||
const int target_range_min_dbfs_;
|
||||
|
||||
// Channel controllers updating the gain upwards/downwards.
|
||||
std::vector<std::unique_ptr<MonoInputVolumeController>> channel_controllers_;
|
||||
int channel_controlling_gain_ = 0;
|
||||
};
|
||||
|
||||
// TODO(bugs.webrtc.org/7494): Use applied/recommended input volume naming
|
||||
// convention.
|
||||
class MonoInputVolumeController {
|
||||
public:
|
||||
MonoInputVolumeController(int min_input_volume_after_clipping,
|
||||
int min_input_volume,
|
||||
int update_input_volume_wait_frames,
|
||||
float speech_probability_threshold,
|
||||
float speech_ratio_threshold);
|
||||
~MonoInputVolumeController();
|
||||
MonoInputVolumeController(const MonoInputVolumeController&) = delete;
|
||||
MonoInputVolumeController& operator=(const MonoInputVolumeController&) =
|
||||
delete;
|
||||
|
||||
void Initialize();
|
||||
void HandleCaptureOutputUsedChange(bool capture_output_used);
|
||||
|
||||
// Sets the current input volume.
|
||||
void set_stream_analog_level(int input_volume) {
|
||||
recommended_input_volume_ = input_volume;
|
||||
}
|
||||
|
||||
// Lowers the recommended input volume in response to clipping based on the
|
||||
// suggested reduction `clipped_level_step`. Must be called after
|
||||
// `set_stream_analog_level()`.
|
||||
void HandleClipping(int clipped_level_step);
|
||||
|
||||
// TODO(bugs.webrtc.org/7494): Rename, audio not passed to the method anymore.
|
||||
// Adjusts the recommended input volume upwards/downwards depending on the
|
||||
// result of `HandleClipping()` and on `rms_error_dbfs`. Updates are only
|
||||
// allowed for active speech segments and when `rms_error_dbfs` is not empty.
|
||||
// Must be called after `HandleClipping()`.
|
||||
void Process(absl::optional<int> rms_error_dbfs, float speech_probability);
|
||||
|
||||
// Returns the recommended input volume. Must be called after `Process()`.
|
||||
int recommended_analog_level() const { return recommended_input_volume_; }
|
||||
|
||||
void ActivateLogging() { log_to_histograms_ = true; }
|
||||
|
||||
int min_input_volume_after_clipping() const {
|
||||
return min_input_volume_after_clipping_;
|
||||
}
|
||||
|
||||
// Only used for testing.
|
||||
int min_input_volume() const { return min_input_volume_; }
|
||||
|
||||
private:
|
||||
// Sets a new input volume, after first checking that it hasn't been updated
|
||||
// by the user, in which case no action is taken.
|
||||
void SetInputVolume(int new_volume);
|
||||
|
||||
// Sets the maximum input volume that the input volume controller is allowed
|
||||
// to apply. The volume must be at least `kClippedLevelMin`.
|
||||
void SetMaxLevel(int level);
|
||||
|
||||
int CheckVolumeAndReset();
|
||||
|
||||
// Updates the recommended input volume. If the volume slider needs to be
|
||||
// moved, we check first if the user has adjusted it, in which case we take no
|
||||
// action and cache the updated level.
|
||||
void UpdateInputVolume(int rms_error_dbfs);
|
||||
|
||||
const int min_input_volume_;
|
||||
const int min_input_volume_after_clipping_;
|
||||
int max_input_volume_;
|
||||
|
||||
int last_recommended_input_volume_ = 0;
|
||||
|
||||
bool capture_output_used_ = true;
|
||||
bool check_volume_on_next_process_ = true;
|
||||
bool startup_ = true;
|
||||
|
||||
// TODO(bugs.webrtc.org/7494): Create a separate member for the applied
|
||||
// input volume.
|
||||
// Recommended input volume. After `set_stream_analog_level()` is
|
||||
// called, it holds the observed applied input volume. Possibly updated by
|
||||
// `HandleClipping()` and `Process()`; after these calls, holds the
|
||||
// recommended input volume.
|
||||
int recommended_input_volume_ = 0;
|
||||
|
||||
bool log_to_histograms_ = false;
|
||||
|
||||
// Counters for frames and speech frames since the last update in the
|
||||
// recommended input volume.
|
||||
const int update_input_volume_wait_frames_;
|
||||
int frames_since_update_input_volume_ = 0;
|
||||
int speech_frames_since_update_input_volume_ = 0;
|
||||
bool is_first_frame_ = true;
|
||||
|
||||
// Speech probability threshold for a frame to be considered speech (instead
|
||||
// of silence). Limited to [0.0f, 1.0f].
|
||||
const float speech_probability_threshold_;
|
||||
// Minimum ratio of speech frames. Limited to [0.0f, 1.0f].
|
||||
const float speech_ratio_threshold_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_CONTROLLER_H_
|
||||
|
|
@ -0,0 +1,171 @@
|
|||
/*
|
||||
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/input_volume_stats_reporter.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
#include "rtc_base/strings/string_builder.h"
|
||||
#include "system_wrappers/include/metrics.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
using InputVolumeType = InputVolumeStatsReporter::InputVolumeType;
|
||||
|
||||
constexpr int kFramesIn60Seconds = 6000;
|
||||
constexpr int kMinInputVolume = 0;
|
||||
constexpr int kMaxInputVolume = 255;
|
||||
constexpr int kMaxUpdate = kMaxInputVolume - kMinInputVolume;
|
||||
|
||||
int ComputeAverageUpdate(int sum_updates, int num_updates) {
|
||||
RTC_DCHECK_GE(sum_updates, 0);
|
||||
RTC_DCHECK_LE(sum_updates, kMaxUpdate * kFramesIn60Seconds);
|
||||
RTC_DCHECK_GE(num_updates, 0);
|
||||
RTC_DCHECK_LE(num_updates, kFramesIn60Seconds);
|
||||
if (num_updates == 0) {
|
||||
return 0;
|
||||
}
|
||||
return std::round(static_cast<float>(sum_updates) /
|
||||
static_cast<float>(num_updates));
|
||||
}
|
||||
|
||||
constexpr absl::string_view MetricNamePrefix(
|
||||
InputVolumeType input_volume_type) {
|
||||
switch (input_volume_type) {
|
||||
case InputVolumeType::kApplied:
|
||||
return "WebRTC.Audio.Apm.AppliedInputVolume.";
|
||||
case InputVolumeType::kRecommended:
|
||||
return "WebRTC.Audio.Apm.RecommendedInputVolume.";
|
||||
}
|
||||
}
|
||||
|
||||
metrics::Histogram* CreateVolumeHistogram(InputVolumeType input_volume_type) {
|
||||
char buffer[64];
|
||||
rtc::SimpleStringBuilder builder(buffer);
|
||||
builder << MetricNamePrefix(input_volume_type) << "OnChange";
|
||||
return metrics::HistogramFactoryGetCountsLinear(/*name=*/builder.str(),
|
||||
/*min=*/1,
|
||||
/*max=*/kMaxInputVolume,
|
||||
/*bucket_count=*/50);
|
||||
}
|
||||
|
||||
metrics::Histogram* CreateRateHistogram(InputVolumeType input_volume_type,
|
||||
absl::string_view name) {
|
||||
char buffer[64];
|
||||
rtc::SimpleStringBuilder builder(buffer);
|
||||
builder << MetricNamePrefix(input_volume_type) << name;
|
||||
return metrics::HistogramFactoryGetCountsLinear(/*name=*/builder.str(),
|
||||
/*min=*/1,
|
||||
/*max=*/kFramesIn60Seconds,
|
||||
/*bucket_count=*/50);
|
||||
}
|
||||
|
||||
metrics::Histogram* CreateAverageHistogram(InputVolumeType input_volume_type,
|
||||
absl::string_view name) {
|
||||
char buffer[64];
|
||||
rtc::SimpleStringBuilder builder(buffer);
|
||||
builder << MetricNamePrefix(input_volume_type) << name;
|
||||
return metrics::HistogramFactoryGetCountsLinear(/*name=*/builder.str(),
|
||||
/*min=*/1,
|
||||
/*max=*/kMaxUpdate,
|
||||
/*bucket_count=*/50);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
InputVolumeStatsReporter::InputVolumeStatsReporter(InputVolumeType type)
|
||||
: histograms_(
|
||||
{.on_volume_change = CreateVolumeHistogram(type),
|
||||
.decrease_rate = CreateRateHistogram(type, "DecreaseRate"),
|
||||
.decrease_average = CreateAverageHistogram(type, "DecreaseAverage"),
|
||||
.increase_rate = CreateRateHistogram(type, "IncreaseRate"),
|
||||
.increase_average = CreateAverageHistogram(type, "IncreaseAverage"),
|
||||
.update_rate = CreateRateHistogram(type, "UpdateRate"),
|
||||
.update_average = CreateAverageHistogram(type, "UpdateAverage")}),
|
||||
cannot_log_stats_(!histograms_.AllPointersSet()) {
|
||||
if (cannot_log_stats_) {
|
||||
RTC_LOG(LS_WARNING) << "Will not log any `" << MetricNamePrefix(type)
|
||||
<< "*` histogram stats.";
|
||||
}
|
||||
}
|
||||
|
||||
InputVolumeStatsReporter::~InputVolumeStatsReporter() = default;
|
||||
|
||||
void InputVolumeStatsReporter::UpdateStatistics(int input_volume) {
|
||||
if (cannot_log_stats_) {
|
||||
// Since the stats cannot be logged, do not bother updating them.
|
||||
return;
|
||||
}
|
||||
|
||||
RTC_DCHECK_GE(input_volume, kMinInputVolume);
|
||||
RTC_DCHECK_LE(input_volume, kMaxInputVolume);
|
||||
if (previous_input_volume_.has_value() &&
|
||||
input_volume != previous_input_volume_.value()) {
|
||||
// Update stats when the input volume changes.
|
||||
metrics::HistogramAdd(histograms_.on_volume_change, input_volume);
|
||||
// Update stats that are periodically logged.
|
||||
const int volume_change = input_volume - previous_input_volume_.value();
|
||||
if (volume_change < 0) {
|
||||
++volume_update_stats_.num_decreases;
|
||||
volume_update_stats_.sum_decreases -= volume_change;
|
||||
} else {
|
||||
++volume_update_stats_.num_increases;
|
||||
volume_update_stats_.sum_increases += volume_change;
|
||||
}
|
||||
}
|
||||
// Periodically log input volume change metrics.
|
||||
if (++log_volume_update_stats_counter_ >= kFramesIn60Seconds) {
|
||||
LogVolumeUpdateStats();
|
||||
volume_update_stats_ = {};
|
||||
log_volume_update_stats_counter_ = 0;
|
||||
}
|
||||
previous_input_volume_ = input_volume;
|
||||
}
|
||||
|
||||
void InputVolumeStatsReporter::LogVolumeUpdateStats() const {
|
||||
// Decrease rate and average.
|
||||
metrics::HistogramAdd(histograms_.decrease_rate,
|
||||
volume_update_stats_.num_decreases);
|
||||
if (volume_update_stats_.num_decreases > 0) {
|
||||
int average_decrease = ComputeAverageUpdate(
|
||||
volume_update_stats_.sum_decreases, volume_update_stats_.num_decreases);
|
||||
metrics::HistogramAdd(histograms_.decrease_average, average_decrease);
|
||||
}
|
||||
// Increase rate and average.
|
||||
metrics::HistogramAdd(histograms_.increase_rate,
|
||||
volume_update_stats_.num_increases);
|
||||
if (volume_update_stats_.num_increases > 0) {
|
||||
int average_increase = ComputeAverageUpdate(
|
||||
volume_update_stats_.sum_increases, volume_update_stats_.num_increases);
|
||||
metrics::HistogramAdd(histograms_.increase_average, average_increase);
|
||||
}
|
||||
// Update rate and average.
|
||||
int num_updates =
|
||||
volume_update_stats_.num_decreases + volume_update_stats_.num_increases;
|
||||
metrics::HistogramAdd(histograms_.update_rate, num_updates);
|
||||
if (num_updates > 0) {
|
||||
int average_update = ComputeAverageUpdate(
|
||||
volume_update_stats_.sum_decreases + volume_update_stats_.sum_increases,
|
||||
num_updates);
|
||||
metrics::HistogramAdd(histograms_.update_average, average_update);
|
||||
}
|
||||
}
|
||||
|
||||
void UpdateHistogramOnRecommendedInputVolumeChangeToMatchTarget(int volume) {
|
||||
RTC_HISTOGRAM_COUNTS_LINEAR(
|
||||
"WebRTC.Audio.Apm.RecommendedInputVolume.OnChangeToMatchTarget", volume,
|
||||
1, kMaxInputVolume, 50);
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,96 @@
|
|||
/*
|
||||
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_STATS_REPORTER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_STATS_REPORTER_H_
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "rtc_base/gtest_prod_util.h"
|
||||
#include "system_wrappers/include/metrics.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// Input volume statistics calculator. Computes aggregate stats based on the
|
||||
// framewise input volume observed by `UpdateStatistics()`. Periodically logs
|
||||
// the statistics into a histogram.
|
||||
class InputVolumeStatsReporter {
|
||||
public:
|
||||
enum class InputVolumeType {
|
||||
kApplied = 0,
|
||||
kRecommended = 1,
|
||||
};
|
||||
|
||||
explicit InputVolumeStatsReporter(InputVolumeType input_volume_type);
|
||||
InputVolumeStatsReporter(const InputVolumeStatsReporter&) = delete;
|
||||
InputVolumeStatsReporter operator=(const InputVolumeStatsReporter&) = delete;
|
||||
~InputVolumeStatsReporter();
|
||||
|
||||
// Updates the stats based on `input_volume`. Periodically logs the stats into
|
||||
// a histogram.
|
||||
void UpdateStatistics(int input_volume);
|
||||
|
||||
private:
|
||||
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
|
||||
CheckVolumeUpdateStatsForEmptyStats);
|
||||
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
|
||||
CheckVolumeUpdateStatsAfterNoVolumeChange);
|
||||
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
|
||||
CheckVolumeUpdateStatsAfterVolumeIncrease);
|
||||
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
|
||||
CheckVolumeUpdateStatsAfterVolumeDecrease);
|
||||
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
|
||||
CheckVolumeUpdateStatsAfterReset);
|
||||
|
||||
// Stores input volume update stats to enable calculation of update rate and
|
||||
// average update separately for volume increases and decreases.
|
||||
struct VolumeUpdateStats {
|
||||
int num_decreases = 0;
|
||||
int num_increases = 0;
|
||||
int sum_decreases = 0;
|
||||
int sum_increases = 0;
|
||||
} volume_update_stats_;
|
||||
|
||||
// Returns a copy of the stored statistics. Use only for testing.
|
||||
VolumeUpdateStats volume_update_stats() const { return volume_update_stats_; }
|
||||
|
||||
// Computes aggregate stat and logs them into a histogram.
|
||||
void LogVolumeUpdateStats() const;
|
||||
|
||||
// Histograms.
|
||||
struct Histograms {
|
||||
metrics::Histogram* const on_volume_change;
|
||||
metrics::Histogram* const decrease_rate;
|
||||
metrics::Histogram* const decrease_average;
|
||||
metrics::Histogram* const increase_rate;
|
||||
metrics::Histogram* const increase_average;
|
||||
metrics::Histogram* const update_rate;
|
||||
metrics::Histogram* const update_average;
|
||||
bool AllPointersSet() const {
|
||||
return !!on_volume_change && !!decrease_rate && !!decrease_average &&
|
||||
!!increase_rate && !!increase_average && !!update_rate &&
|
||||
!!update_average;
|
||||
}
|
||||
} histograms_;
|
||||
|
||||
// True if the stats cannot be logged.
|
||||
const bool cannot_log_stats_;
|
||||
|
||||
int log_volume_update_stats_counter_ = 0;
|
||||
absl::optional<int> previous_input_volume_ = absl::nullopt;
|
||||
};
|
||||
|
||||
// Updates the histogram that keeps track of recommended input volume changes
|
||||
// required in order to match the target level in the input volume adaptation
|
||||
// process.
|
||||
void UpdateHistogramOnRecommendedInputVolumeChangeToMatchTarget(int volume);
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_STATS_REPORTER_H_
|
||||
|
|
@ -0,0 +1,204 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/interpolated_gain_curve.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/strings/string_builder.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
InterpolatedGainCurve::approximation_params_x_;
|
||||
|
||||
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
InterpolatedGainCurve::approximation_params_m_;
|
||||
|
||||
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
InterpolatedGainCurve::approximation_params_q_;
|
||||
|
||||
InterpolatedGainCurve::InterpolatedGainCurve(
|
||||
ApmDataDumper* apm_data_dumper,
|
||||
absl::string_view histogram_name_prefix)
|
||||
: region_logger_(
|
||||
(rtc::StringBuilder("WebRTC.Audio.")
|
||||
<< histogram_name_prefix << ".FixedDigitalGainCurveRegion.Identity")
|
||||
.str(),
|
||||
(rtc::StringBuilder("WebRTC.Audio.")
|
||||
<< histogram_name_prefix << ".FixedDigitalGainCurveRegion.Knee")
|
||||
.str(),
|
||||
(rtc::StringBuilder("WebRTC.Audio.")
|
||||
<< histogram_name_prefix << ".FixedDigitalGainCurveRegion.Limiter")
|
||||
.str(),
|
||||
(rtc::StringBuilder("WebRTC.Audio.")
|
||||
<< histogram_name_prefix
|
||||
<< ".FixedDigitalGainCurveRegion.Saturation")
|
||||
.str()),
|
||||
apm_data_dumper_(apm_data_dumper) {}
|
||||
|
||||
InterpolatedGainCurve::~InterpolatedGainCurve() {
|
||||
if (stats_.available) {
|
||||
RTC_DCHECK(apm_data_dumper_);
|
||||
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_identity",
|
||||
stats_.look_ups_identity_region);
|
||||
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_knee",
|
||||
stats_.look_ups_knee_region);
|
||||
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_limiter",
|
||||
stats_.look_ups_limiter_region);
|
||||
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_saturation",
|
||||
stats_.look_ups_saturation_region);
|
||||
region_logger_.LogRegionStats(stats_);
|
||||
}
|
||||
}
|
||||
|
||||
InterpolatedGainCurve::RegionLogger::RegionLogger(
|
||||
absl::string_view identity_histogram_name,
|
||||
absl::string_view knee_histogram_name,
|
||||
absl::string_view limiter_histogram_name,
|
||||
absl::string_view saturation_histogram_name)
|
||||
: identity_histogram(
|
||||
metrics::HistogramFactoryGetCounts(identity_histogram_name,
|
||||
1,
|
||||
10000,
|
||||
50)),
|
||||
knee_histogram(metrics::HistogramFactoryGetCounts(knee_histogram_name,
|
||||
1,
|
||||
10000,
|
||||
50)),
|
||||
limiter_histogram(
|
||||
metrics::HistogramFactoryGetCounts(limiter_histogram_name,
|
||||
1,
|
||||
10000,
|
||||
50)),
|
||||
saturation_histogram(
|
||||
metrics::HistogramFactoryGetCounts(saturation_histogram_name,
|
||||
1,
|
||||
10000,
|
||||
50)) {}
|
||||
|
||||
InterpolatedGainCurve::RegionLogger::~RegionLogger() = default;
|
||||
|
||||
void InterpolatedGainCurve::RegionLogger::LogRegionStats(
|
||||
const InterpolatedGainCurve::Stats& stats) const {
|
||||
using Region = InterpolatedGainCurve::GainCurveRegion;
|
||||
const int duration_s =
|
||||
stats.region_duration_frames / (1000 / kFrameDurationMs);
|
||||
|
||||
switch (stats.region) {
|
||||
case Region::kIdentity: {
|
||||
if (identity_histogram) {
|
||||
metrics::HistogramAdd(identity_histogram, duration_s);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Region::kKnee: {
|
||||
if (knee_histogram) {
|
||||
metrics::HistogramAdd(knee_histogram, duration_s);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Region::kLimiter: {
|
||||
if (limiter_histogram) {
|
||||
metrics::HistogramAdd(limiter_histogram, duration_s);
|
||||
}
|
||||
break;
|
||||
}
|
||||
case Region::kSaturation: {
|
||||
if (saturation_histogram) {
|
||||
metrics::HistogramAdd(saturation_histogram, duration_s);
|
||||
}
|
||||
break;
|
||||
}
|
||||
default: {
|
||||
RTC_DCHECK_NOTREACHED();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void InterpolatedGainCurve::UpdateStats(float input_level) const {
|
||||
stats_.available = true;
|
||||
|
||||
GainCurveRegion region;
|
||||
|
||||
if (input_level < approximation_params_x_[0]) {
|
||||
stats_.look_ups_identity_region++;
|
||||
region = GainCurveRegion::kIdentity;
|
||||
} else if (input_level <
|
||||
approximation_params_x_[kInterpolatedGainCurveKneePoints - 1]) {
|
||||
stats_.look_ups_knee_region++;
|
||||
region = GainCurveRegion::kKnee;
|
||||
} else if (input_level < kMaxInputLevelLinear) {
|
||||
stats_.look_ups_limiter_region++;
|
||||
region = GainCurveRegion::kLimiter;
|
||||
} else {
|
||||
stats_.look_ups_saturation_region++;
|
||||
region = GainCurveRegion::kSaturation;
|
||||
}
|
||||
|
||||
if (region == stats_.region) {
|
||||
++stats_.region_duration_frames;
|
||||
} else {
|
||||
region_logger_.LogRegionStats(stats_);
|
||||
|
||||
stats_.region_duration_frames = 0;
|
||||
stats_.region = region;
|
||||
}
|
||||
}
|
||||
|
||||
// Looks up a gain to apply given a non-negative input level.
|
||||
// The cost of this operation depends on the region in which `input_level`
|
||||
// falls.
|
||||
// For the identity and the saturation regions the cost is O(1).
|
||||
// For the other regions, namely knee and limiter, the cost is
|
||||
// O(2 + log2(`LightkInterpolatedGainCurveTotalPoints`), plus O(1) for the
|
||||
// linear interpolation (one product and one sum).
|
||||
float InterpolatedGainCurve::LookUpGainToApply(float input_level) const {
|
||||
UpdateStats(input_level);
|
||||
|
||||
if (input_level <= approximation_params_x_[0]) {
|
||||
// Identity region.
|
||||
return 1.0f;
|
||||
}
|
||||
|
||||
if (input_level >= kMaxInputLevelLinear) {
|
||||
// Saturating lower bound. The saturing samples exactly hit the clipping
|
||||
// level. This method achieves has the lowest harmonic distorsion, but it
|
||||
// may reduce the amplitude of the non-saturating samples too much.
|
||||
return 32768.f / input_level;
|
||||
}
|
||||
|
||||
// Knee and limiter regions; find the linear piece index. Spelling
|
||||
// out the complete type was the only way to silence both the clang
|
||||
// plugin and the windows compilers.
|
||||
std::array<float, kInterpolatedGainCurveTotalPoints>::const_iterator it =
|
||||
std::lower_bound(approximation_params_x_.begin(),
|
||||
approximation_params_x_.end(), input_level);
|
||||
const size_t index = std::distance(approximation_params_x_.begin(), it) - 1;
|
||||
RTC_DCHECK_LE(0, index);
|
||||
RTC_DCHECK_LT(index, approximation_params_m_.size());
|
||||
RTC_DCHECK_LE(approximation_params_x_[index], input_level);
|
||||
if (index < approximation_params_m_.size() - 1) {
|
||||
RTC_DCHECK_LE(input_level, approximation_params_x_[index + 1]);
|
||||
}
|
||||
|
||||
// Piece-wise linear interploation.
|
||||
const float gain = approximation_params_m_[index] * input_level +
|
||||
approximation_params_q_[index];
|
||||
RTC_DCHECK_LE(0.f, gain);
|
||||
return gain;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,152 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "rtc_base/gtest_prod_util.h"
|
||||
#include "system_wrappers/include/metrics.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
class ApmDataDumper;
|
||||
|
||||
constexpr float kInputLevelScalingFactor = 32768.0f;
|
||||
|
||||
// Defined as DbfsToLinear(kLimiterMaxInputLevelDbFs)
|
||||
constexpr float kMaxInputLevelLinear = static_cast<float>(36766.300710566735);
|
||||
|
||||
// Interpolated gain curve using under-approximation to avoid saturation.
|
||||
//
|
||||
// The goal of this class is allowing fast look ups to get an accurate
|
||||
// estimates of the gain to apply given an estimated input level.
|
||||
class InterpolatedGainCurve {
|
||||
public:
|
||||
enum class GainCurveRegion {
|
||||
kIdentity = 0,
|
||||
kKnee = 1,
|
||||
kLimiter = 2,
|
||||
kSaturation = 3
|
||||
};
|
||||
|
||||
struct Stats {
|
||||
// Region in which the output level equals the input one.
|
||||
size_t look_ups_identity_region = 0;
|
||||
// Smoothing between the identity and the limiter regions.
|
||||
size_t look_ups_knee_region = 0;
|
||||
// Limiter region in which the output and input levels are linearly related.
|
||||
size_t look_ups_limiter_region = 0;
|
||||
// Region in which saturation may occur since the input level is beyond the
|
||||
// maximum expected by the limiter.
|
||||
size_t look_ups_saturation_region = 0;
|
||||
// True if stats have been populated.
|
||||
bool available = false;
|
||||
|
||||
// The current region, and for how many frames the level has been
|
||||
// in that region.
|
||||
GainCurveRegion region = GainCurveRegion::kIdentity;
|
||||
int64_t region_duration_frames = 0;
|
||||
};
|
||||
|
||||
InterpolatedGainCurve(ApmDataDumper* apm_data_dumper,
|
||||
absl::string_view histogram_name_prefix);
|
||||
~InterpolatedGainCurve();
|
||||
|
||||
InterpolatedGainCurve(const InterpolatedGainCurve&) = delete;
|
||||
InterpolatedGainCurve& operator=(const InterpolatedGainCurve&) = delete;
|
||||
|
||||
Stats get_stats() const { return stats_; }
|
||||
|
||||
// Given a non-negative input level (linear scale), a scalar factor to apply
|
||||
// to a sub-frame is returned.
|
||||
// Levels above kLimiterMaxInputLevelDbFs will be reduced to 0 dBFS
|
||||
// after applying this gain
|
||||
float LookUpGainToApply(float input_level) const;
|
||||
|
||||
private:
|
||||
// For comparing 'approximation_params_*_' with ones computed by
|
||||
// ComputeInterpolatedGainCurve.
|
||||
FRIEND_TEST_ALL_PREFIXES(GainController2InterpolatedGainCurve,
|
||||
CheckApproximationParams);
|
||||
|
||||
struct RegionLogger {
|
||||
metrics::Histogram* identity_histogram;
|
||||
metrics::Histogram* knee_histogram;
|
||||
metrics::Histogram* limiter_histogram;
|
||||
metrics::Histogram* saturation_histogram;
|
||||
|
||||
RegionLogger(absl::string_view identity_histogram_name,
|
||||
absl::string_view knee_histogram_name,
|
||||
absl::string_view limiter_histogram_name,
|
||||
absl::string_view saturation_histogram_name);
|
||||
|
||||
~RegionLogger();
|
||||
|
||||
void LogRegionStats(const InterpolatedGainCurve::Stats& stats) const;
|
||||
} region_logger_;
|
||||
|
||||
void UpdateStats(float input_level) const;
|
||||
|
||||
ApmDataDumper* const apm_data_dumper_;
|
||||
|
||||
static constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
approximation_params_x_ = {
|
||||
{30057.296875, 30148.986328125, 30240.67578125, 30424.052734375,
|
||||
30607.4296875, 30790.806640625, 30974.18359375, 31157.560546875,
|
||||
31340.939453125, 31524.31640625, 31707.693359375, 31891.0703125,
|
||||
32074.447265625, 32257.82421875, 32441.201171875, 32624.580078125,
|
||||
32807.95703125, 32991.33203125, 33174.7109375, 33358.08984375,
|
||||
33541.46484375, 33724.84375, 33819.53515625, 34009.5390625,
|
||||
34200.05859375, 34389.81640625, 34674.48828125, 35054.375,
|
||||
35434.86328125, 35814.81640625, 36195.16796875, 36575.03125}};
|
||||
static constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
approximation_params_m_ = {
|
||||
{-3.515235675877192989e-07, -1.050251626111275982e-06,
|
||||
-2.085213736791047268e-06, -3.443004743530764244e-06,
|
||||
-4.773849468620028347e-06, -6.077375928725814447e-06,
|
||||
-7.353257842623861507e-06, -8.601219633419532329e-06,
|
||||
-9.821013009059242904e-06, -1.101243378798244521e-05,
|
||||
-1.217532644659513608e-05, -1.330956911260727793e-05,
|
||||
-1.441507538402220234e-05, -1.549179251014720649e-05,
|
||||
-1.653970684856176376e-05, -1.755882840370759368e-05,
|
||||
-1.854918446042574942e-05, -1.951086778717581183e-05,
|
||||
-2.044398024736437947e-05, -2.1348627342376858e-05,
|
||||
-2.222496914328075945e-05, -2.265374678245279938e-05,
|
||||
-2.242570917587727308e-05, -2.220122041762806475e-05,
|
||||
-2.19802095671184361e-05, -2.176260204578284174e-05,
|
||||
-2.133731686626560986e-05, -2.092481918225530535e-05,
|
||||
-2.052459603874012828e-05, -2.013615448959171772e-05,
|
||||
-1.975903069251216948e-05, -1.939277899509761482e-05}};
|
||||
|
||||
static constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
|
||||
approximation_params_q_ = {
|
||||
{1.010565876960754395, 1.031631827354431152, 1.062929749488830566,
|
||||
1.104239225387573242, 1.144973039627075195, 1.185109615325927734,
|
||||
1.224629044532775879, 1.263512492179870605, 1.301741957664489746,
|
||||
1.339300632476806641, 1.376173257827758789, 1.412345528602600098,
|
||||
1.447803974151611328, 1.482536554336547852, 1.516532182693481445,
|
||||
1.549780607223510742, 1.582272171974182129, 1.613999366760253906,
|
||||
1.644955039024353027, 1.675132393836975098, 1.704526185989379883,
|
||||
1.718986630439758301, 1.711274504661560059, 1.703639745712280273,
|
||||
1.696081161499023438, 1.688597679138183594, 1.673851132392883301,
|
||||
1.659391283988952637, 1.645209431648254395, 1.631297469139099121,
|
||||
1.617647409439086914, 1.604251742362976074}};
|
||||
|
||||
// Stats.
|
||||
mutable Stats stats_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_
|
||||
|
|
@ -0,0 +1,155 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/limiter.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
// This constant affects the way scaling factors are interpolated for the first
|
||||
// sub-frame of a frame. Only in the case in which the first sub-frame has an
|
||||
// estimated level which is greater than the that of the previous analyzed
|
||||
// sub-frame, linear interpolation is replaced with a power function which
|
||||
// reduces the chances of over-shooting (and hence saturation), however reducing
|
||||
// the fixed gain effectiveness.
|
||||
constexpr float kAttackFirstSubframeInterpolationPower = 8.0f;
|
||||
|
||||
void InterpolateFirstSubframe(float last_factor,
|
||||
float current_factor,
|
||||
rtc::ArrayView<float> subframe) {
|
||||
const int n = rtc::dchecked_cast<int>(subframe.size());
|
||||
constexpr float p = kAttackFirstSubframeInterpolationPower;
|
||||
for (int i = 0; i < n; ++i) {
|
||||
subframe[i] = std::pow(1.f - i / n, p) * (last_factor - current_factor) +
|
||||
current_factor;
|
||||
}
|
||||
}
|
||||
|
||||
void ComputePerSampleSubframeFactors(
|
||||
const std::array<float, kSubFramesInFrame + 1>& scaling_factors,
|
||||
int samples_per_channel,
|
||||
rtc::ArrayView<float> per_sample_scaling_factors) {
|
||||
const int num_subframes = scaling_factors.size() - 1;
|
||||
const int subframe_size =
|
||||
rtc::CheckedDivExact(samples_per_channel, num_subframes);
|
||||
|
||||
// Handle first sub-frame differently in case of attack.
|
||||
const bool is_attack = scaling_factors[0] > scaling_factors[1];
|
||||
if (is_attack) {
|
||||
InterpolateFirstSubframe(
|
||||
scaling_factors[0], scaling_factors[1],
|
||||
rtc::ArrayView<float>(
|
||||
per_sample_scaling_factors.subview(0, subframe_size)));
|
||||
}
|
||||
|
||||
for (int i = is_attack ? 1 : 0; i < num_subframes; ++i) {
|
||||
const int subframe_start = i * subframe_size;
|
||||
const float scaling_start = scaling_factors[i];
|
||||
const float scaling_end = scaling_factors[i + 1];
|
||||
const float scaling_diff = (scaling_end - scaling_start) / subframe_size;
|
||||
for (int j = 0; j < subframe_size; ++j) {
|
||||
per_sample_scaling_factors[subframe_start + j] =
|
||||
scaling_start + scaling_diff * j;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void ScaleSamples(rtc::ArrayView<const float> per_sample_scaling_factors,
|
||||
AudioFrameView<float> signal) {
|
||||
const int samples_per_channel = signal.samples_per_channel();
|
||||
RTC_DCHECK_EQ(samples_per_channel, per_sample_scaling_factors.size());
|
||||
for (int i = 0; i < signal.num_channels(); ++i) {
|
||||
rtc::ArrayView<float> channel = signal.channel(i);
|
||||
for (int j = 0; j < samples_per_channel; ++j) {
|
||||
channel[j] = rtc::SafeClamp(channel[j] * per_sample_scaling_factors[j],
|
||||
kMinFloatS16Value, kMaxFloatS16Value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void CheckLimiterSampleRate(int sample_rate_hz) {
|
||||
// Check that per_sample_scaling_factors_ is large enough.
|
||||
RTC_DCHECK_LE(sample_rate_hz,
|
||||
kMaximalNumberOfSamplesPerChannel * 1000 / kFrameDurationMs);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
Limiter::Limiter(int sample_rate_hz,
|
||||
ApmDataDumper* apm_data_dumper,
|
||||
absl::string_view histogram_name)
|
||||
: interp_gain_curve_(apm_data_dumper, histogram_name),
|
||||
level_estimator_(sample_rate_hz, apm_data_dumper),
|
||||
apm_data_dumper_(apm_data_dumper) {
|
||||
CheckLimiterSampleRate(sample_rate_hz);
|
||||
}
|
||||
|
||||
Limiter::~Limiter() = default;
|
||||
|
||||
void Limiter::Process(AudioFrameView<float> signal) {
|
||||
const std::array<float, kSubFramesInFrame> level_estimate =
|
||||
level_estimator_.ComputeLevel(signal);
|
||||
|
||||
RTC_DCHECK_EQ(level_estimate.size() + 1, scaling_factors_.size());
|
||||
scaling_factors_[0] = last_scaling_factor_;
|
||||
std::transform(level_estimate.begin(), level_estimate.end(),
|
||||
scaling_factors_.begin() + 1, [this](float x) {
|
||||
return interp_gain_curve_.LookUpGainToApply(x);
|
||||
});
|
||||
|
||||
const int samples_per_channel = signal.samples_per_channel();
|
||||
RTC_DCHECK_LE(samples_per_channel, kMaximalNumberOfSamplesPerChannel);
|
||||
|
||||
auto per_sample_scaling_factors = rtc::ArrayView<float>(
|
||||
&per_sample_scaling_factors_[0], samples_per_channel);
|
||||
ComputePerSampleSubframeFactors(scaling_factors_, samples_per_channel,
|
||||
per_sample_scaling_factors);
|
||||
ScaleSamples(per_sample_scaling_factors, signal);
|
||||
|
||||
last_scaling_factor_ = scaling_factors_.back();
|
||||
|
||||
// Dump data for debug.
|
||||
apm_data_dumper_->DumpRaw("agc2_limiter_last_scaling_factor",
|
||||
last_scaling_factor_);
|
||||
apm_data_dumper_->DumpRaw(
|
||||
"agc2_limiter_region",
|
||||
static_cast<int>(interp_gain_curve_.get_stats().region));
|
||||
}
|
||||
|
||||
InterpolatedGainCurve::Stats Limiter::GetGainCurveStats() const {
|
||||
return interp_gain_curve_.get_stats();
|
||||
}
|
||||
|
||||
void Limiter::SetSampleRate(int sample_rate_hz) {
|
||||
CheckLimiterSampleRate(sample_rate_hz);
|
||||
level_estimator_.SetSampleRate(sample_rate_hz);
|
||||
}
|
||||
|
||||
void Limiter::Reset() {
|
||||
level_estimator_.Reset();
|
||||
}
|
||||
|
||||
float Limiter::LastAudioLevel() const {
|
||||
return level_estimator_.LastAudioLevel();
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,63 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "modules/audio_processing/agc2/fixed_digital_level_estimator.h"
|
||||
#include "modules/audio_processing/agc2/interpolated_gain_curve.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
class ApmDataDumper;
|
||||
|
||||
class Limiter {
|
||||
public:
|
||||
Limiter(int sample_rate_hz,
|
||||
ApmDataDumper* apm_data_dumper,
|
||||
absl::string_view histogram_name_prefix);
|
||||
Limiter(const Limiter& limiter) = delete;
|
||||
Limiter& operator=(const Limiter& limiter) = delete;
|
||||
~Limiter();
|
||||
|
||||
// Applies limiter and hard-clipping to `signal`.
|
||||
void Process(AudioFrameView<float> signal);
|
||||
InterpolatedGainCurve::Stats GetGainCurveStats() const;
|
||||
|
||||
// Supported rates must be
|
||||
// * supported by FixedDigitalLevelEstimator
|
||||
// * below kMaximalNumberOfSamplesPerChannel*1000/kFrameDurationMs
|
||||
// so that samples_per_channel fit in the
|
||||
// per_sample_scaling_factors_ array.
|
||||
void SetSampleRate(int sample_rate_hz);
|
||||
|
||||
// Resets the internal state.
|
||||
void Reset();
|
||||
|
||||
float LastAudioLevel() const;
|
||||
|
||||
private:
|
||||
const InterpolatedGainCurve interp_gain_curve_;
|
||||
FixedDigitalLevelEstimator level_estimator_;
|
||||
ApmDataDumper* const apm_data_dumper_ = nullptr;
|
||||
|
||||
// Work array containing the sub-frame scaling factors to be interpolated.
|
||||
std::array<float, kSubFramesInFrame + 1> scaling_factors_ = {};
|
||||
std::array<float, kMaximalNumberOfSamplesPerChannel>
|
||||
per_sample_scaling_factors_ = {};
|
||||
float last_scaling_factor_ = 1.f;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_
|
||||
|
|
@ -0,0 +1,138 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/limiter_db_gain_curve.h"
|
||||
|
||||
#include <cmath>
|
||||
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
double ComputeKneeStart(double max_input_level_db,
|
||||
double knee_smoothness_db,
|
||||
double compression_ratio) {
|
||||
RTC_CHECK_LT((compression_ratio - 1.0) * knee_smoothness_db /
|
||||
(2.0 * compression_ratio),
|
||||
max_input_level_db);
|
||||
return -knee_smoothness_db / 2.0 -
|
||||
max_input_level_db / (compression_ratio - 1.0);
|
||||
}
|
||||
|
||||
std::array<double, 3> ComputeKneeRegionPolynomial(double knee_start_dbfs,
|
||||
double knee_smoothness_db,
|
||||
double compression_ratio) {
|
||||
const double a = (1.0 - compression_ratio) /
|
||||
(2.0 * knee_smoothness_db * compression_ratio);
|
||||
const double b = 1.0 - 2.0 * a * knee_start_dbfs;
|
||||
const double c = a * knee_start_dbfs * knee_start_dbfs;
|
||||
return {{a, b, c}};
|
||||
}
|
||||
|
||||
double ComputeLimiterD1(double max_input_level_db, double compression_ratio) {
|
||||
return (std::pow(10.0, -max_input_level_db / (20.0 * compression_ratio)) *
|
||||
(1.0 - compression_ratio) / compression_ratio) /
|
||||
kMaxAbsFloatS16Value;
|
||||
}
|
||||
|
||||
constexpr double ComputeLimiterD2(double compression_ratio) {
|
||||
return (1.0 - 2.0 * compression_ratio) / compression_ratio;
|
||||
}
|
||||
|
||||
double ComputeLimiterI2(double max_input_level_db,
|
||||
double compression_ratio,
|
||||
double gain_curve_limiter_i1) {
|
||||
RTC_CHECK_NE(gain_curve_limiter_i1, 0.f);
|
||||
return std::pow(10.0, -max_input_level_db / (20.0 * compression_ratio)) /
|
||||
gain_curve_limiter_i1 /
|
||||
std::pow(kMaxAbsFloatS16Value, gain_curve_limiter_i1 - 1);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
LimiterDbGainCurve::LimiterDbGainCurve()
|
||||
: max_input_level_linear_(DbfsToFloatS16(max_input_level_db_)),
|
||||
knee_start_dbfs_(ComputeKneeStart(max_input_level_db_,
|
||||
knee_smoothness_db_,
|
||||
compression_ratio_)),
|
||||
knee_start_linear_(DbfsToFloatS16(knee_start_dbfs_)),
|
||||
limiter_start_dbfs_(knee_start_dbfs_ + knee_smoothness_db_),
|
||||
limiter_start_linear_(DbfsToFloatS16(limiter_start_dbfs_)),
|
||||
knee_region_polynomial_(ComputeKneeRegionPolynomial(knee_start_dbfs_,
|
||||
knee_smoothness_db_,
|
||||
compression_ratio_)),
|
||||
gain_curve_limiter_d1_(
|
||||
ComputeLimiterD1(max_input_level_db_, compression_ratio_)),
|
||||
gain_curve_limiter_d2_(ComputeLimiterD2(compression_ratio_)),
|
||||
gain_curve_limiter_i1_(1.0 / compression_ratio_),
|
||||
gain_curve_limiter_i2_(ComputeLimiterI2(max_input_level_db_,
|
||||
compression_ratio_,
|
||||
gain_curve_limiter_i1_)) {
|
||||
static_assert(knee_smoothness_db_ > 0.0f, "");
|
||||
static_assert(compression_ratio_ > 1.0f, "");
|
||||
RTC_CHECK_GE(max_input_level_db_, knee_start_dbfs_ + knee_smoothness_db_);
|
||||
}
|
||||
|
||||
constexpr double LimiterDbGainCurve::max_input_level_db_;
|
||||
constexpr double LimiterDbGainCurve::knee_smoothness_db_;
|
||||
constexpr double LimiterDbGainCurve::compression_ratio_;
|
||||
|
||||
double LimiterDbGainCurve::GetOutputLevelDbfs(double input_level_dbfs) const {
|
||||
if (input_level_dbfs < knee_start_dbfs_) {
|
||||
return input_level_dbfs;
|
||||
} else if (input_level_dbfs < limiter_start_dbfs_) {
|
||||
return GetKneeRegionOutputLevelDbfs(input_level_dbfs);
|
||||
}
|
||||
return GetCompressorRegionOutputLevelDbfs(input_level_dbfs);
|
||||
}
|
||||
|
||||
double LimiterDbGainCurve::GetGainLinear(double input_level_linear) const {
|
||||
if (input_level_linear < knee_start_linear_) {
|
||||
return 1.0;
|
||||
}
|
||||
return DbfsToFloatS16(
|
||||
GetOutputLevelDbfs(FloatS16ToDbfs(input_level_linear))) /
|
||||
input_level_linear;
|
||||
}
|
||||
|
||||
// Computes the first derivative of GetGainLinear() in `x`.
|
||||
double LimiterDbGainCurve::GetGainFirstDerivativeLinear(double x) const {
|
||||
// Beyond-knee region only.
|
||||
RTC_CHECK_GE(x, limiter_start_linear_ - 1e-7 * kMaxAbsFloatS16Value);
|
||||
return gain_curve_limiter_d1_ *
|
||||
std::pow(x / kMaxAbsFloatS16Value, gain_curve_limiter_d2_);
|
||||
}
|
||||
|
||||
// Computes the integral of GetGainLinear() in the range [x0, x1].
|
||||
double LimiterDbGainCurve::GetGainIntegralLinear(double x0, double x1) const {
|
||||
RTC_CHECK_LE(x0, x1); // Valid interval.
|
||||
RTC_CHECK_GE(x0, limiter_start_linear_); // Beyond-knee region only.
|
||||
auto limiter_integral = [this](const double& x) {
|
||||
return gain_curve_limiter_i2_ * std::pow(x, gain_curve_limiter_i1_);
|
||||
};
|
||||
return limiter_integral(x1) - limiter_integral(x0);
|
||||
}
|
||||
|
||||
double LimiterDbGainCurve::GetKneeRegionOutputLevelDbfs(
|
||||
double input_level_dbfs) const {
|
||||
return knee_region_polynomial_[0] * input_level_dbfs * input_level_dbfs +
|
||||
knee_region_polynomial_[1] * input_level_dbfs +
|
||||
knee_region_polynomial_[2];
|
||||
}
|
||||
|
||||
double LimiterDbGainCurve::GetCompressorRegionOutputLevelDbfs(
|
||||
double input_level_dbfs) const {
|
||||
return (input_level_dbfs - max_input_level_db_) / compression_ratio_;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,76 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_LIMITER_DB_GAIN_CURVE_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_LIMITER_DB_GAIN_CURVE_H_
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_testing_common.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// A class for computing a limiter gain curve (in dB scale) given a set of
|
||||
// hard-coded parameters (namely, kLimiterDbGainCurveMaxInputLevelDbFs,
|
||||
// kLimiterDbGainCurveKneeSmoothnessDb, and
|
||||
// kLimiterDbGainCurveCompressionRatio). The generated curve consists of four
|
||||
// regions: identity (linear), knee (quadratic polynomial), compression
|
||||
// (linear), saturation (linear). The aforementioned constants are used to shape
|
||||
// the different regions.
|
||||
class LimiterDbGainCurve {
|
||||
public:
|
||||
LimiterDbGainCurve();
|
||||
|
||||
double max_input_level_db() const { return max_input_level_db_; }
|
||||
double max_input_level_linear() const { return max_input_level_linear_; }
|
||||
double knee_start_linear() const { return knee_start_linear_; }
|
||||
double limiter_start_linear() const { return limiter_start_linear_; }
|
||||
|
||||
// These methods can be marked 'constexpr' in C++ 14.
|
||||
double GetOutputLevelDbfs(double input_level_dbfs) const;
|
||||
double GetGainLinear(double input_level_linear) const;
|
||||
double GetGainFirstDerivativeLinear(double x) const;
|
||||
double GetGainIntegralLinear(double x0, double x1) const;
|
||||
|
||||
private:
|
||||
double GetKneeRegionOutputLevelDbfs(double input_level_dbfs) const;
|
||||
double GetCompressorRegionOutputLevelDbfs(double input_level_dbfs) const;
|
||||
|
||||
static constexpr double max_input_level_db_ = test::kLimiterMaxInputLevelDbFs;
|
||||
static constexpr double knee_smoothness_db_ = test::kLimiterKneeSmoothnessDb;
|
||||
static constexpr double compression_ratio_ = test::kLimiterCompressionRatio;
|
||||
|
||||
const double max_input_level_linear_;
|
||||
|
||||
// Do not modify signal with level <= knee_start_dbfs_.
|
||||
const double knee_start_dbfs_;
|
||||
const double knee_start_linear_;
|
||||
|
||||
// The upper end of the knee region, which is between knee_start_dbfs_ and
|
||||
// limiter_start_dbfs_.
|
||||
const double limiter_start_dbfs_;
|
||||
const double limiter_start_linear_;
|
||||
|
||||
// Coefficients {a, b, c} of the knee region polynomial
|
||||
// ax^2 + bx + c in the DB scale.
|
||||
const std::array<double, 3> knee_region_polynomial_;
|
||||
|
||||
// Parameters for the computation of the first derivative of GetGainLinear().
|
||||
const double gain_curve_limiter_d1_;
|
||||
const double gain_curve_limiter_d2_;
|
||||
|
||||
// Parameters for the computation of the integral of GetGainLinear().
|
||||
const double gain_curve_limiter_i1_;
|
||||
const double gain_curve_limiter_i2_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_LIMITER_DB_GAIN_CURVE_H_
|
||||
|
|
@ -0,0 +1,60 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/limiter.h"
|
||||
|
||||
#include "common_audio/include/audio_util.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/agc2/agc2_testing_common.h"
|
||||
#include "modules/audio_processing/agc2/vector_float_frame.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/gunit.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
TEST(Limiter, LimiterShouldConstructAndRun) {
|
||||
const int sample_rate_hz = 48000;
|
||||
ApmDataDumper apm_data_dumper(0);
|
||||
|
||||
Limiter limiter(sample_rate_hz, &apm_data_dumper, "");
|
||||
|
||||
VectorFloatFrame vectors_with_float_frame(1, sample_rate_hz / 100,
|
||||
kMaxAbsFloatS16Value);
|
||||
limiter.Process(vectors_with_float_frame.float_frame_view());
|
||||
}
|
||||
|
||||
TEST(Limiter, OutputVolumeAboveThreshold) {
|
||||
const int sample_rate_hz = 48000;
|
||||
const float input_level =
|
||||
(kMaxAbsFloatS16Value + DbfsToFloatS16(test::kLimiterMaxInputLevelDbFs)) /
|
||||
2.f;
|
||||
ApmDataDumper apm_data_dumper(0);
|
||||
|
||||
Limiter limiter(sample_rate_hz, &apm_data_dumper, "");
|
||||
|
||||
// Give the level estimator time to adapt.
|
||||
for (int i = 0; i < 5; ++i) {
|
||||
VectorFloatFrame vectors_with_float_frame(1, sample_rate_hz / 100,
|
||||
input_level);
|
||||
limiter.Process(vectors_with_float_frame.float_frame_view());
|
||||
}
|
||||
|
||||
VectorFloatFrame vectors_with_float_frame(1, sample_rate_hz / 100,
|
||||
input_level);
|
||||
limiter.Process(vectors_with_float_frame.float_frame_view());
|
||||
rtc::ArrayView<const float> channel =
|
||||
vectors_with_float_frame.float_frame_view().channel(0);
|
||||
|
||||
for (const auto& sample : channel) {
|
||||
EXPECT_LT(0.9f * kMaxAbsFloatS16Value, sample);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,172 @@
|
|||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/noise_level_estimator.h"
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
constexpr int kFramesPerSecond = 100;
|
||||
|
||||
float FrameEnergy(const AudioFrameView<const float>& audio) {
|
||||
float energy = 0.0f;
|
||||
for (int k = 0; k < audio.num_channels(); ++k) {
|
||||
float channel_energy =
|
||||
std::accumulate(audio.channel(k).begin(), audio.channel(k).end(), 0.0f,
|
||||
[](float a, float b) -> float { return a + b * b; });
|
||||
energy = std::max(channel_energy, energy);
|
||||
}
|
||||
return energy;
|
||||
}
|
||||
|
||||
float EnergyToDbfs(float signal_energy, int num_samples) {
|
||||
RTC_DCHECK_GE(signal_energy, 0.0f);
|
||||
const float rms_square = signal_energy / num_samples;
|
||||
constexpr float kMinDbfs = -90.30899869919436f;
|
||||
if (rms_square <= 1.0f) {
|
||||
return kMinDbfs;
|
||||
}
|
||||
return 10.0f * std::log10(rms_square) + kMinDbfs;
|
||||
}
|
||||
|
||||
// Updates the noise floor with instant decay and slow attack. This tuning is
|
||||
// specific for AGC2, so that (i) it can promptly increase the gain if the noise
|
||||
// floor drops (instant decay) and (ii) in case of music or fast speech, due to
|
||||
// which the noise floor can be overestimated, the gain reduction is slowed
|
||||
// down.
|
||||
float SmoothNoiseFloorEstimate(float current_estimate, float new_estimate) {
|
||||
constexpr float kAttack = 0.5f;
|
||||
if (current_estimate < new_estimate) {
|
||||
// Attack phase.
|
||||
return kAttack * new_estimate + (1.0f - kAttack) * current_estimate;
|
||||
}
|
||||
// Instant attack.
|
||||
return new_estimate;
|
||||
}
|
||||
|
||||
class NoiseFloorEstimator : public NoiseLevelEstimator {
|
||||
public:
|
||||
// Update the noise floor every 5 seconds.
|
||||
static constexpr int kUpdatePeriodNumFrames = 500;
|
||||
static_assert(kUpdatePeriodNumFrames >= 200,
|
||||
"A too small value may cause noise level overestimation.");
|
||||
static_assert(kUpdatePeriodNumFrames <= 1500,
|
||||
"A too large value may make AGC2 slow at reacting to increased "
|
||||
"noise levels.");
|
||||
|
||||
NoiseFloorEstimator(ApmDataDumper* data_dumper) : data_dumper_(data_dumper) {
|
||||
RTC_DCHECK(data_dumper_);
|
||||
// Initially assume that 48 kHz will be used. `Analyze()` will detect the
|
||||
// used sample rate and call `Initialize()` again if needed.
|
||||
Initialize(/*sample_rate_hz=*/48000);
|
||||
}
|
||||
NoiseFloorEstimator(const NoiseFloorEstimator&) = delete;
|
||||
NoiseFloorEstimator& operator=(const NoiseFloorEstimator&) = delete;
|
||||
~NoiseFloorEstimator() = default;
|
||||
|
||||
float Analyze(const AudioFrameView<const float>& frame) override {
|
||||
// Detect sample rate changes.
|
||||
const int sample_rate_hz =
|
||||
static_cast<int>(frame.samples_per_channel() * kFramesPerSecond);
|
||||
if (sample_rate_hz != sample_rate_hz_) {
|
||||
Initialize(sample_rate_hz);
|
||||
}
|
||||
|
||||
const float frame_energy = FrameEnergy(frame);
|
||||
if (frame_energy <= min_noise_energy_) {
|
||||
// Ignore frames when muted or below the minimum measurable energy.
|
||||
if (data_dumper_)
|
||||
data_dumper_->DumpRaw("agc2_noise_floor_estimator_preliminary_level",
|
||||
noise_energy_);
|
||||
return EnergyToDbfs(noise_energy_,
|
||||
static_cast<int>(frame.samples_per_channel()));
|
||||
}
|
||||
|
||||
if (preliminary_noise_energy_set_) {
|
||||
preliminary_noise_energy_ =
|
||||
std::min(preliminary_noise_energy_, frame_energy);
|
||||
} else {
|
||||
preliminary_noise_energy_ = frame_energy;
|
||||
preliminary_noise_energy_set_ = true;
|
||||
}
|
||||
if (data_dumper_)
|
||||
data_dumper_->DumpRaw("agc2_noise_floor_estimator_preliminary_level",
|
||||
preliminary_noise_energy_);
|
||||
|
||||
if (counter_ == 0) {
|
||||
// Full period observed.
|
||||
first_period_ = false;
|
||||
// Update the estimated noise floor energy with the preliminary
|
||||
// estimation.
|
||||
noise_energy_ = SmoothNoiseFloorEstimate(
|
||||
/*current_estimate=*/noise_energy_,
|
||||
/*new_estimate=*/preliminary_noise_energy_);
|
||||
// Reset for a new observation period.
|
||||
counter_ = kUpdatePeriodNumFrames;
|
||||
preliminary_noise_energy_set_ = false;
|
||||
} else if (first_period_) {
|
||||
// While analyzing the signal during the initial period, continuously
|
||||
// update the estimated noise energy, which is monotonic.
|
||||
noise_energy_ = preliminary_noise_energy_;
|
||||
counter_--;
|
||||
} else {
|
||||
// During the observation period it's only allowed to lower the energy.
|
||||
noise_energy_ = std::min(noise_energy_, preliminary_noise_energy_);
|
||||
counter_--;
|
||||
}
|
||||
|
||||
float noise_rms_dbfs = EnergyToDbfs(
|
||||
noise_energy_, static_cast<int>(frame.samples_per_channel()));
|
||||
if (data_dumper_)
|
||||
data_dumper_->DumpRaw("agc2_noise_rms_dbfs", noise_rms_dbfs);
|
||||
|
||||
return noise_rms_dbfs;
|
||||
}
|
||||
|
||||
private:
|
||||
void Initialize(int sample_rate_hz) {
|
||||
sample_rate_hz_ = sample_rate_hz;
|
||||
first_period_ = true;
|
||||
preliminary_noise_energy_set_ = false;
|
||||
// Initialize the minimum noise energy to -84 dBFS.
|
||||
min_noise_energy_ = sample_rate_hz * 2.0f * 2.0f / kFramesPerSecond;
|
||||
preliminary_noise_energy_ = min_noise_energy_;
|
||||
noise_energy_ = min_noise_energy_;
|
||||
counter_ = kUpdatePeriodNumFrames;
|
||||
}
|
||||
|
||||
ApmDataDumper* const data_dumper_;
|
||||
int sample_rate_hz_;
|
||||
float min_noise_energy_;
|
||||
bool first_period_;
|
||||
bool preliminary_noise_energy_set_;
|
||||
float preliminary_noise_energy_;
|
||||
float noise_energy_;
|
||||
int counter_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<NoiseLevelEstimator> CreateNoiseFloorEstimator(
|
||||
ApmDataDumper* data_dumper) {
|
||||
return std::make_unique<NoiseFloorEstimator>(data_dumper);
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
/*
|
||||
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
class ApmDataDumper;
|
||||
|
||||
// Noise level estimator interface.
|
||||
class NoiseLevelEstimator {
|
||||
public:
|
||||
virtual ~NoiseLevelEstimator() = default;
|
||||
// Analyzes a 10 ms `frame`, updates the noise level estimation and returns
|
||||
// the value for the latter in dBFS.
|
||||
virtual float Analyze(const AudioFrameView<const float>& frame) = 0;
|
||||
};
|
||||
|
||||
// Creates a noise level estimator based on noise floor detection.
|
||||
std::unique_ptr<NoiseLevelEstimator> CreateNoiseFloorEstimator(
|
||||
ApmDataDumper* data_dumper);
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
|
||||
|
|
@ -0,0 +1,3 @@
|
|||
include_rules = [
|
||||
"+third_party/rnnoise",
|
||||
]
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
|
||||
static_assert(1 << kAutoCorrelationFftOrder >
|
||||
kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
|
||||
"");
|
||||
|
||||
} // namespace
|
||||
|
||||
AutoCorrelationCalculator::AutoCorrelationCalculator()
|
||||
: fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal),
|
||||
tmp_(fft_.CreateBuffer()),
|
||||
X_(fft_.CreateBuffer()),
|
||||
H_(fft_.CreateBuffer()) {}
|
||||
|
||||
AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
|
||||
|
||||
// The auto-correlations coefficients are computed as follows:
|
||||
// |.........|...........| <- pitch buffer
|
||||
// [ x (fixed) ]
|
||||
// [ y_0 ]
|
||||
// [ y_{m-1} ]
|
||||
// x and y are sub-array of equal length; x is never moved, whereas y slides.
|
||||
// The cross-correlation between y_0 and x corresponds to the auto-correlation
|
||||
// for the maximum pitch period. Hence, the first value in `auto_corr` has an
|
||||
// inverted lag equal to 0 that corresponds to a lag equal to the maximum
|
||||
// pitch period.
|
||||
void AutoCorrelationCalculator::ComputeOnPitchBuffer(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
|
||||
rtc::ArrayView<float, kNumLags12kHz> auto_corr) {
|
||||
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
|
||||
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
|
||||
constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
|
||||
constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
|
||||
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
|
||||
"Mismatch between pitch buffer size, frame size and maximum "
|
||||
"pitch period.");
|
||||
static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength,
|
||||
"The FFT length is not sufficiently big to avoid cyclic "
|
||||
"convolution errors.");
|
||||
auto tmp = tmp_->GetView();
|
||||
|
||||
// Compute the FFT for the reversed reference frame - i.e.,
|
||||
// pitch_buf[-kConvolutionLength:].
|
||||
std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(),
|
||||
tmp.begin());
|
||||
std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f);
|
||||
fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false);
|
||||
|
||||
// Compute the FFT for the sliding frames chunk. The sliding frames are
|
||||
// defined as pitch_buf[i:i+kConvolutionLength] where i in
|
||||
// [0, kNumLags12kHz). The chunk includes all of them, hence it is
|
||||
// defined as pitch_buf[:kNumLags12kHz+kConvolutionLength].
|
||||
std::copy(pitch_buf.begin(),
|
||||
pitch_buf.begin() + kConvolutionLength + kNumLags12kHz,
|
||||
tmp.begin());
|
||||
std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f);
|
||||
fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
|
||||
|
||||
// Convolve in the frequency domain.
|
||||
constexpr float kScalingFactor = 1.f / static_cast<float>(kFftFrameSize);
|
||||
std::fill(tmp.begin(), tmp.end(), 0.f);
|
||||
fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor);
|
||||
fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false);
|
||||
|
||||
// Extract the auto-correlation coefficients.
|
||||
std::copy(tmp.begin() + kConvolutionLength - 1,
|
||||
tmp.begin() + kConvolutionLength + kNumLags12kHz - 1,
|
||||
auto_corr.begin());
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
/*
|
||||
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/utility/pffft_wrapper.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Class to compute the auto correlation on the pitch buffer for a target pitch
|
||||
// interval.
|
||||
class AutoCorrelationCalculator {
|
||||
public:
|
||||
AutoCorrelationCalculator();
|
||||
AutoCorrelationCalculator(const AutoCorrelationCalculator&) = delete;
|
||||
AutoCorrelationCalculator& operator=(const AutoCorrelationCalculator&) =
|
||||
delete;
|
||||
~AutoCorrelationCalculator();
|
||||
|
||||
// Computes the auto-correlation coefficients for a target pitch interval.
|
||||
// `auto_corr` indexes are inverted lags.
|
||||
void ComputeOnPitchBuffer(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
|
||||
rtc::ArrayView<float, kNumLags12kHz> auto_corr);
|
||||
|
||||
private:
|
||||
Pffft fft_;
|
||||
std::unique_ptr<Pffft::FloatBuffer> tmp_;
|
||||
std::unique_ptr<Pffft::FloatBuffer> X_;
|
||||
std::unique_ptr<Pffft::FloatBuffer> H_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
constexpr double kPi = 3.14159265358979323846;
|
||||
|
||||
constexpr int kSampleRate24kHz = 24000;
|
||||
constexpr int kFrameSize10ms24kHz = kSampleRate24kHz / 100;
|
||||
constexpr int kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
|
||||
|
||||
// Pitch buffer.
|
||||
constexpr int kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
|
||||
constexpr int kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
|
||||
constexpr int kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
|
||||
static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even.");
|
||||
|
||||
// 24 kHz analysis.
|
||||
// Define a higher minimum pitch period for the initial search. This is used to
|
||||
// avoid searching for very short periods, for which a refinement step is
|
||||
// responsible.
|
||||
constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
|
||||
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
|
||||
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
|
||||
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
|
||||
// Number of (inverted) lags during the initial pitch search phase at 24 kHz.
|
||||
constexpr int kInitialNumLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
|
||||
// Number of (inverted) lags during the pitch search refinement phase at 24 kHz.
|
||||
constexpr int kRefineNumLags24kHz = kMaxPitch24kHz + 1;
|
||||
static_assert(
|
||||
kRefineNumLags24kHz > kInitialNumLags24kHz,
|
||||
"The refinement step must search the pitch in an extended pitch range.");
|
||||
|
||||
// 12 kHz analysis.
|
||||
constexpr int kSampleRate12kHz = 12000;
|
||||
constexpr int kFrameSize10ms12kHz = kSampleRate12kHz / 100;
|
||||
constexpr int kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
|
||||
constexpr int kBufSize12kHz = kBufSize24kHz / 2;
|
||||
constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
|
||||
constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2;
|
||||
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
|
||||
// The inverted lags for the pitch interval [`kInitialMinPitch12kHz`,
|
||||
// `kMaxPitch12kHz`] are in the range [0, `kNumLags12kHz`].
|
||||
constexpr int kNumLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
|
||||
|
||||
// 48 kHz constants.
|
||||
constexpr int kMinPitch48kHz = kMinPitch24kHz * 2;
|
||||
constexpr int kMaxPitch48kHz = kMaxPitch24kHz * 2;
|
||||
|
||||
// Spectral features.
|
||||
constexpr int kNumBands = 22;
|
||||
constexpr int kNumLowerBands = 6;
|
||||
static_assert((0 < kNumLowerBands) && (kNumLowerBands < kNumBands), "");
|
||||
constexpr int kCepstralCoeffsHistorySize = 8;
|
||||
static_assert(kCepstralCoeffsHistorySize > 2,
|
||||
"The history size must at least be 3 to compute first and second "
|
||||
"derivatives.");
|
||||
|
||||
constexpr int kFeatureVectorSize = 42;
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
|
||||
|
|
@ -0,0 +1,90 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
// Computed as `scipy.signal.butter(N=2, Wn=60/24000, btype='highpass')`.
|
||||
constexpr BiQuadFilter::Config kHpfConfig24k{
|
||||
{0.99446179f, -1.98892358f, 0.99446179f},
|
||||
{-1.98889291f, 0.98895425f}};
|
||||
|
||||
} // namespace
|
||||
|
||||
FeaturesExtractor::FeaturesExtractor(const AvailableCpuFeatures& cpu_features)
|
||||
: use_high_pass_filter_(false),
|
||||
hpf_(kHpfConfig24k),
|
||||
pitch_buf_24kHz_(),
|
||||
pitch_buf_24kHz_view_(pitch_buf_24kHz_.GetBufferView()),
|
||||
lp_residual_(kBufSize24kHz),
|
||||
lp_residual_view_(lp_residual_.data(), kBufSize24kHz),
|
||||
pitch_estimator_(cpu_features),
|
||||
reference_frame_view_(pitch_buf_24kHz_.GetMostRecentValuesView()) {
|
||||
RTC_DCHECK_EQ(kBufSize24kHz, lp_residual_.size());
|
||||
Reset();
|
||||
}
|
||||
|
||||
FeaturesExtractor::~FeaturesExtractor() = default;
|
||||
|
||||
void FeaturesExtractor::Reset() {
|
||||
pitch_buf_24kHz_.Reset();
|
||||
spectral_features_extractor_.Reset();
|
||||
if (use_high_pass_filter_) {
|
||||
hpf_.Reset();
|
||||
}
|
||||
}
|
||||
|
||||
bool FeaturesExtractor::CheckSilenceComputeFeatures(
|
||||
rtc::ArrayView<const float, kFrameSize10ms24kHz> samples,
|
||||
rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
|
||||
// Pre-processing.
|
||||
if (use_high_pass_filter_) {
|
||||
std::array<float, kFrameSize10ms24kHz> samples_filtered;
|
||||
hpf_.Process(samples, samples_filtered);
|
||||
// Feed buffer with the pre-processed version of `samples`.
|
||||
pitch_buf_24kHz_.Push(samples_filtered);
|
||||
} else {
|
||||
// Feed buffer with `samples`.
|
||||
pitch_buf_24kHz_.Push(samples);
|
||||
}
|
||||
// Extract the LP residual.
|
||||
float lpc_coeffs[kNumLpcCoefficients];
|
||||
ComputeAndPostProcessLpcCoefficients(pitch_buf_24kHz_view_, lpc_coeffs);
|
||||
ComputeLpResidual(lpc_coeffs, pitch_buf_24kHz_view_, lp_residual_view_);
|
||||
// Estimate pitch on the LP-residual and write the normalized pitch period
|
||||
// into the output vector (normalization based on training data stats).
|
||||
pitch_period_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
|
||||
feature_vector[kFeatureVectorSize - 2] = 0.01f * (pitch_period_48kHz_ - 300);
|
||||
// Extract lagged frames (according to the estimated pitch period).
|
||||
RTC_DCHECK_LE(pitch_period_48kHz_ / 2, kMaxPitch24kHz);
|
||||
auto lagged_frame = pitch_buf_24kHz_view_.subview(
|
||||
kMaxPitch24kHz - pitch_period_48kHz_ / 2, kFrameSize20ms24kHz);
|
||||
// Analyze reference and lagged frames checking if silence has been detected
|
||||
// and write the feature vector.
|
||||
return spectral_features_extractor_.CheckSilenceComputeFeatures(
|
||||
reference_frame_view_, {lagged_frame.data(), kFrameSize20ms24kHz},
|
||||
{feature_vector.data() + kNumLowerBands, kNumBands - kNumLowerBands},
|
||||
{feature_vector.data(), kNumLowerBands},
|
||||
{feature_vector.data() + kNumBands, kNumLowerBands},
|
||||
{feature_vector.data() + kNumBands + kNumLowerBands, kNumLowerBands},
|
||||
{feature_vector.data() + kNumBands + 2 * kNumLowerBands, kNumLowerBands},
|
||||
&feature_vector[kFeatureVectorSize - 1]);
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,61 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/biquad_filter.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Feature extractor to feed the VAD RNN.
|
||||
class FeaturesExtractor {
|
||||
public:
|
||||
explicit FeaturesExtractor(const AvailableCpuFeatures& cpu_features);
|
||||
FeaturesExtractor(const FeaturesExtractor&) = delete;
|
||||
FeaturesExtractor& operator=(const FeaturesExtractor&) = delete;
|
||||
~FeaturesExtractor();
|
||||
void Reset();
|
||||
// Analyzes the samples, computes the feature vector and returns true if
|
||||
// silence is detected (false if not). When silence is detected,
|
||||
// `feature_vector` is partially written and therefore must not be used to
|
||||
// feed the VAD RNN.
|
||||
bool CheckSilenceComputeFeatures(
|
||||
rtc::ArrayView<const float, kFrameSize10ms24kHz> samples,
|
||||
rtc::ArrayView<float, kFeatureVectorSize> feature_vector);
|
||||
|
||||
private:
|
||||
const bool use_high_pass_filter_;
|
||||
// TODO(bugs.webrtc.org/7494): Remove HPF depending on how AGC2 is used in APM
|
||||
// and on whether an HPF is already used as pre-processing step in APM.
|
||||
BiQuadFilter hpf_;
|
||||
SequenceBuffer<float, kBufSize24kHz, kFrameSize10ms24kHz, kFrameSize20ms24kHz>
|
||||
pitch_buf_24kHz_;
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf_24kHz_view_;
|
||||
std::vector<float> lp_residual_;
|
||||
rtc::ArrayView<float, kBufSize24kHz> lp_residual_view_;
|
||||
PitchEstimator pitch_estimator_;
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame_view_;
|
||||
SpectralFeaturesExtractor spectral_features_extractor_;
|
||||
int pitch_period_48kHz_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
|
||||
|
|
@ -0,0 +1,141 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cmath>
|
||||
#include <numeric>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
// Computes auto-correlation coefficients for `x` and writes them in
|
||||
// `auto_corr`. The lag values are in {0, ..., max_lag - 1}, where max_lag
|
||||
// equals the size of `auto_corr`.
|
||||
void ComputeAutoCorrelation(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> auto_corr) {
|
||||
constexpr int max_lag = auto_corr.size();
|
||||
RTC_DCHECK_LT(max_lag, x.size());
|
||||
for (int lag = 0; lag < max_lag; ++lag) {
|
||||
auto_corr[lag] =
|
||||
std::inner_product(x.begin(), x.end() - lag, x.begin() + lag, 0.f);
|
||||
}
|
||||
}
|
||||
|
||||
// Applies denoising to the auto-correlation coefficients.
|
||||
void DenoiseAutoCorrelation(
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> auto_corr) {
|
||||
// Assume -40 dB white noise floor.
|
||||
auto_corr[0] *= 1.0001f;
|
||||
// Hard-coded values obtained as
|
||||
// [np.float32((0.008*0.008*i*i)) for i in range(1,5)].
|
||||
auto_corr[1] -= auto_corr[1] * 0.000064f;
|
||||
auto_corr[2] -= auto_corr[2] * 0.000256f;
|
||||
auto_corr[3] -= auto_corr[3] * 0.000576f;
|
||||
auto_corr[4] -= auto_corr[4] * 0.001024f;
|
||||
static_assert(kNumLpcCoefficients == 5, "Update `auto_corr`.");
|
||||
}
|
||||
|
||||
// Computes the initial inverse filter coefficients given the auto-correlation
|
||||
// coefficients of an input frame.
|
||||
void ComputeInitialInverseFilterCoefficients(
|
||||
rtc::ArrayView<const float, kNumLpcCoefficients> auto_corr,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients - 1> lpc_coeffs) {
|
||||
float error = auto_corr[0];
|
||||
for (int i = 0; i < kNumLpcCoefficients - 1; ++i) {
|
||||
float reflection_coeff = 0.f;
|
||||
for (int j = 0; j < i; ++j) {
|
||||
reflection_coeff += lpc_coeffs[j] * auto_corr[i - j];
|
||||
}
|
||||
reflection_coeff += auto_corr[i + 1];
|
||||
|
||||
// Avoid division by numbers close to zero.
|
||||
constexpr float kMinErrorMagnitude = 1e-6f;
|
||||
if (std::fabs(error) < kMinErrorMagnitude) {
|
||||
error = std::copysign(kMinErrorMagnitude, error);
|
||||
}
|
||||
|
||||
reflection_coeff /= -error;
|
||||
// Update LPC coefficients and total error.
|
||||
lpc_coeffs[i] = reflection_coeff;
|
||||
for (int j = 0; j < ((i + 1) >> 1); ++j) {
|
||||
const float tmp1 = lpc_coeffs[j];
|
||||
const float tmp2 = lpc_coeffs[i - 1 - j];
|
||||
lpc_coeffs[j] = tmp1 + reflection_coeff * tmp2;
|
||||
lpc_coeffs[i - 1 - j] = tmp2 + reflection_coeff * tmp1;
|
||||
}
|
||||
error -= reflection_coeff * reflection_coeff * error;
|
||||
if (error < 0.001f * auto_corr[0]) {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void ComputeAndPostProcessLpcCoefficients(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs) {
|
||||
std::array<float, kNumLpcCoefficients> auto_corr;
|
||||
ComputeAutoCorrelation(x, auto_corr);
|
||||
if (auto_corr[0] == 0.f) { // Empty frame.
|
||||
std::fill(lpc_coeffs.begin(), lpc_coeffs.end(), 0);
|
||||
return;
|
||||
}
|
||||
DenoiseAutoCorrelation(auto_corr);
|
||||
std::array<float, kNumLpcCoefficients - 1> lpc_coeffs_pre{};
|
||||
ComputeInitialInverseFilterCoefficients(auto_corr, lpc_coeffs_pre);
|
||||
// LPC coefficients post-processing.
|
||||
// TODO(bugs.webrtc.org/9076): Consider removing these steps.
|
||||
lpc_coeffs_pre[0] *= 0.9f;
|
||||
lpc_coeffs_pre[1] *= 0.9f * 0.9f;
|
||||
lpc_coeffs_pre[2] *= 0.9f * 0.9f * 0.9f;
|
||||
lpc_coeffs_pre[3] *= 0.9f * 0.9f * 0.9f * 0.9f;
|
||||
constexpr float kC = 0.8f;
|
||||
lpc_coeffs[0] = lpc_coeffs_pre[0] + kC;
|
||||
lpc_coeffs[1] = lpc_coeffs_pre[1] + kC * lpc_coeffs_pre[0];
|
||||
lpc_coeffs[2] = lpc_coeffs_pre[2] + kC * lpc_coeffs_pre[1];
|
||||
lpc_coeffs[3] = lpc_coeffs_pre[3] + kC * lpc_coeffs_pre[2];
|
||||
lpc_coeffs[4] = kC * lpc_coeffs_pre[3];
|
||||
static_assert(kNumLpcCoefficients == 5, "Update `lpc_coeffs(_pre)`.");
|
||||
}
|
||||
|
||||
void ComputeLpResidual(
|
||||
rtc::ArrayView<const float, kNumLpcCoefficients> lpc_coeffs,
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float> y) {
|
||||
RTC_DCHECK_GT(x.size(), kNumLpcCoefficients);
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
// The code below implements the following operation:
|
||||
// y[i] = x[i] + dot_product({x[i], ..., x[i - kNumLpcCoefficients + 1]},
|
||||
// lpc_coeffs)
|
||||
// Edge case: i < kNumLpcCoefficients.
|
||||
y[0] = x[0];
|
||||
for (int i = 1; i < kNumLpcCoefficients; ++i) {
|
||||
y[i] =
|
||||
std::inner_product(x.crend() - i, x.crend(), lpc_coeffs.cbegin(), x[i]);
|
||||
}
|
||||
// Regular case.
|
||||
auto last = x.crend();
|
||||
for (int i = kNumLpcCoefficients; rtc::SafeLt(i, y.size()); ++i, --last) {
|
||||
y[i] = std::inner_product(last - kNumLpcCoefficients, last,
|
||||
lpc_coeffs.cbegin(), x[i]);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,41 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include "api/array_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Linear predictive coding (LPC) inverse filter length.
|
||||
constexpr int kNumLpcCoefficients = 5;
|
||||
|
||||
// Given a frame `x`, computes a post-processed version of LPC coefficients
|
||||
// tailored for pitch estimation.
|
||||
void ComputeAndPostProcessLpcCoefficients(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs);
|
||||
|
||||
// Computes the LP residual for the input frame `x` and the LPC coefficients
|
||||
// `lpc_coeffs`. `y` and `x` can point to the same array for in-place
|
||||
// computation.
|
||||
void ComputeLpResidual(
|
||||
rtc::ArrayView<const float, kNumLpcCoefficients> lpc_coeffs,
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float> y);
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
PitchEstimator::PitchEstimator(const AvailableCpuFeatures& cpu_features)
|
||||
: cpu_features_(cpu_features),
|
||||
y_energy_24kHz_(kRefineNumLags24kHz, 0.f),
|
||||
pitch_buffer_12kHz_(kBufSize12kHz),
|
||||
auto_correlation_12kHz_(kNumLags12kHz) {}
|
||||
|
||||
PitchEstimator::~PitchEstimator() = default;
|
||||
|
||||
int PitchEstimator::Estimate(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
|
||||
rtc::ArrayView<float, kBufSize12kHz> pitch_buffer_12kHz_view(
|
||||
pitch_buffer_12kHz_.data(), kBufSize12kHz);
|
||||
RTC_DCHECK_EQ(pitch_buffer_12kHz_.size(), pitch_buffer_12kHz_view.size());
|
||||
rtc::ArrayView<float, kNumLags12kHz> auto_correlation_12kHz_view(
|
||||
auto_correlation_12kHz_.data(), kNumLags12kHz);
|
||||
RTC_DCHECK_EQ(auto_correlation_12kHz_.size(),
|
||||
auto_correlation_12kHz_view.size());
|
||||
|
||||
// TODO(bugs.chromium.org/10480): Use `cpu_features_` to estimate pitch.
|
||||
// Perform the initial pitch search at 12 kHz.
|
||||
Decimate2x(pitch_buffer, pitch_buffer_12kHz_view);
|
||||
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view,
|
||||
auto_correlation_12kHz_view);
|
||||
CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz(
|
||||
pitch_buffer_12kHz_view, auto_correlation_12kHz_view, cpu_features_);
|
||||
// The refinement is done using the pitch buffer that contains 24 kHz samples.
|
||||
// Therefore, adapt the inverted lags in `pitch_candidates_inv_lags` from 12
|
||||
// to 24 kHz.
|
||||
pitch_periods.best *= 2;
|
||||
pitch_periods.second_best *= 2;
|
||||
|
||||
// Refine the initial pitch period estimation from 12 kHz to 48 kHz.
|
||||
// Pre-compute frame energies at 24 kHz.
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_24kHz_view(
|
||||
y_energy_24kHz_.data(), kRefineNumLags24kHz);
|
||||
RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size());
|
||||
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view,
|
||||
cpu_features_);
|
||||
// Estimation at 48 kHz.
|
||||
const int pitch_lag_48kHz = ComputePitchPeriod48kHz(
|
||||
pitch_buffer, y_energy_24kHz_view, pitch_periods, cpu_features_);
|
||||
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
|
||||
pitch_buffer, y_energy_24kHz_view,
|
||||
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz,
|
||||
last_pitch_48kHz_, cpu_features_);
|
||||
return last_pitch_48kHz_.period;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
#include "rtc_base/gtest_prod_util.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Pitch estimator.
|
||||
class PitchEstimator {
|
||||
public:
|
||||
explicit PitchEstimator(const AvailableCpuFeatures& cpu_features);
|
||||
PitchEstimator(const PitchEstimator&) = delete;
|
||||
PitchEstimator& operator=(const PitchEstimator&) = delete;
|
||||
~PitchEstimator();
|
||||
// Returns the estimated pitch period at 48 kHz.
|
||||
int Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer);
|
||||
|
||||
private:
|
||||
FRIEND_TEST_ALL_PREFIXES(RnnVadTest, PitchSearchWithinTolerance);
|
||||
float GetLastPitchStrengthForTesting() const {
|
||||
return last_pitch_48kHz_.strength;
|
||||
}
|
||||
|
||||
const AvailableCpuFeatures cpu_features_;
|
||||
PitchInfo last_pitch_48kHz_{};
|
||||
AutoCorrelationCalculator auto_corr_calculator_;
|
||||
std::vector<float> y_energy_24kHz_;
|
||||
std::vector<float> pitch_buffer_12kHz_;
|
||||
std::vector<float> auto_correlation_12kHz_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
|
||||
|
|
@ -0,0 +1,513 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
#include <numeric>
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
float ComputeAutoCorrelation(
|
||||
int inverted_lag,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
const VectorMath& vector_math) {
|
||||
RTC_DCHECK_LT(inverted_lag, kBufSize24kHz);
|
||||
RTC_DCHECK_LT(inverted_lag, kRefineNumLags24kHz);
|
||||
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
|
||||
return vector_math.DotProduct(
|
||||
pitch_buffer.subview(/*offset=*/kMaxPitch24kHz),
|
||||
pitch_buffer.subview(inverted_lag, kFrameSize20ms24kHz));
|
||||
}
|
||||
|
||||
// Given an auto-correlation coefficient `curr_auto_correlation` and its
|
||||
// neighboring values `prev_auto_correlation` and `next_auto_correlation`
|
||||
// computes a pseudo-interpolation offset to be applied to the pitch period
|
||||
// associated to `curr`. The output is a lag in {-1, 0, +1}.
|
||||
// TODO(bugs.webrtc.org/9076): Consider removing this method.
|
||||
// `GetPitchPseudoInterpolationOffset()` it is relevant only if the spectral
|
||||
// analysis works at a sample rate that is twice as that of the pitch buffer;
|
||||
// In particular, it is not relevant for the estimated pitch period feature fed
|
||||
// into the RNN.
|
||||
int GetPitchPseudoInterpolationOffset(float prev_auto_correlation,
|
||||
float curr_auto_correlation,
|
||||
float next_auto_correlation) {
|
||||
if ((next_auto_correlation - prev_auto_correlation) >
|
||||
0.7f * (curr_auto_correlation - prev_auto_correlation)) {
|
||||
return 1; // `next_auto_correlation` is the largest auto-correlation
|
||||
// coefficient.
|
||||
} else if ((prev_auto_correlation - next_auto_correlation) >
|
||||
0.7f * (curr_auto_correlation - next_auto_correlation)) {
|
||||
return -1; // `prev_auto_correlation` is the largest auto-correlation
|
||||
// coefficient.
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
// Refines a pitch period `lag` encoded as lag with pseudo-interpolation. The
|
||||
// output sample rate is twice as that of `lag`.
|
||||
int PitchPseudoInterpolationLagPitchBuf(
|
||||
int lag,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
const VectorMath& vector_math) {
|
||||
int offset = 0;
|
||||
// Cannot apply pseudo-interpolation at the boundaries.
|
||||
if (lag > 0 && lag < kMaxPitch24kHz) {
|
||||
const int inverted_lag = kMaxPitch24kHz - lag;
|
||||
offset = GetPitchPseudoInterpolationOffset(
|
||||
ComputeAutoCorrelation(inverted_lag + 1, pitch_buffer, vector_math),
|
||||
ComputeAutoCorrelation(inverted_lag, pitch_buffer, vector_math),
|
||||
ComputeAutoCorrelation(inverted_lag - 1, pitch_buffer, vector_math));
|
||||
}
|
||||
return 2 * lag + offset;
|
||||
}
|
||||
|
||||
// Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when
|
||||
// looking for sub-harmonics.
|
||||
// The values have been chosen to serve the following algorithm. Given the
|
||||
// initial pitch period T, we examine whether one of its harmonics is the true
|
||||
// fundamental frequency. We consider T/k with k in {2, ..., 15}. For each of
|
||||
// these harmonics, in addition to the pitch strength of itself, we choose one
|
||||
// multiple of its pitch period, n*T/k, to validate it (by averaging their pitch
|
||||
// strengths). The multiplier n is chosen so that n*T/k is used only one time
|
||||
// over all k. When for example k = 4, we should also expect a peak at 3*T/4.
|
||||
// When k = 8 instead we don't want to look at 2*T/8, since we have already
|
||||
// checked T/4 before. Instead, we look at T*3/8.
|
||||
// The array can be generate in Python as follows:
|
||||
// from fractions import Fraction
|
||||
// # Smallest positive integer not in X.
|
||||
// def mex(X):
|
||||
// for i in range(1, int(max(X)+2)):
|
||||
// if i not in X:
|
||||
// return i
|
||||
// # Visited multiples of the period.
|
||||
// S = {1}
|
||||
// for n in range(2, 16):
|
||||
// sn = mex({n * i for i in S} | {1})
|
||||
// S = S | {Fraction(1, n), Fraction(sn, n)}
|
||||
// print(sn, end=', ')
|
||||
constexpr std::array<int, 14> kSubHarmonicMultipliers = {
|
||||
{3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
|
||||
|
||||
struct Range {
|
||||
int min;
|
||||
int max;
|
||||
};
|
||||
|
||||
// Number of analyzed pitches to the left(right) of a pitch candidate.
|
||||
constexpr int kPitchNeighborhoodRadius = 2;
|
||||
|
||||
// Creates a pitch period interval centered in `inverted_lag` with hard-coded
|
||||
// radius. Clipping is applied so that the interval is always valid for a 24 kHz
|
||||
// pitch buffer.
|
||||
Range CreateInvertedLagRange(int inverted_lag) {
|
||||
return {std::max(inverted_lag - kPitchNeighborhoodRadius, 0),
|
||||
std::min(inverted_lag + kPitchNeighborhoodRadius,
|
||||
kInitialNumLags24kHz - 1)};
|
||||
}
|
||||
|
||||
constexpr int kNumPitchCandidates = 2; // Best and second best.
|
||||
// Maximum number of analyzed pitch periods.
|
||||
constexpr int kMaxPitchPeriods24kHz =
|
||||
kNumPitchCandidates * (2 * kPitchNeighborhoodRadius + 1);
|
||||
|
||||
// Collection of inverted lags.
|
||||
class InvertedLagsIndex {
|
||||
public:
|
||||
InvertedLagsIndex() : num_entries_(0) {}
|
||||
// Adds an inverted lag to the index. Cannot add more than
|
||||
// `kMaxPitchPeriods24kHz` values.
|
||||
void Append(int inverted_lag) {
|
||||
RTC_DCHECK_LT(num_entries_, kMaxPitchPeriods24kHz);
|
||||
inverted_lags_[num_entries_++] = inverted_lag;
|
||||
}
|
||||
const int* data() const { return inverted_lags_.data(); }
|
||||
int size() const { return num_entries_; }
|
||||
|
||||
private:
|
||||
std::array<int, kMaxPitchPeriods24kHz> inverted_lags_;
|
||||
int num_entries_;
|
||||
};
|
||||
|
||||
// Computes the auto correlation coefficients for the inverted lags in the
|
||||
// closed interval `inverted_lags`. Updates `inverted_lags_index` by appending
|
||||
// the inverted lags for the computed auto correlation values.
|
||||
void ComputeAutoCorrelation(
|
||||
Range inverted_lags,
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation,
|
||||
InvertedLagsIndex& inverted_lags_index,
|
||||
const VectorMath& vector_math) {
|
||||
// Check valid range.
|
||||
RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max);
|
||||
// Trick to avoid zero initialization of `auto_correlation`.
|
||||
// Needed by the pseudo-interpolation.
|
||||
if (inverted_lags.min > 0) {
|
||||
auto_correlation[inverted_lags.min - 1] = 0.f;
|
||||
}
|
||||
if (inverted_lags.max < kInitialNumLags24kHz - 1) {
|
||||
auto_correlation[inverted_lags.max + 1] = 0.f;
|
||||
}
|
||||
// Check valid `inverted_lag` indexes.
|
||||
RTC_DCHECK_GE(inverted_lags.min, 0);
|
||||
RTC_DCHECK_LT(inverted_lags.max, kInitialNumLags24kHz);
|
||||
for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max;
|
||||
++inverted_lag) {
|
||||
auto_correlation[inverted_lag] =
|
||||
ComputeAutoCorrelation(inverted_lag, pitch_buffer, vector_math);
|
||||
inverted_lags_index.Append(inverted_lag);
|
||||
}
|
||||
}
|
||||
|
||||
// Searches the strongest pitch period at 24 kHz and returns its inverted lag at
|
||||
// 48 kHz.
|
||||
int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const int> inverted_lags,
|
||||
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
const VectorMath& vector_math) {
|
||||
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
|
||||
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
|
||||
int best_inverted_lag = 0; // Pitch period.
|
||||
float best_numerator = -1.f; // Pitch strength numerator.
|
||||
float best_denominator = 0.f; // Pitch strength denominator.
|
||||
for (int inverted_lag : inverted_lags) {
|
||||
// A pitch candidate must have positive correlation.
|
||||
if (auto_correlation[inverted_lag] > 0.f) {
|
||||
// Auto-correlation energy normalized by frame energy.
|
||||
const float numerator =
|
||||
auto_correlation[inverted_lag] * auto_correlation[inverted_lag];
|
||||
const float denominator = y_energy[inverted_lag];
|
||||
// Compare numerator/denominator ratios without using divisions.
|
||||
if (numerator * best_denominator > best_numerator * denominator) {
|
||||
best_inverted_lag = inverted_lag;
|
||||
best_numerator = numerator;
|
||||
best_denominator = denominator;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Pseudo-interpolation to transform `best_inverted_lag` (24 kHz pitch) to a
|
||||
// 48 kHz pitch period.
|
||||
if (best_inverted_lag == 0 || best_inverted_lag >= kInitialNumLags24kHz - 1) {
|
||||
// Cannot apply pseudo-interpolation at the boundaries.
|
||||
return best_inverted_lag * 2;
|
||||
}
|
||||
int offset = GetPitchPseudoInterpolationOffset(
|
||||
auto_correlation[best_inverted_lag + 1],
|
||||
auto_correlation[best_inverted_lag],
|
||||
auto_correlation[best_inverted_lag - 1]);
|
||||
// TODO(bugs.webrtc.org/9076): When retraining, check if `offset` below should
|
||||
// be subtracted since `inverted_lag` is an inverted lag but offset is a lag.
|
||||
return 2 * best_inverted_lag + offset;
|
||||
}
|
||||
|
||||
// Returns an alternative pitch period for `pitch_period` given a `multiplier`
|
||||
// and a `divisor` of the period.
|
||||
constexpr int GetAlternativePitchPeriod(int pitch_period,
|
||||
int multiplier,
|
||||
int divisor) {
|
||||
RTC_DCHECK_GT(divisor, 0);
|
||||
// Same as `round(multiplier * pitch_period / divisor)`.
|
||||
return (2 * multiplier * pitch_period + divisor) / (2 * divisor);
|
||||
}
|
||||
|
||||
// Returns true if the alternative pitch period is stronger than the initial one
|
||||
// given the last estimated pitch and the value of `period_divisor` used to
|
||||
// compute the alternative pitch period via `GetAlternativePitchPeriod()`.
|
||||
bool IsAlternativePitchStrongerThanInitial(PitchInfo last,
|
||||
PitchInfo initial,
|
||||
PitchInfo alternative,
|
||||
int period_divisor) {
|
||||
// Initial pitch period candidate thresholds for a sample rate of 24 kHz.
|
||||
// Computed as [5*k*k for k in range(16)].
|
||||
constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
|
||||
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
|
||||
static_assert(
|
||||
kInitialPitchPeriodThresholds.size() == kSubHarmonicMultipliers.size(),
|
||||
"");
|
||||
RTC_DCHECK_GE(last.period, 0);
|
||||
RTC_DCHECK_GE(initial.period, 0);
|
||||
RTC_DCHECK_GE(alternative.period, 0);
|
||||
RTC_DCHECK_GE(period_divisor, 2);
|
||||
// Compute a term that lowers the threshold when `alternative.period` is close
|
||||
// to the last estimated period `last.period` - i.e., pitch tracking.
|
||||
float lower_threshold_term = 0.f;
|
||||
if (std::abs(alternative.period - last.period) <= 1) {
|
||||
// The candidate pitch period is within 1 sample from the last one.
|
||||
// Make the candidate at `alternative.period` very easy to be accepted.
|
||||
lower_threshold_term = last.strength;
|
||||
} else if (std::abs(alternative.period - last.period) == 2 &&
|
||||
initial.period >
|
||||
kInitialPitchPeriodThresholds[period_divisor - 2]) {
|
||||
// The candidate pitch period is 2 samples far from the last one and the
|
||||
// period `initial.period` (from which `alternative.period` has been
|
||||
// derived) is greater than a threshold. Make `alternative.period` easy to
|
||||
// be accepted.
|
||||
lower_threshold_term = 0.5f * last.strength;
|
||||
}
|
||||
// Set the threshold based on the strength of the initial estimate
|
||||
// `initial.period`. Also reduce the chance of false positives caused by a
|
||||
// bias towards high frequencies (originating from short-term correlations).
|
||||
float threshold =
|
||||
std::max(0.3f, 0.7f * initial.strength - lower_threshold_term);
|
||||
if (alternative.period < 3 * kMinPitch24kHz) {
|
||||
// High frequency.
|
||||
threshold = std::max(0.4f, 0.85f * initial.strength - lower_threshold_term);
|
||||
} else if (alternative.period < 2 * kMinPitch24kHz) {
|
||||
// Even higher frequency.
|
||||
threshold = std::max(0.5f, 0.9f * initial.strength - lower_threshold_term);
|
||||
}
|
||||
return alternative.strength > threshold;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
||||
rtc::ArrayView<float, kBufSize12kHz> dst) {
|
||||
// TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter.
|
||||
static_assert(2 * kBufSize12kHz == kBufSize24kHz, "");
|
||||
for (int i = 0; i < kBufSize12kHz; ++i) {
|
||||
dst[i] = src[2 * i];
|
||||
}
|
||||
}
|
||||
|
||||
void ComputeSlidingFrameSquareEnergies24kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy,
|
||||
AvailableCpuFeatures cpu_features) {
|
||||
VectorMath vector_math(cpu_features);
|
||||
static_assert(kFrameSize20ms24kHz < kBufSize24kHz, "");
|
||||
const auto frame_20ms_view = pitch_buffer.subview(0, kFrameSize20ms24kHz);
|
||||
float yy = vector_math.DotProduct(frame_20ms_view, frame_20ms_view);
|
||||
y_energy[0] = yy;
|
||||
static_assert(kMaxPitch24kHz - 1 + kFrameSize20ms24kHz < kBufSize24kHz, "");
|
||||
static_assert(kMaxPitch24kHz < kRefineNumLags24kHz, "");
|
||||
for (int inverted_lag = 0; inverted_lag < kMaxPitch24kHz; ++inverted_lag) {
|
||||
yy -= pitch_buffer[inverted_lag] * pitch_buffer[inverted_lag];
|
||||
yy += pitch_buffer[inverted_lag + kFrameSize20ms24kHz] *
|
||||
pitch_buffer[inverted_lag + kFrameSize20ms24kHz];
|
||||
yy = std::max(1.f, yy);
|
||||
y_energy[inverted_lag + 1] = yy;
|
||||
}
|
||||
}
|
||||
|
||||
CandidatePitchPeriods ComputePitchPeriod12kHz(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation,
|
||||
AvailableCpuFeatures cpu_features) {
|
||||
static_assert(kMaxPitch12kHz > kNumLags12kHz, "");
|
||||
static_assert(kMaxPitch12kHz < kBufSize12kHz, "");
|
||||
|
||||
// Stores a pitch candidate period and strength information.
|
||||
struct PitchCandidate {
|
||||
// Pitch period encoded as inverted lag.
|
||||
int period_inverted_lag = 0;
|
||||
// Pitch strength encoded as a ratio.
|
||||
float strength_numerator = -1.f;
|
||||
float strength_denominator = 0.f;
|
||||
// Compare the strength of two pitch candidates.
|
||||
bool HasStrongerPitchThan(const PitchCandidate& b) const {
|
||||
// Comparing the numerator/denominator ratios without using divisions.
|
||||
return strength_numerator * b.strength_denominator >
|
||||
b.strength_numerator * strength_denominator;
|
||||
}
|
||||
};
|
||||
|
||||
VectorMath vector_math(cpu_features);
|
||||
static_assert(kFrameSize20ms12kHz + 1 < kBufSize12kHz, "");
|
||||
const auto frame_view = pitch_buffer.subview(0, kFrameSize20ms12kHz + 1);
|
||||
float denominator = 1.f + vector_math.DotProduct(frame_view, frame_view);
|
||||
// Search best and second best pitches by looking at the scaled
|
||||
// auto-correlation.
|
||||
PitchCandidate best;
|
||||
PitchCandidate second_best;
|
||||
second_best.period_inverted_lag = 1;
|
||||
for (int inverted_lag = 0; inverted_lag < kNumLags12kHz; ++inverted_lag) {
|
||||
// A pitch candidate must have positive correlation.
|
||||
if (auto_correlation[inverted_lag] > 0.f) {
|
||||
PitchCandidate candidate{
|
||||
inverted_lag,
|
||||
auto_correlation[inverted_lag] * auto_correlation[inverted_lag],
|
||||
denominator};
|
||||
if (candidate.HasStrongerPitchThan(second_best)) {
|
||||
if (candidate.HasStrongerPitchThan(best)) {
|
||||
second_best = best;
|
||||
best = candidate;
|
||||
} else {
|
||||
second_best = candidate;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Update `squared_energy_y` for the next inverted lag.
|
||||
const float y_old = pitch_buffer[inverted_lag];
|
||||
const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms12kHz];
|
||||
denominator -= y_old * y_old;
|
||||
denominator += y_new * y_new;
|
||||
denominator = std::max(0.f, denominator);
|
||||
}
|
||||
return {best.period_inverted_lag, second_best.period_inverted_lag};
|
||||
}
|
||||
|
||||
int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
CandidatePitchPeriods pitch_candidates,
|
||||
AvailableCpuFeatures cpu_features) {
|
||||
// Compute the auto-correlation terms only for neighbors of the two pitch
|
||||
// candidates (best and second best).
|
||||
std::array<float, kInitialNumLags24kHz> auto_correlation;
|
||||
InvertedLagsIndex inverted_lags_index;
|
||||
// Create two inverted lag ranges so that `r1` precedes `r2`.
|
||||
const bool swap_candidates =
|
||||
pitch_candidates.best > pitch_candidates.second_best;
|
||||
const Range r1 = CreateInvertedLagRange(
|
||||
swap_candidates ? pitch_candidates.second_best : pitch_candidates.best);
|
||||
const Range r2 = CreateInvertedLagRange(
|
||||
swap_candidates ? pitch_candidates.best : pitch_candidates.second_best);
|
||||
// Check valid ranges.
|
||||
RTC_DCHECK_LE(r1.min, r1.max);
|
||||
RTC_DCHECK_LE(r2.min, r2.max);
|
||||
// Check `r1` precedes `r2`.
|
||||
RTC_DCHECK_LE(r1.min, r2.min);
|
||||
RTC_DCHECK_LE(r1.max, r2.max);
|
||||
VectorMath vector_math(cpu_features);
|
||||
if (r1.max + 1 >= r2.min) {
|
||||
// Overlapping or adjacent ranges.
|
||||
ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation,
|
||||
inverted_lags_index, vector_math);
|
||||
} else {
|
||||
// Disjoint ranges.
|
||||
ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation,
|
||||
inverted_lags_index, vector_math);
|
||||
ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation,
|
||||
inverted_lags_index, vector_math);
|
||||
}
|
||||
return ComputePitchPeriod48kHz(pitch_buffer, inverted_lags_index,
|
||||
auto_correlation, y_energy, vector_math);
|
||||
}
|
||||
|
||||
PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
int initial_pitch_period_48kHz,
|
||||
PitchInfo last_pitch_48kHz,
|
||||
AvailableCpuFeatures cpu_features) {
|
||||
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
|
||||
RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
|
||||
|
||||
// Stores information for a refined pitch candidate.
|
||||
struct RefinedPitchCandidate {
|
||||
int period;
|
||||
float strength;
|
||||
// Additional strength data used for the final pitch estimation.
|
||||
float xy; // Auto-correlation.
|
||||
float y_energy; // Energy of the sliding frame `y`.
|
||||
};
|
||||
|
||||
const float x_energy = y_energy[kMaxPitch24kHz];
|
||||
const auto pitch_strength = [x_energy](float xy, float y_energy) {
|
||||
RTC_DCHECK_GE(x_energy * y_energy, 0.f);
|
||||
return xy / std::sqrt(1.f + x_energy * y_energy);
|
||||
};
|
||||
VectorMath vector_math(cpu_features);
|
||||
|
||||
// Initialize the best pitch candidate with `initial_pitch_period_48kHz`.
|
||||
RefinedPitchCandidate best_pitch;
|
||||
best_pitch.period =
|
||||
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
|
||||
best_pitch.xy = ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period,
|
||||
pitch_buffer, vector_math);
|
||||
best_pitch.y_energy = y_energy[kMaxPitch24kHz - best_pitch.period];
|
||||
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy);
|
||||
// Keep a copy of the initial pitch candidate.
|
||||
const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength};
|
||||
// 24 kHz version of the last estimated pitch.
|
||||
const PitchInfo last_pitch{last_pitch_48kHz.period / 2,
|
||||
last_pitch_48kHz.strength};
|
||||
|
||||
// Find `max_period_divisor` such that the result of
|
||||
// `GetAlternativePitchPeriod(initial_pitch_period, 1, max_period_divisor)`
|
||||
// equals `kMinPitch24kHz`.
|
||||
const int max_period_divisor =
|
||||
(2 * initial_pitch.period) / (2 * kMinPitch24kHz - 1);
|
||||
for (int period_divisor = 2; period_divisor <= max_period_divisor;
|
||||
++period_divisor) {
|
||||
PitchInfo alternative_pitch;
|
||||
alternative_pitch.period = GetAlternativePitchPeriod(
|
||||
initial_pitch.period, /*multiplier=*/1, period_divisor);
|
||||
RTC_DCHECK_GE(alternative_pitch.period, kMinPitch24kHz);
|
||||
// When looking at `alternative_pitch.period`, we also look at one of its
|
||||
// sub-harmonics. `kSubHarmonicMultipliers` is used to know where to look.
|
||||
// `period_divisor` == 2 is a special case since `dual_alternative_period`
|
||||
// might be greater than the maximum pitch period.
|
||||
int dual_alternative_period = GetAlternativePitchPeriod(
|
||||
initial_pitch.period, kSubHarmonicMultipliers[period_divisor - 2],
|
||||
period_divisor);
|
||||
RTC_DCHECK_GT(dual_alternative_period, 0);
|
||||
if (period_divisor == 2 && dual_alternative_period > kMaxPitch24kHz) {
|
||||
dual_alternative_period = initial_pitch.period;
|
||||
}
|
||||
RTC_DCHECK_NE(alternative_pitch.period, dual_alternative_period)
|
||||
<< "The lower pitch period and the additional sub-harmonic must not "
|
||||
"coincide.";
|
||||
// Compute an auto-correlation score for the primary pitch candidate
|
||||
// `alternative_pitch.period` by also looking at its possible sub-harmonic
|
||||
// `dual_alternative_period`.
|
||||
const float xy_primary_period = ComputeAutoCorrelation(
|
||||
kMaxPitch24kHz - alternative_pitch.period, pitch_buffer, vector_math);
|
||||
// TODO(webrtc:10480): Copy `xy_primary_period` if the secondary period is
|
||||
// equal to the primary one.
|
||||
const float xy_secondary_period = ComputeAutoCorrelation(
|
||||
kMaxPitch24kHz - dual_alternative_period, pitch_buffer, vector_math);
|
||||
const float xy = 0.5f * (xy_primary_period + xy_secondary_period);
|
||||
const float yy =
|
||||
0.5f * (y_energy[kMaxPitch24kHz - alternative_pitch.period] +
|
||||
y_energy[kMaxPitch24kHz - dual_alternative_period]);
|
||||
alternative_pitch.strength = pitch_strength(xy, yy);
|
||||
|
||||
// Maybe update best period.
|
||||
if (IsAlternativePitchStrongerThanInitial(
|
||||
last_pitch, initial_pitch, alternative_pitch, period_divisor)) {
|
||||
best_pitch = {alternative_pitch.period, alternative_pitch.strength, xy,
|
||||
yy};
|
||||
}
|
||||
}
|
||||
|
||||
// Final pitch strength and period.
|
||||
best_pitch.xy = std::max(0.f, best_pitch.xy);
|
||||
RTC_DCHECK_LE(0.f, best_pitch.y_energy);
|
||||
float final_pitch_strength =
|
||||
(best_pitch.y_energy <= best_pitch.xy)
|
||||
? 1.f
|
||||
: best_pitch.xy / (best_pitch.y_energy + 1.f);
|
||||
final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength);
|
||||
int final_pitch_period_48kHz = std::max(
|
||||
kMinPitch48kHz, PitchPseudoInterpolationLagPitchBuf(
|
||||
best_pitch.period, pitch_buffer, vector_math));
|
||||
|
||||
return {final_pitch_period_48kHz, final_pitch_strength};
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,114 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
#include <utility>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Performs 2x decimation without any anti-aliasing filter.
|
||||
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
|
||||
rtc::ArrayView<float, kBufSize12kHz> dst);
|
||||
|
||||
// Key concepts and keywords used below in this file.
|
||||
//
|
||||
// The pitch estimation relies on a pitch buffer, which is an array-like data
|
||||
// structured designed as follows:
|
||||
//
|
||||
// |....A....|.....B.....|
|
||||
//
|
||||
// The part on the left, named `A` contains the oldest samples, whereas `B`
|
||||
// contains the most recent ones. The size of `A` corresponds to the maximum
|
||||
// pitch period, that of `B` to the analysis frame size (e.g., 16 ms and 20 ms
|
||||
// respectively).
|
||||
//
|
||||
// Pitch estimation is essentially based on the analysis of two 20 ms frames
|
||||
// extracted from the pitch buffer. One frame, called `x`, is kept fixed and
|
||||
// corresponds to `B` - i.e., the most recent 20 ms. The other frame, called
|
||||
// `y`, is extracted from different parts of the buffer instead.
|
||||
//
|
||||
// The offset between `x` and `y` corresponds to a specific pitch period.
|
||||
// For instance, if `y` is positioned at the beginning of the pitch buffer, then
|
||||
// the cross-correlation between `x` and `y` can be used as an indication of the
|
||||
// strength for the maximum pitch.
|
||||
//
|
||||
// Such an offset can be encoded in two ways:
|
||||
// - As a lag, which is the index in the pitch buffer for the first item in `y`
|
||||
// - As an inverted lag, which is the number of samples from the beginning of
|
||||
// `x` and the end of `y`
|
||||
//
|
||||
// |---->| lag
|
||||
// |....A....|.....B.....|
|
||||
// |<--| inverted lag
|
||||
// |.....y.....| `y` 20 ms frame
|
||||
//
|
||||
// The inverted lag has the advantage of being directly proportional to the
|
||||
// corresponding pitch period.
|
||||
|
||||
// Computes the sum of squared samples for every sliding frame `y` in the pitch
|
||||
// buffer. The indexes of `y_energy` are inverted lags.
|
||||
void ComputeSlidingFrameSquareEnergies24kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy,
|
||||
AvailableCpuFeatures cpu_features);
|
||||
|
||||
// Top-2 pitch period candidates. Unit: number of samples - i.e., inverted lags.
|
||||
struct CandidatePitchPeriods {
|
||||
int best;
|
||||
int second_best;
|
||||
};
|
||||
|
||||
// Computes the candidate pitch periods at 12 kHz given a view on the 12 kHz
|
||||
// pitch buffer and the auto-correlation values (having inverted lags as
|
||||
// indexes).
|
||||
CandidatePitchPeriods ComputePitchPeriod12kHz(
|
||||
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation,
|
||||
AvailableCpuFeatures cpu_features);
|
||||
|
||||
// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer,
|
||||
// the energies for the sliding frames `y` at 24 kHz and the pitch period
|
||||
// candidates at 24 kHz (encoded as inverted lag).
|
||||
int ComputePitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
CandidatePitchPeriods pitch_candidates_24kHz,
|
||||
AvailableCpuFeatures cpu_features);
|
||||
|
||||
struct PitchInfo {
|
||||
int period;
|
||||
float strength;
|
||||
};
|
||||
|
||||
// Computes the pitch period at 48 kHz searching in an extended pitch range
|
||||
// given a view on the 24 kHz pitch buffer, the energies for the sliding frames
|
||||
// `y` at 24 kHz, the initial 48 kHz estimation (computed by
|
||||
// `ComputePitchPeriod48kHz()`) and the last estimated pitch.
|
||||
PitchInfo ComputeExtendedPitchPeriod48kHz(
|
||||
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
|
||||
int initial_pitch_period_48kHz,
|
||||
PitchInfo last_pitch_48kHz,
|
||||
AvailableCpuFeatures cpu_features);
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_
|
||||
|
|
@ -0,0 +1,65 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
|
||||
#include "api/array_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Ring buffer for N arrays of type T each one with size S.
|
||||
template <typename T, int S, int N>
|
||||
class RingBuffer {
|
||||
static_assert(S > 0, "");
|
||||
static_assert(N > 0, "");
|
||||
static_assert(std::is_arithmetic<T>::value,
|
||||
"Integral or floating point required.");
|
||||
|
||||
public:
|
||||
RingBuffer() : tail_(0) {}
|
||||
RingBuffer(const RingBuffer&) = delete;
|
||||
RingBuffer& operator=(const RingBuffer&) = delete;
|
||||
~RingBuffer() = default;
|
||||
// Set the ring buffer values to zero.
|
||||
void Reset() { buffer_.fill(0); }
|
||||
// Replace the least recently pushed array in the buffer with `new_values`.
|
||||
void Push(rtc::ArrayView<const T, S> new_values) {
|
||||
std::memcpy(buffer_.data() + S * tail_, new_values.data(), S * sizeof(T));
|
||||
tail_ += 1;
|
||||
if (tail_ == N)
|
||||
tail_ = 0;
|
||||
}
|
||||
// Return an array view onto the array with a given delay. A view on the last
|
||||
// and least recently push array is returned when `delay` is 0 and N - 1
|
||||
// respectively.
|
||||
rtc::ArrayView<const T, S> GetArrayView(int delay) const {
|
||||
RTC_DCHECK_LE(0, delay);
|
||||
RTC_DCHECK_LT(delay, N);
|
||||
int offset = tail_ - 1 - delay;
|
||||
if (offset < 0)
|
||||
offset += N;
|
||||
return {buffer_.data() + S * offset, S};
|
||||
}
|
||||
|
||||
private:
|
||||
int tail_; // Index of the least recently pushed sub-array.
|
||||
std::array<T, S * N> buffer_{};
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
|
||||
|
|
@ -0,0 +1,91 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
using ::rnnoise::kInputLayerInputSize;
|
||||
static_assert(kFeatureVectorSize == kInputLayerInputSize, "");
|
||||
using ::rnnoise::kInputDenseBias;
|
||||
using ::rnnoise::kInputDenseWeights;
|
||||
using ::rnnoise::kInputLayerOutputSize;
|
||||
static_assert(kInputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
|
||||
|
||||
using ::rnnoise::kHiddenGruBias;
|
||||
using ::rnnoise::kHiddenGruRecurrentWeights;
|
||||
using ::rnnoise::kHiddenGruWeights;
|
||||
using ::rnnoise::kHiddenLayerOutputSize;
|
||||
static_assert(kHiddenLayerOutputSize <= kGruLayerMaxUnits, "");
|
||||
|
||||
using ::rnnoise::kOutputDenseBias;
|
||||
using ::rnnoise::kOutputDenseWeights;
|
||||
using ::rnnoise::kOutputLayerOutputSize;
|
||||
static_assert(kOutputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
|
||||
|
||||
} // namespace
|
||||
|
||||
RnnVad::RnnVad(const AvailableCpuFeatures& cpu_features)
|
||||
: input_(kInputLayerInputSize,
|
||||
kInputLayerOutputSize,
|
||||
kInputDenseBias,
|
||||
kInputDenseWeights,
|
||||
ActivationFunction::kTansigApproximated,
|
||||
cpu_features,
|
||||
/*layer_name=*/"FC1"),
|
||||
hidden_(kInputLayerOutputSize,
|
||||
kHiddenLayerOutputSize,
|
||||
kHiddenGruBias,
|
||||
kHiddenGruWeights,
|
||||
kHiddenGruRecurrentWeights,
|
||||
cpu_features,
|
||||
/*layer_name=*/"GRU1"),
|
||||
output_(kHiddenLayerOutputSize,
|
||||
kOutputLayerOutputSize,
|
||||
kOutputDenseBias,
|
||||
kOutputDenseWeights,
|
||||
ActivationFunction::kSigmoidApproximated,
|
||||
// The output layer is just 24x1. The unoptimized code is faster.
|
||||
NoAvailableCpuFeatures(),
|
||||
/*layer_name=*/"FC2") {
|
||||
// Input-output chaining size checks.
|
||||
RTC_DCHECK_EQ(input_.size(), hidden_.input_size())
|
||||
<< "The input and the hidden layers sizes do not match.";
|
||||
RTC_DCHECK_EQ(hidden_.size(), output_.input_size())
|
||||
<< "The hidden and the output layers sizes do not match.";
|
||||
}
|
||||
|
||||
RnnVad::~RnnVad() = default;
|
||||
|
||||
void RnnVad::Reset() {
|
||||
hidden_.Reset();
|
||||
}
|
||||
|
||||
float RnnVad::ComputeVadProbability(
|
||||
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
|
||||
bool is_silence) {
|
||||
if (is_silence) {
|
||||
Reset();
|
||||
return 0.f;
|
||||
}
|
||||
input_.ComputeOutput(feature_vector);
|
||||
hidden_.ComputeOutput(input_);
|
||||
output_.ComputeOutput(hidden_);
|
||||
RTC_DCHECK_EQ(output_.size(), 1);
|
||||
return output_.data()[0];
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,53 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
|
||||
|
||||
#include <stddef.h>
|
||||
#include <sys/types.h>
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Recurrent network with hard-coded architecture and weights for voice activity
|
||||
// detection.
|
||||
class RnnVad {
|
||||
public:
|
||||
explicit RnnVad(const AvailableCpuFeatures& cpu_features);
|
||||
RnnVad(const RnnVad&) = delete;
|
||||
RnnVad& operator=(const RnnVad&) = delete;
|
||||
~RnnVad();
|
||||
void Reset();
|
||||
// Observes `feature_vector` and `is_silence`, updates the RNN and returns the
|
||||
// current voice probability.
|
||||
float ComputeVadProbability(
|
||||
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
|
||||
bool is_silence);
|
||||
|
||||
private:
|
||||
FullyConnectedLayer input_;
|
||||
GatedRecurrentLayer hidden_;
|
||||
FullyConnectedLayer output_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
|
||||
|
|
@ -0,0 +1,104 @@
|
|||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
std::vector<float> GetScaledParams(rtc::ArrayView<const int8_t> params) {
|
||||
std::vector<float> scaled_params(params.size());
|
||||
std::transform(params.begin(), params.end(), scaled_params.begin(),
|
||||
[](int8_t x) -> float {
|
||||
return ::rnnoise::kWeightsScale * static_cast<float>(x);
|
||||
});
|
||||
return scaled_params;
|
||||
}
|
||||
|
||||
// TODO(bugs.chromium.org/10480): Hard-code optimized layout and remove this
|
||||
// function to improve setup time.
|
||||
// Casts and scales `weights` and re-arranges the layout.
|
||||
std::vector<float> PreprocessWeights(rtc::ArrayView<const int8_t> weights,
|
||||
int output_size) {
|
||||
if (output_size == 1) {
|
||||
return GetScaledParams(weights);
|
||||
}
|
||||
// Transpose, scale and cast.
|
||||
const int input_size = rtc::CheckedDivExact(
|
||||
rtc::dchecked_cast<int>(weights.size()), output_size);
|
||||
std::vector<float> w(weights.size());
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
for (int i = 0; i < input_size; ++i) {
|
||||
w[o * input_size + i] = rnnoise::kWeightsScale *
|
||||
static_cast<float>(weights[i * output_size + o]);
|
||||
}
|
||||
}
|
||||
return w;
|
||||
}
|
||||
|
||||
rtc::FunctionView<float(float)> GetActivationFunction(
|
||||
ActivationFunction activation_function) {
|
||||
switch (activation_function) {
|
||||
case ActivationFunction::kTansigApproximated:
|
||||
return ::rnnoise::TansigApproximated;
|
||||
case ActivationFunction::kSigmoidApproximated:
|
||||
return ::rnnoise::SigmoidApproximated;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
FullyConnectedLayer::FullyConnectedLayer(
|
||||
const int input_size,
|
||||
const int output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
ActivationFunction activation_function,
|
||||
const AvailableCpuFeatures& cpu_features,
|
||||
absl::string_view layer_name)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(GetScaledParams(bias)),
|
||||
weights_(PreprocessWeights(weights, output_size)),
|
||||
vector_math_(cpu_features),
|
||||
activation_function_(GetActivationFunction(activation_function)) {
|
||||
RTC_DCHECK_LE(output_size_, kFullyConnectedLayerMaxUnits)
|
||||
<< "Insufficient FC layer over-allocation (" << layer_name << ").";
|
||||
RTC_DCHECK_EQ(output_size_, bias_.size())
|
||||
<< "Mismatching output size and bias terms array size (" << layer_name
|
||||
<< ").";
|
||||
RTC_DCHECK_EQ(input_size_ * output_size_, weights_.size())
|
||||
<< "Mismatching input-output size and weight coefficients array size ("
|
||||
<< layer_name << ").";
|
||||
}
|
||||
|
||||
FullyConnectedLayer::~FullyConnectedLayer() = default;
|
||||
|
||||
void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size_);
|
||||
rtc::ArrayView<const float> weights(weights_);
|
||||
for (int o = 0; o < output_size_; ++o) {
|
||||
output_[o] = activation_function_(
|
||||
bias_[o] + vector_math_.DotProduct(
|
||||
input, weights.subview(o * input_size_, input_size_)));
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,72 @@
|
|||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "api/array_view.h"
|
||||
#include "api/function_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Activation function for a neural network cell.
|
||||
enum class ActivationFunction { kTansigApproximated, kSigmoidApproximated };
|
||||
|
||||
// Maximum number of units for an FC layer.
|
||||
constexpr int kFullyConnectedLayerMaxUnits = 24;
|
||||
|
||||
// Fully-connected layer with a custom activation function which owns the output
|
||||
// buffer.
|
||||
class FullyConnectedLayer {
|
||||
public:
|
||||
// Ctor. `output_size` cannot be greater than `kFullyConnectedLayerMaxUnits`.
|
||||
FullyConnectedLayer(int input_size,
|
||||
int output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
ActivationFunction activation_function,
|
||||
const AvailableCpuFeatures& cpu_features,
|
||||
absl::string_view layer_name);
|
||||
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
|
||||
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
|
||||
~FullyConnectedLayer();
|
||||
|
||||
// Returns the size of the input vector.
|
||||
int input_size() const { return input_size_; }
|
||||
// Returns the pointer to the first element of the output buffer.
|
||||
const float* data() const { return output_.data(); }
|
||||
// Returns the size of the output buffer.
|
||||
int size() const { return output_size_; }
|
||||
|
||||
// Computes the fully-connected layer output.
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
const int input_size_;
|
||||
const int output_size_;
|
||||
const std::vector<float> bias_;
|
||||
const std::vector<float> weights_;
|
||||
const VectorMath vector_math_;
|
||||
rtc::FunctionView<float(float)> activation_function_;
|
||||
// Over-allocated array with size equal to `output_size_`.
|
||||
std::array<float, kFullyConnectedLayerMaxUnits> output_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_
|
||||
|
|
@ -0,0 +1,198 @@
|
|||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "third_party/rnnoise/src/rnn_activations.h"
|
||||
#include "third_party/rnnoise/src/rnn_vad_weights.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
constexpr int kNumGruGates = 3; // Update, reset, output.
|
||||
|
||||
std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,
|
||||
int output_size) {
|
||||
// Transpose, cast and scale.
|
||||
// `n` is the size of the first dimension of the 3-dim tensor `weights`.
|
||||
const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
|
||||
output_size * kNumGruGates);
|
||||
const int stride_src = kNumGruGates * output_size;
|
||||
const int stride_dst = n * output_size;
|
||||
std::vector<float> tensor_dst(tensor_src.size());
|
||||
for (int g = 0; g < kNumGruGates; ++g) {
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
for (int i = 0; i < n; ++i) {
|
||||
tensor_dst[g * stride_dst + o * n + i] =
|
||||
::rnnoise::kWeightsScale *
|
||||
static_cast<float>(
|
||||
tensor_src[i * stride_src + g * output_size + o]);
|
||||
}
|
||||
}
|
||||
}
|
||||
return tensor_dst;
|
||||
}
|
||||
|
||||
// Computes the output for the update or the reset gate.
|
||||
// Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
|
||||
// - `g`: output gate vector
|
||||
// - `W`: weights matrix
|
||||
// - `i`: input vector
|
||||
// - `R`: recurrent weights matrix
|
||||
// - `s`: state gate vector
|
||||
// - `b`: bias vector
|
||||
void ComputeUpdateResetGate(int input_size,
|
||||
int output_size,
|
||||
const VectorMath& vector_math,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> state,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<float> gate) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size);
|
||||
RTC_DCHECK_EQ(state.size(), output_size);
|
||||
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||
RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
|
||||
RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated.
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
float x = bias[o];
|
||||
x += vector_math.DotProduct(input,
|
||||
weights.subview(o * input_size, input_size));
|
||||
x += vector_math.DotProduct(
|
||||
state, recurrent_weights.subview(o * output_size, output_size));
|
||||
gate[o] = ::rnnoise::SigmoidApproximated(x);
|
||||
}
|
||||
}
|
||||
|
||||
// Computes the output for the state gate.
|
||||
// Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
|
||||
// - `s'`: output state gate vector
|
||||
// - `s`: previous state gate vector
|
||||
// - `u`: update gate vector
|
||||
// - `W`: weights matrix
|
||||
// - `i`: input vector
|
||||
// - `R`: recurrent weights matrix
|
||||
// - `r`: reset gate vector
|
||||
// - `b`: bias vector
|
||||
// - `.*` element-wise product
|
||||
void ComputeStateGate(int input_size,
|
||||
int output_size,
|
||||
const VectorMath& vector_math,
|
||||
rtc::ArrayView<const float> input,
|
||||
rtc::ArrayView<const float> update,
|
||||
rtc::ArrayView<const float> reset,
|
||||
rtc::ArrayView<const float> bias,
|
||||
rtc::ArrayView<const float> weights,
|
||||
rtc::ArrayView<const float> recurrent_weights,
|
||||
rtc::ArrayView<float> state) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size);
|
||||
RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated.
|
||||
RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated.
|
||||
RTC_DCHECK_EQ(bias.size(), output_size);
|
||||
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
|
||||
RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
|
||||
RTC_DCHECK_EQ(state.size(), output_size);
|
||||
std::array<float, kGruLayerMaxUnits> reset_x_state;
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
reset_x_state[o] = state[o] * reset[o];
|
||||
}
|
||||
for (int o = 0; o < output_size; ++o) {
|
||||
float x = bias[o];
|
||||
x += vector_math.DotProduct(input,
|
||||
weights.subview(o * input_size, input_size));
|
||||
x += vector_math.DotProduct(
|
||||
{reset_x_state.data(), static_cast<size_t>(output_size)},
|
||||
recurrent_weights.subview(o * output_size, output_size));
|
||||
state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
GatedRecurrentLayer::GatedRecurrentLayer(
|
||||
const int input_size,
|
||||
const int output_size,
|
||||
const rtc::ArrayView<const int8_t> bias,
|
||||
const rtc::ArrayView<const int8_t> weights,
|
||||
const rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
const AvailableCpuFeatures& cpu_features,
|
||||
absl::string_view layer_name)
|
||||
: input_size_(input_size),
|
||||
output_size_(output_size),
|
||||
bias_(PreprocessGruTensor(bias, output_size)),
|
||||
weights_(PreprocessGruTensor(weights, output_size)),
|
||||
recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)),
|
||||
vector_math_(cpu_features) {
|
||||
RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
|
||||
<< "Insufficient GRU layer over-allocation (" << layer_name << ").";
|
||||
RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
|
||||
<< "Mismatching output size and bias terms array size (" << layer_name
|
||||
<< ").";
|
||||
RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
|
||||
<< "Mismatching input-output size and weight coefficients array size ("
|
||||
<< layer_name << ").";
|
||||
RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
|
||||
recurrent_weights_.size())
|
||||
<< "Mismatching input-output size and recurrent weight coefficients array"
|
||||
" size ("
|
||||
<< layer_name << ").";
|
||||
Reset();
|
||||
}
|
||||
|
||||
GatedRecurrentLayer::~GatedRecurrentLayer() = default;
|
||||
|
||||
void GatedRecurrentLayer::Reset() {
|
||||
state_.fill(0.f);
|
||||
}
|
||||
|
||||
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
|
||||
RTC_DCHECK_EQ(input.size(), input_size_);
|
||||
|
||||
// The tensors below are organized as a sequence of flattened tensors for the
|
||||
// `update`, `reset` and `state` gates.
|
||||
rtc::ArrayView<const float> bias(bias_);
|
||||
rtc::ArrayView<const float> weights(weights_);
|
||||
rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
|
||||
// Strides to access to the flattened tensors for a specific gate.
|
||||
const int stride_weights = input_size_ * output_size_;
|
||||
const int stride_recurrent_weights = output_size_ * output_size_;
|
||||
|
||||
rtc::ArrayView<float> state(state_.data(), output_size_);
|
||||
|
||||
// Update gate.
|
||||
std::array<float, kGruLayerMaxUnits> update;
|
||||
ComputeUpdateResetGate(
|
||||
input_size_, output_size_, vector_math_, input, state,
|
||||
bias.subview(0, output_size_), weights.subview(0, stride_weights),
|
||||
recurrent_weights.subview(0, stride_recurrent_weights), update);
|
||||
// Reset gate.
|
||||
std::array<float, kGruLayerMaxUnits> reset;
|
||||
ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state,
|
||||
bias.subview(output_size_, output_size_),
|
||||
weights.subview(stride_weights, stride_weights),
|
||||
recurrent_weights.subview(stride_recurrent_weights,
|
||||
stride_recurrent_weights),
|
||||
reset);
|
||||
// State gate.
|
||||
ComputeStateGate(input_size_, output_size_, vector_math_, input, update,
|
||||
reset, bias.subview(2 * output_size_, output_size_),
|
||||
weights.subview(2 * stride_weights, stride_weights),
|
||||
recurrent_weights.subview(2 * stride_recurrent_weights,
|
||||
stride_recurrent_weights),
|
||||
state);
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,70 @@
|
|||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Maximum number of units for a GRU layer.
|
||||
constexpr int kGruLayerMaxUnits = 24;
|
||||
|
||||
// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
|
||||
// activation functions for the update/reset and output gates respectively.
|
||||
class GatedRecurrentLayer {
|
||||
public:
|
||||
// Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`.
|
||||
GatedRecurrentLayer(int input_size,
|
||||
int output_size,
|
||||
rtc::ArrayView<const int8_t> bias,
|
||||
rtc::ArrayView<const int8_t> weights,
|
||||
rtc::ArrayView<const int8_t> recurrent_weights,
|
||||
const AvailableCpuFeatures& cpu_features,
|
||||
absl::string_view layer_name);
|
||||
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
|
||||
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
|
||||
~GatedRecurrentLayer();
|
||||
|
||||
// Returns the size of the input vector.
|
||||
int input_size() const { return input_size_; }
|
||||
// Returns the pointer to the first element of the output buffer.
|
||||
const float* data() const { return state_.data(); }
|
||||
// Returns the size of the output buffer.
|
||||
int size() const { return output_size_; }
|
||||
|
||||
// Resets the GRU state.
|
||||
void Reset();
|
||||
// Computes the recurrent layer output and updates the status.
|
||||
void ComputeOutput(rtc::ArrayView<const float> input);
|
||||
|
||||
private:
|
||||
const int input_size_;
|
||||
const int output_size_;
|
||||
const std::vector<float> bias_;
|
||||
const std::vector<float> weights_;
|
||||
const std::vector<float> recurrent_weights_;
|
||||
const VectorMath vector_math_;
|
||||
// Over-allocated array with size equal to `output_size_`.
|
||||
std::array<float, kGruLayerMaxUnits> state_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
|
||||
|
|
@ -0,0 +1,123 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include <array>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/flags/flag.h"
|
||||
#include "absl/flags/parse.h"
|
||||
#include "common_audio/resampler/push_sinc_resampler.h"
|
||||
#include "common_audio/wav_file.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
ABSL_FLAG(std::string, i, "", "Path to the input wav file");
|
||||
ABSL_FLAG(std::string, f, "", "Path to the output features file");
|
||||
ABSL_FLAG(std::string, o, "", "Path to the output VAD probabilities file");
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace test {
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
absl::ParseCommandLine(argc, argv);
|
||||
rtc::LogMessage::LogToDebug(rtc::LS_INFO);
|
||||
|
||||
// Open wav input file and check properties.
|
||||
const std::string input_wav_file = absl::GetFlag(FLAGS_i);
|
||||
WavReader wav_reader(input_wav_file);
|
||||
if (wav_reader.num_channels() != 1) {
|
||||
RTC_LOG(LS_ERROR) << "Only mono wav files are supported";
|
||||
return 1;
|
||||
}
|
||||
if (wav_reader.sample_rate() % 100 != 0) {
|
||||
RTC_LOG(LS_ERROR) << "The sample rate rate must allow 10 ms frames.";
|
||||
return 1;
|
||||
}
|
||||
RTC_LOG(LS_INFO) << "Input sample rate: " << wav_reader.sample_rate();
|
||||
|
||||
// Init output files.
|
||||
const std::string output_vad_probs_file = absl::GetFlag(FLAGS_o);
|
||||
FILE* vad_probs_file = fopen(output_vad_probs_file.c_str(), "wb");
|
||||
FILE* features_file = nullptr;
|
||||
const std::string output_feature_file = absl::GetFlag(FLAGS_f);
|
||||
if (!output_feature_file.empty()) {
|
||||
features_file = fopen(output_feature_file.c_str(), "wb");
|
||||
}
|
||||
|
||||
// Initialize.
|
||||
const int frame_size_10ms =
|
||||
rtc::CheckedDivExact(wav_reader.sample_rate(), 100);
|
||||
std::vector<float> samples_10ms;
|
||||
samples_10ms.resize(frame_size_10ms);
|
||||
std::array<float, kFrameSize10ms24kHz> samples_10ms_24kHz;
|
||||
PushSincResampler resampler(frame_size_10ms, kFrameSize10ms24kHz);
|
||||
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
|
||||
FeaturesExtractor features_extractor(cpu_features);
|
||||
std::array<float, kFeatureVectorSize> feature_vector;
|
||||
RnnVad rnn_vad(cpu_features);
|
||||
|
||||
// Compute VAD probabilities.
|
||||
while (true) {
|
||||
// Read frame at the input sample rate.
|
||||
const size_t read_samples =
|
||||
wav_reader.ReadSamples(frame_size_10ms, samples_10ms.data());
|
||||
if (rtc::SafeLt(read_samples, frame_size_10ms)) {
|
||||
break; // EOF.
|
||||
}
|
||||
// Resample input.
|
||||
resampler.Resample(samples_10ms.data(), samples_10ms.size(),
|
||||
samples_10ms_24kHz.data(), samples_10ms_24kHz.size());
|
||||
|
||||
// Extract features and feed the RNN.
|
||||
bool is_silence = features_extractor.CheckSilenceComputeFeatures(
|
||||
samples_10ms_24kHz, feature_vector);
|
||||
float vad_probability =
|
||||
rnn_vad.ComputeVadProbability(feature_vector, is_silence);
|
||||
// Write voice probability.
|
||||
RTC_DCHECK_GE(vad_probability, 0.f);
|
||||
RTC_DCHECK_GE(1.f, vad_probability);
|
||||
fwrite(&vad_probability, sizeof(float), 1, vad_probs_file);
|
||||
// Write features.
|
||||
if (features_file) {
|
||||
const float float_is_silence = is_silence ? 1.f : 0.f;
|
||||
fwrite(&float_is_silence, sizeof(float), 1, features_file);
|
||||
if (is_silence) {
|
||||
// Do not write uninitialized values.
|
||||
feature_vector.fill(0.f);
|
||||
}
|
||||
fwrite(feature_vector.data(), sizeof(float), kFeatureVectorSize,
|
||||
features_file);
|
||||
}
|
||||
}
|
||||
|
||||
// Close output file(s).
|
||||
fclose(vad_probs_file);
|
||||
RTC_LOG(LS_INFO) << "VAD probabilities written to " << output_vad_probs_file;
|
||||
if (features_file) {
|
||||
fclose(features_file);
|
||||
RTC_LOG(LS_INFO) << "features written to " << output_feature_file;
|
||||
}
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
||||
} // namespace test
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
int main(int argc, char* argv[]) {
|
||||
return webrtc::rnn_vad::test::main(argc, argv);
|
||||
}
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstring>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Linear buffer implementation to (i) push fixed size chunks of sequential data
|
||||
// and (ii) view contiguous parts of the buffer. The buffer and the pushed
|
||||
// chunks have size S and N respectively. For instance, when S = 2N the first
|
||||
// half of the sequence buffer is replaced with its second half, and the new N
|
||||
// values are written at the end of the buffer.
|
||||
// The class also provides a view on the most recent M values, where 0 < M <= S
|
||||
// and by default M = N.
|
||||
template <typename T, int S, int N, int M = N>
|
||||
class SequenceBuffer {
|
||||
static_assert(N <= S,
|
||||
"The new chunk size cannot be larger than the sequence buffer "
|
||||
"size.");
|
||||
static_assert(std::is_arithmetic<T>::value,
|
||||
"Integral or floating point required.");
|
||||
|
||||
public:
|
||||
SequenceBuffer() : buffer_(S) {
|
||||
RTC_DCHECK_EQ(S, buffer_.size());
|
||||
Reset();
|
||||
}
|
||||
SequenceBuffer(const SequenceBuffer&) = delete;
|
||||
SequenceBuffer& operator=(const SequenceBuffer&) = delete;
|
||||
~SequenceBuffer() = default;
|
||||
int size() const { return S; }
|
||||
int chunks_size() const { return N; }
|
||||
// Sets the sequence buffer values to zero.
|
||||
void Reset() { std::fill(buffer_.begin(), buffer_.end(), 0); }
|
||||
// Returns a view on the whole buffer.
|
||||
rtc::ArrayView<const T, S> GetBufferView() const {
|
||||
return {buffer_.data(), S};
|
||||
}
|
||||
// Returns a view on the M most recent values of the buffer.
|
||||
rtc::ArrayView<const T, M> GetMostRecentValuesView() const {
|
||||
static_assert(M <= S,
|
||||
"The number of most recent values cannot be larger than the "
|
||||
"sequence buffer size.");
|
||||
return {buffer_.data() + S - M, M};
|
||||
}
|
||||
// Shifts left the buffer by N items and add new N items at the end.
|
||||
void Push(rtc::ArrayView<const T, N> new_values) {
|
||||
// Make space for the new values.
|
||||
if (S > N)
|
||||
std::memmove(buffer_.data(), buffer_.data() + N, (S - N) * sizeof(T));
|
||||
// Copy the new values at the end of the buffer.
|
||||
std::memcpy(buffer_.data() + S - N, new_values.data(), N * sizeof(T));
|
||||
}
|
||||
|
||||
private:
|
||||
std::vector<T> buffer_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
|
||||
|
|
@ -0,0 +1,214 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <limits>
|
||||
#include <numeric>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
constexpr float kSilenceThreshold = 0.04f;
|
||||
|
||||
// Computes the new cepstral difference stats and pushes them into the passed
|
||||
// symmetric matrix buffer.
|
||||
void UpdateCepstralDifferenceStats(
|
||||
rtc::ArrayView<const float, kNumBands> new_cepstral_coeffs,
|
||||
const RingBuffer<float, kNumBands, kCepstralCoeffsHistorySize>& ring_buf,
|
||||
SymmetricMatrixBuffer<float, kCepstralCoeffsHistorySize>* sym_matrix_buf) {
|
||||
RTC_DCHECK(sym_matrix_buf);
|
||||
// Compute the new cepstral distance stats.
|
||||
std::array<float, kCepstralCoeffsHistorySize - 1> distances;
|
||||
for (int i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
|
||||
const int delay = i + 1;
|
||||
auto old_cepstral_coeffs = ring_buf.GetArrayView(delay);
|
||||
distances[i] = 0.f;
|
||||
for (int k = 0; k < kNumBands; ++k) {
|
||||
const float c = new_cepstral_coeffs[k] - old_cepstral_coeffs[k];
|
||||
distances[i] += c * c;
|
||||
}
|
||||
}
|
||||
// Push the new spectral distance stats into the symmetric matrix buffer.
|
||||
sym_matrix_buf->Push(distances);
|
||||
}
|
||||
|
||||
// Computes the first half of the Vorbis window.
|
||||
std::array<float, kFrameSize20ms24kHz / 2> ComputeScaledHalfVorbisWindow(
|
||||
float scaling = 1.f) {
|
||||
constexpr int kHalfSize = kFrameSize20ms24kHz / 2;
|
||||
std::array<float, kHalfSize> half_window{};
|
||||
for (int i = 0; i < kHalfSize; ++i) {
|
||||
half_window[i] =
|
||||
scaling *
|
||||
std::sin(0.5 * kPi * std::sin(0.5 * kPi * (i + 0.5) / kHalfSize) *
|
||||
std::sin(0.5 * kPi * (i + 0.5) / kHalfSize));
|
||||
}
|
||||
return half_window;
|
||||
}
|
||||
|
||||
// Computes the forward FFT on a 20 ms frame to which a given window function is
|
||||
// applied. The Fourier coefficient corresponding to the Nyquist frequency is
|
||||
// set to zero (it is never used and this allows to simplify the code).
|
||||
void ComputeWindowedForwardFft(
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> frame,
|
||||
const std::array<float, kFrameSize20ms24kHz / 2>& half_window,
|
||||
Pffft::FloatBuffer* fft_input_buffer,
|
||||
Pffft::FloatBuffer* fft_output_buffer,
|
||||
Pffft* fft) {
|
||||
RTC_DCHECK_EQ(frame.size(), 2 * half_window.size());
|
||||
// Apply windowing.
|
||||
auto in = fft_input_buffer->GetView();
|
||||
for (int i = 0, j = kFrameSize20ms24kHz - 1;
|
||||
rtc::SafeLt(i, half_window.size()); ++i, --j) {
|
||||
in[i] = frame[i] * half_window[i];
|
||||
in[j] = frame[j] * half_window[i];
|
||||
}
|
||||
fft->ForwardTransform(*fft_input_buffer, fft_output_buffer, /*ordered=*/true);
|
||||
// Set the Nyquist frequency coefficient to zero.
|
||||
auto out = fft_output_buffer->GetView();
|
||||
out[1] = 0.f;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
SpectralFeaturesExtractor::SpectralFeaturesExtractor()
|
||||
: half_window_(ComputeScaledHalfVorbisWindow(
|
||||
1.f / static_cast<float>(kFrameSize20ms24kHz))),
|
||||
fft_(kFrameSize20ms24kHz, Pffft::FftType::kReal),
|
||||
fft_buffer_(fft_.CreateBuffer()),
|
||||
reference_frame_fft_(fft_.CreateBuffer()),
|
||||
lagged_frame_fft_(fft_.CreateBuffer()),
|
||||
dct_table_(ComputeDctTable()) {}
|
||||
|
||||
SpectralFeaturesExtractor::~SpectralFeaturesExtractor() = default;
|
||||
|
||||
void SpectralFeaturesExtractor::Reset() {
|
||||
cepstral_coeffs_ring_buf_.Reset();
|
||||
cepstral_diffs_buf_.Reset();
|
||||
}
|
||||
|
||||
bool SpectralFeaturesExtractor::CheckSilenceComputeFeatures(
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame,
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> lagged_frame,
|
||||
rtc::ArrayView<float, kNumBands - kNumLowerBands> higher_bands_cepstrum,
|
||||
rtc::ArrayView<float, kNumLowerBands> average,
|
||||
rtc::ArrayView<float, kNumLowerBands> first_derivative,
|
||||
rtc::ArrayView<float, kNumLowerBands> second_derivative,
|
||||
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr,
|
||||
float* variability) {
|
||||
// Compute the Opus band energies for the reference frame.
|
||||
ComputeWindowedForwardFft(reference_frame, half_window_, fft_buffer_.get(),
|
||||
reference_frame_fft_.get(), &fft_);
|
||||
spectral_correlator_.ComputeAutoCorrelation(
|
||||
reference_frame_fft_->GetConstView(), reference_frame_bands_energy_);
|
||||
// Check if the reference frame has silence.
|
||||
const float tot_energy =
|
||||
std::accumulate(reference_frame_bands_energy_.begin(),
|
||||
reference_frame_bands_energy_.end(), 0.f);
|
||||
if (tot_energy < kSilenceThreshold) {
|
||||
return true;
|
||||
}
|
||||
// Compute the Opus band energies for the lagged frame.
|
||||
ComputeWindowedForwardFft(lagged_frame, half_window_, fft_buffer_.get(),
|
||||
lagged_frame_fft_.get(), &fft_);
|
||||
spectral_correlator_.ComputeAutoCorrelation(lagged_frame_fft_->GetConstView(),
|
||||
lagged_frame_bands_energy_);
|
||||
// Log of the band energies for the reference frame.
|
||||
std::array<float, kNumBands> log_bands_energy;
|
||||
ComputeSmoothedLogMagnitudeSpectrum(reference_frame_bands_energy_,
|
||||
log_bands_energy);
|
||||
// Reference frame cepstrum.
|
||||
std::array<float, kNumBands> cepstrum;
|
||||
ComputeDct(log_bands_energy, dct_table_, cepstrum);
|
||||
// Ad-hoc correction terms for the first two cepstral coefficients.
|
||||
cepstrum[0] -= 12.f;
|
||||
cepstrum[1] -= 4.f;
|
||||
// Update the ring buffer and the cepstral difference stats.
|
||||
cepstral_coeffs_ring_buf_.Push(cepstrum);
|
||||
UpdateCepstralDifferenceStats(cepstrum, cepstral_coeffs_ring_buf_,
|
||||
&cepstral_diffs_buf_);
|
||||
// Write the higher bands cepstral coefficients.
|
||||
RTC_DCHECK_EQ(cepstrum.size() - kNumLowerBands, higher_bands_cepstrum.size());
|
||||
std::copy(cepstrum.begin() + kNumLowerBands, cepstrum.end(),
|
||||
higher_bands_cepstrum.begin());
|
||||
// Compute and write remaining features.
|
||||
ComputeAvgAndDerivatives(average, first_derivative, second_derivative);
|
||||
ComputeNormalizedCepstralCorrelation(bands_cross_corr);
|
||||
RTC_DCHECK(variability);
|
||||
*variability = ComputeVariability();
|
||||
return false;
|
||||
}
|
||||
|
||||
void SpectralFeaturesExtractor::ComputeAvgAndDerivatives(
|
||||
rtc::ArrayView<float, kNumLowerBands> average,
|
||||
rtc::ArrayView<float, kNumLowerBands> first_derivative,
|
||||
rtc::ArrayView<float, kNumLowerBands> second_derivative) const {
|
||||
auto curr = cepstral_coeffs_ring_buf_.GetArrayView(0);
|
||||
auto prev1 = cepstral_coeffs_ring_buf_.GetArrayView(1);
|
||||
auto prev2 = cepstral_coeffs_ring_buf_.GetArrayView(2);
|
||||
RTC_DCHECK_EQ(average.size(), first_derivative.size());
|
||||
RTC_DCHECK_EQ(first_derivative.size(), second_derivative.size());
|
||||
RTC_DCHECK_LE(average.size(), curr.size());
|
||||
for (int i = 0; rtc::SafeLt(i, average.size()); ++i) {
|
||||
// Average, kernel: [1, 1, 1].
|
||||
average[i] = curr[i] + prev1[i] + prev2[i];
|
||||
// First derivative, kernel: [1, 0, - 1].
|
||||
first_derivative[i] = curr[i] - prev2[i];
|
||||
// Second derivative, Laplacian kernel: [1, -2, 1].
|
||||
second_derivative[i] = curr[i] - 2 * prev1[i] + prev2[i];
|
||||
}
|
||||
}
|
||||
|
||||
void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
|
||||
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr) {
|
||||
spectral_correlator_.ComputeCrossCorrelation(
|
||||
reference_frame_fft_->GetConstView(), lagged_frame_fft_->GetConstView(),
|
||||
bands_cross_corr_);
|
||||
// Normalize.
|
||||
for (int i = 0; rtc::SafeLt(i, bands_cross_corr_.size()); ++i) {
|
||||
bands_cross_corr_[i] =
|
||||
bands_cross_corr_[i] /
|
||||
std::sqrt(0.001f + reference_frame_bands_energy_[i] *
|
||||
lagged_frame_bands_energy_[i]);
|
||||
}
|
||||
// Cepstrum.
|
||||
ComputeDct(bands_cross_corr_, dct_table_, bands_cross_corr);
|
||||
// Ad-hoc correction terms for the first two cepstral coefficients.
|
||||
bands_cross_corr[0] -= 1.3f;
|
||||
bands_cross_corr[1] -= 0.9f;
|
||||
}
|
||||
|
||||
float SpectralFeaturesExtractor::ComputeVariability() const {
|
||||
// Compute cepstral variability score.
|
||||
float variability = 0.f;
|
||||
for (int delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
|
||||
float min_dist = std::numeric_limits<float>::max();
|
||||
for (int delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
|
||||
if (delay1 == delay2) // The distance would be 0.
|
||||
continue;
|
||||
min_dist =
|
||||
std::min(min_dist, cepstral_diffs_buf_.GetValue(delay1, delay2));
|
||||
}
|
||||
variability += min_dist;
|
||||
}
|
||||
// Normalize (based on training set stats).
|
||||
// TODO(bugs.webrtc.org/10480): Isolate normalization from feature extraction.
|
||||
return variability / kCepstralCoeffsHistorySize - 2.1f;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,79 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_
|
||||
|
||||
#include <array>
|
||||
#include <cstddef>
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/ring_buffer.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h"
|
||||
#include "modules/audio_processing/utility/pffft_wrapper.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Class to compute spectral features.
|
||||
class SpectralFeaturesExtractor {
|
||||
public:
|
||||
SpectralFeaturesExtractor();
|
||||
SpectralFeaturesExtractor(const SpectralFeaturesExtractor&) = delete;
|
||||
SpectralFeaturesExtractor& operator=(const SpectralFeaturesExtractor&) =
|
||||
delete;
|
||||
~SpectralFeaturesExtractor();
|
||||
// Resets the internal state of the feature extractor.
|
||||
void Reset();
|
||||
// Analyzes a pair of reference and lagged frames from the pitch buffer,
|
||||
// detects silence and computes features. If silence is detected, the output
|
||||
// is neither computed nor written.
|
||||
bool CheckSilenceComputeFeatures(
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame,
|
||||
rtc::ArrayView<const float, kFrameSize20ms24kHz> lagged_frame,
|
||||
rtc::ArrayView<float, kNumBands - kNumLowerBands> higher_bands_cepstrum,
|
||||
rtc::ArrayView<float, kNumLowerBands> average,
|
||||
rtc::ArrayView<float, kNumLowerBands> first_derivative,
|
||||
rtc::ArrayView<float, kNumLowerBands> second_derivative,
|
||||
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr,
|
||||
float* variability);
|
||||
|
||||
private:
|
||||
void ComputeAvgAndDerivatives(
|
||||
rtc::ArrayView<float, kNumLowerBands> average,
|
||||
rtc::ArrayView<float, kNumLowerBands> first_derivative,
|
||||
rtc::ArrayView<float, kNumLowerBands> second_derivative) const;
|
||||
void ComputeNormalizedCepstralCorrelation(
|
||||
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr);
|
||||
float ComputeVariability() const;
|
||||
|
||||
const std::array<float, kFrameSize20ms24kHz / 2> half_window_;
|
||||
Pffft fft_;
|
||||
std::unique_ptr<Pffft::FloatBuffer> fft_buffer_;
|
||||
std::unique_ptr<Pffft::FloatBuffer> reference_frame_fft_;
|
||||
std::unique_ptr<Pffft::FloatBuffer> lagged_frame_fft_;
|
||||
SpectralCorrelator spectral_correlator_;
|
||||
std::array<float, kOpusBands24kHz> reference_frame_bands_energy_;
|
||||
std::array<float, kOpusBands24kHz> lagged_frame_bands_energy_;
|
||||
std::array<float, kOpusBands24kHz> bands_cross_corr_;
|
||||
const std::array<float, kNumBands * kNumBands> dct_table_;
|
||||
RingBuffer<float, kNumBands, kCepstralCoeffsHistorySize>
|
||||
cepstral_coeffs_ring_buf_;
|
||||
SymmetricMatrixBuffer<float, kCepstralCoeffsHistorySize> cepstral_diffs_buf_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_
|
||||
|
|
@ -0,0 +1,188 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <cmath>
|
||||
#include <cstddef>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
// Weights for each FFT coefficient for each Opus band (Nyquist frequency
|
||||
// excluded). The size of each band is specified in
|
||||
// `kOpusScaleNumBins24kHz20ms`.
|
||||
constexpr std::array<float, kFrameSize20ms24kHz / 2> kOpusBandWeights24kHz20ms =
|
||||
{{
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 0
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 1
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 2
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 3
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 4
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 5
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 6
|
||||
0.f, 0.25f, 0.5f, 0.75f, // Band 7
|
||||
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
|
||||
0.625f, 0.75f, 0.875f, // Band 8
|
||||
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
|
||||
0.625f, 0.75f, 0.875f, // Band 9
|
||||
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
|
||||
0.625f, 0.75f, 0.875f, // Band 10
|
||||
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
|
||||
0.625f, 0.75f, 0.875f, // Band 11
|
||||
0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
|
||||
0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
|
||||
0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
|
||||
0.9375f, // Band 12
|
||||
0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
|
||||
0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
|
||||
0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
|
||||
0.9375f, // Band 13
|
||||
0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
|
||||
0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
|
||||
0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
|
||||
0.9375f, // Band 14
|
||||
0.f, 0.0416667f, 0.0833333f, 0.125f, 0.166667f,
|
||||
0.208333f, 0.25f, 0.291667f, 0.333333f, 0.375f,
|
||||
0.416667f, 0.458333f, 0.5f, 0.541667f, 0.583333f,
|
||||
0.625f, 0.666667f, 0.708333f, 0.75f, 0.791667f,
|
||||
0.833333f, 0.875f, 0.916667f, 0.958333f, // Band 15
|
||||
0.f, 0.0416667f, 0.0833333f, 0.125f, 0.166667f,
|
||||
0.208333f, 0.25f, 0.291667f, 0.333333f, 0.375f,
|
||||
0.416667f, 0.458333f, 0.5f, 0.541667f, 0.583333f,
|
||||
0.625f, 0.666667f, 0.708333f, 0.75f, 0.791667f,
|
||||
0.833333f, 0.875f, 0.916667f, 0.958333f, // Band 16
|
||||
0.f, 0.03125f, 0.0625f, 0.09375f, 0.125f,
|
||||
0.15625f, 0.1875f, 0.21875f, 0.25f, 0.28125f,
|
||||
0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f,
|
||||
0.46875f, 0.5f, 0.53125f, 0.5625f, 0.59375f,
|
||||
0.625f, 0.65625f, 0.6875f, 0.71875f, 0.75f,
|
||||
0.78125f, 0.8125f, 0.84375f, 0.875f, 0.90625f,
|
||||
0.9375f, 0.96875f, // Band 17
|
||||
0.f, 0.0208333f, 0.0416667f, 0.0625f, 0.0833333f,
|
||||
0.104167f, 0.125f, 0.145833f, 0.166667f, 0.1875f,
|
||||
0.208333f, 0.229167f, 0.25f, 0.270833f, 0.291667f,
|
||||
0.3125f, 0.333333f, 0.354167f, 0.375f, 0.395833f,
|
||||
0.416667f, 0.4375f, 0.458333f, 0.479167f, 0.5f,
|
||||
0.520833f, 0.541667f, 0.5625f, 0.583333f, 0.604167f,
|
||||
0.625f, 0.645833f, 0.666667f, 0.6875f, 0.708333f,
|
||||
0.729167f, 0.75f, 0.770833f, 0.791667f, 0.8125f,
|
||||
0.833333f, 0.854167f, 0.875f, 0.895833f, 0.916667f,
|
||||
0.9375f, 0.958333f, 0.979167f // Band 18
|
||||
}};
|
||||
|
||||
} // namespace
|
||||
|
||||
SpectralCorrelator::SpectralCorrelator()
|
||||
: weights_(kOpusBandWeights24kHz20ms.begin(),
|
||||
kOpusBandWeights24kHz20ms.end()) {}
|
||||
|
||||
SpectralCorrelator::~SpectralCorrelator() = default;
|
||||
|
||||
void SpectralCorrelator::ComputeAutoCorrelation(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const {
|
||||
ComputeCrossCorrelation(x, x, auto_corr);
|
||||
}
|
||||
|
||||
void SpectralCorrelator::ComputeCrossCorrelation(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float, kOpusBands24kHz> cross_corr) const {
|
||||
RTC_DCHECK_EQ(x.size(), kFrameSize20ms24kHz);
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
RTC_DCHECK_EQ(x[1], 0.f) << "The Nyquist coefficient must be zeroed.";
|
||||
RTC_DCHECK_EQ(y[1], 0.f) << "The Nyquist coefficient must be zeroed.";
|
||||
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
|
||||
int k = 0; // Next Fourier coefficient index.
|
||||
cross_corr[0] = 0.f;
|
||||
for (int i = 0; i < kOpusBands24kHz - 1; ++i) {
|
||||
cross_corr[i + 1] = 0.f;
|
||||
for (int j = 0; j < kOpusScaleNumBins24kHz20ms[i]; ++j) { // Band size.
|
||||
const float v = x[2 * k] * y[2 * k] + x[2 * k + 1] * y[2 * k + 1];
|
||||
const float tmp = weights_[k] * v;
|
||||
cross_corr[i] += v - tmp;
|
||||
cross_corr[i + 1] += tmp;
|
||||
k++;
|
||||
}
|
||||
}
|
||||
cross_corr[0] *= 2.f; // The first band only gets half contribution.
|
||||
RTC_DCHECK_EQ(k, kFrameSize20ms24kHz / 2); // Nyquist coefficient never used.
|
||||
}
|
||||
|
||||
void ComputeSmoothedLogMagnitudeSpectrum(
|
||||
rtc::ArrayView<const float> bands_energy,
|
||||
rtc::ArrayView<float, kNumBands> log_bands_energy) {
|
||||
RTC_DCHECK_LE(bands_energy.size(), kNumBands);
|
||||
constexpr float kOneByHundred = 1e-2f;
|
||||
constexpr float kLogOneByHundred = -2.f;
|
||||
// Init.
|
||||
float log_max = kLogOneByHundred;
|
||||
float follow = kLogOneByHundred;
|
||||
const auto smooth = [&log_max, &follow](float x) {
|
||||
x = std::max(log_max - 7.f, std::max(follow - 1.5f, x));
|
||||
log_max = std::max(log_max, x);
|
||||
follow = std::max(follow - 1.5f, x);
|
||||
return x;
|
||||
};
|
||||
// Smoothing over the bands for which the band energy is defined.
|
||||
for (int i = 0; rtc::SafeLt(i, bands_energy.size()); ++i) {
|
||||
log_bands_energy[i] = smooth(std::log10(kOneByHundred + bands_energy[i]));
|
||||
}
|
||||
// Smoothing over the remaining bands (zero energy).
|
||||
for (int i = bands_energy.size(); i < kNumBands; ++i) {
|
||||
log_bands_energy[i] = smooth(kLogOneByHundred);
|
||||
}
|
||||
}
|
||||
|
||||
std::array<float, kNumBands * kNumBands> ComputeDctTable() {
|
||||
std::array<float, kNumBands * kNumBands> dct_table;
|
||||
const double k = std::sqrt(0.5);
|
||||
for (int i = 0; i < kNumBands; ++i) {
|
||||
for (int j = 0; j < kNumBands; ++j)
|
||||
dct_table[i * kNumBands + j] = std::cos((i + 0.5) * j * kPi / kNumBands);
|
||||
dct_table[i * kNumBands] *= k;
|
||||
}
|
||||
return dct_table;
|
||||
}
|
||||
|
||||
void ComputeDct(rtc::ArrayView<const float> in,
|
||||
rtc::ArrayView<const float, kNumBands * kNumBands> dct_table,
|
||||
rtc::ArrayView<float> out) {
|
||||
// DCT scaling factor - i.e., sqrt(2 / kNumBands).
|
||||
constexpr float kDctScalingFactor = 0.301511345f;
|
||||
constexpr float kDctScalingFactorError =
|
||||
kDctScalingFactor * kDctScalingFactor -
|
||||
2.f / static_cast<float>(kNumBands);
|
||||
static_assert(
|
||||
(kDctScalingFactorError >= 0.f && kDctScalingFactorError < 1e-1f) ||
|
||||
(kDctScalingFactorError < 0.f && kDctScalingFactorError > -1e-1f),
|
||||
"kNumBands changed and kDctScalingFactor has not been updated.");
|
||||
RTC_DCHECK_NE(in.data(), out.data()) << "In-place DCT is not supported.";
|
||||
RTC_DCHECK_LE(in.size(), kNumBands);
|
||||
RTC_DCHECK_LE(1, out.size());
|
||||
RTC_DCHECK_LE(out.size(), in.size());
|
||||
for (int i = 0; rtc::SafeLt(i, out.size()); ++i) {
|
||||
out[i] = 0.f;
|
||||
for (int j = 0; rtc::SafeLt(j, in.size()); ++j) {
|
||||
out[i] += in[j] * dct_table[j * kNumBands + i];
|
||||
}
|
||||
// TODO(bugs.webrtc.org/10480): Scaling factor in the DCT table.
|
||||
out[i] *= kDctScalingFactor;
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,100 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <array>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// At a sample rate of 24 kHz, the last 3 Opus bands are beyond the Nyquist
|
||||
// frequency. However, band #19 gets the contributions from band #18 because
|
||||
// of the symmetric triangular filter with peak response at 12 kHz.
|
||||
constexpr int kOpusBands24kHz = 20;
|
||||
static_assert(kOpusBands24kHz < kNumBands,
|
||||
"The number of bands at 24 kHz must be less than those defined "
|
||||
"in the Opus scale at 48 kHz.");
|
||||
|
||||
// Number of FFT frequency bins covered by each band in the Opus scale at a
|
||||
// sample rate of 24 kHz for 20 ms frames.
|
||||
// Declared here for unit testing.
|
||||
constexpr std::array<int, kOpusBands24kHz - 1> GetOpusScaleNumBins24kHz20ms() {
|
||||
return {4, 4, 4, 4, 4, 4, 4, 4, 8, 8, 8, 8, 16, 16, 16, 24, 24, 32, 48};
|
||||
}
|
||||
|
||||
// TODO(bugs.webrtc.org/10480): Move to a separate file.
|
||||
// Class to compute band-wise spectral features in the Opus perceptual scale
|
||||
// for 20 ms frames sampled at 24 kHz. The analysis methods apply triangular
|
||||
// filters with peak response at the each band boundary.
|
||||
class SpectralCorrelator {
|
||||
public:
|
||||
// Ctor.
|
||||
SpectralCorrelator();
|
||||
SpectralCorrelator(const SpectralCorrelator&) = delete;
|
||||
SpectralCorrelator& operator=(const SpectralCorrelator&) = delete;
|
||||
~SpectralCorrelator();
|
||||
|
||||
// Computes the band-wise spectral auto-correlations.
|
||||
// `x` must:
|
||||
// - have size equal to `kFrameSize20ms24kHz`;
|
||||
// - be encoded as vectors of interleaved real-complex FFT coefficients
|
||||
// where x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
|
||||
void ComputeAutoCorrelation(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const;
|
||||
|
||||
// Computes the band-wise spectral cross-correlations.
|
||||
// `x` and `y` must:
|
||||
// - have size equal to `kFrameSize20ms24kHz`;
|
||||
// - be encoded as vectors of interleaved real-complex FFT coefficients where
|
||||
// x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
|
||||
void ComputeCrossCorrelation(
|
||||
rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y,
|
||||
rtc::ArrayView<float, kOpusBands24kHz> cross_corr) const;
|
||||
|
||||
private:
|
||||
const std::vector<float> weights_; // Weights for each Fourier coefficient.
|
||||
};
|
||||
|
||||
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
|
||||
// spectral_features.cc. Given a vector of Opus-bands energy coefficients,
|
||||
// computes the log magnitude spectrum applying smoothing both over time and
|
||||
// over frequency. Declared here for unit testing.
|
||||
void ComputeSmoothedLogMagnitudeSpectrum(
|
||||
rtc::ArrayView<const float> bands_energy,
|
||||
rtc::ArrayView<float, kNumBands> log_bands_energy);
|
||||
|
||||
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
|
||||
// spectral_features.cc. Creates a DCT table for arrays having size equal to
|
||||
// `kNumBands`. Declared here for unit testing.
|
||||
std::array<float, kNumBands * kNumBands> ComputeDctTable();
|
||||
|
||||
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
|
||||
// spectral_features.cc. Computes DCT for `in` given a pre-computed DCT table.
|
||||
// In-place computation is not allowed and `out` can be smaller than `in` in
|
||||
// order to only compute the first DCT coefficients. Declared here for unit
|
||||
// testing.
|
||||
void ComputeDct(rtc::ArrayView<const float> in,
|
||||
rtc::ArrayView<const float, kNumBands * kNumBands> dct_table,
|
||||
rtc::ArrayView<float> out);
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_
|
||||
|
|
@ -0,0 +1,95 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <array>
|
||||
#include <cstring>
|
||||
#include <utility>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Data structure to buffer the results of pair-wise comparisons between items
|
||||
// stored in a ring buffer. Every time that the oldest item is replaced in the
|
||||
// ring buffer, the new one is compared to the remaining items in the ring
|
||||
// buffer. The results of such comparisons need to be buffered and automatically
|
||||
// removed when one of the two corresponding items that have been compared is
|
||||
// removed from the ring buffer. It is assumed that the comparison is symmetric
|
||||
// and that comparing an item with itself is not needed.
|
||||
template <typename T, int S>
|
||||
class SymmetricMatrixBuffer {
|
||||
static_assert(S > 2, "");
|
||||
|
||||
public:
|
||||
SymmetricMatrixBuffer() = default;
|
||||
SymmetricMatrixBuffer(const SymmetricMatrixBuffer&) = delete;
|
||||
SymmetricMatrixBuffer& operator=(const SymmetricMatrixBuffer&) = delete;
|
||||
~SymmetricMatrixBuffer() = default;
|
||||
// Sets the buffer values to zero.
|
||||
void Reset() {
|
||||
static_assert(std::is_arithmetic<T>::value,
|
||||
"Integral or floating point required.");
|
||||
buf_.fill(0);
|
||||
}
|
||||
// Pushes the results from the comparison between the most recent item and
|
||||
// those that are still in the ring buffer. The first element in `values` must
|
||||
// correspond to the comparison between the most recent item and the second
|
||||
// most recent one in the ring buffer, whereas the last element in `values`
|
||||
// must correspond to the comparison between the most recent item and the
|
||||
// oldest one in the ring buffer.
|
||||
void Push(rtc::ArrayView<T, S - 1> values) {
|
||||
// Move the lower-right sub-matrix of size (S-2) x (S-2) one row up and one
|
||||
// column left.
|
||||
std::memmove(buf_.data(), buf_.data() + S, (buf_.size() - S) * sizeof(T));
|
||||
// Copy new values in the last column in the right order.
|
||||
for (int i = 0; rtc::SafeLt(i, values.size()); ++i) {
|
||||
const int index = (S - 1 - i) * (S - 1) - 1;
|
||||
RTC_DCHECK_GE(index, 0);
|
||||
RTC_DCHECK_LT(index, buf_.size());
|
||||
buf_[index] = values[i];
|
||||
}
|
||||
}
|
||||
// Reads the value that corresponds to comparison of two items in the ring
|
||||
// buffer having delay `delay1` and `delay2`. The two arguments must not be
|
||||
// equal and both must be in {0, ..., S - 1}.
|
||||
T GetValue(int delay1, int delay2) const {
|
||||
int row = S - 1 - delay1;
|
||||
int col = S - 1 - delay2;
|
||||
RTC_DCHECK_NE(row, col) << "The diagonal cannot be accessed.";
|
||||
if (row > col)
|
||||
std::swap(row, col); // Swap to access the upper-right triangular part.
|
||||
RTC_DCHECK_LE(0, row);
|
||||
RTC_DCHECK_LT(row, S - 1) << "Not enforcing row < col and row != col.";
|
||||
RTC_DCHECK_LE(1, col) << "Not enforcing row < col and row != col.";
|
||||
RTC_DCHECK_LT(col, S);
|
||||
const int index = row * (S - 1) + (col - 1);
|
||||
RTC_DCHECK_LE(0, index);
|
||||
RTC_DCHECK_LT(index, buf_.size());
|
||||
return buf_[index];
|
||||
}
|
||||
|
||||
private:
|
||||
// Encode an upper-right triangular matrix (excluding its diagonal) using a
|
||||
// square matrix. This allows to move the data in Push() with one single
|
||||
// operation.
|
||||
std::array<T, (S - 1) * (S - 1)> buf_{};
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
|
||||
|
||||
#include <algorithm>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
#include <type_traits>
|
||||
#include <vector>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
#include "test/gtest.h"
|
||||
#include "test/testsupport/file_utils.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
namespace {
|
||||
|
||||
// File reader for binary files that contain a sequence of values with
|
||||
// arithmetic type `T`. The values of type `T` that are read are cast to float.
|
||||
template <typename T>
|
||||
class FloatFileReader : public FileReader {
|
||||
public:
|
||||
static_assert(std::is_arithmetic<T>::value, "");
|
||||
explicit FloatFileReader(absl::string_view filename)
|
||||
: is_(std::string(filename), std::ios::binary | std::ios::ate),
|
||||
size_(is_.tellg() / sizeof(T)) {
|
||||
RTC_CHECK(is_);
|
||||
SeekBeginning();
|
||||
}
|
||||
FloatFileReader(const FloatFileReader&) = delete;
|
||||
FloatFileReader& operator=(const FloatFileReader&) = delete;
|
||||
~FloatFileReader() = default;
|
||||
|
||||
int size() const override { return size_; }
|
||||
bool ReadChunk(rtc::ArrayView<float> dst) override {
|
||||
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
|
||||
if (std::is_same<T, float>::value) {
|
||||
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
|
||||
} else {
|
||||
buffer_.resize(dst.size());
|
||||
is_.read(reinterpret_cast<char*>(buffer_.data()), bytes_to_read);
|
||||
std::transform(buffer_.begin(), buffer_.end(), dst.begin(),
|
||||
[](const T& v) -> float { return static_cast<float>(v); });
|
||||
}
|
||||
return is_.gcount() == bytes_to_read;
|
||||
}
|
||||
bool ReadValue(float& dst) override { return ReadChunk({&dst, 1}); }
|
||||
void SeekForward(int hop) override { is_.seekg(hop * sizeof(T), is_.cur); }
|
||||
void SeekBeginning() override { is_.seekg(0, is_.beg); }
|
||||
|
||||
private:
|
||||
std::ifstream is_;
|
||||
const int size_;
|
||||
std::vector<T> buffer_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
using webrtc::test::ResourcePath;
|
||||
|
||||
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed) {
|
||||
ASSERT_EQ(expected.size(), computed.size());
|
||||
for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
EXPECT_FLOAT_EQ(expected[i], computed[i]);
|
||||
}
|
||||
}
|
||||
|
||||
void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed,
|
||||
float tolerance) {
|
||||
ASSERT_EQ(expected.size(), computed.size());
|
||||
for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
|
||||
SCOPED_TRACE(i);
|
||||
EXPECT_NEAR(expected[i], computed[i], tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
std::unique_ptr<FileReader> CreatePcmSamplesReader() {
|
||||
return std::make_unique<FloatFileReader<int16_t>>(
|
||||
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/samples",
|
||||
"pcm"));
|
||||
}
|
||||
|
||||
ChunksFileReader CreatePitchBuffer24kHzReader() {
|
||||
auto reader = std::make_unique<FloatFileReader<float>>(
|
||||
/*filename=*/test::ResourcePath(
|
||||
"audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"));
|
||||
const int num_chunks = rtc::CheckedDivExact(reader->size(), kBufSize24kHz);
|
||||
return {/*chunk_size=*/kBufSize24kHz, num_chunks, std::move(reader)};
|
||||
}
|
||||
|
||||
ChunksFileReader CreateLpResidualAndPitchInfoReader() {
|
||||
constexpr int kPitchInfoSize = 2; // Pitch period and strength.
|
||||
constexpr int kChunkSize = kBufSize24kHz + kPitchInfoSize;
|
||||
auto reader = std::make_unique<FloatFileReader<float>>(
|
||||
/*filename=*/test::ResourcePath(
|
||||
"audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"));
|
||||
const int num_chunks = rtc::CheckedDivExact(reader->size(), kChunkSize);
|
||||
return {kChunkSize, num_chunks, std::move(reader)};
|
||||
}
|
||||
|
||||
std::unique_ptr<FileReader> CreateGruInputReader() {
|
||||
return std::make_unique<FloatFileReader<float>>(
|
||||
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/gru_in",
|
||||
"dat"));
|
||||
}
|
||||
|
||||
std::unique_ptr<FileReader> CreateVadProbsReader() {
|
||||
return std::make_unique<FloatFileReader<float>>(
|
||||
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob",
|
||||
"dat"));
|
||||
}
|
||||
|
||||
PitchTestData::PitchTestData() {
|
||||
FloatFileReader<float> reader(
|
||||
/*filename=*/ResourcePath(
|
||||
"audio_processing/agc2/rnn_vad/pitch_search_int", "dat"));
|
||||
reader.ReadChunk(pitch_buffer_24k_);
|
||||
reader.ReadChunk(square_energies_24k_);
|
||||
reader.ReadChunk(auto_correlation_12k_);
|
||||
// Reverse the order of the squared energy values.
|
||||
// Required after the WebRTC CL 191703 which switched to forward computation.
|
||||
std::reverse(square_energies_24k_.begin(), square_energies_24k_.end());
|
||||
}
|
||||
|
||||
PitchTestData::~PitchTestData() = default;
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,130 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
|
||||
|
||||
#include <array>
|
||||
#include <fstream>
|
||||
#include <memory>
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
constexpr float kFloatMin = std::numeric_limits<float>::min();
|
||||
|
||||
// Fails for every pair from two equally sized rtc::ArrayView<float> views such
|
||||
// that the values in the pair do not match.
|
||||
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed);
|
||||
|
||||
// Fails for every pair from two equally sized rtc::ArrayView<float> views such
|
||||
// that their absolute error is above a given threshold.
|
||||
void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
|
||||
rtc::ArrayView<const float> computed,
|
||||
float tolerance);
|
||||
|
||||
// File reader interface.
|
||||
class FileReader {
|
||||
public:
|
||||
virtual ~FileReader() = default;
|
||||
// Number of values in the file.
|
||||
virtual int size() const = 0;
|
||||
// Reads `dst.size()` float values into `dst`, advances the internal file
|
||||
// position according to the number of read bytes and returns true if the
|
||||
// values are correctly read. If the number of remaining bytes in the file is
|
||||
// not sufficient to read `dst.size()` float values, `dst` is partially
|
||||
// modified and false is returned.
|
||||
virtual bool ReadChunk(rtc::ArrayView<float> dst) = 0;
|
||||
// Reads a single float value, advances the internal file position according
|
||||
// to the number of read bytes and returns true if the value is correctly
|
||||
// read. If the number of remaining bytes in the file is not sufficient to
|
||||
// read one float, `dst` is not modified and false is returned.
|
||||
virtual bool ReadValue(float& dst) = 0;
|
||||
// Advances the internal file position by `hop` float values.
|
||||
virtual void SeekForward(int hop) = 0;
|
||||
// Resets the internal file position to BOF.
|
||||
virtual void SeekBeginning() = 0;
|
||||
};
|
||||
|
||||
// File reader for files that contain `num_chunks` chunks with size equal to
|
||||
// `chunk_size`.
|
||||
struct ChunksFileReader {
|
||||
const int chunk_size;
|
||||
const int num_chunks;
|
||||
std::unique_ptr<FileReader> reader;
|
||||
};
|
||||
|
||||
// Creates a reader for the PCM S16 samples file.
|
||||
std::unique_ptr<FileReader> CreatePcmSamplesReader();
|
||||
|
||||
// Creates a reader for the 24 kHz pitch buffer test data.
|
||||
ChunksFileReader CreatePitchBuffer24kHzReader();
|
||||
|
||||
// Creates a reader for the LP residual and pitch information test data.
|
||||
ChunksFileReader CreateLpResidualAndPitchInfoReader();
|
||||
|
||||
// Creates a reader for the sequence of GRU input vectors.
|
||||
std::unique_ptr<FileReader> CreateGruInputReader();
|
||||
|
||||
// Creates a reader for the VAD probabilities test data.
|
||||
std::unique_ptr<FileReader> CreateVadProbsReader();
|
||||
|
||||
// Class to retrieve a test pitch buffer content and the expected output for the
|
||||
// analysis steps.
|
||||
class PitchTestData {
|
||||
public:
|
||||
PitchTestData();
|
||||
~PitchTestData();
|
||||
rtc::ArrayView<const float, kBufSize24kHz> PitchBuffer24kHzView() const {
|
||||
return pitch_buffer_24k_;
|
||||
}
|
||||
rtc::ArrayView<const float, kRefineNumLags24kHz> SquareEnergies24kHzView()
|
||||
const {
|
||||
return square_energies_24k_;
|
||||
}
|
||||
rtc::ArrayView<const float, kNumLags12kHz> AutoCorrelation12kHzView() const {
|
||||
return auto_correlation_12k_;
|
||||
}
|
||||
|
||||
private:
|
||||
std::array<float, kBufSize24kHz> pitch_buffer_24k_;
|
||||
std::array<float, kRefineNumLags24kHz> square_energies_24k_;
|
||||
std::array<float, kNumLags12kHz> auto_correlation_12k_;
|
||||
};
|
||||
|
||||
// Writer for binary files.
|
||||
class FileWriter {
|
||||
public:
|
||||
explicit FileWriter(absl::string_view file_path)
|
||||
: os_(std::string(file_path), std::ios::binary) {}
|
||||
FileWriter(const FileWriter&) = delete;
|
||||
FileWriter& operator=(const FileWriter&) = delete;
|
||||
~FileWriter() = default;
|
||||
void WriteChunk(rtc::ArrayView<const float> value) {
|
||||
const std::streamsize bytes_to_write = value.size() * sizeof(float);
|
||||
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
|
||||
}
|
||||
|
||||
private:
|
||||
std::ofstream os_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
|
||||
|
|
@ -0,0 +1,115 @@
|
|||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
|
||||
|
||||
// Defines WEBRTC_ARCH_X86_FAMILY, used below.
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
#if defined(WEBRTC_HAS_NEON)
|
||||
#include <arm_neon.h>
|
||||
#endif
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
#include <emmintrin.h>
|
||||
#endif
|
||||
|
||||
#include <numeric>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
#include "rtc_base/system/arch.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
// Provides optimizations for mathematical operations having vectors as
|
||||
// operand(s).
|
||||
class VectorMath {
|
||||
public:
|
||||
explicit VectorMath(AvailableCpuFeatures cpu_features)
|
||||
: cpu_features_(cpu_features) {}
|
||||
|
||||
// Computes the dot product between two equally sized vectors.
|
||||
float DotProduct(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y) const {
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
#if defined(WEBRTC_ARCH_X86_FAMILY)
|
||||
// TODO(@dkaraush): compile with avx support
|
||||
/*if (cpu_features_.avx2) {
|
||||
return DotProductAvx2(x, y);
|
||||
} else */if (cpu_features_.sse2) {
|
||||
__m128 accumulator = _mm_setzero_ps();
|
||||
constexpr int kBlockSizeLog2 = 2;
|
||||
constexpr int kBlockSize = 1 << kBlockSizeLog2;
|
||||
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
|
||||
<< kBlockSizeLog2;
|
||||
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
|
||||
RTC_DCHECK_LE(i + kBlockSize, x.size());
|
||||
const __m128 x_i = _mm_loadu_ps(&x[i]);
|
||||
const __m128 y_i = _mm_loadu_ps(&y[i]);
|
||||
// Multiply-add.
|
||||
const __m128 z_j = _mm_mul_ps(x_i, y_i);
|
||||
accumulator = _mm_add_ps(accumulator, z_j);
|
||||
}
|
||||
// Reduce `accumulator` by addition.
|
||||
__m128 high = _mm_movehl_ps(accumulator, accumulator);
|
||||
accumulator = _mm_add_ps(accumulator, high);
|
||||
high = _mm_shuffle_ps(accumulator, accumulator, 1);
|
||||
accumulator = _mm_add_ps(accumulator, high);
|
||||
float dot_product = _mm_cvtss_f32(accumulator);
|
||||
// Add the result for the last block if incomplete.
|
||||
for (int i = incomplete_block_index;
|
||||
i < rtc::dchecked_cast<int>(x.size()); ++i) {
|
||||
dot_product += x[i] * y[i];
|
||||
}
|
||||
return dot_product;
|
||||
}
|
||||
#elif defined(WEBRTC_HAS_NEON) && defined(WEBRTC_ARCH_ARM64)
|
||||
if (cpu_features_.neon) {
|
||||
float32x4_t accumulator = vdupq_n_f32(0.f);
|
||||
constexpr int kBlockSizeLog2 = 2;
|
||||
constexpr int kBlockSize = 1 << kBlockSizeLog2;
|
||||
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
|
||||
<< kBlockSizeLog2;
|
||||
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
|
||||
RTC_DCHECK_LE(i + kBlockSize, x.size());
|
||||
const float32x4_t x_i = vld1q_f32(&x[i]);
|
||||
const float32x4_t y_i = vld1q_f32(&y[i]);
|
||||
accumulator = vfmaq_f32(accumulator, x_i, y_i);
|
||||
}
|
||||
// Reduce `accumulator` by addition.
|
||||
const float32x2_t tmp =
|
||||
vpadd_f32(vget_low_f32(accumulator), vget_high_f32(accumulator));
|
||||
float dot_product = vget_lane_f32(vpadd_f32(tmp, vrev64_f32(tmp)), 0);
|
||||
// Add the result for the last block if incomplete.
|
||||
for (int i = incomplete_block_index;
|
||||
i < rtc::dchecked_cast<int>(x.size()); ++i) {
|
||||
dot_product += x[i] * y[i];
|
||||
}
|
||||
return dot_product;
|
||||
}
|
||||
#endif
|
||||
return std::inner_product(x.begin(), x.end(), y.begin(), 0.f);
|
||||
}
|
||||
|
||||
private:
|
||||
float DotProductAvx2(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y) const;
|
||||
|
||||
const AvailableCpuFeatures cpu_features_;
|
||||
};
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
|
||||
|
|
@ -0,0 +1,54 @@
|
|||
/*
|
||||
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include <immintrin.h>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_conversions.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace rnn_vad {
|
||||
|
||||
float VectorMath::DotProductAvx2(rtc::ArrayView<const float> x,
|
||||
rtc::ArrayView<const float> y) const {
|
||||
RTC_DCHECK(cpu_features_.avx2);
|
||||
RTC_DCHECK_EQ(x.size(), y.size());
|
||||
__m256 accumulator = _mm256_setzero_ps();
|
||||
constexpr int kBlockSizeLog2 = 3;
|
||||
constexpr int kBlockSize = 1 << kBlockSizeLog2;
|
||||
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
|
||||
<< kBlockSizeLog2;
|
||||
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
|
||||
RTC_DCHECK_LE(i + kBlockSize, x.size());
|
||||
const __m256 x_i = _mm256_loadu_ps(&x[i]);
|
||||
const __m256 y_i = _mm256_loadu_ps(&y[i]);
|
||||
accumulator = _mm256_fmadd_ps(x_i, y_i, accumulator);
|
||||
}
|
||||
// Reduce `accumulator` by addition.
|
||||
__m128 high = _mm256_extractf128_ps(accumulator, 1);
|
||||
__m128 low = _mm256_extractf128_ps(accumulator, 0);
|
||||
low = _mm_add_ps(high, low);
|
||||
high = _mm_movehl_ps(high, low);
|
||||
low = _mm_add_ps(high, low);
|
||||
high = _mm_shuffle_ps(low, low, 1);
|
||||
low = _mm_add_ss(high, low);
|
||||
float dot_product = _mm_cvtss_f32(low);
|
||||
// Add the result for the last block if incomplete.
|
||||
for (int i = incomplete_block_index; i < rtc::dchecked_cast<int>(x.size());
|
||||
++i) {
|
||||
dot_product += x[i] * y[i];
|
||||
}
|
||||
return dot_product;
|
||||
}
|
||||
|
||||
} // namespace rnn_vad
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,183 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/saturation_protector.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/agc2/saturation_protector_buffer.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
constexpr int kPeakEnveloperSuperFrameLengthMs = 400;
|
||||
constexpr float kMinMarginDb = 12.0f;
|
||||
constexpr float kMaxMarginDb = 25.0f;
|
||||
constexpr float kAttack = 0.9988493699365052f;
|
||||
constexpr float kDecay = 0.9997697679981565f;
|
||||
|
||||
// Saturation protector state. Defined outside of `SaturationProtectorImpl` to
|
||||
// implement check-point and restore ops.
|
||||
struct SaturationProtectorState {
|
||||
bool operator==(const SaturationProtectorState& s) const {
|
||||
return headroom_db == s.headroom_db &&
|
||||
peak_delay_buffer == s.peak_delay_buffer &&
|
||||
max_peaks_dbfs == s.max_peaks_dbfs &&
|
||||
time_since_push_ms == s.time_since_push_ms;
|
||||
}
|
||||
inline bool operator!=(const SaturationProtectorState& s) const {
|
||||
return !(*this == s);
|
||||
}
|
||||
|
||||
float headroom_db;
|
||||
SaturationProtectorBuffer peak_delay_buffer;
|
||||
float max_peaks_dbfs;
|
||||
int time_since_push_ms; // Time since the last ring buffer push operation.
|
||||
};
|
||||
|
||||
// Resets the saturation protector state.
|
||||
void ResetSaturationProtectorState(float initial_headroom_db,
|
||||
SaturationProtectorState& state) {
|
||||
state.headroom_db = initial_headroom_db;
|
||||
state.peak_delay_buffer.Reset();
|
||||
state.max_peaks_dbfs = kMinLevelDbfs;
|
||||
state.time_since_push_ms = 0;
|
||||
}
|
||||
|
||||
// Updates `state` by analyzing the estimated speech level `speech_level_dbfs`
|
||||
// and the peak level `peak_dbfs` for an observed frame. `state` must not be
|
||||
// modified without calling this function.
|
||||
void UpdateSaturationProtectorState(float peak_dbfs,
|
||||
float speech_level_dbfs,
|
||||
SaturationProtectorState& state) {
|
||||
// Get the max peak over `kPeakEnveloperSuperFrameLengthMs` ms.
|
||||
state.max_peaks_dbfs = std::max(state.max_peaks_dbfs, peak_dbfs);
|
||||
state.time_since_push_ms += kFrameDurationMs;
|
||||
if (rtc::SafeGt(state.time_since_push_ms, kPeakEnveloperSuperFrameLengthMs)) {
|
||||
// Push `max_peaks_dbfs` back into the ring buffer.
|
||||
state.peak_delay_buffer.PushBack(state.max_peaks_dbfs);
|
||||
// Reset.
|
||||
state.max_peaks_dbfs = kMinLevelDbfs;
|
||||
state.time_since_push_ms = 0;
|
||||
}
|
||||
|
||||
// Update the headroom by comparing the estimated speech level and the delayed
|
||||
// max speech peak.
|
||||
const float delayed_peak_dbfs =
|
||||
state.peak_delay_buffer.Front().value_or(state.max_peaks_dbfs);
|
||||
const float difference_db = delayed_peak_dbfs - speech_level_dbfs;
|
||||
if (difference_db > state.headroom_db) {
|
||||
// Attack.
|
||||
state.headroom_db =
|
||||
state.headroom_db * kAttack + difference_db * (1.0f - kAttack);
|
||||
} else {
|
||||
// Decay.
|
||||
state.headroom_db =
|
||||
state.headroom_db * kDecay + difference_db * (1.0f - kDecay);
|
||||
}
|
||||
|
||||
state.headroom_db =
|
||||
rtc::SafeClamp<float>(state.headroom_db, kMinMarginDb, kMaxMarginDb);
|
||||
}
|
||||
|
||||
// Saturation protector which recommends a headroom based on the recent peaks.
|
||||
class SaturationProtectorImpl : public SaturationProtector {
|
||||
public:
|
||||
explicit SaturationProtectorImpl(float initial_headroom_db,
|
||||
int adjacent_speech_frames_threshold,
|
||||
ApmDataDumper* apm_data_dumper)
|
||||
: apm_data_dumper_(apm_data_dumper),
|
||||
initial_headroom_db_(initial_headroom_db),
|
||||
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold) {
|
||||
Reset();
|
||||
}
|
||||
SaturationProtectorImpl(const SaturationProtectorImpl&) = delete;
|
||||
SaturationProtectorImpl& operator=(const SaturationProtectorImpl&) = delete;
|
||||
~SaturationProtectorImpl() = default;
|
||||
|
||||
float HeadroomDb() override { return headroom_db_; }
|
||||
|
||||
void Analyze(float speech_probability,
|
||||
float peak_dbfs,
|
||||
float speech_level_dbfs) override {
|
||||
if (speech_probability < kVadConfidenceThreshold) {
|
||||
// Not a speech frame.
|
||||
if (adjacent_speech_frames_threshold_ > 1) {
|
||||
// When two or more adjacent speech frames are required in order to
|
||||
// update the state, we need to decide whether to discard or confirm the
|
||||
// updates based on the speech sequence length.
|
||||
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
|
||||
// First non-speech frame after a long enough sequence of speech
|
||||
// frames. Update the reliable state.
|
||||
reliable_state_ = preliminary_state_;
|
||||
} else if (num_adjacent_speech_frames_ > 0) {
|
||||
// First non-speech frame after a too short sequence of speech frames.
|
||||
// Reset to the last reliable state.
|
||||
preliminary_state_ = reliable_state_;
|
||||
}
|
||||
}
|
||||
num_adjacent_speech_frames_ = 0;
|
||||
} else {
|
||||
// Speech frame observed.
|
||||
num_adjacent_speech_frames_++;
|
||||
|
||||
// Update preliminary level estimate.
|
||||
UpdateSaturationProtectorState(peak_dbfs, speech_level_dbfs,
|
||||
preliminary_state_);
|
||||
|
||||
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
|
||||
// `preliminary_state_` is now reliable. Update the headroom.
|
||||
headroom_db_ = preliminary_state_.headroom_db;
|
||||
}
|
||||
}
|
||||
DumpDebugData();
|
||||
}
|
||||
|
||||
void Reset() override {
|
||||
num_adjacent_speech_frames_ = 0;
|
||||
headroom_db_ = initial_headroom_db_;
|
||||
ResetSaturationProtectorState(initial_headroom_db_, preliminary_state_);
|
||||
ResetSaturationProtectorState(initial_headroom_db_, reliable_state_);
|
||||
}
|
||||
|
||||
private:
|
||||
void DumpDebugData() {
|
||||
apm_data_dumper_->DumpRaw(
|
||||
"agc2_saturation_protector_preliminary_max_peak_dbfs",
|
||||
preliminary_state_.max_peaks_dbfs);
|
||||
apm_data_dumper_->DumpRaw(
|
||||
"agc2_saturation_protector_reliable_max_peak_dbfs",
|
||||
reliable_state_.max_peaks_dbfs);
|
||||
}
|
||||
|
||||
ApmDataDumper* const apm_data_dumper_;
|
||||
const float initial_headroom_db_;
|
||||
const int adjacent_speech_frames_threshold_;
|
||||
int num_adjacent_speech_frames_;
|
||||
float headroom_db_;
|
||||
SaturationProtectorState preliminary_state_;
|
||||
SaturationProtectorState reliable_state_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<SaturationProtector> CreateSaturationProtector(
|
||||
float initial_headroom_db,
|
||||
int adjacent_speech_frames_threshold,
|
||||
ApmDataDumper* apm_data_dumper) {
|
||||
return std::make_unique<SaturationProtectorImpl>(
|
||||
initial_headroom_db, adjacent_speech_frames_threshold, apm_data_dumper);
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,46 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
namespace webrtc {
|
||||
class ApmDataDumper;
|
||||
|
||||
// Saturation protector. Analyzes peak levels and recommends a headroom to
|
||||
// reduce the chances of clipping.
|
||||
class SaturationProtector {
|
||||
public:
|
||||
virtual ~SaturationProtector() = default;
|
||||
|
||||
// Returns the recommended headroom in dB.
|
||||
virtual float HeadroomDb() = 0;
|
||||
|
||||
// Analyzes the peak level of a 10 ms frame along with its speech probability
|
||||
// and the current speech level estimate to update the recommended headroom.
|
||||
virtual void Analyze(float speech_probability,
|
||||
float peak_dbfs,
|
||||
float speech_level_dbfs) = 0;
|
||||
|
||||
// Resets the internal state.
|
||||
virtual void Reset() = 0;
|
||||
};
|
||||
|
||||
// Creates a saturation protector that starts at `initial_headroom_db`.
|
||||
std::unique_ptr<SaturationProtector> CreateSaturationProtector(
|
||||
float initial_headroom_db,
|
||||
int adjacent_speech_frames_threshold,
|
||||
ApmDataDumper* apm_data_dumper);
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_
|
||||
|
|
@ -0,0 +1,77 @@
|
|||
/*
|
||||
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/saturation_protector_buffer.h"
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/numerics/safe_compare.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
SaturationProtectorBuffer::SaturationProtectorBuffer() = default;
|
||||
|
||||
SaturationProtectorBuffer::~SaturationProtectorBuffer() = default;
|
||||
|
||||
bool SaturationProtectorBuffer::operator==(
|
||||
const SaturationProtectorBuffer& b) const {
|
||||
RTC_DCHECK_LE(size_, buffer_.size());
|
||||
RTC_DCHECK_LE(b.size_, b.buffer_.size());
|
||||
if (size_ != b.size_) {
|
||||
return false;
|
||||
}
|
||||
for (int i = 0, i0 = FrontIndex(), i1 = b.FrontIndex(); i < size_;
|
||||
++i, ++i0, ++i1) {
|
||||
if (buffer_[i0 % buffer_.size()] != b.buffer_[i1 % b.buffer_.size()]) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
int SaturationProtectorBuffer::Capacity() const {
|
||||
return buffer_.size();
|
||||
}
|
||||
|
||||
int SaturationProtectorBuffer::Size() const {
|
||||
return size_;
|
||||
}
|
||||
|
||||
void SaturationProtectorBuffer::Reset() {
|
||||
next_ = 0;
|
||||
size_ = 0;
|
||||
}
|
||||
|
||||
void SaturationProtectorBuffer::PushBack(float v) {
|
||||
RTC_DCHECK_GE(next_, 0);
|
||||
RTC_DCHECK_GE(size_, 0);
|
||||
RTC_DCHECK_LT(next_, buffer_.size());
|
||||
RTC_DCHECK_LE(size_, buffer_.size());
|
||||
buffer_[next_++] = v;
|
||||
if (rtc::SafeEq(next_, buffer_.size())) {
|
||||
next_ = 0;
|
||||
}
|
||||
if (rtc::SafeLt(size_, buffer_.size())) {
|
||||
size_++;
|
||||
}
|
||||
}
|
||||
|
||||
absl::optional<float> SaturationProtectorBuffer::Front() const {
|
||||
if (size_ == 0) {
|
||||
return absl::nullopt;
|
||||
}
|
||||
RTC_DCHECK_LT(FrontIndex(), buffer_.size());
|
||||
return buffer_[FrontIndex()];
|
||||
}
|
||||
|
||||
int SaturationProtectorBuffer::FrontIndex() const {
|
||||
return rtc::SafeEq(size_, buffer_.size()) ? next_ : 0;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,59 @@
|
|||
/*
|
||||
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_BUFFER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_BUFFER_H_
|
||||
|
||||
#include <array>
|
||||
|
||||
#include "absl/types/optional.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// Ring buffer for the saturation protector which only supports (i) push back
|
||||
// and (ii) read oldest item.
|
||||
class SaturationProtectorBuffer {
|
||||
public:
|
||||
SaturationProtectorBuffer();
|
||||
~SaturationProtectorBuffer();
|
||||
|
||||
bool operator==(const SaturationProtectorBuffer& b) const;
|
||||
inline bool operator!=(const SaturationProtectorBuffer& b) const {
|
||||
return !(*this == b);
|
||||
}
|
||||
|
||||
// Maximum number of values that the buffer can contain.
|
||||
int Capacity() const;
|
||||
|
||||
// Number of values in the buffer.
|
||||
int Size() const;
|
||||
|
||||
void Reset();
|
||||
|
||||
// Pushes back `v`. If the buffer is full, the oldest value is replaced.
|
||||
void PushBack(float v);
|
||||
|
||||
// Returns the oldest item in the buffer. Returns an empty value if the
|
||||
// buffer is empty.
|
||||
absl::optional<float> Front() const;
|
||||
|
||||
private:
|
||||
int FrontIndex() const;
|
||||
// `buffer_` has `size_` elements (up to the size of `buffer_`) and `next_` is
|
||||
// the position where the next new value is written in `buffer_`.
|
||||
std::array<float, kSaturationProtectorBufferSize> buffer_;
|
||||
int next_ = 0;
|
||||
int size_ = 0;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_BUFFER_H_
|
||||
|
|
@ -0,0 +1,174 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/speech_level_estimator.h"
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/logging/apm_data_dumper.h"
|
||||
#include "rtc_base/checks.h"
|
||||
#include "rtc_base/logging.h"
|
||||
#include "rtc_base/numerics/safe_minmax.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
float ClampLevelEstimateDbfs(float level_estimate_dbfs) {
|
||||
return rtc::SafeClamp<float>(level_estimate_dbfs, -90.0f, 30.0f);
|
||||
}
|
||||
|
||||
// Returns the initial speech level estimate needed to apply the initial gain.
|
||||
float GetInitialSpeechLevelEstimateDbfs(
|
||||
const AudioProcessing::Config::GainController2::AdaptiveDigital& config) {
|
||||
return ClampLevelEstimateDbfs(-kSaturationProtectorInitialHeadroomDb -
|
||||
config.initial_gain_db - config.headroom_db);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
bool SpeechLevelEstimator::LevelEstimatorState::operator==(
|
||||
const SpeechLevelEstimator::LevelEstimatorState& b) const {
|
||||
return time_to_confidence_ms == b.time_to_confidence_ms &&
|
||||
level_dbfs.numerator == b.level_dbfs.numerator &&
|
||||
level_dbfs.denominator == b.level_dbfs.denominator;
|
||||
}
|
||||
|
||||
float SpeechLevelEstimator::LevelEstimatorState::Ratio::GetRatio() const {
|
||||
RTC_DCHECK_NE(denominator, 0.f);
|
||||
return numerator / denominator;
|
||||
}
|
||||
|
||||
SpeechLevelEstimator::SpeechLevelEstimator(
|
||||
ApmDataDumper* apm_data_dumper,
|
||||
const AudioProcessing::Config::GainController2::AdaptiveDigital& config,
|
||||
int adjacent_speech_frames_threshold)
|
||||
: apm_data_dumper_(apm_data_dumper),
|
||||
initial_speech_level_dbfs_(GetInitialSpeechLevelEstimateDbfs(config)),
|
||||
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
|
||||
level_dbfs_(initial_speech_level_dbfs_),
|
||||
// TODO(bugs.webrtc.org/7494): Remove init below when AGC2 input volume
|
||||
// controller temporal dependency removed.
|
||||
is_confident_(false) {
|
||||
RTC_DCHECK(apm_data_dumper_);
|
||||
RTC_DCHECK_GE(adjacent_speech_frames_threshold_, 1);
|
||||
Reset();
|
||||
}
|
||||
|
||||
void SpeechLevelEstimator::Update(float rms_dbfs,
|
||||
float peak_dbfs,
|
||||
float speech_probability) {
|
||||
RTC_DCHECK_GT(rms_dbfs, -150.0f);
|
||||
RTC_DCHECK_LT(rms_dbfs, 50.0f);
|
||||
RTC_DCHECK_GT(peak_dbfs, -150.0f);
|
||||
RTC_DCHECK_LT(peak_dbfs, 50.0f);
|
||||
RTC_DCHECK_GE(speech_probability, 0.0f);
|
||||
RTC_DCHECK_LE(speech_probability, 1.0f);
|
||||
if (speech_probability < kVadConfidenceThreshold) {
|
||||
// Not a speech frame.
|
||||
if (adjacent_speech_frames_threshold_ > 1) {
|
||||
// When two or more adjacent speech frames are required in order to update
|
||||
// the state, we need to decide whether to discard or confirm the updates
|
||||
// based on the speech sequence length.
|
||||
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
|
||||
// First non-speech frame after a long enough sequence of speech frames.
|
||||
// Update the reliable state.
|
||||
reliable_state_ = preliminary_state_;
|
||||
} else if (num_adjacent_speech_frames_ > 0) {
|
||||
// First non-speech frame after a too short sequence of speech frames.
|
||||
// Reset to the last reliable state.
|
||||
preliminary_state_ = reliable_state_;
|
||||
}
|
||||
}
|
||||
num_adjacent_speech_frames_ = 0;
|
||||
} else {
|
||||
// Speech frame observed.
|
||||
num_adjacent_speech_frames_++;
|
||||
|
||||
// Update preliminary level estimate.
|
||||
RTC_DCHECK_GE(preliminary_state_.time_to_confidence_ms, 0);
|
||||
const bool buffer_is_full = preliminary_state_.time_to_confidence_ms == 0;
|
||||
if (!buffer_is_full) {
|
||||
preliminary_state_.time_to_confidence_ms -= kFrameDurationMs;
|
||||
}
|
||||
// Weighted average of levels with speech probability as weight.
|
||||
RTC_DCHECK_GT(speech_probability, 0.0f);
|
||||
const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.0f;
|
||||
preliminary_state_.level_dbfs.numerator =
|
||||
preliminary_state_.level_dbfs.numerator * leak_factor +
|
||||
rms_dbfs * speech_probability;
|
||||
preliminary_state_.level_dbfs.denominator =
|
||||
preliminary_state_.level_dbfs.denominator * leak_factor +
|
||||
speech_probability;
|
||||
|
||||
const float level_dbfs = preliminary_state_.level_dbfs.GetRatio();
|
||||
|
||||
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
|
||||
// `preliminary_state_` is now reliable. Update the last level estimation.
|
||||
level_dbfs_ = ClampLevelEstimateDbfs(level_dbfs);
|
||||
}
|
||||
}
|
||||
UpdateIsConfident();
|
||||
DumpDebugData();
|
||||
}
|
||||
|
||||
void SpeechLevelEstimator::UpdateIsConfident() {
|
||||
if (adjacent_speech_frames_threshold_ == 1) {
|
||||
// Ignore `reliable_state_` when a single frame is enough to update the
|
||||
// level estimate (because it is not used).
|
||||
is_confident_ = preliminary_state_.time_to_confidence_ms == 0;
|
||||
return;
|
||||
}
|
||||
// Once confident, it remains confident.
|
||||
RTC_DCHECK(reliable_state_.time_to_confidence_ms != 0 ||
|
||||
preliminary_state_.time_to_confidence_ms == 0);
|
||||
// During the first long enough speech sequence, `reliable_state_` must be
|
||||
// ignored since `preliminary_state_` is used.
|
||||
is_confident_ =
|
||||
reliable_state_.time_to_confidence_ms == 0 ||
|
||||
(num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_ &&
|
||||
preliminary_state_.time_to_confidence_ms == 0);
|
||||
}
|
||||
|
||||
void SpeechLevelEstimator::Reset() {
|
||||
ResetLevelEstimatorState(preliminary_state_);
|
||||
ResetLevelEstimatorState(reliable_state_);
|
||||
level_dbfs_ = initial_speech_level_dbfs_;
|
||||
num_adjacent_speech_frames_ = 0;
|
||||
}
|
||||
|
||||
void SpeechLevelEstimator::ResetLevelEstimatorState(
|
||||
LevelEstimatorState& state) const {
|
||||
state.time_to_confidence_ms = kLevelEstimatorTimeToConfidenceMs;
|
||||
state.level_dbfs.numerator = initial_speech_level_dbfs_;
|
||||
state.level_dbfs.denominator = 1.0f;
|
||||
}
|
||||
|
||||
void SpeechLevelEstimator::DumpDebugData() const {
|
||||
if (!apm_data_dumper_)
|
||||
return;
|
||||
apm_data_dumper_->DumpRaw("agc2_speech_level_dbfs", level_dbfs_);
|
||||
apm_data_dumper_->DumpRaw("agc2_speech_level_is_confident", is_confident_);
|
||||
apm_data_dumper_->DumpRaw(
|
||||
"agc2_adaptive_level_estimator_num_adjacent_speech_frames",
|
||||
num_adjacent_speech_frames_);
|
||||
apm_data_dumper_->DumpRaw(
|
||||
"agc2_adaptive_level_estimator_preliminary_level_estimate_num",
|
||||
preliminary_state_.level_dbfs.numerator);
|
||||
apm_data_dumper_->DumpRaw(
|
||||
"agc2_adaptive_level_estimator_preliminary_level_estimate_den",
|
||||
preliminary_state_.level_dbfs.denominator);
|
||||
apm_data_dumper_->DumpRaw(
|
||||
"agc2_adaptive_level_estimator_preliminary_time_to_confidence_ms",
|
||||
preliminary_state_.time_to_confidence_ms);
|
||||
apm_data_dumper_->DumpRaw(
|
||||
"agc2_adaptive_level_estimator_reliable_time_to_confidence_ms",
|
||||
reliable_state_.time_to_confidence_ms);
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,81 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SPEECH_LEVEL_ESTIMATOR_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_SPEECH_LEVEL_ESTIMATOR_H_
|
||||
|
||||
#include <stddef.h>
|
||||
|
||||
#include <type_traits>
|
||||
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/include/audio_processing.h"
|
||||
|
||||
namespace webrtc {
|
||||
class ApmDataDumper;
|
||||
|
||||
// Active speech level estimator based on the analysis of the following
|
||||
// framewise properties: RMS level (dBFS), peak level (dBFS), speech
|
||||
// probability.
|
||||
class SpeechLevelEstimator {
|
||||
public:
|
||||
SpeechLevelEstimator(
|
||||
ApmDataDumper* apm_data_dumper,
|
||||
const AudioProcessing::Config::GainController2::AdaptiveDigital& config,
|
||||
int adjacent_speech_frames_threshold);
|
||||
SpeechLevelEstimator(const SpeechLevelEstimator&) = delete;
|
||||
SpeechLevelEstimator& operator=(const SpeechLevelEstimator&) = delete;
|
||||
|
||||
// Updates the level estimation.
|
||||
void Update(float rms_dbfs, float peak_dbfs, float speech_probability);
|
||||
// Returns the estimated speech plus noise level.
|
||||
float level_dbfs() const { return level_dbfs_; }
|
||||
// Returns true if the estimator is confident on its current estimate.
|
||||
bool is_confident() const { return is_confident_; }
|
||||
|
||||
void Reset();
|
||||
|
||||
private:
|
||||
// Part of the level estimator state used for check-pointing and restore ops.
|
||||
struct LevelEstimatorState {
|
||||
bool operator==(const LevelEstimatorState& s) const;
|
||||
inline bool operator!=(const LevelEstimatorState& s) const {
|
||||
return !(*this == s);
|
||||
}
|
||||
// TODO(bugs.webrtc.org/7494): Remove `time_to_confidence_ms` if redundant.
|
||||
int time_to_confidence_ms;
|
||||
struct Ratio {
|
||||
float numerator;
|
||||
float denominator;
|
||||
float GetRatio() const;
|
||||
} level_dbfs;
|
||||
};
|
||||
static_assert(std::is_trivially_copyable<LevelEstimatorState>::value, "");
|
||||
|
||||
void UpdateIsConfident();
|
||||
|
||||
void ResetLevelEstimatorState(LevelEstimatorState& state) const;
|
||||
|
||||
void DumpDebugData() const;
|
||||
|
||||
ApmDataDumper* const apm_data_dumper_;
|
||||
|
||||
const float initial_speech_level_dbfs_;
|
||||
const int adjacent_speech_frames_threshold_;
|
||||
LevelEstimatorState preliminary_state_;
|
||||
LevelEstimatorState reliable_state_;
|
||||
float level_dbfs_;
|
||||
bool is_confident_;
|
||||
int num_adjacent_speech_frames_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_SPEECH_LEVEL_ESTIMATOR_H_
|
||||
|
|
@ -0,0 +1,105 @@
|
|||
/*
|
||||
* Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/speech_probability_buffer.h"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
constexpr float kActivityThreshold = 0.9f;
|
||||
constexpr int kNumAnalysisFrames = 100;
|
||||
// We use 12 in AGC2 adaptive digital, but with a slightly different logic.
|
||||
constexpr int kTransientWidthThreshold = 7;
|
||||
|
||||
} // namespace
|
||||
|
||||
SpeechProbabilityBuffer::SpeechProbabilityBuffer(
|
||||
float low_probability_threshold)
|
||||
: low_probability_threshold_(low_probability_threshold),
|
||||
probabilities_(kNumAnalysisFrames) {
|
||||
RTC_DCHECK_GE(low_probability_threshold, 0.0f);
|
||||
RTC_DCHECK_LE(low_probability_threshold, 1.0f);
|
||||
RTC_DCHECK(!probabilities_.empty());
|
||||
}
|
||||
|
||||
void SpeechProbabilityBuffer::Update(float probability) {
|
||||
// Remove the oldest entry if the circular buffer is full.
|
||||
if (buffer_is_full_) {
|
||||
const float oldest_probability = probabilities_[buffer_index_];
|
||||
sum_probabilities_ -= oldest_probability;
|
||||
}
|
||||
|
||||
// Check for transients.
|
||||
if (probability <= low_probability_threshold_) {
|
||||
// Set a probability lower than the threshold to zero.
|
||||
probability = 0.0f;
|
||||
|
||||
// Check if this has been a transient.
|
||||
if (num_high_probability_observations_ <= kTransientWidthThreshold) {
|
||||
RemoveTransient();
|
||||
}
|
||||
num_high_probability_observations_ = 0;
|
||||
} else if (num_high_probability_observations_ <= kTransientWidthThreshold) {
|
||||
++num_high_probability_observations_;
|
||||
}
|
||||
|
||||
// Update the circular buffer and the current sum.
|
||||
probabilities_[buffer_index_] = probability;
|
||||
sum_probabilities_ += probability;
|
||||
|
||||
// Increment the buffer index and check for wrap-around.
|
||||
if (++buffer_index_ >= kNumAnalysisFrames) {
|
||||
buffer_index_ = 0;
|
||||
buffer_is_full_ = true;
|
||||
}
|
||||
}
|
||||
|
||||
void SpeechProbabilityBuffer::RemoveTransient() {
|
||||
// Don't expect to be here if high-activity region is longer than
|
||||
// `kTransientWidthThreshold` or there has not been any transient.
|
||||
RTC_DCHECK_LE(num_high_probability_observations_, kTransientWidthThreshold);
|
||||
|
||||
// Replace previously added probabilities with zero.
|
||||
int index =
|
||||
(buffer_index_ > 0) ? (buffer_index_ - 1) : (kNumAnalysisFrames - 1);
|
||||
|
||||
while (num_high_probability_observations_-- > 0) {
|
||||
sum_probabilities_ -= probabilities_[index];
|
||||
probabilities_[index] = 0.0f;
|
||||
|
||||
// Update the circular buffer index.
|
||||
index = (index > 0) ? (index - 1) : (kNumAnalysisFrames - 1);
|
||||
}
|
||||
}
|
||||
|
||||
bool SpeechProbabilityBuffer::IsActiveSegment() const {
|
||||
if (!buffer_is_full_) {
|
||||
return false;
|
||||
}
|
||||
if (sum_probabilities_ < kActivityThreshold * kNumAnalysisFrames) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
void SpeechProbabilityBuffer::Reset() {
|
||||
sum_probabilities_ = 0.0f;
|
||||
|
||||
// Empty the circular buffer.
|
||||
buffer_index_ = 0;
|
||||
buffer_is_full_ = false;
|
||||
num_high_probability_observations_ = 0;
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,80 @@
|
|||
/*
|
||||
* Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "rtc_base/gtest_prod_util.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// This class implements a circular buffer that stores speech probabilities
|
||||
// for a speech segment and estimates speech activity for that segment.
|
||||
class SpeechProbabilityBuffer {
|
||||
public:
|
||||
// Ctor. The value of `low_probability_threshold` is required to be on the
|
||||
// range [0.0f, 1.0f].
|
||||
explicit SpeechProbabilityBuffer(float low_probability_threshold);
|
||||
~SpeechProbabilityBuffer() {}
|
||||
SpeechProbabilityBuffer(const SpeechProbabilityBuffer&) = delete;
|
||||
SpeechProbabilityBuffer& operator=(const SpeechProbabilityBuffer&) = delete;
|
||||
|
||||
// Adds `probability` in the buffer and computes an updatds sum of the buffer
|
||||
// probabilities. Value of `probability` is required to be on the range
|
||||
// [0.0f, 1.0f].
|
||||
void Update(float probability);
|
||||
|
||||
// Resets the histogram, forgets the past.
|
||||
void Reset();
|
||||
|
||||
// Returns true if the segment is active (a long enough segment with an
|
||||
// average speech probability above `low_probability_threshold`).
|
||||
bool IsActiveSegment() const;
|
||||
|
||||
private:
|
||||
void RemoveTransient();
|
||||
|
||||
// Use only for testing.
|
||||
float GetSumProbabilities() const { return sum_probabilities_; }
|
||||
|
||||
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest,
|
||||
CheckSumAfterInitialization);
|
||||
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, CheckSumAfterUpdate);
|
||||
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, CheckSumAfterReset);
|
||||
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest,
|
||||
CheckSumAfterTransientNotRemoved);
|
||||
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest,
|
||||
CheckSumAfterTransientRemoved);
|
||||
|
||||
const float low_probability_threshold_;
|
||||
|
||||
// Sum of probabilities stored in `probabilities_`. Must be updated if
|
||||
// `probabilities_` is updated.
|
||||
float sum_probabilities_ = 0.0f;
|
||||
|
||||
// Circular buffer for probabilities.
|
||||
std::vector<float> probabilities_;
|
||||
|
||||
// Current index of the circular buffer, where the newest data will be written
|
||||
// to, therefore, pointing to the oldest data if buffer is full.
|
||||
int buffer_index_ = 0;
|
||||
|
||||
// Indicates if the buffer is full and adding a new value removes the oldest
|
||||
// value.
|
||||
int buffer_is_full_ = false;
|
||||
|
||||
int num_high_probability_observations_ = 0;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_
|
||||
|
|
@ -0,0 +1,113 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/vad_wrapper.h"
|
||||
|
||||
#include <array>
|
||||
#include <utility>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "common_audio/resampler/include/push_resampler.h"
|
||||
#include "modules/audio_processing/agc2/agc2_common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/common.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
|
||||
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
|
||||
#include "rtc_base/checks.h"
|
||||
|
||||
namespace webrtc {
|
||||
namespace {
|
||||
|
||||
constexpr int kNumFramesPerSecond = 100;
|
||||
|
||||
class MonoVadImpl : public VoiceActivityDetectorWrapper::MonoVad {
|
||||
public:
|
||||
explicit MonoVadImpl(const AvailableCpuFeatures& cpu_features)
|
||||
: features_extractor_(cpu_features), rnn_vad_(cpu_features) {}
|
||||
MonoVadImpl(const MonoVadImpl&) = delete;
|
||||
MonoVadImpl& operator=(const MonoVadImpl&) = delete;
|
||||
~MonoVadImpl() = default;
|
||||
|
||||
int SampleRateHz() const override { return rnn_vad::kSampleRate24kHz; }
|
||||
void Reset() override { rnn_vad_.Reset(); }
|
||||
float Analyze(rtc::ArrayView<const float> frame) override {
|
||||
RTC_DCHECK_EQ(frame.size(), rnn_vad::kFrameSize10ms24kHz);
|
||||
std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
|
||||
const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
|
||||
/*samples=*/{frame.data(), rnn_vad::kFrameSize10ms24kHz},
|
||||
feature_vector);
|
||||
return rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
|
||||
}
|
||||
|
||||
private:
|
||||
rnn_vad::FeaturesExtractor features_extractor_;
|
||||
rnn_vad::RnnVad rnn_vad_;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
|
||||
const AvailableCpuFeatures& cpu_features,
|
||||
int sample_rate_hz)
|
||||
: VoiceActivityDetectorWrapper(kVadResetPeriodMs,
|
||||
cpu_features,
|
||||
sample_rate_hz) {}
|
||||
|
||||
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
|
||||
int vad_reset_period_ms,
|
||||
const AvailableCpuFeatures& cpu_features,
|
||||
int sample_rate_hz)
|
||||
: VoiceActivityDetectorWrapper(vad_reset_period_ms,
|
||||
std::make_unique<MonoVadImpl>(cpu_features),
|
||||
sample_rate_hz) {}
|
||||
|
||||
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
|
||||
int vad_reset_period_ms,
|
||||
std::unique_ptr<MonoVad> vad,
|
||||
int sample_rate_hz)
|
||||
: vad_reset_period_frames_(
|
||||
rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
|
||||
time_to_vad_reset_(vad_reset_period_frames_),
|
||||
vad_(std::move(vad)) {
|
||||
RTC_DCHECK(vad_);
|
||||
RTC_DCHECK_GT(vad_reset_period_frames_, 1);
|
||||
resampled_buffer_.resize(
|
||||
rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond));
|
||||
Initialize(sample_rate_hz);
|
||||
}
|
||||
|
||||
VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default;
|
||||
|
||||
void VoiceActivityDetectorWrapper::Initialize(int sample_rate_hz) {
|
||||
RTC_DCHECK_GT(sample_rate_hz, 0);
|
||||
frame_size_ = rtc::CheckedDivExact(sample_rate_hz, kNumFramesPerSecond);
|
||||
int status =
|
||||
resampler_.InitializeIfNeeded(sample_rate_hz, vad_->SampleRateHz(),
|
||||
/*num_channels=*/1);
|
||||
constexpr int kStatusOk = 0;
|
||||
RTC_DCHECK_EQ(status, kStatusOk);
|
||||
vad_->Reset();
|
||||
}
|
||||
|
||||
float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) {
|
||||
// Periodically reset the VAD.
|
||||
time_to_vad_reset_--;
|
||||
if (time_to_vad_reset_ <= 0) {
|
||||
vad_->Reset();
|
||||
time_to_vad_reset_ = vad_reset_period_frames_;
|
||||
}
|
||||
// Resample the first channel of `frame`.
|
||||
RTC_DCHECK_EQ(frame.samples_per_channel(), frame_size_);
|
||||
resampler_.Resample(frame.channel(0).data(), frame_size_,
|
||||
resampled_buffer_.data(), resampled_buffer_.size());
|
||||
|
||||
return vad_->Analyze(resampled_buffer_);
|
||||
}
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,82 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "api/array_view.h"
|
||||
#include "common_audio/resampler/include/push_resampler.h"
|
||||
#include "modules/audio_processing/agc2/cpu_features.h"
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// Wraps a single-channel Voice Activity Detector (VAD) which is used to analyze
|
||||
// the first channel of the input audio frames. Takes care of resampling the
|
||||
// input frames to match the sample rate of the wrapped VAD and periodically
|
||||
// resets the VAD.
|
||||
class VoiceActivityDetectorWrapper {
|
||||
public:
|
||||
// Single channel VAD interface.
|
||||
class MonoVad {
|
||||
public:
|
||||
virtual ~MonoVad() = default;
|
||||
// Returns the sample rate (Hz) required for the input frames analyzed by
|
||||
// `ComputeProbability`.
|
||||
virtual int SampleRateHz() const = 0;
|
||||
// Resets the internal state.
|
||||
virtual void Reset() = 0;
|
||||
// Analyzes an audio frame and returns the speech probability.
|
||||
virtual float Analyze(rtc::ArrayView<const float> frame) = 0;
|
||||
};
|
||||
|
||||
// Ctor. Uses `cpu_features` to instantiate the default VAD.
|
||||
VoiceActivityDetectorWrapper(const AvailableCpuFeatures& cpu_features,
|
||||
int sample_rate_hz);
|
||||
|
||||
// Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
|
||||
// `MonoVad::Reset()`; it must be equal to or greater than the duration of two
|
||||
// frames. Uses `cpu_features` to instantiate the default VAD.
|
||||
VoiceActivityDetectorWrapper(int vad_reset_period_ms,
|
||||
const AvailableCpuFeatures& cpu_features,
|
||||
int sample_rate_hz);
|
||||
// Ctor. Uses a custom `vad`.
|
||||
VoiceActivityDetectorWrapper(int vad_reset_period_ms,
|
||||
std::unique_ptr<MonoVad> vad,
|
||||
int sample_rate_hz);
|
||||
|
||||
VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete;
|
||||
VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) =
|
||||
delete;
|
||||
~VoiceActivityDetectorWrapper();
|
||||
|
||||
// Initializes the VAD wrapper.
|
||||
void Initialize(int sample_rate_hz);
|
||||
|
||||
// Analyzes the first channel of `frame` and returns the speech probability.
|
||||
// `frame` must be a 10 ms frame with the sample rate specified in the last
|
||||
// `Initialize()` call.
|
||||
float Analyze(AudioFrameView<const float> frame);
|
||||
|
||||
private:
|
||||
const int vad_reset_period_frames_;
|
||||
int frame_size_;
|
||||
int time_to_vad_reset_;
|
||||
PushResampler<float> resampler_;
|
||||
std::unique_ptr<MonoVad> vad_;
|
||||
std::vector<float> resampled_buffer_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_
|
||||
|
|
@ -0,0 +1,39 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#include "modules/audio_processing/agc2/vector_float_frame.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
namespace {
|
||||
|
||||
std::vector<float*> ConstructChannelPointers(
|
||||
std::vector<std::vector<float>>* x) {
|
||||
std::vector<float*> channel_ptrs;
|
||||
for (auto& v : *x) {
|
||||
channel_ptrs.push_back(v.data());
|
||||
}
|
||||
return channel_ptrs;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
VectorFloatFrame::VectorFloatFrame(int num_channels,
|
||||
int samples_per_channel,
|
||||
float start_value)
|
||||
: channels_(num_channels,
|
||||
std::vector<float>(samples_per_channel, start_value)),
|
||||
channel_ptrs_(ConstructChannelPointers(&channels_)),
|
||||
float_frame_view_(channel_ptrs_.data(),
|
||||
channels_.size(),
|
||||
samples_per_channel) {}
|
||||
|
||||
VectorFloatFrame::~VectorFloatFrame() = default;
|
||||
|
||||
} // namespace webrtc
|
||||
|
|
@ -0,0 +1,42 @@
|
|||
/*
|
||||
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
|
||||
*
|
||||
* Use of this source code is governed by a BSD-style license
|
||||
* that can be found in the LICENSE file in the root of the source
|
||||
* tree. An additional intellectual property rights grant can be found
|
||||
* in the file PATENTS. All contributing project authors may
|
||||
* be found in the AUTHORS file in the root of the source tree.
|
||||
*/
|
||||
|
||||
#ifndef MODULES_AUDIO_PROCESSING_AGC2_VECTOR_FLOAT_FRAME_H_
|
||||
#define MODULES_AUDIO_PROCESSING_AGC2_VECTOR_FLOAT_FRAME_H_
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "modules/audio_processing/include/audio_frame_view.h"
|
||||
|
||||
namespace webrtc {
|
||||
|
||||
// A construct consisting of a multi-channel audio frame, and a FloatFrame view
|
||||
// of it.
|
||||
class VectorFloatFrame {
|
||||
public:
|
||||
VectorFloatFrame(int num_channels,
|
||||
int samples_per_channel,
|
||||
float start_value);
|
||||
const AudioFrameView<float>& float_frame_view() { return float_frame_view_; }
|
||||
AudioFrameView<const float> float_frame_view() const {
|
||||
return float_frame_view_;
|
||||
}
|
||||
|
||||
~VectorFloatFrame();
|
||||
|
||||
private:
|
||||
std::vector<std::vector<float>> channels_;
|
||||
std::vector<float*> channel_ptrs_;
|
||||
AudioFrameView<float> float_frame_view_;
|
||||
};
|
||||
|
||||
} // namespace webrtc
|
||||
|
||||
#endif // MODULES_AUDIO_PROCESSING_AGC2_VECTOR_FLOAT_FRAME_H_
|
||||
Loading…
Add table
Add a link
Reference in a new issue