diff --git a/assemblyai/__init__.py b/assemblyai/__init__.py index 740a93f..241d536 100644 --- a/assemblyai/__init__.py +++ b/assemblyai/__init__.py @@ -53,6 +53,7 @@ Sentiment, SentimentType, Settings, + SpeakerOptions, SpeechModel, StatusResult, SummarizationModel, @@ -114,6 +115,7 @@ "Sentiment", "SentimentType", "Settings", + "SpeakerOptions", "SpeechModel", "StatusResult", "SummarizationModel", diff --git a/assemblyai/__version__.py b/assemblyai/__version__.py index ef838c9..3ce4899 100644 --- a/assemblyai/__version__.py +++ b/assemblyai/__version__.py @@ -1 +1 @@ -__version__ = "0.41.3" +__version__ = "0.41.4" diff --git a/assemblyai/types.py b/assemblyai/types.py index ca1aa51..ad8a6fa 100644 --- a/assemblyai/types.py +++ b/assemblyai/types.py @@ -489,6 +489,43 @@ class SpeechModel(str, Enum): "The model optimized for accuracy, low latency, ease of use, and multi-language support" +class SpeakerOptions(BaseModel): + """ + Speaker options for controlling speaker diarization parameters + """ + + min_speakers_expected: Optional[int] = Field( + None, ge=1, description="Minimum number of speakers expected in the audio" + ) + max_speakers_expected: Optional[int] = Field( + None, ge=1, description="Maximum number of speakers expected in the audio" + ) + + if pydantic_v2: + + @field_validator("max_speakers_expected") + @classmethod + def validate_max_speakers(cls, v, info): + if v is not None and info.data.get("min_speakers_expected") is not None: + min_speakers = info.data["min_speakers_expected"] + if v < min_speakers: + raise ValueError( + "max_speakers_expected must be greater than or equal to min_speakers_expected" + ) + return v + else: + + @validator("max_speakers_expected") + def validate_max_speakers(cls, v, values): + if v is not None and values.get("min_speakers_expected") is not None: + min_speakers = values["min_speakers_expected"] + if v < min_speakers: + raise ValueError( + "max_speakers_expected must be greater than or equal to min_speakers_expected" + ) + return v + + class RawTranscriptionConfig(BaseModel): language_code: Optional[Union[str, LanguageCode]] = None """ @@ -546,6 +583,9 @@ class RawTranscriptionConfig(BaseModel): speakers_expected: Optional[int] = None "The number of speakers you expect to be in your audio file." + speaker_options: Optional[SpeakerOptions] = None + "Advanced options for controlling speaker diarization parameters." + content_safety: Optional[bool] = None "Enable Content Safety Detection." @@ -633,6 +673,7 @@ def __init__( redact_pii_sub: Optional[PIISubstitutionPolicy] = None, speaker_labels: Optional[bool] = None, speakers_expected: Optional[int] = None, + speaker_options: Optional[SpeakerOptions] = None, content_safety: Optional[bool] = None, content_safety_confidence: Optional[int] = None, iab_categories: Optional[bool] = None, @@ -675,6 +716,7 @@ def __init__( redact_pii_sub: The replacement logic for detected PII. speaker_labels: Enable Speaker Diarization. speakers_expected: The number of speakers you expect to hear in your audio file. Up to 10 speakers are supported. + speaker_options: Advanced options for controlling speaker diarization parameters, including min and max speakers expected. content_safety: Enable Content Safety Detection. iab_categories: Enable Topic Detection. custom_spelling: Customize how words are spelled and formatted using to and from values. @@ -722,7 +764,7 @@ def __init__( redact_pii_policies, redact_pii_sub, ) - self.set_speaker_diarization(speaker_labels, speakers_expected) + self.set_speaker_diarization(speaker_labels, speakers_expected, speaker_options) self.set_content_safety(content_safety, content_safety_confidence) self.iab_categories = iab_categories self.set_custom_spelling(custom_spelling, override=True) @@ -934,6 +976,12 @@ def speakers_expected(self) -> Optional[int]: return self._raw_transcription_config.speakers_expected + @property + def speaker_options(self) -> Optional[SpeakerOptions]: + "Returns the advanced speaker diarization options." + + return self._raw_transcription_config.speaker_options + @property def content_safety(self) -> Optional[bool]: "Returns the status of the Content Safety feature." @@ -1162,6 +1210,7 @@ def set_speaker_diarization( self, enable: Optional[bool] = True, speakers_expected: Optional[int] = None, + speaker_options: Optional[SpeakerOptions] = None, ) -> Self: """ Whether to enable Speaker Diarization on the transcript. @@ -1169,14 +1218,24 @@ def set_speaker_diarization( Args: `enable`: Enable Speaker Diarization `speakers_expected`: The number of speakers in the audio file. + `speaker_options`: Advanced options for controlling speaker diarization parameters. """ - if not enable: + # If enable is explicitly False, clear all speaker settings + if enable is False: self._raw_transcription_config.speaker_labels = None self._raw_transcription_config.speakers_expected = None + self._raw_transcription_config.speaker_options = None + # If enable is True or None, set the values (allow setting speaker_options even when enable is None) else: - self._raw_transcription_config.speaker_labels = True - self._raw_transcription_config.speakers_expected = speakers_expected + # Only set speaker_labels to True if enable is explicitly True + if enable is True: + self._raw_transcription_config.speaker_labels = True + # Always set these if provided, regardless of enable value + if speakers_expected is not None: + self._raw_transcription_config.speakers_expected = speakers_expected + if speaker_options is not None: + self._raw_transcription_config.speaker_options = speaker_options return self @@ -1712,6 +1771,9 @@ class BaseTranscript(BaseModel): speakers_expected: Optional[int] = None "The number of speakers you expect to be in your audio file." + speaker_options: Optional[SpeakerOptions] = None + "Advanced options for controlling speaker diarization parameters." + content_safety: Optional[bool] = None "Enable Content Safety Detection." diff --git a/tests/unit/test_speaker_options.py b/tests/unit/test_speaker_options.py new file mode 100644 index 0000000..122c498 --- /dev/null +++ b/tests/unit/test_speaker_options.py @@ -0,0 +1,96 @@ +import pytest + +import assemblyai as aai + + +def test_speaker_options_creation(): + """Test that SpeakerOptions can be created with valid parameters.""" + speaker_options = aai.SpeakerOptions( + min_speakers_expected=2, max_speakers_expected=5 + ) + assert speaker_options.min_speakers_expected == 2 + assert speaker_options.max_speakers_expected == 5 + + +def test_speaker_options_validation(): + """Test that SpeakerOptions validates max >= min.""" + with pytest.raises( + ValueError, + match="max_speakers_expected must be greater than or equal to min_speakers_expected", + ): + aai.SpeakerOptions(min_speakers_expected=5, max_speakers_expected=2) + + +def test_speaker_options_min_only(): + """Test that SpeakerOptions can be created with only min_speakers_expected.""" + speaker_options = aai.SpeakerOptions(min_speakers_expected=3) + assert speaker_options.min_speakers_expected == 3 + assert speaker_options.max_speakers_expected is None + + +def test_speaker_options_max_only(): + """Test that SpeakerOptions can be created with only max_speakers_expected.""" + speaker_options = aai.SpeakerOptions(max_speakers_expected=5) + assert speaker_options.min_speakers_expected is None + assert speaker_options.max_speakers_expected == 5 + + +def test_transcription_config_with_speaker_options(): + """Test that TranscriptionConfig accepts speaker_options parameter.""" + speaker_options = aai.SpeakerOptions( + min_speakers_expected=2, max_speakers_expected=4 + ) + + config = aai.TranscriptionConfig( + speaker_labels=True, speaker_options=speaker_options + ) + + assert config.speaker_labels is True + assert config.speaker_options == speaker_options + assert config.speaker_options.min_speakers_expected == 2 + assert config.speaker_options.max_speakers_expected == 4 + + +def test_set_speaker_diarization_with_speaker_options(): + """Test setting speaker diarization with speaker_options.""" + speaker_options = aai.SpeakerOptions( + min_speakers_expected=1, max_speakers_expected=3 + ) + + config = aai.TranscriptionConfig() + config.set_speaker_diarization( + enable=True, speakers_expected=2, speaker_options=speaker_options + ) + + assert config.speaker_labels is True + assert config.speakers_expected == 2 + assert config.speaker_options == speaker_options + + +def test_set_speaker_diarization_disable_clears_speaker_options(): + """Test that disabling speaker diarization clears speaker_options.""" + speaker_options = aai.SpeakerOptions(min_speakers_expected=2) + + config = aai.TranscriptionConfig() + config.set_speaker_diarization(enable=True, speaker_options=speaker_options) + + # Verify it was set + assert config.speaker_options == speaker_options + + # Now disable + config.set_speaker_diarization(enable=False) + + assert config.speaker_labels is None + assert config.speakers_expected is None + assert config.speaker_options is None + + +def test_speaker_options_in_raw_config(): + """Test that speaker_options is properly set in the raw config.""" + speaker_options = aai.SpeakerOptions( + min_speakers_expected=2, max_speakers_expected=5 + ) + + config = aai.TranscriptionConfig(speaker_options=speaker_options) + + assert config.raw.speaker_options == speaker_options