Repo created

This commit is contained in:
Fr4nz D13trich 2025-11-22 14:04:28 +01:00
parent 81b91f4139
commit f8c34fa5ee
22732 changed files with 4815320 additions and 2 deletions

View file

@ -0,0 +1,216 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/adaptive_digital_gain_controller.h"
#include <algorithm>
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
#include "system_wrappers/include/metrics.h"
namespace webrtc {
namespace {
using AdaptiveDigitalConfig =
AudioProcessing::Config::GainController2::AdaptiveDigital;
constexpr int kHeadroomHistogramMin = 0;
constexpr int kHeadroomHistogramMax = 50;
constexpr int kGainDbHistogramMax = 30;
// Computes the gain for `input_level_dbfs` to reach `-config.headroom_db`.
// Clamps the gain in [0, `config.max_gain_db`]. `config.headroom_db` is a
// safety margin to allow transient peaks to exceed the target peak level
// without clipping.
float ComputeGainDb(float input_level_dbfs,
const AdaptiveDigitalConfig& config) {
// If the level is very low, apply the maximum gain.
if (input_level_dbfs < -(config.headroom_db + config.max_gain_db)) {
return config.max_gain_db;
}
// We expect to end up here most of the time: the level is below
// -headroom, but we can boost it to -headroom.
if (input_level_dbfs < -config.headroom_db) {
return -config.headroom_db - input_level_dbfs;
}
// The level is too high and we can't boost.
RTC_DCHECK_GE(input_level_dbfs, -config.headroom_db);
return 0.0f;
}
// Returns `target_gain_db` if applying such a gain to `input_noise_level_dbfs`
// does not exceed `max_output_noise_level_dbfs`. Otherwise lowers and returns
// `target_gain_db` so that the output noise level equals
// `max_output_noise_level_dbfs`.
float LimitGainByNoise(float target_gain_db,
float input_noise_level_dbfs,
float max_output_noise_level_dbfs,
ApmDataDumper& apm_data_dumper) {
const float max_allowed_gain_db =
max_output_noise_level_dbfs - input_noise_level_dbfs;
apm_data_dumper.DumpRaw("agc2_adaptive_gain_applier_max_allowed_gain_db",
max_allowed_gain_db);
return std::min(target_gain_db, std::max(max_allowed_gain_db, 0.0f));
}
float LimitGainByLowConfidence(float target_gain_db,
float last_gain_db,
float limiter_audio_level_dbfs,
bool estimate_is_confident) {
if (estimate_is_confident ||
limiter_audio_level_dbfs <= kLimiterThresholdForAgcGainDbfs) {
return target_gain_db;
}
const float limiter_level_dbfs_before_gain =
limiter_audio_level_dbfs - last_gain_db;
// Compute a new gain so that `limiter_level_dbfs_before_gain` +
// `new_target_gain_db` is not great than `kLimiterThresholdForAgcGainDbfs`.
const float new_target_gain_db = std::max(
kLimiterThresholdForAgcGainDbfs - limiter_level_dbfs_before_gain, 0.0f);
return std::min(new_target_gain_db, target_gain_db);
}
// Computes how the gain should change during this frame.
// Return the gain difference in db to 'last_gain_db'.
float ComputeGainChangeThisFrameDb(float target_gain_db,
float last_gain_db,
bool gain_increase_allowed,
float max_gain_decrease_db,
float max_gain_increase_db) {
RTC_DCHECK_GT(max_gain_decrease_db, 0);
RTC_DCHECK_GT(max_gain_increase_db, 0);
float target_gain_difference_db = target_gain_db - last_gain_db;
if (!gain_increase_allowed) {
target_gain_difference_db = std::min(target_gain_difference_db, 0.0f);
}
return rtc::SafeClamp(target_gain_difference_db, -max_gain_decrease_db,
max_gain_increase_db);
}
} // namespace
AdaptiveDigitalGainController::AdaptiveDigitalGainController(
ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2::AdaptiveDigital& config,
int adjacent_speech_frames_threshold)
: apm_data_dumper_(apm_data_dumper),
gain_applier_(
/*hard_clip_samples=*/false,
/*initial_gain_factor=*/DbToRatio(config.initial_gain_db)),
config_(config),
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
max_gain_change_db_per_10ms_(config_.max_gain_change_db_per_second *
kFrameDurationMs / 1000.0f),
calls_since_last_gain_log_(0),
frames_to_gain_increase_allowed_(adjacent_speech_frames_threshold),
last_gain_db_(config_.initial_gain_db) {
RTC_DCHECK_GT(max_gain_change_db_per_10ms_, 0.0f);
RTC_DCHECK_GE(frames_to_gain_increase_allowed_, 1);
RTC_DCHECK_GE(config_.max_output_noise_level_dbfs, -90.0f);
RTC_DCHECK_LE(config_.max_output_noise_level_dbfs, 0.0f);
}
void AdaptiveDigitalGainController::Process(const FrameInfo& info,
AudioFrameView<float> frame) {
RTC_DCHECK_GE(info.speech_level_dbfs, -150.0f);
RTC_DCHECK_GE(frame.num_channels(), 1);
RTC_DCHECK(
frame.samples_per_channel() == 80 || frame.samples_per_channel() == 160 ||
frame.samples_per_channel() == 320 || frame.samples_per_channel() == 480)
<< "`frame` does not look like a 10 ms frame for an APM supported sample "
"rate";
// Compute the input level used to select the desired gain.
RTC_DCHECK_GT(info.headroom_db, 0.0f);
const float input_level_dbfs = info.speech_level_dbfs + info.headroom_db;
const float target_gain_db = LimitGainByLowConfidence(
LimitGainByNoise(ComputeGainDb(input_level_dbfs, config_),
info.noise_rms_dbfs, config_.max_output_noise_level_dbfs,
*apm_data_dumper_),
last_gain_db_, info.limiter_envelope_dbfs, info.speech_level_reliable);
// Forbid increasing the gain until enough adjacent speech frames are
// observed.
bool first_confident_speech_frame = false;
if (info.speech_probability < kVadConfidenceThreshold) {
frames_to_gain_increase_allowed_ = adjacent_speech_frames_threshold_;
} else if (frames_to_gain_increase_allowed_ > 0) {
frames_to_gain_increase_allowed_--;
first_confident_speech_frame = frames_to_gain_increase_allowed_ == 0;
}
apm_data_dumper_->DumpRaw(
"agc2_adaptive_gain_applier_frames_to_gain_increase_allowed",
frames_to_gain_increase_allowed_);
const bool gain_increase_allowed = frames_to_gain_increase_allowed_ == 0;
float max_gain_increase_db = max_gain_change_db_per_10ms_;
if (first_confident_speech_frame) {
// No gain increase happened while waiting for a long enough speech
// sequence. Therefore, temporarily allow a faster gain increase.
RTC_DCHECK(gain_increase_allowed);
max_gain_increase_db *= adjacent_speech_frames_threshold_;
}
const float gain_change_this_frame_db = ComputeGainChangeThisFrameDb(
target_gain_db, last_gain_db_, gain_increase_allowed,
/*max_gain_decrease_db=*/max_gain_change_db_per_10ms_,
max_gain_increase_db);
apm_data_dumper_->DumpRaw("agc2_adaptive_gain_applier_want_to_change_by_db",
target_gain_db - last_gain_db_);
apm_data_dumper_->DumpRaw("agc2_adaptive_gain_applier_will_change_by_db",
gain_change_this_frame_db);
// Optimization: avoid calling math functions if gain does not
// change.
if (gain_change_this_frame_db != 0.f) {
gain_applier_.SetGainFactor(
DbToRatio(last_gain_db_ + gain_change_this_frame_db));
}
gain_applier_.ApplyGain(frame);
// Remember that the gain has changed for the next iteration.
last_gain_db_ = last_gain_db_ + gain_change_this_frame_db;
apm_data_dumper_->DumpRaw("agc2_adaptive_gain_applier_applied_gain_db",
last_gain_db_);
// Log every 10 seconds.
calls_since_last_gain_log_++;
if (calls_since_last_gain_log_ == 1000) {
calls_since_last_gain_log_ = 0;
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.EstimatedSpeechLevel",
-info.speech_level_dbfs, 0, 100, 101);
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.EstimatedNoiseLevel",
-info.noise_rms_dbfs, 0, 100, 101);
RTC_HISTOGRAM_COUNTS_LINEAR(
"WebRTC.Audio.Agc2.Headroom", info.headroom_db, kHeadroomHistogramMin,
kHeadroomHistogramMax,
kHeadroomHistogramMax - kHeadroomHistogramMin + 1);
RTC_HISTOGRAM_COUNTS_LINEAR("WebRTC.Audio.Agc2.DigitalGainApplied",
last_gain_db_, 0, kGainDbHistogramMax,
kGainDbHistogramMax + 1);
RTC_LOG(LS_INFO) << "AGC2 adaptive digital"
<< " | speech_dbfs: " << info.speech_level_dbfs
<< " | noise_dbfs: " << info.noise_rms_dbfs
<< " | headroom_db: " << info.headroom_db
<< " | gain_db: " << last_gain_db_;
}
}
} // namespace webrtc

View file

