Skip to content

Commit fd59075

Browse files
committed
Move statements into a new stmt package
1 parent e4cb587 commit fd59075

File tree

5 files changed

+94
-91
lines changed

5 files changed

+94
-91
lines changed

client/stmt.go

Lines changed: 22 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -8,47 +8,28 @@ import (
88
"runtime"
99

1010
"github.com/go-mysql-org/go-mysql/mysql"
11+
"github.com/go-mysql-org/go-mysql/stmt"
1112
"github.com/go-mysql-org/go-mysql/utils"
1213
"github.com/pingcap/errors"
1314
)
1415

1516
type Stmt struct {
1617
conn *Conn
17-
id uint32
1818

19-
params int
20-
columns int
21-
warnings int
22-
23-
// Field definitions from the PREPARE response (for proxy passthrough)
24-
paramFields []*mysql.Field
25-
columnFields []*mysql.Field
19+
// PreparedStmt contains common fields shared with server.Stmt for proxy passthrough
20+
stmt.PreparedStmt
2621
}
2722

2823
func (s *Stmt) ParamNum() int {
29-
return s.params
24+
return s.Params
3025
}
3126

3227
func (s *Stmt) ColumnNum() int {
33-
return s.columns
28+
return s.Columns
3429
}
3530

3631
func (s *Stmt) WarningsNum() int {
37-
return s.warnings
38-
}
39-
40-
// GetParamFields returns the parameter field definitions from the PREPARE response.
41-
// Implements server.StmtFieldsProvider for proxy passthrough.
42-
// The caller should not modify the returned slice.
43-
func (s *Stmt) GetParamFields() []*mysql.Field {
44-
return s.paramFields
45-
}
46-
47-
// GetColumnFields returns the column field definitions from the PREPARE response.
48-
// Implements server.StmtFieldsProvider for proxy passthrough.
49-
// The caller should not modify the returned slice.
50-
func (s *Stmt) GetColumnFields() []*mysql.Field {
51-
return s.columnFields
32+
return s.Warnings
5233
}
5334

