@@ -277,7 +277,7 @@ impl BaseTransformImpl for EdgeImpulseAudioInfer {
277277 // Copy the input buffer to the output buffer (passthrough)
278278 out_map. copy_from_slice ( & in_map) ;
279279
280- // Convert input samples from S16LE to f32 normalized [-1, 1] for inference
280+ // Convert input samples from S16LE to f32 for inference
281281 let samples: Vec < f32 > = in_map
282282 . chunks_exact ( 2 )
283283 . map ( |chunk| {
@@ -378,8 +378,13 @@ impl BaseTransformImpl for EdgeImpulseAudioInfer {
378378 } ;
379379
380380 if let Some ( mut model) = model {
381+ gst:: debug!( CAT , obj = self . obj( ) , "Got model, getting parameters..." ) ;
382+
381383 let params = match model. parameters ( ) {
382- Ok ( p) => p,
384+ Ok ( p) => {
385+ gst:: debug!( CAT , obj = self . obj( ) , "Successfully got model parameters" ) ;
386+ p
387+ }
383388 Err ( e) => {
384389 gst:: error!(
385390 CAT ,
@@ -391,26 +396,46 @@ impl BaseTransformImpl for EdgeImpulseAudioInfer {
391396 }
392397 } ;
393398
394- let slice_size = params. slice_size as usize ;
399+ // Get the required number of raw audio samples from the model parameters
400+ // For FFI mode, we should use the raw sample count from metadata
401+ #[ cfg( feature = "ffi" ) ]
402+ let required_samples =
403+ { edge_impulse_runner:: ffi:: ModelMetadata :: get ( ) . raw_sample_count } ;
404+ #[ cfg( not( feature = "ffi" ) ) ]
405+ let required_samples = { params. slice_size as usize } ;
406+ gst:: debug!(
407+ CAT ,
408+ obj = self . obj( ) ,
409+ "Model expects {} raw audio samples for inference (ModelMetadata::get().raw_sample_count)" ,
410+ required_samples
411+ ) ;
395412
396- // Add new samples to buffer
413+ // Add new samples to buffer (convert S16LE to f32, no normalization)
397414 let mut sample_buffer = self . sample_buffer . lock ( ) . unwrap ( ) ;
415+ let samples_len = samples. len ( ) ;
398416 sample_buffer. extend ( samples) ;
399417
400- // Process when we have enough samples
401- if sample_buffer. len ( ) >= slice_size {
418+ gst:: debug!(
419+ CAT ,
420+ obj = self . obj( ) ,
421+ "Buffer status: {} samples in buffer, need {} samples, received {} new samples (total: {})" ,
422+ sample_buffer. len( ) ,
423+ required_samples,
424+ samples_len,
425+ sample_buffer. len( ) + samples_len
426+ ) ;
427+
428+ // Run inference when we have enough samples
429+ while sample_buffer. len ( ) >= required_samples {
402430 let now = std:: time:: Instant :: now ( ) ;
403-
404- // Extract features for inference
405- let features: Vec < f32 > = sample_buffer. drain ( ..slice_size) . collect ( ) ;
406-
431+ // Take exactly the number of raw samples we need
432+ let features: Vec < f32 > = sample_buffer. drain ( ..required_samples) . collect ( ) ;
407433 gst:: debug!(
408434 CAT ,
409435 obj = self . obj( ) ,
410436 "Running inference with {} samples" ,
411437 features. len( )
412438 ) ;
413-
414439 // Run inference
415440 match model. infer ( features, None ) {
416441 Ok ( result) => {
@@ -463,12 +488,6 @@ impl BaseTransformImpl for EdgeImpulseAudioInfer {
463488 }
464489 Err ( e) => {
465490 gst:: error!( CAT , obj = self . obj( ) , "Inference failed: {}" , e) ;
466- let s = crate :: common:: create_error_message (
467- "audio" ,
468- inbuf. pts ( ) . unwrap_or ( gst:: ClockTime :: ZERO ) ,
469- e. to_string ( ) ,
470- ) ;
471- let _ = self . obj ( ) . post_message ( gst:: message:: Element :: new ( s) ) ;
472491 }
473492 }
474493 }
@@ -477,6 +496,101 @@ impl BaseTransformImpl for EdgeImpulseAudioInfer {
477496 state. model = Some ( model) ;
478497 }
479498
499+ // Handle end-of-stream inference if we have remaining samples
500+ let is_eos = inbuf. size ( ) == 0 ;
501+ if is_eos {
502+ let mut sample_buffer = self . sample_buffer . lock ( ) . unwrap ( ) ;
503+ if !sample_buffer. is_empty ( ) {
504+ // For EIM mode, we need to get the required samples from the model parameters
505+ // Since we don't have a model in EIM mode without a valid model path, we'll use a default
506+ let required_samples = 16000 ; // Default for most audio models
507+
508+ gst:: debug!(
509+ CAT ,
510+ obj = self . obj( ) ,
511+ "End of stream reached with {} samples in buffer, running final inference" ,
512+ sample_buffer. len( )
513+ ) ;
514+
515+ // Take as many samples as we have, pad with zeros if needed
516+ let mut final_samples: Vec < f32 > = sample_buffer. drain ( ..) . collect ( ) ;
517+ if final_samples. len ( ) < required_samples {
518+ final_samples. extend ( vec ! [ 0.0 ; required_samples - final_samples. len( ) ] ) ;
519+ gst:: debug!(
520+ CAT ,
521+ obj = self . obj( ) ,
522+ "Using {} real samples + {} zero padding = {} total samples for classification" ,
523+ final_samples. len( ) - ( required_samples - final_samples. len( ) ) ,
524+ required_samples - final_samples. len( ) ,
525+ final_samples. len( )
526+ ) ;
527+ } else {
528+ // Take exactly the required number of samples
529+ final_samples = final_samples[ ..required_samples] . to_vec ( ) ;
530+ gst:: debug!(
531+ CAT ,
532+ obj = self . obj( ) ,
533+ "Using exactly {} raw audio samples for classification (no padding)" ,
534+ final_samples. len( )
535+ ) ;
536+ }
537+
538+ if let Some ( mut model) = state. model . take ( ) {
539+ let now = std:: time:: Instant :: now ( ) ;
540+ match model. infer ( final_samples, None ) {
541+ Ok ( result) => {
542+ let elapsed = now. elapsed ( ) ;
543+ let mut result_value = serde_json:: to_value ( & result. result ) . unwrap ( ) ;
544+ if let Some ( classification) = result_value. get_mut ( "classification" ) {
545+ if classification. is_array ( ) {
546+ let mut map = serde_json:: Map :: new ( ) ;
547+ for entry in classification. as_array ( ) . unwrap ( ) {
548+ if let ( Some ( label) , Some ( value) ) =
549+ ( entry. get ( "label" ) , entry. get ( "value" ) )
550+ {
551+ if let ( Some ( label) , Some ( value) ) =
552+ ( label. as_str ( ) , value. as_f64 ( ) )
553+ {
554+ map. insert (
555+ label. to_string ( ) ,
556+ serde_json:: Value :: from ( value) ,
557+ ) ;
558+ }
559+ }
560+ }
561+ * classification = serde_json:: Value :: Object ( map) ;
562+ }
563+ }
564+ let result_json =
565+ serde_json:: to_string ( & result_value) . unwrap_or_else ( |e| {
566+ gst:: warning!(
567+ CAT ,
568+ obj = self . obj( ) ,
569+ "Failed to serialize result: {}" ,
570+ e
571+ ) ;
572+ String :: from ( "{}" )
573+ } ) ;
574+
575+ let s = crate :: common:: create_inference_message (
576+ "audio" ,
577+ inbuf. pts ( ) . unwrap_or ( gst:: ClockTime :: ZERO ) ,
578+ "classification" ,
579+ result_json,
580+ elapsed. as_millis ( ) as u32 ,
581+ ) ;
582+
583+ let _ = self . obj ( ) . post_message ( gst:: message:: Element :: new ( s) ) ;
584+ }
585+ Err ( e) => {
586+ gst:: error!( CAT , obj = self . obj( ) , "Final inference failed: {}" , e) ;
587+ }
588+ }
589+ state. model = Some ( model) ;
590+ }
591+ }
592+ }
593+
480594 Ok ( gst:: FlowSuccess :: Ok )
481595 }
482596
0 commit comments