@ -0,0 +1,66 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_CONTROLLER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_CONTROLLER_H_
#include <vector>
#include "modules/audio_processing/agc2/gain_applier.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/include/audio_processing.h"
namespace webrtc {
class ApmDataDumper;
// Selects the target digital gain, decides when and how quickly to adapt to the
// target and applies the current gain to 10 ms frames.
class AdaptiveDigitalGainController {
public:
// Information about a frame to process.
struct FrameInfo {
float speech_probability; // Probability of speech in the [0, 1] range.
float speech_level_dbfs; // Estimated speech level (dBFS).
bool speech_level_reliable; // True with reliable speech level estimation.
float noise_rms_dbfs; // Estimated noise RMS level (dBFS).
float headroom_db; // Headroom (dB).
// TODO(bugs.webrtc.org/7494): Remove `limiter_envelope_dbfs`.
float limiter_envelope_dbfs; // Envelope level from the limiter (dBFS).
};
AdaptiveDigitalGainController(
ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2::AdaptiveDigital& config,
int adjacent_speech_frames_threshold);
AdaptiveDigitalGainController(const AdaptiveDigitalGainController&) = delete;
AdaptiveDigitalGainController& operator=(
const AdaptiveDigitalGainController&) = delete;
// Analyzes `info`, updates the digital gain and applies it to a 10 ms
// `frame`. Supports any sample rate supported by APM.
void Process(const FrameInfo& info, AudioFrameView<float> frame);
private:
ApmDataDumper* const apm_data_dumper_;
GainApplier gain_applier_;
const AudioProcessing::Config::GainController2::AdaptiveDigital config_;
const int adjacent_speech_frames_threshold_;
const float max_gain_change_db_per_10ms_;
int calls_since_last_gain_log_;
int frames_to_gain_increase_allowed_;
float last_gain_db_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_ADAPTIVE_DIGITAL_GAIN_CONTROLLER_H_

View file

@ -0,0 +1,62 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_
#define MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_
namespace webrtc {
constexpr float kMinFloatS16Value = -32768.0f;
constexpr float kMaxFloatS16Value = 32767.0f;
constexpr float kMaxAbsFloatS16Value = 32768.0f;
// Minimum audio level in dBFS scale for S16 samples.
constexpr float kMinLevelDbfs = -90.31f;
constexpr int kFrameDurationMs = 10;
constexpr int kSubFramesInFrame = 20;
constexpr int kMaximalNumberOfSamplesPerChannel = 480;
// Adaptive digital gain applier settings.
// At what limiter levels should we start decreasing the adaptive digital gain.
constexpr float kLimiterThresholdForAgcGainDbfs = -1.0f;
// Number of milliseconds to wait to periodically reset the VAD.
constexpr int kVadResetPeriodMs = 1500;
// Speech probability threshold to detect speech activity.
constexpr float kVadConfidenceThreshold = 0.95f;
// Minimum number of adjacent speech frames having a sufficiently high speech
// probability to reliably detect speech activity.
constexpr int kAdjacentSpeechFramesThreshold = 12;
// Number of milliseconds of speech frames to observe to make the estimator
// confident.
constexpr float kLevelEstimatorTimeToConfidenceMs = 400;
constexpr float kLevelEstimatorLeakFactor =
1.0f - 1.0f / kLevelEstimatorTimeToConfidenceMs;
// Saturation Protector settings.
constexpr float kSaturationProtectorInitialHeadroomDb = 20.0f;
constexpr int kSaturationProtectorBufferSize = 4;
// Number of interpolation points for each region of the limiter.
// These values have been tuned to limit the interpolated gain curve error given
// the limiter parameters and allowing a maximum error of +/- 32768^-1.
constexpr int kInterpolatedGainCurveKneePoints = 22;
constexpr int kInterpolatedGainCurveBeyondKneePoints = 10;
constexpr int kInterpolatedGainCurveTotalPoints =
kInterpolatedGainCurveKneePoints + kInterpolatedGainCurveBeyondKneePoints;
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_AGC2_COMMON_H_

View file

@ -0,0 +1,94 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/agc2_testing_common.h"
#include <math.h>
#include "rtc_base/checks.h"
namespace webrtc {
namespace test {
std::vector<double> LinSpace(double l, double r, int num_points) {
RTC_CHECK_GE(num_points, 2);
std::vector<double> points(num_points);
const double step = (r - l) / (num_points - 1.0);
points[0] = l;
for (int i = 1; i < num_points - 1; i++) {
points[i] = static_cast<double>(l) + i * step;
}
points[num_points - 1] = r;
return points;
}
WhiteNoiseGenerator::WhiteNoiseGenerator(int min_amplitude, int max_amplitude)
: rand_gen_(42),
min_amplitude_(min_amplitude),
max_amplitude_(max_amplitude) {
RTC_DCHECK_LT(min_amplitude_, max_amplitude_);
RTC_DCHECK_LE(kMinS16, min_amplitude_);
RTC_DCHECK_LE(min_amplitude_, kMaxS16);
RTC_DCHECK_LE(kMinS16, max_amplitude_);
RTC_DCHECK_LE(max_amplitude_, kMaxS16);
}
float WhiteNoiseGenerator::operator()() {
return static_cast<float>(rand_gen_.Rand(min_amplitude_, max_amplitude_));
}
SineGenerator::SineGenerator(float amplitude,
float frequency_hz,
int sample_rate_hz)
: amplitude_(amplitude),
frequency_hz_(frequency_hz),
sample_rate_hz_(sample_rate_hz),
x_radians_(0.0f) {
RTC_DCHECK_GT(amplitude_, 0);
RTC_DCHECK_LE(amplitude_, kMaxS16);
}
float SineGenerator::operator()() {
constexpr float kPi = 3.1415926536f;
x_radians_ += frequency_hz_ / sample_rate_hz_ * 2 * kPi;
if (x_radians_ >= 2 * kPi) {
x_radians_ -= 2 * kPi;
}
// Use sinf instead of std::sinf for libstdc++ compatibility.
return amplitude_ * sinf(x_radians_);
}
PulseGenerator::PulseGenerator(float pulse_amplitude,
float no_pulse_amplitude,
float frequency_hz,
int sample_rate_hz)
: pulse_amplitude_(pulse_amplitude),
no_pulse_amplitude_(no_pulse_amplitude),
samples_period_(
static_cast<int>(static_cast<float>(sample_rate_hz) / frequency_hz)),
sample_counter_(0) {
RTC_DCHECK_GE(pulse_amplitude_, kMinS16);
RTC_DCHECK_LE(pulse_amplitude_, kMaxS16);
RTC_DCHECK_GT(no_pulse_amplitude_, kMinS16);
RTC_DCHECK_LE(no_pulse_amplitude_, kMaxS16);
RTC_DCHECK_GT(sample_rate_hz, frequency_hz);
}
float PulseGenerator::operator()() {
sample_counter_++;
if (sample_counter_ >= samples_period_) {
sample_counter_ -= samples_period_;
}
return static_cast<float>(sample_counter_ == 0 ? pulse_amplitude_
: no_pulse_amplitude_);
}
} // namespace test
} // namespace webrtc

View file

@ -0,0 +1,82 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_
#define MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_
#include <limits>
#include <vector>
#include "rtc_base/random.h"
namespace webrtc {
namespace test {
constexpr float kMinS16 =
static_cast<float>(std::numeric_limits<int16_t>::min());
constexpr float kMaxS16 =
static_cast<float>(std::numeric_limits<int16_t>::max());
// Level Estimator test parameters.
constexpr float kDecayMs = 20.0f;
// Limiter parameters.
constexpr float kLimiterMaxInputLevelDbFs = 1.f;
constexpr float kLimiterKneeSmoothnessDb = 1.f;
constexpr float kLimiterCompressionRatio = 5.f;
// Returns evenly spaced `num_points` numbers over a specified interval [l, r].
std::vector<double> LinSpace(double l, double r, int num_points);
// Generates white noise.
class WhiteNoiseGenerator {
public:
WhiteNoiseGenerator(int min_amplitude, int max_amplitude);
float operator()();
private:
Random rand_gen_;
const int min_amplitude_;
const int max_amplitude_;
};
// Generates a sine function.
class SineGenerator {
public:
SineGenerator(float amplitude, float frequency_hz, int sample_rate_hz);
float operator()();
private:
const float amplitude_;
const float frequency_hz_;
const int sample_rate_hz_;
float x_radians_;
};
// Generates periodic pulses.
class PulseGenerator {
public:
PulseGenerator(float pulse_amplitude,
float no_pulse_amplitude,
float frequency_hz,
int sample_rate_hz);
float operator()();
private:
const float pulse_amplitude_;
const float no_pulse_amplitude_;
const int samples_period_;
int sample_counter_;
};
} // namespace test
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_AGC2_TESTING_COMMON_H_

View file

@ -0,0 +1,60 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/biquad_filter.h"
#include "rtc_base/arraysize.h"
namespace webrtc {
BiQuadFilter::BiQuadFilter(const Config& config)
: config_(config), state_({}) {}
BiQuadFilter::~BiQuadFilter() = default;
void BiQuadFilter::SetConfig(const Config& config) {
config_ = config;
state_ = {};
}
void BiQuadFilter::Reset() {
state_ = {};
}
void BiQuadFilter::Process(rtc::ArrayView<const float> x,
rtc::ArrayView<float> y) {
RTC_DCHECK_EQ(x.size(), y.size());
const float config_a0 = config_.a[0];
const float config_a1 = config_.a[1];
const float config_b0 = config_.b[0];
const float config_b1 = config_.b[1];
const float config_b2 = config_.b[2];
float state_a0 = state_.a[0];
float state_a1 = state_.a[1];
float state_b0 = state_.b[0];
float state_b1 = state_.b[1];
for (size_t k = 0, x_size = x.size(); k < x_size; ++k) {
// Use a temporary variable for `x[k]` to allow in-place processing.
const float tmp = x[k];
float y_k = config_b0 * tmp + config_b1 * state_b0 + config_b2 * state_b1 -
config_a0 * state_a0 - config_a1 * state_a1;
state_b1 = state_b0;
state_b0 = tmp;
state_a1 = state_a0;
state_a0 = y_k;
y[k] = y_k;
}
state_.a[0] = state_a0;
state_.a[1] = state_a1;
state_.b[0] = state_b0;
state_.b[1] = state_b1;
}
} // namespace webrtc

View file

@ -0,0 +1,56 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_
#include "api/array_view.h"
namespace webrtc {
// Transposed direct form I implementation of a bi-quad filter.
// b[0] + b[1] • z^(-1) + b[2] • z^(-2)
// H(z) = ------------------------------------
// 1 + a[1] • z^(-1) + a[2] • z^(-2)
class BiQuadFilter {
public:
// Normalized filter coefficients.
// Computed as `[b, a] = scipy.signal.butter(N=2, Wn, btype)`.
struct Config {
float b[3]; // b[0], b[1], b[2].
float a[2]; // a[1], a[2].
};
explicit BiQuadFilter(const Config& config);
BiQuadFilter(const BiQuadFilter&) = delete;
BiQuadFilter& operator=(const BiQuadFilter&) = delete;
~BiQuadFilter();
// Sets the filter configuration and resets the internal state.
void SetConfig(const Config& config);
// Zeroes the filter state.
void Reset();
// Filters `x` and writes the output in `y`, which must have the same length
// of `x`. In-place processing is supported.
void Process(rtc::ArrayView<const float> x, rtc::ArrayView<float> y);
private:
Config config_;
struct State {
float b[2];
float a[2];
} state_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_BIQUAD_FILTER_H_

View file

@ -0,0 +1,175 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/biquad_filter.h"
#include <algorithm>
#include <array>
#include <cmath>
// TODO(bugs.webrtc.org/8948): Add when the issue is fixed.
// #include "test/fpe_observer.h"
#include "rtc_base/gunit.h"
namespace webrtc {
namespace {
constexpr int kFrameSize = 8;
constexpr int kNumFrames = 4;
using FloatArraySequence =
std::array<std::array<float, kFrameSize>, kNumFrames>;
constexpr FloatArraySequence kBiQuadInputSeq = {
{{{-87.166290f, -8.029022f, 101.619583f, -0.294296f, -5.825764f, -8.890625f,
10.310432f, 54.845333f}},
{{-64.647644f, -6.883945f, 11.059189f, -95.242538f, -108.870834f,
11.024944f, 63.044102f, -52.709583f}},
{{-32.350529f, -18.108028f, -74.022339f, -8.986874f, -1.525581f,
103.705513f, 6.346226f, -14.319557f}},
{{22.645832f, -64.597153f, 55.462521f, -109.393188f, 10.117825f,
-40.019642f, -98.612228f, -8.330326f}}}};
// Computed as `scipy.signal.butter(N=2, Wn=60/24000, btype='highpass')`.
constexpr BiQuadFilter::Config kBiQuadConfig{
{0.99446179f, -1.98892358f, 0.99446179f},
{-1.98889291f, 0.98895425f}};
// Comparing to scipy. The expected output is generated as follows:
// zi = np.float32([0, 0])
// for i in range(4):
// yn, zi = scipy.signal.lfilter(B, A, x[i], zi=zi)
// print(yn)
constexpr FloatArraySequence kBiQuadOutputSeq = {
{{{-86.68354497f, -7.02175351f, 102.10290352f, -0.37487333f, -5.87205847f,
-8.85521608f, 10.33772563f, 54.51157181f}},
{{-64.92531604f, -6.76395978f, 11.15534507f, -94.68073341f, -107.18177856f,
13.24642474f, 64.84288941f, -50.97822629f}},
{{-30.1579652f, -15.64850899f, -71.06662821f, -5.5883229f, 1.91175353f,
106.5572003f, 8.57183046f, -12.06298473f}},
{{24.84286614f, -62.18094158f, 57.91488056f, -106.65685933f, 13.38760103f,
-36.60367134f, -94.44880104f, -3.59920354f}}}};
// Fails for every pair from two equally sized rtc::ArrayView<float> views such
// that their relative error is above a given threshold. If the expected value
// of a pair is 0, `tolerance` is used to check the absolute error.
void ExpectNearRelative(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed,
const float tolerance) {
// The relative error is undefined when the expected value is 0.
// When that happens, check the absolute error instead. `safe_den` is used
// below to implement such logic.
auto safe_den = [](float x) { return (x == 0.0f) ? 1.0f : std::fabs(x); };
ASSERT_EQ(expected.size(), computed.size());
for (size_t i = 0; i < expected.size(); ++i) {
const float abs_diff = std::fabs(expected[i] - computed[i]);
// No failure when the values are equal.
if (abs_diff == 0.0f) {
continue;
}
SCOPED_TRACE(i);
SCOPED_TRACE(expected[i]);
SCOPED_TRACE(computed[i]);
EXPECT_LE(abs_diff / safe_den(expected[i]), tolerance);
}
}
// Checks that filtering works when different containers are used both as input
// and as output.
TEST(BiQuadFilterTest, FilterNotInPlace) {
BiQuadFilter filter(kBiQuadConfig);
std::array<float, kFrameSize> samples;
// TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
for (int i = 0; i < kNumFrames; ++i) {
SCOPED_TRACE(i);
filter.Process(kBiQuadInputSeq[i], samples);
ExpectNearRelative(kBiQuadOutputSeq[i], samples, 2e-4f);
}
}
// Checks that filtering works when the same container is used both as input and
// as output.
TEST(BiQuadFilterTest, FilterInPlace) {
BiQuadFilter filter(kBiQuadConfig);
std::array<float, kFrameSize> samples;
// TODO(https://bugs.webrtc.org/8948): Add when the issue is fixed.
// FloatingPointExceptionObserver fpe_observer;
for (int i = 0; i < kNumFrames; ++i) {
SCOPED_TRACE(i);
std::copy(kBiQuadInputSeq[i].begin(), kBiQuadInputSeq[i].end(),
samples.begin());
filter.Process({samples}, {samples});
ExpectNearRelative(kBiQuadOutputSeq[i], samples, 2e-4f);
}
}
// Checks that different configurations produce different outputs.
TEST(BiQuadFilterTest, SetConfigDifferentOutput) {
BiQuadFilter filter(/*config=*/{{0.97803048f, -1.95606096f, 0.97803048f},
{-1.95557824f, 0.95654368f}});
std::array<float, kFrameSize> samples1;
for (int i = 0; i < kNumFrames; ++i) {
filter.Process(kBiQuadInputSeq[i], samples1);
}
filter.SetConfig(
{{0.09763107f, 0.19526215f, 0.09763107f}, {-0.94280904f, 0.33333333f}});
std::array<float, kFrameSize> samples2;
for (int i = 0; i < kNumFrames; ++i) {
filter.Process(kBiQuadInputSeq[i], samples2);
}
EXPECT_NE(samples1, samples2);
}
// Checks that when `SetConfig()` is called but the filter coefficients are the
// same the filter state is reset.
TEST(BiQuadFilterTest, SetConfigResetsState) {
BiQuadFilter filter(kBiQuadConfig);
std::array<float, kFrameSize> samples1;
for (int i = 0; i < kNumFrames; ++i) {
filter.Process(kBiQuadInputSeq[i], samples1);
}
filter.SetConfig(kBiQuadConfig);
std::array<float, kFrameSize> samples2;
for (int i = 0; i < kNumFrames; ++i) {
filter.Process(kBiQuadInputSeq[i], samples2);
}
EXPECT_EQ(samples1, samples2);
}
// Checks that when `Reset()` is called the filter state is reset.
TEST(BiQuadFilterTest, Reset) {
BiQuadFilter filter(kBiQuadConfig);
std::array<float, kFrameSize> samples1;
for (int i = 0; i < kNumFrames; ++i) {
filter.Process(kBiQuadInputSeq[i], samples1);
}
filter.Reset();
std::array<float, kFrameSize> samples2;
for (int i = 0; i < kNumFrames; ++i) {
filter.Process(kBiQuadInputSeq[i], samples2);
}
EXPECT_EQ(samples1, samples2);
}
} // namespace
} // namespace webrtc

View file

@ -0,0 +1,384 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/clipping_predictor.h"
#include <algorithm>
#include <memory>
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/clipping_predictor_level_buffer.h"
#include "modules/audio_processing/agc2/gain_map_internal.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
constexpr int kClippingPredictorMaxGainChange = 15;
// Returns an input volume in the [`min_input_volume`, `max_input_volume`] range
// that reduces `gain_error_db`, which is a gain error estimated when
// `input_volume` was applied, according to a fixed gain map.
int ComputeVolumeUpdate(int gain_error_db,
int input_volume,
int min_input_volume,
int max_input_volume) {
RTC_DCHECK_GE(input_volume, 0);
RTC_DCHECK_LE(input_volume, max_input_volume);
if (gain_error_db == 0) {
return input_volume;
}
int new_volume = input_volume;
if (gain_error_db > 0) {
while (kGainMap[new_volume] - kGainMap[input_volume] < gain_error_db &&
new_volume < max_input_volume) {
++new_volume;
}
} else {
while (kGainMap[new_volume] - kGainMap[input_volume] > gain_error_db &&
new_volume > min_input_volume) {
--new_volume;
}
}
return new_volume;
}
float ComputeCrestFactor(const ClippingPredictorLevelBuffer::Level& level) {
const float crest_factor =
FloatS16ToDbfs(level.max) - FloatS16ToDbfs(std::sqrt(level.average));
return crest_factor;
}
// Crest factor-based clipping prediction and clipped level step estimation.
class ClippingEventPredictor : public ClippingPredictor {
public:
// ClippingEventPredictor with `num_channels` channels (limited to values
// higher than zero); window size `window_length` and reference window size
// `reference_window_length` (both referring to the number of frames in the
// respective sliding windows and limited to values higher than zero);
// reference window delay `reference_window_delay` (delay in frames, limited
// to values zero and higher with an additional requirement of
// `window_length` < `reference_window_length` + reference_window_delay`);
// and an estimation peak threshold `clipping_threshold` and a crest factor
// drop threshold `crest_factor_margin` (both in dB).
ClippingEventPredictor(int num_channels,
int window_length,
int reference_window_length,
int reference_window_delay,
float clipping_threshold,
float crest_factor_margin)
: window_length_(window_length),
reference_window_length_(reference_window_length),
reference_window_delay_(reference_window_delay),
clipping_threshold_(clipping_threshold),
crest_factor_margin_(crest_factor_margin) {
RTC_DCHECK_GT(num_channels, 0);
RTC_DCHECK_GT(window_length, 0);
RTC_DCHECK_GT(reference_window_length, 0);
RTC_DCHECK_GE(reference_window_delay, 0);
RTC_DCHECK_GT(reference_window_length + reference_window_delay,
window_length);
const int buffer_length = GetMinFramesProcessed();
RTC_DCHECK_GT(buffer_length, 0);
for (int i = 0; i < num_channels; ++i) {
ch_buffers_.push_back(
std::make_unique<ClippingPredictorLevelBuffer>(buffer_length));
}
}
ClippingEventPredictor(const ClippingEventPredictor&) = delete;
ClippingEventPredictor& operator=(const ClippingEventPredictor&) = delete;
~ClippingEventPredictor() {}
void Reset() {
const int num_channels = ch_buffers_.size();
for (int i = 0; i < num_channels; ++i) {
ch_buffers_[i]->Reset();
}
}
// Analyzes a frame of audio and stores the framewise metrics in
// `ch_buffers_`.
void Analyze(const AudioFrameView<const float>& frame) {
const int num_channels = frame.num_channels();
RTC_DCHECK_EQ(num_channels, ch_buffers_.size());
const int samples_per_channel = frame.samples_per_channel();
RTC_DCHECK_GT(samples_per_channel, 0);
for (int channel = 0; channel < num_channels; ++channel) {
float sum_squares = 0.0f;
float peak = 0.0f;
for (const auto& sample : frame.channel(channel)) {
sum_squares += sample * sample;
peak = std::max(std::fabs(sample), peak);
}
ch_buffers_[channel]->Push(
{sum_squares / static_cast<float>(samples_per_channel), peak});
}
}
// Estimates the analog gain adjustment for channel `channel` using a
// sliding window over the frame-wise metrics in `ch_buffers_`. Returns an
// estimate for the clipped level step equal to `default_clipped_level_step_`
// if at least `GetMinFramesProcessed()` frames have been processed since the
// last reset and a clipping event is predicted. `level`, `min_mic_level`, and
// `max_mic_level` are limited to [0, 255] and `default_step` to [1, 255].
absl::optional<int> EstimateClippedLevelStep(int channel,
int level,
int default_step,
int min_mic_level,
int max_mic_level) const {
RTC_CHECK_GE(channel, 0);
RTC_CHECK_LT(channel, ch_buffers_.size());
RTC_DCHECK_GE(level, 0);
RTC_DCHECK_LE(level, 255);
RTC_DCHECK_GT(default_step, 0);
RTC_DCHECK_LE(default_step, 255);
RTC_DCHECK_GE(min_mic_level, 0);
RTC_DCHECK_LE(min_mic_level, 255);
RTC_DCHECK_GE(max_mic_level, 0);
RTC_DCHECK_LE(max_mic_level, 255);
if (level <= min_mic_level) {
return absl::nullopt;
}
if (PredictClippingEvent(channel)) {
const int new_level =
rtc::SafeClamp(level - default_step, min_mic_level, max_mic_level);
const int step = level - new_level;
if (step > 0) {
return step;
}
}
return absl::nullopt;
}
private:
int GetMinFramesProcessed() const {
return reference_window_delay_ + reference_window_length_;
}
// Predicts clipping events based on the processed audio frames. Returns
// true if a clipping event is likely.
bool PredictClippingEvent(int channel) const {
const auto metrics =
ch_buffers_[channel]->ComputePartialMetrics(0, window_length_);
if (!metrics.has_value() ||
!(FloatS16ToDbfs(metrics.value().max) > clipping_threshold_)) {
return false;
}
const auto reference_metrics = ch_buffers_[channel]->ComputePartialMetrics(
reference_window_delay_, reference_window_length_);
if (!reference_metrics.has_value()) {
return false;
}
const float crest_factor = ComputeCrestFactor(metrics.value());
const float reference_crest_factor =
ComputeCrestFactor(reference_metrics.value());
if (crest_factor < reference_crest_factor - crest_factor_margin_) {
return true;
}
return false;
}
std::vector<std::unique_ptr<ClippingPredictorLevelBuffer>> ch_buffers_;
const int window_length_;
const int reference_window_length_;
const int reference_window_delay_;
const float clipping_threshold_;
const float crest_factor_margin_;
};
// Performs crest factor-based clipping peak prediction.
class ClippingPeakPredictor : public ClippingPredictor {
public:
// Ctor. ClippingPeakPredictor with `num_channels` channels (limited to values
// higher than zero); window size `window_length` and reference window size
// `reference_window_length` (both referring to the number of frames in the
// respective sliding windows and limited to values higher than zero);
// reference window delay `reference_window_delay` (delay in frames, limited
// to values zero and higher with an additional requirement of
// `window_length` < `reference_window_length` + reference_window_delay`);
// and a clipping prediction threshold `clipping_threshold` (in dB). Adaptive
// clipped level step estimation is used if `adaptive_step_estimation` is
// true.
explicit ClippingPeakPredictor(int num_channels,
int window_length,
int reference_window_length,
int reference_window_delay,
int clipping_threshold,
bool adaptive_step_estimation)
: window_length_(window_length),
reference_window_length_(reference_window_length),
reference_window_delay_(reference_window_delay),
clipping_threshold_(clipping_threshold),
adaptive_step_estimation_(adaptive_step_estimation) {
RTC_DCHECK_GT(num_channels, 0);
RTC_DCHECK_GT(window_length, 0);
RTC_DCHECK_GT(reference_window_length, 0);
RTC_DCHECK_GE(reference_window_delay, 0);
RTC_DCHECK_GT(reference_window_length + reference_window_delay,
window_length);
const int buffer_length = GetMinFramesProcessed();
RTC_DCHECK_GT(buffer_length, 0);
for (int i = 0; i < num_channels; ++i) {
ch_buffers_.push_back(
std::make_unique<ClippingPredictorLevelBuffer>(buffer_length));
}
}
ClippingPeakPredictor(const ClippingPeakPredictor&) = delete;
ClippingPeakPredictor& operator=(const ClippingPeakPredictor&) = delete;
~ClippingPeakPredictor() {}
void Reset() {
const int num_channels = ch_buffers_.size();
for (int i = 0; i < num_channels; ++i) {
ch_buffers_[i]->Reset();
}
}
// Analyzes a frame of audio and stores the framewise metrics in
// `ch_buffers_`.
void Analyze(const AudioFrameView<const float>& frame) {
const int num_channels = frame.num_channels();
RTC_DCHECK_EQ(num_channels, ch_buffers_.size());
const int samples_per_channel = frame.samples_per_channel();
RTC_DCHECK_GT(samples_per_channel, 0);
for (int channel = 0; channel < num_channels; ++channel) {
float sum_squares = 0.0f;
float peak = 0.0f;
for (const auto& sample : frame.channel(channel)) {
sum_squares += sample * sample;
peak = std::max(std::fabs(sample), peak);
}
ch_buffers_[channel]->Push(
{sum_squares / static_cast<float>(samples_per_channel), peak});
}
}
// Estimates the analog gain adjustment for channel `channel` using a
// sliding window over the frame-wise metrics in `ch_buffers_`. Returns an
// estimate for the clipped level step (equal to
// `default_clipped_level_step_` if `adaptive_estimation_` is false) if at
// least `GetMinFramesProcessed()` frames have been processed since the last
// reset and a clipping event is predicted. `level`, `min_mic_level`, and
// `max_mic_level` are limited to [0, 255] and `default_step` to [1, 255].
absl::optional<int> EstimateClippedLevelStep(int channel,
int level,
int default_step,
int min_mic_level,
int max_mic_level) const {
RTC_DCHECK_GE(channel, 0);
RTC_DCHECK_LT(channel, ch_buffers_.size());
RTC_DCHECK_GE(level, 0);
RTC_DCHECK_LE(level, 255);
RTC_DCHECK_GT(default_step, 0);
RTC_DCHECK_LE(default_step, 255);
RTC_DCHECK_GE(min_mic_level, 0);
RTC_DCHECK_LE(min_mic_level, 255);
RTC_DCHECK_GE(max_mic_level, 0);
RTC_DCHECK_LE(max_mic_level, 255);
if (level <= min_mic_level) {
return absl::nullopt;
}
absl::optional<float> estimate_db = EstimatePeakValue(channel);
if (estimate_db.has_value() && estimate_db.value() > clipping_threshold_) {
int step = 0;
if (!adaptive_step_estimation_) {
step = default_step;
} else {
const int estimated_gain_change =
rtc::SafeClamp(-static_cast<int>(std::ceil(estimate_db.value())),
-kClippingPredictorMaxGainChange, 0);
step =
std::max(level - ComputeVolumeUpdate(estimated_gain_change, level,
min_mic_level, max_mic_level),
default_step);
}
const int new_level =
rtc::SafeClamp(level - step, min_mic_level, max_mic_level);
if (level > new_level) {
return level - new_level;
}
}
return absl::nullopt;
}
private:
int GetMinFramesProcessed() {
return reference_window_delay_ + reference_window_length_;
}
// Predicts clipping sample peaks based on the processed audio frames.
// Returns the estimated peak value if clipping is predicted. Otherwise
// returns absl::nullopt.
absl::optional<float> EstimatePeakValue(int channel) const {
const auto reference_metrics = ch_buffers_[channel]->ComputePartialMetrics(
reference_window_delay_, reference_window_length_);
if (!reference_metrics.has_value()) {
return absl::nullopt;
}
const auto metrics =
ch_buffers_[channel]->ComputePartialMetrics(0, window_length_);
if (!metrics.has_value() ||
!(FloatS16ToDbfs(metrics.value().max) > clipping_threshold_)) {
return absl::nullopt;
}
const float reference_crest_factor =
ComputeCrestFactor(reference_metrics.value());
const float& mean_squares = metrics.value().average;
const float projected_peak =
reference_crest_factor + FloatS16ToDbfs(std::sqrt(mean_squares));
return projected_peak;
}
std::vector<std::unique_ptr<ClippingPredictorLevelBuffer>> ch_buffers_;
const int window_length_;
const int reference_window_length_;
const int reference_window_delay_;
const int clipping_threshold_;
const bool adaptive_step_estimation_;
};
} // namespace
std::unique_ptr<ClippingPredictor> CreateClippingPredictor(
int num_channels,
const AudioProcessing::Config::GainController1::AnalogGainController::
ClippingPredictor& config) {
if (!config.enabled) {
RTC_LOG(LS_INFO) << "[AGC2] Clipping prediction disabled.";
return nullptr;
}
RTC_LOG(LS_INFO) << "[AGC2] Clipping prediction enabled.";
using ClippingPredictorMode = AudioProcessing::Config::GainController1::
AnalogGainController::ClippingPredictor::Mode;
switch (config.mode) {
case ClippingPredictorMode::kClippingEventPrediction:
return std::make_unique<ClippingEventPredictor>(
num_channels, config.window_length, config.reference_window_length,
config.reference_window_delay, config.clipping_threshold,
config.crest_factor_margin);
case ClippingPredictorMode::kAdaptiveStepClippingPeakPrediction:
return std::make_unique<ClippingPeakPredictor>(
num_channels, config.window_length, config.reference_window_length,
config.reference_window_delay, config.clipping_threshold,
/*adaptive_step_estimation=*/true);
case ClippingPredictorMode::kFixedStepClippingPeakPrediction:
return std::make_unique<ClippingPeakPredictor>(
num_channels, config.window_length, config.reference_window_length,
config.reference_window_delay, config.clipping_threshold,
/*adaptive_step_estimation=*/false);
}
RTC_DCHECK_NOTREACHED();
}
} // namespace webrtc

View file

@ -0,0 +1,62 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "modules/audio_processing/include/audio_processing.h"
namespace webrtc {
// Frame-wise clipping prediction and clipped level step estimation. Analyzes
// 10 ms multi-channel frames and estimates an analog mic level decrease step
// to possibly avoid clipping when predicted. `Analyze()` and
// `EstimateClippedLevelStep()` can be called in any order.
class ClippingPredictor {
public:
virtual ~ClippingPredictor() = default;
virtual void Reset() = 0;
// Analyzes a 10 ms multi-channel audio frame.
virtual void Analyze(const AudioFrameView<const float>& frame) = 0;
// Predicts if clipping is going to occur for the specified `channel` in the
// near-future and, if so, it returns a recommended analog mic level decrease
// step. Returns absl::nullopt if clipping is not predicted.
// `level` is the current analog mic level, `default_step` is the amount the
// mic level is lowered by the analog controller with every clipping event and
// `min_mic_level` and `max_mic_level` is the range of allowed analog mic
// levels.
virtual absl::optional<int> EstimateClippedLevelStep(
int channel,
int level,
int default_step,
int min_mic_level,
int max_mic_level) const = 0;
};
// Creates a ClippingPredictor based on the provided `config`. When enabled,
// the following must hold for `config`:
// `window_length < reference_window_length + reference_window_delay`.
// Returns `nullptr` if `config.enabled` is false.
std::unique_ptr<ClippingPredictor> CreateClippingPredictor(
int num_channels,
const AudioProcessing::Config::GainController1::AnalogGainController::
ClippingPredictor& config);
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_H_

View file

@ -0,0 +1,77 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/clipping_predictor_level_buffer.h"
#include <algorithm>
#include <cmath>
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
namespace webrtc {
bool ClippingPredictorLevelBuffer::Level::operator==(const Level& level) const {
constexpr float kEpsilon = 1e-6f;
return std::fabs(average - level.average) < kEpsilon &&
std::fabs(max - level.max) < kEpsilon;
}
ClippingPredictorLevelBuffer::ClippingPredictorLevelBuffer(int capacity)
: tail_(-1), size_(0), data_(std::max(1, capacity)) {
if (capacity > kMaxCapacity) {
RTC_LOG(LS_WARNING) << "[agc]: ClippingPredictorLevelBuffer exceeds the "
<< "maximum allowed capacity. Capacity: " << capacity;
}
RTC_DCHECK(!data_.empty());
}
void ClippingPredictorLevelBuffer::Reset() {
tail_ = -1;
size_ = 0;
}
void ClippingPredictorLevelBuffer::Push(Level level) {
++tail_;
if (tail_ == Capacity()) {
tail_ = 0;
}
if (size_ < Capacity()) {
size_++;
}
data_[tail_] = level;
}
// TODO(bugs.webrtc.org/12774): Optimize partial computation for long buffers.
absl::optional<ClippingPredictorLevelBuffer::Level>
ClippingPredictorLevelBuffer::ComputePartialMetrics(int delay,
int num_items) const {
RTC_DCHECK_GE(delay, 0);
RTC_DCHECK_LT(delay, Capacity());
RTC_DCHECK_GT(num_items, 0);
RTC_DCHECK_LE(num_items, Capacity());
RTC_DCHECK_LE(delay + num_items, Capacity());
if (delay + num_items > Size()) {
return absl::nullopt;
}
float sum = 0.0f;
float max = 0.0f;
for (int i = 0; i < num_items && i < Size(); ++i) {
int idx = tail_ - delay - i;
if (idx < 0) {
idx += Capacity();
}
sum += data_[idx].average;
max = std::fmax(data_[idx].max, max);
}
return absl::optional<Level>({sum / static_cast<float>(num_items), max});
}
} // namespace webrtc

View file

@ -0,0 +1,71 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_LEVEL_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_LEVEL_BUFFER_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
namespace webrtc {
// A circular buffer to store frame-wise `Level` items for clipping prediction.
// The current implementation is not optimized for large buffer lengths.
class ClippingPredictorLevelBuffer {
public:
struct Level {
float average;
float max;
bool operator==(const Level& level) const;
};
// Recommended maximum capacity. It is possible to create a buffer with a
// larger capacity, but the implementation is not optimized for large values.
static constexpr int kMaxCapacity = 100;
// Ctor. Sets the buffer capacity to max(1, `capacity`) and logs a warning
// message if the capacity is greater than `kMaxCapacity`.
explicit ClippingPredictorLevelBuffer(int capacity);
~ClippingPredictorLevelBuffer() {}
ClippingPredictorLevelBuffer(const ClippingPredictorLevelBuffer&) = delete;
ClippingPredictorLevelBuffer& operator=(const ClippingPredictorLevelBuffer&) =
delete;
void Reset();
// Returns the current number of items stored in the buffer.
int Size() const { return size_; }
// Returns the capacity of the buffer.
int Capacity() const { return data_.size(); }
// Adds a `level` item into the circular buffer `data_`. Stores at most
// `Capacity()` items. If more items are pushed, the new item replaces the
// least recently pushed item.
void Push(Level level);
// If at least `num_items` + `delay` items have been pushed, returns the
// average and maximum value for the `num_items` most recently pushed items
// from `delay` to `delay` - `num_items` (a delay equal to zero corresponds
// to the most recently pushed item). The value of `delay` is limited to
// [0, N] and `num_items` to [1, M] where N + M is the capacity of the buffer.
absl::optional<Level> ComputePartialMetrics(int delay, int num_items) const;
private:
int tail_;
int size_;
std::vector<Level> data_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_CLIPPING_PREDICTOR_LEVEL_BUFFER_H_

View file

@ -0,0 +1,229 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/compute_interpolated_gain_curve.h"
#include <algorithm>
#include <cmath>
#include <queue>
#include <tuple>
#include <utility>
#include <vector>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/agc2_testing_common.h"
#include "modules/audio_processing/agc2/limiter_db_gain_curve.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
std::pair<double, double> ComputeLinearApproximationParams(
const LimiterDbGainCurve* limiter,
const double x) {
const double m = limiter->GetGainFirstDerivativeLinear(x);
const double q = limiter->GetGainLinear(x) - m * x;
return {m, q};
}
double ComputeAreaUnderPiecewiseLinearApproximation(
const LimiterDbGainCurve* limiter,
const double x0,
const double x1) {
RTC_CHECK_LT(x0, x1);
// Linear approximation in x0 and x1.
double m0, q0, m1, q1;
std::tie(m0, q0) = ComputeLinearApproximationParams(limiter, x0);
std::tie(m1, q1) = ComputeLinearApproximationParams(limiter, x1);
// Intersection point between two adjacent linear pieces.
RTC_CHECK_NE(m1, m0);
const double x_split = (q0 - q1) / (m1 - m0);
RTC_CHECK_LT(x0, x_split);
RTC_CHECK_LT(x_split, x1);
auto area_under_linear_piece = [](double x_l, double x_r, double m,
double q) {
return x_r * (m * x_r / 2.0 + q) - x_l * (m * x_l / 2.0 + q);
};
return area_under_linear_piece(x0, x_split, m0, q0) +
area_under_linear_piece(x_split, x1, m1, q1);
}
// Computes the approximation error in the limiter region for a given interval.
// The error is computed as the difference between the areas beneath the limiter
// curve to approximate and its linear under-approximation.
double LimiterUnderApproximationNegativeError(const LimiterDbGainCurve* limiter,
const double x0,
const double x1) {
const double area_limiter = limiter->GetGainIntegralLinear(x0, x1);
const double area_interpolated_curve =
ComputeAreaUnderPiecewiseLinearApproximation(limiter, x0, x1);
RTC_CHECK_GE(area_limiter, area_interpolated_curve);
return area_limiter - area_interpolated_curve;
}
// Automatically finds where to sample the beyond-knee region of a limiter using
// a greedy optimization algorithm that iteratively decreases the approximation
// error.
// The solution is sub-optimal because the algorithm is greedy and the points
// are assigned by halving intervals (starting with the whole beyond-knee region
// as a single interval). However, even if sub-optimal, this algorithm works
// well in practice and it is efficiently implemented using priority queues.
std::vector<double> SampleLimiterRegion(const LimiterDbGainCurve* limiter) {
static_assert(kInterpolatedGainCurveBeyondKneePoints > 2, "");
struct Interval {
Interval() = default; // Ctor required by std::priority_queue.
Interval(double l, double r, double e) : x0(l), x1(r), error(e) {
RTC_CHECK(x0 < x1);
}
bool operator<(const Interval& other) const { return error < other.error; }
double x0;
double x1;
double error;
};
std::priority_queue<Interval, std::vector<Interval>> q;
q.emplace(limiter->limiter_start_linear(), limiter->max_input_level_linear(),
LimiterUnderApproximationNegativeError(
limiter, limiter->limiter_start_linear(),
limiter->max_input_level_linear()));
// Iteratively find points by halving the interval with greatest error.
while (q.size() < kInterpolatedGainCurveBeyondKneePoints) {
// Get the interval with highest error.
const auto interval = q.top();
q.pop();
// Split `interval` and enqueue.
double x_split = (interval.x0 + interval.x1) / 2.0;
q.emplace(interval.x0, x_split,
LimiterUnderApproximationNegativeError(limiter, interval.x0,
x_split)); // Left.
q.emplace(x_split, interval.x1,
LimiterUnderApproximationNegativeError(limiter, x_split,
interval.x1)); // Right.
}
// Copy x1 values and sort them.
RTC_CHECK_EQ(q.size(), kInterpolatedGainCurveBeyondKneePoints);
std::vector<double> samples(kInterpolatedGainCurveBeyondKneePoints);
for (size_t i = 0; i < kInterpolatedGainCurveBeyondKneePoints; ++i) {
const auto interval = q.top();
q.pop();
samples[i] = interval.x1;
}
RTC_CHECK(q.empty());
std::sort(samples.begin(), samples.end());
return samples;
}
// Compute the parameters to over-approximate the knee region via linear
// interpolation. Over-approximating is saturation-safe since the knee region is
// convex.
void PrecomputeKneeApproxParams(const LimiterDbGainCurve* limiter,
test::InterpolatedParameters* parameters) {
static_assert(kInterpolatedGainCurveKneePoints > 2, "");
// Get `kInterpolatedGainCurveKneePoints` - 1 equally spaced points.
const std::vector<double> points = test::LinSpace(
limiter->knee_start_linear(), limiter->limiter_start_linear(),
kInterpolatedGainCurveKneePoints - 1);
// Set the first two points. The second is computed to help with the beginning
// of the knee region, which has high curvature.
parameters->computed_approximation_params_x[0] = points[0];
parameters->computed_approximation_params_x[1] =
(points[0] + points[1]) / 2.0;
// Copy the remaining points.
std::copy(std::begin(points) + 1, std::end(points),
std::begin(parameters->computed_approximation_params_x) + 2);
// Compute (m, q) pairs for each linear piece y = mx + q.
for (size_t i = 0; i < kInterpolatedGainCurveKneePoints - 1; ++i) {
const double x0 = parameters->computed_approximation_params_x[i];
const double x1 = parameters->computed_approximation_params_x[i + 1];
const double y0 = limiter->GetGainLinear(x0);
const double y1 = limiter->GetGainLinear(x1);
RTC_CHECK_NE(x1, x0);
parameters->computed_approximation_params_m[i] = (y1 - y0) / (x1 - x0);
parameters->computed_approximation_params_q[i] =
y0 - parameters->computed_approximation_params_m[i] * x0;
}
}
// Compute the parameters to under-approximate the beyond-knee region via linear
// interpolation and greedy sampling. Under-approximating is saturation-safe
// since the beyond-knee region is concave.
void PrecomputeBeyondKneeApproxParams(
const LimiterDbGainCurve* limiter,
test::InterpolatedParameters* parameters) {
// Find points on which the linear pieces are tangent to the gain curve.
const auto samples = SampleLimiterRegion(limiter);
// Parametrize each linear piece.
double m, q;
std::tie(m, q) = ComputeLinearApproximationParams(
limiter,
parameters
->computed_approximation_params_x[kInterpolatedGainCurveKneePoints -
1]);
parameters
->computed_approximation_params_m[kInterpolatedGainCurveKneePoints - 1] =
m;
parameters
->computed_approximation_params_q[kInterpolatedGainCurveKneePoints - 1] =
q;
for (size_t i = 0; i < samples.size(); ++i) {
std::tie(m, q) = ComputeLinearApproximationParams(limiter, samples[i]);
parameters
->computed_approximation_params_m[i +
kInterpolatedGainCurveKneePoints] = m;
parameters
->computed_approximation_params_q[i +
kInterpolatedGainCurveKneePoints] = q;
}
// Find the point of intersection between adjacent linear pieces. They will be
// used as boundaries between adjacent linear pieces.
for (size_t i = kInterpolatedGainCurveKneePoints;
i < kInterpolatedGainCurveKneePoints +
kInterpolatedGainCurveBeyondKneePoints;
++i) {
RTC_CHECK_NE(parameters->computed_approximation_params_m[i],
parameters->computed_approximation_params_m[i - 1]);
parameters->computed_approximation_params_x[i] =
( // Formula: (q0 - q1) / (m1 - m0).
parameters->computed_approximation_params_q[i - 1] -
parameters->computed_approximation_params_q[i]) /
(parameters->computed_approximation_params_m[i] -
parameters->computed_approximation_params_m[i - 1]);
}
}
} // namespace
namespace test {
InterpolatedParameters ComputeInterpolatedGainCurveApproximationParams() {
InterpolatedParameters parameters;
LimiterDbGainCurve limiter;
parameters.computed_approximation_params_x.fill(0.0f);
parameters.computed_approximation_params_m.fill(0.0f);
parameters.computed_approximation_params_q.fill(0.0f);
PrecomputeKneeApproxParams(&limiter, &parameters);
PrecomputeBeyondKneeApproxParams(&limiter, &parameters);
return parameters;
}
} // namespace test
} // namespace webrtc

View file

@ -0,0 +1,48 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_COMPUTE_INTERPOLATED_GAIN_CURVE_H_
#define MODULES_AUDIO_PROCESSING_AGC2_COMPUTE_INTERPOLATED_GAIN_CURVE_H_
#include <array>
#include "modules/audio_processing/agc2/agc2_common.h"
namespace webrtc {
namespace test {
// Parameters for interpolated gain curve using under-approximation to
// avoid saturation.
//
// The saturation gain is defined in order to let hard-clipping occur for
// those samples having a level that falls in the saturation region. It is an
// upper bound of the actual gain to apply - i.e., that returned by the
// limiter.
// Knee and beyond-knee regions approximation parameters.
// The gain curve is approximated as a piece-wise linear function.
// `approx_params_x_` are the boundaries between adjacent linear pieces,
// `approx_params_m_` and `approx_params_q_` are the slope and the y-intercept
// values of each piece.
struct InterpolatedParameters {
std::array<float, kInterpolatedGainCurveTotalPoints>
computed_approximation_params_x;
std::array<float, kInterpolatedGainCurveTotalPoints>
computed_approximation_params_m;
std::array<float, kInterpolatedGainCurveTotalPoints>
computed_approximation_params_q;
};
InterpolatedParameters ComputeInterpolatedGainCurveApproximationParams();
} // namespace test
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_COMPUTE_INTERPOLATED_GAIN_CURVE_H_

View file

@ -0,0 +1,62 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/cpu_features.h"
#include "rtc_base/strings/string_builder.h"
#include "rtc_base/system/arch.h"
#include "system_wrappers/include/cpu_features_wrapper.h"
namespace webrtc {
std::string AvailableCpuFeatures::ToString() const {
char buf[64];
rtc::SimpleStringBuilder builder(buf);
bool first = true;
if (sse2) {
builder << (first ? "SSE2" : "_SSE2");
first = false;
}
if (avx2) {
builder << (first ? "AVX2" : "_AVX2");
first = false;
}
if (neon) {
builder << (first ? "NEON" : "_NEON");
first = false;
}
if (first) {
return "none";
}
return builder.str();
}
// Detects available CPU features.
AvailableCpuFeatures GetAvailableCpuFeatures() {
#if defined(WEBRTC_ARCH_X86_FAMILY)
return {/*sse2=*/GetCPUInfo(kSSE2) != 0,
/*avx2=*/GetCPUInfo(kAVX2) != 0,
/*neon=*/false};
#elif defined(WEBRTC_HAS_NEON)
return {/*sse2=*/false,
/*avx2=*/false,
/*neon=*/true};
#else
return {/*sse2=*/false,
/*avx2=*/false,
/*neon=*/false};
#endif
}
AvailableCpuFeatures NoAvailableCpuFeatures() {
return {/*sse2=*/false, /*avx2=*/false, /*neon=*/false};
}
} // namespace webrtc

View file

@ -0,0 +1,39 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_
#define MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_
#include <string>
namespace webrtc {
// Collection of flags indicating which CPU features are available on the
// current platform. True means available.
struct AvailableCpuFeatures {
AvailableCpuFeatures(bool sse2, bool avx2, bool neon)
: sse2(sse2), avx2(avx2), neon(neon) {}
// Intel.
bool sse2;
bool avx2;
// ARM.
bool neon;
std::string ToString() const;
};
// Detects what CPU features are available.
AvailableCpuFeatures GetAvailableCpuFeatures();
// Returns the CPU feature flags all set to false.
AvailableCpuFeatures NoAvailableCpuFeatures();
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_CPU_FEATURES_H_

View file

@ -0,0 +1,121 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/fixed_digital_level_estimator.h"
#include <algorithm>
#include <cmath>
#include "api/array_view.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr float kInitialFilterStateLevel = 0.0f;
// Instant attack.
constexpr float kAttackFilterConstant = 0.0f;
// Limiter decay constant.
// Computed as `10 ** (-1/20 * subframe_duration / kDecayMs)` where:
// - `subframe_duration` is `kFrameDurationMs / kSubFramesInFrame`;
// - `kDecayMs` is defined in agc2_testing_common.h.
constexpr float kDecayFilterConstant = 0.9971259f;
} // namespace
FixedDigitalLevelEstimator::FixedDigitalLevelEstimator(
int sample_rate_hz,
ApmDataDumper* apm_data_dumper)
: apm_data_dumper_(apm_data_dumper),
filter_state_level_(kInitialFilterStateLevel) {
SetSampleRate(sample_rate_hz);
CheckParameterCombination();
RTC_DCHECK(apm_data_dumper_);
apm_data_dumper_->DumpRaw("agc2_level_estimator_samplerate", sample_rate_hz);
}
void FixedDigitalLevelEstimator::CheckParameterCombination() {
RTC_DCHECK_GT(samples_in_frame_, 0);
RTC_DCHECK_LE(kSubFramesInFrame, samples_in_frame_);
RTC_DCHECK_EQ(samples_in_frame_ % kSubFramesInFrame, 0);
RTC_DCHECK_GT(samples_in_sub_frame_, 1);
}
std::array<float, kSubFramesInFrame> FixedDigitalLevelEstimator::ComputeLevel(
const AudioFrameView<const float>& float_frame) {
RTC_DCHECK_GT(float_frame.num_channels(), 0);
RTC_DCHECK_EQ(float_frame.samples_per_channel(), samples_in_frame_);
// Compute max envelope without smoothing.
std::array<float, kSubFramesInFrame> envelope{};
for (int channel_idx = 0; channel_idx < float_frame.num_channels();
++channel_idx) {
const auto channel = float_frame.channel(channel_idx);
for (int sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) {
for (int sample_in_sub_frame = 0;
sample_in_sub_frame < samples_in_sub_frame_; ++sample_in_sub_frame) {
envelope[sub_frame] =
std::max(envelope[sub_frame],
std::abs(channel[sub_frame * samples_in_sub_frame_ +
sample_in_sub_frame]));
}
}
}
// Make sure envelope increases happen one step earlier so that the
// corresponding *gain decrease* doesn't miss a sudden signal
// increase due to interpolation.
for (int sub_frame = 0; sub_frame < kSubFramesInFrame - 1; ++sub_frame) {
if (envelope[sub_frame] < envelope[sub_frame + 1]) {
envelope[sub_frame] = envelope[sub_frame + 1];
}
}
// Add attack / decay smoothing.
for (int sub_frame = 0; sub_frame < kSubFramesInFrame; ++sub_frame) {
const float envelope_value = envelope[sub_frame];
if (envelope_value > filter_state_level_) {
envelope[sub_frame] = envelope_value * (1 - kAttackFilterConstant) +
filter_state_level_ * kAttackFilterConstant;
} else {
envelope[sub_frame] = envelope_value * (1 - kDecayFilterConstant) +
filter_state_level_ * kDecayFilterConstant;
}
filter_state_level_ = envelope[sub_frame];
// Dump data for debug.
RTC_DCHECK(apm_data_dumper_);
const auto channel = float_frame.channel(0);
apm_data_dumper_->DumpRaw("agc2_level_estimator_samples",
samples_in_sub_frame_,
&channel[sub_frame * samples_in_sub_frame_]);
apm_data_dumper_->DumpRaw("agc2_level_estimator_level",
envelope[sub_frame]);
}
return envelope;
}
void FixedDigitalLevelEstimator::SetSampleRate(int sample_rate_hz) {
samples_in_frame_ =
rtc::CheckedDivExact(sample_rate_hz * kFrameDurationMs, 1000);
samples_in_sub_frame_ =
rtc::CheckedDivExact(samples_in_frame_, kSubFramesInFrame);
CheckParameterCombination();
}
void FixedDigitalLevelEstimator::Reset() {
filter_state_level_ = kInitialFilterStateLevel;
}
} // namespace webrtc

View file

@ -0,0 +1,66 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_FIXED_DIGITAL_LEVEL_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_FIXED_DIGITAL_LEVEL_ESTIMATOR_H_
#include <array>
#include <vector>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
class ApmDataDumper;
// Produces a smooth signal level estimate from an input audio
// stream. The estimate smoothing is done through exponential
// filtering.
class FixedDigitalLevelEstimator {
public:
// Sample rates are allowed if the number of samples in a frame
// (sample_rate_hz * kFrameDurationMs / 1000) is divisible by
// kSubFramesInSample. For kFrameDurationMs=10 and
// kSubFramesInSample=20, this means that sample_rate_hz has to be
// divisible by 2000.
FixedDigitalLevelEstimator(int sample_rate_hz,
ApmDataDumper* apm_data_dumper);
FixedDigitalLevelEstimator(const FixedDigitalLevelEstimator&) = delete;
FixedDigitalLevelEstimator& operator=(const FixedDigitalLevelEstimator&) =
delete;
// The input is assumed to be in FloatS16 format. Scaled input will
// produce similarly scaled output. A frame of with kFrameDurationMs
// ms of audio produces a level estimates in the same scale. The
// level estimate contains kSubFramesInFrame values.
std::array<float, kSubFramesInFrame> ComputeLevel(
const AudioFrameView<const float>& float_frame);
// Rate may be changed at any time (but not concurrently) from the
// value passed to the constructor. The class is not thread safe.
void SetSampleRate(int sample_rate_hz);
// Resets the level estimator internal state.
void Reset();
float LastAudioLevel() const { return filter_state_level_; }
private:
void CheckParameterCombination();
ApmDataDumper* const apm_data_dumper_ = nullptr;
float filter_state_level_;
int samples_in_frame_;
int samples_in_sub_frame_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_FIXED_DIGITAL_LEVEL_ESTIMATOR_H_

View file

@ -0,0 +1,103 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/gain_applier.h"
#include "api/array_view.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
// Returns true when the gain factor is so close to 1 that it would
// not affect int16 samples.
bool GainCloseToOne(float gain_factor) {
return 1.f - 1.f / kMaxFloatS16Value <= gain_factor &&
gain_factor <= 1.f + 1.f / kMaxFloatS16Value;
}
void ClipSignal(AudioFrameView<float> signal) {
for (int k = 0; k < signal.num_channels(); ++k) {
rtc::ArrayView<float> channel_view = signal.channel(k);
for (auto& sample : channel_view) {
sample = rtc::SafeClamp(sample, kMinFloatS16Value, kMaxFloatS16Value);
}
}
}
void ApplyGainWithRamping(float last_gain_linear,
float gain_at_end_of_frame_linear,
float inverse_samples_per_channel,
AudioFrameView<float> float_frame) {
// Do not modify the signal.
if (last_gain_linear == gain_at_end_of_frame_linear &&
GainCloseToOne(gain_at_end_of_frame_linear)) {
return;
}
// Gain is constant and different from 1.
if (last_gain_linear == gain_at_end_of_frame_linear) {
for (int k = 0; k < float_frame.num_channels(); ++k) {
rtc::ArrayView<float> channel_view = float_frame.channel(k);
for (auto& sample : channel_view) {
sample *= gain_at_end_of_frame_linear;
}
}
return;
}
// The gain changes. We have to change slowly to avoid discontinuities.
const float increment = (gain_at_end_of_frame_linear - last_gain_linear) *
inverse_samples_per_channel;
float gain = last_gain_linear;
for (int i = 0; i < float_frame.samples_per_channel(); ++i) {
for (int ch = 0; ch < float_frame.num_channels(); ++ch) {
float_frame.channel(ch)[i] *= gain;
}
gain += increment;
}
}
} // namespace
GainApplier::GainApplier(bool hard_clip_samples, float initial_gain_factor)
: hard_clip_samples_(hard_clip_samples),
last_gain_factor_(initial_gain_factor),
current_gain_factor_(initial_gain_factor) {}
void GainApplier::ApplyGain(AudioFrameView<float> signal) {
if (static_cast<int>(signal.samples_per_channel()) != samples_per_channel_) {
Initialize(signal.samples_per_channel());
}
ApplyGainWithRamping(last_gain_factor_, current_gain_factor_,
inverse_samples_per_channel_, signal);
last_gain_factor_ = current_gain_factor_;
if (hard_clip_samples_) {
ClipSignal(signal);
}
}
// TODO(bugs.webrtc.org/7494): Remove once switched to gains in dB.
void GainApplier::SetGainFactor(float gain_factor) {
RTC_DCHECK_GT(gain_factor, 0.f);
current_gain_factor_ = gain_factor;
}
void GainApplier::Initialize(int samples_per_channel) {
RTC_DCHECK_GT(samples_per_channel, 0);
samples_per_channel_ = static_cast<int>(samples_per_channel);
inverse_samples_per_channel_ = 1.f / samples_per_channel_;
}
} // namespace webrtc

View file

@ -0,0 +1,44 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_GAIN_APPLIER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_GAIN_APPLIER_H_
#include <stddef.h>
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
class GainApplier {
public:
GainApplier(bool hard_clip_samples, float initial_gain_factor);
void ApplyGain(AudioFrameView<float> signal);
void SetGainFactor(float gain_factor);
float GetGainFactor() const { return current_gain_factor_; }
private:
void Initialize(int samples_per_channel);
// Whether to clip samples after gain is applied. If 'true', result
// will fit in FloatS16 range.
const bool hard_clip_samples_;
float last_gain_factor_;
// If this value is not equal to 'last_gain_factor', gain will be
// ramped from 'last_gain_factor_' to this value during the next
// 'ApplyGain'.
float current_gain_factor_;
int samples_per_channel_ = -1;
float inverse_samples_per_channel_ = -1.f;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_GAIN_APPLIER_H_

View file

@ -0,0 +1,93 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/gain_applier.h"
#include <math.h>
#include <algorithm>
#include <limits>
#include "modules/audio_processing/agc2/vector_float_frame.h"
#include "rtc_base/gunit.h"
namespace webrtc {
TEST(AutomaticGainController2GainApplier, InitialGainIsRespected) {
constexpr float initial_signal_level = 123.f;
constexpr float gain_factor = 10.f;
VectorFloatFrame fake_audio(1, 1, initial_signal_level);
GainApplier gain_applier(true, gain_factor);
gain_applier.ApplyGain(fake_audio.float_frame_view());
EXPECT_NEAR(fake_audio.float_frame_view().channel(0)[0],
initial_signal_level * gain_factor, 0.1f);
}
TEST(AutomaticGainController2GainApplier, ClippingIsDone) {
constexpr float initial_signal_level = 30000.f;
constexpr float gain_factor = 10.f;
VectorFloatFrame fake_audio(1, 1, initial_signal_level);
GainApplier gain_applier(true, gain_factor);
gain_applier.ApplyGain(fake_audio.float_frame_view());
EXPECT_NEAR(fake_audio.float_frame_view().channel(0)[0],
std::numeric_limits<int16_t>::max(), 0.1f);
}
TEST(AutomaticGainController2GainApplier, ClippingIsNotDone) {
constexpr float initial_signal_level = 30000.f;
constexpr float gain_factor = 10.f;
VectorFloatFrame fake_audio(1, 1, initial_signal_level);
GainApplier gain_applier(false, gain_factor);
gain_applier.ApplyGain(fake_audio.float_frame_view());
EXPECT_NEAR(fake_audio.float_frame_view().channel(0)[0],
initial_signal_level * gain_factor, 0.1f);
}
TEST(AutomaticGainController2GainApplier, RampingIsDone) {
constexpr float initial_signal_level = 30000.f;
constexpr float initial_gain_factor = 1.f;
constexpr float target_gain_factor = 0.5f;
constexpr int num_channels = 3;
constexpr int samples_per_channel = 4;
VectorFloatFrame fake_audio(num_channels, samples_per_channel,
initial_signal_level);
GainApplier gain_applier(false, initial_gain_factor);
gain_applier.SetGainFactor(target_gain_factor);
gain_applier.ApplyGain(fake_audio.float_frame_view());
// The maximal gain change should be close to that in linear interpolation.
for (size_t channel = 0; channel < num_channels; ++channel) {
float max_signal_change = 0.f;
float last_signal_level = initial_signal_level;
for (const auto sample : fake_audio.float_frame_view().channel(channel)) {
const float current_change = fabs(last_signal_level - sample);
max_signal_change = std::max(max_signal_change, current_change);
last_signal_level = sample;
}
const float total_gain_change =
fabs((initial_gain_factor - target_gain_factor) * initial_signal_level);
EXPECT_NEAR(max_signal_change, total_gain_change / samples_per_channel,
0.1f);
}
// Next frame should have the desired level.
VectorFloatFrame next_fake_audio_frame(num_channels, samples_per_channel,
initial_signal_level);
gain_applier.ApplyGain(next_fake_audio_frame.float_frame_view());
// The last sample should have the new gain.
EXPECT_NEAR(next_fake_audio_frame.float_frame_view().channel(0)[0],
initial_signal_level * target_gain_factor, 0.1f);
}
} // namespace webrtc

View file

@ -0,0 +1,46 @@
/*
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_GAIN_MAP_INTERNAL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_GAIN_MAP_INTERNAL_H_
namespace webrtc {
static constexpr int kGainMapSize = 256;
// Maps input volumes, which are values in the [0, 255] range, to gains in dB.
// The values below are generated with numpy as follows:
// SI = 2 # Initial slope.
// SF = 0.25 # Final slope.
// D = 8/256 # Quantization factor.
// x = np.linspace(0, 255, 256) # Input volumes.
// y = (SF * x + (SI - SF) * (1 - np.exp(-D*x)) / D - 56).round()
static const int kGainMap[kGainMapSize] = {
-56, -54, -52, -50, -48, -47, -45, -43, -42, -40, -38, -37, -35, -34, -33,
-31, -30, -29, -27, -26, -25, -24, -23, -22, -20, -19, -18, -17, -16, -15,
-14, -14, -13, -12, -11, -10, -9, -8, -8, -7, -6, -5, -5, -4, -3,
-2, -2, -1, 0, 0, 1, 1, 2, 3, 3, 4, 4, 5, 5, 6,
6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13,
13, 14, 14, 15, 15, 15, 16, 16, 17, 17, 17, 18, 18, 18, 19,
19, 19, 20, 20, 21, 21, 21, 22, 22, 22, 23, 23, 23, 24, 24,
24, 24, 25, 25, 25, 26, 26, 26, 27, 27, 27, 28, 28, 28, 28,
29, 29, 29, 30, 30, 30, 30, 31, 31, 31, 32, 32, 32, 32, 33,
33, 33, 33, 34, 34, 34, 35, 35, 35, 35, 36, 36, 36, 36, 37,
37, 37, 38, 38, 38, 38, 39, 39, 39, 39, 40, 40, 40, 40, 41,
41, 41, 41, 42, 42, 42, 42, 43, 43, 43, 44, 44, 44, 44, 45,
45, 45, 45, 46, 46, 46, 46, 47, 47, 47, 47, 48, 48, 48, 48,
49, 49, 49, 49, 50, 50, 50, 50, 51, 51, 51, 51, 52, 52, 52,
52, 53, 53, 53, 53, 54, 54, 54, 54, 55, 55, 55, 55, 56, 56,
56, 56, 57, 57, 57, 57, 58, 58, 58, 58, 59, 59, 59, 59, 60,
60, 60, 60, 61, 61, 61, 61, 62, 62, 62, 62, 63, 63, 63, 63,
64};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_GAIN_MAP_INTERNAL_H_

View file

@ -0,0 +1,580 @@
/*
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/input_volume_controller.h"
#include <algorithm>
#include <cmath>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/gain_map_internal.h"
#include "modules/audio_processing/agc2/input_volume_stats_reporter.h"
#include "modules/audio_processing/include/audio_frame_view.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
#include "system_wrappers/include/field_trial.h"
#include "system_wrappers/include/metrics.h"
namespace webrtc {
namespace {
// Amount of error we tolerate in the microphone input volume (presumably due to
// OS quantization) before we assume the user has manually adjusted the volume.
constexpr int kVolumeQuantizationSlack = 25;
constexpr int kMaxInputVolume = 255;
static_assert(kGainMapSize > kMaxInputVolume, "gain map too small");
// Maximum absolute RMS error.
constexpr int KMaxAbsRmsErrorDbfs = 15;
static_assert(KMaxAbsRmsErrorDbfs > 0, "");
using Agc1ClippingPredictorConfig = AudioProcessing::Config::GainController1::
AnalogGainController::ClippingPredictor;
// TODO(webrtc:7494): Hardcode clipping predictor parameters and remove this
// function after no longer needed in the ctor.
Agc1ClippingPredictorConfig CreateClippingPredictorConfig(bool enabled) {
Agc1ClippingPredictorConfig config;
config.enabled = enabled;
return config;
}
// Returns an input volume in the [`min_input_volume`, `kMaxInputVolume`] range
// that reduces `gain_error_db`, which is a gain error estimated when
// `input_volume` was applied, according to a fixed gain map.
int ComputeVolumeUpdate(int gain_error_db,
int input_volume,
int min_input_volume) {
RTC_DCHECK_GE(input_volume, 0);
RTC_DCHECK_LE(input_volume, kMaxInputVolume);
if (gain_error_db == 0) {
return input_volume;
}
int new_volume = input_volume;
if (gain_error_db > 0) {
while (kGainMap[new_volume] - kGainMap[input_volume] < gain_error_db &&
new_volume < kMaxInputVolume) {
++new_volume;
}
} else {
while (kGainMap[new_volume] - kGainMap[input_volume] > gain_error_db &&
new_volume > min_input_volume) {
--new_volume;
}
}
return new_volume;
}
// Returns the proportion of samples in the buffer which are at full-scale
// (and presumably clipped).
float ComputeClippedRatio(const float* const* audio,
size_t num_channels,
size_t samples_per_channel) {
RTC_DCHECK_GT(samples_per_channel, 0);
int num_clipped = 0;
for (size_t ch = 0; ch < num_channels; ++ch) {
int num_clipped_in_ch = 0;
for (size_t i = 0; i < samples_per_channel; ++i) {
RTC_DCHECK(audio[ch]);
if (audio[ch][i] >= 32767.0f || audio[ch][i] <= -32768.0f) {
++num_clipped_in_ch;
}
}
num_clipped = std::max(num_clipped, num_clipped_in_ch);
}
return static_cast<float>(num_clipped) / (samples_per_channel);
}
void LogClippingMetrics(int clipping_rate) {
RTC_LOG(LS_INFO) << "[AGC2] Input clipping rate: " << clipping_rate << "%";
RTC_HISTOGRAM_COUNTS_LINEAR(/*name=*/"WebRTC.Audio.Agc.InputClippingRate",
/*sample=*/clipping_rate, /*min=*/0, /*max=*/100,
/*bucket_count=*/50);
}
// Compares `speech_level_dbfs` to the [`target_range_min_dbfs`,
// `target_range_max_dbfs`] range and returns the error to be compensated via
// input volume adjustment. Returns a positive value when the level is below
// the range, a negative value when the level is above the range, zero
// otherwise.
int GetSpeechLevelRmsErrorDb(float speech_level_dbfs,
int target_range_min_dbfs,
int target_range_max_dbfs) {
constexpr float kMinSpeechLevelDbfs = -90.0f;
constexpr float kMaxSpeechLevelDbfs = 30.0f;
RTC_DCHECK_GE(speech_level_dbfs, kMinSpeechLevelDbfs);
RTC_DCHECK_LE(speech_level_dbfs, kMaxSpeechLevelDbfs);
speech_level_dbfs = rtc::SafeClamp<float>(
speech_level_dbfs, kMinSpeechLevelDbfs, kMaxSpeechLevelDbfs);
int rms_error_db = 0;
if (speech_level_dbfs > target_range_max_dbfs) {
rms_error_db = std::round(target_range_max_dbfs - speech_level_dbfs);
} else if (speech_level_dbfs < target_range_min_dbfs) {
rms_error_db = std::round(target_range_min_dbfs - speech_level_dbfs);
}
return rms_error_db;
}
} // namespace
MonoInputVolumeController::MonoInputVolumeController(
int min_input_volume_after_clipping,
int min_input_volume,
int update_input_volume_wait_frames,
float speech_probability_threshold,
float speech_ratio_threshold)
: min_input_volume_(min_input_volume),
min_input_volume_after_clipping_(min_input_volume_after_clipping),
max_input_volume_(kMaxInputVolume),
update_input_volume_wait_frames_(
std::max(update_input_volume_wait_frames, 1)),
speech_probability_threshold_(speech_probability_threshold),
speech_ratio_threshold_(speech_ratio_threshold) {
RTC_DCHECK_GE(min_input_volume_, 0);
RTC_DCHECK_LE(min_input_volume_, 255);
RTC_DCHECK_GE(min_input_volume_after_clipping_, 0);
RTC_DCHECK_LE(min_input_volume_after_clipping_, 255);
RTC_DCHECK_GE(max_input_volume_, 0);
RTC_DCHECK_LE(max_input_volume_, 255);
RTC_DCHECK_GE(update_input_volume_wait_frames_, 0);
RTC_DCHECK_GE(speech_probability_threshold_, 0.0f);
RTC_DCHECK_LE(speech_probability_threshold_, 1.0f);
RTC_DCHECK_GE(speech_ratio_threshold_, 0.0f);
RTC_DCHECK_LE(speech_ratio_threshold_, 1.0f);
}
MonoInputVolumeController::~MonoInputVolumeController() = default;
void MonoInputVolumeController::Initialize() {
max_input_volume_ = kMaxInputVolume;
capture_output_used_ = true;
check_volume_on_next_process_ = true;
frames_since_update_input_volume_ = 0;
speech_frames_since_update_input_volume_ = 0;
is_first_frame_ = true;
}
// A speeh segment is considered active if at least
// `update_input_volume_wait_frames_` new frames have been processed since the
// previous update and the ratio of non-silence frames (i.e., frames with a
// `speech_probability` higher than `speech_probability_threshold_`) is at least
// `speech_ratio_threshold_`.
void MonoInputVolumeController::Process(absl::optional<int> rms_error_db,
float speech_probability) {
if (check_volume_on_next_process_) {
check_volume_on_next_process_ = false;
// We have to wait until the first process call to check the volume,
// because Chromium doesn't guarantee it to be valid any earlier.
CheckVolumeAndReset();
}
// Count frames with a high speech probability as speech.
if (speech_probability >= speech_probability_threshold_) {
++speech_frames_since_update_input_volume_;
}
// Reset the counters and maybe update the input volume.
if (++frames_since_update_input_volume_ >= update_input_volume_wait_frames_) {
const float speech_ratio =
static_cast<float>(speech_frames_since_update_input_volume_) /
static_cast<float>(update_input_volume_wait_frames_);
// Always reset the counters regardless of whether the volume changes or
// not.
frames_since_update_input_volume_ = 0;
speech_frames_since_update_input_volume_ = 0;
// Update the input volume if allowed.
if (!is_first_frame_ && speech_ratio >= speech_ratio_threshold_ &&
rms_error_db.has_value()) {
UpdateInputVolume(*rms_error_db);
}
}
is_first_frame_ = false;
}
void MonoInputVolumeController::HandleClipping(int clipped_level_step) {
RTC_DCHECK_GT(clipped_level_step, 0);
// Always decrease the maximum input volume, even if the current input volume
// is below threshold.
SetMaxLevel(std::max(min_input_volume_after_clipping_,
max_input_volume_ - clipped_level_step));
if (log_to_histograms_) {
RTC_HISTOGRAM_BOOLEAN("WebRTC.Audio.AgcClippingAdjustmentAllowed",
last_recommended_input_volume_ - clipped_level_step >=
min_input_volume_after_clipping_);
}
if (last_recommended_input_volume_ > min_input_volume_after_clipping_) {
// Don't try to adjust the input volume if we're already below the limit. As
// a consequence, if the user has brought the input volume above the limit,
// we will still not react until the postproc updates the input volume.
SetInputVolume(
std::max(min_input_volume_after_clipping_,
last_recommended_input_volume_ - clipped_level_step));
frames_since_update_input_volume_ = 0;
speech_frames_since_update_input_volume_ = 0;
is_first_frame_ = false;
}
}
void MonoInputVolumeController::SetInputVolume(int new_volume) {
int applied_input_volume = recommended_input_volume_;
if (applied_input_volume == 0) {
RTC_DLOG(LS_INFO)
<< "[AGC2] The applied input volume is zero, taking no action.";
return;
}
if (applied_input_volume < 0 || applied_input_volume > kMaxInputVolume) {
RTC_LOG(LS_ERROR) << "[AGC2] Invalid value for the applied input volume: "
<< applied_input_volume;
return;
}
// Detect manual input volume adjustments by checking if the
// `applied_input_volume` is outside of the `[last_recommended_input_volume_ -
// kVolumeQuantizationSlack, last_recommended_input_volume_ +
// kVolumeQuantizationSlack]` range.
if (applied_input_volume >
last_recommended_input_volume_ + kVolumeQuantizationSlack ||
applied_input_volume <
last_recommended_input_volume_ - kVolumeQuantizationSlack) {
RTC_DLOG(LS_INFO)
<< "[AGC2] The input volume was manually adjusted. Updating "
"stored input volume from "
<< last_recommended_input_volume_ << " to " << applied_input_volume;
last_recommended_input_volume_ = applied_input_volume;
// Always allow the user to increase the volume.
if (last_recommended_input_volume_ > max_input_volume_) {
SetMaxLevel(last_recommended_input_volume_);
}
// Take no action in this case, since we can't be sure when the volume
// was manually adjusted.
frames_since_update_input_volume_ = 0;
speech_frames_since_update_input_volume_ = 0;
is_first_frame_ = false;
return;
}
new_volume = std::min(new_volume, max_input_volume_);
if (new_volume == last_recommended_input_volume_) {
return;
}
recommended_input_volume_ = new_volume;
RTC_DLOG(LS_INFO) << "[AGC2] Applied input volume: " << applied_input_volume
<< " | last recommended input volume: "
<< last_recommended_input_volume_
<< " | newly recommended input volume: " << new_volume;
last_recommended_input_volume_ = new_volume;
}
void MonoInputVolumeController::SetMaxLevel(int input_volume) {
RTC_DCHECK_GE(input_volume, min_input_volume_after_clipping_);
max_input_volume_ = input_volume;
RTC_DLOG(LS_INFO) << "[AGC2] Maximum input volume updated: "
<< max_input_volume_;
}
void MonoInputVolumeController::HandleCaptureOutputUsedChange(
bool capture_output_used) {
if (capture_output_used_ == capture_output_used) {
return;
}
capture_output_used_ = capture_output_used;
if (capture_output_used) {
// When we start using the output, we should reset things to be safe.
check_volume_on_next_process_ = true;
}
}
int MonoInputVolumeController::CheckVolumeAndReset() {
int input_volume = recommended_input_volume_;
// Reasons for taking action at startup:
// 1) A person starting a call is expected to be heard.
// 2) Independent of interpretation of `input_volume` == 0 we should raise it
// so the AGC can do its job properly.
if (input_volume == 0 && !startup_) {
RTC_DLOG(LS_INFO)
<< "[AGC2] The applied input volume is zero, taking no action.";
return 0;
}
if (input_volume < 0 || input_volume > kMaxInputVolume) {
RTC_LOG(LS_ERROR) << "[AGC2] Invalid value for the applied input volume: "
<< input_volume;
return -1;
}
RTC_DLOG(LS_INFO) << "[AGC2] Initial input volume: " << input_volume;
if (input_volume < min_input_volume_) {
input_volume = min_input_volume_;
RTC_DLOG(LS_INFO)
<< "[AGC2] The initial input volume is too low, raising to "
<< input_volume;
recommended_input_volume_ = input_volume;
}
last_recommended_input_volume_ = input_volume;
startup_ = false;
frames_since_update_input_volume_ = 0;
speech_frames_since_update_input_volume_ = 0;
is_first_frame_ = true;
return 0;
}
void MonoInputVolumeController::UpdateInputVolume(int rms_error_db) {
RTC_DLOG(LS_INFO) << "[AGC2] RMS error: " << rms_error_db << " dB";
// Prevent too large microphone input volume changes by clamping the RMS
// error.
rms_error_db =
rtc::SafeClamp(rms_error_db, -KMaxAbsRmsErrorDbfs, KMaxAbsRmsErrorDbfs);
if (rms_error_db == 0) {
return;
}
SetInputVolume(ComputeVolumeUpdate(
rms_error_db, last_recommended_input_volume_, min_input_volume_));
}
InputVolumeController::InputVolumeController(int num_capture_channels,
const Config& config)
: num_capture_channels_(num_capture_channels),
min_input_volume_(config.min_input_volume),
capture_output_used_(true),
clipped_level_step_(config.clipped_level_step),
clipped_ratio_threshold_(config.clipped_ratio_threshold),
clipped_wait_frames_(config.clipped_wait_frames),
clipping_predictor_(CreateClippingPredictor(
num_capture_channels,
CreateClippingPredictorConfig(config.enable_clipping_predictor))),
use_clipping_predictor_step_(
!!clipping_predictor_ &&
CreateClippingPredictorConfig(config.enable_clipping_predictor)
.use_predicted_step),
frames_since_clipped_(config.clipped_wait_frames),
clipping_rate_log_counter_(0),
clipping_rate_log_(0.0f),
target_range_max_dbfs_(config.target_range_max_dbfs),
target_range_min_dbfs_(config.target_range_min_dbfs),
channel_controllers_(num_capture_channels) {
RTC_LOG(LS_INFO)
<< "[AGC2] Input volume controller enabled. Minimum input volume: "
<< min_input_volume_;
for (auto& controller : channel_controllers_) {
controller = std::make_unique<MonoInputVolumeController>(
config.clipped_level_min, min_input_volume_,
config.update_input_volume_wait_frames,
config.speech_probability_threshold, config.speech_ratio_threshold);
}
RTC_DCHECK(!channel_controllers_.empty());
RTC_DCHECK_GT(clipped_level_step_, 0);
RTC_DCHECK_LE(clipped_level_step_, 255);
RTC_DCHECK_GT(clipped_ratio_threshold_, 0.0f);
RTC_DCHECK_LT(clipped_ratio_threshold_, 1.0f);
RTC_DCHECK_GT(clipped_wait_frames_, 0);
channel_controllers_[0]->ActivateLogging();
}
InputVolumeController::~InputVolumeController() {}
void InputVolumeController::Initialize() {
for (auto& controller : channel_controllers_) {
controller->Initialize();
}
capture_output_used_ = true;
AggregateChannelLevels();
clipping_rate_log_ = 0.0f;
clipping_rate_log_counter_ = 0;
applied_input_volume_ = absl::nullopt;
}
void InputVolumeController::AnalyzeInputAudio(int applied_input_volume,
const AudioBuffer& audio_buffer) {
RTC_DCHECK_GE(applied_input_volume, 0);
RTC_DCHECK_LE(applied_input_volume, 255);
SetAppliedInputVolume(applied_input_volume);
RTC_DCHECK_EQ(audio_buffer.num_channels(), channel_controllers_.size());
const float* const* audio = audio_buffer.channels_const();
size_t samples_per_channel = audio_buffer.num_frames();
RTC_DCHECK(audio);
AggregateChannelLevels();
if (!capture_output_used_) {
return;
}
if (!!clipping_predictor_) {
AudioFrameView<const float> frame = AudioFrameView<const float>(
audio, num_capture_channels_, static_cast<int>(samples_per_channel));
clipping_predictor_->Analyze(frame);
}
// Check for clipped samples. We do this in the preprocessing phase in order
// to catch clipped echo as well.
//
// If we find a sufficiently clipped frame, drop the current microphone
// input volume and enforce a new maximum input volume, dropped the same
// amount from the current maximum. This harsh treatment is an effort to avoid
// repeated clipped echo events.
float clipped_ratio =
ComputeClippedRatio(audio, num_capture_channels_, samples_per_channel);
clipping_rate_log_ = std::max(clipped_ratio, clipping_rate_log_);
clipping_rate_log_counter_++;
constexpr int kNumFramesIn30Seconds = 3000;
if (clipping_rate_log_counter_ == kNumFramesIn30Seconds) {
LogClippingMetrics(std::round(100.0f * clipping_rate_log_));
clipping_rate_log_ = 0.0f;
clipping_rate_log_counter_ = 0;
}
if (frames_since_clipped_ < clipped_wait_frames_) {
++frames_since_clipped_;
return;
}
const bool clipping_detected = clipped_ratio > clipped_ratio_threshold_;
bool clipping_predicted = false;
int predicted_step = 0;
if (!!clipping_predictor_) {
for (int channel = 0; channel < num_capture_channels_; ++channel) {
const auto step = clipping_predictor_->EstimateClippedLevelStep(
channel, recommended_input_volume_, clipped_level_step_,
channel_controllers_[channel]->min_input_volume_after_clipping(),
kMaxInputVolume);
if (step.has_value()) {
predicted_step = std::max(predicted_step, step.value());
clipping_predicted = true;
}
}
}
if (clipping_detected) {
RTC_DLOG(LS_INFO) << "[AGC2] Clipping detected (ratio: " << clipped_ratio
<< ")";
}
int step = clipped_level_step_;
if (clipping_predicted) {
predicted_step = std::max(predicted_step, clipped_level_step_);
RTC_DLOG(LS_INFO) << "[AGC2] Clipping predicted (volume down step: "
<< predicted_step << ")";
if (use_clipping_predictor_step_) {
step = predicted_step;
}
}
if (clipping_detected ||
(clipping_predicted && use_clipping_predictor_step_)) {
for (auto& state_ch : channel_controllers_) {
state_ch->HandleClipping(step);
}
frames_since_clipped_ = 0;
if (!!clipping_predictor_) {
clipping_predictor_->Reset();
}
}
AggregateChannelLevels();
}
absl::optional<int> InputVolumeController::RecommendInputVolume(
float speech_probability,
absl::optional<float> speech_level_dbfs) {
// Only process if applied input volume is set.
if (!applied_input_volume_.has_value()) {
RTC_LOG(LS_ERROR) << "[AGC2] Applied input volume not set.";
return absl::nullopt;
}
AggregateChannelLevels();
const int volume_after_clipping_handling = recommended_input_volume_;
if (!capture_output_used_) {
return applied_input_volume_;
}
absl::optional<int> rms_error_db;
if (speech_level_dbfs.has_value()) {
// Compute the error for all frames (both speech and non-speech frames).
rms_error_db = GetSpeechLevelRmsErrorDb(
*speech_level_dbfs, target_range_min_dbfs_, target_range_max_dbfs_);
}
for (auto& controller : channel_controllers_) {
controller->Process(rms_error_db, speech_probability);
}
AggregateChannelLevels();
if (volume_after_clipping_handling != recommended_input_volume_) {
// The recommended input volume was adjusted in order to match the target
// level.
UpdateHistogramOnRecommendedInputVolumeChangeToMatchTarget(
recommended_input_volume_);
}
applied_input_volume_ = absl::nullopt;
return recommended_input_volume();
}
void InputVolumeController::HandleCaptureOutputUsedChange(
bool capture_output_used) {
for (auto& controller : channel_controllers_) {
controller->HandleCaptureOutputUsedChange(capture_output_used);
}
capture_output_used_ = capture_output_used;
}
void InputVolumeController::SetAppliedInputVolume(int input_volume) {
applied_input_volume_ = input_volume;
for (auto& controller : channel_controllers_) {
controller->set_stream_analog_level(input_volume);
}
AggregateChannelLevels();
}
void InputVolumeController::AggregateChannelLevels() {
int new_recommended_input_volume =
channel_controllers_[0]->recommended_analog_level();
channel_controlling_gain_ = 0;
for (size_t ch = 1; ch < channel_controllers_.size(); ++ch) {
int input_volume = channel_controllers_[ch]->recommended_analog_level();
if (input_volume < new_recommended_input_volume) {
new_recommended_input_volume = input_volume;
channel_controlling_gain_ = static_cast<int>(ch);
}
}
// Enforce the minimum input volume when a recommendation is made.
if (applied_input_volume_.has_value() && *applied_input_volume_ > 0) {
new_recommended_input_volume =
std::max(new_recommended_input_volume, min_input_volume_);
}
recommended_input_volume_ = new_recommended_input_volume;
}
} // namespace webrtc

View file

@ -0,0 +1,282 @@
/*
* Copyright (c) 2013 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_CONTROLLER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_CONTROLLER_H_
#include <memory>
#include <vector>
#include "absl/types/optional.h"
#include "api/array_view.h"
#include "modules/audio_processing/agc2/clipping_predictor.h"
#include "modules/audio_processing/audio_buffer.h"
#include "modules/audio_processing/include/audio_processing.h"
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
class MonoInputVolumeController;
// The input volume controller recommends what volume to use, handles volume
// changes and clipping detection and prediction. In particular, it handles
// changes triggered by the user (e.g., volume set to zero by a HW mute button).
// This class is not thread-safe.
// TODO(bugs.webrtc.org/7494): Use applied/recommended input volume naming
// convention.
class InputVolumeController final {
public:
// Config for the constructor.
struct Config {
// Minimum input volume that can be recommended. Not enforced when the
// applied input volume is zero outside startup.
int min_input_volume = 20;
// Lowest input volume level that will be applied in response to clipping.
int clipped_level_min = 70;
// Amount input volume level is lowered with every clipping event. Limited
// to (0, 255].
int clipped_level_step = 15;
// Proportion of clipped samples required to declare a clipping event.
// Limited to (0.0f, 1.0f).
float clipped_ratio_threshold = 0.1f;
// Time in frames to wait after a clipping event before checking again.
// Limited to values higher than 0.
int clipped_wait_frames = 300;
// Enables clipping prediction functionality.
bool enable_clipping_predictor = true;
// Speech level target range (dBFS). If the speech level is in the range
// [`target_range_min_dbfs`, `target_range_max_dbfs`], no input volume
// adjustments are done based on the speech level. For speech levels below
// and above the range, the targets `target_range_min_dbfs` and
// `target_range_max_dbfs` are used, respectively.
int target_range_max_dbfs = -30;
int target_range_min_dbfs = -50;
// Number of wait frames between the recommended input volume updates.
int update_input_volume_wait_frames = 100;
// Speech probability threshold: speech probabilities below the threshold
// are considered silence. Limited to [0.0f, 1.0f].
float speech_probability_threshold = 0.7f;
// Minimum speech frame ratio for volume updates to be allowed. Limited to
// [0.0f, 1.0f].
float speech_ratio_threshold = 0.6f;
};
// Ctor. `num_capture_channels` specifies the number of channels for the audio
// passed to `AnalyzePreProcess()` and `Process()`. Clamps
// `config.startup_min_level` in the [12, 255] range.
InputVolumeController(int num_capture_channels, const Config& config);
~InputVolumeController();
InputVolumeController(const InputVolumeController&) = delete;
InputVolumeController& operator=(const InputVolumeController&) = delete;
// TODO(webrtc:7494): Integrate initialization into ctor and remove.
void Initialize();
// Analyzes `audio_buffer` before `RecommendInputVolume()` is called so tha
// the analysis can be performed before digital processing operations take
// place (e.g., echo cancellation). The analysis consists of input clipping
// detection and prediction (if enabled).
void AnalyzeInputAudio(int applied_input_volume,
const AudioBuffer& audio_buffer);
// Adjusts the recommended input volume upwards/downwards based on the result
// of `AnalyzeInputAudio()` and on `speech_level_dbfs` (if specified). Must
// be called after `AnalyzeInputAudio()`. The value of `speech_probability`
// is expected to be in the range [0, 1] and `speech_level_dbfs` in the range
// [-90, 30] and both should be estimated after echo cancellation and noise
// suppression are applied. Returns a non-empty input volume recommendation if
// available. If `capture_output_used_` is true, returns the applied input
// volume.
absl::optional<int> RecommendInputVolume(
float speech_probability,
absl::optional<float> speech_level_dbfs);
// Stores whether the capture output will be used or not. Call when the
// capture stream output has been flagged to be used/not-used. If unused, the
// controller disregards all incoming audio.
void HandleCaptureOutputUsedChange(bool capture_output_used);
// Returns true if clipping prediction is enabled.
// TODO(bugs.webrtc.org/7494): Deprecate this method.
bool clipping_predictor_enabled() const { return !!clipping_predictor_; }
// Returns true if clipping prediction is used to adjust the input volume.
// TODO(bugs.webrtc.org/7494): Deprecate this method.
bool use_clipping_predictor_step() const {
return use_clipping_predictor_step_;
}
// Only use for testing: Use `RecommendInputVolume()` elsewhere.
// Returns the value of a member variable, needed for testing
// `AnalyzeInputAudio()`.
int recommended_input_volume() const { return recommended_input_volume_; }
// Only use for testing.
bool capture_output_used() const { return capture_output_used_; }
private:
friend class InputVolumeControllerTestHelper;
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest, MinInputVolumeDefault);
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest, MinInputVolumeDisabled);
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest,
MinInputVolumeOutOfRangeAbove);
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest,
MinInputVolumeOutOfRangeBelow);
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerTest, MinInputVolumeEnabled50);
FRIEND_TEST_ALL_PREFIXES(InputVolumeControllerParametrizedTest,
ClippingParametersVerified);
// Sets the applied input volume and resets the recommended input volume.
void SetAppliedInputVolume(int level);
void AggregateChannelLevels();
const int num_capture_channels_;
// Minimum input volume that can be recommended.
const int min_input_volume_;
// TODO(bugs.webrtc.org/7494): Once
// `AudioProcessingImpl::recommended_stream_analog_level()` becomes a trivial
// getter, leave uninitialized.
// Recommended input volume. After `SetAppliedInputVolume()` is called it
// holds holds the observed input volume. Possibly updated by
// `AnalyzePreProcess()` and `Process()`; after these calls, holds the
// recommended input volume.
int recommended_input_volume_ = 0;
// Applied input volume. After `SetAppliedInputVolume()` is called it holds
// the current applied volume.
absl::optional<int> applied_input_volume_;
bool capture_output_used_;
// Clipping detection and prediction.
const int clipped_level_step_;
const float clipped_ratio_threshold_;
const int clipped_wait_frames_;
const std::unique_ptr<ClippingPredictor> clipping_predictor_;
const bool use_clipping_predictor_step_;
int frames_since_clipped_;
int clipping_rate_log_counter_;
float clipping_rate_log_;
// Target range minimum and maximum. If the seech level is in the range
// [`target_range_min_dbfs`, `target_range_max_dbfs`], no volume adjustments
// take place. Instead, the digital gain controller is assumed to adapt to
// compensate for the speech level RMS error.
const int target_range_max_dbfs_;
const int target_range_min_dbfs_;
// Channel controllers updating the gain upwards/downwards.
std::vector<std::unique_ptr<MonoInputVolumeController>> channel_controllers_;
int channel_controlling_gain_ = 0;
};
// TODO(bugs.webrtc.org/7494): Use applied/recommended input volume naming
// convention.
class MonoInputVolumeController {
public:
MonoInputVolumeController(int min_input_volume_after_clipping,
int min_input_volume,
int update_input_volume_wait_frames,
float speech_probability_threshold,
float speech_ratio_threshold);
~MonoInputVolumeController();
MonoInputVolumeController(const MonoInputVolumeController&) = delete;
MonoInputVolumeController& operator=(const MonoInputVolumeController&) =
delete;
void Initialize();
void HandleCaptureOutputUsedChange(bool capture_output_used);
// Sets the current input volume.
void set_stream_analog_level(int input_volume) {
recommended_input_volume_ = input_volume;
}
// Lowers the recommended input volume in response to clipping based on the
// suggested reduction `clipped_level_step`. Must be called after
// `set_stream_analog_level()`.
void HandleClipping(int clipped_level_step);
// TODO(bugs.webrtc.org/7494): Rename, audio not passed to the method anymore.
// Adjusts the recommended input volume upwards/downwards depending on the
// result of `HandleClipping()` and on `rms_error_dbfs`. Updates are only
// allowed for active speech segments and when `rms_error_dbfs` is not empty.
// Must be called after `HandleClipping()`.
void Process(absl::optional<int> rms_error_dbfs, float speech_probability);
// Returns the recommended input volume. Must be called after `Process()`.
int recommended_analog_level() const { return recommended_input_volume_; }
void ActivateLogging() { log_to_histograms_ = true; }
int min_input_volume_after_clipping() const {
return min_input_volume_after_clipping_;
}
// Only used for testing.
int min_input_volume() const { return min_input_volume_; }
private:
// Sets a new input volume, after first checking that it hasn't been updated
// by the user, in which case no action is taken.
void SetInputVolume(int new_volume);
// Sets the maximum input volume that the input volume controller is allowed
// to apply. The volume must be at least `kClippedLevelMin`.
void SetMaxLevel(int level);
int CheckVolumeAndReset();
// Updates the recommended input volume. If the volume slider needs to be
// moved, we check first if the user has adjusted it, in which case we take no
// action and cache the updated level.
void UpdateInputVolume(int rms_error_dbfs);
const int min_input_volume_;
const int min_input_volume_after_clipping_;
int max_input_volume_;
int last_recommended_input_volume_ = 0;
bool capture_output_used_ = true;
bool check_volume_on_next_process_ = true;
bool startup_ = true;
// TODO(bugs.webrtc.org/7494): Create a separate member for the applied
// input volume.
// Recommended input volume. After `set_stream_analog_level()` is
// called, it holds the observed applied input volume. Possibly updated by
// `HandleClipping()` and `Process()`; after these calls, holds the
// recommended input volume.
int recommended_input_volume_ = 0;
bool log_to_histograms_ = false;
// Counters for frames and speech frames since the last update in the
// recommended input volume.
const int update_input_volume_wait_frames_;
int frames_since_update_input_volume_ = 0;
int speech_frames_since_update_input_volume_ = 0;
bool is_first_frame_ = true;
// Speech probability threshold for a frame to be considered speech (instead
// of silence). Limited to [0.0f, 1.0f].
const float speech_probability_threshold_;
// Minimum ratio of speech frames. Limited to [0.0f, 1.0f].
const float speech_ratio_threshold_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_CONTROLLER_H_

View file

@ -0,0 +1,171 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/input_volume_stats_reporter.h"
#include <cmath>
#include "absl/strings/string_view.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
#include "rtc_base/strings/string_builder.h"
#include "system_wrappers/include/metrics.h"
namespace webrtc {
namespace {
using InputVolumeType = InputVolumeStatsReporter::InputVolumeType;
constexpr int kFramesIn60Seconds = 6000;
constexpr int kMinInputVolume = 0;
constexpr int kMaxInputVolume = 255;
constexpr int kMaxUpdate = kMaxInputVolume - kMinInputVolume;
int ComputeAverageUpdate(int sum_updates, int num_updates) {
RTC_DCHECK_GE(sum_updates, 0);
RTC_DCHECK_LE(sum_updates, kMaxUpdate * kFramesIn60Seconds);
RTC_DCHECK_GE(num_updates, 0);
RTC_DCHECK_LE(num_updates, kFramesIn60Seconds);
if (num_updates == 0) {
return 0;
}
return std::round(static_cast<float>(sum_updates) /
static_cast<float>(num_updates));
}
constexpr absl::string_view MetricNamePrefix(
InputVolumeType input_volume_type) {
switch (input_volume_type) {
case InputVolumeType::kApplied:
return "WebRTC.Audio.Apm.AppliedInputVolume.";
case InputVolumeType::kRecommended:
return "WebRTC.Audio.Apm.RecommendedInputVolume.";
}
}
metrics::Histogram* CreateVolumeHistogram(InputVolumeType input_volume_type) {
char buffer[64];
rtc::SimpleStringBuilder builder(buffer);
builder << MetricNamePrefix(input_volume_type) << "OnChange";
return metrics::HistogramFactoryGetCountsLinear(/*name=*/builder.str(),
/*min=*/1,
/*max=*/kMaxInputVolume,
/*bucket_count=*/50);
}
metrics::Histogram* CreateRateHistogram(InputVolumeType input_volume_type,
absl::string_view name) {
char buffer[64];
rtc::SimpleStringBuilder builder(buffer);
builder << MetricNamePrefix(input_volume_type) << name;
return metrics::HistogramFactoryGetCountsLinear(/*name=*/builder.str(),
/*min=*/1,
/*max=*/kFramesIn60Seconds,
/*bucket_count=*/50);
}
metrics::Histogram* CreateAverageHistogram(InputVolumeType input_volume_type,
absl::string_view name) {
char buffer[64];
rtc::SimpleStringBuilder builder(buffer);
builder << MetricNamePrefix(input_volume_type) << name;
return metrics::HistogramFactoryGetCountsLinear(/*name=*/builder.str(),
/*min=*/1,
/*max=*/kMaxUpdate,
/*bucket_count=*/50);
}
} // namespace
InputVolumeStatsReporter::InputVolumeStatsReporter(InputVolumeType type)
: histograms_(
{.on_volume_change = CreateVolumeHistogram(type),
.decrease_rate = CreateRateHistogram(type, "DecreaseRate"),
.decrease_average = CreateAverageHistogram(type, "DecreaseAverage"),
.increase_rate = CreateRateHistogram(type, "IncreaseRate"),
.increase_average = CreateAverageHistogram(type, "IncreaseAverage"),
.update_rate = CreateRateHistogram(type, "UpdateRate"),
.update_average = CreateAverageHistogram(type, "UpdateAverage")}),
cannot_log_stats_(!histograms_.AllPointersSet()) {
if (cannot_log_stats_) {
RTC_LOG(LS_WARNING) << "Will not log any `" << MetricNamePrefix(type)
<< "*` histogram stats.";
}
}
InputVolumeStatsReporter::~InputVolumeStatsReporter() = default;
void InputVolumeStatsReporter::UpdateStatistics(int input_volume) {
if (cannot_log_stats_) {
// Since the stats cannot be logged, do not bother updating them.
return;
}
RTC_DCHECK_GE(input_volume, kMinInputVolume);
RTC_DCHECK_LE(input_volume, kMaxInputVolume);
if (previous_input_volume_.has_value() &&
input_volume != previous_input_volume_.value()) {
// Update stats when the input volume changes.
metrics::HistogramAdd(histograms_.on_volume_change, input_volume);
// Update stats that are periodically logged.
const int volume_change = input_volume - previous_input_volume_.value();
if (volume_change < 0) {
++volume_update_stats_.num_decreases;
volume_update_stats_.sum_decreases -= volume_change;
} else {
++volume_update_stats_.num_increases;
volume_update_stats_.sum_increases += volume_change;
}
}
// Periodically log input volume change metrics.
if (++log_volume_update_stats_counter_ >= kFramesIn60Seconds) {
LogVolumeUpdateStats();
volume_update_stats_ = {};
log_volume_update_stats_counter_ = 0;
}
previous_input_volume_ = input_volume;
}
void InputVolumeStatsReporter::LogVolumeUpdateStats() const {
// Decrease rate and average.
metrics::HistogramAdd(histograms_.decrease_rate,
volume_update_stats_.num_decreases);
if (volume_update_stats_.num_decreases > 0) {
int average_decrease = ComputeAverageUpdate(
volume_update_stats_.sum_decreases, volume_update_stats_.num_decreases);
metrics::HistogramAdd(histograms_.decrease_average, average_decrease);
}
// Increase rate and average.
metrics::HistogramAdd(histograms_.increase_rate,
volume_update_stats_.num_increases);
if (volume_update_stats_.num_increases > 0) {
int average_increase = ComputeAverageUpdate(
volume_update_stats_.sum_increases, volume_update_stats_.num_increases);
metrics::HistogramAdd(histograms_.increase_average, average_increase);
}
// Update rate and average.
int num_updates =
volume_update_stats_.num_decreases + volume_update_stats_.num_increases;
metrics::HistogramAdd(histograms_.update_rate, num_updates);
if (num_updates > 0) {
int average_update = ComputeAverageUpdate(
volume_update_stats_.sum_decreases + volume_update_stats_.sum_increases,
num_updates);
metrics::HistogramAdd(histograms_.update_average, average_update);
}
}
void UpdateHistogramOnRecommendedInputVolumeChangeToMatchTarget(int volume) {
RTC_HISTOGRAM_COUNTS_LINEAR(
"WebRTC.Audio.Apm.RecommendedInputVolume.OnChangeToMatchTarget", volume,
1, kMaxInputVolume, 50);
}
} // namespace webrtc

