arduino-audio-tools
Loading...
Searching...
No Matches
TfLiteAudioStream.h
1#pragma once
2
3// Configure FFT to output 16 bit fixed point.
4#define FIXED_POINT 16
5
6//#include <MicroTFLite.h>
7//#include <TensorFlowLite.h>
8#include "Chirale_TensorFlowLite.h" // https://github.com/spaziochirale/Chirale_TensorFlowLite
9#include <cmath>
10#include <cstdint>
11#include "AudioTools/CoreAudio/BaseStream.h"
12#include "AudioTools/CoreAudio/AudioOutput.h"
13#include "AudioTools/CoreAudio/Buffers.h"
14#include "tensorflow/lite/c/common.h"
15#include "tensorflow/lite/experimental/microfrontend/lib/frontend.h"
16#include "tensorflow/lite/experimental/microfrontend/lib/frontend_util.h"
17#include "tensorflow/lite/micro/all_ops_resolver.h"
18#include "tensorflow/lite/micro/kernels/micro_ops.h"
19#include "tensorflow/lite/micro/micro_interpreter.h"
20#include "tensorflow/lite/micro/micro_mutable_op_resolver.h"
21#include "tensorflow/lite/micro/system_setup.h"
22#include "tensorflow/lite/schema/schema_generated.h"
23
31namespace audio_tools {
32
33// Forward Declarations
34class TfLiteAudioStreamBase;
35class TfLiteAbstractRecognizeCommands;
36
44 public:
45 virtual bool begin(TfLiteAudioStreamBase *parent) = 0;
46 virtual int read(int16_t*data, int len) = 0;
47};
48
56 public:
57 virtual bool begin(TfLiteAudioStreamBase *parent) = 0;
58 virtual bool write(const int16_t sample) = 0;
59};
60
70 const unsigned char* model = nullptr;
71 TfLiteReader *reader = nullptr;
72 TfLiteWriter *writer = nullptr;
73 TfLiteAbstractRecognizeCommands *recognizeCommands=nullptr;
74 bool useAllOpsResolver = false;
75 // callback for command handler
76 void (*respondToCommand)(const char* found_command, uint8_t score,
77 bool is_new_command) = nullptr;
78
79 // Create an area of memory to use for input, output, and intermediate arrays.
80 // The size of this will depend on the model you’re using, and may need to be
81 // determined by experimentation.
82 size_t kTensorArenaSize = 10 * 1024;
83
84 // Keeping these as constant expressions allow us to allocate fixed-sized
85 // arrays on the stack for our working memory.
86
87 // The size of the input time series data we pass to the FFT to produce
88 // the frequency information. This has to be a power of two, and since
89 // we're dealing with 30ms of 16KHz inputs, which means 480 samples, this
90 // is the next value.
91 // int kMaxAudioSampleSize = 320; //512; // 480
92 int sample_rate = 16000;
93
94 // Number of audio channels - is usually 1. If 2 we reduce it to 1 by
95 // averaging the 2 channels
96 int channels = 1;
97
98 // The following values are derived from values used during model training.
99 // If you change the way you preprocess the input, update all these constants.
100 int kFeatureSliceSize = 40;
101 int kFeatureSliceCount = 49;
102 int kFeatureSliceStrideMs = 20;
103 int kFeatureSliceDurationMs = 30;
104
105 // number of new slices to collect before evaluating the model
106 int kSlicesToProcess = 2;
107
108 // Parameters for RecognizeCommands
109 int32_t average_window_duration_ms = 1000;
110 uint8_t detection_threshold = 50;
111 int32_t suppression_ms = 1500;
112 int32_t minimum_count = 3;
113
114 // input for FrontendConfig
115 float filterbank_lower_band_limit = 125.0;
116 float filterbank_upper_band_limit = 7500.0;
117 float noise_reduction_smoothing_bits = 10;
118 float noise_reduction_even_smoothing = 0.025;
119 float noise_reduction_odd_smoothing = 0.06;
120 float noise_reduction_min_signal_remaining = 0.05;
121 bool pcan_gain_control_enable_pcan = 1;
122 float pcan_gain_control_strength = 0.95;
123 float pcan_gain_control_offset = 80.0;
124 float pcan_gain_control_gain_bits = 21;
125 bool log_scale_enable_log = 1;
126 uint8_t log_scale_scale_shift = 6;
127
129 template<int N>
130 void setCategories(const char* (&array)[N]){
131 labels = array;
132 kCategoryCount = N;
133 }
134
135 int categoryCount() {
136 return kCategoryCount;
137 }
138
139 int featureElementCount() {
140 return kFeatureSliceSize * kFeatureSliceCount;
141 }
142
143 int audioSampleSize() {
144 return kFeatureSliceDurationMs * (sample_rate / 1000);
145 }
146
147 int strideSampleSize() {
148 return kFeatureSliceStrideMs * (sample_rate / 1000);
149 }
150
151 private:
152 int kCategoryCount = 0;
153 const char** labels = nullptr;
154};
155
163 public:
164 // convert float to int8
165 static int8_t quantize(float value, float scale, float zero_point){
166 if(scale==0.0&&zero_point==0) return value;
167 return value / scale + zero_point;
168 }
169 // convert int8 to float
170 static float dequantize(int8_t value, float scale, float zero_point){
171 if(scale==0.0&&zero_point==0) return value;
172 return (value - zero_point) * scale;
173 }
174
175 static float dequantizeToNewRange(int8_t value, float scale, float zero_point, float new_range){
176 float deq = (static_cast<float>(value) - zero_point) * scale;
177 return clip(deq * new_range, new_range);
178 }
179
180 static float clip(float value, float range){
181 if (value>=0.0){
182 return value > range ? range : value;
183 } else {
184 return -value < -range ? -range : value;
185 }
186 }
187};
188
196 public:
197 virtual bool begin(TfLiteConfig cfg) = 0;
198 virtual TfLiteStatus getCommand(const TfLiteTensor* latest_results, const int32_t current_time_ms,
199 const char** found_command,uint8_t* score,bool* is_new_command) = 0;
200
201};
202
217 public:
218
220 }
221
223 bool begin(TfLiteConfig cfg) override {
224 TRACED();
225 this->cfg = cfg;
226 if (cfg.labels == nullptr) {
227 LOGE("config.labels not defined");
228 return false;
229 }
230 return true;
231 }
232
233 // Call this with the results of running a model on sample data.
234 virtual TfLiteStatus getCommand(const TfLiteTensor* latest_results,
235 const int32_t current_time_ms,
236 const char** found_command,
237 uint8_t* score,
238 bool* is_new_command) override {
239
240 TRACED();
241 this->current_time_ms = current_time_ms;
242 this->time_since_last_top = current_time_ms - previous_time_ms;
243
244 deleteOldRecords(current_time_ms - cfg.average_window_duration_ms);
245 int idx = resultCategoryIdx(latest_results->data.int8);
246 Result row(current_time_ms, idx, latest_results->data.int8[idx]);
247 result_queue.push_back(row);
248
249 TfLiteStatus result = validate(latest_results);
250 if (result!=kTfLiteOk){
251 return result;
252 }
253 return evaluate(found_command, score, is_new_command);
254 }
255
256 protected:
257 struct Result {
258 int32_t time_ms;
259 int category=0;
260 int8_t score=0;
261
262 Result() = default;
263 Result(int32_t time_ms,int category, int8_t score){
264 this->time_ms = time_ms;
265 this->category = category;
266 this->score = score;
267 }
268 };
269
270 TfLiteConfig cfg;
271 Vector <Result> result_queue;
272 int previous_cateogory=-1;
273 int32_t current_time_ms=0;
274 int32_t previous_time_ms=0;
275 int32_t time_since_last_top=0;
276
278 int resultCategoryIdx(int8_t* score) {
279 int result = -1;
280 uint8_t top_score = std::numeric_limits<uint8_t>::min();
281 for (int j=0;j<categoryCount();j++){
282 if (score[j]>top_score){
283 result = j;
284 }
285 }
286 return result;
287 }
288
291 return cfg.categoryCount();
292 }
293
295 void deleteOldRecords(int32_t limit) {
296 if (result_queue.empty()) return;
297 while (result_queue[0].time_ms<limit){
298 result_queue.pop_front();
299 }
300 }
301
303 TfLiteStatus evaluate(const char** found_command, uint8_t* result_score, bool* is_new_command) {
304 TRACED();
305 float totals[categoryCount()]={0};
306 int count[categoryCount()]={0};
307 // calculate totals
308 for (int j=0;j<result_queue.size();j++){
309 int idx = result_queue[j].category;
310 totals[idx] += result_queue[j].score;
311 count[idx]++;
312 }
313
314 // find max
315 int maxIdx = -1;
316 float max = -100000;
317 for (int j=0;j<categoryCount();j++){
318 if (totals[j]>max){
319 max = totals[j];
320 maxIdx = j;
321 }
322 }
323
324 if (maxIdx==-1){
325 LOGE("Could not find max category")
326 return kTfLiteError;
327 }
328
329 // determine result
330 *result_score = totals[maxIdx] / count[maxIdx];
331 *found_command = cfg.labels[maxIdx];
332
333 if (previous_cateogory!=maxIdx
334 && *result_score > cfg.detection_threshold
335 && time_since_last_top > cfg.suppression_ms){
336 previous_time_ms = current_time_ms;
337 previous_cateogory = maxIdx;
338 *is_new_command = true;
339 } else {
340 *is_new_command = false;
341 }
342
343 LOGD("Category: %s, score: %d, is_new: %d",*found_command, *result_score, *is_new_command);
344
345 return kTfLiteOk;
346 }
347
349 TfLiteStatus validate(const TfLiteTensor* latest_results) {
350 if ((latest_results->dims->size != 2) ||
351 (latest_results->dims->data[0] != 1) ||
352 (latest_results->dims->data[1] != categoryCount())) {
353 LOGE(
354 "The results for recognition should contain %d "
355 "elements, but there are "
356 "%d in an %d-dimensional shape",
357 categoryCount(), (int)latest_results->dims->data[1],
358 (int)latest_results->dims->size);
359 return kTfLiteError;
360 }
361
362 if (latest_results->type != kTfLiteInt8) {
363 LOGE("The results for recognition should be int8 elements, but are %d",
364 (int)latest_results->type);
365 return kTfLiteError;
366 }
367
368 if ((!result_queue.empty()) &&
369 (current_time_ms < result_queue[0].time_ms)) {
370 LOGE("Results must be in increasing time order: timestamp %d < %d",
371 (int)current_time_ms, (int)result_queue[0].time_ms);
372 return kTfLiteError;
373 }
374 return kTfLiteOk;
375 }
376
377};
378
379
388 public:
389 virtual void setInterpreter(tflite::MicroInterpreter* p_interpreter) = 0;
390 virtual TfLiteConfig defaultConfig() = 0;
391 virtual bool begin(TfLiteConfig config) = 0;
392 virtual int availableToWrite() = 0;
393
395 virtual size_t write(const uint8_t* data, size_t len)= 0;
396 virtual tflite::MicroInterpreter& interpreter()= 0;
397
399 virtual TfLiteConfig &config()= 0;
400
402 virtual int8_t* modelInputBuffer()= 0;
403};
404
412 public:
413 TfLiteMicroSpeachWriter() = default;
414
416 if (p_buffer != nullptr) delete p_buffer;
417 if (p_audio_samples != nullptr) delete p_audio_samples;
418 }
419
421 virtual bool begin(TfLiteAudioStreamBase *parent) {
422 TRACED();
423 this->parent = parent;
424 cfg = parent->config();
425 current_time = 0;
426 kMaxAudioSampleSize = cfg.audioSampleSize();
427 kStrideSampleSize = cfg.strideSampleSize();
428 kKeepSampleSize = kMaxAudioSampleSize - kStrideSampleSize;
429
430 if (!setup_recognizer()) {
431 LOGE("setup_recognizer");
432 return false;
433 }
434
435 // setup FrontendConfig
436 TfLiteStatus init_status = initializeMicroFeatures();
437 if (init_status != kTfLiteOk) {
438 return false;
439 }
440
441 // Allocate ring buffer
442 if (p_buffer == nullptr) {
443 p_buffer = new audio_tools::RingBuffer<int16_t>(kMaxAudioSampleSize);
444 LOGD("Allocating buffer for %d samples", kMaxAudioSampleSize);
445 }
446
447 // Initialize the feature data to default values.
448 if (p_feature_data == nullptr) {
449 p_feature_data = new int8_t[cfg.featureElementCount()];
450 memset(p_feature_data, 0, cfg.featureElementCount());
451 }
452
453 // allocate p_audio_samples
454 if (p_audio_samples == nullptr) {
455 p_audio_samples = new int16_t[kMaxAudioSampleSize];
456 memset(p_audio_samples, 0, kMaxAudioSampleSize * sizeof(int16_t));
457 }
458
459 return true;
460 }
461
462 virtual bool write(int16_t sample) {
463 TRACED();
464 if (!write1(sample)){
465 // determine time
466 current_time += cfg.kFeatureSliceStrideMs;
467 // determine slice
468 total_slice_count++;
469
470 int8_t* feature_buffer = addSlice();
471 if (total_slice_count >= cfg.kSlicesToProcess) {
472 processSlices(feature_buffer);
473 // reset total_slice_count
474 total_slice_count = 0;
475 }
476 }
477 return true;
478 }
479
480 protected:
481 TfLiteConfig cfg;
482 TfLiteAudioStreamBase *parent=nullptr;
483 int8_t* p_feature_data = nullptr;
484 int16_t* p_audio_samples = nullptr;
485 audio_tools::RingBuffer<int16_t>* p_buffer = nullptr;
486 FrontendState g_micro_features_state;
487 FrontendConfig config;
488 int kMaxAudioSampleSize;
489 int kStrideSampleSize;
490 int kKeepSampleSize;
491 int16_t last_value;
492 int8_t channel = 0;
493 int32_t current_time = 0;
494 int16_t total_slice_count = 0;
495
496 virtual bool setup_recognizer() {
497 // setup default p_recognizer if not defined
498 if (cfg.recognizeCommands == nullptr) {
499 static TfLiteMicroSpeechRecognizeCommands static_recognizer;
500 cfg.recognizeCommands = &static_recognizer;
501 }
502 return cfg.recognizeCommands->begin(cfg);
503 }
504
506 virtual bool write1(const int16_t sample) {
507 if (cfg.channels == 1) {
508 p_buffer->write(sample);
509 } else {
510 if (channel == 0) {
511 last_value = sample;
512 channel = 1;
513 } else
514 // calculate avg of 2 channels and convert it to int8_t
515 p_buffer->write(((sample / 2) + (last_value / 2)));
516 channel = 0;
517 }
518 return p_buffer->availableForWrite() > 0;
519 }
520
521 // If we can avoid recalculating some slices, just move the existing
522 // data up in the spectrogram, to perform something like this: last time
523 // = 80ms current time = 120ms
524 // +-----------+ +-----------+
525 // | data@20ms | --> | data@60ms |
526 // +-----------+ -- +-----------+
527 // | data@40ms | -- --> | data@80ms |
528 // +-----------+ -- -- +-----------+
529 // | data@60ms | -- -- | <empty> |
530 // +-----------+ -- +-----------+
531 // | data@80ms | -- | <empty> |
532 // +-----------+ +-----------+
533 virtual int8_t* addSlice() {
534 TRACED();
535 // shift p_feature_data by one slice one one
536 memmove(p_feature_data, p_feature_data + cfg.kFeatureSliceSize,
537 (cfg.kFeatureSliceCount - 1) * cfg.kFeatureSliceSize);
538
539 // copy data from buffer to p_audio_samples
540 int audio_samples_size =
541 p_buffer->readArray(p_audio_samples, kMaxAudioSampleSize);
542
543 // check size
544 if (audio_samples_size != kMaxAudioSampleSize) {
545 LOGE("audio_samples_size=%d != kMaxAudioSampleSize=%d",
546 audio_samples_size, kMaxAudioSampleSize);
547 }
548
549 // keep some data to be reprocessed - move by kStrideSampleSize
550 p_buffer->writeArray(p_audio_samples + kStrideSampleSize, kKeepSampleSize);
551
552 // the new slice data will always be stored at the end
553 int8_t* new_slice_data =
554 p_feature_data + ((cfg.kFeatureSliceCount - 1) * cfg.kFeatureSliceSize);
555 size_t num_samples_read = 0;
556 if (generateMicroFeatures(p_audio_samples, audio_samples_size,
557 new_slice_data, cfg.kFeatureSliceSize,
558 &num_samples_read) != kTfLiteOk) {
559 LOGE("Error generateMicroFeatures");
560 }
561 // printFeatures();
562 return p_feature_data;
563 }
564
565 // Process multiple slice of audio data
566 virtual bool processSlices(int8_t* feature_buffer) {
567 LOGI("->slices: %d", total_slice_count);
568 // Copy feature buffer to input tensor
569 memcpy(parent->modelInputBuffer(), feature_buffer, cfg.featureElementCount());
570
571 // Run the model on the spectrogram input and make sure it succeeds.
572 TfLiteStatus invoke_status = parent->interpreter().Invoke();
573 if (invoke_status != kTfLiteOk) {
574 LOGE("Invoke failed");
575 return false;
576 }
577
578 // Obtain a pointer to the output tensor
579 TfLiteTensor* output = parent->interpreter().output(0);
580
581 // Determine whether a command was recognized
582 const char* found_command = nullptr;
583 uint8_t score = 0;
584 bool is_new_command = false;
585
586 TfLiteStatus process_status = cfg.recognizeCommands->getCommand(
587 output, current_time, &found_command, &score, &is_new_command);
588 if (process_status != kTfLiteOk) {
589 LOGE("TfLiteMicroSpeechRecognizeCommands::getCommand() failed");
590 return false;
591 }
592 // Do something based on the recognized command. The default
593 // implementation just prints to the error console, but you should replace
594 // this with your own function for a real application.
595 respondToCommand(found_command, score, is_new_command);
596 return true;
597 }
598
601 for (int i = 0; i < cfg.kFeatureSliceCount; i++) {
602 for (int j = 0; j < cfg.kFeatureSliceSize; j++) {
603 Serial.print(p_feature_data[(i * cfg.kFeatureSliceSize) + j]);
604 Serial.print(" ");
605 }
606 Serial.println();
607 }
608 Serial.println("------------");
609 }
610
611 virtual TfLiteStatus initializeMicroFeatures() {
612 TRACED();
613 config.window.size_ms = cfg.kFeatureSliceDurationMs;
614 config.window.step_size_ms = cfg.kFeatureSliceStrideMs;
615 config.filterbank.num_channels = cfg.kFeatureSliceSize;
616 config.filterbank.lower_band_limit = cfg.filterbank_lower_band_limit;
617 config.filterbank.upper_band_limit = cfg.filterbank_upper_band_limit;
618 config.noise_reduction.smoothing_bits = cfg.noise_reduction_smoothing_bits;
619 config.noise_reduction.even_smoothing = cfg.noise_reduction_even_smoothing;
620 config.noise_reduction.odd_smoothing = cfg.noise_reduction_odd_smoothing;
621 config.noise_reduction.min_signal_remaining = cfg.noise_reduction_min_signal_remaining;
622 config.pcan_gain_control.enable_pcan = cfg.pcan_gain_control_enable_pcan;
623 config.pcan_gain_control.strength = cfg.pcan_gain_control_strength;
624 config.pcan_gain_control.offset = cfg.pcan_gain_control_offset ;
625 config.pcan_gain_control.gain_bits = cfg.pcan_gain_control_gain_bits;
626 config.log_scale.enable_log = cfg.log_scale_enable_log;
627 config.log_scale.scale_shift = cfg.log_scale_scale_shift;
628 if (!FrontendPopulateState(&config, &g_micro_features_state,
629 cfg.sample_rate)) {
630 LOGE("frontendPopulateState() failed");
631 return kTfLiteError;
632 }
633 return kTfLiteOk;
634 }
635
636 virtual TfLiteStatus generateMicroFeatures(const int16_t* input,
637 int input_size, int8_t* output,
638 int output_size,
639 size_t* num_samples_read) {
640 TRACED();
641 const int16_t* frontend_input = input;
642
643 // Apply FFT
644 FrontendOutput frontend_output = FrontendProcessSamples(
645 &g_micro_features_state, frontend_input, input_size, num_samples_read);
646
647 // Check size
648 if (output_size != frontend_output.size) {
649 LOGE("output_size=%d, frontend_output.size=%d", output_size,
650 frontend_output.size);
651 }
652
653 // printf("input_size: %d, num_samples_read: %d,output_size: %d,
654 // frontend_output.size:%d \n", input_size, *num_samples_read, output_size,
655 // frontend_output.size);
656
657 // // check generated features
658 // if (input_size != *num_samples_read){
659 // LOGE("audio_samples_size=%d vs num_samples_read=%d", input_size,
660 // *num_samples_read);
661 // }
662
663 for (size_t i = 0; i < frontend_output.size; ++i) {
664 // These scaling values are derived from those used in input_data.py in
665 // the training pipeline. The feature pipeline outputs 16-bit signed
666 // integers in roughly a 0 to 670 range. In training, these are then
667 // arbitrarily divided by 25.6 to get float values in the rough range of
668 // 0.0 to 26.0. This scaling is performed for historical reasons, to match
669 // up with the output of other feature generators. The process is then
670 // further complicated when we quantize the model. This means we have to
671 // scale the 0.0 to 26.0 real values to the -128 to 127 signed integer
672 // numbers. All this means that to get matching values from our integer
673 // feature output into the tensor input, we have to perform: input =
674 // (((feature / 25.6) / 26.0) * 256) - 128 To simplify this and perform it
675 // in 32-bit integer math, we rearrange to: input = (feature * 256) /
676 // (25.6 * 26.0) - 128
677 constexpr int32_t value_scale = 256;
678 constexpr int32_t value_div =
679 static_cast<int32_t>((25.6f * 26.0f) + 0.5f);
680 int32_t value =
681 ((frontend_output.values[i] * value_scale) + (value_div / 2)) /
682 value_div;
683 value -= 128;
684 if (value < -128) {
685 value = -128;
686 }
687 if (value > 127) {
688 value = 127;
689 }
690 output[i] = value;
691 }
692
693 return kTfLiteOk;
694 }
695
697 virtual void respondToCommand(const char* found_command, uint8_t score,
698 bool is_new_command) {
699 if (cfg.respondToCommand != nullptr) {
700 cfg.respondToCommand(found_command, score, is_new_command);
701 } else {
702 TRACED();
703 if (is_new_command) {
704 char buffer[80];
705 snprintf(buffer, 80, "Result: %s, score: %d, is_new: %s", found_command,
706 score, is_new_command ? "true" : "false");
707 Serial.println(buffer);
708 }
709 }
710 }
711};
712
721 public: TfLiteSineReader(int16_t range=32767, float increment=0.01 ){
722 this->increment = increment;
723 this->range = range;
724 }
725
726 virtual bool begin(TfLiteAudioStreamBase *parent) override {
727 // setup on first call
728 p_interpreter = &parent->interpreter();
729 input = p_interpreter->input(0);
730 output = p_interpreter->output(0);
731 channels = parent->config().channels;
732 return true;
733 }
734
735 virtual int read(int16_t*data, int sampleCount) override {
736 TRACED();
737 float two_pi = 2 * PI;
738 for (int j=0; j<sampleCount; j+=channels){
739 // Quantize the input from floating-point to integer
740 input->data.int8[0] = TfLiteQuantizer::quantize(actX,input->params.scale, input->params.zero_point);
741
742 // Invoke TF Model
743 TfLiteStatus invoke_status = p_interpreter->Invoke();
744
745 // Check the result
746 if(kTfLiteOk!= invoke_status){
747 LOGE("invoke_status not ok");
748 return j;
749 }
750 if(kTfLiteInt8 != output->type){
751 LOGE("Output type is not kTfLiteInt8");
752 return j;
753 }
754
755 // Dequantize the output and convet it to int32 range
756 data[j] = TfLiteQuantizer::dequantizeToNewRange(output->data.int8[0], output->params.scale, output->params.zero_point, range);
757 // printf("%d\n", data[j]); // for debugging using the Serial Plotter
758 LOGD("%f->%d / %d->%d",actX, input->data.int8[0], output->data.int8[0], data[j]);
759 for (int i=1;i<channels;i++){
760 data[j+i] = data[j];
761 LOGD("generate data for channels");
762 }
763 // Increment X
764 actX += increment;
765 if (actX>two_pi){
766 actX-=two_pi;
767 }
768 }
769 return sampleCount;
770 }
771
772 protected:
773 float actX=0;
774 float increment=0.1;
775 int16_t range=0;
776 int channels;
777 TfLiteTensor* input = nullptr;
778 TfLiteTensor* output = nullptr;
779 tflite::MicroInterpreter* p_interpreter = nullptr;
780};
781
789 public:
792 if (p_tensor_arena != nullptr) delete[] p_tensor_arena;
793 }
794
795
797 void setInterpreter(tflite::MicroInterpreter* p_interpreter) {
798 TRACED();
799 this->p_interpreter = p_interpreter;
800 }
801
802 // Provides the default configuration
803 virtual TfLiteConfig defaultConfig() override {
804 TfLiteConfig def;
805 return def;
806 }
807
809 virtual bool begin(TfLiteConfig config) override {
810 TRACED();
811 cfg = config;
812
813 // alloatme memory
814 p_tensor_arena = new uint8_t[cfg.kTensorArenaSize];
815
816 if (cfg.categoryCount()>0){
817
818 // setup the feature provider
819 if (!setupWriter()) {
820 LOGE("setupWriter");
821 return false;
822 }
823 } else {
824 LOGW("categoryCount=%d", cfg.categoryCount());
825 }
826
827 // Map the model into a usable data structure. This doesn't involve any
828 // copying or parsing, it's a very lightweight operation.
829 if (!setModel(cfg.model)) {
830 return false;
831 }
832
833 if (!setupInterpreter()) {
834 return false;
835 }
836
837 // Allocate memory from the p_tensor_arena for the model's tensors.
838 LOGI("AllocateTensors");
839 TfLiteStatus allocate_status = p_interpreter->AllocateTensors();
840 if (allocate_status != kTfLiteOk) {
841 LOGE("AllocateTensors() failed");
842 return false;
843 }
844
845 // Get information about the memory area to use for the model's input.
846 LOGI("Get Input");
847 p_tensor = p_interpreter->input(0);
848 if (cfg.categoryCount()>0){
849 if ((p_tensor->dims->size != 2) || (p_tensor->dims->data[0] != 1) ||
850 (p_tensor->dims->data[1] !=
851 (cfg.kFeatureSliceCount * cfg.kFeatureSliceSize)) ||
852 (p_tensor->type != kTfLiteInt8)) {
853 LOGE("Bad input tensor parameters in model");
854 return false;
855 }
856 }
857
858 LOGI("Get Buffer");
859 p_tensor_buffer = p_tensor->data.int8;
860 if (p_tensor_buffer == nullptr) {
861 LOGE("p_tensor_buffer is null");
862 return false;
863 }
864
865 // setup reader
866 if (cfg.reader!=nullptr){
867 cfg.reader->begin(this);
868 }
869
870 // all good if we made it here
871 is_setup = true;
872 LOGI("done");
873 return true;
874 }
875
877 virtual int availableToWrite() override { return DEFAULT_BUFFER_SIZE; }
878
880 virtual size_t write(const uint8_t* data, size_t len) override {
881 TRACED();
882 if (cfg.writer==nullptr){
883 LOGE("cfg.output is null");
884 return 0;
885 }
886 int16_t* samples = (int16_t*)data;
887 int16_t sample_count = len / 2;
888 for (int j = 0; j < sample_count; j++) {
889 cfg.writer->write(samples[j]);
890 }
891 return len;
892 }
893
895 virtual int available() override { return cfg.reader != nullptr ? DEFAULT_BUFFER_SIZE : 0; }
896
898 virtual size_t readBytes(uint8_t *data, size_t len) override {
899 TRACED();
900 if (cfg.reader!=nullptr){
901 return cfg.reader->read((int16_t*)data, (int) len/sizeof(int16_t)) * sizeof(int16_t);
902 }else {
903 return 0;
904 }
905 }
906
908 tflite::MicroInterpreter& interpreter() override {
909 return *p_interpreter;
910 }
911
913 TfLiteConfig &config() override {
914 return cfg;
915 }
916
918 int8_t* modelInputBuffer() override {
919 return p_tensor_buffer;
920 }
921
922 protected:
923 const tflite::Model* p_model = nullptr;
924 tflite::MicroInterpreter* p_interpreter = nullptr;
925 TfLiteTensor* p_tensor = nullptr;
926 bool is_setup = false;
927 TfLiteConfig cfg;
928 // Create an area of memory to use for input, output, and intermediate
929 // arrays. The size of this will depend on the model you're using, and may
930 // need to be determined by experimentation.
931 uint8_t* p_tensor_arena = nullptr;
932 int8_t* p_tensor_buffer = nullptr;
933
934 virtual bool setModel(const unsigned char* model) {
935 TRACED();
936 p_model = tflite::GetModel(model);
937 if (p_model->version() != TFLITE_SCHEMA_VERSION) {
938 LOGE(
939 "Model provided is schema version %d not equal "
940 "to supported version %d.",
941 p_model->version(), TFLITE_SCHEMA_VERSION);
942 return false;
943 }
944 return true;
945 }
946
947 virtual bool setupWriter() {
948 if (cfg.writer == nullptr) {
949 static TfLiteMicroSpeachWriter writer;
950 cfg.writer = &writer;
951 }
952 return cfg.writer->begin(this);
953 }
954
955 // Pull in only the operation implementations we need.
956 // This relies on a complete list of all the ops needed by this graph.
957 // An easier approach is to just use the AllOpsResolver, but this will
958 // incur some penalty in code space for op implementations that are not
959 // needed by this graph.
960 //
961 virtual bool setupInterpreter() {
962 if (p_interpreter == nullptr) {
963 TRACEI();
964 if (cfg.useAllOpsResolver) {
965 tflite::AllOpsResolver resolver;
966 static tflite::MicroInterpreter static_interpreter{
967 p_model, resolver, p_tensor_arena, cfg.kTensorArenaSize};
968 p_interpreter = &static_interpreter;
969 } else {
970 // NOLINTNEXTLINE(runtime-global-variables)
971 static tflite::MicroMutableOpResolver<4> micro_op_resolver{};
972 if (micro_op_resolver.AddDepthwiseConv2D() != kTfLiteOk) {
973 return false;
974 }
975 if (micro_op_resolver.AddFullyConnected() != kTfLiteOk) {
976 return false;
977 }
978 if (micro_op_resolver.AddSoftmax() != kTfLiteOk) {
979 return false;
980 }
981 if (micro_op_resolver.AddReshape() != kTfLiteOk) {
982 return false;
983 }
984 // Build an p_interpreter to run the model with.
985 static tflite::MicroInterpreter static_interpreter{
986 p_model, micro_op_resolver, p_tensor_arena, cfg.kTensorArenaSize};
987 p_interpreter = &static_interpreter;
988 }
989 }
990 return true;
991 }
992};
993
994} // namespace audio_tools
Base class for all Audio Streams. It support the boolean operator to test if the object is ready with...
Definition BaseStream.h:122
virtual int readArray(T data[], int len)
reads multiple values
Definition Buffers.h:33
virtual int writeArray(const T data[], int len)
Fills the buffer data.
Definition Buffers.h:55
Implements a typed Ringbuffer.
Definition Buffers.h:341
virtual int availableForWrite() override
provides the number of entries that are available to write
Definition Buffers.h:413
virtual bool write(T data) override
write add an entry to the buffer
Definition Buffers.h:391
Base class for implementing different primitive decoding models on top of the instantaneous results f...
Definition TfLiteAudioStream.h:195
Astract TfLiteAudioStream to provide access to TfLiteAudioStream for Reader and Writers.
Definition TfLiteAudioStream.h:387
virtual size_t write(const uint8_t *data, size_t len)=0
process the data in batches of max kMaxAudioSampleSize.
virtual int8_t * modelInputBuffer()=0
Provides access to the model input buffer.
virtual TfLiteConfig & config()=0
Provides the TfLiteConfig information.
TfLiteAudioStream which uses Tensorflow Light to analyze the data. If it is used as a generator (wher...
Definition TfLiteAudioStream.h:788
virtual size_t write(const uint8_t *data, size_t len) override
process the data in batches of max kMaxAudioSampleSize.
Definition TfLiteAudioStream.h:880
tflite::MicroInterpreter & interpreter() override
Provides the tf lite interpreter.
Definition TfLiteAudioStream.h:908
TfLiteConfig & config() override
Provides the TfLiteConfig information.
Definition TfLiteAudioStream.h:913
virtual size_t readBytes(uint8_t *data, size_t len) override
provide audio data with cfg.input
Definition TfLiteAudioStream.h:898
virtual bool begin(TfLiteConfig config) override
Start the processing.
Definition TfLiteAudioStream.h:809
int8_t * modelInputBuffer() override
Provides access to the model input buffer.
Definition TfLiteAudioStream.h:918
void setInterpreter(tflite::MicroInterpreter *p_interpreter)
Optionally define your own p_interpreter.
Definition TfLiteAudioStream.h:797
virtual int availableToWrite() override
Constant streaming.
Definition TfLiteAudioStream.h:877
virtual int available() override
We can provide only some audio data when cfg.input is defined.
Definition TfLiteAudioStream.h:895
TfLiteMicroSpeachWriter for Audio Data.
Definition TfLiteAudioStream.h:411
void printFeatures()
For debugging: print feature matrix.
Definition TfLiteAudioStream.h:600
virtual bool write1(const int16_t sample)
Processes a single sample.
Definition TfLiteAudioStream.h:506
virtual bool begin(TfLiteAudioStreamBase *parent)
Call begin before starting the processing.
Definition TfLiteAudioStream.h:421
virtual void respondToCommand(const char *found_command, uint8_t score, bool is_new_command)
Overwrite this method to implement your own handler or provide callback.
Definition TfLiteAudioStream.h:697
This class is designed to apply a very primitive decoding model on top of the instantaneous results f...
Definition TfLiteAudioStream.h:216
TfLiteStatus validate(const TfLiteTensor *latest_results)
Checks the input data.
Definition TfLiteAudioStream.h:349
TfLiteStatus evaluate(const char **found_command, uint8_t *result_score, bool *is_new_command)
Finds the result.
Definition TfLiteAudioStream.h:303
void deleteOldRecords(int32_t limit)
Removes obsolete records from the queue.
Definition TfLiteAudioStream.h:295
int categoryCount()
Determines the number of categories.
Definition TfLiteAudioStream.h:290
bool begin(TfLiteConfig cfg) override
Setup parameters from config.
Definition TfLiteAudioStream.h:223
int resultCategoryIdx(int8_t *score)
finds the category with the biggest score
Definition TfLiteAudioStream.h:278
Quantizer that helps to quantize and dequantize between float and int8.
Definition TfLiteAudioStream.h:162
Input class which provides the next value if the TfLiteAudioStream is treated as an audio sourcce.
Definition TfLiteAudioStream.h:43
Generate a sine output from a model that was trained on the sine method. (=hello_world)
Definition TfLiteAudioStream.h:720
Output class which interprets audio data if TfLiteAudioStream is treated as audio sink.
Definition TfLiteAudioStream.h:55
Generic Implementation of sound input and output for desktop environments using portaudio.
Definition AudioCodecsBase.h:10
Configuration settings for TfLiteAudioStream.
Definition TfLiteAudioStream.h:68
void setCategories(const char *(&array)[N])
Defines the labels.
Definition TfLiteAudioStream.h:130