Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions backend/internal/handler/dto/mappers.go
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,8 @@ func usageLogFromServiceUser(l *service.UsageLog) UsageLog {
ImageSize: l.ImageSize,
ImageInputSize: l.ImageInputSize,
ImageOutputSize: l.ImageOutputSize,
ImageOutputTokens: l.ImageOutputTokens,
ImageOutputCost: l.ImageOutputCost,
ImageSizeSource: l.ImageSizeSource,
ImageSizeBreakdown: l.ImageSizeBreakdown,
MediaType: l.MediaType,
Expand Down
2 changes: 2 additions & 0 deletions backend/internal/handler/dto/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,8 @@ type UsageLog struct {
ImageSize *string `json:"image_size"`
ImageInputSize *string `json:"image_input_size"`
ImageOutputSize *string `json:"image_output_size"`
ImageOutputTokens int `json:"image_output_tokens"`
ImageOutputCost float64 `json:"image_output_cost"`
ImageSizeSource *string `json:"image_size_source"`
ImageSizeBreakdown map[string]int `json:"image_size_breakdown"`
MediaType *string `json:"media_type"`
Expand Down
10 changes: 7 additions & 3 deletions backend/internal/service/billing_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ type ModelPricing struct {
LongContextInputMultiplier float64 // 长上下文整次会话输入倍率
LongContextOutputMultiplier float64 // 长上下文整次会话输出倍率
ImageOutputPricePerToken float64 // 图片输出 token 价格 (USD)
ImageOutputPriceExplicit bool // 是否由渠道定价显式设定(为 true 时即使 == 0 也不回退)
}

const (
Expand Down Expand Up @@ -409,7 +410,7 @@ func (s *BillingService) GetModelPricing(model string) (*ModelPricing, error) {
}

// GetModelPricingWithChannel 获取模型定价,渠道配置的价格覆盖默认值
// 仅覆盖渠道中非 nil 的价格字段,nil 字段使用默认定价
// 渠道存在时,未配置的图片输出价格归零(不回退到 LiteLLM)
func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing *ChannelModelPricing) (*ModelPricing, error) {
pricing, err := s.GetModelPricing(model)
if err != nil {
Expand Down Expand Up @@ -437,7 +438,10 @@ func (s *BillingService) GetModelPricingWithChannel(model string, channelPricing
}
if channelPricing.ImageOutputPrice != nil {
pricing.ImageOutputPricePerToken = *channelPricing.ImageOutputPrice
} else {
pricing.ImageOutputPricePerToken = 0
}
pricing.ImageOutputPriceExplicit = true
return pricing, nil
}

Expand Down Expand Up @@ -570,8 +574,8 @@ func (s *BillingService) computeTokenBreakdown(
// 图片输出 token 费用(独立费率)
if tokens.ImageOutputTokens > 0 {
imgPrice := pricing.ImageOutputPricePerToken
if imgPrice == 0 {
imgPrice = outputPrice // 回退到常规输出价格
if imgPrice == 0 && !pricing.ImageOutputPriceExplicit {
imgPrice = outputPrice
}
bd.ImageOutputCost = float64(tokens.ImageOutputTokens) * imgPrice
}
Expand Down
59 changes: 59 additions & 0 deletions backend/internal/service/billing_service_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -950,3 +950,62 @@ func TestGetModelPricingWithChannel_UnknownModelReturnsError(t *testing.T) {
require.Nil(t, pricing)
require.Contains(t, err.Error(), "pricing not found")
}

func TestGetModelPricingWithChannel_NilImageOutputPriceZerosAndMarksExplicit(t *testing.T) {
svc := newTestBillingService()

chPricing := &ChannelModelPricing{
InputPrice: testPtrFloat64(10e-6),
OutputPrice: testPtrFloat64(20e-6),
// ImageOutputPrice intentionally nil
}
pricing, err := svc.GetModelPricingWithChannel("claude-sonnet-4", chPricing)
require.NoError(t, err)

require.Equal(t, 0.0, pricing.ImageOutputPricePerToken)
require.True(t, pricing.ImageOutputPriceExplicit)
}

func TestComputeTokenBreakdown_ExplicitZeroImagePrice_NoFallback(t *testing.T) {
svc := newTestBillingService()

pricing := &ModelPricing{
InputPricePerToken: 3e-6,
OutputPricePerToken: 15e-6,
ImageOutputPricePerToken: 0,
ImageOutputPriceExplicit: true,
}
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 200,
ImageOutputTokens: 50,
}
bd := svc.computeTokenBreakdown(pricing, tokens, 1.0, "", false)

// ImageOutputTokens should NOT fall back to outputPrice
require.Equal(t, 0.0, bd.ImageOutputCost)
// textOutputTokens = 200 - 50 = 150
require.InDelta(t, 150*15e-6, bd.OutputCost, 1e-12)
}

func TestComputeTokenBreakdown_NonExplicitZeroImagePrice_FallsBackToOutput(t *testing.T) {
svc := newTestBillingService()

pricing := &ModelPricing{
InputPricePerToken: 3e-6,
OutputPricePerToken: 15e-6,
ImageOutputPricePerToken: 0,
ImageOutputPriceExplicit: false,
}
tokens := UsageTokens{
InputTokens: 100,
OutputTokens: 200,
ImageOutputTokens: 50,
}
bd := svc.computeTokenBreakdown(pricing, tokens, 1.0, "", false)

// Should fall back to outputPrice since not explicit
require.InDelta(t, 50*15e-6, bd.ImageOutputCost, 1e-12)
// textOutputTokens = 200 - 50 = 150
require.InDelta(t, 150*15e-6, bd.OutputCost, 1e-12)
}
7 changes: 5 additions & 2 deletions backend/internal/service/gateway_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -8837,8 +8837,11 @@ func (s *GatewayService) calculateRecordUsageCost(
imageMultiplier float64,
opts *recordUsageOpts,
) *CostBreakdown {
// 图片生成计费
// 图片生成:渠道定价为 token 计费时走 token 路径,否则走图片计费
if result.ImageCount > 0 {
if resolved := s.resolveChannelPricing(ctx, billingModel, apiKey); resolved != nil && resolved.Mode == BillingModeToken {
return s.calculateTokenCost(ctx, result, apiKey, billingModel, multiplier, opts)
}
return s.calculateImageCost(ctx, result, apiKey, billingModel, imageMultiplier)
}

Expand Down Expand Up @@ -9016,7 +9019,7 @@ func (s *GatewayService) buildRecordUsageLog(
SubscriptionID: optionalSubscriptionID(subscription),
CreatedAt: time.Now(),
}
if result.ImageCount > 0 {
if result.ImageCount > 0 && (cost == nil || cost.BillingMode != string(BillingModeToken)) {
usageLog.RateMultiplier = imageMultiplier
}
if cost != nil {
Expand Down
35 changes: 31 additions & 4 deletions backend/internal/service/model_pricing_resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ type ResolvedPricing struct {

// 是否支持缓存细分
SupportsCacheBreakdown bool

// 渠道定价原始配置(用于区间模式下获取 ImageOutputPrice)
channelPricing *ChannelModelPricing
}

// ModelPricingResolver 统一模型定价解析器。
Expand Down Expand Up @@ -71,8 +74,9 @@ func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput)
}
if mode == BillingModePerRequest || mode == BillingModeImage {
resolved := &ResolvedPricing{
Mode: mode,
Source: PricingSourceChannel,
Mode: mode,
Source: PricingSourceChannel,
channelPricing: chPricing,
}
r.applyRequestTierOverrides(chPricing, resolved)
return resolved
Expand All @@ -93,6 +97,7 @@ func (r *ModelPricingResolver) Resolve(ctx context.Context, input PricingInput)
// 2. 如果有 GroupID,尝试渠道覆盖
if chPricing != nil {
resolved.Source = PricingSourceChannel
resolved.channelPricing = chPricing
r.applyTokenOverrides(chPricing, resolved)
} else if input.GroupID != nil {
r.applyChannelOverrides(ctx, *input.GroupID, input.Model, resolved)
Expand Down Expand Up @@ -120,6 +125,7 @@ func (r *ModelPricingResolver) applyChannelOverrides(ctx context.Context, groupI
}

resolved.Source = PricingSourceChannel
resolved.channelPricing = chPricing
resolved.Mode = chPricing.BillingMode
if resolved.Mode == "" {
resolved.Mode = BillingModeToken
Expand All @@ -141,6 +147,16 @@ func (r *ModelPricingResolver) applyTokenOverrides(chPricing *ChannelModelPricin
// 如果有有效的区间定价,使用区间
if len(validIntervals) > 0 {
resolved.Intervals = validIntervals
// 区间不匹配时回退到 BasePricing,也需要覆盖图片价格
if resolved.BasePricing == nil {
resolved.BasePricing = &ModelPricing{}
}
if chPricing.ImageOutputPrice != nil {
resolved.BasePricing.ImageOutputPricePerToken = *chPricing.ImageOutputPrice
} else {
resolved.BasePricing.ImageOutputPricePerToken = 0
}
resolved.BasePricing.ImageOutputPriceExplicit = true
return
}

Expand All @@ -166,9 +182,13 @@ func (r *ModelPricingResolver) applyTokenOverrides(chPricing *ChannelModelPricin
resolved.BasePricing.CacheReadPricePerToken = *chPricing.CacheReadPrice
resolved.BasePricing.CacheReadPricePerTokenPriority = *chPricing.CacheReadPrice
}
// 渠道定价覆盖一切:显式配置则用配置值,未配置则归零(不回退到 LiteLLM)
if chPricing.ImageOutputPrice != nil {
resolved.BasePricing.ImageOutputPricePerToken = *chPricing.ImageOutputPrice
} else {
resolved.BasePricing.ImageOutputPricePerToken = 0
}
resolved.BasePricing.ImageOutputPriceExplicit = true
}

// applyRequestTierOverrides 应用按次/图片模式的渠道覆盖
Expand Down Expand Up @@ -205,11 +225,11 @@ func (r *ModelPricingResolver) GetIntervalPricing(resolved *ResolvedPricing, tot
return resolved.BasePricing
}

return intervalToModelPricing(iv, resolved.SupportsCacheBreakdown)
return intervalToModelPricing(iv, resolved.SupportsCacheBreakdown, resolved.channelPricing)
}

// intervalToModelPricing 将区间定价转换为 ModelPricing
func intervalToModelPricing(iv *PricingInterval, supportsCacheBreakdown bool) *ModelPricing {
func intervalToModelPricing(iv *PricingInterval, supportsCacheBreakdown bool, chPricing *ChannelModelPricing) *ModelPricing {
pricing := &ModelPricing{
SupportsCacheBreakdown: supportsCacheBreakdown,
}
Expand All @@ -230,6 +250,13 @@ func intervalToModelPricing(iv *PricingInterval, supportsCacheBreakdown bool) *M
pricing.CacheReadPricePerToken = *iv.CacheReadPrice
pricing.CacheReadPricePerTokenPriority = *iv.CacheReadPrice
}
// 渠道定价存在时,ImageOutputPrice 显式覆盖
if chPricing != nil {
pricing.ImageOutputPriceExplicit = true
if chPricing.ImageOutputPrice != nil {
pricing.ImageOutputPricePerToken = *chPricing.ImageOutputPrice
}
}
return pricing
}

Expand Down
66 changes: 66 additions & 0 deletions backend/internal/service/model_pricing_resolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -661,3 +661,69 @@ func TestFilterValidIntervals(t *testing.T) {
})
}
}

// ===========================================================================
// 9. ImageOutputPriceExplicit tests
// ===========================================================================

func TestApplyTokenOverrides_FlatSetsImageOutputPriceExplicit(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(3e-6),
OutputPrice: testPtrFloat64(15e-6),
// ImageOutputPrice intentionally nil
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})

require.Equal(t, PricingSourceChannel, resolved.Source)
require.True(t, resolved.BasePricing.ImageOutputPriceExplicit)
require.Equal(t, 0.0, resolved.BasePricing.ImageOutputPricePerToken)
}

func TestApplyTokenOverrides_FlatWithImageOutputPriceSetsExplicit(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
InputPrice: testPtrFloat64(3e-6),
OutputPrice: testPtrFloat64(15e-6),
ImageOutputPrice: testPtrFloat64(50e-6),
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})

require.True(t, resolved.BasePricing.ImageOutputPriceExplicit)
require.InDelta(t, 50e-6, resolved.BasePricing.ImageOutputPricePerToken, 1e-12)
}

