Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions assemblyai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
Sentiment,
SentimentType,
Settings,
SpeakerOptions,
SpeechModel,
StatusResult,
SummarizationModel,
Expand Down Expand Up @@ -114,6 +115,7 @@
"Sentiment",
"SentimentType",
"Settings",
"SpeakerOptions",
"SpeechModel",
"StatusResult",
"SummarizationModel",
Expand Down
2 changes: 1 addition & 1 deletion assemblyai/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.41.3"
__version__ = "0.41.4"
70 changes: 66 additions & 4 deletions assemblyai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,43 @@ class SpeechModel(str, Enum):
"The model optimized for accuracy, low latency, ease of use, and multi-language support"


class SpeakerOptions(BaseModel):
"""
Speaker options for controlling speaker diarization parameters
"""

min_speakers_expected: Optional[int] = Field(
None, ge=1, description="Minimum number of speakers expected in the audio"
)
max_speakers_expected: Optional[int] = Field(
None, ge=1, description="Maximum number of speakers expected in the audio"
)

if pydantic_v2:

@field_validator("max_speakers_expected")
@classmethod
def validate_max_speakers(cls, v, info):
if v is not None and info.data.get("min_speakers_expected") is not None:
min_speakers = info.data["min_speakers_expected"]
if v < min_speakers:
raise ValueError(
"max_speakers_expected must be greater than or equal to min_speakers_expected"
)
return v
else:

@validator("max_speakers_expected")
def validate_max_speakers(cls, v, values):
if v is not None and values.get("min_speakers_expected") is not None:
min_speakers = values["min_speakers_expected"]
if v < min_speakers:
raise ValueError(
"max_speakers_expected must be greater than or equal to min_speakers_expected"
)
return v


class RawTranscriptionConfig(BaseModel):
language_code: Optional[Union[str, LanguageCode]] = None
"""
Expand Down Expand Up @@ -546,6 +583,9 @@ class RawTranscriptionConfig(BaseModel):
speakers_expected: Optional[int] = None
"The number of speakers you expect to be in your audio file."

speaker_options: Optional[SpeakerOptions] = None
"Advanced options for controlling speaker diarization parameters."

content_safety: Optional[bool] = None
"Enable Content Safety Detection."

Expand Down Expand Up @@ -633,6 +673,7 @@ def __init__(
redact_pii_sub: Optional[PIISubstitutionPolicy] = None,
speaker_labels: Optional[bool] = None,
speakers_expected: Optional[int] = None,
speaker_options: Optional[SpeakerOptions] = None,
content_safety: Optional[bool] = None,
content_safety_confidence: Optional[int] = None,
iab_categories: Optional[bool] = None,
Expand Down Expand Up @@ -675,6 +716,7 @@ def __init__(
redact_pii_sub: The replacement logic for detected PII.
speaker_labels: Enable Speaker Diarization.
speakers_expected: The number of speakers you expect to hear in your audio file. Up to 10 speakers are supported.
speaker_options: Advanced options for controlling speaker diarization parameters, including min and max speakers expected.
content_safety: Enable Content Safety Detection.
iab_categories: Enable Topic Detection.
custom_spelling: Customize how words are spelled and formatted using to and from values.
Expand Down Expand Up @@ -722,7 +764,7 @@ def __init__(
redact_pii_policies,
redact_pii_sub,
)
self.set_speaker_diarization(speaker_labels, speakers_expected)
self.set_speaker_diarization(speaker_labels, speakers_expected, speaker_options)
self.set_content_safety(content_safety, content_safety_confidence)
self.iab_categories = iab_categories
self.set_custom_spelling(custom_spelling, override=True)
Expand Down Expand Up @@ -934,6 +976,12 @@ def speakers_expected(self) -> Optional[int]:

return self._raw_transcription_config.speakers_expected

@property
def speaker_options(self) -> Optional[SpeakerOptions]:
"Returns the advanced speaker diarization options."

return self._raw_transcription_config.speaker_options

@property
def content_safety(self) -> Optional[bool]:
"Returns the status of the Content Safety feature."
Expand Down Expand Up @@ -1162,21 +1210,32 @@ def set_speaker_diarization(
self,
enable: Optional[bool] = True,
speakers_expected: Optional[int] = None,
speaker_options: Optional[SpeakerOptions] = None,
) -> Self:
"""
Whether to enable Speaker Diarization on the transcript.

