diff --git a/flytecopilot/cmd/sidecar_test.go b/flytecopilot/cmd/sidecar_test.go index c8a89c60e55..31bde463491 100644 --- a/flytecopilot/cmd/sidecar_test.go +++ b/flytecopilot/cmd/sidecar_test.go @@ -6,8 +6,8 @@ import ( "os" "testing" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" "github.com/flyteorg/flyte/v2/flytecopilot/cmd/containerwatcher" "github.com/flyteorg/flyte/v2/flytestdlib/promutils" diff --git a/flytecopilot/data/download.go b/flytecopilot/data/download.go index a79c016e935..fccaa38c8ca 100644 --- a/flytecopilot/data/download.go +++ b/flytecopilot/data/download.go @@ -15,10 +15,10 @@ import ( "time" "github.com/ghodss/yaml" - "github.com/golang/protobuf/jsonpb" //nolint: staticcheck - "github.com/golang/protobuf/proto" //nolint: staticcheck - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + structpb "google.golang.org/protobuf/types/known/structpb" "github.com/flyteorg/flyte/v2/flytestdlib/futures" "github.com/flyteorg/flyte/v2/flytestdlib/logger" @@ -289,7 +289,7 @@ func (d Downloader) handleError(_ context.Context, b *core.Error, toFilePath str func (d Downloader) handleGeneric(ctx context.Context, b *structpb.Struct, toFilePath string, writeToFile bool) (interface{}, error) { if writeToFile && b != nil { - m := jsonpb.Marshaler{} + m := protojson.MarshalOptions{} writer, err := os.Create(toFilePath) if err != nil { return nil, errors.Wrapf(err, "failed to open file at path %s", toFilePath) @@ -300,7 +300,12 @@ func (d Downloader) handleGeneric(ctx context.Context, b *structpb.Struct, toFil logger.Errorf(ctx, "failed to close File write stream. Error: %s", err) } }() - return b, m.Marshal(writer, b) + raw, err := m.Marshal(b) + if err != nil { + return nil, err + } + _, err = writer.Write(raw) + return b, err } return b, nil } @@ -310,7 +315,6 @@ func (d Downloader) handlePrimitive(primitive *core.Primitive, toFilePath string var toByteArray func() ([]byte, error) var v interface{} - var err error switch primitive.GetValue().(type) { case *core.Primitive_StringValue: @@ -334,18 +338,18 @@ func (d Downloader) handlePrimitive(primitive *core.Primitive, toFilePath string return []byte(strconv.FormatFloat(primitive.GetFloatValue(), 'f', -1, 64)), nil } case *core.Primitive_Datetime: - v = primitive.GetDatetime().AsTime() - if err != nil { + if err := primitive.GetDatetime().CheckValid(); err != nil { return nil, err } + v = primitive.GetDatetime().AsTime() toByteArray = func() ([]byte, error) { return []byte(primitive.GetDatetime().AsTime().Format(time.RFC3339Nano)), nil } case *core.Primitive_Duration: - v = primitive.GetDuration().AsDuration() - if err != nil { + if err := primitive.GetDuration().CheckValid(); err != nil { return nil, err } + v = primitive.GetDuration().AsDuration() toByteArray = func() ([]byte, error) { return []byte(primitive.GetDuration().AsDuration().String()), nil } @@ -533,6 +537,10 @@ func (d Downloader) DownloadInputs(ctx context.Context, inputRef storage.DataRef logger.Errorf(ctx, "Failed to download inputs from [%s], err [%s]", inputRef, err) return errors.Wrapf(err, "failed to download input metadata message from remote store") } + if len(inputs.GetLiterals()) == 0 { + return nil + } + varMap, lMap, err := d.RecursiveDownload(ctx, inputs, outputDir, true) if err != nil { return errors.Wrapf(err, "failed to download input variable from remote store") diff --git a/flytecopilot/data/download_test.go b/flytecopilot/data/download_test.go index 2fb23847f63..83b6df9b189 100644 --- a/flytecopilot/data/download_test.go +++ b/flytecopilot/data/download_test.go @@ -265,7 +265,7 @@ func TestRecursiveDownload(t *testing.T) { } // Mock reading the offloaded metadata - err = s.WriteProtobuf(context.Background(), storage.DataReference("s3://container/offloaded"), storage.Options{}, &core.Literal{ + err = s.WriteProtobuf(context.Background(), "s3://container/offloaded", storage.Options{}, &core.Literal{ Value: &core.Literal_Map{ Map: &core.LiteralMap{ Literals: map[string]*core.Literal{ diff --git a/flytecopilot/data/upload.go b/flytecopilot/data/upload.go index 8412e3a9ee8..ca3b49b2639 100644 --- a/flytecopilot/data/upload.go +++ b/flytecopilot/data/upload.go @@ -9,8 +9,8 @@ import ( "path/filepath" "reflect" - "github.com/golang/protobuf/proto" //nolint: staticcheck "github.com/pkg/errors" + "google.golang.org/protobuf/proto" "github.com/flyteorg/flyte/v2/flyteidl2/clients/go/coreutils" "github.com/flyteorg/flyte/v2/flytestdlib/futures" diff --git a/flyteidl2/clients/go/coreutils/extract_literal_test.go b/flyteidl2/clients/go/coreutils/extract_literal_test.go index 8781c6b3a52..66562f5ec31 100644 --- a/flyteidl2/clients/go/coreutils/extract_literal_test.go +++ b/flyteidl2/clients/go/coreutils/extract_literal_test.go @@ -8,8 +8,8 @@ import ( "testing" "time" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" + structpb "google.golang.org/protobuf/types/known/structpb" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" ) diff --git a/flyteidl2/clients/go/coreutils/literals.go b/flyteidl2/clients/go/coreutils/literals.go index b3d4bfed36c..1925ee216b6 100644 --- a/flyteidl2/clients/go/coreutils/literals.go +++ b/flyteidl2/clients/go/coreutils/literals.go @@ -11,11 +11,11 @@ import ( "strings" "time" - "github.com/golang/protobuf/jsonpb" //nolint: staticcheck - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" "github.com/shamaton/msgpack/v2" + "google.golang.org/protobuf/encoding/protojson" "google.golang.org/protobuf/types/known/durationpb" + structpb "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/flyteorg/flyte/v2/flytestdlib/storage" @@ -378,8 +378,8 @@ func MakeLiteralForSimpleType(t core.SimpleType, s string) (*core.Literal, error switch t { case core.SimpleType_STRUCT: st := &structpb.Struct{} - unmarshaler := jsonpb.Unmarshaler{AllowUnknownFields: true} - err := unmarshaler.Unmarshal(strings.NewReader(s), st) + unmarshaler := protojson.UnmarshalOptions{DiscardUnknown: true} + err := unmarshaler.Unmarshal([]byte(s), st) if err != nil { return nil, errors.Wrapf(err, "failed to load generic type as json.") } diff --git a/flyteidl2/clients/go/coreutils/literals_test.go b/flyteidl2/clients/go/coreutils/literals_test.go index ac851cab4b6..99aa043c9ee 100644 --- a/flyteidl2/clients/go/coreutils/literals_test.go +++ b/flyteidl2/clients/go/coreutils/literals_test.go @@ -13,11 +13,11 @@ import ( "github.com/flyteorg/flyte/v2/flytestdlib/storage" "github.com/go-test/deep" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" "github.com/shamaton/msgpack/v2" "github.com/stretchr/testify/assert" "google.golang.org/protobuf/types/known/durationpb" + structpb "google.golang.org/protobuf/types/known/structpb" "google.golang.org/protobuf/types/known/timestamppb" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" diff --git a/flyteplugins/go/tasks/pluginmachinery/core/phase.go b/flyteplugins/go/tasks/pluginmachinery/core/phase.go index de3f48cdbd9..e7cd81d5019 100644 --- a/flyteplugins/go/tasks/pluginmachinery/core/phase.go +++ b/flyteplugins/go/tasks/pluginmachinery/core/phase.go @@ -4,7 +4,7 @@ import ( "fmt" "time" - structpb "github.com/golang/protobuf/ptypes/struct" + structpb "google.golang.org/protobuf/types/known/structpb" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" ) diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go index 126a9f7f941..50488f57991 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot.go @@ -7,8 +7,8 @@ import ( "strconv" "time" - "github.com/golang/protobuf/proto" //nolint: staticcheck "github.com/pkg/errors" + "google.golang.org/protobuf/proto" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot_test.go index f74f3e702af..054c008b5f7 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/copilot_test.go @@ -7,8 +7,8 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go index 49682428000..aa8cfcacf0a 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/k8s_resource_adds_test.go @@ -6,7 +6,6 @@ import ( "reflect" "testing" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" v12 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" @@ -64,7 +63,7 @@ func TestGetExecutionEnvVars(t *testing.T) { envVars := GetExecutionEnvVars(mock, tt.consoleURL) assert.Len(t, envVars, tt.expectedEnvVars) if tt.expectedEnvVar != nil { - assert.True(t, proto.Equal(&envVars[5], tt.expectedEnvVar)) + assert.Equal(t, tt.expectedEnvVar, &envVars[5]) } } } diff --git a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go index 1b1a67f1638..ce2ccd48f0e 100644 --- a/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go +++ b/flyteplugins/go/tasks/pluginmachinery/flytek8s/pod_helper.go @@ -9,8 +9,8 @@ import ( "strings" "time" - "github.com/golang/protobuf/proto" //nolint: staticcheck "github.com/imdario/mergo" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" diff --git a/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go b/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go index cedd8fb4d2a..fddb7aa530b 100644 --- a/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/ioutils/remote_file_output_reader_test.go @@ -10,7 +10,7 @@ import ( "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" - "google.golang.org/protobuf/runtime/protoiface" + "google.golang.org/protobuf/proto" ) type MemoryMetadata struct { @@ -54,7 +54,7 @@ func TestReadOrigin(t *testing.T) { }, } store := &storageMocks.ComposedProtobufStore{} - store.EXPECT().ReadProtobuf(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, ref storage.DataReference, msg protoiface.MessageV1) { + store.EXPECT().ReadProtobuf(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, ref storage.DataReference, msg proto.Message) { assert.NotNil(t, msg) casted := msg.(*core.ErrorDocument) casted.Error = errorDoc.Error @@ -89,7 +89,7 @@ func TestReadOrigin(t *testing.T) { }, } store := &storageMocks.ComposedProtobufStore{} - store.EXPECT().ReadProtobuf(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, ref storage.DataReference, msg protoiface.MessageV1) { + store.EXPECT().ReadProtobuf(mock.Anything, mock.Anything, mock.Anything).Run(func(ctx context.Context, ref storage.DataReference, msg proto.Message) { assert.NotNil(t, msg) casted := msg.(*core.ErrorDocument) casted.Error = errorDoc.Error diff --git a/flyteplugins/go/tasks/pluginmachinery/ioutils/task_reader_test.go b/flyteplugins/go/tasks/pluginmachinery/ioutils/task_reader_test.go index c27f8788b39..dbfccb73836 100644 --- a/flyteplugins/go/tasks/pluginmachinery/ioutils/task_reader_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/ioutils/task_reader_test.go @@ -5,8 +5,8 @@ import ( "fmt" "testing" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/core/mocks" "github.com/flyteorg/flyte/v2/flytestdlib/contextutils" diff --git a/flyteplugins/go/tasks/pluginmachinery/secret/secrets_test.go b/flyteplugins/go/tasks/pluginmachinery/secret/secrets_test.go index 17891551dd8..10048d2282e 100644 --- a/flyteplugins/go/tasks/pluginmachinery/secret/secrets_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/secret/secrets_test.go @@ -12,6 +12,8 @@ import ( "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/secret/config" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/secret/mocks" + secretUtils "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/utils/secrets" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" ) func TestSecretsWebhook_Mutate(t *testing.T) { @@ -23,12 +25,16 @@ func TestSecretsWebhook_Mutate(t *testing.T) { }) namespace := "test-namespace" + secretAnnotations, err := secretUtils.MarshalSecretsToMapStrings([]*core.Secret{ + { + Key: "my_key", + }, + }) + assert.NoError(t, err) podWithAnnotations := &corev1.Pod{ ObjectMeta: v1.ObjectMeta{ - Namespace: namespace, - Annotations: map[string]string{ - "flyte.secrets/s0": "nnsxsorcnv4v623fperca", - }, + Namespace: namespace, + Annotations: secretAnnotations, }, } diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go b/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go index 80d453bc537..fda9d208fe0 100755 --- a/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils.go @@ -3,16 +3,15 @@ package utils import ( "encoding/json" "fmt" - "strings" - "github.com/golang/protobuf/jsonpb" //nolint: staticcheck - "github.com/golang/protobuf/proto" //nolint: staticcheck - structpb "github.com/golang/protobuf/ptypes/struct" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + structpb "google.golang.org/protobuf/types/known/structpb" ) -var jsonPbMarshaler = jsonpb.Marshaler{} -var jsonPbUnmarshaler = &jsonpb.Unmarshaler{ - AllowUnknownFields: true, +var jsonPbMarshaler = protojson.MarshalOptions{} +var jsonPbUnmarshaler = protojson.UnmarshalOptions{ + DiscardUnknown: true, } // Deprecated: Use flytestdlib/utils.UnmarshalStructToPb instead. @@ -21,12 +20,12 @@ func UnmarshalStruct(structObj *structpb.Struct, msg proto.Message) error { return fmt.Errorf("nil Struct Object passed") } - jsonObj, err := jsonPbMarshaler.MarshalToString(structObj) + jsonObj, err := jsonPbMarshaler.Marshal(structObj) if err != nil { return err } - if err = jsonPbUnmarshaler.Unmarshal(strings.NewReader(jsonObj), msg); err != nil { + if err = jsonPbUnmarshaler.Unmarshal(jsonObj, msg); err != nil { return err } @@ -39,12 +38,12 @@ func MarshalStruct(in proto.Message, out *structpb.Struct) error { return fmt.Errorf("nil Struct Object passed") } - jsonObj, err := jsonPbMarshaler.MarshalToString(in) + jsonObj, err := jsonPbMarshaler.Marshal(in) if err != nil { return err } - if err = jsonpb.UnmarshalString(jsonObj, out); err != nil { + if err = jsonPbUnmarshaler.Unmarshal(jsonObj, out); err != nil { return err } @@ -53,7 +52,8 @@ func MarshalStruct(in proto.Message, out *structpb.Struct) error { // Deprecated: Use flytestdlib/utils.MarshalToString instead. func MarshalToString(msg proto.Message) (string, error) { - return jsonPbMarshaler.MarshalToString(msg) + b, err := jsonPbMarshaler.Marshal(msg) + return string(b), err } // Deprecated: Use flytestdlib/utils.MarshalObjToStruct instead. @@ -66,7 +66,7 @@ func MarshalObjToStruct(input interface{}) (*structpb.Struct, error) { // Turn JSON into a protobuf struct structObj := &structpb.Struct{} - if err := jsonpb.UnmarshalString(string(b), structObj); err != nil { + if err := jsonPbUnmarshaler.Unmarshal(b, structObj); err != nil { return nil, err } return structObj, nil diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils_test.go b/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils_test.go index abe1b7d2a25..3f34c4ef07e 100644 --- a/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/marshal_utils_test.go @@ -5,8 +5,8 @@ import ( "testing" "github.com/go-test/deep" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" + structpb "google.golang.org/protobuf/types/known/structpb" v1 "k8s.io/api/core/v1" ) diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler.go b/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler.go index b4997a63844..1bc0c84927a 100644 --- a/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler.go @@ -2,13 +2,12 @@ package secrets import ( "fmt" + "strconv" + "strings" "github.com/flyteorg/flyte/v2/flyteplugins/go/tasks/pluginmachinery/encoding" "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" - "github.com/golang/protobuf/proto" //nolint: staticcheck - - "strconv" - "strings" + "google.golang.org/protobuf/encoding/prototext" ) const ( @@ -36,7 +35,7 @@ func decodeSecret(encoded string) (string, error) { } func marshalSecret(s *core.Secret) string { - return encodeSecret(proto.MarshalTextString(s)) + return encodeSecret(prototext.MarshalOptions{Multiline: false}.Format(s)) } func unmarshalSecret(encoded string) (*core.Secret, error) { @@ -46,7 +45,7 @@ func unmarshalSecret(encoded string) (*core.Secret, error) { } s := &core.Secret{} - err = proto.UnmarshalText(decoded, s) + err = prototext.Unmarshal([]byte(decoded), s) return s, err } diff --git a/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler_test.go b/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler_test.go index b07899ee635..f769ed26839 100644 --- a/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler_test.go +++ b/flyteplugins/go/tasks/pluginmachinery/utils/secrets/marshaler_test.go @@ -35,14 +35,14 @@ func TestMarshalSecretsToMapStrings(t *testing.T) { Group: ";':/\\", }, }}, want: map[string]string{ - "flyte.secrets/s0": "m4zg54lqhiqceozhhixvyxbcbi", + "flyte.secrets/s0": "m4zg54lqhirdwjz1f4ofyiq", }, wantErr: false}, {name: "Without group", args: args{secrets: []*core.Secret{ { Key: "my_key", }, }}, want: map[string]string{ - "flyte.secrets/s0": "nnsxsoraejwxsx2lmv3secq", + "flyte.secrets/s0": "nnsxsorcnv3v512fpera", }, wantErr: false}, } for _, tt := range tests { diff --git a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go index 386b0d960a8..8b44c5f450b 100644 --- a/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/dask/dask_test.go @@ -7,9 +7,9 @@ import ( "time" daskAPI "github.com/dask/dask-kubernetes/v2023/dask_kubernetes/operator/go_client/pkg/apis/kubernetes.dask.org/v1" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/structpb" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go index 1db7cdc0e4b..ff18dee236c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/mpi/mpi_test.go @@ -7,10 +7,10 @@ import ( "testing" "time" - structpb "github.com/golang/protobuf/ptypes/struct" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + structpb "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go index 896b6960aa7..c5780c0fa5d 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/pytorch/pytorch_test.go @@ -7,10 +7,10 @@ import ( "testing" "time" - structpb "github.com/golang/protobuf/ptypes/struct" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + structpb "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go index 068526581e9..fa3290e685c 100644 --- a/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/kfoperators/tensorflow/tensorflow_test.go @@ -7,11 +7,11 @@ import ( "testing" "time" - structpb "github.com/golang/protobuf/ptypes/struct" kubeflowv1 "github.com/kubeflow/training-operator/pkg/apis/kubeflow.org/v1" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "google.golang.org/protobuf/proto" + structpb "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go b/flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go index ca33040c92e..890f0aed191 100644 --- a/flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/pod/sidecar_test.go @@ -8,10 +8,10 @@ import ( "path" "testing" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "google.golang.org/protobuf/proto" + structpb "google.golang.org/protobuf/types/known/structpb" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/api/resource" metav1 "k8s.io/apimachinery/pkg/apis/meta/v1" diff --git a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go index 2d7e2ffa26d..17b6d0a21bb 100644 --- a/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go +++ b/flyteplugins/go/tasks/plugins/k8s/spark/spark_test.go @@ -10,9 +10,9 @@ import ( sj "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" sparkOp "github.com/GoogleCloudPlatform/spark-on-k8s-operator/pkg/apis/sparkoperator.k8s.io/v1beta2" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" + structpb "google.golang.org/protobuf/types/known/structpb" corev1 "k8s.io/api/core/v1" v1 "k8s.io/apimachinery/pkg/apis/meta/v1" "sigs.k8s.io/controller-runtime/pkg/client" diff --git a/flytestdlib/app/error.go b/flytestdlib/app/error.go index 8433f149f07..55984aee091 100644 --- a/flytestdlib/app/error.go +++ b/flytestdlib/app/error.go @@ -3,9 +3,10 @@ package app import ( "fmt" - "github.com/golang/protobuf/proto" //nolint: staticcheck "google.golang.org/grpc/codes" "google.golang.org/grpc/status" + "google.golang.org/protobuf/proto" + "google.golang.org/protobuf/protoadapt" ) type ServerError interface { @@ -27,7 +28,7 @@ func (e *serverError) Code() codes.Code { } func (e *serverError) WithDetails(details proto.Message) (ServerError, error) { - s, err := e.status.WithDetails(details) + s, err := e.status.WithDetails(protoadapt.MessageV1Of(details)) if err != nil { return nil, err } diff --git a/flytestdlib/flytestdlib/storage/mocks/mocks.go b/flytestdlib/flytestdlib/storage/mocks/mocks.go index a55277e13df..dad6cfbc9b8 100644 --- a/flytestdlib/flytestdlib/storage/mocks/mocks.go +++ b/flytestdlib/flytestdlib/storage/mocks/mocks.go @@ -10,7 +10,7 @@ import ( "github.com/flyteorg/flyte/v2/flytestdlib/storage" mock "github.com/stretchr/testify/mock" - "github.com/golang/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" ) // NewMetadata creates a new instance of Metadata. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. diff --git a/flytestdlib/pbhash/pbhash.go b/flytestdlib/pbhash/pbhash.go index d2f46b2a296..e599c0f4972 100644 --- a/flytestdlib/pbhash/pbhash.go +++ b/flytestdlib/pbhash/pbhash.go @@ -6,13 +6,13 @@ import ( "encoding/base64" goObjectHash "github.com/benlaurie/objecthash/go/objecthash" - "github.com/golang/protobuf/jsonpb" //nolint: staticcheck - "github.com/golang/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" "github.com/flyteorg/flyte/v2/flytestdlib/logger" ) -var marshaller = &jsonpb.Marshaler{} +var marshaller = protojson.MarshalOptions{} func fromHashToByteArray(input [32]byte) []byte { output := make([]byte, 32) @@ -24,16 +24,17 @@ func fromHashToByteArray(input [32]byte) []byte { func ComputeHash(ctx context.Context, pb proto.Message) ([]byte, error) { // We marshal the pb object to JSON first which should provide a consistent mapping of pb to json fields as stated // here: https://developers.google.com/protocol-buffers/docs/proto3#json - // jsonpb marshalling includes: + // protojson marshalling includes: // - sorting map values to provide a stable output // - omitting empty values which supports backwards compatibility of old protobuf definitions // We do not use protobuf marshalling because it does not guarantee stable output because of how it handles // unknown fields and ordering of fields. https://github.com/protocolbuffers/protobuf/issues/2830 - pbJSON, err := marshaller.MarshalToString(pb) + pbJSONBytes, err := marshaller.Marshal(pb) if err != nil { logger.Warning(ctx, "failed to marshal pb [%+v] to JSON with err %v", pb, err) return nil, err } + pbJSON := string(pbJSONBytes) // Deterministically hash the JSON object to a byte array. The library will sort the map keys of the JSON object // so that we do not run into the issues from pb marshalling. diff --git a/flytestdlib/pbhash/pbhash_test.go b/flytestdlib/pbhash/pbhash_test.go index 75735b41352..23abae1286e 100644 --- a/flytestdlib/pbhash/pbhash_test.go +++ b/flytestdlib/pbhash/pbhash_test.go @@ -5,74 +5,41 @@ import ( "testing" "time" - "github.com/golang/protobuf/proto" - "github.com/golang/protobuf/ptypes" - "github.com/golang/protobuf/ptypes/duration" - "github.com/golang/protobuf/ptypes/timestamp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/durationpb" + "google.golang.org/protobuf/types/known/structpb" + "google.golang.org/protobuf/types/known/timestamppb" ) -// Mock a Protobuf generated GO object -type mockProtoMessage struct { - Integer int64 `protobuf:"varint,1,opt,name=integer,proto3" json:"integer,omitempty"` - FloatValue float64 `protobuf:"fixed64,2,opt,name=float_value,json=floatValue,proto3" json:"float_value,omitempty"` - StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` - Boolean bool `protobuf:"varint,4,opt,name=boolean,proto3" json:"boolean,omitempty"` - Datetime *timestamp.Timestamp `protobuf:"bytes,5,opt,name=datetime,proto3" json:"datetime,omitempty"` - Duration *duration.Duration `protobuf:"bytes,6,opt,name=duration,proto3" json:"duration,omitempty"` - MapValue map[string]string `protobuf:"bytes,7,rep,name=map_value,json=mapValue,proto3" json:"map_value,omitempty" protobuf_key:"bytes,1,opt,name=key,proto3" protobuf_val:"bytes,2,opt,name=value,proto3"` - Collections []string `protobuf:"bytes,8,rep,name=collections,proto3" json:"collections,omitempty"` -} - -func (mockProtoMessage) Reset() { -} - -func (m mockProtoMessage) String() string { - return proto.CompactTextString(m) -} - -func (mockProtoMessage) ProtoMessage() { -} - -// Mock an older version of the above pb object that doesn't have some fields -type mockOlderProto struct { - Integer int64 `protobuf:"varint,1,opt,name=integer,proto3" json:"integer,omitempty"` - FloatValue float64 `protobuf:"fixed64,2,opt,name=float_value,json=floatValue,proto3" json:"float_value,omitempty"` - StringValue string `protobuf:"bytes,3,opt,name=string_value,json=stringValue,proto3" json:"string_value,omitempty"` - Boolean bool `protobuf:"varint,4,opt,name=boolean,proto3" json:"boolean,omitempty"` -} +var sampleTime = timestamppb.New(time.Date(2019, 03, 29, 12, 0, 0, 0, time.UTC)) -func (mockOlderProto) Reset() { -} - -func (m mockOlderProto) String() string { - return proto.CompactTextString(m) -} +func makeStruct(t *testing.T, fields map[string]interface{}) *structpb.Struct { + t.Helper() -func (mockOlderProto) ProtoMessage() { + s, err := structpb.NewStruct(fields) + require.NoError(t, err) + return s } -var sampleTime, _ = ptypes.TimestampProto( - time.Date(2019, 03, 29, 12, 0, 0, 0, time.UTC)) - func TestProtoHash(t *testing.T) { - mockProto := &mockProtoMessage{ - Integer: 18, - FloatValue: 1.3, - StringValue: "lets test this", - Boolean: true, - Datetime: sampleTime, - Duration: ptypes.DurationProto(time.Millisecond), - MapValue: map[string]string{ + mockProto := makeStruct(t, map[string]interface{}{ + "integer": 18, + "floatValue": 1.3, + "stringValue": "lets test this", + "boolean": true, + "datetime": sampleTime.AsTime().Format(time.RFC3339Nano), + "duration": durationpb.New(time.Millisecond).AsDuration().String(), + "mapValue": map[string]interface{}{ "z": "last", "a": "first", }, - Collections: []string{"1", "2", "3"}, - } + "collections": []interface{}{"1", "2", "3"}, + }) - expectedHashedMockProto := []byte{0x62, 0x95, 0xb2, 0x2c, 0x23, 0xf5, 0x35, 0x6d, 0x3, 0x56, 0x4d, 0xc7, 0x8f, 0xae, - 0x2d, 0x2b, 0xbd, 0x7, 0xff, 0xdb, 0x7e, 0xe5, 0xf4, 0x25, 0x8f, 0xbc, 0xb2, 0xc, 0xad, 0xa5, 0x48, 0x44} - expectedHashString := "YpWyLCP1NW0DVk3Hj64tK70H/9t+5fQlj7yyDK2lSEQ=" + expectedHashedMockProto := []byte{0x45, 0xd1, 0xe, 0x9, 0x5e, 0xe3, 0xf7, 0x3e, 0xe9, 0x9, 0xe9, 0xc9, 0x27, 0xd6, + 0xf5, 0x79, 0x81, 0xf6, 0x52, 0x48, 0x3f, 0x71, 0x8c, 0x2, 0x87, 0x1, 0x98, 0x58, 0x5b, 0x7e, 0xf, 0xda} + expectedHashString := "RdEOCV7j9z7pCenJJ9b1eYH2Ukg/cYwChwGYWFt+D9o=" t.Run("TestFullProtoHash", func(t *testing.T) { hashedBytes, err := ComputeHash(context.Background(), mockProto) @@ -86,7 +53,10 @@ func TestProtoHash(t *testing.T) { }) t.Run("TestFullProtoHashReorderKeys", func(t *testing.T) { - mockProto.MapValue = map[string]string{"a": "first", "z": "last"} + mockProto.Fields["mapValue"] = structpb.NewStructValue(makeStruct(t, map[string]interface{}{ + "a": "first", + "z": "last", + })) hashedBytes, err := ComputeHash(context.Background(), mockProto) assert.Nil(t, err) assert.Equal(t, expectedHashedMockProto, hashedBytes) @@ -100,18 +70,18 @@ func TestProtoHash(t *testing.T) { func TestPartialFilledProtoHash(t *testing.T) { - mockProtoOmitEmpty := &mockProtoMessage{ - Integer: 18, - FloatValue: 1.3, - StringValue: "lets test this", - Boolean: true, - } + mockProtoOmitEmpty := makeStruct(t, map[string]interface{}{ + "integer": 18, + "floatValue": 1.3, + "stringValue": "lets test this", + "boolean": true, + }) - expectedHashedMockProtoOmitEmpty := []byte{0x1a, 0x13, 0xcc, 0x4c, 0xab, 0xc9, 0x7d, 0x43, 0xc7, 0x2b, 0xc5, 0x37, - 0xbc, 0x49, 0xa8, 0x8b, 0xfc, 0x1d, 0x54, 0x1c, 0x7b, 0x21, 0x04, 0x8f, 0xab, 0x28, 0xc6, 0x5c, 0x06, 0x73, - 0xaa, 0xe2} + expectedHashedMockProtoOmitEmpty := []byte{0x6d, 0xfa, 0xc1, 0xc2, 0xe0, 0xee, 0xad, 0xe2, 0xa5, 0xad, 0x7d, 0x9e, + 0xad, 0x1c, 0x94, 0x11, 0x6a, 0x21, 0x23, 0xe1, 0xfb, 0xe2, 0x35, 0xd5, 0x37, 0x89, 0xf3, 0xfc, 0xa, 0xfb, + 0x3d, 0xe9} - expectedHashStringOmitEmpty := "GhPMTKvJfUPHK8U3vEmoi/wdVBx7IQSPqyjGXAZzquI=" + expectedHashStringOmitEmpty := "bfrBwuDureKlrX2erRyUEWohI+H74jXVN4nz/Ar7Pek=" t.Run("TestPartial", func(t *testing.T) { hashedBytes, err := ComputeHash(context.Background(), mockProtoOmitEmpty) @@ -124,12 +94,12 @@ func TestPartialFilledProtoHash(t *testing.T) { assert.Equal(t, hashedString, expectedHashStringOmitEmpty) }) - mockOldProtoMessage := &mockOlderProto{ - Integer: 18, - FloatValue: 1.3, - StringValue: "lets test this", - Boolean: true, - } + mockOldProtoMessage := makeStruct(t, map[string]interface{}{ + "integer": 18, + "floatValue": 1.3, + "stringValue": "lets test this", + "boolean": true, + }) t.Run("TestOlderProto", func(t *testing.T) { hashedBytes, err := ComputeHash(context.Background(), mockOldProtoMessage) diff --git a/flytestdlib/storage/mocks/mocks.go b/flytestdlib/storage/mocks/mocks.go index 83180ebabde..dad6cfbc9b8 100644 --- a/flytestdlib/storage/mocks/mocks.go +++ b/flytestdlib/storage/mocks/mocks.go @@ -9,8 +9,8 @@ import ( "io" "github.com/flyteorg/flyte/v2/flytestdlib/storage" - "github.com/golang/protobuf/proto" mock "github.com/stretchr/testify/mock" + "google.golang.org/protobuf/proto" ) // NewMetadata creates a new instance of Metadata. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. diff --git a/flytestdlib/storage/protobuf_store.go b/flytestdlib/storage/protobuf_store.go index 44e04d6acbd..59747e6eb8a 100644 --- a/flytestdlib/storage/protobuf_store.go +++ b/flytestdlib/storage/protobuf_store.go @@ -6,9 +6,9 @@ import ( "fmt" "time" - "github.com/golang/protobuf/proto" //nolint: staticcheck errs "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" + "google.golang.org/protobuf/proto" "github.com/flyteorg/flyte/v2/flytestdlib/ioutils" "github.com/flyteorg/flyte/v2/flytestdlib/logger" diff --git a/flytestdlib/storage/protobuf_store_test.go b/flytestdlib/storage/protobuf_store_test.go index 019c61e7d45..b6562ac5898 100644 --- a/flytestdlib/storage/protobuf_store_test.go +++ b/flytestdlib/storage/protobuf_store_test.go @@ -9,56 +9,29 @@ import ( "net/http/httptest" "testing" - "github.com/golang/protobuf/proto" errs "github.com/pkg/errors" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/wrapperspb" "github.com/flyteorg/flyte/v2/flytestdlib/promutils" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/core" "github.com/flyteorg/stow/s3" ) -type mockProtoMessage struct { - X int64 `protobuf:"varint,2,opt,name=x,json=x,proto3" json:"x,omitempty"` -} - -type mockBigDataProtoMessage struct { - X []byte `protobuf:"bytes,1,opt,name=X,proto3" json:"X,omitempty"` -} - -func (mockProtoMessage) Reset() { -} - -func (m mockProtoMessage) String() string { - return proto.CompactTextString(m) -} - -func (mockProtoMessage) ProtoMessage() { -} - -func (mockBigDataProtoMessage) Reset() { -} - -func (m mockBigDataProtoMessage) String() string { - return proto.CompactTextString(m) -} - -func (mockBigDataProtoMessage) ProtoMessage() { -} - func TestDefaultProtobufStore(t *testing.T) { t.Run("Read after Write", func(t *testing.T) { testScope := promutils.NewTestScope() s, err := NewDataStore(&Config{Type: TypeMemory}, testScope) assert.NoError(t, err) - err = s.WriteProtobuf(context.TODO(), "hello", Options{}, &mockProtoMessage{X: 5}) + err = s.WriteProtobuf(context.TODO(), "hello", Options{}, wrapperspb.Int64(5)) assert.NoError(t, err) - m := &mockProtoMessage{} + m := &wrapperspb.Int64Value{} err = s.ReadProtobuf(context.TODO(), "hello", m) assert.NoError(t, err) - assert.Equal(t, int64(5), m.X) + assert.Equal(t, int64(5), m.Value) }) t.Run("RefreshConfig", func(t *testing.T) { @@ -108,6 +81,33 @@ func TestDefaultProtobufStore(t *testing.T) { }) } +func TestDefaultProtobufStore_EmptyLiteralMap(t *testing.T) { + testScope := promutils.NewTestScope() + s, err := NewDataStore(&Config{Type: TypeMemory}, testScope) + require.NoError(t, err) + + ref := DataReference("empty-literal-map") + require.NoError(t, s.WriteProtobuf(context.TODO(), ref, Options{}, &core.LiteralMap{})) + + raw, err := s.ReadRaw(context.TODO(), ref) + require.NoError(t, err) + defer func() { + require.NoError(t, raw.Close()) + }() + + rawBytes, err := io.ReadAll(raw) + require.NoError(t, err) + assert.Empty(t, rawBytes) + + got := &core.LiteralMap{ + Literals: map[string]*core.Literal{ + "stale": {}, + }, + } + require.NoError(t, s.ReadProtobuf(context.TODO(), ref, got)) + assert.Empty(t, got.GetLiterals()) +} + func TestDefaultProtobufStore_BigDataReadAfterWrite(t *testing.T) { t.Run("Read after Write with Big Data", func(t *testing.T) { testScope := promutils.NewTestScope() @@ -127,15 +127,15 @@ func TestDefaultProtobufStore_BigDataReadAfterWrite(t *testing.T) { _, err = rand.Read(bigD) assert.NoError(t, err) - mockMessage := &mockBigDataProtoMessage{X: bigD} + mockMessage := wrapperspb.Bytes(bigD) err = s.WriteProtobuf(context.TODO(), DataReference("bigK"), Options{}, mockMessage) assert.NoError(t, err) - m := &mockBigDataProtoMessage{} + m := &wrapperspb.BytesValue{} err = s.ReadProtobuf(context.TODO(), DataReference("bigK"), m) assert.NoError(t, err) - assert.Equal(t, bigD, m.X) + assert.Equal(t, bigD, m.Value) }) } @@ -159,13 +159,13 @@ func TestDefaultProtobufStore_HardErrors(t *testing.T) { } pbErroneousStore := NewDefaultProtobufStoreWithMetrics(store, metrics.protoMetrics) t.Run("Test if hard write errors are handled correctly", func(t *testing.T) { - err := pbErroneousStore.WriteProtobuf(ctx, k1, Options{}, &mockProtoMessage{X: 5}) + err := pbErroneousStore.WriteProtobuf(ctx, k1, Options{}, wrapperspb.Int64(5)) assert.False(t, IsFailedWriteToCache(err)) assert.Equal(t, dummyWriteErrorMsg, errs.Cause(err).Error()) }) t.Run("Test if hard read errors are handled correctly", func(t *testing.T) { - m := &mockProtoMessage{} + m := &wrapperspb.Int64Value{} err := pbErroneousStore.ReadProtobuf(ctx, k1, m) assert.False(t, IsFailedWriteToCache(err)) assert.Equal(t, dummyReadErrorMsg, errs.Cause(err).Error()) diff --git a/flytestdlib/storage/storage.go b/flytestdlib/storage/storage.go index 6a2ba589ef9..7ef5ac92ae3 100644 --- a/flytestdlib/storage/storage.go +++ b/flytestdlib/storage/storage.go @@ -15,7 +15,7 @@ import ( "time" "github.com/flyteorg/stow" - "github.com/golang/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" ) // DataReference defines a reference to data location. diff --git a/flytestdlib/utils/marshal_utils.go b/flytestdlib/utils/marshal_utils.go index d0fde425455..90180f4b0b8 100644 --- a/flytestdlib/utils/marshal_utils.go +++ b/flytestdlib/utils/marshal_utils.go @@ -4,17 +4,16 @@ import ( "bytes" "encoding/json" "fmt" - "strings" - "github.com/golang/protobuf/jsonpb" //nolint: staticcheck - "github.com/golang/protobuf/proto" //nolint: staticcheck - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + "google.golang.org/protobuf/proto" + structpb "google.golang.org/protobuf/types/known/structpb" ) -var jsonPbMarshaler = jsonpb.Marshaler{} -var jsonPbUnmarshaler = jsonpb.Unmarshaler{ - AllowUnknownFields: true, +var jsonPbMarshaler = protojson.MarshalOptions{} +var jsonPbUnmarshaler = protojson.UnmarshalOptions{ + DiscardUnknown: true, } // UnmarshalStructToPb unmarshals a proto struct into a proto message using jsonPb marshaler. @@ -27,12 +26,12 @@ func UnmarshalStructToPb(structObj *structpb.Struct, msg proto.Message) error { return fmt.Errorf("nil proto.Message object passed") } - jsonObj, err := jsonPbMarshaler.MarshalToString(structObj) + jsonObj, err := jsonPbMarshaler.Marshal(structObj) if err != nil { return errors.WithMessage(err, "Failed to marshal strcutObj input") } - if err = UnmarshalStringToPb(jsonObj, msg); err != nil { + if err = UnmarshalBytesToPb(jsonObj, msg); err != nil { return errors.WithMessage(err, "Failed to unmarshal json obj into proto") } @@ -46,9 +45,11 @@ func MarshalPbToStruct(in proto.Message) (out *structpb.Struct, err error) { } var buf bytes.Buffer - if err := jsonPbMarshaler.Marshal(&buf, in); err != nil { + b, err := jsonPbMarshaler.Marshal(in) + if err != nil { return nil, errors.WithMessage(err, "Failed to marshal input proto message") } + buf.Write(b) out = &structpb.Struct{} if err = UnmarshalBytesToPb(buf.Bytes(), out); err != nil { @@ -60,27 +61,37 @@ func MarshalPbToStruct(in proto.Message) (out *structpb.Struct, err error) { // MarshalPbToString marshals a proto message using jsonPb marshaler to string. func MarshalPbToString(msg proto.Message) (string, error) { - return jsonPbMarshaler.MarshalToString(msg) + if msg == nil { + return "", fmt.Errorf("nil proto message passed") + } + + b, err := jsonPbMarshaler.Marshal(msg) + return string(b), err } // UnmarshalStringToPb unmarshals a string to a proto message func UnmarshalStringToPb(s string, msg proto.Message) error { - return jsonPbUnmarshaler.Unmarshal(strings.NewReader(s), msg) + return jsonPbUnmarshaler.Unmarshal([]byte(s), msg) } // MarshalPbToBytes marshals a proto message to a byte slice func MarshalPbToBytes(msg proto.Message) ([]byte, error) { + if msg == nil { + return nil, fmt.Errorf("nil proto message passed") + } + var buf bytes.Buffer - err := jsonPbMarshaler.Marshal(&buf, msg) + b, err := jsonPbMarshaler.Marshal(msg) if err != nil { return nil, err } + buf.Write(b) return buf.Bytes(), nil } // UnmarshalBytesToPb unmarshals a byte slice to a proto message func UnmarshalBytesToPb(b []byte, msg proto.Message) error { - return jsonPbUnmarshaler.Unmarshal(bytes.NewReader(b), msg) + return jsonPbUnmarshaler.Unmarshal(b, msg) } // MarshalObjToStruct marshals obj into a struct. Will use jsonPb if input is a proto message, otherwise, it'll use json diff --git a/flytestdlib/utils/marshal_utils_test.go b/flytestdlib/utils/marshal_utils_test.go index 3ac0c3f3390..ad86ca4f8eb 100644 --- a/flytestdlib/utils/marshal_utils_test.go +++ b/flytestdlib/utils/marshal_utils_test.go @@ -4,9 +4,9 @@ import ( "testing" "github.com/go-test/deep" - "github.com/golang/protobuf/proto" - structpb "github.com/golang/protobuf/ptypes/struct" "github.com/stretchr/testify/assert" + "google.golang.org/protobuf/proto" + structpb "google.golang.org/protobuf/types/known/structpb" v1 "k8s.io/api/core/v1" "k8s.io/apimachinery/pkg/util/json" diff --git a/flytestdlib/utils/prototest/test_type.pb.go b/flytestdlib/utils/prototest/test_type.pb.go index 7495748680b..103f4d8d453 100644 --- a/flytestdlib/utils/prototest/test_type.pb.go +++ b/flytestdlib/utils/prototest/test_type.pb.go @@ -10,7 +10,6 @@ import ( reflect "reflect" sync "sync" - proto "github.com/golang/protobuf/proto" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) @@ -22,10 +21,6 @@ const ( _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) ) -// This is a compile-time assertion that a sufficiently up-to-date version -// of the legacy proto package is being used. -const _ = proto.ProtoPackageIsVersion4 - type TestProto struct { state protoimpl.MessageState sizeCache protoimpl.SizeCache diff --git a/go.mod b/go.mod index f7c24b527fd..7bccc1bf006 100644 --- a/go.mod +++ b/go.mod @@ -27,7 +27,6 @@ require ( github.com/ghodss/yaml v1.0.0 github.com/go-test/deep v1.1.1 github.com/go-viper/mapstructure/v2 v2.4.0 - github.com/golang/protobuf v1.5.4 github.com/googleapis/gax-go/v2 v2.15.0 github.com/grpc-ecosystem/go-grpc-middleware/providers/prometheus v1.1.0 github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.3 @@ -144,6 +143,7 @@ require ( github.com/golang-jwt/jwt/v5 v5.3.0 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/golang/mock v1.6.0 // indirect + github.com/golang/protobuf v1.5.4 // indirect github.com/google/btree v1.1.3 // indirect github.com/google/cel-go v0.26.0 // indirect github.com/google/gnostic-models v0.7.0 // indirect diff --git a/runs/repository/transformers/task.go b/runs/repository/transformers/task.go index ec7ccb16c61..2a7c8246d1e 100644 --- a/runs/repository/transformers/task.go +++ b/runs/repository/transformers/task.go @@ -5,7 +5,7 @@ import ( "database/sql" "strings" - "github.com/golang/protobuf/proto" //nolint: staticcheck + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "github.com/flyteorg/flyte/v2/flytestdlib/logger" diff --git a/runs/service/internal_run_service.go b/runs/service/internal_run_service.go index 32eebc334b6..577e7686dd0 100644 --- a/runs/service/internal_run_service.go +++ b/runs/service/internal_run_service.go @@ -5,10 +5,11 @@ import ( "database/sql" "errors" "fmt" - "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" "io" "time" + "github.com/flyteorg/flyte/v2/gen/go/flyteidl2/common" + "connectrpc.com/connect" grpcstatus "google.golang.org/genproto/googleapis/rpc/status" "google.golang.org/protobuf/proto" diff --git a/runs/service/run_service_test.go b/runs/service/run_service_test.go index 3c40bd59176..6c7a580525a 100644 --- a/runs/service/run_service_test.go +++ b/runs/service/run_service_test.go @@ -13,10 +13,10 @@ import ( "time" "connectrpc.com/connect" - "github.com/golang/protobuf/proto" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/mock" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/proto" "google.golang.org/protobuf/types/known/timestamppb" "github.com/flyteorg/flyte/v2/flytestdlib/storage"