View file

@ -0,0 +1,96 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_STATS_REPORTER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_STATS_REPORTER_H_
#include "absl/types/optional.h"
#include "rtc_base/gtest_prod_util.h"
#include "system_wrappers/include/metrics.h"
namespace webrtc {
// Input volume statistics calculator. Computes aggregate stats based on the
// framewise input volume observed by `UpdateStatistics()`. Periodically logs
// the statistics into a histogram.
class InputVolumeStatsReporter {
public:
enum class InputVolumeType {
kApplied = 0,
kRecommended = 1,
};
explicit InputVolumeStatsReporter(InputVolumeType input_volume_type);
InputVolumeStatsReporter(const InputVolumeStatsReporter&) = delete;
InputVolumeStatsReporter operator=(const InputVolumeStatsReporter&) = delete;
~InputVolumeStatsReporter();
// Updates the stats based on `input_volume`. Periodically logs the stats into
// a histogram.
void UpdateStatistics(int input_volume);
private:
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
CheckVolumeUpdateStatsForEmptyStats);
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
CheckVolumeUpdateStatsAfterNoVolumeChange);
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
CheckVolumeUpdateStatsAfterVolumeIncrease);
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
CheckVolumeUpdateStatsAfterVolumeDecrease);
FRIEND_TEST_ALL_PREFIXES(InputVolumeStatsReporterTest,
CheckVolumeUpdateStatsAfterReset);
// Stores input volume update stats to enable calculation of update rate and
// average update separately for volume increases and decreases.
struct VolumeUpdateStats {
int num_decreases = 0;
int num_increases = 0;
int sum_decreases = 0;
int sum_increases = 0;
} volume_update_stats_;
// Returns a copy of the stored statistics. Use only for testing.
VolumeUpdateStats volume_update_stats() const { return volume_update_stats_; }
// Computes aggregate stat and logs them into a histogram.
void LogVolumeUpdateStats() const;
// Histograms.
struct Histograms {
metrics::Histogram* const on_volume_change;
metrics::Histogram* const decrease_rate;
metrics::Histogram* const decrease_average;
metrics::Histogram* const increase_rate;
metrics::Histogram* const increase_average;
metrics::Histogram* const update_rate;
metrics::Histogram* const update_average;
bool AllPointersSet() const {
return !!on_volume_change && !!decrease_rate && !!decrease_average &&
!!increase_rate && !!increase_average && !!update_rate &&
!!update_average;
}
} histograms_;
// True if the stats cannot be logged.
const bool cannot_log_stats_;
int log_volume_update_stats_counter_ = 0;
absl::optional<int> previous_input_volume_ = absl::nullopt;
};
// Updates the histogram that keeps track of recommended input volume changes
// required in order to match the target level in the input volume adaptation
// process.
void UpdateHistogramOnRecommendedInputVolumeChangeToMatchTarget(int volume);
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_INPUT_VOLUME_STATS_REPORTER_H_

