Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ TAG := $(shell git rev-list --tags --max-count=1)
VERSION := $(shell git describe --tags ${TAG})
.PHONY: build check fmt lint test test-race vet test-cover-html help install proto ui compose-up-dev
.DEFAULT_GOAL := build
PROTON_COMMIT := "4144445eb0f9cbd1a801a3d0aa5cfce4cc0ea551"
PROTON_COMMIT := "b1687af73f994fa9612a023c850aa97c35735af8"

ui:
@echo " > generating ui build"
Expand Down
5 changes: 4 additions & 1 deletion billing/customer/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ func (s *Service) Update(ctx context.Context, customer Customer) (Customer, erro
return Customer{}, err
}

// Always infer org_id from existing customer (ignore from request for security)
customer.OrgID = existingCustomer.OrgID

// update a customer in stripe
stripeCustomer, err := s.stripeClient.Customers.Update(existingCustomer.ProviderID, &stripe.CustomerParams{
Params: stripe.Params{
Expand All @@ -184,7 +187,7 @@ func (s *Service) Update(ctx context.Context, customer Customer) (Customer, erro
Name: &customer.Name,
Phone: &customer.Phone,
Metadata: map[string]string{
"org_id": customer.OrgID,
"org_id": existingCustomer.OrgID,
"managed_by": "frontier",
},
})
Expand Down
25 changes: 15 additions & 10 deletions internal/api/v1beta1connect/billing_customer.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,9 +105,9 @@ func (h *ConnectHandler) UpdateBillingAccount(ctx context.Context, request *conn
}
}

// Ignore org_id from request - it will be inferred from billing account ID
updatedCustomer, err := h.customerService.Update(ctx, customer.Customer{
ID: request.Msg.GetId(),
OrgID: request.Msg.GetOrgId(),
Name: request.Msg.GetBody().GetName(),
Email: request.Msg.GetBody().GetEmail(),
Phone: request.Msg.GetBody().GetPhone(),
Expand All @@ -119,7 +119,6 @@ func (h *ConnectHandler) UpdateBillingAccount(ctx context.Context, request *conn
if err != nil {
errorLogger.LogServiceError(ctx, request, "UpdateBillingAccount.Update", err,
zap.String("customer_id", request.Msg.GetId()),
zap.String("org_id", request.Msg.GetOrgId()),
zap.String("customer_name", request.Msg.GetBody().GetName()),
zap.String("customer_email", request.Msg.GetBody().GetEmail()),
zap.String("currency", request.Msg.GetBody().GetCurrency()))
Expand Down Expand Up @@ -447,14 +446,20 @@ func (h *ConnectHandler) UpdateBillingAccountDetails(ctx context.Context, reques
return nil, connect.NewError(connect.CodeInternal, ErrInternalServerError)
}

// Add audit log
audit.GetAuditor(ctx, request.Msg.GetOrgId()).LogWithAttrs(audit.BillingAccountDetailsUpdatedEvent, audit.Target{
ID: request.Msg.GetId(),
Type: "billing_account",
}, map[string]string{
"credit_min": fmt.Sprintf("%d", details.CreditMin),
"due_in_days": fmt.Sprintf("%d", details.DueInDays),
})
// Add audit log - infer org_id from billing account
customerOb, err := h.customerService.GetByID(ctx, request.Msg.GetId())
if err == nil {
audit.GetAuditor(ctx, customerOb.OrgID).LogWithAttrs(audit.BillingAccountDetailsUpdatedEvent, audit.Target{
ID: request.Msg.GetId(),
Type: "billing_account",
}, map[string]string{
"credit_min": fmt.Sprintf("%d", details.CreditMin),
"due_in_days": fmt.Sprintf("%d", details.DueInDays),
})
} else {
errorLogger.LogServiceError(ctx, request, "UpdateBillingAccountDetails.GetByID", err,
zap.String("customer_id", request.Msg.GetId()))
}

return connect.NewResponse(&frontierv1beta1.UpdateBillingAccountDetailsResponse{}), nil
}
Expand Down
8 changes: 8 additions & 0 deletions internal/api/v1beta1connect/billing_customer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,14 @@ func TestHandler_UpdateBillingAccountDetails(t *testing.T) {
if tt.request.Msg.GetDueInDays() >= 0 {
mockCustomerService.EXPECT().UpdateDetails(mock.Anything, tt.request.Msg.GetId(), mock.Anything).
Return(tt.mockUpdateDetails, tt.mockUpdateError)
// Mock GetByID call used for fetching org_id for audit log (only called if UpdateDetails succeeds)
if tt.mockUpdateError == nil {
mockCustomerService.EXPECT().GetByID(mock.Anything, tt.request.Msg.GetId()).
Return(customer.Customer{
ID: tt.request.Msg.GetId(),
OrgID: "test-org-id",
}, nil)
}
}

handler := &ConnectHandler{
Expand Down
9 changes: 9 additions & 0 deletions internal/api/v1beta1connect/v1beta1connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,15 @@ func (h *ConnectHandler) GetOrgIDFromCheckoutID(ctx context.Context, checkoutID
return customer.OrgID, nil
}

// GetOrgIDFromBillingAccountID returns the organization ID for a given billing account
func (h *ConnectHandler) GetOrgIDFromBillingAccountID(ctx context.Context, billingAccountID string) (string, error) {
customer, err := h.customerService.GetByID(ctx, billingAccountID)
if err != nil {
return "", err
}
return customer.OrgID, nil
}

func ExtractLogger(ctx context.Context) *zap.Logger {
if logger, ok := ctx.Value(loggerContextKey).(*zap.Logger); ok {
return logger
Expand Down
63 changes: 36 additions & 27 deletions pkg/server/connect_interceptors/authorization.go
Original file line number Diff line number Diff line change
Expand Up @@ -762,60 +762,71 @@ var authorizationValidationMap = map[string]func(ctx context.Context, handler *v
},
"/raystack.frontier.v1beta1.FrontierService/GetBillingAccount": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error {
pbreq := req.(*connect.Request[frontierv1beta1.GetBillingAccountRequest])
if err := ensureBillingAccountBelongToOrg(ctx, handler, pbreq.Msg.GetOrgId(), pbreq.Msg.GetId()); err != nil {
orgID, err := getOrgForBillingAccount(ctx, handler, pbreq.Msg.GetId())
if err != nil {
return err
}
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: pbreq.Msg.GetOrgId()}, schema.GetPermission, req)
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: orgID}, schema.GetPermission, req)
},
"/raystack.frontier.v1beta1.FrontierService/GetBillingBalance": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error {
pbreq := req.(*connect.Request[frontierv1beta1.GetBillingBalanceRequest])
if err := ensureBillingAccountBelongToOrg(ctx, handler, pbreq.Msg.GetOrgId(), pbreq.Msg.GetId()); err != nil {
orgID, err := getOrgForBillingAccount(ctx, handler, pbreq.Msg.GetId())
if err != nil {
return err
}
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: pbreq.Msg.GetOrgId()}, schema.GetPermission, req)
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: orgID}, schema.GetPermission, req)
},
"/raystack.frontier.v1beta1.FrontierService/CheckCreditEntitlement": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error {
pbreq := req.(*connect.Request[frontierv1beta1.CheckCreditEntitlementRequest])
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: pbreq.Msg.GetOrgId()}, schema.GetPermission, req)
},
"/raystack.frontier.v1beta1.FrontierService/UpdateBillingAccount": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error {
pbreq := req.(*connect.Request[frontierv1beta1.UpdateBillingAccountRequest])
if err := ensureBillingAccountBelongToOrg(ctx, handler, pbreq.Msg.GetOrgId(), pbreq.Msg.GetId()); err != nil {
orgID, err := getOrgForBillingAccount(ctx, handler, pbreq.Msg.GetId())
if err != nil {
return err
}
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: pbreq.Msg.GetOrgId()}, schema.DeletePermission, req)
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: orgID}, schema.DeletePermission, req)
},
"/raystack.frontier.v1beta1.FrontierService/DeleteBillingAccount": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error {
pbreq := req.(*connect.Request[frontierv1beta1.DeleteBillingAccountRequest])
if err := ensureBillingAccountBelongToOrg(ctx, handler, pbreq.Msg.GetOrgId(), pbreq.Msg.GetId()); err != nil {
orgID, err := getOrgForBillingAccount(ctx, handler, pbreq.Msg.GetId())
if err != nil {
return err
}
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: pbreq.Msg.GetOrgId()}, schema.DeletePermission, req)
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: orgID}, schema.DeletePermission, req)
},
"/raystack.frontier.v1beta1.FrontierService/EnableBillingAccount": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error {
pbreq := req.(*connect.Request[frontierv1beta1.EnableBillingAccountRequest])
if err := ensureBillingAccountBelongToOrg(ctx, handler, pbreq.Msg.GetOrgId(), pbreq.Msg.GetId()); err != nil {
orgID, err := getOrgForBillingAccount(ctx, handler, pbreq.Msg.GetId())
if err != nil {
return err
}
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: pbreq.Msg.GetOrgId()}, schema.DeletePermission, req)
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: orgID}, schema.DeletePermission, req)
},
"/raystack.frontier.v1beta1.FrontierService/DisableBillingAccount": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error {
pbreq := req.(*connect.Request[frontierv1beta1.DisableBillingAccountRequest])
if err := ensureBillingAccountBelongToOrg(ctx, handler, pbreq.Msg.GetOrgId(), pbreq.Msg.GetId()); err != nil {
orgID, err := getOrgForBillingAccount(ctx, handler, pbreq.Msg.GetId())
if err != nil {
return err
}
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: pbreq.Msg.GetOrgId()}, schema.DeletePermission, req)
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: orgID}, schema.DeletePermission, req)
},
"/raystack.frontier.v1beta1.FrontierService/HasTrialed": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error {
pbreq := req.(*connect.Request[frontierv1beta1.HasTrialedRequest])
if err := ensureBillingAccountBelongToOrg(ctx, handler, pbreq.Msg.GetOrgId(), pbreq.Msg.GetId()); err != nil {
orgID, err := getOrgForBillingAccount(ctx, handler, pbreq.Msg.GetId())
if err != nil {
return err
}
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: pbreq.Msg.GetOrgId()}, schema.GetPermission, req)
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: orgID}, schema.GetPermission, req)
},
"/raystack.frontier.v1beta1.FrontierService/RegisterBillingAccount": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error {
pbreq := req.(*connect.Request[frontierv1beta1.RegisterBillingAccountRequest])
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: pbreq.Msg.GetOrgId()}, schema.DeletePermission, req)
orgID, err := getOrgForBillingAccount(ctx, handler, pbreq.Msg.GetId())
if err != nil {
return err
}
return handler.IsAuthorized(ctx, relation.Object{Namespace: schema.OrganizationNamespace, ID: orgID}, schema.DeletePermission, req)
},

