arduino-audio-tools
All Classes Namespaces Files Functions Variables Typedefs Enumerations Friends Modules Pages
TfLiteAudioStream.h
1#pragma once
2
3// Configure FFT to output 16 bit fixed point.
4#define FIXED_POINT 16
5
6//#include <MicroTFLite.h>
7#include <TensorFlowLite.h>
8#include <cmath>
9#include <cstdint>
10#include "AudioTools/CoreAudio/BaseStream.h"
11#include "AudioTools/CoreAudio/AudioOutput.h"
12#include "AudioTools/CoreAudio/Buffers.h"
13#include "tensorflow/lite/c/common.h"
14#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h"
15#include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h"
16#include "tensorflow/lite/micro/all_ops_resolver.h"
17#include "tensorflow/lite/micro/kernels/micro_ops.h"
18#include "tensorflow/lite/micro/micro_interpreter.h"
19#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
20#include "tensorflow/lite/micro/system_setup.h"
21#include "tensorflow/lite/schema/schema_generated.h"
22
30namespace audio_tools {
31
32// Forward Declarations
33class TfLiteAudioStreamBase;
34class TfLiteAbstractRecognizeCommands;
35
43 public:
44 virtual bool begin(TfLiteAudioStreamBase *parent) = 0;
45 virtual int read(int16_t*data, int len) = 0;
46};
47
55 public:
56 virtual bool begin(TfLiteAudioStreamBase *parent) = 0;
57 virtual bool write(const int16_t sample) = 0;
58};
59
69 const unsigned char* model = nullptr;
70 TfLiteReader *reader = nullptr;
71 TfLiteWriter *writer = nullptr;
72 TfLiteAbstractRecognizeCommands *recognizeCommands=nullptr;
73 bool useAllOpsResolver = false;
74 // callback for command handler
75 void (*respondToCommand)(const char* found_command, uint8_t score,
76 bool is_new_command) = nullptr;
77
78 // Create an area of memory to use for input, output, and intermediate arrays.
79 // The size of this will depend on the model you’re using, and may need to be
80 // determined by experimentation.
81 size_t kTensorArenaSize = 10 * 1024;
82
83 // Keeping these as constant expressions allow us to allocate fixed-sized
84 // arrays on the stack for our working memory.
85
86 // The size of the input time series data we pass to the FFT to produce
87 // the frequency information. This has to be a power of two, and since
88 // we're dealing with 30ms of 16KHz inputs, which means 480 samples, this
89 // is the next value.
90 // int kMaxAudioSampleSize = 320; //512; // 480
91 int sample_rate = 16000;
92
93 // Number of audio channels - is usually 1. If 2 we reduce it to 1 by
94 // averaging the 2 channels
95 int channels = 1;
96
97 // The following values are derived from values used during model training.
98 // If you change the way you preprocess the input, update all these constants.
99 int kFeatureSliceSize = 40;
100 int kFeatureSliceCount = 49;
101 int kFeatureSliceStrideMs = 20;
102 int kFeatureSliceDurationMs = 30;
103
104 // number of new slices to collect before evaluating the model
105 int kSlicesToProcess = 2;
106
107 // Parameters for RecognizeCommands
108 int32_t average_window_duration_ms = 1000;
109 uint8_t detection_threshold = 50;
110 int32_t suppression_ms = 1500;
111 int32_t minimum_count = 3;
112
113 // input for FrontendConfig
114 float filterbank_lower_band_limit = 125.0;
115 float filterbank_upper_band_limit = 7500.0;
116 float noise_reduction_smoothing_bits = 10;
117 float noise_reduction_even_smoothing = 0.025;
118 float noise_reduction_odd_smoothing = 0.06;
119 float noise_reduction_min_signal_remaining = 0.05;
120 bool pcan_gain_control_enable_pcan = 1;
121 float pcan_gain_control_strength = 0.95;
122 float pcan_gain_control_offset = 80.0;
123 float pcan_gain_control_gain_bits = 21;
124 bool log_scale_enable_log = 1;
125 uint8_t log_scale_scale_shift = 6;
126
128 template<int N>
129 void setCategories(const char* (&array)[N]){
130 labels = array;
131 kCategoryCount = N;
132 }
133
134 int categoryCount() {
135 return kCategoryCount;
136 }
137
138 int featureElementCount() {
139 return kFeatureSliceSize * kFeatureSliceCount;
140 }
141
142 int audioSampleSize() {
143 return kFeatureSliceDurationMs * (sample_rate / 1000);
144 }
145
146 int strideSampleSize() {
147 return kFeatureSliceStrideMs * (sample_rate / 1000);
148 }
149
150 private:
151 int kCategoryCount = 0;
152 const char** labels = nullptr;
153};
154
162 public:
163 // convert float to int8
164 static int8_t quantize(float value, float scale, float zero_point){
165 if(scale==0.0&&zero_point==0) return value;
166 return value / scale + zero_point;
167 }
168 // convert int8 to float
169 static float dequantize(int8_t value, float scale, float zero_point){
170 if(scale==0.0&&zero_point==0) return value;
171 return (value - zero_point) * scale;
172 }
173
174 static float dequantizeToNewRange(int8_t value, float scale, float zero_point, float new_range){
175 float deq = (static_cast<float>(value) - zero_point) * scale;
176 return clip(deq * new_range, new_range);
177 }
178
179 static float clip(float value, float range){
180 if (value>=0.0){
181 return value > range ? range : value;
182 } else {
183 return -value < -range ? -range : value;
184 }
185 }
186};
187
195 public:
196 virtual bool begin(TfLiteConfig cfg) = 0;
197 virtual TfLiteStatus getCommand(const TfLiteTensor* latest_results, const int32_t current_time_ms,
198 const char** found_command,uint8_t* score,bool* is_new_command) = 0;
199
200};
201
216 public:
217
219 }
220
222 bool begin(TfLiteConfig cfg) override {
223 TRACED();
224 this->cfg = cfg;
225 if (cfg.labels == nullptr) {
226 LOGE("config.labels not defined");
227 return false;
228 }
229 return true;
230 }
231
232 // Call this with the results of running a model on sample data.
233 virtual TfLiteStatus getCommand(const TfLiteTensor* latest_results,
234 const int32_t current_time_ms,
235 const char** found_command,
236 uint8_t* score,
237 bool* is_new_command) override {
238
239 TRACED();
240 this->current_time_ms = current_time_ms;
241 this->time_since_last_top = current_time_ms - previous_time_ms;
242
243 deleteOldRecords(current_time_ms - cfg.average_window_duration_ms);
244 int idx = resultCategoryIdx(latest_results->data.int8);
245 Result row(current_time_ms, idx, latest_results->data.int8[idx]);
246 result_queue.push_back(row);
247
248 TfLiteStatus result = validate(latest_results);
249 if (result!=kTfLiteOk){
250 return result;
251 }
252 return evaluate(found_command, score, is_new_command);
253 }
254
255 protected:
256 struct Result {
257 int32_t time_ms;
258 int category=0;
259 int8_t score=0;
260
261 Result() = default;
262 Result(int32_t time_ms,int category, int8_t score){
263 this->time_ms = time_ms;
264 this->category = category;
265 this->score = score;
266 }
267 };
268
269 TfLiteConfig cfg;
270 Vector <Result> result_queue;
271 int previous_cateogory=-1;
272 int32_t current_time_ms=0;
273 int32_t previous_time_ms=0;
274 int32_t time_since_last_top=0;
275
277 int resultCategoryIdx(int8_t* score) {
278 int result = -1;
279 uint8_t top_score = std::numeric_limits<uint8_t>::min();
280 for (int j=0;j<categoryCount();j++){
281 if (score[j]>top_score){
282 result = j;
283 }
284 }
285 return result;
286 }
287
290 return cfg.categoryCount();
291 }
292
294 void deleteOldRecords(int32_t limit) {
295 if (result_queue.empty()) return;
296 while (result_queue[0].time_ms<limit){
297 result_queue.pop_front();
298 }
299 }
300
302 TfLiteStatus evaluate(const char** found_command, uint8_t* result_score, bool* is_new_command) {
303 TRACED();
304 float totals[categoryCount()]={0};
305 int count[categoryCount()]={0};
306 // calculate totals
307 for (int j=0;j<result_queue.size();j++){
308 int idx = result_queue[j].category;
309 totals[idx] += result_queue[j].score;
310 count[idx]++;
311 }
312
313 // find max
314 int maxIdx = -1;
315 float max = -100000;
316 for (int j=0;j<categoryCount();j++){
317 if (totals[j]>max){
318 max = totals[j];
319 maxIdx = j;
320 }
321 }
322
323 if (maxIdx==-1){
324 LOGE("Could not find max category")
325 return kTfLiteError;
326 }
327
328 // determine result
329 *result_score = totals[maxIdx] / count[maxIdx];
330 *found_command = cfg.labels[maxIdx];
331
332 if (previous_cateogory!=maxIdx
333 && *result_score > cfg.detection_threshold
334 && time_since_last_top > cfg.suppression_ms){
335 previous_time_ms = current_time_ms;
336 previous_cateogory = maxIdx;
337 *is_new_command = true;
338 } else {
339 *is_new_command = false;
340 }
341
342 LOGD("Category: %s, score: %d, is_new: %d",*found_command, *result_score, *is_new_command);
343
344 return kTfLiteOk;
345 }
346
348 TfLiteStatus validate(const TfLiteTensor* latest_results) {
349 if ((latest_results->dims->size != 2) ||
350 (latest_results->dims->data[0] != 1) ||
351 (latest_results->dims->data[1] != categoryCount())) {
352 LOGE(
353 "The results for recognition should contain %d "
354 "elements, but there are "
355 "%d in an %d-dimensional shape",
356 categoryCount(), (int)latest_results->dims->data[1],
357 (int)latest_results->dims->size);
358 return kTfLiteError;
359 }
360
361 if (latest_results->type != kTfLiteInt8) {
362 LOGE("The results for recognition should be int8 elements, but are %d",
363 (int)latest_results->type);
364 return kTfLiteError;
365 }
366
367 if ((!result_queue.empty()) &&
368 (current_time_ms < result_queue[0].time_ms)) {
369 LOGE("Results must be in increasing time order: timestamp %d < %d",
370 (int)current_time_ms, (int)result_queue[0].time_ms);
371 return kTfLiteError;
372 }
373 return kTfLiteOk;
374 }
375
376};
377
378
387 public:
388 virtual void setInterpreter(tflite::MicroInterpreter* p_interpreter) = 0;
389 virtual TfLiteConfig defaultConfig() = 0;
390 virtual bool begin(TfLiteConfig config) = 0;
391 virtual int availableToWrite() = 0;
392
394 virtual size_t write(const uint8_t* data, size_t len)= 0;
395 virtual tflite::MicroInterpreter& interpreter()= 0;
396
398 virtual TfLiteConfig &config()= 0;
399
401 virtual int8_t* modelInputBuffer()= 0;
402};
403
411 public:
412 TfLiteMicroSpeachWriter() = default;
413
415 if (p_buffer != nullptr) delete p_buffer;
416 if (p_audio_samples != nullptr) delete p_audio_samples;
417 }
418
420 virtual bool begin(TfLiteAudioStreamBase *parent) {
421 TRACED();
422 this->parent = parent;
423 cfg = parent->config();
424 current_time = 0;
425 kMaxAudioSampleSize = cfg.audioSampleSize();
426 kStrideSampleSize = cfg.strideSampleSize();
427 kKeepSampleSize = kMaxAudioSampleSize - kStrideSampleSize;
428
429 if (!setup_recognizer()) {
430 LOGE("setup_recognizer");
431 return false;
432 }
433
434 // setup FrontendConfig
435 TfLiteStatus init_status = initializeMicroFeatures();
436 if (init_status != kTfLiteOk) {
437 return false;
438 }
439
440 // Allocate ring buffer
441 if (p_buffer == nullptr) {
442 p_buffer = new audio_tools::RingBuffer<int16_t>(kMaxAudioSampleSize);
443 LOGD("Allocating buffer for %d samples", kMaxAudioSampleSize);
444 }
445
446 // Initialize the feature data to default values.
447 if (p_feature_data == nullptr) {
448 p_feature_data = new int8_t[cfg.featureElementCount()];
449 memset(p_feature_data, 0, cfg.featureElementCount());
450 }
451
452 // allocate p_audio_samples
453 if (p_audio_samples == nullptr) {
454 p_audio_samples = new int16_t[kMaxAudioSampleSize];
455 memset(p_audio_samples, 0, kMaxAudioSampleSize * sizeof(int16_t));
456 }
457
458 return true;
459 }
460
461 virtual bool write(int16_t sample) {
462 TRACED();
463 if (!write1(sample)){
464 // determine time
465 current_time += cfg.kFeatureSliceStrideMs;
466 // determine slice
467 total_slice_count++;
468
469 int8_t* feature_buffer = addSlice();
470 if (total_slice_count >= cfg.kSlicesToProcess) {
471 processSlices(feature_buffer);
472 // reset total_slice_count
473 total_slice_count = 0;
474 }
475 }
476 return true;
477 }
478
479 protected:
480 TfLiteConfig cfg;
481 TfLiteAudioStreamBase *parent=nullptr;
482 int8_t* p_feature_data = nullptr;
483 int16_t* p_audio_samples = nullptr;
484 audio_tools::RingBuffer<int16_t>* p_buffer = nullptr;
485 FrontendState g_micro_features_state;
486 FrontendConfig config;
487 int kMaxAudioSampleSize;
488 int kStrideSampleSize;
489 int kKeepSampleSize;
490 int16_t last_value;
491 int8_t channel = 0;
492 int32_t current_time = 0;
493 int16_t total_slice_count = 0;
494
495 virtual bool setup_recognizer() {
496 // setup default p_recognizer if not defined
497 if (cfg.recognizeCommands == nullptr) {
498 static TfLiteMicroSpeechRecognizeCommands static_recognizer;
499 cfg.recognizeCommands = &static_recognizer;
500 }
501 return cfg.recognizeCommands->begin(cfg);
502 }
503
505 virtual bool write1(const int16_t sample) {
506 if (cfg.channels == 1) {
507 p_buffer->write(sample);
508 } else {
509 if (channel == 0) {
510 last_value = sample;
511 channel = 1;
512 } else
513 // calculate avg of 2 channels and convert it to int8_t
514 p_buffer->write(((sample / 2) + (last_value / 2)));
515 channel = 0;
516 }
517 return p_buffer->availableForWrite() > 0;
518 }
519
520 // If we can avoid recalculating some slices, just move the existing
521 // data up in the spectrogram, to perform something like this: last time
522 // = 80ms current time = 120ms
523 // +-----------+ +-----------+
524 // | data@20ms | --> | data@60ms |
525 // +-----------+ -- +-----------+
526 // | data@40ms | -- --> | data@80ms |
527 // +-----------+ -- -- +-----------+
528 // | data@60ms | -- -- | <empty> |
529 // +-----------+ -- +-----------+
530 // | data@80ms | -- | <empty> |
531 // +-----------+ +-----------+
532 virtual int8_t* addSlice() {
533 TRACED();
534 // shift p_feature_data by one slice one one
535 memmove(p_feature_data, p_feature_data + cfg.kFeatureSliceSize,
536 (cfg.kFeatureSliceCount - 1) * cfg.kFeatureSliceSize);
537
538 // copy data from buffer to p_audio_samples
539 int audio_samples_size =
540 p_buffer->readArray(p_audio_samples, kMaxAudioSampleSize);
541
542 // check size
543 if (audio_samples_size != kMaxAudioSampleSize) {
544 LOGE("audio_samples_size=%d != kMaxAudioSampleSize=%d",
545 audio_samples_size, kMaxAudioSampleSize);
546 }
547
548 // keep some data to be reprocessed - move by kStrideSampleSize
549 p_buffer->writeArray(p_audio_samples + kStrideSampleSize, kKeepSampleSize);
550
551 // the new slice data will always be stored at the end
552 int8_t* new_slice_data =
553 p_feature_data + ((cfg.kFeatureSliceCount - 1) * cfg.kFeatureSliceSize);
554 size_t num_samples_read = 0;
555 if (generateMicroFeatures(p_audio_samples, audio_samples_size,
556 new_slice_data, cfg.kFeatureSliceSize,
557 &num_samples_read) != kTfLiteOk) {
558 LOGE("Error generateMicroFeatures");
559 }
560 // printFeatures();
561 return p_feature_data;
562 }
563
564 // Process multiple slice of audio data
565 virtual bool processSlices(int8_t* feature_buffer) {
566 LOGI("->slices: %d", total_slice_count);
567 // Copy feature buffer to input tensor
568 memcpy(parent->modelInputBuffer(), feature_buffer, cfg.featureElementCount());
569
570 // Run the model on the spectrogram input and make sure it succeeds.
571 TfLiteStatus invoke_status = parent->interpreter().Invoke();
572 if (invoke_status != kTfLiteOk) {
573 LOGE("Invoke failed");
574 return false;
575 }
576
577 // Obtain a pointer to the output tensor
578 TfLiteTensor* output = parent->interpreter().output(0);
579
580 // Determine whether a command was recognized
581 const char* found_command = nullptr;
582 uint8_t score = 0;
583 bool is_new_command = false;
584
585 TfLiteStatus process_status = cfg.recognizeCommands->getCommand(
586 output, current_time, &found_command, &score, &is_new_command);
587 if (process_status != kTfLiteOk) {
588 LOGE("TfLiteMicroSpeechRecognizeCommands::getCommand() failed");
589 return false;
590 }
591 // Do something based on the recognized command. The default
592 // implementation just prints to the error console, but you should replace
593 // this with your own function for a real application.
594 respondToCommand(found_command, score, is_new_command);
595 return true;
596 }
597
600 for (int i = 0; i < cfg.kFeatureSliceCount; i++) {
601 for (int j = 0; j < cfg.kFeatureSliceSize; j++) {
602 Serial.print(p_feature_data[(i * cfg.kFeatureSliceSize) + j]);
603 Serial.print(" ");
604 }
605 Serial.println();
606 }
607 Serial.println("------------");
608 }
609
610 virtual TfLiteStatus initializeMicroFeatures() {
611 TRACED();
612 config.window.size_ms = cfg.kFeatureSliceDurationMs;
613 config.window.step_size_ms = cfg.kFeatureSliceStrideMs;
614 config.filterbank.num_channels = cfg.kFeatureSliceSize;
615 config.filterbank.lower_band_limit = cfg.filterbank_lower_band_limit;
616 config.filterbank.upper_band_limit = cfg.filterbank_upper_band_limit;
617 config.noise_reduction.smoothing_bits = cfg.noise_reduction_smoothing_bits;
618 config.noise_reduction.even_smoothing = cfg.noise_reduction_even_smoothing;
619 config.noise_reduction.odd_smoothing = cfg.noise_reduction_odd_smoothing;
620 config.noise_reduction.min_signal_remaining = cfg.noise_reduction_min_signal_remaining;
621 config.pcan_gain_control.enable_pcan = cfg.pcan_gain_control_enable_pcan;
622 config.pcan_gain_control.strength = cfg.pcan_gain_control_strength;
623 config.pcan_gain_control.offset = cfg.pcan_gain_control_offset ;
624 config.pcan_gain_control.gain_bits = cfg.pcan_gain_control_gain_bits;
625 config.log_scale.enable_log = cfg.log_scale_enable_log;
626 config.log_scale.scale_shift = cfg.log_scale_scale_shift;
627 if (!FrontendPopulateState(&config, &g_micro_features_state,
628 cfg.sample_rate)) {
629 LOGE("frontendPopulateState() failed");
630 return kTfLiteError;
631 }
632 return kTfLiteOk;
633 }
634
635 virtual TfLiteStatus generateMicroFeatures(const int16_t* input,
636 int input_size, int8_t* output,
637 int output_size,
638 size_t* num_samples_read) {
639 TRACED();
640 const int16_t* frontend_input = input;
641
642 // Apply FFT
643 FrontendOutput frontend_output = FrontendProcessSamples(
644 &g_micro_features_state, frontend_input, input_size, num_samples_read);
645
646 // Check size
647 if (output_size != frontend_output.size) {
648 LOGE("output_size=%d, frontend_output.size=%d", output_size,
649 frontend_output.size);
650 }
651
652 // printf("input_size: %d, num_samples_read: %d,output_size: %d,
653 // frontend_output.size:%d \n", input_size, *num_samples_read, output_size,
654 // frontend_output.size);
655
656 // // check generated features
657 // if (input_size != *num_samples_read){
658 // LOGE("audio_samples_size=%d vs num_samples_read=%d", input_size,
659 // *num_samples_read);
660 // }
661
662 for (size_t i = 0; i < frontend_output.size; ++i) {
663 // These scaling values are derived from those used in input_data.py in
664 // the training pipeline. The feature pipeline outputs 16-bit signed
665 // integers in roughly a 0 to 670 range. In training, these are then
666 // arbitrarily divided by 25.6 to get float values in the rough range of
667 // 0.0 to 26.0. This scaling is performed for historical reasons, to match
668 // up with the output of other feature generators. The process is then
669 // further complicated when we quantize the model. This means we have to
670 // scale the 0.0 to 26.0 real values to the -128 to 127 signed integer
671 // numbers. All this means that to get matching values from our integer
672 // feature output into the tensor input, we have to perform: input =
673 // (((feature / 25.6) / 26.0) * 256) - 128 To simplify this and perform it
674 // in 32-bit integer math, we rearrange to: input = (feature * 256) /
675 // (25.6 * 26.0) - 128
676 constexpr int32_t value_scale = 256;
677 constexpr int32_t value_div =
678 static_cast<int32_t>((25.6f * 26.0f) + 0.5f);
679 int32_t value =
680 ((frontend_output.values[i] * value_scale) + (value_div / 2)) /
681 value_div;
682 value -= 128;
683 if (value < -128) {
684 value = -128;
685 }
686 if (value > 127) {
687 value = 127;
688 }
689 output[i] = value;
690 }
691
692 return kTfLiteOk;
693 }
694
696 virtual void respondToCommand(const char* found_command, uint8_t score,
697 bool is_new_command) {
698 if (cfg.respondToCommand != nullptr) {
699 cfg.respondToCommand(found_command, score, is_new_command);
700 } else {
701 TRACED();
702 if (is_new_command) {
703 char buffer[80];
704 snprintf(buffer, 80, "Result: %s, score: %d, is_new: %s", found_command,
705 score, is_new_command ? "true" : "false");
706 Serial.println(buffer);
707 }
708 }
709 }
710};
711
720 public: TfLiteSineReader(int16_t range=32767, float increment=0.01 ){
721 this->increment = increment;
722 this->range = range;
723 }
724
725 virtual bool begin(TfLiteAudioStreamBase *parent) override {
726 // setup on first call
727 p_interpreter = &parent->interpreter();
728 input = p_interpreter->input(0);
729 output = p_interpreter->output(0);
730 channels = parent->config().channels;
731 return true;
732 }
733
734 virtual int read(int16_t*data, int sampleCount) override {
735 TRACED();
736 float two_pi = 2 * PI;
737 for (int j=0; j<sampleCount; j+=channels){
738 // Quantize the input from floating-point to integer
739 input->data.int8[0] = TfLiteQuantizer::quantize(actX,input->params.scale, input->params.zero_point);
740
741 // Invoke TF Model
742 TfLiteStatus invoke_status = p_interpreter->Invoke();
743
744 // Check the result
745 if(kTfLiteOk!= invoke_status){
746 LOGE("invoke_status not ok");
747 return j;
748 }
749 if(kTfLiteInt8 != output->type){
750 LOGE("Output type is not kTfLiteInt8");
751 return j;
752 }
753
754 // Dequantize the output and convet it to int32 range
755 data[j] = TfLiteQuantizer::dequantizeToNewRange(output->data.int8[0], output->params.scale, output->params.zero_point, range);
756 // printf("%d\n", data[j]); // for debugging using the Serial Plotter
757 LOGD("%f->%d / %d->%d",actX, input->data.int8[0], output->data.int8[0], data[j]);
758 for (int i=1;i<channels;i++){
759 data[j+i] = data[j];
760 LOGD("generate data for channels");
761 }
762 // Increment X
763 actX += increment;
764 if (actX>two_pi){
765 actX-=two_pi;
766 }
767 }
768 return sampleCount;
769 }
770
771 protected:
772 float actX=0;
773 float increment=0.1;
774 int16_t range=0;
775 int channels;
776 TfLiteTensor* input = nullptr;
777 TfLiteTensor* output = nullptr;
778 tflite::MicroInterpreter* p_interpreter = nullptr;
779};
780
788 public:
791 if (p_tensor_arena != nullptr) delete[] p_tensor_arena;
792 }
793
794
796 void setInterpreter(tflite::MicroInterpreter* p_interpreter) {
797 TRACED();
798 this->p_interpreter = p_interpreter;
799 }
800
801 // Provides the default configuration
802 virtual TfLiteConfig defaultConfig() override {
803 TfLiteConfig def;
804 return def;
805 }
806
808 virtual bool begin(TfLiteConfig config) override {
809 TRACED();
810 cfg = config;
811
812 // alloatme memory
813 p_tensor_arena = new uint8_t[cfg.kTensorArenaSize];
814
815 if (cfg.categoryCount()>0){
816
817 // setup the feature provider
818 if (!setupWriter()) {
819 LOGE("setupWriter");
820 return false;
821 }
822 } else {
823 LOGW("categoryCount=%d", cfg.categoryCount());
824 }
825
826 // Map the model into a usable data structure. This doesn't involve any
827 // copying or parsing, it's a very lightweight operation.
828 if (!setModel(cfg.model)) {
829 return false;
830 }
831
832 if (!setupInterpreter()) {
833 return false;
834 }
835
836 // Allocate memory from the p_tensor_arena for the model's tensors.
837 LOGI("AllocateTensors");
838 TfLiteStatus allocate_status = p_interpreter->AllocateTensors();
839 if (allocate_status != kTfLiteOk) {
840 LOGE("AllocateTensors() failed");
841 return false;
842 }
843
844 // Get information about the memory area to use for the model's input.
845 LOGI("Get Input");
846 p_tensor = p_interpreter->input(0);
847 if (cfg.categoryCount()>0){
848 if ((p_tensor->dims->size != 2) || (p_tensor->dims->data[0] != 1) ||
849 (p_tensor->dims->data[1] !=
850 (cfg.kFeatureSliceCount * cfg.kFeatureSliceSize)) ||
851 (p_tensor->type != kTfLiteInt8)) {
852 LOGE("Bad input tensor parameters in model");
853 return false;
854 }
855 }
856
857 LOGI("Get Buffer");
858 p_tensor_buffer = p_tensor->data.int8;
859 if (p_tensor_buffer == nullptr) {
860 LOGE("p_tensor_buffer is null");
861 return false;
862 }
863
864 // setup reader
865 if (cfg.reader!=nullptr){
866 cfg.reader->begin(this);
867 }
868
869 // all good if we made it here
870 is_setup = true;
871 LOGI("done");
872 return true;
873 }
874
876 virtual int availableToWrite() override { return DEFAULT_BUFFER_SIZE; }
877
879 virtual size_t write(const uint8_t* data, size_t len) override {
880 TRACED();
881 if (cfg.writer==nullptr){
882 LOGE("cfg.output is null");
883 return 0;
884 }
885 int16_t* samples = (int16_t*)data;
886 int16_t sample_count = len / 2;
887 for (int j = 0; j < sample_count; j++) {
888 cfg.writer->write(samples[j]);
889 }
890 return len;
891 }
892
894 virtual int available() override { return cfg.reader != nullptr ? DEFAULT_BUFFER_SIZE : 0; }
895
897 virtual size_t readBytes(uint8_t *data, size_t len) override {
898 TRACED();
899 if (cfg.reader!=nullptr){
900 return cfg.reader->read((int16_t*)data, (int) len/sizeof(int16_t)) * sizeof(int16_t);
901 }else {
902 return 0;
903 }
904 }
905
907 tflite::MicroInterpreter& interpreter() override {
908 return *p_interpreter;
909 }
910
912 TfLiteConfig &config() override {
913 return cfg;
914 }
915
917 int8_t* modelInputBuffer() override {
918 return p_tensor_buffer;
919 }
920
921 protected:
922 const tflite::Model* p_model = nullptr;
923 tflite::MicroInterpreter* p_interpreter = nullptr;
924 TfLiteTensor* p_tensor = nullptr;
925 bool is_setup = false;
926 TfLiteConfig cfg;
927 // Create an area of memory to use for input, output, and intermediate
928 // arrays. The size of this will depend on the model you're using, and may
929 // need to be determined by experimentation.
930 uint8_t* p_tensor_arena = nullptr;
931 int8_t* p_tensor_buffer = nullptr;
932
933 virtual bool setModel(const unsigned char* model) {
934 TRACED();
935 p_model = tflite::GetModel(model);
936 if (p_model->version() != TFLITE_SCHEMA_VERSION) {
937 LOGE(
938 "Model provided is schema version %d not equal "
939 "to supported version %d.",
940 p_model->version(), TFLITE_SCHEMA_VERSION);
941 return false;
942 }
943 return true;
944 }
945
946 virtual bool setupWriter() {
947 if (cfg.writer == nullptr) {
948 static TfLiteMicroSpeachWriter writer;
949 cfg.writer = &writer;
950 }
951 return cfg.writer->begin(this);
952 }
953
954 // Pull in only the operation implementations we need.
955 // This relies on a complete list of all the ops needed by this graph.
956 // An easier approach is to just use the AllOpsResolver, but this will
957 // incur some penalty in code space for op implementations that are not
958 // needed by this graph.
959 //
960 virtual bool setupInterpreter() {
961 if (p_interpreter == nullptr) {
962 TRACEI();
963 if (cfg.useAllOpsResolver) {
964 tflite::AllOpsResolver resolver;
965 static tflite::MicroInterpreter static_interpreter{
966 p_model, resolver, p_tensor_arena, cfg.kTensorArenaSize};
967 p_interpreter = &static_interpreter;
968 } else {
969 // NOLINTNEXTLINE(runtime-global-variables)
970 static tflite::MicroMutableOpResolver<4> micro_op_resolver{};
971 if (micro_op_resolver.AddDepthwiseConv2D() != kTfLiteOk) {
972 return false;
973 }
974 if (micro_op_resolver.AddFullyConnected() != kTfLiteOk) {
975 return false;
976 }
977 if (micro_op_resolver.AddSoftmax() != kTfLiteOk) {
978 return false;
979 }
980 if (micro_op_resolver.AddReshape() != kTfLiteOk) {
981 return false;
982 }
983 // Build an p_interpreter to run the model with.
984 static tflite::MicroInterpreter static_interpreter{
985 p_model, micro_op_resolver, p_tensor_arena, cfg.kTensorArenaSize};
986 p_interpreter = &static_interpreter;
987 }
988 }
989 return true;
990 }
991};
992
993} // namespace audio_tools
Base class for all Audio Streams. It support the boolean operator to test if the object is ready with...
Definition BaseStream.h:119
virtual int readArray(T data[], int len)
reads multiple values
Definition Buffers.h:37
virtual int writeArray(const T data[], int len)
Fills the buffer data.
Definition Buffers.h:59
Implements a typed Ringbuffer.
Definition Buffers.h:308
virtual int availableForWrite() override
provides the number of entries that are available to write
Definition Buffers.h:380
virtual bool write(T data) override
write add an entry to the buffer
Definition Buffers.h:358
Base class for implementing different primitive decoding models on top of the instantaneous results f...
Definition TfLiteAudioStream.h:194
Astract TfLiteAudioStream to provide access to TfLiteAudioStream for Reader and Writers.
Definition TfLiteAudioStream.h:386
virtual size_t write(const uint8_t *data, size_t len)=0
process the data in batches of max kMaxAudioSampleSize.
virtual int8_t * modelInputBuffer()=0
Provides access to the model input buffer.
virtual TfLiteConfig & config()=0
Provides the TfLiteConfig information.
TfLiteAudioStream which uses Tensorflow Light to analyze the data. If it is used as a generator (wher...
Definition TfLiteAudioStream.h:787
virtual size_t write(const uint8_t *data, size_t len) override
process the data in batches of max kMaxAudioSampleSize.
Definition TfLiteAudioStream.h:879
tflite::MicroInterpreter & interpreter() override
Provides the tf lite interpreter.
Definition TfLiteAudioStream.h:907
TfLiteConfig & config() override
Provides the TfLiteConfig information.
Definition TfLiteAudioStream.h:912
virtual size_t readBytes(uint8_t *data, size_t len) override
provide audio data with cfg.input
Definition TfLiteAudioStream.h:897
virtual bool begin(TfLiteConfig config) override
Start the processing.
Definition TfLiteAudioStream.h:808
int8_t * modelInputBuffer() override
Provides access to the model input buffer.
Definition TfLiteAudioStream.h:917
void setInterpreter(tflite::MicroInterpreter *p_interpreter)
Optionally define your own p_interpreter.
Definition TfLiteAudioStream.h:796
virtual int availableToWrite() override
Constant streaming.
Definition TfLiteAudioStream.h:876
virtual int available() override
We can provide only some audio data when cfg.input is defined.
Definition TfLiteAudioStream.h:894
TfLiteMicroSpeachWriter for Audio Data.
Definition TfLiteAudioStream.h:410
void printFeatures()
For debugging: print feature matrix.
Definition TfLiteAudioStream.h:599
virtual bool write1(const int16_t sample)
Processes a single sample.
Definition TfLiteAudioStream.h:505
virtual bool begin(TfLiteAudioStreamBase *parent)
Call begin before starting the processing.
Definition TfLiteAudioStream.h:420
virtual void respondToCommand(const char *found_command, uint8_t score, bool is_new_command)
Overwrite this method to implement your own handler or provide callback.
Definition TfLiteAudioStream.h:696
This class is designed to apply a very primitive decoding model on top of the instantaneous results f...
Definition TfLiteAudioStream.h:215
TfLiteStatus validate(const TfLiteTensor *latest_results)
Checks the input data.
Definition TfLiteAudioStream.h:348
TfLiteStatus evaluate(const char **found_command, uint8_t *result_score, bool *is_new_command)
Finds the result.
Definition TfLiteAudioStream.h:302
void deleteOldRecords(int32_t limit)
Removes obsolete records from the queue.
Definition TfLiteAudioStream.h:294
int categoryCount()
Determines the number of categories.
Definition TfLiteAudioStream.h:289
bool begin(TfLiteConfig cfg) override
Setup parameters from config.
Definition TfLiteAudioStream.h:222
int resultCategoryIdx(int8_t *score)
finds the category with the biggest score
Definition TfLiteAudioStream.h:277
Quantizer that helps to quantize and dequantize between float and int8.
Definition TfLiteAudioStream.h:161
Input class which provides the next value if the TfLiteAudioStream is treated as an audio sourcce.
Definition TfLiteAudioStream.h:42
Generate a sine output from a model that was trained on the sine method. (=hello_world)
Definition TfLiteAudioStream.h:719
Output class which interprets audio data if TfLiteAudioStream is treated as audio sink.
Definition TfLiteAudioStream.h:54
Generic Implementation of sound input and output for desktop environments using portaudio.
Definition AudioCodecsBase.h:10
Configuration settings for TfLiteAudioStream.
Definition TfLiteAudioStream.h:67
void setCategories(const char *(&array)[N])
Defines the labels.
Definition TfLiteAudioStream.h:129