Skip to content

Commit 8abbc87

Browse files
authored
Store prepared statements field definitions (#1073)
* Store prepared statements field definitions * Add a test * Make `paramFields` and `columnFields` private * Add comments * Move statements into a new `stmt` package * Return an error if we can't parse
1 parent 6ec39f6 commit 8abbc87

File tree

5 files changed

+148
-29
lines changed

5 files changed

+148
-29
lines changed

client/stmt.go

Lines changed: 26 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,25 +8,25 @@ import (
88
"runtime"
99

1010
"github.com/go-mysql-org/go-mysql/mysql"
11+
"github.com/go-mysql-org/go-mysql/stmt"
1112
"github.com/go-mysql-org/go-mysql/utils"
1213
"github.com/pingcap/errors"
1314
)
1415

1516
type Stmt struct {
16-
conn *Conn
17-
id uint32
18-
19-
params int
20-
columns int
17+
conn *Conn
2118
warnings int
19+
20+
// PreparedStmt contains common fields shared with server.Stmt for proxy passthrough
21+
stmt.PreparedStmt
2222
}
2323

2424
func (s *Stmt) ParamNum() int {
25-
return s.params
25+
return s.Params
2626
}
2727

2828
func (s *Stmt) ColumnNum() int {
29-
return s.columns
29+
return s.Columns
3030
}
3131

3232
func (s *Stmt) WarningsNum() int {
@@ -50,7 +50,7 @@ func (s *Stmt) ExecuteSelectStreaming(result *mysql.Result, perRowCb SelectPerRo
5050
}
5151

5252
func (s *Stmt) Close() error {
53-
if err := s.conn.writeCommandUint32(mysql.COM_STMT_CLOSE, s.id); err != nil {
53+
if err := s.conn.writeCommandUint32(mysql.COM_STMT_CLOSE, s.ID); err != nil {
5454
return errors.Trace(err)
5555
}
5656

@@ -60,10 +60,10 @@ func (s *Stmt) Close() error {
6060
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
6161
func (s *Stmt) write(args ...interface{}) error {
6262
defer clear(s.conn.queryAttributes)
63-
paramsNum := s.params
63+
paramsNum := s.Params
6464

6565
if len(args) != paramsNum {
66-
return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
66+
return fmt.Errorf("argument mismatch, need %d but got %d", s.Params, len(args))
6767
}
6868

6969
if (s.conn.capability&mysql.CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) {
@@ -187,7 +187,7 @@ func (s *Stmt) write(args ...interface{}) error {
187187

188188
data.Write([]byte{0, 0, 0, 0})
189189
data.WriteByte(mysql.COM_STMT_EXECUTE)
190-
data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)})
190+
data.Write([]byte{byte(s.ID), byte(s.ID >> 8), byte(s.ID >> 16), byte(s.ID >> 24)})
191191

192192
flags := mysql.CURSOR_TYPE_NO_CURSOR
193193
if paramsNum > 0 {
@@ -254,15 +254,15 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
254254
pos := 1
255255

256256
// for statement id
257-
s.id = binary.LittleEndian.Uint32(data[pos:])
257+
s.ID = binary.LittleEndian.Uint32(data[pos:])
258258
pos += 4
259259

260260
// number columns
261-
s.columns = int(binary.LittleEndian.Uint16(data[pos:]))
261+
s.Columns = int(binary.LittleEndian.Uint16(data[pos:]))
262262
pos += 2
263263

264264
// number params
265-
s.params = int(binary.LittleEndian.Uint16(data[pos:]))
265+
s.Params = int(binary.LittleEndian.Uint16(data[pos:]))
266266
pos += 2
267267

268268
// reserved
@@ -274,11 +274,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
274274
// pos += 2
275275
}
276276

277-
if s.params > 0 {
278-
for range s.params {
279-
if _, err := s.conn.ReadPacket(); err != nil {
277+
if s.Params > 0 {
278+
s.RawParamFields = make([][]byte, s.Params)
279+
for i := range s.Params {
280+
data, err := s.conn.ReadPacket()
281+
if err != nil {
280282
return nil, errors.Trace(err)
281283
}
284+
s.RawParamFields[i] = data
282285
}
283286
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
284287
if packet, err := s.conn.ReadPacket(); err != nil {
@@ -289,12 +292,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
289292
}
290293
}
291294

292-
if s.columns > 0 {
293-
// TODO process when CLIENT_CACHE_METADATA enabled
294-
for range s.columns {
295-
if _, err := s.conn.ReadPacket(); err != nil {
295+
if s.Columns > 0 {
296+
s.RawColumnFields = make([][]byte, s.Columns)
297+
for i := range s.Columns {
298+
data, err := s.conn.ReadPacket()
299+
if err != nil {
296300
return nil, errors.Trace(err)
297301
}
302+
s.RawColumnFields[i] = data
298303
}
299304
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
300305
if packet, err := s.conn.ReadPacket(); err != nil {

server/command.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77

88
"github.com/go-mysql-org/go-mysql/mysql"
99
"github.com/go-mysql-org/go-mysql/replication"
10+
"github.com/go-mysql-org/go-mysql/stmt"
1011
"github.com/go-mysql-org/go-mysql/utils"
1112
)
1213

@@ -112,6 +113,10 @@ func (c *Conn) dispatch(data []byte) interface{} {
112113
if st.Params, st.Columns, st.Context, err = c.h.HandleStmtPrepare(st.Query); err != nil {
113114
return err
114115
} else {
116+
if provider, ok := st.Context.(*stmt.PreparedStmt); ok {
117+
st.RawParamFields = provider.RawParamFields
118+
st.RawColumnFields = provider.RawColumnFields
119+
}
115120
st.ResetParams()
116121
c.stmts[c.stmtID] = st
117122
return st

server/stmt.go

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strconv"
88

99
"github.com/go-mysql-org/go-mysql/mysql"
10+
"github.com/go-mysql-org/go-mysql/stmt"
1011
"github.com/pingcap/errors"
1112
)
1213

@@ -16,15 +17,13 @@ var (
1617
)
1718

1819
type Stmt struct {
19-
ID uint32
2020
Query string
21-
22-
Params int
23-
Columns int
24-
25-
Args []interface{}
21+
Args []interface{}
2622

2723
Context interface{}
24+
25+
// PreparedStmt contains common fields shared with client.Stmt for proxy passthrough
26+
stmt.PreparedStmt
2827
}
2928

3029
func (s *Stmt) Rest(params int, columns int, context interface{}) {
@@ -61,7 +60,11 @@ func (c *Conn) writePrepare(s *Stmt) error {
6160
if s.Params > 0 {
6261
for i := 0; i < s.Params; i++ {
6362
data = data[0:4]
64-
data = append(data, paramFieldData...)
63+
if s.RawParamFields != nil && i < len(s.RawParamFields) {
64+
data = append(data, s.RawParamFields[i]...)
65+
} else {
66+
data = append(data, paramFieldData...)
67+
}
6568

6669
if err := c.WritePacket(data); err != nil {
6770
return errors.Trace(err)
@@ -76,7 +79,11 @@ func (c *Conn) writePrepare(s *Stmt) error {
7679
if s.Columns > 0 {
7780
for i := 0; i < s.Columns; i++ {
7881
data = data[0:4]
79-
data = append(data, columnFieldData...)
82+
if s.RawColumnFields != nil && i < len(s.RawColumnFields) {
83+
data = append(data, s.RawColumnFields[i]...)
84+
} else {
85+
data = append(data, columnFieldData...)
86+
}
8087

8188
if err := c.WritePacket(data); err != nil {
8289
return errors.Trace(err)

server/stmt_test.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ package server
33
import (
44
"testing"
55

6+
"github.com/go-mysql-org/go-mysql/mysql"
7+
"github.com/go-mysql-org/go-mysql/stmt"
68
"github.com/stretchr/testify/require"
79
)
810

@@ -46,3 +48,52 @@ func TestHandleStmtExecute(t *testing.T) {
4648
}
4749
}
4850
}
51+
52+
type mockPrepareHandler struct {
53+
EmptyHandler
54+
context any
55+
paramCount, columnCount int
56+
}
57+
58+
func (h *mockPrepareHandler) HandleStmtPrepare(query string) (int, int, any, error) {
59+
return h.paramCount, h.columnCount, h.context, nil
60+
}
61+
62+
func TestStmtPrepareWithoutPreparedStmt(t *testing.T) {
63+
c := &Conn{
64+
h: &mockPrepareHandler{context: "plain string", paramCount: 1, columnCount: 1},
65+
stmts: make(map[uint32]*Stmt),
66+
}
67+
68+
result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT * FROM t"...))
69+
70+
st := result.(*Stmt)
71+
require.Nil(t, st.RawParamFields)
72+
require.Nil(t, st.RawColumnFields)
73+
}
74+
75+
func TestStmtPrepareWithPreparedStmt(t *testing.T) {
76+
paramField := &mysql.Field{Name: []byte("?"), Type: mysql.MYSQL_TYPE_LONG}
77+
columnField := &mysql.Field{Name: []byte("id"), Type: mysql.MYSQL_TYPE_LONGLONG}
78+
79+
provider := &stmt.PreparedStmt{
80+
RawParamFields: [][]byte{paramField.Dump()},
81+
RawColumnFields: [][]byte{columnField.Dump()},
82+
}
83+
c := &Conn{
84+
h: &mockPrepareHandler{context: provider, paramCount: 1, columnCount: 1},
85+
stmts: make(map[uint32]*Stmt),
86+
}
87+
88+
result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT id FROM t WHERE id = ?"...))
89+
90+
st := result.(*Stmt)
91+
require.NotNil(t, st.RawParamFields)
92+
require.NotNil(t, st.RawColumnFields)
93+
paramFields, err := st.GetParamFields()
94+
require.NoError(t, err)
95+
require.Equal(t, mysql.MYSQL_TYPE_LONG, paramFields[0].Type)
96+
columnFields, err := st.GetColumnFields()
97+
require.NoError(t, err)
98+
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, columnFields[0].Type)
99+
}

stmt/stmt.go

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
package stmt
2+
3+
import "github.com/go-mysql-org/go-mysql/mysql"
4+
5+
type PreparedStmt struct {
6+
ID uint32
7+
Params int
8+
Columns int
9+
10+
RawParamFields [][]byte
11+
RawColumnFields [][]byte
12+
13+
paramFields []*mysql.Field
14+
columnFields []*mysql.Field
15+
}
16+
17+
func (s *PreparedStmt) GetParamFields() ([]*mysql.Field, error) {
18+
if s.RawParamFields == nil {
19+
return nil, nil
20+
}
21+
if s.paramFields == nil {
22+
fields := make([]*mysql.Field, len(s.RawParamFields))
23+
for i, raw := range s.RawParamFields {
24+
field := &mysql.Field{}
25+
if err := field.Parse(raw); err != nil {
26+
return nil, err
27+
}
28+
fields[i] = field
29+
}
30+
s.paramFields = fields
31+
}
32+
return s.paramFields, nil
33+
}
34+
35+
func (s *PreparedStmt) GetColumnFields() ([]*mysql.Field, error) {
36+
if s.RawColumnFields == nil {
37+
return nil, nil
38+
}
39+
if s.columnFields == nil {
40+
fields := make([]*mysql.Field, len(s.RawColumnFields))
41+
for i, raw := range s.RawColumnFields {
42+
field := &mysql.Field{}
43+
if err := field.Parse(raw); err != nil {
44+
return nil, err
45+
}
46+
fields[i] = field
47+
}
48+
s.columnFields = fields
49+
}
50+
return s.columnFields, nil
51+
}

0 commit comments

Comments
 (0)