7 #include <TensorFlowLite.h>
10 #include "AudioTools/CoreAudio/AudioOutput.h"
11 #include "AudioTools/CoreAudio/Buffers.h"
12 #include "tensorflow/lite/c/common.h"
13 #include "tensorflow/lite/experimental/microfrontend/lib/frontend.h"
14 #include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h"
15 #include "tensorflow/lite/micro/all_ops_resolver.h"
16 #include "tensorflow/lite/micro/kernels/micro_ops.h"
17 #include "tensorflow/lite/micro/micro_interpreter.h"
18 #include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
19 #include "tensorflow/lite/micro/system_setup.h"
20 #include "tensorflow/lite/schema/schema_generated.h"
32 class TfLiteAudioStreamBase;
33 class TfLiteAbstractRecognizeCommands;
44 virtual int read(int16_t*data,
int len) = 0;
56 virtual bool write(
const int16_t sample) = 0;
68 const unsigned char* model =
nullptr;
72 bool useAllOpsResolver =
false;
74 void (*respondToCommand)(
const char* found_command, uint8_t score,
75 bool is_new_command) =
nullptr;
80 size_t kTensorArenaSize = 10 * 1024;
90 int sample_rate = 16000;
98 int kFeatureSliceSize = 40;
99 int kFeatureSliceCount = 49;
100 int kFeatureSliceStrideMs = 20;
101 int kFeatureSliceDurationMs = 30;
104 int kSlicesToProcess = 2;
107 int32_t average_window_duration_ms = 1000;
108 uint8_t detection_threshold = 50;
109 int32_t suppression_ms = 1500;
110 int32_t minimum_count = 3;
113 float filterbank_lower_band_limit = 125.0;
114 float filterbank_upper_band_limit = 7500.0;
115 float noise_reduction_smoothing_bits = 10;
116 float noise_reduction_even_smoothing = 0.025;
117 float noise_reduction_odd_smoothing = 0.06;
118 float noise_reduction_min_signal_remaining = 0.05;
119 bool pcan_gain_control_enable_pcan = 1;
120 float pcan_gain_control_strength = 0.95;
121 float pcan_gain_control_offset = 80.0;
122 float pcan_gain_control_gain_bits = 21;
123 bool log_scale_enable_log = 1;
124 uint8_t log_scale_scale_shift = 6;
133 int categoryCount() {
134 return kCategoryCount;
137 int featureElementCount() {
138 return kFeatureSliceSize * kFeatureSliceCount;
141 int audioSampleSize() {
142 return kFeatureSliceDurationMs * (sample_rate / 1000);
145 int strideSampleSize() {
146 return kFeatureSliceStrideMs * (sample_rate / 1000);
150 int kCategoryCount = 0;
151 const char** labels =
nullptr;
163 static int8_t quantize(
float value,
float scale,
float zero_point){
164 if(scale==0.0&&zero_point==0)
return value;
165 return value / scale + zero_point;
168 static float dequantize(int8_t value,
float scale,
float zero_point){
169 if(scale==0.0&&zero_point==0)
return value;
170 return (value - zero_point) * scale;
173 static float dequantizeToNewRange(int8_t value,
float scale,
float zero_point,
float new_range){
174 float deq = (
static_cast<float>(value) - zero_point) * scale;
175 return clip(deq * new_range, new_range);
178 static float clip(
float value,
float range){
180 return value > range ? range : value;
182 return -value < -range ? -range : value;
196 virtual TfLiteStatus getCommand(
const TfLiteTensor* latest_results,
const int32_t current_time_ms,
197 const char** found_command,uint8_t* score,
bool* is_new_command) = 0;
224 if (cfg.labels ==
nullptr) {
225 LOGE(
"config.labels not defined");
232 virtual TfLiteStatus getCommand(
const TfLiteTensor* latest_results,
233 const int32_t current_time_ms,
234 const char** found_command,
236 bool* is_new_command)
override {
239 this->current_time_ms = current_time_ms;
240 this->time_since_last_top = current_time_ms - previous_time_ms;
244 Result row(current_time_ms, idx, latest_results->data.int8[idx]);
245 result_queue.push_back(row);
247 TfLiteStatus result =
validate(latest_results);
248 if (result!=kTfLiteOk){
251 return evaluate(found_command, score, is_new_command);
261 Result(int32_t time_ms,
int category, int8_t score){
262 this->time_ms = time_ms;
263 this->category = category;
270 int previous_cateogory=-1;
271 int32_t current_time_ms=0;
272 int32_t previous_time_ms=0;
273 int32_t time_since_last_top=0;
278 uint8_t top_score = std::numeric_limits<uint8_t>::min();
280 if (score[j]>top_score){
289 return cfg.categoryCount();
294 if (result_queue.empty())
return;
295 while (result_queue[0].time_ms<limit){
296 result_queue.pop_front();
301 TfLiteStatus
evaluate(
const char** found_command, uint8_t* result_score,
bool* is_new_command) {
306 for (
int j=0;j<result_queue.size();j++){
307 int idx = result_queue[j].category;
308 totals[idx] += result_queue[j].score;
323 LOGE(
"Could not find max category")
328 *result_score = totals[maxIdx] / count[maxIdx];
329 *found_command = cfg.labels[maxIdx];
331 if (previous_cateogory!=maxIdx
332 && *result_score > cfg.detection_threshold
333 && time_since_last_top > cfg.suppression_ms){
334 previous_time_ms = current_time_ms;
335 previous_cateogory = maxIdx;
336 *is_new_command =
true;
338 *is_new_command =
false;
341 LOGD(
"Category: %s, score: %d, is_new: %d",*found_command, *result_score, *is_new_command);
347 TfLiteStatus
validate(
const TfLiteTensor* latest_results) {
348 if ((latest_results->dims->size != 2) ||
349 (latest_results->dims->data[0] != 1) ||
352 "The results for recognition should contain %d "
353 "elements, but there are "
354 "%d in an %d-dimensional shape",
356 (
int)latest_results->dims->size);
360 if (latest_results->type != kTfLiteInt8) {
361 LOGE(
"The results for recognition should be int8 elements, but are %d",
362 (
int)latest_results->type);
366 if ((!result_queue.empty()) &&
367 (current_time_ms < result_queue[0].time_ms)) {
368 LOGE(
"Results must be in increasing time order: timestamp %d < %d",
369 (
int)current_time_ms, (
int)result_queue[0].time_ms);
387 virtual void setInterpreter(tflite::MicroInterpreter* p_interpreter) = 0;
390 virtual int availableToWrite() = 0;
393 virtual size_t write(
const uint8_t* data,
size_t len)= 0;
394 virtual tflite::MicroInterpreter& interpreter()= 0;
414 if (p_buffer !=
nullptr)
delete p_buffer;
415 if (p_audio_samples !=
nullptr)
delete p_audio_samples;
421 this->parent = parent;
424 kMaxAudioSampleSize = cfg.audioSampleSize();
425 kStrideSampleSize = cfg.strideSampleSize();
426 kKeepSampleSize = kMaxAudioSampleSize - kStrideSampleSize;
428 if (!setup_recognizer()) {
429 LOGE(
"setup_recognizer");
434 TfLiteStatus init_status = initializeMicroFeatures();
435 if (init_status != kTfLiteOk) {
440 if (p_buffer ==
nullptr) {
442 LOGD(
"Allocating buffer for %d samples", kMaxAudioSampleSize);
446 if (p_feature_data ==
nullptr) {
447 p_feature_data =
new int8_t[cfg.featureElementCount()];
448 memset(p_feature_data, 0, cfg.featureElementCount());
452 if (p_audio_samples ==
nullptr) {
453 p_audio_samples =
new int16_t[kMaxAudioSampleSize];
454 memset(p_audio_samples, 0, kMaxAudioSampleSize *
sizeof(int16_t));
460 virtual bool write(int16_t sample) {
464 current_time += cfg.kFeatureSliceStrideMs;
468 int8_t* feature_buffer = addSlice();
469 if (total_slice_count >= cfg.kSlicesToProcess) {
470 processSlices(feature_buffer);
472 total_slice_count = 0;
480 TfLiteAudioStreamBase *parent=
nullptr;
481 int8_t* p_feature_data =
nullptr;
482 int16_t* p_audio_samples =
nullptr;
484 FrontendState g_micro_features_state;
485 FrontendConfig config;
486 int kMaxAudioSampleSize;
487 int kStrideSampleSize;
491 int32_t current_time = 0;
492 int16_t total_slice_count = 0;
494 virtual bool setup_recognizer() {
496 if (cfg.recognizeCommands ==
nullptr) {
497 static TfLiteMicroSpeechRecognizeCommands static_recognizer;
498 cfg.recognizeCommands = &static_recognizer;
500 return cfg.recognizeCommands->begin(cfg);
504 virtual bool write1(
const int16_t sample) {
505 if (cfg.channels == 1) {
506 p_buffer->
write(sample);
513 p_buffer->
write(((sample / 2) + (last_value / 2)));
531 virtual int8_t* addSlice() {
534 memmove(p_feature_data, p_feature_data + cfg.kFeatureSliceSize,
535 (cfg.kFeatureSliceCount - 1) * cfg.kFeatureSliceSize);
538 int audio_samples_size =
539 p_buffer->
readArray(p_audio_samples, kMaxAudioSampleSize);
542 if (audio_samples_size != kMaxAudioSampleSize) {
543 LOGE(
"audio_samples_size=%d != kMaxAudioSampleSize=%d",
544 audio_samples_size, kMaxAudioSampleSize);
548 p_buffer->
writeArray(p_audio_samples + kStrideSampleSize, kKeepSampleSize);
551 int8_t* new_slice_data =
552 p_feature_data + ((cfg.kFeatureSliceCount - 1) * cfg.kFeatureSliceSize);
553 size_t num_samples_read = 0;
554 if (generateMicroFeatures(p_audio_samples, audio_samples_size,
555 new_slice_data, cfg.kFeatureSliceSize,
556 &num_samples_read) != kTfLiteOk) {
557 LOGE(
"Error generateMicroFeatures");
560 return p_feature_data;
564 virtual bool processSlices(int8_t* feature_buffer) {
565 LOGI(
"->slices: %d", total_slice_count);
567 memcpy(parent->modelInputBuffer(), feature_buffer, cfg.featureElementCount());
570 TfLiteStatus invoke_status = parent->interpreter().Invoke();
571 if (invoke_status != kTfLiteOk) {
572 LOGE(
"Invoke failed");
577 TfLiteTensor* output = parent->interpreter().output(0);
580 const char* found_command =
nullptr;
582 bool is_new_command =
false;
584 TfLiteStatus process_status = cfg.recognizeCommands->getCommand(
585 output, current_time, &found_command, &score, &is_new_command);
586 if (process_status != kTfLiteOk) {
587 LOGE(
"TfLiteMicroSpeechRecognizeCommands::getCommand() failed");
599 for (
int i = 0; i < cfg.kFeatureSliceCount; i++) {
600 for (
int j = 0; j < cfg.kFeatureSliceSize; j++) {
601 Serial.print(p_feature_data[(i * cfg.kFeatureSliceSize) + j]);
606 Serial.println(
"------------");
609 virtual TfLiteStatus initializeMicroFeatures() {
611 config.window.size_ms = cfg.kFeatureSliceDurationMs;
612 config.window.step_size_ms = cfg.kFeatureSliceStrideMs;
613 config.filterbank.num_channels = cfg.kFeatureSliceSize;
614 config.filterbank.lower_band_limit = cfg.filterbank_lower_band_limit;
615 config.filterbank.upper_band_limit = cfg.filterbank_upper_band_limit;
616 config.noise_reduction.smoothing_bits = cfg.noise_reduction_smoothing_bits;
617 config.noise_reduction.even_smoothing = cfg.noise_reduction_even_smoothing;
618 config.noise_reduction.odd_smoothing = cfg.noise_reduction_odd_smoothing;
619 config.noise_reduction.min_signal_remaining = cfg.noise_reduction_min_signal_remaining;
620 config.pcan_gain_control.enable_pcan = cfg.pcan_gain_control_enable_pcan;
621 config.pcan_gain_control.strength = cfg.pcan_gain_control_strength;
622 config.pcan_gain_control.offset = cfg.pcan_gain_control_offset ;
623 config.pcan_gain_control.gain_bits = cfg.pcan_gain_control_gain_bits;
624 config.log_scale.enable_log = cfg.log_scale_enable_log;
625 config.log_scale.scale_shift = cfg.log_scale_scale_shift;
626 if (!FrontendPopulateState(&config, &g_micro_features_state,
628 LOGE(
"frontendPopulateState() failed");
634 virtual TfLiteStatus generateMicroFeatures(
const int16_t* input,
635 int input_size, int8_t* output,
637 size_t* num_samples_read) {
639 const int16_t* frontend_input = input;
642 FrontendOutput frontend_output = FrontendProcessSamples(
643 &g_micro_features_state, frontend_input, input_size, num_samples_read);
646 if (output_size != frontend_output.size) {
647 LOGE(
"output_size=%d, frontend_output.size=%d", output_size,
648 frontend_output.size);
661 for (
size_t i = 0; i < frontend_output.size; ++i) {
675 constexpr int32_t value_scale = 256;
676 constexpr int32_t value_div =
677 static_cast<int32_t
>((25.6f * 26.0f) + 0.5f);
679 ((frontend_output.values[i] * value_scale) + (value_div / 2)) /
696 bool is_new_command) {
697 if (cfg.respondToCommand !=
nullptr) {
698 cfg.respondToCommand(found_command, score, is_new_command);
701 if (is_new_command) {
703 snprintf(buffer, 80,
"Result: %s, score: %d, is_new: %s", found_command,
704 score, is_new_command ?
"true" :
"false");
705 Serial.println(buffer);
720 this->increment = increment;
726 p_interpreter = &parent->interpreter();
727 input = p_interpreter->input(0);
728 output = p_interpreter->output(0);
729 channels = parent->
config().channels;
733 virtual int read(int16_t*data,
int sampleCount)
override {
735 float two_pi = 2 * PI;
736 for (
int j=0; j<sampleCount; j+=channels){
738 input->data.int8[0] = TfLiteQuantizer::quantize(actX,input->params.scale, input->params.zero_point);
741 TfLiteStatus invoke_status = p_interpreter->Invoke();
744 if(kTfLiteOk!= invoke_status){
745 LOGE(
"invoke_status not ok");
748 if(kTfLiteInt8 != output->type){
749 LOGE(
"Output type is not kTfLiteInt8");
754 data[j] = TfLiteQuantizer::dequantizeToNewRange(output->data.int8[0], output->params.scale, output->params.zero_point, range);
756 LOGD(
"%f->%d / %d->%d",actX, input->data.int8[0], output->data.int8[0], data[j]);
757 for (
int i=1;i<channels;i++){
759 LOGD(
"generate data for channels");
775 TfLiteTensor* input =
nullptr;
776 TfLiteTensor* output =
nullptr;
777 tflite::MicroInterpreter* p_interpreter =
nullptr;
790 if (p_tensor_arena !=
nullptr)
delete[] p_tensor_arena;
797 this->p_interpreter = p_interpreter;
812 p_tensor_arena =
new uint8_t[cfg.kTensorArenaSize];
814 if (cfg.categoryCount()>0){
817 if (!setupWriter()) {
822 LOGW(
"categoryCount=%d", cfg.categoryCount());
827 if (!setModel(cfg.model)) {
831 if (!setupInterpreter()) {
836 LOGI(
"AllocateTensors");
837 TfLiteStatus allocate_status = p_interpreter->AllocateTensors();
838 if (allocate_status != kTfLiteOk) {
839 LOGE(
"AllocateTensors() failed");
845 p_tensor = p_interpreter->input(0);
846 if (cfg.categoryCount()>0){
847 if ((p_tensor->dims->size != 2) || (p_tensor->dims->data[0] != 1) ||
848 (p_tensor->dims->data[1] !=
849 (cfg.kFeatureSliceCount * cfg.kFeatureSliceSize)) ||
850 (p_tensor->type != kTfLiteInt8)) {
851 LOGE(
"Bad input tensor parameters in model");
857 p_tensor_buffer = p_tensor->data.int8;
858 if (p_tensor_buffer ==
nullptr) {
859 LOGE(
"p_tensor_buffer is null");
864 if (cfg.reader!=
nullptr){
865 cfg.reader->begin(
this);
878 virtual size_t write(
const uint8_t* data,
size_t len)
override {
880 if (cfg.writer==
nullptr){
881 LOGE(
"cfg.output is null");
884 int16_t* samples = (int16_t*)data;
885 int16_t sample_count = len / 2;
886 for (
int j = 0; j < sample_count; j++) {
887 cfg.writer->write(samples[j]);
893 virtual int available()
override {
return cfg.reader !=
nullptr ? DEFAULT_BUFFER_SIZE : 0; }
896 virtual size_t readBytes(uint8_t *data,
size_t len)
override {
898 if (cfg.reader!=
nullptr){
899 return cfg.reader->read((int16_t*)data, (
int) len/
sizeof(int16_t)) *
sizeof(int16_t);
907 return *p_interpreter;
917 return p_tensor_buffer;
921 const tflite::Model* p_model =
nullptr;
922 tflite::MicroInterpreter* p_interpreter =
nullptr;
923 TfLiteTensor* p_tensor =
nullptr;
924 bool is_setup =
false;
929 uint8_t* p_tensor_arena =
nullptr;
930 int8_t* p_tensor_buffer =
nullptr;
932 virtual bool setModel(
const unsigned char* model) {
934 p_model = tflite::GetModel(model);
935 if (p_model->version() != TFLITE_SCHEMA_VERSION) {
937 "Model provided is schema version %d not equal "
938 "to supported version %d.",
939 p_model->version(), TFLITE_SCHEMA_VERSION);
945 virtual bool setupWriter() {
946 if (cfg.writer ==
nullptr) {
947 static TfLiteMicroSpeachWriter writer;
948 cfg.writer = &writer;
950 return cfg.writer->begin(
this);
959 virtual bool setupInterpreter() {
960 if (p_interpreter ==
nullptr) {
962 if (cfg.useAllOpsResolver) {
963 tflite::AllOpsResolver resolver;
964 static tflite::MicroInterpreter static_interpreter{
965 p_model, resolver, p_tensor_arena, cfg.kTensorArenaSize};
966 p_interpreter = &static_interpreter;
969 static tflite::MicroMutableOpResolver<4> micro_op_resolver{};
970 if (micro_op_resolver.AddDepthwiseConv2D() != kTfLiteOk) {
973 if (micro_op_resolver.AddFullyConnected() != kTfLiteOk) {
976 if (micro_op_resolver.AddSoftmax() != kTfLiteOk) {
979 if (micro_op_resolver.AddReshape() != kTfLiteOk) {
983 static tflite::MicroInterpreter static_interpreter{
984 p_model, micro_op_resolver, p_tensor_arena, cfg.kTensorArenaSize};
985 p_interpreter = &static_interpreter;