Skip to content

Commit a37e0bb

Browse files
committed
Move statements into a new stmt package
1 parent e4cb587 commit a37e0bb

File tree

5 files changed

+92
-89
lines changed

5 files changed

+92
-89
lines changed

client/stmt.go

Lines changed: 21 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -8,49 +8,31 @@ import (
88
"runtime"
99

1010
"github.com/go-mysql-org/go-mysql/mysql"
11+
"github.com/go-mysql-org/go-mysql/stmt"
1112
"github.com/go-mysql-org/go-mysql/utils"
1213
"github.com/pingcap/errors"
1314
)
1415

1516
type Stmt struct {
16-
conn *Conn
17-
id uint32
18-
19-
params int
20-
columns int
17+
conn *Conn
2118
warnings int
2219

23-
// Field definitions from the PREPARE response (for proxy passthrough)
24-
paramFields []*mysql.Field
25-
columnFields []*mysql.Field
20+
// PreparedStmt contains common fields shared with server.Stmt for proxy passthrough
21+
stmt.PreparedStmt
2622
}
2723

2824
func (s *Stmt) ParamNum() int {
29-
return s.params
25+
return s.Params
3026
}
3127

3228
func (s *Stmt) ColumnNum() int {
33-
return s.columns
29+
return s.Columns
3430
}
3531

3632
func (s *Stmt) WarningsNum() int {
3733
return s.warnings
3834
}
3935

