Skip to content

Commit 9d0777a

Browse files
committed
Return an error if we can't parse
1 parent fd59075 commit 9d0777a

File tree

2 files changed

+26
-14
lines changed

2 files changed

+26
-14
lines changed

server/stmt_test.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,10 @@ func TestStmtPrepareWithPreparedStmt(t *testing.T) {
9090
st := result.(*Stmt)
9191
require.NotNil(t, st.RawParamFields)
9292
require.NotNil(t, st.RawColumnFields)
93-
require.Equal(t, mysql.MYSQL_TYPE_LONG, st.GetParamFields()[0].Type)
94-
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, st.GetColumnFields()[0].Type)
93+
paramFields, err := st.GetParamFields()
94+
require.NoError(t, err)
95+
require.Equal(t, mysql.MYSQL_TYPE_LONG, paramFields[0].Type)
96+
columnFields, err := st.GetColumnFields()
97+
require.NoError(t, err)
98+
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, columnFields[0].Type)
9599
}

stmt/stmt.go

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -15,30 +15,38 @@ type PreparedStmt struct {
1515
columnFields []*mysql.Field
1616
}
1717

18-
func (s *PreparedStmt) GetParamFields() []*mysql.Field {
18+
func (s *PreparedStmt) GetParamFields() ([]*mysql.Field, error) {
1919
if s.RawParamFields == nil {
20-
return nil
20+
return nil, nil
2121
}
2222
if s.paramFields == nil {
23-
s.paramFields = make([]*mysql.Field, len(s.RawParamFields))
23+
fields := make([]*mysql.Field, len(s.RawParamFields))
2424
for i, raw := range s.RawParamFields {
25-
s.paramFields[i] = &mysql.Field{}
26-
_ = s.paramFields[i].Parse(raw)
25+
field := &mysql.Field{}
26+
if err := field.Parse(raw); err != nil {
27+
return nil, err
28+
}
29+
fields[i] = field
2730
}
31+
s.paramFields = fields
2832
}
29-
return s.paramFields
33+
return s.paramFields, nil
3034
}
3135

32-
func (s *PreparedStmt) GetColumnFields() []*mysql.Field {
36+
func (s *PreparedStmt) GetColumnFields() ([]*mysql.Field, error) {
3337
if s.RawColumnFields == nil {
34-
return nil
38+
return nil, nil
3539
}
3640
if s.columnFields == nil {
37-
s.columnFields = make([]*mysql.Field, len(s.RawColumnFields))
41+
fields := make([]*mysql.Field, len(s.RawColumnFields))
3842
for i, raw := range s.RawColumnFields {
39-
s.columnFields[i] = &mysql.Field{}
40-
_ = s.columnFields[i].Parse(raw)
43+
field := &mysql.Field{}
44+
if err := field.Parse(raw); err != nil {
45+
return nil, err
46+
}
47+
fields[i] = field
4148
}
49+
s.columnFields = fields
4250
}
43-
return s.columnFields
51+
return s.columnFields, nil
4452
}

0 commit comments

Comments
 (0)