View file

@ -0,0 +1,204 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/interpolated_gain_curve.h"
#include <algorithm>
#include <iterator>
#include "absl/strings/string_view.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/strings/string_builder.h"
namespace webrtc {
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
InterpolatedGainCurve::approximation_params_x_;
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
InterpolatedGainCurve::approximation_params_m_;
constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
InterpolatedGainCurve::approximation_params_q_;
InterpolatedGainCurve::InterpolatedGainCurve(
ApmDataDumper* apm_data_dumper,
absl::string_view histogram_name_prefix)
: region_logger_(
(rtc::StringBuilder("WebRTC.Audio.")
<< histogram_name_prefix << ".FixedDigitalGainCurveRegion.Identity")
.str(),
(rtc::StringBuilder("WebRTC.Audio.")
<< histogram_name_prefix << ".FixedDigitalGainCurveRegion.Knee")
.str(),
(rtc::StringBuilder("WebRTC.Audio.")
<< histogram_name_prefix << ".FixedDigitalGainCurveRegion.Limiter")
.str(),
(rtc::StringBuilder("WebRTC.Audio.")
<< histogram_name_prefix
<< ".FixedDigitalGainCurveRegion.Saturation")
.str()),
apm_data_dumper_(apm_data_dumper) {}
InterpolatedGainCurve::~InterpolatedGainCurve() {
if (stats_.available) {
RTC_DCHECK(apm_data_dumper_);
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_identity",
stats_.look_ups_identity_region);
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_knee",
stats_.look_ups_knee_region);
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_limiter",
stats_.look_ups_limiter_region);
apm_data_dumper_->DumpRaw("agc2_interp_gain_curve_lookups_saturation",
stats_.look_ups_saturation_region);
region_logger_.LogRegionStats(stats_);
}
}
InterpolatedGainCurve::RegionLogger::RegionLogger(
absl::string_view identity_histogram_name,
absl::string_view knee_histogram_name,
absl::string_view limiter_histogram_name,
absl::string_view saturation_histogram_name)
: identity_histogram(
metrics::HistogramFactoryGetCounts(identity_histogram_name,
1,
10000,
50)),
knee_histogram(metrics::HistogramFactoryGetCounts(knee_histogram_name,
1,
10000,
50)),
limiter_histogram(
metrics::HistogramFactoryGetCounts(limiter_histogram_name,
1,
10000,
50)),
saturation_histogram(
metrics::HistogramFactoryGetCounts(saturation_histogram_name,
1,
10000,
50)) {}
InterpolatedGainCurve::RegionLogger::~RegionLogger() = default;
void InterpolatedGainCurve::RegionLogger::LogRegionStats(
const InterpolatedGainCurve::Stats& stats) const {
using Region = InterpolatedGainCurve::GainCurveRegion;
const int duration_s =
stats.region_duration_frames / (1000 / kFrameDurationMs);
switch (stats.region) {
case Region::kIdentity: {
if (identity_histogram) {
metrics::HistogramAdd(identity_histogram, duration_s);
}
break;
}
case Region::kKnee: {
if (knee_histogram) {
metrics::HistogramAdd(knee_histogram, duration_s);
}
break;
}
case Region::kLimiter: {
if (limiter_histogram) {
metrics::HistogramAdd(limiter_histogram, duration_s);
}
break;
}
case Region::kSaturation: {
if (saturation_histogram) {
metrics::HistogramAdd(saturation_histogram, duration_s);
}
break;
}
default: {
RTC_DCHECK_NOTREACHED();
}
}
}
void InterpolatedGainCurve::UpdateStats(float input_level) const {
stats_.available = true;
GainCurveRegion region;
if (input_level < approximation_params_x_[0]) {
stats_.look_ups_identity_region++;
region = GainCurveRegion::kIdentity;
} else if (input_level <
approximation_params_x_[kInterpolatedGainCurveKneePoints - 1]) {
stats_.look_ups_knee_region++;
region = GainCurveRegion::kKnee;
} else if (input_level < kMaxInputLevelLinear) {
stats_.look_ups_limiter_region++;
region = GainCurveRegion::kLimiter;
} else {
stats_.look_ups_saturation_region++;
region = GainCurveRegion::kSaturation;
}
if (region == stats_.region) {
++stats_.region_duration_frames;
} else {
region_logger_.LogRegionStats(stats_);
stats_.region_duration_frames = 0;
stats_.region = region;
}
}
// Looks up a gain to apply given a non-negative input level.
// The cost of this operation depends on the region in which `input_level`
// falls.
// For the identity and the saturation regions the cost is O(1).
// For the other regions, namely knee and limiter, the cost is
// O(2 + log2(`LightkInterpolatedGainCurveTotalPoints`), plus O(1) for the
// linear interpolation (one product and one sum).
float InterpolatedGainCurve::LookUpGainToApply(float input_level) const {
UpdateStats(input_level);
if (input_level <= approximation_params_x_[0]) {
// Identity region.
return 1.0f;
}
if (input_level >= kMaxInputLevelLinear) {
// Saturating lower bound. The saturing samples exactly hit the clipping
// level. This method achieves has the lowest harmonic distorsion, but it
// may reduce the amplitude of the non-saturating samples too much.
return 32768.f / input_level;
}
// Knee and limiter regions; find the linear piece index. Spelling
// out the complete type was the only way to silence both the clang
// plugin and the windows compilers.
std::array<float, kInterpolatedGainCurveTotalPoints>::const_iterator it =
std::lower_bound(approximation_params_x_.begin(),
approximation_params_x_.end(), input_level);
const size_t index = std::distance(approximation_params_x_.begin(), it) - 1;
RTC_DCHECK_LE(0, index);
RTC_DCHECK_LT(index, approximation_params_m_.size());
RTC_DCHECK_LE(approximation_params_x_[index], input_level);
if (index < approximation_params_m_.size() - 1) {
RTC_DCHECK_LE(input_level, approximation_params_x_[index + 1]);
}
// Piece-wise linear interploation.
const float gain = approximation_params_m_[index] * input_level +
approximation_params_q_[index];
RTC_DCHECK_LE(0.f, gain);
return gain;
}
} // namespace webrtc

View file

@ -0,0 +1,152 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_
#define MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_
#include <array>
#include "absl/strings/string_view.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "rtc_base/gtest_prod_util.h"
#include "system_wrappers/include/metrics.h"
namespace webrtc {
class ApmDataDumper;
constexpr float kInputLevelScalingFactor = 32768.0f;
// Defined as DbfsToLinear(kLimiterMaxInputLevelDbFs)
constexpr float kMaxInputLevelLinear = static_cast<float>(36766.300710566735);
// Interpolated gain curve using under-approximation to avoid saturation.
//
// The goal of this class is allowing fast look ups to get an accurate
// estimates of the gain to apply given an estimated input level.
class InterpolatedGainCurve {
public:
enum class GainCurveRegion {
kIdentity = 0,
kKnee = 1,
kLimiter = 2,
kSaturation = 3
};
struct Stats {
// Region in which the output level equals the input one.
size_t look_ups_identity_region = 0;
// Smoothing between the identity and the limiter regions.
size_t look_ups_knee_region = 0;
// Limiter region in which the output and input levels are linearly related.
size_t look_ups_limiter_region = 0;
// Region in which saturation may occur since the input level is beyond the
// maximum expected by the limiter.
size_t look_ups_saturation_region = 0;
// True if stats have been populated.
bool available = false;
// The current region, and for how many frames the level has been
// in that region.
GainCurveRegion region = GainCurveRegion::kIdentity;
int64_t region_duration_frames = 0;
};
InterpolatedGainCurve(ApmDataDumper* apm_data_dumper,
absl::string_view histogram_name_prefix);
~InterpolatedGainCurve();
InterpolatedGainCurve(const InterpolatedGainCurve&) = delete;
InterpolatedGainCurve& operator=(const InterpolatedGainCurve&) = delete;
Stats get_stats() const { return stats_; }
// Given a non-negative input level (linear scale), a scalar factor to apply
// to a sub-frame is returned.
// Levels above kLimiterMaxInputLevelDbFs will be reduced to 0 dBFS
// after applying this gain
float LookUpGainToApply(float input_level) const;
private:
// For comparing 'approximation_params_*_' with ones computed by
// ComputeInterpolatedGainCurve.
FRIEND_TEST_ALL_PREFIXES(GainController2InterpolatedGainCurve,
CheckApproximationParams);
struct RegionLogger {
metrics::Histogram* identity_histogram;
metrics::Histogram* knee_histogram;
metrics::Histogram* limiter_histogram;
metrics::Histogram* saturation_histogram;
RegionLogger(absl::string_view identity_histogram_name,
absl::string_view knee_histogram_name,
absl::string_view limiter_histogram_name,
absl::string_view saturation_histogram_name);
~RegionLogger();
void LogRegionStats(const InterpolatedGainCurve::Stats& stats) const;
} region_logger_;
void UpdateStats(float input_level) const;
ApmDataDumper* const apm_data_dumper_;
static constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
approximation_params_x_ = {
{30057.296875, 30148.986328125, 30240.67578125, 30424.052734375,
30607.4296875, 30790.806640625, 30974.18359375, 31157.560546875,
31340.939453125, 31524.31640625, 31707.693359375, 31891.0703125,
32074.447265625, 32257.82421875, 32441.201171875, 32624.580078125,
32807.95703125, 32991.33203125, 33174.7109375, 33358.08984375,
33541.46484375, 33724.84375, 33819.53515625, 34009.5390625,
34200.05859375, 34389.81640625, 34674.48828125, 35054.375,
35434.86328125, 35814.81640625, 36195.16796875, 36575.03125}};
static constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
approximation_params_m_ = {
{-3.515235675877192989e-07, -1.050251626111275982e-06,
-2.085213736791047268e-06, -3.443004743530764244e-06,
-4.773849468620028347e-06, -6.077375928725814447e-06,
-7.353257842623861507e-06, -8.601219633419532329e-06,
-9.821013009059242904e-06, -1.101243378798244521e-05,
-1.217532644659513608e-05, -1.330956911260727793e-05,
-1.441507538402220234e-05, -1.549179251014720649e-05,
-1.653970684856176376e-05, -1.755882840370759368e-05,
-1.854918446042574942e-05, -1.951086778717581183e-05,
-2.044398024736437947e-05, -2.1348627342376858e-05,
-2.222496914328075945e-05, -2.265374678245279938e-05,
-2.242570917587727308e-05, -2.220122041762806475e-05,
-2.19802095671184361e-05, -2.176260204578284174e-05,
-2.133731686626560986e-05, -2.092481918225530535e-05,
-2.052459603874012828e-05, -2.013615448959171772e-05,
-1.975903069251216948e-05, -1.939277899509761482e-05}};
static constexpr std::array<float, kInterpolatedGainCurveTotalPoints>
approximation_params_q_ = {
{1.010565876960754395, 1.031631827354431152, 1.062929749488830566,
1.104239225387573242, 1.144973039627075195, 1.185109615325927734,
1.224629044532775879, 1.263512492179870605, 1.301741957664489746,
1.339300632476806641, 1.376173257827758789, 1.412345528602600098,
1.447803974151611328, 1.482536554336547852, 1.516532182693481445,
1.549780607223510742, 1.582272171974182129, 1.613999366760253906,
1.644955039024353027, 1.675132393836975098, 1.704526185989379883,
1.718986630439758301, 1.711274504661560059, 1.703639745712280273,
1.696081161499023438, 1.688597679138183594, 1.673851132392883301,
1.659391283988952637, 1.645209431648254395, 1.631297469139099121,
1.617647409439086914, 1.604251742362976074}};
// Stats.
mutable Stats stats_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_INTERPOLATED_GAIN_CURVE_H_

View file

@ -0,0 +1,155 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/limiter.h"
#include <algorithm>
#include <array>
#include <cmath>
#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
// This constant affects the way scaling factors are interpolated for the first
// sub-frame of a frame. Only in the case in which the first sub-frame has an
// estimated level which is greater than the that of the previous analyzed
// sub-frame, linear interpolation is replaced with a power function which
// reduces the chances of over-shooting (and hence saturation), however reducing
// the fixed gain effectiveness.
constexpr float kAttackFirstSubframeInterpolationPower = 8.0f;
void InterpolateFirstSubframe(float last_factor,
float current_factor,
rtc::ArrayView<float> subframe) {
const int n = rtc::dchecked_cast<int>(subframe.size());
constexpr float p = kAttackFirstSubframeInterpolationPower;
for (int i = 0; i < n; ++i) {
subframe[i] = std::pow(1.f - i / n, p) * (last_factor - current_factor) +
current_factor;
}
}
void ComputePerSampleSubframeFactors(
const std::array<float, kSubFramesInFrame + 1>& scaling_factors,
int samples_per_channel,
rtc::ArrayView<float> per_sample_scaling_factors) {
const int num_subframes = scaling_factors.size() - 1;
const int subframe_size =
rtc::CheckedDivExact(samples_per_channel, num_subframes);
// Handle first sub-frame differently in case of attack.
const bool is_attack = scaling_factors[0] > scaling_factors[1];
if (is_attack) {
InterpolateFirstSubframe(
scaling_factors[0], scaling_factors[1],
rtc::ArrayView<float>(
per_sample_scaling_factors.subview(0, subframe_size)));
}
for (int i = is_attack ? 1 : 0; i < num_subframes; ++i) {
const int subframe_start = i * subframe_size;
const float scaling_start = scaling_factors[i];
const float scaling_end = scaling_factors[i + 1];
const float scaling_diff = (scaling_end - scaling_start) / subframe_size;
for (int j = 0; j < subframe_size; ++j) {
per_sample_scaling_factors[subframe_start + j] =
scaling_start + scaling_diff * j;
}
}
}
void ScaleSamples(rtc::ArrayView<const float> per_sample_scaling_factors,
AudioFrameView<float> signal) {
const int samples_per_channel = signal.samples_per_channel();
RTC_DCHECK_EQ(samples_per_channel, per_sample_scaling_factors.size());
for (int i = 0; i < signal.num_channels(); ++i) {
rtc::ArrayView<float> channel = signal.channel(i);
for (int j = 0; j < samples_per_channel; ++j) {
channel[j] = rtc::SafeClamp(channel[j] * per_sample_scaling_factors[j],
kMinFloatS16Value, kMaxFloatS16Value);
}
}
}
void CheckLimiterSampleRate(int sample_rate_hz) {
// Check that per_sample_scaling_factors_ is large enough.
RTC_DCHECK_LE(sample_rate_hz,
kMaximalNumberOfSamplesPerChannel * 1000 / kFrameDurationMs);
}
} // namespace
Limiter::Limiter(int sample_rate_hz,
ApmDataDumper* apm_data_dumper,
absl::string_view histogram_name)
: interp_gain_curve_(apm_data_dumper, histogram_name),
level_estimator_(sample_rate_hz, apm_data_dumper),
apm_data_dumper_(apm_data_dumper) {
CheckLimiterSampleRate(sample_rate_hz);
}
Limiter::~Limiter() = default;
void Limiter::Process(AudioFrameView<float> signal) {
const std::array<float, kSubFramesInFrame> level_estimate =
level_estimator_.ComputeLevel(signal);
RTC_DCHECK_EQ(level_estimate.size() + 1, scaling_factors_.size());
scaling_factors_[0] = last_scaling_factor_;
std::transform(level_estimate.begin(), level_estimate.end(),
scaling_factors_.begin() + 1, [this](float x) {
return interp_gain_curve_.LookUpGainToApply(x);
});
const int samples_per_channel = signal.samples_per_channel();
RTC_DCHECK_LE(samples_per_channel, kMaximalNumberOfSamplesPerChannel);
auto per_sample_scaling_factors = rtc::ArrayView<float>(
&per_sample_scaling_factors_[0], samples_per_channel);
ComputePerSampleSubframeFactors(scaling_factors_, samples_per_channel,
per_sample_scaling_factors);
ScaleSamples(per_sample_scaling_factors, signal);
last_scaling_factor_ = scaling_factors_.back();
// Dump data for debug.
apm_data_dumper_->DumpRaw("agc2_limiter_last_scaling_factor",
last_scaling_factor_);
apm_data_dumper_->DumpRaw(
"agc2_limiter_region",
static_cast<int>(interp_gain_curve_.get_stats().region));
}
InterpolatedGainCurve::Stats Limiter::GetGainCurveStats() const {
return interp_gain_curve_.get_stats();
}
void Limiter::SetSampleRate(int sample_rate_hz) {
CheckLimiterSampleRate(sample_rate_hz);
level_estimator_.SetSampleRate(sample_rate_hz);
}
void Limiter::Reset() {
level_estimator_.Reset();
}
float Limiter::LastAudioLevel() const {
return level_estimator_.LastAudioLevel();
}
} // namespace webrtc

View file

@ -0,0 +1,63 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_
#include <vector>
#include "absl/strings/string_view.h"
#include "modules/audio_processing/agc2/fixed_digital_level_estimator.h"
#include "modules/audio_processing/agc2/interpolated_gain_curve.h"
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
class ApmDataDumper;
class Limiter {
public:
Limiter(int sample_rate_hz,
ApmDataDumper* apm_data_dumper,
absl::string_view histogram_name_prefix);
Limiter(const Limiter& limiter) = delete;
Limiter& operator=(const Limiter& limiter) = delete;
~Limiter();
// Applies limiter and hard-clipping to `signal`.
void Process(AudioFrameView<float> signal);
InterpolatedGainCurve::Stats GetGainCurveStats() const;
// Supported rates must be
// * supported by FixedDigitalLevelEstimator
// * below kMaximalNumberOfSamplesPerChannel*1000/kFrameDurationMs
// so that samples_per_channel fit in the
// per_sample_scaling_factors_ array.
void SetSampleRate(int sample_rate_hz);
// Resets the internal state.
void Reset();
float LastAudioLevel() const;
private:
const InterpolatedGainCurve interp_gain_curve_;
FixedDigitalLevelEstimator level_estimator_;
ApmDataDumper* const apm_data_dumper_ = nullptr;
// Work array containing the sub-frame scaling factors to be interpolated.
std::array<float, kSubFramesInFrame + 1> scaling_factors_ = {};
std::array<float, kMaximalNumberOfSamplesPerChannel>
per_sample_scaling_factors_ = {};
float last_scaling_factor_ = 1.f;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_LIMITER_H_

View file

@ -0,0 +1,138 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/limiter_db_gain_curve.h"
#include <cmath>
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
double ComputeKneeStart(double max_input_level_db,
double knee_smoothness_db,
double compression_ratio) {
RTC_CHECK_LT((compression_ratio - 1.0) * knee_smoothness_db /
(2.0 * compression_ratio),
max_input_level_db);
return -knee_smoothness_db / 2.0 -
max_input_level_db / (compression_ratio - 1.0);
}
std::array<double, 3> ComputeKneeRegionPolynomial(double knee_start_dbfs,
double knee_smoothness_db,
double compression_ratio) {
const double a = (1.0 - compression_ratio) /
(2.0 * knee_smoothness_db * compression_ratio);
const double b = 1.0 - 2.0 * a * knee_start_dbfs;
const double c = a * knee_start_dbfs * knee_start_dbfs;
return {{a, b, c}};
}
double ComputeLimiterD1(double max_input_level_db, double compression_ratio) {
return (std::pow(10.0, -max_input_level_db / (20.0 * compression_ratio)) *
(1.0 - compression_ratio) / compression_ratio) /
kMaxAbsFloatS16Value;
}
constexpr double ComputeLimiterD2(double compression_ratio) {
return (1.0 - 2.0 * compression_ratio) / compression_ratio;
}
double ComputeLimiterI2(double max_input_level_db,
double compression_ratio,
double gain_curve_limiter_i1) {
RTC_CHECK_NE(gain_curve_limiter_i1, 0.f);
return std::pow(10.0, -max_input_level_db / (20.0 * compression_ratio)) /
gain_curve_limiter_i1 /
std::pow(kMaxAbsFloatS16Value, gain_curve_limiter_i1 - 1);
}
} // namespace
LimiterDbGainCurve::LimiterDbGainCurve()
: max_input_level_linear_(DbfsToFloatS16(max_input_level_db_)),
knee_start_dbfs_(ComputeKneeStart(max_input_level_db_,
knee_smoothness_db_,
compression_ratio_)),
knee_start_linear_(DbfsToFloatS16(knee_start_dbfs_)),
limiter_start_dbfs_(knee_start_dbfs_ + knee_smoothness_db_),
limiter_start_linear_(DbfsToFloatS16(limiter_start_dbfs_)),
knee_region_polynomial_(ComputeKneeRegionPolynomial(knee_start_dbfs_,
knee_smoothness_db_,
compression_ratio_)),
gain_curve_limiter_d1_(
ComputeLimiterD1(max_input_level_db_, compression_ratio_)),
gain_curve_limiter_d2_(ComputeLimiterD2(compression_ratio_)),
gain_curve_limiter_i1_(1.0 / compression_ratio_),
gain_curve_limiter_i2_(ComputeLimiterI2(max_input_level_db_,
compression_ratio_,
gain_curve_limiter_i1_)) {
static_assert(knee_smoothness_db_ > 0.0f, "");
static_assert(compression_ratio_ > 1.0f, "");
RTC_CHECK_GE(max_input_level_db_, knee_start_dbfs_ + knee_smoothness_db_);
}
constexpr double LimiterDbGainCurve::max_input_level_db_;
constexpr double LimiterDbGainCurve::knee_smoothness_db_;
constexpr double LimiterDbGainCurve::compression_ratio_;
double LimiterDbGainCurve::GetOutputLevelDbfs(double input_level_dbfs) const {
if (input_level_dbfs < knee_start_dbfs_) {
return input_level_dbfs;
} else if (input_level_dbfs < limiter_start_dbfs_) {
return GetKneeRegionOutputLevelDbfs(input_level_dbfs);
}
return GetCompressorRegionOutputLevelDbfs(input_level_dbfs);
}
double LimiterDbGainCurve::GetGainLinear(double input_level_linear) const {
if (input_level_linear < knee_start_linear_) {
return 1.0;
}
return DbfsToFloatS16(
GetOutputLevelDbfs(FloatS16ToDbfs(input_level_linear))) /
input_level_linear;
}
// Computes the first derivative of GetGainLinear() in `x`.
double LimiterDbGainCurve::GetGainFirstDerivativeLinear(double x) const {
// Beyond-knee region only.
RTC_CHECK_GE(x, limiter_start_linear_ - 1e-7 * kMaxAbsFloatS16Value);
return gain_curve_limiter_d1_ *
std::pow(x / kMaxAbsFloatS16Value, gain_curve_limiter_d2_);
}
// Computes the integral of GetGainLinear() in the range [x0, x1].
double LimiterDbGainCurve::GetGainIntegralLinear(double x0, double x1) const {
RTC_CHECK_LE(x0, x1); // Valid interval.
RTC_CHECK_GE(x0, limiter_start_linear_); // Beyond-knee region only.
auto limiter_integral = [this](const double& x) {
return gain_curve_limiter_i2_ * std::pow(x, gain_curve_limiter_i1_);
};
return limiter_integral(x1) - limiter_integral(x0);
}
double LimiterDbGainCurve::GetKneeRegionOutputLevelDbfs(
double input_level_dbfs) const {
return knee_region_polynomial_[0] * input_level_dbfs * input_level_dbfs +
knee_region_polynomial_[1] * input_level_dbfs +
knee_region_polynomial_[2];
}
double LimiterDbGainCurve::GetCompressorRegionOutputLevelDbfs(
double input_level_dbfs) const {
return (input_level_dbfs - max_input_level_db_) / compression_ratio_;
}
} // namespace webrtc

View file

@ -0,0 +1,76 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_LIMITER_DB_GAIN_CURVE_H_
#define MODULES_AUDIO_PROCESSING_AGC2_LIMITER_DB_GAIN_CURVE_H_
#include <array>
#include "modules/audio_processing/agc2/agc2_testing_common.h"
namespace webrtc {
// A class for computing a limiter gain curve (in dB scale) given a set of
// hard-coded parameters (namely, kLimiterDbGainCurveMaxInputLevelDbFs,
// kLimiterDbGainCurveKneeSmoothnessDb, and
// kLimiterDbGainCurveCompressionRatio). The generated curve consists of four
// regions: identity (linear), knee (quadratic polynomial), compression
// (linear), saturation (linear). The aforementioned constants are used to shape
// the different regions.
class LimiterDbGainCurve {
public:
LimiterDbGainCurve();
double max_input_level_db() const { return max_input_level_db_; }
double max_input_level_linear() const { return max_input_level_linear_; }
double knee_start_linear() const { return knee_start_linear_; }
double limiter_start_linear() const { return limiter_start_linear_; }
// These methods can be marked 'constexpr' in C++ 14.
double GetOutputLevelDbfs(double input_level_dbfs) const;
double GetGainLinear(double input_level_linear) const;
double GetGainFirstDerivativeLinear(double x) const;
double GetGainIntegralLinear(double x0, double x1) const;
private:
double GetKneeRegionOutputLevelDbfs(double input_level_dbfs) const;
double GetCompressorRegionOutputLevelDbfs(double input_level_dbfs) const;
static constexpr double max_input_level_db_ = test::kLimiterMaxInputLevelDbFs;
static constexpr double knee_smoothness_db_ = test::kLimiterKneeSmoothnessDb;
static constexpr double compression_ratio_ = test::kLimiterCompressionRatio;
const double max_input_level_linear_;
// Do not modify signal with level <= knee_start_dbfs_.
const double knee_start_dbfs_;
const double knee_start_linear_;
// The upper end of the knee region, which is between knee_start_dbfs_ and
// limiter_start_dbfs_.
const double limiter_start_dbfs_;
const double limiter_start_linear_;
// Coefficients {a, b, c} of the knee region polynomial
// ax^2 + bx + c in the DB scale.
const std::array<double, 3> knee_region_polynomial_;
// Parameters for the computation of the first derivative of GetGainLinear().
const double gain_curve_limiter_d1_;
const double gain_curve_limiter_d2_;
// Parameters for the computation of the integral of GetGainLinear().
const double gain_curve_limiter_i1_;
const double gain_curve_limiter_i2_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_LIMITER_DB_GAIN_CURVE_H_

View file

@ -0,0 +1,60 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/limiter.h"
#include "common_audio/include/audio_util.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/agc2_testing_common.h"
#include "modules/audio_processing/agc2/vector_float_frame.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/gunit.h"
namespace webrtc {
TEST(Limiter, LimiterShouldConstructAndRun) {
const int sample_rate_hz = 48000;
ApmDataDumper apm_data_dumper(0);
Limiter limiter(sample_rate_hz, &apm_data_dumper, "");
VectorFloatFrame vectors_with_float_frame(1, sample_rate_hz / 100,
kMaxAbsFloatS16Value);
limiter.Process(vectors_with_float_frame.float_frame_view());
}
TEST(Limiter, OutputVolumeAboveThreshold) {
const int sample_rate_hz = 48000;
const float input_level =
(kMaxAbsFloatS16Value + DbfsToFloatS16(test::kLimiterMaxInputLevelDbFs)) /
2.f;
ApmDataDumper apm_data_dumper(0);
Limiter limiter(sample_rate_hz, &apm_data_dumper, "");
// Give the level estimator time to adapt.
for (int i = 0; i < 5; ++i) {
VectorFloatFrame vectors_with_float_frame(1, sample_rate_hz / 100,
input_level);
limiter.Process(vectors_with_float_frame.float_frame_view());
}
VectorFloatFrame vectors_with_float_frame(1, sample_rate_hz / 100,
input_level);
limiter.Process(vectors_with_float_frame.float_frame_view());
rtc::ArrayView<const float> channel =
vectors_with_float_frame.float_frame_view().channel(0);
for (const auto& sample : channel) {
EXPECT_LT(0.9f * kMaxAbsFloatS16Value, sample);
}
}
} // namespace webrtc

View file

@ -0,0 +1,172 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/noise_level_estimator.h"
#include <stddef.h>
#include <algorithm>
#include <cmath>
#include <numeric>
#include "api/array_view.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr int kFramesPerSecond = 100;
float FrameEnergy(const AudioFrameView<const float>& audio) {
float energy = 0.0f;
for (int k = 0; k < audio.num_channels(); ++k) {
float channel_energy =
std::accumulate(audio.channel(k).begin(), audio.channel(k).end(), 0.0f,
[](float a, float b) -> float { return a + b * b; });
energy = std::max(channel_energy, energy);
}
return energy;
}
float EnergyToDbfs(float signal_energy, int num_samples) {
RTC_DCHECK_GE(signal_energy, 0.0f);
const float rms_square = signal_energy / num_samples;
constexpr float kMinDbfs = -90.30899869919436f;
if (rms_square <= 1.0f) {
return kMinDbfs;
}
return 10.0f * std::log10(rms_square) + kMinDbfs;
}
// Updates the noise floor with instant decay and slow attack. This tuning is
// specific for AGC2, so that (i) it can promptly increase the gain if the noise
// floor drops (instant decay) and (ii) in case of music or fast speech, due to
// which the noise floor can be overestimated, the gain reduction is slowed
// down.
float SmoothNoiseFloorEstimate(float current_estimate, float new_estimate) {
constexpr float kAttack = 0.5f;
if (current_estimate < new_estimate) {
// Attack phase.
return kAttack * new_estimate + (1.0f - kAttack) * current_estimate;
}
// Instant attack.
return new_estimate;
}
class NoiseFloorEstimator : public NoiseLevelEstimator {
public:
// Update the noise floor every 5 seconds.
static constexpr int kUpdatePeriodNumFrames = 500;
static_assert(kUpdatePeriodNumFrames >= 200,
"A too small value may cause noise level overestimation.");
static_assert(kUpdatePeriodNumFrames <= 1500,
"A too large value may make AGC2 slow at reacting to increased "
"noise levels.");
NoiseFloorEstimator(ApmDataDumper* data_dumper) : data_dumper_(data_dumper) {
RTC_DCHECK(data_dumper_);
// Initially assume that 48 kHz will be used. `Analyze()` will detect the
// used sample rate and call `Initialize()` again if needed.
Initialize(/*sample_rate_hz=*/48000);
}
NoiseFloorEstimator(const NoiseFloorEstimator&) = delete;
NoiseFloorEstimator& operator=(const NoiseFloorEstimator&) = delete;
~NoiseFloorEstimator() = default;
float Analyze(const AudioFrameView<const float>& frame) override {
// Detect sample rate changes.
const int sample_rate_hz =
static_cast<int>(frame.samples_per_channel() * kFramesPerSecond);
if (sample_rate_hz != sample_rate_hz_) {
Initialize(sample_rate_hz);
}
const float frame_energy = FrameEnergy(frame);
if (frame_energy <= min_noise_energy_) {
// Ignore frames when muted or below the minimum measurable energy.
if (data_dumper_)
data_dumper_->DumpRaw("agc2_noise_floor_estimator_preliminary_level",
noise_energy_);
return EnergyToDbfs(noise_energy_,
static_cast<int>(frame.samples_per_channel()));
}
if (preliminary_noise_energy_set_) {
preliminary_noise_energy_ =
std::min(preliminary_noise_energy_, frame_energy);
} else {
preliminary_noise_energy_ = frame_energy;
preliminary_noise_energy_set_ = true;
}
if (data_dumper_)
data_dumper_->DumpRaw("agc2_noise_floor_estimator_preliminary_level",
preliminary_noise_energy_);
if (counter_ == 0) {
// Full period observed.
first_period_ = false;
// Update the estimated noise floor energy with the preliminary
// estimation.
noise_energy_ = SmoothNoiseFloorEstimate(
/*current_estimate=*/noise_energy_,
/*new_estimate=*/preliminary_noise_energy_);
// Reset for a new observation period.
counter_ = kUpdatePeriodNumFrames;
preliminary_noise_energy_set_ = false;
} else if (first_period_) {
// While analyzing the signal during the initial period, continuously
// update the estimated noise energy, which is monotonic.
noise_energy_ = preliminary_noise_energy_;
counter_--;
} else {
// During the observation period it's only allowed to lower the energy.
noise_energy_ = std::min(noise_energy_, preliminary_noise_energy_);
counter_--;
}
float noise_rms_dbfs = EnergyToDbfs(
noise_energy_, static_cast<int>(frame.samples_per_channel()));
if (data_dumper_)
data_dumper_->DumpRaw("agc2_noise_rms_dbfs", noise_rms_dbfs);
return noise_rms_dbfs;
}
private:
void Initialize(int sample_rate_hz) {
sample_rate_hz_ = sample_rate_hz;
first_period_ = true;
preliminary_noise_energy_set_ = false;
// Initialize the minimum noise energy to -84 dBFS.
min_noise_energy_ = sample_rate_hz * 2.0f * 2.0f / kFramesPerSecond;
preliminary_noise_energy_ = min_noise_energy_;
noise_energy_ = min_noise_energy_;
counter_ = kUpdatePeriodNumFrames;
}
ApmDataDumper* const data_dumper_;
int sample_rate_hz_;
float min_noise_energy_;
bool first_period_;
bool preliminary_noise_energy_set_;
float preliminary_noise_energy_;
float noise_energy_;
int counter_;
};
} // namespace
std::unique_ptr<NoiseLevelEstimator> CreateNoiseFloorEstimator(
ApmDataDumper* data_dumper) {
return std::make_unique<NoiseFloorEstimator>(data_dumper);
}
} // namespace webrtc