Args:
`enable`: Enable Speaker Diarization
`speakers_expected`: The number of speakers in the audio file.
`speaker_options`: Advanced options for controlling speaker diarization parameters.
"""

if not enable:
# If enable is explicitly False, clear all speaker settings
if enable is False:
self._raw_transcription_config.speaker_labels = None
self._raw_transcription_config.speakers_expected = None
self._raw_transcription_config.speaker_options = None
# If enable is True or None, set the values (allow setting speaker_options even when enable is None)
else:
self._raw_transcription_config.speaker_labels = True
self._raw_transcription_config.speakers_expected = speakers_expected
# Only set speaker_labels to True if enable is explicitly True
if enable is True:
self._raw_transcription_config.speaker_labels = True
# Always set these if provided, regardless of enable value
if speakers_expected is not None:
self._raw_transcription_config.speakers_expected = speakers_expected
if speaker_options is not None:
self._raw_transcription_config.speaker_options = speaker_options

return self

Expand Down Expand Up @@ -1712,6 +1771,9 @@ class BaseTranscript(BaseModel):
speakers_expected: Optional[int] = None
"The number of speakers you expect to be in your audio file."

speaker_options: Optional[SpeakerOptions] = None
"Advanced options for controlling speaker diarization parameters."

content_safety: Optional[bool] = None
"Enable Content Safety Detection."

Expand Down
96 changes: 96 additions & 0 deletions tests/unit/test_speaker_options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import pytest

import assemblyai as aai


def test_speaker_options_creation():
"""Test that SpeakerOptions can be created with valid parameters."""
speaker_options = aai.SpeakerOptions(
min_speakers_expected=2, max_speakers_expected=5
)
assert speaker_options.min_speakers_expected == 2
assert speaker_options.max_speakers_expected == 5


def test_speaker_options_validation():
"""Test that SpeakerOptions validates max >= min."""
with pytest.raises(
ValueError,
match="max_speakers_expected must be greater than or equal to min_speakers_expected",
):
aai.SpeakerOptions(min_speakers_expected=5, max_speakers_expected=2)


def test_speaker_options_min_only():
"""Test that SpeakerOptions can be created with only min_speakers_expected."""
speaker_options = aai.SpeakerOptions(min_speakers_expected=3)
assert speaker_options.min_speakers_expected == 3
assert speaker_options.max_speakers_expected is None


def test_speaker_options_max_only():
"""Test that SpeakerOptions can be created with only max_speakers_expected."""
speaker_options = aai.SpeakerOptions(max_speakers_expected=5)
assert speaker_options.min_speakers_expected is None
assert speaker_options.max_speakers_expected == 5


def test_transcription_config_with_speaker_options():
"""Test that TranscriptionConfig accepts speaker_options parameter."""
speaker_options = aai.SpeakerOptions(
min_speakers_expected=2, max_speakers_expected=4
)

config = aai.TranscriptionConfig(
speaker_labels=True, speaker_options=speaker_options
)

assert config.speaker_labels is True
assert config.speaker_options == speaker_options
assert config.speaker_options.min_speakers_expected == 2
assert config.speaker_options.max_speakers_expected == 4


def test_set_speaker_diarization_with_speaker_options():
"""Test setting speaker diarization with speaker_options."""
speaker_options = aai.SpeakerOptions(
min_speakers_expected=1, max_speakers_expected=3
)

config = aai.TranscriptionConfig()
config.set_speaker_diarization(
enable=True, speakers_expected=2, speaker_options=speaker_options
)

assert config.speaker_labels is True
assert config.speakers_expected == 2
assert config.speaker_options == speaker_options


def test_set_speaker_diarization_disable_clears_speaker_options():
"""Test that disabling speaker diarization clears speaker_options."""
speaker_options = aai.SpeakerOptions(min_speakers_expected=2)

config = aai.TranscriptionConfig()
config.set_speaker_diarization(enable=True, speaker_options=speaker_options)

# Verify it was set
assert config.speaker_options == speaker_options

# Now disable
config.set_speaker_diarization(enable=False)

assert config.speaker_labels is None
assert config.speakers_expected is None
assert config.speaker_options is None


def test_speaker_options_in_raw_config():
"""Test that speaker_options is properly set in the raw config."""
speaker_options = aai.SpeakerOptions(
min_speakers_expected=2, max_speakers_expected=5
)

config = aai.TranscriptionConfig(speaker_options=speaker_options)

assert config.raw.speaker_options == speaker_options