5435
func (s *Stmt) Execute(args ...interface{}) (*mysql.Result, error) {
@@ -68,7 +49,7 @@ func (s *Stmt) ExecuteSelectStreaming(result *mysql.Result, perRowCb SelectPerRo
6849
}
6950

7051
func (s *Stmt) Close() error {
71-
if err := s.conn.writeCommandUint32(mysql.COM_STMT_CLOSE, s.id); err != nil {
52+
if err := s.conn.writeCommandUint32(mysql.COM_STMT_CLOSE, s.ID); err != nil {
7253
return errors.Trace(err)
7354
}
7455

@@ -78,10 +59,10 @@ func (s *Stmt) Close() error {
7859
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
7960
func (s *Stmt) write(args ...interface{}) error {
8061
defer clear(s.conn.queryAttributes)
81-
paramsNum := s.params
62+
paramsNum := s.Params
8263

8364
if len(args) != paramsNum {
84-
return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
65+
return fmt.Errorf("argument mismatch, need %d but got %d", s.Params, len(args))
8566
}
8667

8768
if (s.conn.capability&mysql.CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) {
@@ -205,7 +186,7 @@ func (s *Stmt) write(args ...interface{}) error {
205186

206187
data.Write([]byte{0, 0, 0, 0})
207188
data.WriteByte(mysql.COM_STMT_EXECUTE)
208-
data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)})
189+
data.Write([]byte{byte(s.ID), byte(s.ID >> 8), byte(s.ID >> 16), byte(s.ID >> 24)})
209190

210191
flags := mysql.CURSOR_TYPE_NO_CURSOR
211192
if paramsNum > 0 {
@@ -272,37 +253,34 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
272253
pos := 1
273254

274255
// for statement id
275-
s.id = binary.LittleEndian.Uint32(data[pos:])
256+
s.ID = binary.LittleEndian.Uint32(data[pos:])
276257
pos += 4
277258

278259
// number columns
279-
s.columns = int(binary.LittleEndian.Uint16(data[pos:]))
260+
s.Columns = int(binary.LittleEndian.Uint16(data[pos:]))
280261
pos += 2
281262

282263
// number params
283-
s.params = int(binary.LittleEndian.Uint16(data[pos:]))
264+
s.Params = int(binary.LittleEndian.Uint16(data[pos:]))
284265
pos += 2
285266

286267
// reserved
287268
pos += 1
288269

289270
if len(data) >= 12 {
290271
// warnings
291-
s.warnings = int(binary.LittleEndian.Uint16(data[pos:]))
272+
s.Warnings = int(binary.LittleEndian.Uint16(data[pos:]))
292273
// pos += 2
293274
}
294275

295-
if s.params > 0 {
296-
s.paramFields = make([]*mysql.Field, s.params)
297-
for i := range s.params {
276+
if s.Params > 0 {
277+
s.RawParamFields = make([][]byte, s.Params)
278+
for i := range s.Params {
298279
data, err := s.conn.ReadPacket()
299280
if err != nil {
300281
return nil, errors.Trace(err)
301282
}
302-
s.paramFields[i] = &mysql.Field{}
303-
if err := s.paramFields[i].Parse(data); err != nil {
304-
return nil, errors.Trace(err)
305-
}
283+
s.RawParamFields[i] = data
306284
}
307285
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
308286
if packet, err := s.conn.ReadPacket(); err != nil {
@@ -313,17 +291,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
313291
}
314292
}
315293

316-
if s.columns > 0 {
317-
s.columnFields = make([]*mysql.Field, s.columns)
318-
for i := range s.columns {
294+
if s.Columns > 0 {
295+
s.RawColumnFields = make([][]byte, s.Columns)
296+
for i := range s.Columns {
319297
data, err := s.conn.ReadPacket()
320298
if err != nil {
321299
return nil, errors.Trace(err)
322300
}
323-
s.columnFields[i] = &mysql.Field{}
324-
if err := s.columnFields[i].Parse(data); err != nil {
325-
return nil, errors.Trace(err)
326-
}
301+
s.RawColumnFields[i] = data
327302
}
328303
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
329304
if packet, err := s.conn.ReadPacket(); err != nil {

server/command.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,10 @@ import (
77

88
"github.com/go-mysql-org/go-mysql/mysql"
99
"github.com/go-mysql-org/go-mysql/replication"
10+
"github.com/go-mysql-org/go-mysql/stmt"
1011
"github.com/go-mysql-org/go-mysql/utils"
1112
)
1213

13-
// StmtFieldsProvider is an optional interface that prepared statement contexts can implement
14-
// to provide field definitions for proxy passthrough scenarios.
15-
type StmtFieldsProvider interface {
16-
GetParamFields() []*mysql.Field
17-
GetColumnFields() []*mysql.Field
18-
}
19-
2014
// Handler is what a server needs to implement the client-server protocol
2115
type Handler interface {
2216
// handle COM_INIT_DB command, you can check whether the dbName is valid, or other.
@@ -119,11 +113,9 @@ func (c *Conn) dispatch(data []byte) interface{} {
119113
if st.Params, st.Columns, st.Context, err = c.h.HandleStmtPrepare(st.Query); err != nil {
120114
return err
121115
} else {
122-
// If context provides field definitions (e.g., from a backend prepared statement),
123-
// use them for accurate metadata passthrough in proxy scenarios.
124-
if provider, ok := st.Context.(StmtFieldsProvider); ok {
125-
st.ParamFields = provider.GetParamFields()
126-
st.ColumnFields = provider.GetColumnFields()
116+
if provider, ok := st.Context.(*stmt.PreparedStmt); ok {
117+
st.RawParamFields = provider.RawParamFields
118+
st.RawColumnFields = provider.RawColumnFields
127119
}
128120
st.ResetParams()
129121
c.stmts[c.stmtID] = st

server/stmt.go

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strconv"
88

99
"github.com/go-mysql-org/go-mysql/mysql"
10+
"github.com/go-mysql-org/go-mysql/stmt"
1011
"github.com/pingcap/errors"
1112
)
1213

@@ -16,19 +17,13 @@ var (
1617
)
1718

1819
type Stmt struct {
19-
ID uint32
2020
Query string
21-
22-
Params int
23-
Columns int
24-
25-
Args []interface{}
21+
Args []interface{}
2622

2723
Context interface{}
2824

29-
// Field definitions for proxy passthrough (optional, uses dummy fields if nil)
30-
ParamFields []*mysql.Field
31-
ColumnFields []*mysql.Field
25+
// PreparedStmt contains common fields shared with client.Stmt for proxy passthrough
26+
stmt.PreparedStmt
3227
}
3328

3429
func (s *Stmt) Rest(params int, columns int, context interface{}) {
@@ -65,8 +60,8 @@ func (c *Conn) writePrepare(s *Stmt) error {
6560
if s.Params > 0 {
6661
for i := 0; i < s.Params; i++ {
6762
data = data[0:4]
68-
if s.ParamFields != nil && i < len(s.ParamFields) {
69-
data = append(data, s.ParamFields[i].Dump()...)
63+
if s.RawParamFields != nil && i < len(s.RawParamFields) {
64+
data = append(data, s.RawParamFields[i]...)
7065
} else {
7166
data = append(data, paramFieldData...)
7267
}
@@ -84,8 +79,8 @@ func (c *Conn) writePrepare(s *Stmt) error {
8479
if s.Columns > 0 {
8580
for i := 0; i < s.Columns; i++ {
8681
data = data[0:4]
87-
if s.ColumnFields != nil && i < len(s.ColumnFields) {
88-
data = append(data, s.ColumnFields[i].Dump()...)
82+
if s.RawColumnFields != nil && i < len(s.RawColumnFields) {
83+
data = append(data, s.RawColumnFields[i]...)
8984
} else {
9085
data = append(data, columnFieldData...)
9186
}

server/stmt_test.go

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"testing"
55

66
"github.com/go-mysql-org/go-mysql/mysql"
7+
"github.com/go-mysql-org/go-mysql/stmt"
78
"github.com/stretchr/testify/require"
89
)
910

@@ -58,30 +59,26 @@ func (h *mockPrepareHandler) HandleStmtPrepare(query string) (int, int, any, err
5859
return h.paramCount, h.columnCount, h.context, nil
5960
}
6061

61-
func TestStmtPrepareWithoutFieldsProvider(t *testing.T) {
62+
func TestStmtPrepareWithoutPreparedStmt(t *testing.T) {
6263
c := &Conn{
6364
h: &mockPrepareHandler{context: "plain string", paramCount: 1, columnCount: 1},
6465
stmts: make(map[uint32]*Stmt),
6566
}
6667

6768
result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT * FROM t"...))
6869

69-
stmt := result.(*Stmt)
70-
require.Nil(t, stmt.ParamFields)
71-
require.Nil(t, stmt.ColumnFields)
70+
st := result.(*Stmt)
71+
require.Nil(t, st.RawParamFields)
72+
require.Nil(t, st.RawColumnFields)
7273
}
7374

74-
type mockFieldsProvider struct {
75-
paramFields, columnFields []*mysql.Field
76-
}
77-
78-
func (m *mockFieldsProvider) GetParamFields() []*mysql.Field { return m.paramFields }
79-
func (m *mockFieldsProvider) GetColumnFields() []*mysql.Field { return m.columnFields }
75+
func TestStmtPrepareWithPreparedStmt(t *testing.T) {
76+
paramField := &mysql.Field{Name: []byte("?"), Type: mysql.MYSQL_TYPE_LONG}
77+
columnField := &mysql.Field{Name: []byte("id"), Type: mysql.MYSQL_TYPE_LONGLONG}
8078

81-
func TestStmtPrepareWithFieldsProvider(t *testing.T) {
82-
provider := &mockFieldsProvider{
83-
paramFields: []*mysql.Field{{Name: []byte("?"), Type: mysql.MYSQL_TYPE_LONG}},
84-
columnFields: []*mysql.Field{{Name: []byte("id"), Type: mysql.MYSQL_TYPE_LONGLONG}},
79+
provider := &stmt.PreparedStmt{
80+
RawParamFields: [][]byte{paramField.Dump()},
81+
RawColumnFields: [][]byte{columnField.Dump()},
8582
}
8683
c := &Conn{
8784
h: &mockPrepareHandler{context: provider, paramCount: 1, columnCount: 1},
@@ -90,9 +87,9 @@ func TestStmtPrepareWithFieldsProvider(t *testing.T) {
9087

9188
result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT id FROM t WHERE id = ?"...))
9289

93-
stmt := result.(*Stmt)
94-
require.NotNil(t, stmt.ParamFields)
95-
require.NotNil(t, stmt.ColumnFields)
96-
require.Equal(t, mysql.MYSQL_TYPE_LONG, stmt.ParamFields[0].Type)
97-
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, stmt.ColumnFields[0].Type)
90+
st := result.(*Stmt)
91+
require.NotNil(t, st.RawParamFields)
92+
require.NotNil(t, st.RawColumnFields)
93+
require.Equal(t, mysql.MYSQL_TYPE_LONG, st.GetParamFields()[0].Type)
94+
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, st.GetColumnFields()[0].Type)
9895
}

stmt/stmt.go

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
package stmt
2+
3+
import "github.com/go-mysql-org/go-mysql/mysql"
4+
5+
type PreparedStmt struct {
6+
ID uint32
7+
Params int
8+
Columns int
9+
Warnings int
10+
11+
RawParamFields [][]byte
12+
RawColumnFields [][]byte
13+
14+
paramFields []*mysql.Field
15+
columnFields []*mysql.Field
16+
}
17+
18+
func (s *PreparedStmt) GetParamFields() []*mysql.Field {
19+
if s.RawParamFields == nil {
20+
return nil
21+
}
22+
if s.paramFields == nil {
23+
s.paramFields = make([]*mysql.Field, len(s.RawParamFields))
24+
for i, raw := range s.RawParamFields {
25+
s.paramFields[i] = &mysql.Field{}
26+
_ = s.paramFields[i].Parse(raw)
27+
}
28+
}
29+
return s.paramFields
30+
}
31+
32+
func (s *PreparedStmt) GetColumnFields() []*mysql.Field {
33+
if s.RawColumnFields == nil {
34+
return nil
35+
}
36+
if s.columnFields == nil {
37+
s.columnFields = make([]*mysql.Field, len(s.RawColumnFields))
38+
for i, raw := range s.RawColumnFields {
39+
s.columnFields[i] = &mysql.Field{}
40+
_ = s.columnFields[i].Parse(raw)
41+
}
42+
}
43+
return s.columnFields
44+
}

0 commit comments

Comments
 (0)