View file

@ -0,0 +1,36 @@
/*
* Copyright (c) 2016 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_
#include <memory>
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
class ApmDataDumper;
// Noise level estimator interface.
class NoiseLevelEstimator {
public:
virtual ~NoiseLevelEstimator() = default;
// Analyzes a 10 ms `frame`, updates the noise level estimation and returns
// the value for the latter in dBFS.
virtual float Analyze(const AudioFrameView<const float>& frame) = 0;
};
// Creates a noise level estimator based on noise floor detection.
std::unique_ptr<NoiseLevelEstimator> CreateNoiseFloorEstimator(
ApmDataDumper* data_dumper);
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_NOISE_LEVEL_ESTIMATOR_H_

View file

@ -0,0 +1,3 @@
include_rules = [
"+third_party/rnnoise",
]

View file

@ -0,0 +1,91 @@
/*
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
#include <algorithm>
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
namespace {
constexpr int kAutoCorrelationFftOrder = 9; // Length-512 FFT.
static_assert(1 << kAutoCorrelationFftOrder >
kNumLags12kHz + kBufSize12kHz - kMaxPitch12kHz,
"");
} // namespace
AutoCorrelationCalculator::AutoCorrelationCalculator()
: fft_(1 << kAutoCorrelationFftOrder, Pffft::FftType::kReal),
tmp_(fft_.CreateBuffer()),
X_(fft_.CreateBuffer()),
H_(fft_.CreateBuffer()) {}
AutoCorrelationCalculator::~AutoCorrelationCalculator() = default;
// The auto-correlations coefficients are computed as follows:
// |.........|...........| <- pitch buffer
// [ x (fixed) ]
// [ y_0 ]
// [ y_{m-1} ]
// x and y are sub-array of equal length; x is never moved, whereas y slides.
// The cross-correlation between y_0 and x corresponds to the auto-correlation
// for the maximum pitch period. Hence, the first value in `auto_corr` has an
// inverted lag equal to 0 that corresponds to a lag equal to the maximum
// pitch period.
void AutoCorrelationCalculator::ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumLags12kHz> auto_corr) {
RTC_DCHECK_LT(auto_corr.size(), kMaxPitch12kHz);
RTC_DCHECK_GT(pitch_buf.size(), kMaxPitch12kHz);
constexpr int kFftFrameSize = 1 << kAutoCorrelationFftOrder;
constexpr int kConvolutionLength = kBufSize12kHz - kMaxPitch12kHz;
static_assert(kConvolutionLength == kFrameSize20ms12kHz,
"Mismatch between pitch buffer size, frame size and maximum "
"pitch period.");
static_assert(kFftFrameSize > kNumLags12kHz + kConvolutionLength,
"The FFT length is not sufficiently big to avoid cyclic "
"convolution errors.");
auto tmp = tmp_->GetView();
// Compute the FFT for the reversed reference frame - i.e.,
// pitch_buf[-kConvolutionLength:].
std::reverse_copy(pitch_buf.end() - kConvolutionLength, pitch_buf.end(),
tmp.begin());
std::fill(tmp.begin() + kConvolutionLength, tmp.end(), 0.f);
fft_.ForwardTransform(*tmp_, H_.get(), /*ordered=*/false);
// Compute the FFT for the sliding frames chunk. The sliding frames are
// defined as pitch_buf[i:i+kConvolutionLength] where i in
// [0, kNumLags12kHz). The chunk includes all of them, hence it is
// defined as pitch_buf[:kNumLags12kHz+kConvolutionLength].
std::copy(pitch_buf.begin(),
pitch_buf.begin() + kConvolutionLength + kNumLags12kHz,
tmp.begin());
std::fill(tmp.begin() + kNumLags12kHz + kConvolutionLength, tmp.end(), 0.f);
fft_.ForwardTransform(*tmp_, X_.get(), /*ordered=*/false);
// Convolve in the frequency domain.
constexpr float kScalingFactor = 1.f / static_cast<float>(kFftFrameSize);
std::fill(tmp.begin(), tmp.end(), 0.f);
fft_.FrequencyDomainConvolve(*X_, *H_, tmp_.get(), kScalingFactor);
fft_.BackwardTransform(*tmp_, tmp_.get(), /*ordered=*/false);
// Extract the auto-correlation coefficients.
std::copy(tmp.begin() + kConvolutionLength - 1,
tmp.begin() + kConvolutionLength + kNumLags12kHz - 1,
auto_corr.begin());
}
} // namespace rnn_vad
} // namespace webrtc

View file

@ -0,0 +1,49 @@
/*
* Copyright (c) 2019 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_
#include <memory>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/utility/pffft_wrapper.h"
namespace webrtc {
namespace rnn_vad {
// Class to compute the auto correlation on the pitch buffer for a target pitch
// interval.
class AutoCorrelationCalculator {
public:
AutoCorrelationCalculator();
AutoCorrelationCalculator(const AutoCorrelationCalculator&) = delete;
AutoCorrelationCalculator& operator=(const AutoCorrelationCalculator&) =
delete;
~AutoCorrelationCalculator();
// Computes the auto-correlation coefficients for a target pitch interval.
// `auto_corr` indexes are inverted lags.
void ComputeOnPitchBuffer(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buf,
rtc::ArrayView<float, kNumLags12kHz> auto_corr);
private:
Pffft fft_;
std::unique_ptr<Pffft::FloatBuffer> tmp_;
std::unique_ptr<Pffft::FloatBuffer> X_;
std::unique_ptr<Pffft::FloatBuffer> H_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_AUTO_CORRELATION_H_

View file

@ -0,0 +1,77 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_
#include <stddef.h>
namespace webrtc {
namespace rnn_vad {
constexpr double kPi = 3.14159265358979323846;
constexpr int kSampleRate24kHz = 24000;
constexpr int kFrameSize10ms24kHz = kSampleRate24kHz / 100;
constexpr int kFrameSize20ms24kHz = kFrameSize10ms24kHz * 2;
// Pitch buffer.
constexpr int kMinPitch24kHz = kSampleRate24kHz / 800; // 0.00125 s.
constexpr int kMaxPitch24kHz = kSampleRate24kHz / 62.5; // 0.016 s.
constexpr int kBufSize24kHz = kMaxPitch24kHz + kFrameSize20ms24kHz;
static_assert((kBufSize24kHz & 1) == 0, "The buffer size must be even.");
// 24 kHz analysis.
// Define a higher minimum pitch period for the initial search. This is used to
// avoid searching for very short periods, for which a refinement step is
// responsible.
constexpr int kInitialMinPitch24kHz = 3 * kMinPitch24kHz;
static_assert(kMinPitch24kHz < kInitialMinPitch24kHz, "");
static_assert(kInitialMinPitch24kHz < kMaxPitch24kHz, "");
static_assert(kMaxPitch24kHz > kInitialMinPitch24kHz, "");
// Number of (inverted) lags during the initial pitch search phase at 24 kHz.
constexpr int kInitialNumLags24kHz = kMaxPitch24kHz - kInitialMinPitch24kHz;
// Number of (inverted) lags during the pitch search refinement phase at 24 kHz.
constexpr int kRefineNumLags24kHz = kMaxPitch24kHz + 1;
static_assert(
kRefineNumLags24kHz > kInitialNumLags24kHz,
"The refinement step must search the pitch in an extended pitch range.");
// 12 kHz analysis.
constexpr int kSampleRate12kHz = 12000;
constexpr int kFrameSize10ms12kHz = kSampleRate12kHz / 100;
constexpr int kFrameSize20ms12kHz = kFrameSize10ms12kHz * 2;
constexpr int kBufSize12kHz = kBufSize24kHz / 2;
constexpr int kInitialMinPitch12kHz = kInitialMinPitch24kHz / 2;
constexpr int kMaxPitch12kHz = kMaxPitch24kHz / 2;
static_assert(kMaxPitch12kHz > kInitialMinPitch12kHz, "");
// The inverted lags for the pitch interval [`kInitialMinPitch12kHz`,
// `kMaxPitch12kHz`] are in the range [0, `kNumLags12kHz`].
constexpr int kNumLags12kHz = kMaxPitch12kHz - kInitialMinPitch12kHz;
// 48 kHz constants.
constexpr int kMinPitch48kHz = kMinPitch24kHz * 2;
constexpr int kMaxPitch48kHz = kMaxPitch24kHz * 2;
// Spectral features.
constexpr int kNumBands = 22;
constexpr int kNumLowerBands = 6;
static_assert((0 < kNumLowerBands) && (kNumLowerBands < kNumBands), "");
constexpr int kCepstralCoeffsHistorySize = 8;
static_assert(kCepstralCoeffsHistorySize > 2,
"The history size must at least be 3 to compute first and second "
"derivatives.");
constexpr int kFeatureVectorSize = 42;
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_COMMON_H_

View file

@ -0,0 +1,90 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include <array>
#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
namespace {
// Computed as `scipy.signal.butter(N=2, Wn=60/24000, btype='highpass')`.
constexpr BiQuadFilter::Config kHpfConfig24k{
{0.99446179f, -1.98892358f, 0.99446179f},
{-1.98889291f, 0.98895425f}};
} // namespace
FeaturesExtractor::FeaturesExtractor(const AvailableCpuFeatures& cpu_features)
: use_high_pass_filter_(false),
hpf_(kHpfConfig24k),
pitch_buf_24kHz_(),
pitch_buf_24kHz_view_(pitch_buf_24kHz_.GetBufferView()),
lp_residual_(kBufSize24kHz),
lp_residual_view_(lp_residual_.data(), kBufSize24kHz),
pitch_estimator_(cpu_features),
reference_frame_view_(pitch_buf_24kHz_.GetMostRecentValuesView()) {
RTC_DCHECK_EQ(kBufSize24kHz, lp_residual_.size());
Reset();
}
FeaturesExtractor::~FeaturesExtractor() = default;
void FeaturesExtractor::Reset() {
pitch_buf_24kHz_.Reset();
spectral_features_extractor_.Reset();
if (use_high_pass_filter_) {
hpf_.Reset();
}
}
bool FeaturesExtractor::CheckSilenceComputeFeatures(
rtc::ArrayView<const float, kFrameSize10ms24kHz> samples,
rtc::ArrayView<float, kFeatureVectorSize> feature_vector) {
// Pre-processing.
if (use_high_pass_filter_) {
std::array<float, kFrameSize10ms24kHz> samples_filtered;
hpf_.Process(samples, samples_filtered);
// Feed buffer with the pre-processed version of `samples`.
pitch_buf_24kHz_.Push(samples_filtered);
} else {
// Feed buffer with `samples`.
pitch_buf_24kHz_.Push(samples);
}
// Extract the LP residual.
float lpc_coeffs[kNumLpcCoefficients];
ComputeAndPostProcessLpcCoefficients(pitch_buf_24kHz_view_, lpc_coeffs);
ComputeLpResidual(lpc_coeffs, pitch_buf_24kHz_view_, lp_residual_view_);
// Estimate pitch on the LP-residual and write the normalized pitch period
// into the output vector (normalization based on training data stats).
pitch_period_48kHz_ = pitch_estimator_.Estimate(lp_residual_view_);
feature_vector[kFeatureVectorSize - 2] = 0.01f * (pitch_period_48kHz_ - 300);
// Extract lagged frames (according to the estimated pitch period).
RTC_DCHECK_LE(pitch_period_48kHz_ / 2, kMaxPitch24kHz);
auto lagged_frame = pitch_buf_24kHz_view_.subview(
kMaxPitch24kHz - pitch_period_48kHz_ / 2, kFrameSize20ms24kHz);
// Analyze reference and lagged frames checking if silence has been detected
// and write the feature vector.
return spectral_features_extractor_.CheckSilenceComputeFeatures(
reference_frame_view_, {lagged_frame.data(), kFrameSize20ms24kHz},
{feature_vector.data() + kNumLowerBands, kNumBands - kNumLowerBands},
{feature_vector.data(), kNumLowerBands},
{feature_vector.data() + kNumBands, kNumLowerBands},
{feature_vector.data() + kNumBands + kNumLowerBands, kNumLowerBands},
{feature_vector.data() + kNumBands + 2 * kNumLowerBands, kNumLowerBands},
&feature_vector[kFeatureVectorSize - 1]);
}
} // namespace rnn_vad
} // namespace webrtc

View file

@ -0,0 +1,61 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/biquad_filter.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
#include "modules/audio_processing/agc2/rnn_vad/sequence_buffer.h"
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
namespace webrtc {
namespace rnn_vad {
// Feature extractor to feed the VAD RNN.
class FeaturesExtractor {
public:
explicit FeaturesExtractor(const AvailableCpuFeatures& cpu_features);
FeaturesExtractor(const FeaturesExtractor&) = delete;
FeaturesExtractor& operator=(const FeaturesExtractor&) = delete;
~FeaturesExtractor();
void Reset();
// Analyzes the samples, computes the feature vector and returns true if
// silence is detected (false if not). When silence is detected,
// `feature_vector` is partially written and therefore must not be used to
// feed the VAD RNN.
bool CheckSilenceComputeFeatures(
rtc::ArrayView<const float, kFrameSize10ms24kHz> samples,
rtc::ArrayView<float, kFeatureVectorSize> feature_vector);
private:
const bool use_high_pass_filter_;
// TODO(bugs.webrtc.org/7494): Remove HPF depending on how AGC2 is used in APM
// and on whether an HPF is already used as pre-processing step in APM.
BiQuadFilter hpf_;
SequenceBuffer<float, kBufSize24kHz, kFrameSize10ms24kHz, kFrameSize20ms24kHz>
pitch_buf_24kHz_;
rtc::ArrayView<const float, kBufSize24kHz> pitch_buf_24kHz_view_;
std::vector<float> lp_residual_;
rtc::ArrayView<float, kBufSize24kHz> lp_residual_view_;
PitchEstimator pitch_estimator_;
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame_view_;
SpectralFeaturesExtractor spectral_features_extractor_;
int pitch_period_48kHz_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_FEATURES_EXTRACTION_H_

View file

@ -0,0 +1,141 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/lp_residual.h"
#include <algorithm>
#include <array>
#include <cmath>
#include <numeric>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
namespace {
// Computes auto-correlation coefficients for `x` and writes them in
// `auto_corr`. The lag values are in {0, ..., max_lag - 1}, where max_lag
// equals the size of `auto_corr`.
void ComputeAutoCorrelation(
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kNumLpcCoefficients> auto_corr) {
constexpr int max_lag = auto_corr.size();
RTC_DCHECK_LT(max_lag, x.size());
for (int lag = 0; lag < max_lag; ++lag) {
auto_corr[lag] =
std::inner_product(x.begin(), x.end() - lag, x.begin() + lag, 0.f);
}
}
// Applies denoising to the auto-correlation coefficients.
void DenoiseAutoCorrelation(
rtc::ArrayView<float, kNumLpcCoefficients> auto_corr) {
// Assume -40 dB white noise floor.
auto_corr[0] *= 1.0001f;
// Hard-coded values obtained as
// [np.float32((0.008*0.008*i*i)) for i in range(1,5)].
auto_corr[1] -= auto_corr[1] * 0.000064f;
auto_corr[2] -= auto_corr[2] * 0.000256f;
auto_corr[3] -= auto_corr[3] * 0.000576f;
auto_corr[4] -= auto_corr[4] * 0.001024f;
static_assert(kNumLpcCoefficients == 5, "Update `auto_corr`.");
}
// Computes the initial inverse filter coefficients given the auto-correlation
// coefficients of an input frame.
void ComputeInitialInverseFilterCoefficients(
rtc::ArrayView<const float, kNumLpcCoefficients> auto_corr,
rtc::ArrayView<float, kNumLpcCoefficients - 1> lpc_coeffs) {
float error = auto_corr[0];
for (int i = 0; i < kNumLpcCoefficients - 1; ++i) {
float reflection_coeff = 0.f;
for (int j = 0; j < i; ++j) {
reflection_coeff += lpc_coeffs[j] * auto_corr[i - j];
}
reflection_coeff += auto_corr[i + 1];
// Avoid division by numbers close to zero.
constexpr float kMinErrorMagnitude = 1e-6f;
if (std::fabs(error) < kMinErrorMagnitude) {
error = std::copysign(kMinErrorMagnitude, error);
}
reflection_coeff /= -error;
// Update LPC coefficients and total error.
lpc_coeffs[i] = reflection_coeff;
for (int j = 0; j < ((i + 1) >> 1); ++j) {
const float tmp1 = lpc_coeffs[j];
const float tmp2 = lpc_coeffs[i - 1 - j];
lpc_coeffs[j] = tmp1 + reflection_coeff * tmp2;
lpc_coeffs[i - 1 - j] = tmp2 + reflection_coeff * tmp1;
}
error -= reflection_coeff * reflection_coeff * error;
if (error < 0.001f * auto_corr[0]) {
break;
}
}
}
} // namespace
void ComputeAndPostProcessLpcCoefficients(
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs) {
std::array<float, kNumLpcCoefficients> auto_corr;
ComputeAutoCorrelation(x, auto_corr);
if (auto_corr[0] == 0.f) { // Empty frame.
std::fill(lpc_coeffs.begin(), lpc_coeffs.end(), 0);
return;
}
DenoiseAutoCorrelation(auto_corr);
std::array<float, kNumLpcCoefficients - 1> lpc_coeffs_pre{};
ComputeInitialInverseFilterCoefficients(auto_corr, lpc_coeffs_pre);
// LPC coefficients post-processing.
// TODO(bugs.webrtc.org/9076): Consider removing these steps.
lpc_coeffs_pre[0] *= 0.9f;
lpc_coeffs_pre[1] *= 0.9f * 0.9f;
lpc_coeffs_pre[2] *= 0.9f * 0.9f * 0.9f;
lpc_coeffs_pre[3] *= 0.9f * 0.9f * 0.9f * 0.9f;
constexpr float kC = 0.8f;
lpc_coeffs[0] = lpc_coeffs_pre[0] + kC;
lpc_coeffs[1] = lpc_coeffs_pre[1] + kC * lpc_coeffs_pre[0];
lpc_coeffs[2] = lpc_coeffs_pre[2] + kC * lpc_coeffs_pre[1];
lpc_coeffs[3] = lpc_coeffs_pre[3] + kC * lpc_coeffs_pre[2];
lpc_coeffs[4] = kC * lpc_coeffs_pre[3];
static_assert(kNumLpcCoefficients == 5, "Update `lpc_coeffs(_pre)`.");
}
void ComputeLpResidual(
rtc::ArrayView<const float, kNumLpcCoefficients> lpc_coeffs,
rtc::ArrayView<const float> x,
rtc::ArrayView<float> y) {
RTC_DCHECK_GT(x.size(), kNumLpcCoefficients);
RTC_DCHECK_EQ(x.size(), y.size());
// The code below implements the following operation:
// y[i] = x[i] + dot_product({x[i], ..., x[i - kNumLpcCoefficients + 1]},
// lpc_coeffs)
// Edge case: i < kNumLpcCoefficients.
y[0] = x[0];
for (int i = 1; i < kNumLpcCoefficients; ++i) {
y[i] =
std::inner_product(x.crend() - i, x.crend(), lpc_coeffs.cbegin(), x[i]);
}
// Regular case.
auto last = x.crend();
for (int i = kNumLpcCoefficients; rtc::SafeLt(i, y.size()); ++i, --last) {
y[i] = std::inner_product(last - kNumLpcCoefficients, last,
lpc_coeffs.cbegin(), x[i]);
}
}
} // namespace rnn_vad
} // namespace webrtc

View file

@ -0,0 +1,41 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_
#include <stddef.h>
#include "api/array_view.h"
namespace webrtc {
namespace rnn_vad {
// Linear predictive coding (LPC) inverse filter length.
constexpr int kNumLpcCoefficients = 5;
// Given a frame `x`, computes a post-processed version of LPC coefficients
// tailored for pitch estimation.
void ComputeAndPostProcessLpcCoefficients(
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kNumLpcCoefficients> lpc_coeffs);
// Computes the LP residual for the input frame `x` and the LPC coefficients
// `lpc_coeffs`. `y` and `x` can point to the same array for in-place
// computation.
void ComputeLpResidual(
rtc::ArrayView<const float, kNumLpcCoefficients> lpc_coeffs,
rtc::ArrayView<const float> x,
rtc::ArrayView<float> y);
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_LP_RESIDUAL_H_

View file

@ -0,0 +1,70 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/pitch_search.h"
#include <array>
#include <cstddef>
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
PitchEstimator::PitchEstimator(const AvailableCpuFeatures& cpu_features)
: cpu_features_(cpu_features),
y_energy_24kHz_(kRefineNumLags24kHz, 0.f),
pitch_buffer_12kHz_(kBufSize12kHz),
auto_correlation_12kHz_(kNumLags12kHz) {}
PitchEstimator::~PitchEstimator() = default;
int PitchEstimator::Estimate(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer) {
rtc::ArrayView<float, kBufSize12kHz> pitch_buffer_12kHz_view(
pitch_buffer_12kHz_.data(), kBufSize12kHz);
RTC_DCHECK_EQ(pitch_buffer_12kHz_.size(), pitch_buffer_12kHz_view.size());
rtc::ArrayView<float, kNumLags12kHz> auto_correlation_12kHz_view(
auto_correlation_12kHz_.data(), kNumLags12kHz);
RTC_DCHECK_EQ(auto_correlation_12kHz_.size(),
auto_correlation_12kHz_view.size());
// TODO(bugs.chromium.org/10480): Use `cpu_features_` to estimate pitch.
// Perform the initial pitch search at 12 kHz.
Decimate2x(pitch_buffer, pitch_buffer_12kHz_view);
auto_corr_calculator_.ComputeOnPitchBuffer(pitch_buffer_12kHz_view,
auto_correlation_12kHz_view);
CandidatePitchPeriods pitch_periods = ComputePitchPeriod12kHz(
pitch_buffer_12kHz_view, auto_correlation_12kHz_view, cpu_features_);
// The refinement is done using the pitch buffer that contains 24 kHz samples.
// Therefore, adapt the inverted lags in `pitch_candidates_inv_lags` from 12
// to 24 kHz.
pitch_periods.best *= 2;
pitch_periods.second_best *= 2;
// Refine the initial pitch period estimation from 12 kHz to 48 kHz.
// Pre-compute frame energies at 24 kHz.
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy_24kHz_view(
y_energy_24kHz_.data(), kRefineNumLags24kHz);
RTC_DCHECK_EQ(y_energy_24kHz_.size(), y_energy_24kHz_view.size());
ComputeSlidingFrameSquareEnergies24kHz(pitch_buffer, y_energy_24kHz_view,
cpu_features_);
// Estimation at 48 kHz.
const int pitch_lag_48kHz = ComputePitchPeriod48kHz(
pitch_buffer, y_energy_24kHz_view, pitch_periods, cpu_features_);
last_pitch_48kHz_ = ComputeExtendedPitchPeriod48kHz(
pitch_buffer, y_energy_24kHz_view,
/*initial_pitch_period_48kHz=*/kMaxPitch48kHz - pitch_lag_48kHz,
last_pitch_48kHz_, cpu_features_);
return last_pitch_48kHz_.period;
}
} // namespace rnn_vad
} // namespace webrtc

View file

@ -0,0 +1,54 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_
#include <memory>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/auto_correlation.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
namespace rnn_vad {
// Pitch estimator.
class PitchEstimator {
public:
explicit PitchEstimator(const AvailableCpuFeatures& cpu_features);
PitchEstimator(const PitchEstimator&) = delete;
PitchEstimator& operator=(const PitchEstimator&) = delete;
~PitchEstimator();
// Returns the estimated pitch period at 48 kHz.
int Estimate(rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer);
private:
FRIEND_TEST_ALL_PREFIXES(RnnVadTest, PitchSearchWithinTolerance);
float GetLastPitchStrengthForTesting() const {
return last_pitch_48kHz_.strength;
}
const AvailableCpuFeatures cpu_features_;
PitchInfo last_pitch_48kHz_{};
AutoCorrelationCalculator auto_corr_calculator_;
std::vector<float> y_energy_24kHz_;
std::vector<float> pitch_buffer_12kHz_;
std::vector<float> auto_correlation_12kHz_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_H_

View file

@ -0,0 +1,513 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/pitch_search_internal.h"
#include <stdlib.h>
#include <algorithm>
#include <cmath>
#include <cstddef>
#include <numeric>
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "rtc_base/system/arch.h"
namespace webrtc {
namespace rnn_vad {
namespace {
float ComputeAutoCorrelation(
int inverted_lag,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
const VectorMath& vector_math) {
RTC_DCHECK_LT(inverted_lag, kBufSize24kHz);
RTC_DCHECK_LT(inverted_lag, kRefineNumLags24kHz);
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
return vector_math.DotProduct(
pitch_buffer.subview(/*offset=*/kMaxPitch24kHz),
pitch_buffer.subview(inverted_lag, kFrameSize20ms24kHz));
}
// Given an auto-correlation coefficient `curr_auto_correlation` and its
// neighboring values `prev_auto_correlation` and `next_auto_correlation`
// computes a pseudo-interpolation offset to be applied to the pitch period
// associated to `curr`. The output is a lag in {-1, 0, +1}.
// TODO(bugs.webrtc.org/9076): Consider removing this method.
// `GetPitchPseudoInterpolationOffset()` it is relevant only if the spectral
// analysis works at a sample rate that is twice as that of the pitch buffer;
// In particular, it is not relevant for the estimated pitch period feature fed
// into the RNN.
int GetPitchPseudoInterpolationOffset(float prev_auto_correlation,
float curr_auto_correlation,
float next_auto_correlation) {
if ((next_auto_correlation - prev_auto_correlation) >
0.7f * (curr_auto_correlation - prev_auto_correlation)) {
return 1; // `next_auto_correlation` is the largest auto-correlation
// coefficient.
} else if ((prev_auto_correlation - next_auto_correlation) >
0.7f * (curr_auto_correlation - next_auto_correlation)) {
return -1; // `prev_auto_correlation` is the largest auto-correlation
// coefficient.
}
return 0;
}
// Refines a pitch period `lag` encoded as lag with pseudo-interpolation. The
// output sample rate is twice as that of `lag`.
int PitchPseudoInterpolationLagPitchBuf(
int lag,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
const VectorMath& vector_math) {
int offset = 0;
// Cannot apply pseudo-interpolation at the boundaries.
if (lag > 0 && lag < kMaxPitch24kHz) {
const int inverted_lag = kMaxPitch24kHz - lag;
offset = GetPitchPseudoInterpolationOffset(
ComputeAutoCorrelation(inverted_lag + 1, pitch_buffer, vector_math),
ComputeAutoCorrelation(inverted_lag, pitch_buffer, vector_math),
ComputeAutoCorrelation(inverted_lag - 1, pitch_buffer, vector_math));
}
return 2 * lag + offset;
}
// Integer multipliers used in ComputeExtendedPitchPeriod48kHz() when
// looking for sub-harmonics.
// The values have been chosen to serve the following algorithm. Given the
// initial pitch period T, we examine whether one of its harmonics is the true
// fundamental frequency. We consider T/k with k in {2, ..., 15}. For each of
// these harmonics, in addition to the pitch strength of itself, we choose one
// multiple of its pitch period, n*T/k, to validate it (by averaging their pitch
// strengths). The multiplier n is chosen so that n*T/k is used only one time
// over all k. When for example k = 4, we should also expect a peak at 3*T/4.
// When k = 8 instead we don't want to look at 2*T/8, since we have already
// checked T/4 before. Instead, we look at T*3/8.
// The array can be generate in Python as follows:
// from fractions import Fraction
// # Smallest positive integer not in X.
// def mex(X):
// for i in range(1, int(max(X)+2)):
// if i not in X:
// return i
// # Visited multiples of the period.
// S = {1}
// for n in range(2, 16):
// sn = mex({n * i for i in S} | {1})
// S = S | {Fraction(1, n), Fraction(sn, n)}
// print(sn, end=', ')
constexpr std::array<int, 14> kSubHarmonicMultipliers = {
{3, 2, 3, 2, 5, 2, 3, 2, 3, 2, 5, 2, 3, 2}};
struct Range {
int min;
int max;
};
// Number of analyzed pitches to the left(right) of a pitch candidate.
constexpr int kPitchNeighborhoodRadius = 2;
// Creates a pitch period interval centered in `inverted_lag` with hard-coded
// radius. Clipping is applied so that the interval is always valid for a 24 kHz
// pitch buffer.
Range CreateInvertedLagRange(int inverted_lag) {
return {std::max(inverted_lag - kPitchNeighborhoodRadius, 0),
std::min(inverted_lag + kPitchNeighborhoodRadius,
kInitialNumLags24kHz - 1)};
}
constexpr int kNumPitchCandidates = 2; // Best and second best.
// Maximum number of analyzed pitch periods.
constexpr int kMaxPitchPeriods24kHz =
kNumPitchCandidates * (2 * kPitchNeighborhoodRadius + 1);
// Collection of inverted lags.
class InvertedLagsIndex {
public:
InvertedLagsIndex() : num_entries_(0) {}
// Adds an inverted lag to the index. Cannot add more than
// `kMaxPitchPeriods24kHz` values.
void Append(int inverted_lag) {
RTC_DCHECK_LT(num_entries_, kMaxPitchPeriods24kHz);
inverted_lags_[num_entries_++] = inverted_lag;
}
const int* data() const { return inverted_lags_.data(); }
int size() const { return num_entries_; }
private:
std::array<int, kMaxPitchPeriods24kHz> inverted_lags_;
int num_entries_;
};
// Computes the auto correlation coefficients for the inverted lags in the
// closed interval `inverted_lags`. Updates `inverted_lags_index` by appending
// the inverted lags for the computed auto correlation values.
void ComputeAutoCorrelation(
Range inverted_lags,
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<float, kInitialNumLags24kHz> auto_correlation,
InvertedLagsIndex& inverted_lags_index,
const VectorMath& vector_math) {
// Check valid range.
RTC_DCHECK_LE(inverted_lags.min, inverted_lags.max);
// Trick to avoid zero initialization of `auto_correlation`.
// Needed by the pseudo-interpolation.
if (inverted_lags.min > 0) {
auto_correlation[inverted_lags.min - 1] = 0.f;
}
if (inverted_lags.max < kInitialNumLags24kHz - 1) {
auto_correlation[inverted_lags.max + 1] = 0.f;
}
// Check valid `inverted_lag` indexes.
RTC_DCHECK_GE(inverted_lags.min, 0);
RTC_DCHECK_LT(inverted_lags.max, kInitialNumLags24kHz);
for (int inverted_lag = inverted_lags.min; inverted_lag <= inverted_lags.max;
++inverted_lag) {
auto_correlation[inverted_lag] =
ComputeAutoCorrelation(inverted_lag, pitch_buffer, vector_math);
inverted_lags_index.Append(inverted_lag);
}
}
// Searches the strongest pitch period at 24 kHz and returns its inverted lag at
// 48 kHz.
int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const int> inverted_lags,
rtc::ArrayView<const float, kInitialNumLags24kHz> auto_correlation,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
const VectorMath& vector_math) {
static_assert(kMaxPitch24kHz > kInitialNumLags24kHz, "");
static_assert(kMaxPitch24kHz < kBufSize24kHz, "");
int best_inverted_lag = 0; // Pitch period.
float best_numerator = -1.f; // Pitch strength numerator.
float best_denominator = 0.f; // Pitch strength denominator.
for (int inverted_lag : inverted_lags) {
// A pitch candidate must have positive correlation.
if (auto_correlation[inverted_lag] > 0.f) {
// Auto-correlation energy normalized by frame energy.
const float numerator =
auto_correlation[inverted_lag] * auto_correlation[inverted_lag];
const float denominator = y_energy[inverted_lag];
// Compare numerator/denominator ratios without using divisions.
if (numerator * best_denominator > best_numerator * denominator) {
best_inverted_lag = inverted_lag;
best_numerator = numerator;
best_denominator = denominator;
}
}
}
// Pseudo-interpolation to transform `best_inverted_lag` (24 kHz pitch) to a
// 48 kHz pitch period.
if (best_inverted_lag == 0 || best_inverted_lag >= kInitialNumLags24kHz - 1) {
// Cannot apply pseudo-interpolation at the boundaries.
return best_inverted_lag * 2;
}
int offset = GetPitchPseudoInterpolationOffset(
auto_correlation[best_inverted_lag + 1],
auto_correlation[best_inverted_lag],
auto_correlation[best_inverted_lag - 1]);
// TODO(bugs.webrtc.org/9076): When retraining, check if `offset` below should
// be subtracted since `inverted_lag` is an inverted lag but offset is a lag.
return 2 * best_inverted_lag + offset;
}
// Returns an alternative pitch period for `pitch_period` given a `multiplier`
// and a `divisor` of the period.
constexpr int GetAlternativePitchPeriod(int pitch_period,
int multiplier,
int divisor) {
RTC_DCHECK_GT(divisor, 0);
// Same as `round(multiplier * pitch_period / divisor)`.
return (2 * multiplier * pitch_period + divisor) / (2 * divisor);
}
// Returns true if the alternative pitch period is stronger than the initial one
// given the last estimated pitch and the value of `period_divisor` used to
// compute the alternative pitch period via `GetAlternativePitchPeriod()`.
bool IsAlternativePitchStrongerThanInitial(PitchInfo last,
PitchInfo initial,
PitchInfo alternative,
int period_divisor) {
// Initial pitch period candidate thresholds for a sample rate of 24 kHz.
// Computed as [5*k*k for k in range(16)].
constexpr std::array<int, 14> kInitialPitchPeriodThresholds = {
{20, 45, 80, 125, 180, 245, 320, 405, 500, 605, 720, 845, 980, 1125}};
static_assert(
kInitialPitchPeriodThresholds.size() == kSubHarmonicMultipliers.size(),
"");
RTC_DCHECK_GE(last.period, 0);
RTC_DCHECK_GE(initial.period, 0);
RTC_DCHECK_GE(alternative.period, 0);
RTC_DCHECK_GE(period_divisor, 2);
// Compute a term that lowers the threshold when `alternative.period` is close
// to the last estimated period `last.period` - i.e., pitch tracking.
float lower_threshold_term = 0.f;
if (std::abs(alternative.period - last.period) <= 1) {
// The candidate pitch period is within 1 sample from the last one.
// Make the candidate at `alternative.period` very easy to be accepted.
lower_threshold_term = last.strength;
} else if (std::abs(alternative.period - last.period) == 2 &&
initial.period >
kInitialPitchPeriodThresholds[period_divisor - 2]) {
// The candidate pitch period is 2 samples far from the last one and the
// period `initial.period` (from which `alternative.period` has been
// derived) is greater than a threshold. Make `alternative.period` easy to
// be accepted.
lower_threshold_term = 0.5f * last.strength;
}
// Set the threshold based on the strength of the initial estimate
// `initial.period`. Also reduce the chance of false positives caused by a
// bias towards high frequencies (originating from short-term correlations).
float threshold =
std::max(0.3f, 0.7f * initial.strength - lower_threshold_term);
if (alternative.period < 3 * kMinPitch24kHz) {
// High frequency.
threshold = std::max(0.4f, 0.85f * initial.strength - lower_threshold_term);
} else if (alternative.period < 2 * kMinPitch24kHz) {
// Even higher frequency.
threshold = std::max(0.5f, 0.9f * initial.strength - lower_threshold_term);
}
return alternative.strength > threshold;
}
} // namespace
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
rtc::ArrayView<float, kBufSize12kHz> dst) {
// TODO(bugs.webrtc.org/9076): Consider adding anti-aliasing filter.
static_assert(2 * kBufSize12kHz == kBufSize24kHz, "");
for (int i = 0; i < kBufSize12kHz; ++i) {
dst[i] = src[2 * i];
}
}
void ComputeSlidingFrameSquareEnergies24kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy,
AvailableCpuFeatures cpu_features) {
VectorMath vector_math(cpu_features);
static_assert(kFrameSize20ms24kHz < kBufSize24kHz, "");
const auto frame_20ms_view = pitch_buffer.subview(0, kFrameSize20ms24kHz);
float yy = vector_math.DotProduct(frame_20ms_view, frame_20ms_view);
y_energy[0] = yy;
static_assert(kMaxPitch24kHz - 1 + kFrameSize20ms24kHz < kBufSize24kHz, "");
static_assert(kMaxPitch24kHz < kRefineNumLags24kHz, "");
for (int inverted_lag = 0; inverted_lag < kMaxPitch24kHz; ++inverted_lag) {
yy -= pitch_buffer[inverted_lag] * pitch_buffer[inverted_lag];
yy += pitch_buffer[inverted_lag + kFrameSize20ms24kHz] *
pitch_buffer[inverted_lag + kFrameSize20ms24kHz];
yy = std::max(1.f, yy);
y_energy[inverted_lag + 1] = yy;
}
}
CandidatePitchPeriods ComputePitchPeriod12kHz(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation,
AvailableCpuFeatures cpu_features) {
static_assert(kMaxPitch12kHz > kNumLags12kHz, "");
static_assert(kMaxPitch12kHz < kBufSize12kHz, "");
// Stores a pitch candidate period and strength information.
struct PitchCandidate {
// Pitch period encoded as inverted lag.
int period_inverted_lag = 0;
// Pitch strength encoded as a ratio.
float strength_numerator = -1.f;
float strength_denominator = 0.f;
// Compare the strength of two pitch candidates.
bool HasStrongerPitchThan(const PitchCandidate& b) const {
// Comparing the numerator/denominator ratios without using divisions.
return strength_numerator * b.strength_denominator >
b.strength_numerator * strength_denominator;
}
};
VectorMath vector_math(cpu_features);
static_assert(kFrameSize20ms12kHz + 1 < kBufSize12kHz, "");
const auto frame_view = pitch_buffer.subview(0, kFrameSize20ms12kHz + 1);
float denominator = 1.f + vector_math.DotProduct(frame_view, frame_view);
// Search best and second best pitches by looking at the scaled
// auto-correlation.
PitchCandidate best;
PitchCandidate second_best;
second_best.period_inverted_lag = 1;
for (int inverted_lag = 0; inverted_lag < kNumLags12kHz; ++inverted_lag) {
// A pitch candidate must have positive correlation.
if (auto_correlation[inverted_lag] > 0.f) {
PitchCandidate candidate{
inverted_lag,
auto_correlation[inverted_lag] * auto_correlation[inverted_lag],
denominator};
if (candidate.HasStrongerPitchThan(second_best)) {
if (candidate.HasStrongerPitchThan(best)) {
second_best = best;
best = candidate;
} else {
second_best = candidate;
}
}
}
// Update `squared_energy_y` for the next inverted lag.
const float y_old = pitch_buffer[inverted_lag];
const float y_new = pitch_buffer[inverted_lag + kFrameSize20ms12kHz];
denominator -= y_old * y_old;
denominator += y_new * y_new;
denominator = std::max(0.f, denominator);
}
return {best.period_inverted_lag, second_best.period_inverted_lag};
}
int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
CandidatePitchPeriods pitch_candidates,
AvailableCpuFeatures cpu_features) {
// Compute the auto-correlation terms only for neighbors of the two pitch
// candidates (best and second best).
std::array<float, kInitialNumLags24kHz> auto_correlation;
InvertedLagsIndex inverted_lags_index;
// Create two inverted lag ranges so that `r1` precedes `r2`.
const bool swap_candidates =
pitch_candidates.best > pitch_candidates.second_best;
const Range r1 = CreateInvertedLagRange(
swap_candidates ? pitch_candidates.second_best : pitch_candidates.best);
const Range r2 = CreateInvertedLagRange(
swap_candidates ? pitch_candidates.best : pitch_candidates.second_best);
// Check valid ranges.
RTC_DCHECK_LE(r1.min, r1.max);
RTC_DCHECK_LE(r2.min, r2.max);
// Check `r1` precedes `r2`.
RTC_DCHECK_LE(r1.min, r2.min);
RTC_DCHECK_LE(r1.max, r2.max);
VectorMath vector_math(cpu_features);
if (r1.max + 1 >= r2.min) {
// Overlapping or adjacent ranges.
ComputeAutoCorrelation({r1.min, r2.max}, pitch_buffer, auto_correlation,
inverted_lags_index, vector_math);
} else {
// Disjoint ranges.
ComputeAutoCorrelation(r1, pitch_buffer, auto_correlation,
inverted_lags_index, vector_math);
ComputeAutoCorrelation(r2, pitch_buffer, auto_correlation,
inverted_lags_index, vector_math);
}
return ComputePitchPeriod48kHz(pitch_buffer, inverted_lags_index,
auto_correlation, y_energy, vector_math);
}
PitchInfo ComputeExtendedPitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
int initial_pitch_period_48kHz,
PitchInfo last_pitch_48kHz,
AvailableCpuFeatures cpu_features) {
RTC_DCHECK_LE(kMinPitch48kHz, initial_pitch_period_48kHz);
RTC_DCHECK_LE(initial_pitch_period_48kHz, kMaxPitch48kHz);
// Stores information for a refined pitch candidate.
struct RefinedPitchCandidate {
int period;
float strength;
// Additional strength data used for the final pitch estimation.
float xy; // Auto-correlation.
float y_energy; // Energy of the sliding frame `y`.
};
const float x_energy = y_energy[kMaxPitch24kHz];
const auto pitch_strength = [x_energy](float xy, float y_energy) {
RTC_DCHECK_GE(x_energy * y_energy, 0.f);
return xy / std::sqrt(1.f + x_energy * y_energy);
};
VectorMath vector_math(cpu_features);
// Initialize the best pitch candidate with `initial_pitch_period_48kHz`.
RefinedPitchCandidate best_pitch;
best_pitch.period =
std::min(initial_pitch_period_48kHz / 2, kMaxPitch24kHz - 1);
best_pitch.xy = ComputeAutoCorrelation(kMaxPitch24kHz - best_pitch.period,
pitch_buffer, vector_math);
best_pitch.y_energy = y_energy[kMaxPitch24kHz - best_pitch.period];
best_pitch.strength = pitch_strength(best_pitch.xy, best_pitch.y_energy);
// Keep a copy of the initial pitch candidate.
const PitchInfo initial_pitch{best_pitch.period, best_pitch.strength};
// 24 kHz version of the last estimated pitch.
const PitchInfo last_pitch{last_pitch_48kHz.period / 2,
last_pitch_48kHz.strength};
// Find `max_period_divisor` such that the result of
// `GetAlternativePitchPeriod(initial_pitch_period, 1, max_period_divisor)`
// equals `kMinPitch24kHz`.
const int max_period_divisor =
(2 * initial_pitch.period) / (2 * kMinPitch24kHz - 1);
for (int period_divisor = 2; period_divisor <= max_period_divisor;
++period_divisor) {
PitchInfo alternative_pitch;
alternative_pitch.period = GetAlternativePitchPeriod(
initial_pitch.period, /*multiplier=*/1, period_divisor);
RTC_DCHECK_GE(alternative_pitch.period, kMinPitch24kHz);
// When looking at `alternative_pitch.period`, we also look at one of its
// sub-harmonics. `kSubHarmonicMultipliers` is used to know where to look.
// `period_divisor` == 2 is a special case since `dual_alternative_period`
// might be greater than the maximum pitch period.
int dual_alternative_period = GetAlternativePitchPeriod(
initial_pitch.period, kSubHarmonicMultipliers[period_divisor - 2],
period_divisor);
RTC_DCHECK_GT(dual_alternative_period, 0);
if (period_divisor == 2 && dual_alternative_period > kMaxPitch24kHz) {
dual_alternative_period = initial_pitch.period;
}
RTC_DCHECK_NE(alternative_pitch.period, dual_alternative_period)
<< "The lower pitch period and the additional sub-harmonic must not "
"coincide.";
// Compute an auto-correlation score for the primary pitch candidate
// `alternative_pitch.period` by also looking at its possible sub-harmonic
// `dual_alternative_period`.
const float xy_primary_period = ComputeAutoCorrelation(
kMaxPitch24kHz - alternative_pitch.period, pitch_buffer, vector_math);
// TODO(webrtc:10480): Copy `xy_primary_period` if the secondary period is
// equal to the primary one.
const float xy_secondary_period = ComputeAutoCorrelation(
kMaxPitch24kHz - dual_alternative_period, pitch_buffer, vector_math);
const float xy = 0.5f * (xy_primary_period + xy_secondary_period);
const float yy =
0.5f * (y_energy[kMaxPitch24kHz - alternative_pitch.period] +
y_energy[kMaxPitch24kHz - dual_alternative_period]);
alternative_pitch.strength = pitch_strength(xy, yy);
// Maybe update best period.
if (IsAlternativePitchStrongerThanInitial(
last_pitch, initial_pitch, alternative_pitch, period_divisor)) {
best_pitch = {alternative_pitch.period, alternative_pitch.strength, xy,
yy};
}
}
// Final pitch strength and period.
best_pitch.xy = std::max(0.f, best_pitch.xy);
RTC_DCHECK_LE(0.f, best_pitch.y_energy);
float final_pitch_strength =
(best_pitch.y_energy <= best_pitch.xy)
? 1.f
: best_pitch.xy / (best_pitch.y_energy + 1.f);
final_pitch_strength = std::min(best_pitch.strength, final_pitch_strength);
int final_pitch_period_48kHz = std::max(
kMinPitch48kHz, PitchPseudoInterpolationLagPitchBuf(
best_pitch.period, pitch_buffer, vector_math));
return {final_pitch_period_48kHz, final_pitch_strength};
}
} // namespace rnn_vad
} // namespace webrtc