40-
// GetParamFields returns the parameter field definitions from the PREPARE response.
41-
// Implements server.StmtFieldsProvider for proxy passthrough.
42-
// The caller should not modify the returned slice.
43-
func (s *Stmt) GetParamFields() []*mysql.Field {
44-
return s.paramFields
45-
}
46-
47-
// GetColumnFields returns the column field definitions from the PREPARE response.
48-
// Implements server.StmtFieldsProvider for proxy passthrough.
49-
// The caller should not modify the returned slice.
50-
func (s *Stmt) GetColumnFields() []*mysql.Field {
51-
return s.columnFields
52-
}
53-
5436
func (s *Stmt) Execute(args ...interface{}) (*mysql.Result, error) {
5537
if err := s.write(args...); err != nil {
5638
return nil, errors.Trace(err)
@@ -68,7 +50,7 @@ func (s *Stmt) ExecuteSelectStreaming(result *mysql.Result, perRowCb SelectPerRo
6850
}
6951

7052
func (s *Stmt) Close() error {
71-
if err := s.conn.writeCommandUint32(mysql.COM_STMT_CLOSE, s.id); err != nil {
53+
if err := s.conn.writeCommandUint32(mysql.COM_STMT_CLOSE, s.ID); err != nil {
7254
return errors.Trace(err)
7355
}
7456

@@ -78,10 +60,10 @@ func (s *Stmt) Close() error {
7860
// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
7961
func (s *Stmt) write(args ...interface{}) error {
8062
defer clear(s.conn.queryAttributes)
81-
paramsNum := s.params
63+
paramsNum := s.Params
8264

8365
if len(args) != paramsNum {
84-
return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args))
66+
return fmt.Errorf("argument mismatch, need %d but got %d", s.Params, len(args))
8567
}
8668

8769
if (s.conn.capability&mysql.CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) {
@@ -205,7 +187,7 @@ func (s *Stmt) write(args ...interface{}) error {
205187

206188
data.Write([]byte{0, 0, 0, 0})
207189
data.WriteByte(mysql.COM_STMT_EXECUTE)
208-
data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)})
190+
data.Write([]byte{byte(s.ID), byte(s.ID >> 8), byte(s.ID >> 16), byte(s.ID >> 24)})
209191

210192
flags := mysql.CURSOR_TYPE_NO_CURSOR
211193
if paramsNum > 0 {
@@ -272,15 +254,15 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
272254
pos := 1
273255

274256
// for statement id
275-
s.id = binary.LittleEndian.Uint32(data[pos:])
257+
s.ID = binary.LittleEndian.Uint32(data[pos:])
276258
pos += 4
277259

278260
// number columns
279-
s.columns = int(binary.LittleEndian.Uint16(data[pos:]))
261+
s.Columns = int(binary.LittleEndian.Uint16(data[pos:]))
280262
pos += 2
281263

282264
// number params
283-
s.params = int(binary.LittleEndian.Uint16(data[pos:]))
265+
s.Params = int(binary.LittleEndian.Uint16(data[pos:]))
284266
pos += 2
285267

286268
// reserved
@@ -292,17 +274,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
292274
// pos += 2
293275
}
294276

295-
if s.params > 0 {
296-
s.paramFields = make([]*mysql.Field, s.params)
297-
for i := range s.params {
277+
if s.Params > 0 {
278+
s.RawParamFields = make([][]byte, s.Params)
279+
for i := range s.Params {
298280
data, err := s.conn.ReadPacket()
299281
if err != nil {
300282
return nil, errors.Trace(err)
301283
}
302-
s.paramFields[i] = &mysql.Field{}
303-
if err := s.paramFields[i].Parse(data); err != nil {
304-
return nil, errors.Trace(err)
305-
}
284+
s.RawParamFields[i] = data
306285
}
307286
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
308287
if packet, err := s.conn.ReadPacket(); err != nil {
@@ -313,17 +292,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
313292
}
314293
}
315294

316-
if s.columns > 0 {
317-
s.columnFields = make([]*mysql.Field, s.columns)
318-
for i := range s.columns {
295+
if s.Columns > 0 {
296+
s.RawColumnFields = make([][]byte, s.Columns)
297+
for i := range s.Columns {
319298
data, err := s.conn.ReadPacket()
320299
if err != nil {
321300
return nil, errors.Trace(err)
322301
}
323-
s.columnFields[i] = &mysql.Field{}
324-
if err := s.columnFields[i].Parse(data); err != nil {
325-
return nil, errors.Trace(err)
326-
}
302+
s.RawColumnFields[i] = data
327303
}
328304
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
329305
if packet, err := s.conn.ReadPacket(); err != nil {

server/command.go

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -7,16 +7,10 @@ import (
77

88
"github.com/go-mysql-org/go-mysql/mysql"
99
"github.com/go-mysql-org/go-mysql/replication"
10+
"github.com/go-mysql-org/go-mysql/stmt"
1011
"github.com/go-mysql-org/go-mysql/utils"
1112
)
1213

13-
// StmtFieldsProvider is an optional interface that prepared statement contexts can implement
14-
// to provide field definitions for proxy passthrough scenarios.
15-
type StmtFieldsProvider interface {
16-
GetParamFields() []*mysql.Field
17-
GetColumnFields() []*mysql.Field
18-
}
19-
2014
// Handler is what a server needs to implement the client-server protocol
2115
type Handler interface {
2216
// handle COM_INIT_DB command, you can check whether the dbName is valid, or other.
@@ -119,11 +113,9 @@ func (c *Conn) dispatch(data []byte) interface{} {
119113
if st.Params, st.Columns, st.Context, err = c.h.HandleStmtPrepare(st.Query); err != nil {
120114
return err
121115
} else {
122-
// If context provides field definitions (e.g., from a backend prepared statement),
123-
// use them for accurate metadata passthrough in proxy scenarios.
124-
if provider, ok := st.Context.(StmtFieldsProvider); ok {
125-
st.ParamFields = provider.GetParamFields()
126-
st.ColumnFields = provider.GetColumnFields()
116+
if provider, ok := st.Context.(*stmt.PreparedStmt); ok {
117+
st.RawParamFields = provider.RawParamFields
118+
st.RawColumnFields = provider.RawColumnFields
127119
}
128120
st.ResetParams()
129121
c.stmts[c.stmtID] = st

server/stmt.go

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strconv"
88

99
"github.com/go-mysql-org/go-mysql/mysql"
10+
"github.com/go-mysql-org/go-mysql/stmt"
1011
"github.com/pingcap/errors"
1112
)
1213

@@ -16,19 +17,13 @@ var (
1617
)
1718

1819
type Stmt struct {
19-
ID uint32
2020
Query string
21-
22-
Params int
23-
Columns int
24-
25-
Args []interface{}
21+
Args []interface{}
2622

2723
Context interface{}
2824

29-
// Field definitions for proxy passthrough (optional, uses dummy fields if nil)
30-
ParamFields []*mysql.Field
31-
ColumnFields []*mysql.Field
25+
// PreparedStmt contains common fields shared with client.Stmt for proxy passthrough
26+
stmt.PreparedStmt
3227
}
3328

3429
func (s *Stmt) Rest(params int, columns int, context interface{}) {
@@ -65,8 +60,8 @@ func (c *Conn) writePrepare(s *Stmt) error {
6560
if s.Params > 0 {
6661
for i := 0; i < s.Params; i++ {
6762
data = data[0:4]
68-
if s.ParamFields != nil && i < len(s.ParamFields) {
69-
data = append(data, s.ParamFields[i].Dump()...)
63+
if s.RawParamFields != nil && i < len(s.RawParamFields) {
64+
data = append(data, s.RawParamFields[i]...)
7065
} else {
7166
data = append(data, paramFieldData...)
7267
}
@@ -84,8 +79,8 @@ func (c *Conn) writePrepare(s *Stmt) error {
8479
if s.Columns > 0 {
8580
for i := 0; i < s.Columns; i++ {
8681
data = data[0:4]
87-
if s.ColumnFields != nil && i < len(s.ColumnFields) {
88-
data = append(data, s.ColumnFields[i].Dump()...)
82+
if s.RawColumnFields != nil && i < len(s.RawColumnFields) {
83+
data = append(data, s.RawColumnFields[i]...)
8984
} else {
9085
data = append(data, columnFieldData...)
9186
}

server/stmt_test.go

Lines changed: 16 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"testing"
55

66
"github.com/go-mysql-org/go-mysql/mysql"
7+
"github.com/go-mysql-org/go-mysql/stmt"
78
"github.com/stretchr/testify/require"
89
)
910

@@ -58,30 +59,26 @@ func (h *mockPrepareHandler) HandleStmtPrepare(query string) (int, int, any, err
5859
return h.paramCount, h.columnCount, h.context, nil
5960
}
6061

61-
func TestStmtPrepareWithoutFieldsProvider(t *testing.T) {
62+
func TestStmtPrepareWithoutPreparedStmt(t *testing.T) {
6263
c := &Conn{
6364
h: &mockPrepareHandler{context: "plain string", paramCount: 1, columnCount: 1},
6465
stmts: make(map[uint32]*Stmt),
6566
}
6667

6768
result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT * FROM t"...))
6869

69-
stmt := result.(*Stmt)
70-
require.Nil(t, stmt.ParamFields)
71-
require.Nil(t, stmt.ColumnFields)
70+
st := result.(*Stmt)
71+
require.Nil(t, st.RawParamFields)
72+
require.Nil(t, st.RawColumnFields)
7273
}
7374

74-
type mockFieldsProvider struct {
75-
paramFields, columnFields []*mysql.Field
76-
}
77-
78-
func (m *mockFieldsProvider) GetParamFields() []*mysql.Field { return m.paramFields }
79-
func (m *mockFieldsProvider) GetColumnFields() []*mysql.Field { return m.columnFields }
75+
func TestStmtPrepareWithPreparedStmt(t *testing.T) {
76+
paramField := &mysql.Field{Name: []byte("?"), Type: mysql.MYSQL_TYPE_LONG}
77+
columnField := &mysql.Field{Name: []byte("id"), Type: mysql.MYSQL_TYPE_LONGLONG}
8078

81-
func TestStmtPrepareWithFieldsProvider(t *testing.T) {
82-
provider := &mockFieldsProvider{
83-
paramFields: []*mysql.Field{{Name: []byte("?"), Type: mysql.MYSQL_TYPE_LONG}},
84-
columnFields: []*mysql.Field{{Name: []byte("id"), Type: mysql.MYSQL_TYPE_LONGLONG}},
79+
provider := &stmt.PreparedStmt{
80+
RawParamFields: [][]byte{paramField.Dump()},
81+
RawColumnFields: [][]byte{columnField.Dump()},
8582
}
8683
c := &Conn{
8784
h: &mockPrepareHandler{context: provider, paramCount: 1, columnCount: 1},
@@ -90,9 +87,9 @@ func TestStmtPrepareWithFieldsProvider(t *testing.T) {
9087

9188
result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT id FROM t WHERE id = ?"...))
9289

93-
stmt := result.(*Stmt)
94-
require.NotNil(t, stmt.ParamFields)
95-
require.NotNil(t, stmt.ColumnFields)
96-
require.Equal(t, mysql.MYSQL_TYPE_LONG, stmt.ParamFields[0].Type)
97-
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, stmt.ColumnFields[0].Type)
90+
st := result.(*Stmt)
91+
require.NotNil(t, st.RawParamFields)
92+
require.NotNil(t, st.RawColumnFields)
93+
require.Equal(t, mysql.MYSQL_TYPE_LONG, st.GetParamFields()[0].Type)
94+
require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, st.GetColumnFields()[0].Type)
9895
}

stmt/stmt.go

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
package stmt
2+
3+
import "github.com/go-mysql-org/go-mysql/mysql"
4+
5+
type PreparedStmt struct {
6+
ID uint32
7+
Params int
8+
Columns int
9+
10+
RawParamFields [][]byte
11+
RawColumnFields [][]byte
12+
13+
paramFields []*mysql.Field
14+
columnFields []*mysql.Field
15+
}
16+
17+
func (s *PreparedStmt) GetParamFields() []*mysql.Field {
18+
if s.RawParamFields == nil {
19+
return nil
20+
}
21+
if s.paramFields == nil {
22+
s.paramFields = make([]*mysql.Field, len(s.RawParamFields))
23+
for i, raw := range s.RawParamFields {
24+
s.paramFields[i] = &mysql.Field{}
25+
_ = s.paramFields[i].Parse(raw)
26+
}
27+
}
28+
return s.paramFields
29+
}
30+
31+
func (s *PreparedStmt) GetColumnFields() []*mysql.Field {
32+
if s.RawColumnFields == nil {
33+
return nil
34+
}
35+
if s.columnFields == nil {
36+
s.columnFields = make([]*mysql.Field, len(s.RawColumnFields))
37+
for i, raw := range s.RawColumnFields {
38+
s.columnFields[i] = &mysql.Field{}
39+
_ = s.columnFields[i].Parse(raw)
40+
}
41+
}
42+
return s.columnFields
43+
}

0 commit comments

Comments
 (0)