Skip to content

Commit ed8cca1

Browse files
authored
Fix audio inference (#20)
1 parent 2ba1b5b commit ed8cca1

File tree

4 files changed

+146
-39
lines changed

4 files changed

+146
-39
lines changed

Cargo.lock

Lines changed: 2 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ regex = "1.10.3"
3232
serde = "1.0"
3333
serde_json = "1.0"
3434
tokio = { version = "1.0", features = ["rt-multi-thread"] }
35-
edge-impulse-runner = { git = "https://github.com/edgeimpulse/edge-impulse-runner-rs.git", rev = "ab83018", default-features = false }
35+
edge-impulse-runner = { git = "https://github.com/edgeimpulse/edge-impulse-runner-rs.git", rev = "20ab935", default-features = false }
3636
tempfile = "3.10"
3737

3838
[dev-dependencies]

examples/audio_inference.rs

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,9 @@ fn create_pipeline(
7272
};
7373

7474
// Create pipeline elements
75-
let capsfilter1 = gst::ElementFactory::make("capsfilter").build()?;
7675
let audioconvert1 = gst::ElementFactory::make("audioconvert").build()?;
7776
let audioresample1 = gst::ElementFactory::make("audioresample").build()?;
78-
let capsfilter2 = gst::ElementFactory::make("capsfilter").build()?;
77+
let capsfilter1 = gst::ElementFactory::make("capsfilter").build()?;
7978
let mut edgeimpulseinfer_factory = gst::ElementFactory::make("edgeimpulseaudioinfer");
8079

8180
// Set model path if provided (EIM mode)
@@ -92,54 +91,48 @@ fn create_pipeline(
9291
let edgeimpulseinfer = edgeimpulseinfer_factory.build()?;
9392
let audioconvert2 = gst::ElementFactory::make("audioconvert").build()?;
9493
let audioresample2 = gst::ElementFactory::make("audioresample").build()?;
95-
let capsfilter3 = gst::ElementFactory::make("capsfilter").build()?;
96-
let sink = gst::ElementFactory::make("autoaudiosink").build()?;
94+
let capsfilter2 = gst::ElementFactory::make("capsfilter").build()?;
95+
let sink = gst::ElementFactory::make("fakesink").build()?;
9796

98-
// Configure caps
97+
// Configure caps - the Edge Impulse element expects S16LE mono audio at 16kHz
9998
let caps1 = gst::Caps::builder("audio/x-raw")
100-
.field("format", "F32LE")
101-
.build();
102-
capsfilter1.set_property("caps", &caps1);
103-
104-
let caps2 = gst::Caps::builder("audio/x-raw")
10599
.field("format", "S16LE")
106100
.field("channels", 1)
107101
.field("rate", 16000)
108102
.field("layout", "interleaved")
109103
.build();
110-
capsfilter2.set_property("caps", &caps2);
104+
capsfilter1.set_property("caps", &caps1);
111105

112-
let caps3 = gst::Caps::builder("audio/x-raw")
106+
// Configure output caps for the sink - standard audio format
107+
let caps2 = gst::Caps::builder("audio/x-raw")
113108
.field("format", "F32LE")
114109
.field("channels", 2)
115110
.field("rate", 44100)
116111
.build();
117-
capsfilter3.set_property("caps", &caps3);
112+
capsfilter2.set_property("caps", &caps2);
118113

119114
// Add elements to pipeline
120115
pipeline.add_many(&[
121-
&capsfilter1,
122116
&audioconvert1,
123117
&audioresample1,
124-
&capsfilter2,
118+
&capsfilter1,
125119
&edgeimpulseinfer,
126120
&audioconvert2,
127121
&audioresample2,
128-
&capsfilter3,
122+
&capsfilter2,
129123
&sink,
130124
])?;
131125

132126
// Link elements
133127
gst::Element::link_many(&[
134128
&source,
135-
&capsfilter1,
136129
&audioconvert1,
137130
&audioresample1,
138-
&capsfilter2,
131+
&capsfilter1,
139132
&edgeimpulseinfer,
140133
&audioconvert2,
141134
&audioresample2,
142-
&capsfilter3,
135+
&capsfilter2,
143136
&sink,
144137
])?;
145138

src/audio/imp.rs

Lines changed: 131 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -277,7 +277,7 @@ impl BaseTransformImpl for EdgeImpulseAudioInfer {
277277
// Copy the input buffer to the output buffer (passthrough)
278278
out_map.copy_from_slice(&in_map);
279279

280-
// Convert input samples from S16LE to f32 normalized [-1, 1] for inference
280+
// Convert input samples from S16LE to f32 for inference
281281
let samples: Vec<f32> = in_map
282282
.chunks_exact(2)
283283
.map(|chunk| {
@@ -378,8 +378,13 @@ impl BaseTransformImpl for EdgeImpulseAudioInfer {
378378
};
379379

380380
if let Some(mut model) = model {
381+
gst::debug!(CAT, obj = self.obj(), "Got model, getting parameters...");
382+
381383
let params = match model.parameters() {
382-
Ok(p) => p,
384+
Ok(p) => {
385+
gst::debug!(CAT, obj = self.obj(), "Successfully got model parameters");
386+
p
387+
}
383388
Err(e) => {
384389
gst::error!(
385390
CAT,
@@ -391,26 +396,46 @@ impl BaseTransformImpl for EdgeImpulseAudioInfer {
391396
}
392397
};
393398

394-
let slice_size = params.slice_size as usize;
399+
// Get the required number of raw audio samples from the model parameters
400+
// For FFI mode, we should use the raw sample count from metadata
401+
#[cfg(feature = "ffi")]
402+
let required_samples =
403+
{ edge_impulse_runner::ffi::ModelMetadata::get().raw_sample_count };
404+
#[cfg(not(feature = "ffi"))]
405+
let required_samples = { params.slice_size as usize };
406+
gst::debug!(
407+
CAT,
408+
obj = self.obj(),
409+
"Model expects {} raw audio samples for inference (ModelMetadata::get().raw_sample_count)",
410+
required_samples
411+
);
395412

396-
// Add new samples to buffer
413+
// Add new samples to buffer (convert S16LE to f32, no normalization)
397414
let mut sample_buffer = self.sample_buffer.lock().unwrap();
415+
let samples_len = samples.len();
398416
sample_buffer.extend(samples);
399417

400-
// Process when we have enough samples
401-
if sample_buffer.len() >= slice_size {
418+
gst::debug!(
419+
CAT,
420+
obj = self.obj(),
421+
"Buffer status: {} samples in buffer, need {} samples, received {} new samples (total: {})",
422+
sample_buffer.len(),
423+
required_samples,
424+
samples_len,
425+
sample_buffer.len() + samples_len
426+
);
427+
428+
// Run inference when we have enough samples
429+
while sample_buffer.len() >= required_samples {
402430
let now = std::time::Instant::now();
403-
404-
// Extract features for inference
405-
let features: Vec<f32> = sample_buffer.drain(..slice_size).collect();
406-
431+
// Take exactly the number of raw samples we need
432+
let features: Vec<f32> = sample_buffer.drain(..required_samples).collect();
407433
gst::debug!(
408434
CAT,
409435
obj = self.obj(),
410436
"Running inference with {} samples",
411437
features.len()
412438
);
413-
414439
// Run inference
415440
match model.infer(features, None) {
416441
Ok(result) => {
@@ -463,12 +488,6 @@ impl BaseTransformImpl for EdgeImpulseAudioInfer {
463488
}
464489
Err(e) => {
465490
gst::error!(CAT, obj = self.obj(), "Inference failed: {}", e);
466-
let s = crate::common::create_error_message(
467-
"audio",
468-
inbuf.pts().unwrap_or(gst::ClockTime::ZERO),
469-
e.to_string(),
470-
);
471-
let _ = self.obj().post_message(gst::message::Element::new(s));
472491
}
473492
}
474493
}
@@ -477,6 +496,101 @@ impl BaseTransformImpl for EdgeImpulseAudioInfer {
477496
state.model = Some(model);
478497
}
479498

499+
// Handle end-of-stream inference if we have remaining samples
500+
let is_eos = inbuf.size() == 0;
501+
if is_eos {
502+
let mut sample_buffer = self.sample_buffer.lock().unwrap();
503+
if !sample_buffer.is_empty() {
504+
// For EIM mode, we need to get the required samples from the model parameters
505+
// Since we don't have a model in EIM mode without a valid model path, we'll use a default
506+
let required_samples = 16000; // Default for most audio models
507+
508+
gst::debug!(
509+
CAT,
510+
obj = self.obj(),
511+
"End of stream reached with {} samples in buffer, running final inference",
512+
sample_buffer.len()
513+
);
514+
515+
// Take as many samples as we have, pad with zeros if needed
516+
let mut final_samples: Vec<f32> = sample_buffer.drain(..).collect();
517+
if final_samples.len() < required_samples {
518+
final_samples.extend(vec![0.0; required_samples - final_samples.len()]);
519+
gst::debug!(
520+
CAT,
521+
obj = self.obj(),
522+
"Using {} real samples + {} zero padding = {} total samples for classification",
523+
final_samples.len() - (required_samples - final_samples.len()),
524+
required_samples - final_samples.len(),
525+
final_samples.len()
526+
);
527+
} else {
528+
// Take exactly the required number of samples
529+
final_samples = final_samples[..required_samples].to_vec();
530+
gst::debug!(
531+
CAT,
532+
obj = self.obj(),
533+
"Using exactly {} raw audio samples for classification (no padding)",
534+
final_samples.len()
535+
);
536+
}
537+
538+
if let Some(mut model) = state.model.take() {
539+
let now = std::time::Instant::now();
540+
match model.infer(final_samples, None) {
541+
Ok(result) => {
542+
let elapsed = now.elapsed();
543+
let mut result_value = serde_json::to_value(&result.result).unwrap();
544+
if let Some(classification) = result_value.get_mut("classification") {
545+
if classification.is_array() {
546+
let mut map = serde_json::Map::new();
547+
for entry in classification.as_array().unwrap() {
548+
if let (Some(label), Some(value)) =
549+
(entry.get("label"), entry.get("value"))
550+
{
551+
if let (Some(label), Some(value)) =
552+
(label.as_str(), value.as_f64())
553+
{
554+
map.insert(
555+
label.to_string(),
556+
serde_json::Value::from(value),
557+
);
558+
}
559+
}
560+
}
561+
*classification = serde_json::Value::Object(map);
562+
}
563+
}
564+
let result_json =
565+
serde_json::to_string(&result_value).unwrap_or_else(|e| {
566+
gst::warning!(
567+
CAT,
568+
obj = self.obj(),
569+
"Failed to serialize result: {}",
570+
e
571+
);
572+
String::from("{}")
573+
});
574+
575+
let s = crate::common::create_inference_message(
576+
"audio",
577+
inbuf.pts().unwrap_or(gst::ClockTime::ZERO),
578+
"classification",
579+
result_json,
580+
elapsed.as_millis() as u32,
581+
);
582+
583+
let _ = self.obj().post_message(gst::message::Element::new(s));
584+
}
585+
Err(e) => {
586+
gst::error!(CAT, obj = self.obj(), "Final inference failed: {}", e);
587+
}
588+
}
589+
state.model = Some(model);
590+
}
591+
}
592+
}
593+
480594
Ok(gst::FlowSuccess::Ok)
481595
}
482596

0 commit comments

Comments
 (0)