View file

@ -0,0 +1,114 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_
#include <stddef.h>
#include <array>
#include <utility>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
namespace webrtc {
namespace rnn_vad {
// Performs 2x decimation without any anti-aliasing filter.
void Decimate2x(rtc::ArrayView<const float, kBufSize24kHz> src,
rtc::ArrayView<float, kBufSize12kHz> dst);
// Key concepts and keywords used below in this file.
//
// The pitch estimation relies on a pitch buffer, which is an array-like data
// structured designed as follows:
//
// |....A....|.....B.....|
//
// The part on the left, named `A` contains the oldest samples, whereas `B`
// contains the most recent ones. The size of `A` corresponds to the maximum
// pitch period, that of `B` to the analysis frame size (e.g., 16 ms and 20 ms
// respectively).
//
// Pitch estimation is essentially based on the analysis of two 20 ms frames
// extracted from the pitch buffer. One frame, called `x`, is kept fixed and
// corresponds to `B` - i.e., the most recent 20 ms. The other frame, called
// `y`, is extracted from different parts of the buffer instead.
//
// The offset between `x` and `y` corresponds to a specific pitch period.
// For instance, if `y` is positioned at the beginning of the pitch buffer, then
// the cross-correlation between `x` and `y` can be used as an indication of the
// strength for the maximum pitch.
//
// Such an offset can be encoded in two ways:
// - As a lag, which is the index in the pitch buffer for the first item in `y`
// - As an inverted lag, which is the number of samples from the beginning of
// `x` and the end of `y`
//
// |---->| lag
// |....A....|.....B.....|
// |<--| inverted lag
// |.....y.....| `y` 20 ms frame
//
// The inverted lag has the advantage of being directly proportional to the
// corresponding pitch period.
// Computes the sum of squared samples for every sliding frame `y` in the pitch
// buffer. The indexes of `y_energy` are inverted lags.
void ComputeSlidingFrameSquareEnergies24kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<float, kRefineNumLags24kHz> y_energy,
AvailableCpuFeatures cpu_features);
// Top-2 pitch period candidates. Unit: number of samples - i.e., inverted lags.
struct CandidatePitchPeriods {
int best;
int second_best;
};
// Computes the candidate pitch periods at 12 kHz given a view on the 12 kHz
// pitch buffer and the auto-correlation values (having inverted lags as
// indexes).
CandidatePitchPeriods ComputePitchPeriod12kHz(
rtc::ArrayView<const float, kBufSize12kHz> pitch_buffer,
rtc::ArrayView<const float, kNumLags12kHz> auto_correlation,
AvailableCpuFeatures cpu_features);
// Computes the pitch period at 48 kHz given a view on the 24 kHz pitch buffer,
// the energies for the sliding frames `y` at 24 kHz and the pitch period
// candidates at 24 kHz (encoded as inverted lag).
int ComputePitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
CandidatePitchPeriods pitch_candidates_24kHz,
AvailableCpuFeatures cpu_features);
struct PitchInfo {
int period;
float strength;
};
// Computes the pitch period at 48 kHz searching in an extended pitch range
// given a view on the 24 kHz pitch buffer, the energies for the sliding frames
// `y` at 24 kHz, the initial 48 kHz estimation (computed by
// `ComputePitchPeriod48kHz()`) and the last estimated pitch.
PitchInfo ComputeExtendedPitchPeriod48kHz(
rtc::ArrayView<const float, kBufSize24kHz> pitch_buffer,
rtc::ArrayView<const float, kRefineNumLags24kHz> y_energy,
int initial_pitch_period_48kHz,
PitchInfo last_pitch_48kHz,
AvailableCpuFeatures cpu_features);
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_PITCH_SEARCH_INTERNAL_H_

View file

