8#include "Chirale_TensorFlowLite.h"
11#include "AudioTools/CoreAudio/BaseStream.h"
12#include "AudioTools/CoreAudio/AudioOutput.h"
13#include "AudioTools/CoreAudio/Buffers.h"
14#include "tensorflow/lite/c/common.h"
15#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h"
16#include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h"
17#include "tensorflow/lite/micro/all_ops_resolver.h"
18#include "tensorflow/lite/micro/kernels/micro_ops.h"
19#include "tensorflow/lite/micro/micro_interpreter.h"
20#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
21#include "tensorflow/lite/micro/system_setup.h"
22#include "tensorflow/lite/schema/schema_generated.h"
34class TfLiteAudioStreamBase;
35class TfLiteAbstractRecognizeCommands;
46 virtual int read(int16_t*data,
int len) = 0;
58 virtual bool write(
const int16_t sample) = 0;
70 const unsigned char* model =
nullptr;
74 bool useAllOpsResolver =
false;
76 void (*respondToCommand)(
const char* found_command, uint8_t score,
77 bool is_new_command) =
nullptr;
82 size_t kTensorArenaSize = 10 * 1024;
92 int sample_rate = 16000;
100 int kFeatureSliceSize = 40;
101 int kFeatureSliceCount = 49;
102 int kFeatureSliceStrideMs = 20;
103 int kFeatureSliceDurationMs = 30;
106 int kSlicesToProcess = 2;
109 int32_t average_window_duration_ms = 1000;
110 uint8_t detection_threshold = 50;
111 int32_t suppression_ms = 1500;
112 int32_t minimum_count = 3;
115 float filterbank_lower_band_limit = 125.0;
116 float filterbank_upper_band_limit = 7500.0;
117 float noise_reduction_smoothing_bits = 10;
118 float noise_reduction_even_smoothing = 0.025;
119 float noise_reduction_odd_smoothing = 0.06;
120 float noise_reduction_min_signal_remaining = 0.05;
121 bool pcan_gain_control_enable_pcan = 1;
122 float pcan_gain_control_strength = 0.95;
123 float pcan_gain_control_offset = 80.0;
124 float pcan_gain_control_gain_bits = 21;
125 bool log_scale_enable_log = 1;
126 uint8_t log_scale_scale_shift = 6;
135 int categoryCount() {
136 return kCategoryCount;
139 int featureElementCount() {
140 return kFeatureSliceSize * kFeatureSliceCount;
143 int audioSampleSize() {
144 return kFeatureSliceDurationMs * (sample_rate / 1000);
147 int strideSampleSize() {
148 return kFeatureSliceStrideMs * (sample_rate / 1000);
152 int kCategoryCount = 0;
153 const char** labels =
nullptr;
165 static int8_t quantize(
float value,
float scale,
float zero_point){
166 if(scale==0.0&&zero_point==0)
return value;
167 return value / scale + zero_point;
170 static float dequantize(int8_t value,
float scale,
float zero_point){
171 if(scale==0.0&&zero_point==0)
return value;
172 return (value - zero_point) * scale;
175 static float dequantizeToNewRange(int8_t value,
float scale,
float zero_point,
float new_range){
176 float deq = (
static_cast<float>(value) - zero_point) * scale;
177 return clip(deq * new_range, new_range);
180 static float clip(
float value,
float range){
182 return value > range ? range : value;
184 return -value < -range ? -range : value;
198 virtual TfLiteStatus getCommand(
const TfLiteTensor* latest_results,
const int32_t current_time_ms,
199 const char** found_command,uint8_t* score,
bool* is_new_command) = 0;
226 if (cfg.labels ==
nullptr) {
227 LOGE(
"config.labels not defined");
234 virtual TfLiteStatus getCommand(
const TfLiteTensor* latest_results,
235 const int32_t current_time_ms,
236 const char** found_command,
238 bool* is_new_command)
override {
241 this->current_time_ms = current_time_ms;
242 this->time_since_last_top = current_time_ms - previous_time_ms;
246 Result row(current_time_ms, idx, latest_results->data.int8[idx]);
247 result_queue.push_back(row);
249 TfLiteStatus result =
validate(latest_results);
250 if (result!=kTfLiteOk){
253 return evaluate(found_command, score, is_new_command);
263 Result(int32_t time_ms,
int category, int8_t score){
264 this->time_ms = time_ms;
265 this->category = category;
271 Vector <Result> result_queue;
272 int previous_cateogory=-1;
273 int32_t current_time_ms=0;
274 int32_t previous_time_ms=0;
275 int32_t time_since_last_top=0;
280 uint8_t top_score = std::numeric_limits<uint8_t>::min();
282 if (score[j]>top_score){
291 return cfg.categoryCount();
296 if (result_queue.empty())
return;
297 while (result_queue[0].time_ms<limit){
298 result_queue.pop_front();
303 TfLiteStatus
evaluate(
const char** found_command, uint8_t* result_score,
bool* is_new_command) {
308 for (
int j=0;j<result_queue.size();j++){
309 int idx = result_queue[j].category;
310 totals[idx] += result_queue[j].score;
325 LOGE(
"Could not find max category")
330 *result_score = totals[maxIdx] / count[maxIdx];
331 *found_command = cfg.labels[maxIdx];
333 if (previous_cateogory!=maxIdx
334 && *result_score > cfg.detection_threshold
335 && time_since_last_top > cfg.suppression_ms){
336 previous_time_ms = current_time_ms;
337 previous_cateogory = maxIdx;
338 *is_new_command =
true;
340 *is_new_command =
false;
343 LOGD(
"Category: %s, score: %d, is_new: %d",*found_command, *result_score, *is_new_command);
349 TfLiteStatus
validate(
const TfLiteTensor* latest_results) {
350 if ((latest_results->dims->size != 2) ||
351 (latest_results->dims->data[0] != 1) ||
354 "The results for recognition should contain %d "
355 "elements, but there are "
356 "%d in an %d-dimensional shape",
358 (
int)latest_results->dims->size);
362 if (latest_results->type != kTfLiteInt8) {
363 LOGE(
"The results for recognition should be int8 elements, but are %d",
364 (
int)latest_results->type);
368 if ((!result_queue.empty()) &&
369 (current_time_ms < result_queue[0].time_ms)) {
370 LOGE(
"Results must be in increasing time order: timestamp %d < %d",
371 (
int)current_time_ms, (
int)result_queue[0].time_ms);
389 virtual void setInterpreter(tflite::MicroInterpreter* p_interpreter) = 0;
392 virtual int availableToWrite() = 0;
395 virtual size_t write(
const uint8_t* data,
size_t len)= 0;
396 virtual tflite::MicroInterpreter& interpreter()= 0;
416 if (p_buffer !=
nullptr)
delete p_buffer;
417 if (p_audio_samples !=
nullptr)
delete p_audio_samples;
423 this->parent = parent;
426 kMaxAudioSampleSize = cfg.audioSampleSize();
427 kStrideSampleSize = cfg.strideSampleSize();
428 kKeepSampleSize = kMaxAudioSampleSize - kStrideSampleSize;
430 if (!setup_recognizer()) {
431 LOGE(
"setup_recognizer");
436 TfLiteStatus init_status = initializeMicroFeatures();
437 if (init_status != kTfLiteOk) {
442 if (p_buffer ==
nullptr) {
444 LOGD(
"Allocating buffer for %d samples", kMaxAudioSampleSize);
448 if (p_feature_data ==
nullptr) {
449 p_feature_data =
new int8_t[cfg.featureElementCount()];
450 memset(p_feature_data, 0, cfg.featureElementCount());
454 if (p_audio_samples ==
nullptr) {
455 p_audio_samples =
new int16_t[kMaxAudioSampleSize];
456 memset(p_audio_samples, 0, kMaxAudioSampleSize *
sizeof(int16_t));
462 virtual bool write(int16_t sample) {
466 current_time += cfg.kFeatureSliceStrideMs;
470 int8_t* feature_buffer = addSlice();
471 if (total_slice_count >= cfg.kSlicesToProcess) {
472 processSlices(feature_buffer);
474 total_slice_count = 0;
482 TfLiteAudioStreamBase *parent=
nullptr;
483 int8_t* p_feature_data =
nullptr;
484 int16_t* p_audio_samples =
nullptr;
486 FrontendState g_micro_features_state;
487 FrontendConfig config;
488 int kMaxAudioSampleSize;
489 int kStrideSampleSize;
493 int32_t current_time = 0;
494 int16_t total_slice_count = 0;
496 virtual bool setup_recognizer() {
498 if (cfg.recognizeCommands ==
nullptr) {
499 static TfLiteMicroSpeechRecognizeCommands static_recognizer;
500 cfg.recognizeCommands = &static_recognizer;
502 return cfg.recognizeCommands->begin(cfg);
506 virtual bool write1(
const int16_t sample) {
507 if (cfg.channels == 1) {
508 p_buffer->
write(sample);
515 p_buffer->
write(((sample / 2) + (last_value / 2)));
533 virtual int8_t* addSlice() {
536 memmove(p_feature_data, p_feature_data + cfg.kFeatureSliceSize,
537 (cfg.kFeatureSliceCount - 1) * cfg.kFeatureSliceSize);
540 int audio_samples_size =
541 p_buffer->
readArray(p_audio_samples, kMaxAudioSampleSize);
544 if (audio_samples_size != kMaxAudioSampleSize) {
545 LOGE(
"audio_samples_size=%d != kMaxAudioSampleSize=%d",
546 audio_samples_size, kMaxAudioSampleSize);
550 p_buffer->
writeArray(p_audio_samples + kStrideSampleSize, kKeepSampleSize);
553 int8_t* new_slice_data =
554 p_feature_data + ((cfg.kFeatureSliceCount - 1) * cfg.kFeatureSliceSize);
555 size_t num_samples_read = 0;
556 if (generateMicroFeatures(p_audio_samples, audio_samples_size,
557 new_slice_data, cfg.kFeatureSliceSize,
558 &num_samples_read) != kTfLiteOk) {
559 LOGE(
"Error generateMicroFeatures");
562 return p_feature_data;
566 virtual bool processSlices(int8_t* feature_buffer) {
567 LOGI(
"->slices: %d", total_slice_count);
569 memcpy(parent->modelInputBuffer(), feature_buffer, cfg.featureElementCount());
572 TfLiteStatus invoke_status = parent->interpreter().Invoke();
573 if (invoke_status != kTfLiteOk) {
574 LOGE(
"Invoke failed");
579 TfLiteTensor* output = parent->interpreter().output(0);
582 const char* found_command =
nullptr;
584 bool is_new_command =
false;
586 TfLiteStatus process_status = cfg.recognizeCommands->getCommand(
587 output, current_time, &found_command, &score, &is_new_command);
588 if (process_status != kTfLiteOk) {
589 LOGE(
"TfLiteMicroSpeechRecognizeCommands::getCommand() failed");
601 for (
int i = 0; i < cfg.kFeatureSliceCount; i++) {
602 for (
int j = 0; j < cfg.kFeatureSliceSize; j++) {
603 Serial.print(p_feature_data[(i * cfg.kFeatureSliceSize) + j]);
608 Serial.println(
"------------");
611 virtual TfLiteStatus initializeMicroFeatures() {
613 config.window.size_ms = cfg.kFeatureSliceDurationMs;
614 config.window.step_size_ms = cfg.kFeatureSliceStrideMs;
615 config.filterbank.num_channels = cfg.kFeatureSliceSize;
616 config.filterbank.lower_band_limit = cfg.filterbank_lower_band_limit;
617 config.filterbank.upper_band_limit = cfg.filterbank_upper_band_limit;
618 config.noise_reduction.smoothing_bits = cfg.noise_reduction_smoothing_bits;
619 config.noise_reduction.even_smoothing = cfg.noise_reduction_even_smoothing;
620 config.noise_reduction.odd_smoothing = cfg.noise_reduction_odd_smoothing;
621 config.noise_reduction.min_signal_remaining = cfg.noise_reduction_min_signal_remaining;
622 config.pcan_gain_control.enable_pcan = cfg.pcan_gain_control_enable_pcan;
623 config.pcan_gain_control.strength = cfg.pcan_gain_control_strength;
624 config.pcan_gain_control.offset = cfg.pcan_gain_control_offset ;
625 config.pcan_gain_control.gain_bits = cfg.pcan_gain_control_gain_bits;
626 config.log_scale.enable_log = cfg.log_scale_enable_log;
627 config.log_scale.scale_shift = cfg.log_scale_scale_shift;
628 if (!FrontendPopulateState(&config, &g_micro_features_state,
630 LOGE(
"frontendPopulateState() failed");
636 virtual TfLiteStatus generateMicroFeatures(
const int16_t* input,
637 int input_size, int8_t* output,
639 size_t* num_samples_read) {
641 const int16_t* frontend_input = input;
644 FrontendOutput frontend_output = FrontendProcessSamples(
645 &g_micro_features_state, frontend_input, input_size, num_samples_read);
648 if (output_size != frontend_output.size) {
649 LOGE(
"output_size=%d, frontend_output.size=%d", output_size,
650 frontend_output.size);
663 for (
size_t i = 0; i < frontend_output.size; ++i) {
677 constexpr int32_t value_scale = 256;
678 constexpr int32_t value_div =
679 static_cast<int32_t
>((25.6f * 26.0f) + 0.5f);
681 ((frontend_output.values[i] * value_scale) + (value_div / 2)) /
698 bool is_new_command) {
699 if (cfg.respondToCommand !=
nullptr) {
700 cfg.respondToCommand(found_command, score, is_new_command);
703 if (is_new_command) {
705 snprintf(buffer, 80,
"Result: %s, score: %d, is_new: %s", found_command,
706 score, is_new_command ?
"true" :
"false");
707 Serial.println(buffer);
722 this->increment = increment;
728 p_interpreter = &parent->interpreter();
729 input = p_interpreter->input(0);
730 output = p_interpreter->output(0);
731 channels = parent->
config().channels;
735 virtual int read(int16_t*data,
int sampleCount)
override {
737 float two_pi = 2 * PI;
738 for (
int j=0; j<sampleCount; j+=channels){
740 input->data.int8[0] = TfLiteQuantizer::quantize(actX,input->params.scale, input->params.zero_point);
743 TfLiteStatus invoke_status = p_interpreter->Invoke();
746 if(kTfLiteOk!= invoke_status){
747 LOGE(
"invoke_status not ok");
750 if(kTfLiteInt8 != output->type){
751 LOGE(
"Output type is not kTfLiteInt8");
756 data[j] = TfLiteQuantizer::dequantizeToNewRange(output->data.int8[0], output->params.scale, output->params.zero_point, range);
758 LOGD(
"%f->%d / %d->%d",actX, input->data.int8[0], output->data.int8[0], data[j]);
759 for (
int i=1;i<channels;i++){
761 LOGD(
"generate data for channels");
777 TfLiteTensor* input =
nullptr;
778 TfLiteTensor* output =
nullptr;
779 tflite::MicroInterpreter* p_interpreter =
nullptr;
792 if (p_tensor_arena !=
nullptr)
delete[] p_tensor_arena;
799 this->p_interpreter = p_interpreter;
814 p_tensor_arena =
new uint8_t[cfg.kTensorArenaSize];
816 if (cfg.categoryCount()>0){
819 if (!setupWriter()) {
824 LOGW(
"categoryCount=%d", cfg.categoryCount());
829 if (!setModel(cfg.model)) {
833 if (!setupInterpreter()) {
838 LOGI(
"AllocateTensors");
839 TfLiteStatus allocate_status = p_interpreter->AllocateTensors();
840 if (allocate_status != kTfLiteOk) {
841 LOGE(
"AllocateTensors() failed");
847 p_tensor = p_interpreter->input(0);
848 if (cfg.categoryCount()>0){
849 if ((p_tensor->dims->size != 2) || (p_tensor->dims->data[0] != 1) ||
850 (p_tensor->dims->data[1] !=
851 (cfg.kFeatureSliceCount * cfg.kFeatureSliceSize)) ||
852 (p_tensor->type != kTfLiteInt8)) {
853 LOGE(
"Bad input tensor parameters in model");
859 p_tensor_buffer = p_tensor->data.int8;
860 if (p_tensor_buffer ==
nullptr) {
861 LOGE(
"p_tensor_buffer is null");
866 if (cfg.reader!=
nullptr){
867 cfg.reader->begin(
this);
880 virtual size_t write(
const uint8_t* data,
size_t len)
override {
882 if (cfg.writer==
nullptr){
883 LOGE(
"cfg.output is null");
886 int16_t* samples = (int16_t*)data;
887 int16_t sample_count = len / 2;
888 for (
int j = 0; j < sample_count; j++) {
889 cfg.writer->write(samples[j]);
895 virtual int available()
override {
return cfg.reader !=
nullptr ? DEFAULT_BUFFER_SIZE : 0; }
898 virtual size_t readBytes(uint8_t *data,
size_t len)
override {
900 if (cfg.reader!=
nullptr){
901 return cfg.reader->read((int16_t*)data, (
int) len/
sizeof(int16_t)) *
sizeof(int16_t);
909 return *p_interpreter;
919 return p_tensor_buffer;
923 const tflite::Model* p_model =
nullptr;
924 tflite::MicroInterpreter* p_interpreter =
nullptr;
925 TfLiteTensor* p_tensor =
nullptr;
926 bool is_setup =
false;
931 uint8_t* p_tensor_arena =
nullptr;
932 int8_t* p_tensor_buffer =
nullptr;
934 virtual bool setModel(
const unsigned char* model) {
936 p_model = tflite::GetModel(model);
937 if (p_model->version() != TFLITE_SCHEMA_VERSION) {
939 "Model provided is schema version %d not equal "
940 "to supported version %d.",
941 p_model->version(), TFLITE_SCHEMA_VERSION);
947 virtual bool setupWriter() {
948 if (cfg.writer ==
nullptr) {
949 static TfLiteMicroSpeachWriter writer;
950 cfg.writer = &writer;
952 return cfg.writer->begin(
this);
961 virtual bool setupInterpreter() {
962 if (p_interpreter ==
nullptr) {
964 if (cfg.useAllOpsResolver) {
965 tflite::AllOpsResolver resolver;
966 static tflite::MicroInterpreter static_interpreter{
967 p_model, resolver, p_tensor_arena, cfg.kTensorArenaSize};
968 p_interpreter = &static_interpreter;
971 static tflite::MicroMutableOpResolver<4> micro_op_resolver{};
972 if (micro_op_resolver.AddDepthwiseConv2D() != kTfLiteOk) {
975 if (micro_op_resolver.AddFullyConnected() != kTfLiteOk) {
978 if (micro_op_resolver.AddSoftmax() != kTfLiteOk) {
981 if (micro_op_resolver.AddReshape() != kTfLiteOk) {
985 static tflite::MicroInterpreter static_interpreter{
986 p_model, micro_op_resolver, p_tensor_arena, cfg.kTensorArenaSize};
987 p_interpreter = &static_interpreter;