func TestApplyTokenOverrides_IntervalSetsImageOutputPriceExplicit(t *testing.T) {
r := newResolverWithChannel(t, []ChannelModelPricing{{
Platform: "anthropic",
Models: []string{"claude-sonnet-4"},
BillingMode: BillingModeToken,
// No ImageOutputPrice
Intervals: []PricingInterval{
{MinTokens: 0, MaxTokens: testPtrInt(100000), InputPrice: testPtrFloat64(3e-6), OutputPrice: testPtrFloat64(15e-6)},
},
}})
resolved := r.Resolve(context.Background(), PricingInput{
Model: "claude-sonnet-4",
GroupID: groupIDPtr(),
})

// BasePricing should have explicit mark (for interval fallback)
require.True(t, resolved.BasePricing.ImageOutputPriceExplicit)
require.Equal(t, 0.0, resolved.BasePricing.ImageOutputPricePerToken)

// intervalToModelPricing should also have explicit mark
pricing := r.GetIntervalPricing(resolved, 50000)
require.True(t, pricing.ImageOutputPriceExplicit)
require.Equal(t, 0.0, pricing.ImageOutputPricePerToken)
}
7 changes: 5 additions & 2 deletions backend/internal/service/openai_gateway_service.go
Original file line number Diff line number Diff line change
Expand Up @@ -5801,7 +5801,7 @@ func (s *OpenAIGatewayService) RecordUsage(ctx context.Context, input *OpenAIRec
usageLog.TotalCost = cost.TotalCost
usageLog.ActualCost = cost.ActualCost
}
if result.ImageCount > 0 {
if result.ImageCount > 0 && (cost == nil || cost.BillingMode != string(BillingModeToken)) {
usageLog.RateMultiplier = imageMultiplier
} else {
usageLog.RateMultiplier = multiplier
Expand Down Expand Up @@ -5895,7 +5895,10 @@ func (s *OpenAIGatewayService) calculateOpenAIRecordUsageCost(
) (*CostBreakdown, error) {
billingModel := firstUsageBillingModel(billingModels)
if result != nil && result.ImageCount > 0 {
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, imageMultiplier), nil
// 渠道定价为 token 计费时走 token 路径,否则走图片计费
if resolved := s.resolveOpenAIChannelPricing(ctx, billingModel, apiKey); resolved == nil || resolved.Mode != BillingModeToken {
return s.calculateOpenAIImageCost(ctx, billingModel, apiKey, result, imageMultiplier), nil
}
}
if len(billingModels) == 0 || billingModel == "" {
return nil, errors.New("openai usage billing model is empty")
Expand Down
Loading
Loading