@ -0,0 +1,65 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_
#include <array>
#include <cstring>
#include <type_traits>
#include "api/array_view.h"
namespace webrtc {
namespace rnn_vad {
// Ring buffer for N arrays of type T each one with size S.
template <typename T, int S, int N>
class RingBuffer {
static_assert(S > 0, "");
static_assert(N > 0, "");
static_assert(std::is_arithmetic<T>::value,
"Integral or floating point required.");
public:
RingBuffer() : tail_(0) {}
RingBuffer(const RingBuffer&) = delete;
RingBuffer& operator=(const RingBuffer&) = delete;
~RingBuffer() = default;
// Set the ring buffer values to zero.
void Reset() { buffer_.fill(0); }
// Replace the least recently pushed array in the buffer with `new_values`.
void Push(rtc::ArrayView<const T, S> new_values) {
std::memcpy(buffer_.data() + S * tail_, new_values.data(), S * sizeof(T));
tail_ += 1;
if (tail_ == N)
tail_ = 0;
}
// Return an array view onto the array with a given delay. A view on the last
// and least recently push array is returned when `delay` is 0 and N - 1
// respectively.
rtc::ArrayView<const T, S> GetArrayView(int delay) const {
RTC_DCHECK_LE(0, delay);
RTC_DCHECK_LT(delay, N);
int offset = tail_ - 1 - delay;
if (offset < 0)
offset += N;
return {buffer_.data() + S * offset, S};
}
private:
int tail_; // Index of the least recently pushed sub-array.
std::array<T, S * N> buffer_{};
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RING_BUFFER_H_

View file

@ -0,0 +1,91 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "rtc_base/checks.h"
#include "third_party/rnnoise/src/rnn_vad_weights.h"
namespace webrtc {
namespace rnn_vad {
namespace {
using ::rnnoise::kInputLayerInputSize;
static_assert(kFeatureVectorSize == kInputLayerInputSize, "");
using ::rnnoise::kInputDenseBias;
using ::rnnoise::kInputDenseWeights;
using ::rnnoise::kInputLayerOutputSize;
static_assert(kInputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
using ::rnnoise::kHiddenGruBias;
using ::rnnoise::kHiddenGruRecurrentWeights;
using ::rnnoise::kHiddenGruWeights;
using ::rnnoise::kHiddenLayerOutputSize;
static_assert(kHiddenLayerOutputSize <= kGruLayerMaxUnits, "");
using ::rnnoise::kOutputDenseBias;
using ::rnnoise::kOutputDenseWeights;
using ::rnnoise::kOutputLayerOutputSize;
static_assert(kOutputLayerOutputSize <= kFullyConnectedLayerMaxUnits, "");
} // namespace
RnnVad::RnnVad(const AvailableCpuFeatures& cpu_features)
: input_(kInputLayerInputSize,
kInputLayerOutputSize,
kInputDenseBias,
kInputDenseWeights,
ActivationFunction::kTansigApproximated,
cpu_features,
/*layer_name=*/"FC1"),
hidden_(kInputLayerOutputSize,
kHiddenLayerOutputSize,
kHiddenGruBias,
kHiddenGruWeights,
kHiddenGruRecurrentWeights,
cpu_features,
/*layer_name=*/"GRU1"),
output_(kHiddenLayerOutputSize,
kOutputLayerOutputSize,
kOutputDenseBias,
kOutputDenseWeights,
ActivationFunction::kSigmoidApproximated,
// The output layer is just 24x1. The unoptimized code is faster.
NoAvailableCpuFeatures(),
/*layer_name=*/"FC2") {
// Input-output chaining size checks.
RTC_DCHECK_EQ(input_.size(), hidden_.input_size())
<< "The input and the hidden layers sizes do not match.";
RTC_DCHECK_EQ(hidden_.size(), output_.input_size())
<< "The hidden and the output layers sizes do not match.";
}
RnnVad::~RnnVad() = default;
void RnnVad::Reset() {
hidden_.Reset();
}
float RnnVad::ComputeVadProbability(
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
bool is_silence) {
if (is_silence) {
Reset();
return 0.f;
}
input_.ComputeOutput(feature_vector);
hidden_.ComputeOutput(input_);
output_.ComputeOutput(hidden_);
RTC_DCHECK_EQ(output_.size(), 1);
return output_.data()[0];
}
} // namespace rnn_vad
} // namespace webrtc

View file

@ -0,0 +1,53 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_
#include <stddef.h>
#include <sys/types.h>
#include <array>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
namespace webrtc {
namespace rnn_vad {
// Recurrent network with hard-coded architecture and weights for voice activity
// detection.
class RnnVad {
public:
explicit RnnVad(const AvailableCpuFeatures& cpu_features);
RnnVad(const RnnVad&) = delete;
RnnVad& operator=(const RnnVad&) = delete;
~RnnVad();
void Reset();
// Observes `feature_vector` and `is_silence`, updates the RNN and returns the
// current voice probability.
float ComputeVadProbability(
rtc::ArrayView<const float, kFeatureVectorSize> feature_vector,
bool is_silence);
private:
FullyConnectedLayer input_;
GatedRecurrentLayer hidden_;
FullyConnectedLayer output_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_H_

View file

@ -0,0 +1,104 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/rnn_fc.h"
#include <algorithm>
#include <numeric>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "third_party/rnnoise/src/rnn_activations.h"
#include "third_party/rnnoise/src/rnn_vad_weights.h"
namespace webrtc {
namespace rnn_vad {
namespace {
std::vector<float> GetScaledParams(rtc::ArrayView<const int8_t> params) {
std::vector<float> scaled_params(params.size());
std::transform(params.begin(), params.end(), scaled_params.begin(),
[](int8_t x) -> float {
return ::rnnoise::kWeightsScale * static_cast<float>(x);
});
return scaled_params;
}
// TODO(bugs.chromium.org/10480): Hard-code optimized layout and remove this
// function to improve setup time.
// Casts and scales `weights` and re-arranges the layout.
std::vector<float> PreprocessWeights(rtc::ArrayView<const int8_t> weights,
int output_size) {
if (output_size == 1) {
return GetScaledParams(weights);
}
// Transpose, scale and cast.
const int input_size = rtc::CheckedDivExact(
rtc::dchecked_cast<int>(weights.size()), output_size);
std::vector<float> w(weights.size());
for (int o = 0; o < output_size; ++o) {
for (int i = 0; i < input_size; ++i) {
w[o * input_size + i] = rnnoise::kWeightsScale *
static_cast<float>(weights[i * output_size + o]);
}
}
return w;
}
rtc::FunctionView<float(float)> GetActivationFunction(
ActivationFunction activation_function) {
switch (activation_function) {
case ActivationFunction::kTansigApproximated:
return ::rnnoise::TansigApproximated;
case ActivationFunction::kSigmoidApproximated:
return ::rnnoise::SigmoidApproximated;
}
}
} // namespace
FullyConnectedLayer::FullyConnectedLayer(
const int input_size,
const int output_size,
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
ActivationFunction activation_function,
const AvailableCpuFeatures& cpu_features,
absl::string_view layer_name)
: input_size_(input_size),
output_size_(output_size),
bias_(GetScaledParams(bias)),
weights_(PreprocessWeights(weights, output_size)),
vector_math_(cpu_features),
activation_function_(GetActivationFunction(activation_function)) {
RTC_DCHECK_LE(output_size_, kFullyConnectedLayerMaxUnits)
<< "Insufficient FC layer over-allocation (" << layer_name << ").";
RTC_DCHECK_EQ(output_size_, bias_.size())
<< "Mismatching output size and bias terms array size (" << layer_name
<< ").";
RTC_DCHECK_EQ(input_size_ * output_size_, weights_.size())
<< "Mismatching input-output size and weight coefficients array size ("
<< layer_name << ").";
}
FullyConnectedLayer::~FullyConnectedLayer() = default;
void FullyConnectedLayer::ComputeOutput(rtc::ArrayView<const float> input) {
RTC_DCHECK_EQ(input.size(), input_size_);
rtc::ArrayView<const float> weights(weights_);
for (int o = 0; o < output_size_; ++o) {
output_[o] = activation_function_(
bias_[o] + vector_math_.DotProduct(
input, weights.subview(o * input_size_, input_size_)));
}
}
} // namespace rnn_vad
} // namespace webrtc

View file

@ -0,0 +1,72 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_
#include <array>
#include <vector>
#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "api/function_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
namespace webrtc {
namespace rnn_vad {
// Activation function for a neural network cell.
enum class ActivationFunction { kTansigApproximated, kSigmoidApproximated };
// Maximum number of units for an FC layer.
constexpr int kFullyConnectedLayerMaxUnits = 24;
// Fully-connected layer with a custom activation function which owns the output
// buffer.
class FullyConnectedLayer {
public:
// Ctor. `output_size` cannot be greater than `kFullyConnectedLayerMaxUnits`.
FullyConnectedLayer(int input_size,
int output_size,
rtc::ArrayView<const int8_t> bias,
rtc::ArrayView<const int8_t> weights,
ActivationFunction activation_function,
const AvailableCpuFeatures& cpu_features,
absl::string_view layer_name);
FullyConnectedLayer(const FullyConnectedLayer&) = delete;
FullyConnectedLayer& operator=(const FullyConnectedLayer&) = delete;
~FullyConnectedLayer();
// Returns the size of the input vector.
int input_size() const { return input_size_; }
// Returns the pointer to the first element of the output buffer.
const float* data() const { return output_.data(); }
// Returns the size of the output buffer.
int size() const { return output_size_; }
// Computes the fully-connected layer output.
void ComputeOutput(rtc::ArrayView<const float> input);
private:
const int input_size_;
const int output_size_;
const std::vector<float> bias_;
const std::vector<float> weights_;
const VectorMath vector_math_;
rtc::FunctionView<float(float)> activation_function_;
// Over-allocated array with size equal to `output_size_`.
std::array<float, kFullyConnectedLayerMaxUnits> output_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_FC_H_

View file

@ -0,0 +1,198 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/rnn_gru.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "third_party/rnnoise/src/rnn_activations.h"
#include "third_party/rnnoise/src/rnn_vad_weights.h"
namespace webrtc {
namespace rnn_vad {
namespace {
constexpr int kNumGruGates = 3; // Update, reset, output.
std::vector<float> PreprocessGruTensor(rtc::ArrayView<const int8_t> tensor_src,
int output_size) {
// Transpose, cast and scale.
// `n` is the size of the first dimension of the 3-dim tensor `weights`.
const int n = rtc::CheckedDivExact(rtc::dchecked_cast<int>(tensor_src.size()),
output_size * kNumGruGates);
const int stride_src = kNumGruGates * output_size;
const int stride_dst = n * output_size;
std::vector<float> tensor_dst(tensor_src.size());
for (int g = 0; g < kNumGruGates; ++g) {
for (int o = 0; o < output_size; ++o) {
for (int i = 0; i < n; ++i) {
tensor_dst[g * stride_dst + o * n + i] =
::rnnoise::kWeightsScale *
static_cast<float>(
tensor_src[i * stride_src + g * output_size + o]);
}
}
}
return tensor_dst;
}
// Computes the output for the update or the reset gate.
// Operation: `g = sigmoid(W^T∙i + R^T∙s + b)` where
// - `g`: output gate vector
// - `W`: weights matrix
// - `i`: input vector
// - `R`: recurrent weights matrix
// - `s`: state gate vector
// - `b`: bias vector
void ComputeUpdateResetGate(int input_size,
int output_size,
const VectorMath& vector_math,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> state,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
rtc::ArrayView<float> gate) {
RTC_DCHECK_EQ(input.size(), input_size);
RTC_DCHECK_EQ(state.size(), output_size);
RTC_DCHECK_EQ(bias.size(), output_size);
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
RTC_DCHECK_GE(gate.size(), output_size); // `gate` is over-allocated.
for (int o = 0; o < output_size; ++o) {
float x = bias[o];
x += vector_math.DotProduct(input,
weights.subview(o * input_size, input_size));
x += vector_math.DotProduct(
state, recurrent_weights.subview(o * output_size, output_size));
gate[o] = ::rnnoise::SigmoidApproximated(x);
}
}
// Computes the output for the state gate.
// Operation: `s' = u .* s + (1 - u) .* ReLU(W^T∙i + R^T∙(s .* r) + b)` where
// - `s'`: output state gate vector
// - `s`: previous state gate vector
// - `u`: update gate vector
// - `W`: weights matrix
// - `i`: input vector
// - `R`: recurrent weights matrix
// - `r`: reset gate vector
// - `b`: bias vector
// - `.*` element-wise product
void ComputeStateGate(int input_size,
int output_size,
const VectorMath& vector_math,
rtc::ArrayView<const float> input,
rtc::ArrayView<const float> update,
rtc::ArrayView<const float> reset,
rtc::ArrayView<const float> bias,
rtc::ArrayView<const float> weights,
rtc::ArrayView<const float> recurrent_weights,
rtc::ArrayView<float> state) {
RTC_DCHECK_EQ(input.size(), input_size);
RTC_DCHECK_GE(update.size(), output_size); // `update` is over-allocated.
RTC_DCHECK_GE(reset.size(), output_size); // `reset` is over-allocated.
RTC_DCHECK_EQ(bias.size(), output_size);
RTC_DCHECK_EQ(weights.size(), input_size * output_size);
RTC_DCHECK_EQ(recurrent_weights.size(), output_size * output_size);
RTC_DCHECK_EQ(state.size(), output_size);
std::array<float, kGruLayerMaxUnits> reset_x_state;
for (int o = 0; o < output_size; ++o) {
reset_x_state[o] = state[o] * reset[o];
}
for (int o = 0; o < output_size; ++o) {
float x = bias[o];
x += vector_math.DotProduct(input,
weights.subview(o * input_size, input_size));
x += vector_math.DotProduct(
{reset_x_state.data(), static_cast<size_t>(output_size)},
recurrent_weights.subview(o * output_size, output_size));
state[o] = update[o] * state[o] + (1.f - update[o]) * std::max(0.f, x);
}
}
} // namespace
GatedRecurrentLayer::GatedRecurrentLayer(
const int input_size,
const int output_size,
const rtc::ArrayView<const int8_t> bias,
const rtc::ArrayView<const int8_t> weights,
const rtc::ArrayView<const int8_t> recurrent_weights,
const AvailableCpuFeatures& cpu_features,
absl::string_view layer_name)
: input_size_(input_size),
output_size_(output_size),
bias_(PreprocessGruTensor(bias, output_size)),
weights_(PreprocessGruTensor(weights, output_size)),
recurrent_weights_(PreprocessGruTensor(recurrent_weights, output_size)),
vector_math_(cpu_features) {
RTC_DCHECK_LE(output_size_, kGruLayerMaxUnits)
<< "Insufficient GRU layer over-allocation (" << layer_name << ").";
RTC_DCHECK_EQ(kNumGruGates * output_size_, bias_.size())
<< "Mismatching output size and bias terms array size (" << layer_name
<< ").";
RTC_DCHECK_EQ(kNumGruGates * input_size_ * output_size_, weights_.size())
<< "Mismatching input-output size and weight coefficients array size ("
<< layer_name << ").";
RTC_DCHECK_EQ(kNumGruGates * output_size_ * output_size_,
recurrent_weights_.size())
<< "Mismatching input-output size and recurrent weight coefficients array"
" size ("
<< layer_name << ").";
Reset();
}
GatedRecurrentLayer::~GatedRecurrentLayer() = default;
void GatedRecurrentLayer::Reset() {
state_.fill(0.f);
}
void GatedRecurrentLayer::ComputeOutput(rtc::ArrayView<const float> input) {
RTC_DCHECK_EQ(input.size(), input_size_);
// The tensors below are organized as a sequence of flattened tensors for the
// `update`, `reset` and `state` gates.
rtc::ArrayView<const float> bias(bias_);
rtc::ArrayView<const float> weights(weights_);
rtc::ArrayView<const float> recurrent_weights(recurrent_weights_);
// Strides to access to the flattened tensors for a specific gate.
const int stride_weights = input_size_ * output_size_;
const int stride_recurrent_weights = output_size_ * output_size_;
rtc::ArrayView<float> state(state_.data(), output_size_);
// Update gate.
std::array<float, kGruLayerMaxUnits> update;
ComputeUpdateResetGate(
input_size_, output_size_, vector_math_, input, state,
bias.subview(0, output_size_), weights.subview(0, stride_weights),
recurrent_weights.subview(0, stride_recurrent_weights), update);
// Reset gate.
std::array<float, kGruLayerMaxUnits> reset;
ComputeUpdateResetGate(input_size_, output_size_, vector_math_, input, state,
bias.subview(output_size_, output_size_),
weights.subview(stride_weights, stride_weights),
recurrent_weights.subview(stride_recurrent_weights,
stride_recurrent_weights),
reset);
// State gate.
ComputeStateGate(input_size_, output_size_, vector_math_, input, update,
reset, bias.subview(2 * output_size_, output_size_),
weights.subview(2 * stride_weights, stride_weights),
recurrent_weights.subview(2 * stride_recurrent_weights,
stride_recurrent_weights),
state);
}
} // namespace rnn_vad
} // namespace webrtc

View file

@ -0,0 +1,70 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_
#include <array>
#include <vector>
#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
namespace webrtc {
namespace rnn_vad {
// Maximum number of units for a GRU layer.
constexpr int kGruLayerMaxUnits = 24;
// Recurrent layer with gated recurrent units (GRUs) with sigmoid and ReLU as
// activation functions for the update/reset and output gates respectively.
class GatedRecurrentLayer {
public:
// Ctor. `output_size` cannot be greater than `kGruLayerMaxUnits`.
GatedRecurrentLayer(int input_size,
int output_size,
rtc::ArrayView<const int8_t> bias,
rtc::ArrayView<const int8_t> weights,
rtc::ArrayView<const int8_t> recurrent_weights,
const AvailableCpuFeatures& cpu_features,
absl::string_view layer_name);
GatedRecurrentLayer(const GatedRecurrentLayer&) = delete;
GatedRecurrentLayer& operator=(const GatedRecurrentLayer&) = delete;
~GatedRecurrentLayer();
// Returns the size of the input vector.
int input_size() const { return input_size_; }
// Returns the pointer to the first element of the output buffer.
const float* data() const { return state_.data(); }
// Returns the size of the output buffer.
int size() const { return output_size_; }
// Resets the GRU state.
void Reset();
// Computes the recurrent layer output and updates the status.
void ComputeOutput(rtc::ArrayView<const float> input);
private:
const int input_size_;
const int output_size_;
const std::vector<float> bias_;
const std::vector<float> weights_;
const std::vector<float> recurrent_weights_;
const VectorMath vector_math_;
// Over-allocated array with size equal to `output_size_`.
std::array<float, kGruLayerMaxUnits> state_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_RNN_GRU_H_

View file

@ -0,0 +1,123 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include <array>
#include <string>
#include <vector>
#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "common_audio/resampler/push_sinc_resampler.h"
#include "common_audio/wav_file.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_compare.h"
ABSL_FLAG(std::string, i, "", "Path to the input wav file");
ABSL_FLAG(std::string, f, "", "Path to the output features file");
ABSL_FLAG(std::string, o, "", "Path to the output VAD probabilities file");
namespace webrtc {
namespace rnn_vad {
namespace test {
int main(int argc, char* argv[]) {
absl::ParseCommandLine(argc, argv);
rtc::LogMessage::LogToDebug(rtc::LS_INFO);
// Open wav input file and check properties.
const std::string input_wav_file = absl::GetFlag(FLAGS_i);
WavReader wav_reader(input_wav_file);
if (wav_reader.num_channels() != 1) {
RTC_LOG(LS_ERROR) << "Only mono wav files are supported";
return 1;
}
if (wav_reader.sample_rate() % 100 != 0) {
RTC_LOG(LS_ERROR) << "The sample rate rate must allow 10 ms frames.";
return 1;
}
RTC_LOG(LS_INFO) << "Input sample rate: " << wav_reader.sample_rate();
// Init output files.
const std::string output_vad_probs_file = absl::GetFlag(FLAGS_o);
FILE* vad_probs_file = fopen(output_vad_probs_file.c_str(), "wb");
FILE* features_file = nullptr;
const std::string output_feature_file = absl::GetFlag(FLAGS_f);
if (!output_feature_file.empty()) {
features_file = fopen(output_feature_file.c_str(), "wb");
}
// Initialize.
const int frame_size_10ms =
rtc::CheckedDivExact(wav_reader.sample_rate(), 100);
std::vector<float> samples_10ms;
samples_10ms.resize(frame_size_10ms);
std::array<float, kFrameSize10ms24kHz> samples_10ms_24kHz;
PushSincResampler resampler(frame_size_10ms, kFrameSize10ms24kHz);
const AvailableCpuFeatures cpu_features = GetAvailableCpuFeatures();
FeaturesExtractor features_extractor(cpu_features);
std::array<float, kFeatureVectorSize> feature_vector;
RnnVad rnn_vad(cpu_features);
// Compute VAD probabilities.
while (true) {
// Read frame at the input sample rate.
const size_t read_samples =
wav_reader.ReadSamples(frame_size_10ms, samples_10ms.data());
if (rtc::SafeLt(read_samples, frame_size_10ms)) {
break; // EOF.
}
// Resample input.
resampler.Resample(samples_10ms.data(), samples_10ms.size(),
samples_10ms_24kHz.data(), samples_10ms_24kHz.size());
// Extract features and feed the RNN.
bool is_silence = features_extractor.CheckSilenceComputeFeatures(
samples_10ms_24kHz, feature_vector);
float vad_probability =
rnn_vad.ComputeVadProbability(feature_vector, is_silence);
// Write voice probability.
RTC_DCHECK_GE(vad_probability, 0.f);
RTC_DCHECK_GE(1.f, vad_probability);
fwrite(&vad_probability, sizeof(float), 1, vad_probs_file);
// Write features.
if (features_file) {
const float float_is_silence = is_silence ? 1.f : 0.f;
fwrite(&float_is_silence, sizeof(float), 1, features_file);
if (is_silence) {
// Do not write uninitialized values.
feature_vector.fill(0.f);
}
fwrite(feature_vector.data(), sizeof(float), kFeatureVectorSize,
features_file);
}
}
// Close output file(s).
fclose(vad_probs_file);
RTC_LOG(LS_INFO) << "VAD probabilities written to " << output_vad_probs_file;
if (features_file) {
fclose(features_file);
RTC_LOG(LS_INFO) << "features written to " << output_feature_file;
}
return 0;
}
} // namespace test
} // namespace rnn_vad
} // namespace webrtc
int main(int argc, char* argv[]) {
return webrtc::rnn_vad::test::main(argc, argv);
}

View file

@ -0,0 +1,79 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_
#include <algorithm>
#include <cstring>
#include <type_traits>
#include <vector>
#include "api/array_view.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace rnn_vad {
// Linear buffer implementation to (i) push fixed size chunks of sequential data
// and (ii) view contiguous parts of the buffer. The buffer and the pushed
// chunks have size S and N respectively. For instance, when S = 2N the first
// half of the sequence buffer is replaced with its second half, and the new N
// values are written at the end of the buffer.
// The class also provides a view on the most recent M values, where 0 < M <= S
// and by default M = N.
template <typename T, int S, int N, int M = N>
class SequenceBuffer {
static_assert(N <= S,
"The new chunk size cannot be larger than the sequence buffer "
"size.");
static_assert(std::is_arithmetic<T>::value,
"Integral or floating point required.");
public:
SequenceBuffer() : buffer_(S) {
RTC_DCHECK_EQ(S, buffer_.size());
Reset();
}
SequenceBuffer(const SequenceBuffer&) = delete;
SequenceBuffer& operator=(const SequenceBuffer&) = delete;
~SequenceBuffer() = default;
int size() const { return S; }
int chunks_size() const { return N; }
// Sets the sequence buffer values to zero.
void Reset() { std::fill(buffer_.begin(), buffer_.end(), 0); }
// Returns a view on the whole buffer.
rtc::ArrayView<const T, S> GetBufferView() const {
return {buffer_.data(), S};
}
// Returns a view on the M most recent values of the buffer.
rtc::ArrayView<const T, M> GetMostRecentValuesView() const {
static_assert(M <= S,
"The number of most recent values cannot be larger than the "
"sequence buffer size.");
return {buffer_.data() + S - M, M};
}
// Shifts left the buffer by N items and add new N items at the end.
void Push(rtc::ArrayView<const T, N> new_values) {
// Make space for the new values.
if (S > N)
std::memmove(buffer_.data(), buffer_.data() + N, (S - N) * sizeof(T));
// Copy the new values at the end of the buffer.
std::memcpy(buffer_.data() + S - N, new_values.data(), N * sizeof(T));
}
private:
std::vector<T> buffer_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SEQUENCE_BUFFER_H_

View file

@ -0,0 +1,214 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/spectral_features.h"
#include <algorithm>
#include <cmath>
#include <limits>
#include <numeric>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
namespace {
constexpr float kSilenceThreshold = 0.04f;
// Computes the new cepstral difference stats and pushes them into the passed
// symmetric matrix buffer.
void UpdateCepstralDifferenceStats(
rtc::ArrayView<const float, kNumBands> new_cepstral_coeffs,
const RingBuffer<float, kNumBands, kCepstralCoeffsHistorySize>& ring_buf,
SymmetricMatrixBuffer<float, kCepstralCoeffsHistorySize>* sym_matrix_buf) {
RTC_DCHECK(sym_matrix_buf);
// Compute the new cepstral distance stats.
std::array<float, kCepstralCoeffsHistorySize - 1> distances;
for (int i = 0; i < kCepstralCoeffsHistorySize - 1; ++i) {
const int delay = i + 1;
auto old_cepstral_coeffs = ring_buf.GetArrayView(delay);
distances[i] = 0.f;
for (int k = 0; k < kNumBands; ++k) {
const float c = new_cepstral_coeffs[k] - old_cepstral_coeffs[k];
distances[i] += c * c;
}
}
// Push the new spectral distance stats into the symmetric matrix buffer.
sym_matrix_buf->Push(distances);
}
// Computes the first half of the Vorbis window.
std::array<float, kFrameSize20ms24kHz / 2> ComputeScaledHalfVorbisWindow(
float scaling = 1.f) {
constexpr int kHalfSize = kFrameSize20ms24kHz / 2;
std::array<float, kHalfSize> half_window{};
for (int i = 0; i < kHalfSize; ++i) {
half_window[i] =
scaling *
std::sin(0.5 * kPi * std::sin(0.5 * kPi * (i + 0.5) / kHalfSize) *
std::sin(0.5 * kPi * (i + 0.5) / kHalfSize));
}
return half_window;
}
// Computes the forward FFT on a 20 ms frame to which a given window function is
// applied. The Fourier coefficient corresponding to the Nyquist frequency is
// set to zero (it is never used and this allows to simplify the code).
void ComputeWindowedForwardFft(
rtc::ArrayView<const float, kFrameSize20ms24kHz> frame,
const std::array<float, kFrameSize20ms24kHz / 2>& half_window,
Pffft::FloatBuffer* fft_input_buffer,
Pffft::FloatBuffer* fft_output_buffer,
Pffft* fft) {
RTC_DCHECK_EQ(frame.size(), 2 * half_window.size());
// Apply windowing.
auto in = fft_input_buffer->GetView();
for (int i = 0, j = kFrameSize20ms24kHz - 1;
rtc::SafeLt(i, half_window.size()); ++i, --j) {
in[i] = frame[i] * half_window[i];
in[j] = frame[j] * half_window[i];
}
fft->ForwardTransform(*fft_input_buffer, fft_output_buffer, /*ordered=*/true);
// Set the Nyquist frequency coefficient to zero.
auto out = fft_output_buffer->GetView();
out[1] = 0.f;
}
} // namespace
SpectralFeaturesExtractor::SpectralFeaturesExtractor()
: half_window_(ComputeScaledHalfVorbisWindow(
1.f / static_cast<float>(kFrameSize20ms24kHz))),
fft_(kFrameSize20ms24kHz, Pffft::FftType::kReal),
fft_buffer_(fft_.CreateBuffer()),
reference_frame_fft_(fft_.CreateBuffer()),
lagged_frame_fft_(fft_.CreateBuffer()),
dct_table_(ComputeDctTable()) {}
SpectralFeaturesExtractor::~SpectralFeaturesExtractor() = default;
void SpectralFeaturesExtractor::Reset() {
cepstral_coeffs_ring_buf_.Reset();
cepstral_diffs_buf_.Reset();
}
bool SpectralFeaturesExtractor::CheckSilenceComputeFeatures(
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame,
rtc::ArrayView<const float, kFrameSize20ms24kHz> lagged_frame,
rtc::ArrayView<float, kNumBands - kNumLowerBands> higher_bands_cepstrum,
rtc::ArrayView<float, kNumLowerBands> average,
rtc::ArrayView<float, kNumLowerBands> first_derivative,
rtc::ArrayView<float, kNumLowerBands> second_derivative,
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr,
float* variability) {
// Compute the Opus band energies for the reference frame.
ComputeWindowedForwardFft(reference_frame, half_window_, fft_buffer_.get(),
reference_frame_fft_.get(), &fft_);
spectral_correlator_.ComputeAutoCorrelation(
reference_frame_fft_->GetConstView(), reference_frame_bands_energy_);
// Check if the reference frame has silence.
const float tot_energy =
std::accumulate(reference_frame_bands_energy_.begin(),
reference_frame_bands_energy_.end(), 0.f);
if (tot_energy < kSilenceThreshold) {
return true;
}
// Compute the Opus band energies for the lagged frame.
ComputeWindowedForwardFft(lagged_frame, half_window_, fft_buffer_.get(),
lagged_frame_fft_.get(), &fft_);
spectral_correlator_.ComputeAutoCorrelation(lagged_frame_fft_->GetConstView(),
lagged_frame_bands_energy_);
// Log of the band energies for the reference frame.
std::array<float, kNumBands> log_bands_energy;
ComputeSmoothedLogMagnitudeSpectrum(reference_frame_bands_energy_,
log_bands_energy);
// Reference frame cepstrum.
std::array<float, kNumBands> cepstrum;
ComputeDct(log_bands_energy, dct_table_, cepstrum);
// Ad-hoc correction terms for the first two cepstral coefficients.
cepstrum[0] -= 12.f;
cepstrum[1] -= 4.f;
// Update the ring buffer and the cepstral difference stats.
cepstral_coeffs_ring_buf_.Push(cepstrum);
UpdateCepstralDifferenceStats(cepstrum, cepstral_coeffs_ring_buf_,
&cepstral_diffs_buf_);
// Write the higher bands cepstral coefficients.
RTC_DCHECK_EQ(cepstrum.size() - kNumLowerBands, higher_bands_cepstrum.size());
std::copy(cepstrum.begin() + kNumLowerBands, cepstrum.end(),
higher_bands_cepstrum.begin());
// Compute and write remaining features.
ComputeAvgAndDerivatives(average, first_derivative, second_derivative);
ComputeNormalizedCepstralCorrelation(bands_cross_corr);
RTC_DCHECK(variability);
*variability = ComputeVariability();
return false;
}
void SpectralFeaturesExtractor::ComputeAvgAndDerivatives(
rtc::ArrayView<float, kNumLowerBands> average,
rtc::ArrayView<float, kNumLowerBands> first_derivative,
rtc::ArrayView<float, kNumLowerBands> second_derivative) const {
auto curr = cepstral_coeffs_ring_buf_.GetArrayView(0);
auto prev1 = cepstral_coeffs_ring_buf_.GetArrayView(1);
auto prev2 = cepstral_coeffs_ring_buf_.GetArrayView(2);
RTC_DCHECK_EQ(average.size(), first_derivative.size());
RTC_DCHECK_EQ(first_derivative.size(), second_derivative.size());
RTC_DCHECK_LE(average.size(), curr.size());
for (int i = 0; rtc::SafeLt(i, average.size()); ++i) {
// Average, kernel: [1, 1, 1].
average[i] = curr[i] + prev1[i] + prev2[i];
// First derivative, kernel: [1, 0, - 1].
first_derivative[i] = curr[i] - prev2[i];
// Second derivative, Laplacian kernel: [1, -2, 1].
second_derivative[i] = curr[i] - 2 * prev1[i] + prev2[i];
}
}
void SpectralFeaturesExtractor::ComputeNormalizedCepstralCorrelation(
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr) {
spectral_correlator_.ComputeCrossCorrelation(
reference_frame_fft_->GetConstView(), lagged_frame_fft_->GetConstView(),
bands_cross_corr_);
// Normalize.
for (int i = 0; rtc::SafeLt(i, bands_cross_corr_.size()); ++i) {
bands_cross_corr_[i] =
bands_cross_corr_[i] /
std::sqrt(0.001f + reference_frame_bands_energy_[i] *
lagged_frame_bands_energy_[i]);
}
// Cepstrum.
ComputeDct(bands_cross_corr_, dct_table_, bands_cross_corr);
// Ad-hoc correction terms for the first two cepstral coefficients.
bands_cross_corr[0] -= 1.3f;
bands_cross_corr[1] -= 0.9f;
}
float SpectralFeaturesExtractor::ComputeVariability() const {
// Compute cepstral variability score.
float variability = 0.f;
for (int delay1 = 0; delay1 < kCepstralCoeffsHistorySize; ++delay1) {
float min_dist = std::numeric_limits<float>::max();
for (int delay2 = 0; delay2 < kCepstralCoeffsHistorySize; ++delay2) {
if (delay1 == delay2) // The distance would be 0.
continue;
min_dist =
std::min(min_dist, cepstral_diffs_buf_.GetValue(delay1, delay2));
}
variability += min_dist;
}
// Normalize (based on training set stats).
// TODO(bugs.webrtc.org/10480): Isolate normalization from feature extraction.
return variability / kCepstralCoeffsHistorySize - 2.1f;
}
} // namespace rnn_vad
} // namespace webrtc

View file

@ -0,0 +1,79 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_
#include <array>
#include <cstddef>
#include <memory>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/ring_buffer.h"
#include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
#include "modules/audio_processing/agc2/rnn_vad/symmetric_matrix_buffer.h"
#include "modules/audio_processing/utility/pffft_wrapper.h"
namespace webrtc {
namespace rnn_vad {
// Class to compute spectral features.
class SpectralFeaturesExtractor {
public:
SpectralFeaturesExtractor();
SpectralFeaturesExtractor(const SpectralFeaturesExtractor&) = delete;
SpectralFeaturesExtractor& operator=(const SpectralFeaturesExtractor&) =
delete;
~SpectralFeaturesExtractor();
// Resets the internal state of the feature extractor.
void Reset();
// Analyzes a pair of reference and lagged frames from the pitch buffer,
// detects silence and computes features. If silence is detected, the output
// is neither computed nor written.
bool CheckSilenceComputeFeatures(
rtc::ArrayView<const float, kFrameSize20ms24kHz> reference_frame,
rtc::ArrayView<const float, kFrameSize20ms24kHz> lagged_frame,
rtc::ArrayView<float, kNumBands - kNumLowerBands> higher_bands_cepstrum,
rtc::ArrayView<float, kNumLowerBands> average,
rtc::ArrayView<float, kNumLowerBands> first_derivative,
rtc::ArrayView<float, kNumLowerBands> second_derivative,
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr,
float* variability);
private:
void ComputeAvgAndDerivatives(
rtc::ArrayView<float, kNumLowerBands> average,
rtc::ArrayView<float, kNumLowerBands> first_derivative,
rtc::ArrayView<float, kNumLowerBands> second_derivative) const;
void ComputeNormalizedCepstralCorrelation(
rtc::ArrayView<float, kNumLowerBands> bands_cross_corr);
float ComputeVariability() const;
const std::array<float, kFrameSize20ms24kHz / 2> half_window_;
Pffft fft_;
std::unique_ptr<Pffft::FloatBuffer> fft_buffer_;
std::unique_ptr<Pffft::FloatBuffer> reference_frame_fft_;
std::unique_ptr<Pffft::FloatBuffer> lagged_frame_fft_;
SpectralCorrelator spectral_correlator_;
std::array<float, kOpusBands24kHz> reference_frame_bands_energy_;
std::array<float, kOpusBands24kHz> lagged_frame_bands_energy_;
std::array<float, kOpusBands24kHz> bands_cross_corr_;
const std::array<float, kNumBands * kNumBands> dct_table_;
RingBuffer<float, kNumBands, kCepstralCoeffsHistorySize>
cepstral_coeffs_ring_buf_;
SymmetricMatrixBuffer<float, kCepstralCoeffsHistorySize> cepstral_diffs_buf_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_H_

View file

@ -0,0 +1,188 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/spectral_features_internal.h"
#include <algorithm>
#include <cmath>
#include <cstddef>
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
namespace {
// Weights for each FFT coefficient for each Opus band (Nyquist frequency
// excluded). The size of each band is specified in
// `kOpusScaleNumBins24kHz20ms`.
constexpr std::array<float, kFrameSize20ms24kHz / 2> kOpusBandWeights24kHz20ms =
{{
0.f, 0.25f, 0.5f, 0.75f, // Band 0
0.f, 0.25f, 0.5f, 0.75f, // Band 1
0.f, 0.25f, 0.5f, 0.75f, // Band 2
0.f, 0.25f, 0.5f, 0.75f, // Band 3
0.f, 0.25f, 0.5f, 0.75f, // Band 4
0.f, 0.25f, 0.5f, 0.75f, // Band 5
0.f, 0.25f, 0.5f, 0.75f, // Band 6
0.f, 0.25f, 0.5f, 0.75f, // Band 7
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
0.625f, 0.75f, 0.875f, // Band 8
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
0.625f, 0.75f, 0.875f, // Band 9
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
0.625f, 0.75f, 0.875f, // Band 10
0.f, 0.125f, 0.25f, 0.375f, 0.5f,
0.625f, 0.75f, 0.875f, // Band 11
0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
0.9375f, // Band 12
0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
0.9375f, // Band 13
0.f, 0.0625f, 0.125f, 0.1875f, 0.25f,
0.3125f, 0.375f, 0.4375f, 0.5f, 0.5625f,
0.625f, 0.6875f, 0.75f, 0.8125f, 0.875f,
0.9375f, // Band 14
0.f, 0.0416667f, 0.0833333f, 0.125f, 0.166667f,
0.208333f, 0.25f, 0.291667f, 0.333333f, 0.375f,
0.416667f, 0.458333f, 0.5f, 0.541667f, 0.583333f,
0.625f, 0.666667f, 0.708333f, 0.75f, 0.791667f,
0.833333f, 0.875f, 0.916667f, 0.958333f, // Band 15
0.f, 0.0416667f, 0.0833333f, 0.125f, 0.166667f,
0.208333f, 0.25f, 0.291667f, 0.333333f, 0.375f,
0.416667f, 0.458333f, 0.5f, 0.541667f, 0.583333f,
0.625f, 0.666667f, 0.708333f, 0.75f, 0.791667f,
0.833333f, 0.875f, 0.916667f, 0.958333f, // Band 16
0.f, 0.03125f, 0.0625f, 0.09375f, 0.125f,
0.15625f, 0.1875f, 0.21875f, 0.25f, 0.28125f,
0.3125f, 0.34375f, 0.375f, 0.40625f, 0.4375f,
0.46875f, 0.5f, 0.53125f, 0.5625f, 0.59375f,
0.625f, 0.65625f, 0.6875f, 0.71875f, 0.75f,
0.78125f, 0.8125f, 0.84375f, 0.875f, 0.90625f,
0.9375f, 0.96875f, // Band 17
0.f, 0.0208333f, 0.0416667f, 0.0625f, 0.0833333f,
0.104167f, 0.125f, 0.145833f, 0.166667f, 0.1875f,
0.208333f, 0.229167f, 0.25f, 0.270833f, 0.291667f,
0.3125f, 0.333333f, 0.354167f, 0.375f, 0.395833f,
0.416667f, 0.4375f, 0.458333f, 0.479167f, 0.5f,
0.520833f, 0.541667f, 0.5625f, 0.583333f, 0.604167f,
0.625f, 0.645833f, 0.666667f, 0.6875f, 0.708333f,
0.729167f, 0.75f, 0.770833f, 0.791667f, 0.8125f,
0.833333f, 0.854167f, 0.875f, 0.895833f, 0.916667f,
0.9375f, 0.958333f, 0.979167f // Band 18
}};
} // namespace
SpectralCorrelator::SpectralCorrelator()
: weights_(kOpusBandWeights24kHz20ms.begin(),
kOpusBandWeights24kHz20ms.end()) {}
SpectralCorrelator::~SpectralCorrelator() = default;
void SpectralCorrelator::ComputeAutoCorrelation(
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const {
ComputeCrossCorrelation(x, x, auto_corr);
}
void SpectralCorrelator::ComputeCrossCorrelation(
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float, kOpusBands24kHz> cross_corr) const {
RTC_DCHECK_EQ(x.size(), kFrameSize20ms24kHz);
RTC_DCHECK_EQ(x.size(), y.size());
RTC_DCHECK_EQ(x[1], 0.f) << "The Nyquist coefficient must be zeroed.";
RTC_DCHECK_EQ(y[1], 0.f) << "The Nyquist coefficient must be zeroed.";
constexpr auto kOpusScaleNumBins24kHz20ms = GetOpusScaleNumBins24kHz20ms();
int k = 0; // Next Fourier coefficient index.
cross_corr[0] = 0.f;
for (int i = 0; i < kOpusBands24kHz - 1; ++i) {
cross_corr[i + 1] = 0.f;
for (int j = 0; j < kOpusScaleNumBins24kHz20ms[i]; ++j) { // Band size.
const float v = x[2 * k] * y[2 * k] + x[2 * k + 1] * y[2 * k + 1];
const float tmp = weights_[k] * v;
cross_corr[i] += v - tmp;
cross_corr[i + 1] += tmp;
k++;
}
}
cross_corr[0] *= 2.f; // The first band only gets half contribution.
RTC_DCHECK_EQ(k, kFrameSize20ms24kHz / 2); // Nyquist coefficient never used.
}
void ComputeSmoothedLogMagnitudeSpectrum(
rtc::ArrayView<const float> bands_energy,
rtc::ArrayView<float, kNumBands> log_bands_energy) {
RTC_DCHECK_LE(bands_energy.size(), kNumBands);
constexpr float kOneByHundred = 1e-2f;
constexpr float kLogOneByHundred = -2.f;
// Init.
float log_max = kLogOneByHundred;
float follow = kLogOneByHundred;
const auto smooth = [&log_max, &follow](float x) {
x = std::max(log_max - 7.f, std::max(follow - 1.5f, x));
log_max = std::max(log_max, x);
follow = std::max(follow - 1.5f, x);
return x;
};
// Smoothing over the bands for which the band energy is defined.
for (int i = 0; rtc::SafeLt(i, bands_energy.size()); ++i) {
log_bands_energy[i] = smooth(std::log10(kOneByHundred + bands_energy[i]));
}
// Smoothing over the remaining bands (zero energy).
for (int i = bands_energy.size(); i < kNumBands; ++i) {
log_bands_energy[i] = smooth(kLogOneByHundred);
}
}
std::array<float, kNumBands * kNumBands> ComputeDctTable() {
std::array<float, kNumBands * kNumBands> dct_table;
const double k = std::sqrt(0.5);
for (int i = 0; i < kNumBands; ++i) {
for (int j = 0; j < kNumBands; ++j)
dct_table[i * kNumBands + j] = std::cos((i + 0.5) * j * kPi / kNumBands);
dct_table[i * kNumBands] *= k;
}
return dct_table;
}
void ComputeDct(rtc::ArrayView<const float> in,
rtc::ArrayView<const float, kNumBands * kNumBands> dct_table,
rtc::ArrayView<float> out) {
// DCT scaling factor - i.e., sqrt(2 / kNumBands).
constexpr float kDctScalingFactor = 0.301511345f;
constexpr float kDctScalingFactorError =
kDctScalingFactor * kDctScalingFactor -
2.f / static_cast<float>(kNumBands);
static_assert(
(kDctScalingFactorError >= 0.f && kDctScalingFactorError < 1e-1f) ||
(kDctScalingFactorError < 0.f && kDctScalingFactorError > -1e-1f),
"kNumBands changed and kDctScalingFactor has not been updated.");
RTC_DCHECK_NE(in.data(), out.data()) << "In-place DCT is not supported.";
RTC_DCHECK_LE(in.size(), kNumBands);
RTC_DCHECK_LE(1, out.size());
RTC_DCHECK_LE(out.size(), in.size());
for (int i = 0; rtc::SafeLt(i, out.size()); ++i) {
out[i] = 0.f;
for (int j = 0; rtc::SafeLt(j, in.size()); ++j) {
out[i] += in[j] * dct_table[j * kNumBands + i];
}
// TODO(bugs.webrtc.org/10480): Scaling factor in the DCT table.
out[i] *= kDctScalingFactor;
}
}
} // namespace rnn_vad
} // namespace webrtc

View file

@ -0,0 +1,100 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_
#include <stddef.h>
#include <array>
#include <vector>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
namespace webrtc {
namespace rnn_vad {
// At a sample rate of 24 kHz, the last 3 Opus bands are beyond the Nyquist
// frequency. However, band #19 gets the contributions from band #18 because
// of the symmetric triangular filter with peak response at 12 kHz.
constexpr int kOpusBands24kHz = 20;
static_assert(kOpusBands24kHz < kNumBands,
"The number of bands at 24 kHz must be less than those defined "
"in the Opus scale at 48 kHz.");
// Number of FFT frequency bins covered by each band in the Opus scale at a
// sample rate of 24 kHz for 20 ms frames.
// Declared here for unit testing.
constexpr std::array<int, kOpusBands24kHz - 1> GetOpusScaleNumBins24kHz20ms() {
return {4, 4, 4, 4, 4, 4, 4, 4, 8, 8, 8, 8, 16, 16, 16, 24, 24, 32, 48};
}
// TODO(bugs.webrtc.org/10480): Move to a separate file.
// Class to compute band-wise spectral features in the Opus perceptual scale
// for 20 ms frames sampled at 24 kHz. The analysis methods apply triangular
// filters with peak response at the each band boundary.
class SpectralCorrelator {
public:
// Ctor.
SpectralCorrelator();
SpectralCorrelator(const SpectralCorrelator&) = delete;
SpectralCorrelator& operator=(const SpectralCorrelator&) = delete;
~SpectralCorrelator();
// Computes the band-wise spectral auto-correlations.
// `x` must:
// - have size equal to `kFrameSize20ms24kHz`;
// - be encoded as vectors of interleaved real-complex FFT coefficients
// where x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
void ComputeAutoCorrelation(
rtc::ArrayView<const float> x,
rtc::ArrayView<float, kOpusBands24kHz> auto_corr) const;
// Computes the band-wise spectral cross-correlations.
// `x` and `y` must:
// - have size equal to `kFrameSize20ms24kHz`;
// - be encoded as vectors of interleaved real-complex FFT coefficients where
// x[1] = y[1] = 0 (the Nyquist frequency coefficient is omitted).
void ComputeCrossCorrelation(
rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y,
rtc::ArrayView<float, kOpusBands24kHz> cross_corr) const;
private:
const std::vector<float> weights_; // Weights for each Fourier coefficient.
};
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
// spectral_features.cc. Given a vector of Opus-bands energy coefficients,
// computes the log magnitude spectrum applying smoothing both over time and
// over frequency. Declared here for unit testing.
void ComputeSmoothedLogMagnitudeSpectrum(
rtc::ArrayView<const float> bands_energy,
rtc::ArrayView<float, kNumBands> log_bands_energy);
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
// spectral_features.cc. Creates a DCT table for arrays having size equal to
// `kNumBands`. Declared here for unit testing.
std::array<float, kNumBands * kNumBands> ComputeDctTable();
// TODO(bugs.webrtc.org/10480): Move to anonymous namespace in
// spectral_features.cc. Computes DCT for `in` given a pre-computed DCT table.
// In-place computation is not allowed and `out` can be smaller than `in` in
// order to only compute the first DCT coefficients. Declared here for unit
// testing.
void ComputeDct(rtc::ArrayView<const float> in,
rtc::ArrayView<const float, kNumBands * kNumBands> dct_table,
rtc::ArrayView<float> out);
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SPECTRAL_FEATURES_INTERNAL_H_

View file

@ -0,0 +1,95 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_
#include <algorithm>
#include <array>
#include <cstring>
#include <utility>
#include "api/array_view.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
// Data structure to buffer the results of pair-wise comparisons between items
// stored in a ring buffer. Every time that the oldest item is replaced in the
// ring buffer, the new one is compared to the remaining items in the ring
// buffer. The results of such comparisons need to be buffered and automatically
// removed when one of the two corresponding items that have been compared is
// removed from the ring buffer. It is assumed that the comparison is symmetric
// and that comparing an item with itself is not needed.
template <typename T, int S>
class SymmetricMatrixBuffer {
static_assert(S > 2, "");
public:
SymmetricMatrixBuffer() = default;
SymmetricMatrixBuffer(const SymmetricMatrixBuffer&) = delete;
SymmetricMatrixBuffer& operator=(const SymmetricMatrixBuffer&) = delete;
~SymmetricMatrixBuffer() = default;
// Sets the buffer values to zero.
void Reset() {
static_assert(std::is_arithmetic<T>::value,
"Integral or floating point required.");
buf_.fill(0);
}
// Pushes the results from the comparison between the most recent item and
// those that are still in the ring buffer. The first element in `values` must
// correspond to the comparison between the most recent item and the second
// most recent one in the ring buffer, whereas the last element in `values`
// must correspond to the comparison between the most recent item and the
// oldest one in the ring buffer.
void Push(rtc::ArrayView<T, S - 1> values) {
// Move the lower-right sub-matrix of size (S-2) x (S-2) one row up and one
// column left.
std::memmove(buf_.data(), buf_.data() + S, (buf_.size() - S) * sizeof(T));
// Copy new values in the last column in the right order.
for (int i = 0; rtc::SafeLt(i, values.size()); ++i) {
const int index = (S - 1 - i) * (S - 1) - 1;
RTC_DCHECK_GE(index, 0);
RTC_DCHECK_LT(index, buf_.size());
buf_[index] = values[i];
}
}
// Reads the value that corresponds to comparison of two items in the ring
// buffer having delay `delay1` and `delay2`. The two arguments must not be
// equal and both must be in {0, ..., S - 1}.
T GetValue(int delay1, int delay2) const {
int row = S - 1 - delay1;
int col = S - 1 - delay2;
RTC_DCHECK_NE(row, col) << "The diagonal cannot be accessed.";
if (row > col)
std::swap(row, col); // Swap to access the upper-right triangular part.
RTC_DCHECK_LE(0, row);
RTC_DCHECK_LT(row, S - 1) << "Not enforcing row < col and row != col.";
RTC_DCHECK_LE(1, col) << "Not enforcing row < col and row != col.";
RTC_DCHECK_LT(col, S);
const int index = row * (S - 1) + (col - 1);
RTC_DCHECK_LE(0, index);
RTC_DCHECK_LT(index, buf_.size());
return buf_[index];
}
private:
// Encode an upper-right triangular matrix (excluding its diagonal) using a
// square matrix. This allows to move the data in Push() with one single
// operation.
std::array<T, (S - 1) * (S - 1)> buf_{};
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_SYMMETRIC_MATRIX_BUFFER_H_

View file

@ -0,0 +1,143 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/rnn_vad/test_utils.h"
#include <algorithm>
#include <fstream>
#include <memory>
#include <string>
#include <type_traits>
#include <vector>
#include "absl/strings/string_view.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
#include "test/gtest.h"
#include "test/testsupport/file_utils.h"
namespace webrtc {
namespace rnn_vad {
namespace {
// File reader for binary files that contain a sequence of values with
// arithmetic type `T`. The values of type `T` that are read are cast to float.
template <typename T>
class FloatFileReader : public FileReader {
public:
static_assert(std::is_arithmetic<T>::value, "");
explicit FloatFileReader(absl::string_view filename)
: is_(std::string(filename), std::ios::binary | std::ios::ate),
size_(is_.tellg() / sizeof(T)) {
RTC_CHECK(is_);
SeekBeginning();
}
FloatFileReader(const FloatFileReader&) = delete;
FloatFileReader& operator=(const FloatFileReader&) = delete;
~FloatFileReader() = default;
int size() const override { return size_; }
bool ReadChunk(rtc::ArrayView<float> dst) override {
const std::streamsize bytes_to_read = dst.size() * sizeof(T);
if (std::is_same<T, float>::value) {
is_.read(reinterpret_cast<char*>(dst.data()), bytes_to_read);
} else {
buffer_.resize(dst.size());
is_.read(reinterpret_cast<char*>(buffer_.data()), bytes_to_read);
std::transform(buffer_.begin(), buffer_.end(), dst.begin(),
[](const T& v) -> float { return static_cast<float>(v); });
}
return is_.gcount() == bytes_to_read;
}
bool ReadValue(float& dst) override { return ReadChunk({&dst, 1}); }
void SeekForward(int hop) override { is_.seekg(hop * sizeof(T), is_.cur); }
void SeekBeginning() override { is_.seekg(0, is_.beg); }
private:
std::ifstream is_;
const int size_;
std::vector<T> buffer_;
};
} // namespace
using webrtc::test::ResourcePath;
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed) {
ASSERT_EQ(expected.size(), computed.size());
for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
SCOPED_TRACE(i);
EXPECT_FLOAT_EQ(expected[i], computed[i]);
}
}
void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed,
float tolerance) {
ASSERT_EQ(expected.size(), computed.size());
for (int i = 0; rtc::SafeLt(i, expected.size()); ++i) {
SCOPED_TRACE(i);
EXPECT_NEAR(expected[i], computed[i], tolerance);
}
}
std::unique_ptr<FileReader> CreatePcmSamplesReader() {
return std::make_unique<FloatFileReader<int16_t>>(
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/samples",
"pcm"));
}
ChunksFileReader CreatePitchBuffer24kHzReader() {
auto reader = std::make_unique<FloatFileReader<float>>(
/*filename=*/test::ResourcePath(
"audio_processing/agc2/rnn_vad/pitch_buf_24k", "dat"));
const int num_chunks = rtc::CheckedDivExact(reader->size(), kBufSize24kHz);
return {/*chunk_size=*/kBufSize24kHz, num_chunks, std::move(reader)};
}
ChunksFileReader CreateLpResidualAndPitchInfoReader() {
constexpr int kPitchInfoSize = 2; // Pitch period and strength.
constexpr int kChunkSize = kBufSize24kHz + kPitchInfoSize;
auto reader = std::make_unique<FloatFileReader<float>>(
/*filename=*/test::ResourcePath(
"audio_processing/agc2/rnn_vad/pitch_lp_res", "dat"));
const int num_chunks = rtc::CheckedDivExact(reader->size(), kChunkSize);
return {kChunkSize, num_chunks, std::move(reader)};
}
std::unique_ptr<FileReader> CreateGruInputReader() {
return std::make_unique<FloatFileReader<float>>(
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/gru_in",
"dat"));
}
std::unique_ptr<FileReader> CreateVadProbsReader() {
return std::make_unique<FloatFileReader<float>>(
/*filename=*/test::ResourcePath("audio_processing/agc2/rnn_vad/vad_prob",
"dat"));
}
PitchTestData::PitchTestData() {
FloatFileReader<float> reader(
/*filename=*/ResourcePath(
"audio_processing/agc2/rnn_vad/pitch_search_int", "dat"));
reader.ReadChunk(pitch_buffer_24k_);
reader.ReadChunk(square_energies_24k_);
reader.ReadChunk(auto_correlation_12k_);
// Reverse the order of the squared energy values.
// Required after the WebRTC CL 191703 which switched to forward computation.
std::reverse(square_energies_24k_.begin(), square_energies_24k_.end());
}
PitchTestData::~PitchTestData() = default;
} // namespace rnn_vad
} // namespace webrtc

View file

@ -0,0 +1,130 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_
#include <array>
#include <fstream>
#include <memory>
#include <string>
#include "absl/strings/string_view.h"
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
namespace rnn_vad {
constexpr float kFloatMin = std::numeric_limits<float>::min();
// Fails for every pair from two equally sized rtc::ArrayView<float> views such
// that the values in the pair do not match.
void ExpectEqualFloatArray(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed);
// Fails for every pair from two equally sized rtc::ArrayView<float> views such
// that their absolute error is above a given threshold.
void ExpectNearAbsolute(rtc::ArrayView<const float> expected,
rtc::ArrayView<const float> computed,
float tolerance);
// File reader interface.
class FileReader {
public:
virtual ~FileReader() = default;
// Number of values in the file.
virtual int size() const = 0;
// Reads `dst.size()` float values into `dst`, advances the internal file
// position according to the number of read bytes and returns true if the
// values are correctly read. If the number of remaining bytes in the file is
// not sufficient to read `dst.size()` float values, `dst` is partially
// modified and false is returned.
virtual bool ReadChunk(rtc::ArrayView<float> dst) = 0;
// Reads a single float value, advances the internal file position according
// to the number of read bytes and returns true if the value is correctly
// read. If the number of remaining bytes in the file is not sufficient to
// read one float, `dst` is not modified and false is returned.
virtual bool ReadValue(float& dst) = 0;
// Advances the internal file position by `hop` float values.
virtual void SeekForward(int hop) = 0;
// Resets the internal file position to BOF.
virtual void SeekBeginning() = 0;
};
// File reader for files that contain `num_chunks` chunks with size equal to
// `chunk_size`.
struct ChunksFileReader {
const int chunk_size;
const int num_chunks;
std::unique_ptr<FileReader> reader;
};
// Creates a reader for the PCM S16 samples file.
std::unique_ptr<FileReader> CreatePcmSamplesReader();
// Creates a reader for the 24 kHz pitch buffer test data.
ChunksFileReader CreatePitchBuffer24kHzReader();
// Creates a reader for the LP residual and pitch information test data.
ChunksFileReader CreateLpResidualAndPitchInfoReader();
// Creates a reader for the sequence of GRU input vectors.
std::unique_ptr<FileReader> CreateGruInputReader();
// Creates a reader for the VAD probabilities test data.
std::unique_ptr<FileReader> CreateVadProbsReader();
// Class to retrieve a test pitch buffer content and the expected output for the
// analysis steps.
class PitchTestData {
public:
PitchTestData();
~PitchTestData();
rtc::ArrayView<const float, kBufSize24kHz> PitchBuffer24kHzView() const {
return pitch_buffer_24k_;
}
rtc::ArrayView<const float, kRefineNumLags24kHz> SquareEnergies24kHzView()
const {
return square_energies_24k_;
}
rtc::ArrayView<const float, kNumLags12kHz> AutoCorrelation12kHzView() const {
return auto_correlation_12k_;
}
private:
std::array<float, kBufSize24kHz> pitch_buffer_24k_;
std::array<float, kRefineNumLags24kHz> square_energies_24k_;
std::array<float, kNumLags12kHz> auto_correlation_12k_;
};
// Writer for binary files.
class FileWriter {
public:
explicit FileWriter(absl::string_view file_path)
: os_(std::string(file_path), std::ios::binary) {}
FileWriter(const FileWriter&) = delete;
FileWriter& operator=(const FileWriter&) = delete;
~FileWriter() = default;
void WriteChunk(rtc::ArrayView<const float> value) {
const std::streamsize bytes_to_write = value.size() * sizeof(float);
os_.write(reinterpret_cast<const char*>(value.data()), bytes_to_write);
}
private:
std::ofstream os_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_TEST_UTILS_H_

View file

@ -0,0 +1,115 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
#define MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_
// Defines WEBRTC_ARCH_X86_FAMILY, used below.
#include "rtc_base/system/arch.h"
#if defined(WEBRTC_HAS_NEON)
#include <arm_neon.h>
#endif
#if defined(WEBRTC_ARCH_X86_FAMILY)
#include <emmintrin.h>
#endif
#include <numeric>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_conversions.h"
#include "rtc_base/system/arch.h"
namespace webrtc {
namespace rnn_vad {
// Provides optimizations for mathematical operations having vectors as
// operand(s).
class VectorMath {
public:
explicit VectorMath(AvailableCpuFeatures cpu_features)
: cpu_features_(cpu_features) {}
// Computes the dot product between two equally sized vectors.
float DotProduct(rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y) const {
RTC_DCHECK_EQ(x.size(), y.size());
#if defined(WEBRTC_ARCH_X86_FAMILY)
// TODO(@dkaraush): compile with avx support
/*if (cpu_features_.avx2) {
return DotProductAvx2(x, y);
} else */if (cpu_features_.sse2) {
__m128 accumulator = _mm_setzero_ps();
constexpr int kBlockSizeLog2 = 2;
constexpr int kBlockSize = 1 << kBlockSizeLog2;
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
<< kBlockSizeLog2;
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
RTC_DCHECK_LE(i + kBlockSize, x.size());
const __m128 x_i = _mm_loadu_ps(&x[i]);
const __m128 y_i = _mm_loadu_ps(&y[i]);
// Multiply-add.
const __m128 z_j = _mm_mul_ps(x_i, y_i);
accumulator = _mm_add_ps(accumulator, z_j);
}
// Reduce `accumulator` by addition.
__m128 high = _mm_movehl_ps(accumulator, accumulator);
accumulator = _mm_add_ps(accumulator, high);
high = _mm_shuffle_ps(accumulator, accumulator, 1);
accumulator = _mm_add_ps(accumulator, high);
float dot_product = _mm_cvtss_f32(accumulator);
// Add the result for the last block if incomplete.
for (int i = incomplete_block_index;
i < rtc::dchecked_cast<int>(x.size()); ++i) {
dot_product += x[i] * y[i];
}
return dot_product;
}
#elif defined(WEBRTC_HAS_NEON) && defined(WEBRTC_ARCH_ARM64)
if (cpu_features_.neon) {
float32x4_t accumulator = vdupq_n_f32(0.f);
constexpr int kBlockSizeLog2 = 2;
constexpr int kBlockSize = 1 << kBlockSizeLog2;
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
<< kBlockSizeLog2;
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
RTC_DCHECK_LE(i + kBlockSize, x.size());
const float32x4_t x_i = vld1q_f32(&x[i]);
const float32x4_t y_i = vld1q_f32(&y[i]);
accumulator = vfmaq_f32(accumulator, x_i, y_i);
}
// Reduce `accumulator` by addition.
const float32x2_t tmp =
vpadd_f32(vget_low_f32(accumulator), vget_high_f32(accumulator));
float dot_product = vget_lane_f32(vpadd_f32(tmp, vrev64_f32(tmp)), 0);
// Add the result for the last block if incomplete.
for (int i = incomplete_block_index;
i < rtc::dchecked_cast<int>(x.size()); ++i) {
dot_product += x[i] * y[i];
}
return dot_product;
}
#endif
return std::inner_product(x.begin(), x.end(), y.begin(), 0.f);
}
private:
float DotProductAvx2(rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y) const;
const AvailableCpuFeatures cpu_features_;
};
} // namespace rnn_vad
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_RNN_VAD_VECTOR_MATH_H_

View file

@ -0,0 +1,54 @@
/*
* Copyright (c) 2020 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include <immintrin.h>
#include "api/array_view.h"
#include "modules/audio_processing/agc2/rnn_vad/vector_math.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_conversions.h"
namespace webrtc {
namespace rnn_vad {
float VectorMath::DotProductAvx2(rtc::ArrayView<const float> x,
rtc::ArrayView<const float> y) const {
RTC_DCHECK(cpu_features_.avx2);
RTC_DCHECK_EQ(x.size(), y.size());
__m256 accumulator = _mm256_setzero_ps();
constexpr int kBlockSizeLog2 = 3;
constexpr int kBlockSize = 1 << kBlockSizeLog2;
const int incomplete_block_index = (x.size() >> kBlockSizeLog2)
<< kBlockSizeLog2;
for (int i = 0; i < incomplete_block_index; i += kBlockSize) {
RTC_DCHECK_LE(i + kBlockSize, x.size());
const __m256 x_i = _mm256_loadu_ps(&x[i]);
const __m256 y_i = _mm256_loadu_ps(&y[i]);
accumulator = _mm256_fmadd_ps(x_i, y_i, accumulator);
}
// Reduce `accumulator` by addition.
__m128 high = _mm256_extractf128_ps(accumulator, 1);
__m128 low = _mm256_extractf128_ps(accumulator, 0);
low = _mm_add_ps(high, low);
high = _mm_movehl_ps(high, low);
low = _mm_add_ps(high, low);
high = _mm_shuffle_ps(low, low, 1);
low = _mm_add_ss(high, low);
float dot_product = _mm_cvtss_f32(low);
// Add the result for the last block if incomplete.
for (int i = incomplete_block_index; i < rtc::dchecked_cast<int>(x.size());
++i) {
dot_product += x[i] * y[i];
}
return dot_product;
}
} // namespace rnn_vad
} // namespace webrtc

View file

@ -0,0 +1,183 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/saturation_protector.h"
#include <memory>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/saturation_protector_buffer.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
constexpr int kPeakEnveloperSuperFrameLengthMs = 400;
constexpr float kMinMarginDb = 12.0f;
constexpr float kMaxMarginDb = 25.0f;
constexpr float kAttack = 0.9988493699365052f;
constexpr float kDecay = 0.9997697679981565f;
// Saturation protector state. Defined outside of `SaturationProtectorImpl` to
// implement check-point and restore ops.
struct SaturationProtectorState {
bool operator==(const SaturationProtectorState& s) const {
return headroom_db == s.headroom_db &&
peak_delay_buffer == s.peak_delay_buffer &&
max_peaks_dbfs == s.max_peaks_dbfs &&
time_since_push_ms == s.time_since_push_ms;
}
inline bool operator!=(const SaturationProtectorState& s) const {
return !(*this == s);
}
float headroom_db;
SaturationProtectorBuffer peak_delay_buffer;
float max_peaks_dbfs;
int time_since_push_ms; // Time since the last ring buffer push operation.
};
// Resets the saturation protector state.
void ResetSaturationProtectorState(float initial_headroom_db,
SaturationProtectorState& state) {
state.headroom_db = initial_headroom_db;
state.peak_delay_buffer.Reset();
state.max_peaks_dbfs = kMinLevelDbfs;
state.time_since_push_ms = 0;
}
// Updates `state` by analyzing the estimated speech level `speech_level_dbfs`
// and the peak level `peak_dbfs` for an observed frame. `state` must not be
// modified without calling this function.
void UpdateSaturationProtectorState(float peak_dbfs,
float speech_level_dbfs,
SaturationProtectorState& state) {
// Get the max peak over `kPeakEnveloperSuperFrameLengthMs` ms.
state.max_peaks_dbfs = std::max(state.max_peaks_dbfs, peak_dbfs);
state.time_since_push_ms += kFrameDurationMs;
if (rtc::SafeGt(state.time_since_push_ms, kPeakEnveloperSuperFrameLengthMs)) {
// Push `max_peaks_dbfs` back into the ring buffer.
state.peak_delay_buffer.PushBack(state.max_peaks_dbfs);
// Reset.
state.max_peaks_dbfs = kMinLevelDbfs;
state.time_since_push_ms = 0;
}
// Update the headroom by comparing the estimated speech level and the delayed
// max speech peak.
const float delayed_peak_dbfs =
state.peak_delay_buffer.Front().value_or(state.max_peaks_dbfs);
const float difference_db = delayed_peak_dbfs - speech_level_dbfs;
if (difference_db > state.headroom_db) {
// Attack.
state.headroom_db =
state.headroom_db * kAttack + difference_db * (1.0f - kAttack);
} else {
// Decay.
state.headroom_db =
state.headroom_db * kDecay + difference_db * (1.0f - kDecay);
}
state.headroom_db =
rtc::SafeClamp<float>(state.headroom_db, kMinMarginDb, kMaxMarginDb);
}
// Saturation protector which recommends a headroom based on the recent peaks.
class SaturationProtectorImpl : public SaturationProtector {
public:
explicit SaturationProtectorImpl(float initial_headroom_db,
int adjacent_speech_frames_threshold,
ApmDataDumper* apm_data_dumper)
: apm_data_dumper_(apm_data_dumper),
initial_headroom_db_(initial_headroom_db),
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold) {
Reset();
}
SaturationProtectorImpl(const SaturationProtectorImpl&) = delete;
SaturationProtectorImpl& operator=(const SaturationProtectorImpl&) = delete;
~SaturationProtectorImpl() = default;
float HeadroomDb() override { return headroom_db_; }
void Analyze(float speech_probability,
float peak_dbfs,
float speech_level_dbfs) override {
if (speech_probability < kVadConfidenceThreshold) {
// Not a speech frame.
if (adjacent_speech_frames_threshold_ > 1) {
// When two or more adjacent speech frames are required in order to
// update the state, we need to decide whether to discard or confirm the
// updates based on the speech sequence length.
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// First non-speech frame after a long enough sequence of speech
// frames. Update the reliable state.
reliable_state_ = preliminary_state_;
} else if (num_adjacent_speech_frames_ > 0) {
// First non-speech frame after a too short sequence of speech frames.
// Reset to the last reliable state.
preliminary_state_ = reliable_state_;
}
}
num_adjacent_speech_frames_ = 0;
} else {
// Speech frame observed.
num_adjacent_speech_frames_++;
// Update preliminary level estimate.
UpdateSaturationProtectorState(peak_dbfs, speech_level_dbfs,
preliminary_state_);
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// `preliminary_state_` is now reliable. Update the headroom.
headroom_db_ = preliminary_state_.headroom_db;
}
}
DumpDebugData();
}
void Reset() override {
num_adjacent_speech_frames_ = 0;
headroom_db_ = initial_headroom_db_;
ResetSaturationProtectorState(initial_headroom_db_, preliminary_state_);
ResetSaturationProtectorState(initial_headroom_db_, reliable_state_);
}
private:
void DumpDebugData() {
apm_data_dumper_->DumpRaw(
"agc2_saturation_protector_preliminary_max_peak_dbfs",
preliminary_state_.max_peaks_dbfs);
apm_data_dumper_->DumpRaw(
"agc2_saturation_protector_reliable_max_peak_dbfs",
reliable_state_.max_peaks_dbfs);
}
ApmDataDumper* const apm_data_dumper_;
const float initial_headroom_db_;
const int adjacent_speech_frames_threshold_;
int num_adjacent_speech_frames_;
float headroom_db_;
SaturationProtectorState preliminary_state_;
SaturationProtectorState reliable_state_;
};
} // namespace
std::unique_ptr<SaturationProtector> CreateSaturationProtector(
float initial_headroom_db,
int adjacent_speech_frames_threshold,
ApmDataDumper* apm_data_dumper) {
return std::make_unique<SaturationProtectorImpl>(
initial_headroom_db, adjacent_speech_frames_threshold, apm_data_dumper);
}
} // namespace webrtc

View file

@ -0,0 +1,46 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_
#include <memory>
namespace webrtc {
class ApmDataDumper;
// Saturation protector. Analyzes peak levels and recommends a headroom to
// reduce the chances of clipping.
class SaturationProtector {
public:
virtual ~SaturationProtector() = default;
// Returns the recommended headroom in dB.
virtual float HeadroomDb() = 0;
// Analyzes the peak level of a 10 ms frame along with its speech probability
// and the current speech level estimate to update the recommended headroom.
virtual void Analyze(float speech_probability,
float peak_dbfs,
float speech_level_dbfs) = 0;
// Resets the internal state.
virtual void Reset() = 0;
};
// Creates a saturation protector that starts at `initial_headroom_db`.
std::unique_ptr<SaturationProtector> CreateSaturationProtector(
float initial_headroom_db,
int adjacent_speech_frames_threshold,
ApmDataDumper* apm_data_dumper);
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_H_

View file

@ -0,0 +1,77 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/saturation_protector_buffer.h"
#include "rtc_base/checks.h"
#include "rtc_base/numerics/safe_compare.h"
namespace webrtc {
SaturationProtectorBuffer::SaturationProtectorBuffer() = default;
SaturationProtectorBuffer::~SaturationProtectorBuffer() = default;
bool SaturationProtectorBuffer::operator==(
const SaturationProtectorBuffer& b) const {
RTC_DCHECK_LE(size_, buffer_.size());
RTC_DCHECK_LE(b.size_, b.buffer_.size());
if (size_ != b.size_) {
return false;
}
for (int i = 0, i0 = FrontIndex(), i1 = b.FrontIndex(); i < size_;
++i, ++i0, ++i1) {
if (buffer_[i0 % buffer_.size()] != b.buffer_[i1 % b.buffer_.size()]) {
return false;
}
}
return true;
}
int SaturationProtectorBuffer::Capacity() const {
return buffer_.size();
}
int SaturationProtectorBuffer::Size() const {
return size_;
}
void SaturationProtectorBuffer::Reset() {
next_ = 0;
size_ = 0;
}
void SaturationProtectorBuffer::PushBack(float v) {
RTC_DCHECK_GE(next_, 0);
RTC_DCHECK_GE(size_, 0);
RTC_DCHECK_LT(next_, buffer_.size());
RTC_DCHECK_LE(size_, buffer_.size());
buffer_[next_++] = v;
if (rtc::SafeEq(next_, buffer_.size())) {
next_ = 0;
}
if (rtc::SafeLt(size_, buffer_.size())) {
size_++;
}
}
absl::optional<float> SaturationProtectorBuffer::Front() const {
if (size_ == 0) {
return absl::nullopt;
}
RTC_DCHECK_LT(FrontIndex(), buffer_.size());
return buffer_[FrontIndex()];
}
int SaturationProtectorBuffer::FrontIndex() const {
return rtc::SafeEq(size_, buffer_.size()) ? next_ : 0;
}
} // namespace webrtc

View file

@ -0,0 +1,59 @@
/*
* Copyright (c) 2021 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_BUFFER_H_
#include <array>
#include "absl/types/optional.h"
#include "modules/audio_processing/agc2/agc2_common.h"
namespace webrtc {
// Ring buffer for the saturation protector which only supports (i) push back
// and (ii) read oldest item.
class SaturationProtectorBuffer {
public:
SaturationProtectorBuffer();
~SaturationProtectorBuffer();
bool operator==(const SaturationProtectorBuffer& b) const;
inline bool operator!=(const SaturationProtectorBuffer& b) const {
return !(*this == b);
}
// Maximum number of values that the buffer can contain.
int Capacity() const;
// Number of values in the buffer.
int Size() const;
void Reset();
// Pushes back `v`. If the buffer is full, the oldest value is replaced.
void PushBack(float v);
// Returns the oldest item in the buffer. Returns an empty value if the
// buffer is empty.
absl::optional<float> Front() const;
private:
int FrontIndex() const;
// `buffer_` has `size_` elements (up to the size of `buffer_`) and `next_` is
// the position where the next new value is written in `buffer_`.
std::array<float, kSaturationProtectorBufferSize> buffer_;
int next_ = 0;
int size_ = 0;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_SATURATION_PROTECTOR_BUFFER_H_

View file

@ -0,0 +1,174 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/speech_level_estimator.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/logging/apm_data_dumper.h"
#include "rtc_base/checks.h"
#include "rtc_base/logging.h"
#include "rtc_base/numerics/safe_minmax.h"
namespace webrtc {
namespace {
float ClampLevelEstimateDbfs(float level_estimate_dbfs) {
return rtc::SafeClamp<float>(level_estimate_dbfs, -90.0f, 30.0f);
}
// Returns the initial speech level estimate needed to apply the initial gain.
float GetInitialSpeechLevelEstimateDbfs(
const AudioProcessing::Config::GainController2::AdaptiveDigital& config) {
return ClampLevelEstimateDbfs(-kSaturationProtectorInitialHeadroomDb -
config.initial_gain_db - config.headroom_db);
}
} // namespace
bool SpeechLevelEstimator::LevelEstimatorState::operator==(
const SpeechLevelEstimator::LevelEstimatorState& b) const {
return time_to_confidence_ms == b.time_to_confidence_ms &&
level_dbfs.numerator == b.level_dbfs.numerator &&
level_dbfs.denominator == b.level_dbfs.denominator;
}
float SpeechLevelEstimator::LevelEstimatorState::Ratio::GetRatio() const {
RTC_DCHECK_NE(denominator, 0.f);
return numerator / denominator;
}
SpeechLevelEstimator::SpeechLevelEstimator(
ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2::AdaptiveDigital& config,
int adjacent_speech_frames_threshold)
: apm_data_dumper_(apm_data_dumper),
initial_speech_level_dbfs_(GetInitialSpeechLevelEstimateDbfs(config)),
adjacent_speech_frames_threshold_(adjacent_speech_frames_threshold),
level_dbfs_(initial_speech_level_dbfs_),
// TODO(bugs.webrtc.org/7494): Remove init below when AGC2 input volume
// controller temporal dependency removed.
is_confident_(false) {
RTC_DCHECK(apm_data_dumper_);
RTC_DCHECK_GE(adjacent_speech_frames_threshold_, 1);
Reset();
}
void SpeechLevelEstimator::Update(float rms_dbfs,
float peak_dbfs,
float speech_probability) {
RTC_DCHECK_GT(rms_dbfs, -150.0f);
RTC_DCHECK_LT(rms_dbfs, 50.0f);
RTC_DCHECK_GT(peak_dbfs, -150.0f);
RTC_DCHECK_LT(peak_dbfs, 50.0f);
RTC_DCHECK_GE(speech_probability, 0.0f);
RTC_DCHECK_LE(speech_probability, 1.0f);
if (speech_probability < kVadConfidenceThreshold) {
// Not a speech frame.
if (adjacent_speech_frames_threshold_ > 1) {
// When two or more adjacent speech frames are required in order to update
// the state, we need to decide whether to discard or confirm the updates
// based on the speech sequence length.
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// First non-speech frame after a long enough sequence of speech frames.
// Update the reliable state.
reliable_state_ = preliminary_state_;
} else if (num_adjacent_speech_frames_ > 0) {
// First non-speech frame after a too short sequence of speech frames.
// Reset to the last reliable state.
preliminary_state_ = reliable_state_;
}
}
num_adjacent_speech_frames_ = 0;
} else {
// Speech frame observed.
num_adjacent_speech_frames_++;
// Update preliminary level estimate.
RTC_DCHECK_GE(preliminary_state_.time_to_confidence_ms, 0);
const bool buffer_is_full = preliminary_state_.time_to_confidence_ms == 0;
if (!buffer_is_full) {
preliminary_state_.time_to_confidence_ms -= kFrameDurationMs;
}
// Weighted average of levels with speech probability as weight.
RTC_DCHECK_GT(speech_probability, 0.0f);
const float leak_factor = buffer_is_full ? kLevelEstimatorLeakFactor : 1.0f;
preliminary_state_.level_dbfs.numerator =
preliminary_state_.level_dbfs.numerator * leak_factor +
rms_dbfs * speech_probability;
preliminary_state_.level_dbfs.denominator =
preliminary_state_.level_dbfs.denominator * leak_factor +
speech_probability;
const float level_dbfs = preliminary_state_.level_dbfs.GetRatio();
if (num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_) {
// `preliminary_state_` is now reliable. Update the last level estimation.
level_dbfs_ = ClampLevelEstimateDbfs(level_dbfs);
}
}
UpdateIsConfident();
DumpDebugData();
}
void SpeechLevelEstimator::UpdateIsConfident() {
if (adjacent_speech_frames_threshold_ == 1) {
// Ignore `reliable_state_` when a single frame is enough to update the
// level estimate (because it is not used).
is_confident_ = preliminary_state_.time_to_confidence_ms == 0;
return;
}
// Once confident, it remains confident.
RTC_DCHECK(reliable_state_.time_to_confidence_ms != 0 ||
preliminary_state_.time_to_confidence_ms == 0);
// During the first long enough speech sequence, `reliable_state_` must be
// ignored since `preliminary_state_` is used.
is_confident_ =
reliable_state_.time_to_confidence_ms == 0 ||
(num_adjacent_speech_frames_ >= adjacent_speech_frames_threshold_ &&
preliminary_state_.time_to_confidence_ms == 0);
}
void SpeechLevelEstimator::Reset() {
ResetLevelEstimatorState(preliminary_state_);
ResetLevelEstimatorState(reliable_state_);
level_dbfs_ = initial_speech_level_dbfs_;
num_adjacent_speech_frames_ = 0;
}
void SpeechLevelEstimator::ResetLevelEstimatorState(
LevelEstimatorState& state) const {
state.time_to_confidence_ms = kLevelEstimatorTimeToConfidenceMs;
state.level_dbfs.numerator = initial_speech_level_dbfs_;
state.level_dbfs.denominator = 1.0f;
}
void SpeechLevelEstimator::DumpDebugData() const {
if (!apm_data_dumper_)
return;
apm_data_dumper_->DumpRaw("agc2_speech_level_dbfs", level_dbfs_);
apm_data_dumper_->DumpRaw("agc2_speech_level_is_confident", is_confident_);
apm_data_dumper_->DumpRaw(
"agc2_adaptive_level_estimator_num_adjacent_speech_frames",
num_adjacent_speech_frames_);
apm_data_dumper_->DumpRaw(
"agc2_adaptive_level_estimator_preliminary_level_estimate_num",
preliminary_state_.level_dbfs.numerator);
apm_data_dumper_->DumpRaw(
"agc2_adaptive_level_estimator_preliminary_level_estimate_den",
preliminary_state_.level_dbfs.denominator);
apm_data_dumper_->DumpRaw(
"agc2_adaptive_level_estimator_preliminary_time_to_confidence_ms",
preliminary_state_.time_to_confidence_ms);
apm_data_dumper_->DumpRaw(
"agc2_adaptive_level_estimator_reliable_time_to_confidence_ms",
reliable_state_.time_to_confidence_ms);
}
} // namespace webrtc

View file

@ -0,0 +1,81 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SPEECH_LEVEL_ESTIMATOR_H_
#define MODULES_AUDIO_PROCESSING_AGC2_SPEECH_LEVEL_ESTIMATOR_H_
#include <stddef.h>
#include <type_traits>
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/include/audio_processing.h"
namespace webrtc {
class ApmDataDumper;
// Active speech level estimator based on the analysis of the following
// framewise properties: RMS level (dBFS), peak level (dBFS), speech
// probability.
class SpeechLevelEstimator {
public:
SpeechLevelEstimator(
ApmDataDumper* apm_data_dumper,
const AudioProcessing::Config::GainController2::AdaptiveDigital& config,
int adjacent_speech_frames_threshold);
SpeechLevelEstimator(const SpeechLevelEstimator&) = delete;
SpeechLevelEstimator& operator=(const SpeechLevelEstimator&) = delete;
// Updates the level estimation.
void Update(float rms_dbfs, float peak_dbfs, float speech_probability);
// Returns the estimated speech plus noise level.
float level_dbfs() const { return level_dbfs_; }
// Returns true if the estimator is confident on its current estimate.
bool is_confident() const { return is_confident_; }
void Reset();
private:
// Part of the level estimator state used for check-pointing and restore ops.
struct LevelEstimatorState {
bool operator==(const LevelEstimatorState& s) const;
inline bool operator!=(const LevelEstimatorState& s) const {
return !(*this == s);
}
// TODO(bugs.webrtc.org/7494): Remove `time_to_confidence_ms` if redundant.
int time_to_confidence_ms;
struct Ratio {
float numerator;
float denominator;
float GetRatio() const;
} level_dbfs;
};
static_assert(std::is_trivially_copyable<LevelEstimatorState>::value, "");
void UpdateIsConfident();
void ResetLevelEstimatorState(LevelEstimatorState& state) const;
void DumpDebugData() const;
ApmDataDumper* const apm_data_dumper_;
const float initial_speech_level_dbfs_;
const int adjacent_speech_frames_threshold_;
LevelEstimatorState preliminary_state_;
LevelEstimatorState reliable_state_;
float level_dbfs_;
bool is_confident_;
int num_adjacent_speech_frames_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_SPEECH_LEVEL_ESTIMATOR_H_

View file

@ -0,0 +1,105 @@
/*
* Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/speech_probability_buffer.h"
#include <algorithm>
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr float kActivityThreshold = 0.9f;
constexpr int kNumAnalysisFrames = 100;
// We use 12 in AGC2 adaptive digital, but with a slightly different logic.
constexpr int kTransientWidthThreshold = 7;
} // namespace
SpeechProbabilityBuffer::SpeechProbabilityBuffer(
float low_probability_threshold)
: low_probability_threshold_(low_probability_threshold),
probabilities_(kNumAnalysisFrames) {
RTC_DCHECK_GE(low_probability_threshold, 0.0f);
RTC_DCHECK_LE(low_probability_threshold, 1.0f);
RTC_DCHECK(!probabilities_.empty());
}
void SpeechProbabilityBuffer::Update(float probability) {
// Remove the oldest entry if the circular buffer is full.
if (buffer_is_full_) {
const float oldest_probability = probabilities_[buffer_index_];
sum_probabilities_ -= oldest_probability;
}
// Check for transients.
if (probability <= low_probability_threshold_) {
// Set a probability lower than the threshold to zero.
probability = 0.0f;
// Check if this has been a transient.
if (num_high_probability_observations_ <= kTransientWidthThreshold) {
RemoveTransient();
}
num_high_probability_observations_ = 0;
} else if (num_high_probability_observations_ <= kTransientWidthThreshold) {
++num_high_probability_observations_;
}
// Update the circular buffer and the current sum.
probabilities_[buffer_index_] = probability;
sum_probabilities_ += probability;
// Increment the buffer index and check for wrap-around.
if (++buffer_index_ >= kNumAnalysisFrames) {
buffer_index_ = 0;
buffer_is_full_ = true;
}
}
void SpeechProbabilityBuffer::RemoveTransient() {
// Don't expect to be here if high-activity region is longer than
// `kTransientWidthThreshold` or there has not been any transient.
RTC_DCHECK_LE(num_high_probability_observations_, kTransientWidthThreshold);
// Replace previously added probabilities with zero.
int index =
(buffer_index_ > 0) ? (buffer_index_ - 1) : (kNumAnalysisFrames - 1);
while (num_high_probability_observations_-- > 0) {
sum_probabilities_ -= probabilities_[index];
probabilities_[index] = 0.0f;
// Update the circular buffer index.
index = (index > 0) ? (index - 1) : (kNumAnalysisFrames - 1);
}
}
bool SpeechProbabilityBuffer::IsActiveSegment() const {
if (!buffer_is_full_) {
return false;
}
if (sum_probabilities_ < kActivityThreshold * kNumAnalysisFrames) {
return false;
}
return true;
}
void SpeechProbabilityBuffer::Reset() {
sum_probabilities_ = 0.0f;
// Empty the circular buffer.
buffer_index_ = 0;
buffer_is_full_ = false;
num_high_probability_observations_ = 0;
}
} // namespace webrtc

View file

@ -0,0 +1,80 @@
/*
* Copyright (c) 2022 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_
#include <vector>
#include "rtc_base/gtest_prod_util.h"
namespace webrtc {
// This class implements a circular buffer that stores speech probabilities
// for a speech segment and estimates speech activity for that segment.
class SpeechProbabilityBuffer {
public:
// Ctor. The value of `low_probability_threshold` is required to be on the
// range [0.0f, 1.0f].
explicit SpeechProbabilityBuffer(float low_probability_threshold);
~SpeechProbabilityBuffer() {}
SpeechProbabilityBuffer(const SpeechProbabilityBuffer&) = delete;
SpeechProbabilityBuffer& operator=(const SpeechProbabilityBuffer&) = delete;
// Adds `probability` in the buffer and computes an updatds sum of the buffer
// probabilities. Value of `probability` is required to be on the range
// [0.0f, 1.0f].
void Update(float probability);
// Resets the histogram, forgets the past.
void Reset();
// Returns true if the segment is active (a long enough segment with an
// average speech probability above `low_probability_threshold`).
bool IsActiveSegment() const;
private:
void RemoveTransient();
// Use only for testing.
float GetSumProbabilities() const { return sum_probabilities_; }
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest,
CheckSumAfterInitialization);
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, CheckSumAfterUpdate);
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest, CheckSumAfterReset);
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest,
CheckSumAfterTransientNotRemoved);
FRIEND_TEST_ALL_PREFIXES(SpeechProbabilityBufferTest,
CheckSumAfterTransientRemoved);
const float low_probability_threshold_;
// Sum of probabilities stored in `probabilities_`. Must be updated if
// `probabilities_` is updated.
float sum_probabilities_ = 0.0f;
// Circular buffer for probabilities.
std::vector<float> probabilities_;
// Current index of the circular buffer, where the newest data will be written
// to, therefore, pointing to the oldest data if buffer is full.
int buffer_index_ = 0;
// Indicates if the buffer is full and adding a new value removes the oldest
// value.
int buffer_is_full_ = false;
int num_high_probability_observations_ = 0;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_SPEECH_PROBABILITY_BUFFER_H_

View file

@ -0,0 +1,113 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/vad_wrapper.h"
#include <array>
#include <utility>
#include "api/array_view.h"
#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/agc2_common.h"
#include "modules/audio_processing/agc2/rnn_vad/common.h"
#include "modules/audio_processing/agc2/rnn_vad/features_extraction.h"
#include "modules/audio_processing/agc2/rnn_vad/rnn.h"
#include "rtc_base/checks.h"
namespace webrtc {
namespace {
constexpr int kNumFramesPerSecond = 100;
class MonoVadImpl : public VoiceActivityDetectorWrapper::MonoVad {
public:
explicit MonoVadImpl(const AvailableCpuFeatures& cpu_features)
: features_extractor_(cpu_features), rnn_vad_(cpu_features) {}
MonoVadImpl(const MonoVadImpl&) = delete;
MonoVadImpl& operator=(const MonoVadImpl&) = delete;
~MonoVadImpl() = default;
int SampleRateHz() const override { return rnn_vad::kSampleRate24kHz; }
void Reset() override { rnn_vad_.Reset(); }
float Analyze(rtc::ArrayView<const float> frame) override {
RTC_DCHECK_EQ(frame.size(), rnn_vad::kFrameSize10ms24kHz);
std::array<float, rnn_vad::kFeatureVectorSize> feature_vector;
const bool is_silence = features_extractor_.CheckSilenceComputeFeatures(
/*samples=*/{frame.data(), rnn_vad::kFrameSize10ms24kHz},
feature_vector);
return rnn_vad_.ComputeVadProbability(feature_vector, is_silence);
}
private:
rnn_vad::FeaturesExtractor features_extractor_;
rnn_vad::RnnVad rnn_vad_;
};
} // namespace
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
const AvailableCpuFeatures& cpu_features,
int sample_rate_hz)
: VoiceActivityDetectorWrapper(kVadResetPeriodMs,
cpu_features,
sample_rate_hz) {}
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
int vad_reset_period_ms,
const AvailableCpuFeatures& cpu_features,
int sample_rate_hz)
: VoiceActivityDetectorWrapper(vad_reset_period_ms,
std::make_unique<MonoVadImpl>(cpu_features),
sample_rate_hz) {}
VoiceActivityDetectorWrapper::VoiceActivityDetectorWrapper(
int vad_reset_period_ms,
std::unique_ptr<MonoVad> vad,
int sample_rate_hz)
: vad_reset_period_frames_(
rtc::CheckedDivExact(vad_reset_period_ms, kFrameDurationMs)),
time_to_vad_reset_(vad_reset_period_frames_),
vad_(std::move(vad)) {
RTC_DCHECK(vad_);
RTC_DCHECK_GT(vad_reset_period_frames_, 1);
resampled_buffer_.resize(
rtc::CheckedDivExact(vad_->SampleRateHz(), kNumFramesPerSecond));
Initialize(sample_rate_hz);
}
VoiceActivityDetectorWrapper::~VoiceActivityDetectorWrapper() = default;
void VoiceActivityDetectorWrapper::Initialize(int sample_rate_hz) {
RTC_DCHECK_GT(sample_rate_hz, 0);
frame_size_ = rtc::CheckedDivExact(sample_rate_hz, kNumFramesPerSecond);
int status =
resampler_.InitializeIfNeeded(sample_rate_hz, vad_->SampleRateHz(),
/*num_channels=*/1);
constexpr int kStatusOk = 0;
RTC_DCHECK_EQ(status, kStatusOk);
vad_->Reset();
}
float VoiceActivityDetectorWrapper::Analyze(AudioFrameView<const float> frame) {
// Periodically reset the VAD.
time_to_vad_reset_--;
if (time_to_vad_reset_ <= 0) {
vad_->Reset();
time_to_vad_reset_ = vad_reset_period_frames_;
}
// Resample the first channel of `frame`.
RTC_DCHECK_EQ(frame.samples_per_channel(), frame_size_);
resampler_.Resample(frame.channel(0).data(), frame_size_,
resampled_buffer_.data(), resampled_buffer_.size());
return vad_->Analyze(resampled_buffer_);
}
} // namespace webrtc

View file

@ -0,0 +1,82 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_
#define MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_
#include <memory>
#include <vector>
#include "api/array_view.h"
#include "common_audio/resampler/include/push_resampler.h"
#include "modules/audio_processing/agc2/cpu_features.h"
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
// Wraps a single-channel Voice Activity Detector (VAD) which is used to analyze
// the first channel of the input audio frames. Takes care of resampling the
// input frames to match the sample rate of the wrapped VAD and periodically
// resets the VAD.
class VoiceActivityDetectorWrapper {
public:
// Single channel VAD interface.
class MonoVad {
public:
virtual ~MonoVad() = default;
// Returns the sample rate (Hz) required for the input frames analyzed by
// `ComputeProbability`.
virtual int SampleRateHz() const = 0;
// Resets the internal state.
virtual void Reset() = 0;
// Analyzes an audio frame and returns the speech probability.
virtual float Analyze(rtc::ArrayView<const float> frame) = 0;
};
// Ctor. Uses `cpu_features` to instantiate the default VAD.
VoiceActivityDetectorWrapper(const AvailableCpuFeatures& cpu_features,
int sample_rate_hz);
// Ctor. `vad_reset_period_ms` indicates the period in milliseconds to call
// `MonoVad::Reset()`; it must be equal to or greater than the duration of two
// frames. Uses `cpu_features` to instantiate the default VAD.
VoiceActivityDetectorWrapper(int vad_reset_period_ms,
const AvailableCpuFeatures& cpu_features,
int sample_rate_hz);
// Ctor. Uses a custom `vad`.
VoiceActivityDetectorWrapper(int vad_reset_period_ms,
std::unique_ptr<MonoVad> vad,
int sample_rate_hz);
VoiceActivityDetectorWrapper(const VoiceActivityDetectorWrapper&) = delete;
VoiceActivityDetectorWrapper& operator=(const VoiceActivityDetectorWrapper&) =
delete;
~VoiceActivityDetectorWrapper();
// Initializes the VAD wrapper.
void Initialize(int sample_rate_hz);
// Analyzes the first channel of `frame` and returns the speech probability.
// `frame` must be a 10 ms frame with the sample rate specified in the last
// `Initialize()` call.
float Analyze(AudioFrameView<const float> frame);
private:
const int vad_reset_period_frames_;
int frame_size_;
int time_to_vad_reset_;
PushResampler<float> resampler_;
std::unique_ptr<MonoVad> vad_;
std::vector<float> resampled_buffer_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_VAD_WRAPPER_H_

View file

@ -0,0 +1,39 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#include "modules/audio_processing/agc2/vector_float_frame.h"
namespace webrtc {
namespace {
std::vector<float*> ConstructChannelPointers(
std::vector<std::vector<float>>* x) {
std::vector<float*> channel_ptrs;
for (auto& v : *x) {
channel_ptrs.push_back(v.data());
}
return channel_ptrs;
}
} // namespace
VectorFloatFrame::VectorFloatFrame(int num_channels,
int samples_per_channel,
float start_value)
: channels_(num_channels,
std::vector<float>(samples_per_channel, start_value)),
channel_ptrs_(ConstructChannelPointers(&channels_)),
float_frame_view_(channel_ptrs_.data(),
channels_.size(),
samples_per_channel) {}
VectorFloatFrame::~VectorFloatFrame() = default;
} // namespace webrtc

View file

@ -0,0 +1,42 @@
/*
* Copyright (c) 2018 The WebRTC project authors. All Rights Reserved.
*
* Use of this source code is governed by a BSD-style license
* that can be found in the LICENSE file in the root of the source
* tree. An additional intellectual property rights grant can be found
* in the file PATENTS. All contributing project authors may
* be found in the AUTHORS file in the root of the source tree.
*/
#ifndef MODULES_AUDIO_PROCESSING_AGC2_VECTOR_FLOAT_FRAME_H_
#define MODULES_AUDIO_PROCESSING_AGC2_VECTOR_FLOAT_FRAME_H_
#include <vector>
#include "modules/audio_processing/include/audio_frame_view.h"
namespace webrtc {
// A construct consisting of a multi-channel audio frame, and a FloatFrame view
// of it.
class VectorFloatFrame {
public:
VectorFloatFrame(int num_channels,
int samples_per_channel,
float start_value);
const AudioFrameView<float>& float_frame_view() { return float_frame_view_; }
AudioFrameView<const float> float_frame_view() const {
return float_frame_view_;
}
~VectorFloatFrame();
private:
std::vector<std::vector<float>> channels_;
std::vector<float*> channel_ptrs_;
AudioFrameView<float> float_frame_view_;
};
} // namespace webrtc
#endif // MODULES_AUDIO_PROCESSING_AGC2_VECTOR_FLOAT_FRAME_H_