Skip to content

Commit fcaff1f

Browse files
committed
Return an error if we can't parse
1 parent a37e0bb commit fcaff1f

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

server/stmt_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ func TestStmtPrepareWithPreparedStmt(t *testing.T) {
9090
st := result.(*Stmt)
9191
require.NotNil(t, st.RawParamFields)
9292
require.NotNil(t, st.RawColumnFields)
93-
require.Equal(t, mysql.MYSQL_TYPE_LONG, st.GetParamFields()[0].Type)
94-
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, st.GetColumnFields()[0].Type)
93+
paramFields, err := st.GetParamFields()
94+
require.NoError(t, err)
95+
require.Equal(t, mysql.MYSQL_TYPE_LONG, paramFields[0].Type)
96+
columnFields, err := st.GetColumnFields()
97+
require.NoError(t, err)
98+
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, columnFields[0].Type)
9599
}

stmt/stmt.go

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,30 +14,38 @@ type PreparedStmt struct {
1414
columnFields []*mysql.Field
1515
}
1616

17-
func (s *PreparedStmt) GetParamFields() []*mysql.Field {
17+
func (s *PreparedStmt) GetParamFields() ([]*mysql.Field, error) {
1818
if s.RawParamFields == nil {
19-
return nil
19+
return nil, nil
2020
}
2121
if s.paramFields == nil {
22-
s.paramFields = make([]*mysql.Field, len(s.RawParamFields))
22+
fields := make([]*mysql.Field, len(s.RawParamFields))
2323
for i, raw := range s.RawParamFields {
24-
s.paramFields[i] = &mysql.Field{}
25-
_ = s.paramFields[i].Parse(raw)
24+
field := &mysql.Field{}
25+
if err := field.Parse(raw); err != nil {
26+
return nil, err
27+
}
28+
fields[i] = field
2629
}
30+
s.paramFields = fields
2731
}
28-
return s.paramFields
32+
return s.paramFields, nil
2933
}
3034

31-
func (s *PreparedStmt) GetColumnFields() []*mysql.Field {
35+
func (s *PreparedStmt) GetColumnFields() ([]*mysql.Field, error) {
3236
if s.RawColumnFields == nil {
33-
return nil
37+
return nil, nil
3438
}
3539
if s.columnFields == nil {
36-
s.columnFields = make([]*mysql.Field, len(s.RawColumnFields))
40+
fields := make([]*mysql.Field, len(s.RawColumnFields))
3741
for i, raw := range s.RawColumnFields {
38-
s.columnFields[i] = &mysql.Field{}
39-
_ = s.columnFields[i].Parse(raw)
42+
field := &mysql.Field{}
43+
if err := field.Parse(raw); err != nil {
44+
return nil, err
45+
}
46+
fields[i] = field
4047
}
48+
s.columnFields = fields
4149
}
42-
return s.columnFields
50+
return s.columnFields, nil
4351
}

0 commit comments

Comments
 (0)