// subscriptions - org_id and billing_id are now inferred from subscription_id
Expand Down Expand Up @@ -1052,14 +1063,16 @@ var authorizationValidationMap = map[string]func(ctx context.Context, handler *v
},
"/raystack.frontier.v1beta1.AdminService/GetBillingAccountDetails": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error {
pbReq := req.(*connect.Request[frontierv1beta1.GetBillingAccountDetailsRequest])
if err := ensureBillingAccountBelongToOrg(ctx, handler, pbReq.Msg.GetOrgId(), pbReq.Msg.GetId()); err != nil {
_, err := getOrgForBillingAccount(ctx, handler, pbReq.Msg.GetId())
if err != nil {
return err
}
return handler.IsSuperUser(ctx, req)
},
"/raystack.frontier.v1beta1.AdminService/UpdateBillingAccountDetails": func(ctx context.Context, handler *v1beta1connect.ConnectHandler, req connect.AnyRequest) error {
pbReq := req.(*connect.Request[frontierv1beta1.UpdateBillingAccountDetailsRequest])
if err := ensureBillingAccountBelongToOrg(ctx, handler, pbReq.Msg.GetOrgId(), pbReq.Msg.GetId()); err != nil {
_, err := getOrgForBillingAccount(ctx, handler, pbReq.Msg.GetId())
if err != nil {
return err
}
return handler.IsSuperUser(ctx, req)
Expand Down Expand Up @@ -1122,18 +1135,14 @@ func ensureRoleBelongToOrg(ctx context.Context, handler *v1beta1connect.ConnectH
return nil
}

func ensureBillingAccountBelongToOrg(ctx context.Context, handler *v1beta1connect.ConnectHandler, orgID, billingID string) error {
acc, err := handler.GetBillingAccount(ctx, connect.NewRequest(&frontierv1beta1.GetBillingAccountRequest{
OrgId: orgID,
Id: billingID,
}))
func getOrgForBillingAccount(ctx context.Context, handler *v1beta1connect.ConnectHandler, billingID string) (string, error) {
// Infer org_id from billing_id (don't trust org_id from request for security)
orgID, err := handler.GetOrgIDFromBillingAccountID(ctx, billingID)
if err != nil {
return err
}
if acc.Msg.GetBillingAccount().GetOrgId() != orgID {
return ErrDeniedInvalidArgs
return "", err
}
return nil
// Return the inferred org_id for authorization check
return orgID, nil
}

func ensureSubscriptionBelongToOrg(ctx context.Context, handler *v1beta1connect.ConnectHandler, subID string) (string, error) {
Expand Down
13 changes: 12 additions & 1 deletion proto/apidocs.swagger.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5547,6 +5547,7 @@ paths:
$ref: '#/definitions/googlerpcStatus'
parameters:
- name: org_id
description: 'DEPRECATED: org_id will be inferred from billing account id'
in: path
required: true
type: string
Expand Down Expand Up @@ -5607,6 +5608,7 @@ paths:
$ref: '#/definitions/googlerpcStatus'
parameters:
- name: org_id
description: 'DEPRECATED: org_id will be inferred from billing account id'
in: path
required: true
type: string
Expand Down Expand Up @@ -5652,6 +5654,7 @@ paths:
$ref: '#/definitions/googlerpcStatus'
parameters:
- name: org_id
description: 'DEPRECATED: org_id will be inferred from billing account id'
in: path
required: true
type: string
Expand Down Expand Up @@ -5707,6 +5710,7 @@ paths:
$ref: '#/definitions/googlerpcStatus'
parameters:
- name: org_id
description: 'DEPRECATED: org_id will be inferred from billing account id'
in: path
required: true
type: string
Expand Down Expand Up @@ -5753,6 +5757,7 @@ paths:
$ref: '#/definitions/googlerpcStatus'
parameters:
- name: org_id
description: 'DEPRECATED: org_id will be inferred from billing account id'
in: path
required: true
type: string
Expand Down Expand Up @@ -5804,6 +5809,7 @@ paths:
$ref: '#/definitions/googlerpcStatus'
parameters:
- name: org_id
description: 'DEPRECATED: org_id will be inferred from billing account id'
in: path
required: true
type: string
Expand Down Expand Up @@ -5855,11 +5861,15 @@ paths:
$ref: '#/definitions/googlerpcStatus'
parameters:
- name: org_id
description: 'DEPRECATED: org_id will be inferred from plan_id'
in: path
required: true
type: string
- name: id
description: ID of the billing account to check
description: |-
ID of the billing account to check

DEPRECATED: billing_id will be inferred from plan_id
in: path
required: true
type: string
Expand Down Expand Up @@ -5906,6 +5916,7 @@ paths:
$ref: '#/definitions/googlerpcStatus'
parameters:
- name: org_id
description: 'DEPRECATED: org_id will be inferred from billing account id'
in: path
required: true
type: string
Expand Down
Loading
Loading