Skip to content

Commit d6fca48

Browse files
committed
Store prepared statements field definitions
1 parent 6675966 commit d6fca48

File tree

3 files changed

+59
-7
lines changed

3 files changed

+59
-7
lines changed

client/stmt.go

Lines changed: 32 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,10 @@ type Stmt struct {
1919
params int
2020
columns int
2121
warnings int
22+
23+
// Field definitions from the PREPARE response (for proxy passthrough)
24+
ParamFields []*mysql.Field
25+
ColumnFields []*mysql.Field
2226
}
2327

2428
func (s *Stmt) ParamNum() int {
@@ -33,6 +37,18 @@ func (s *Stmt) WarningsNum() int {
3337
return s.warnings
3438
}
3539

40+
// GetParamFields returns the parameter field definitions from the PREPARE response.
41+
// Implements server.StmtFieldsProvider for proxy passthrough.
42+
func (s *Stmt) GetParamFields() []*mysql.Field {
43+
return s.ParamFields
44+
}
45+
46+
// GetColumnFields returns the column field definitions from the PREPARE response.
47+
// Implements server.StmtFieldsProvider for proxy passthrough.
48+
func (s *Stmt) GetColumnFields() []*mysql.Field {
49+
return s.ColumnFields
50+
}
51+
3652
func (s *Stmt) Execute(args ...interface{}) (*mysql.Result, error) {
3753
if err := s.write(args...); err != nil {
3854
return nil, errors.Trace(err)
@@ -275,8 +291,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
275291
}
276292

277293
if s.params > 0 {
278-
for range s.params {
279-
if _, err := s.conn.ReadPacket(); err != nil {
294+
s.ParamFields = make([]*mysql.Field, s.params)
295+
for i := range s.params {
296+
data, err := s.conn.ReadPacket()
297+
if err != nil {
298+
return nil, errors.Trace(err)
299+
}
300+
s.ParamFields[i] = &mysql.Field{}
301+
if err := s.ParamFields[i].Parse(data); err != nil {
280302
return nil, errors.Trace(err)
281303
}
282304
}
@@ -290,9 +312,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
290312
}
291313

292314
if s.columns > 0 {
293-
// TODO process when CLIENT_CACHE_METADATA enabled
294-
for range s.columns {
295-
if _, err := s.conn.ReadPacket(); err != nil {
315+
s.ColumnFields = make([]*mysql.Field, s.columns)
316+
for i := range s.columns {
317+
data, err := s.conn.ReadPacket()
318+
if err != nil {
319+
return nil, errors.Trace(err)
320+
}
321+
s.ColumnFields[i] = &mysql.Field{}
322+
if err := s.ColumnFields[i].Parse(data); err != nil {
296323
return nil, errors.Trace(err)
297324
}
298325
}

server/command.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ import (
1010
"github.com/go-mysql-org/go-mysql/utils"
1111
)
1212

13+
// StmtFieldsProvider is an optional interface that prepared statement contexts can implement
14+
// to provide field definitions for proxy passthrough scenarios.
15+
type StmtFieldsProvider interface {
16+
GetParamFields() []*mysql.Field
17+
GetColumnFields() []*mysql.Field
18+
}
19+
1320
// Handler is what a server needs to implement the client-server protocol
1421
type Handler interface {
1522
// handle COM_INIT_DB command, you can check whether the dbName is valid, or other.
@@ -112,6 +119,12 @@ func (c *Conn) dispatch(data []byte) interface{} {
112119
if st.Params, st.Columns, st.Context, err = c.h.HandleStmtPrepare(st.Query); err != nil {
113120
return err
114121
} else {
122+
// If context provides field definitions (e.g., from a backend prepared statement),
123+
// use them for accurate metadata passthrough in proxy scenarios.
124+
if provider, ok := st.Context.(StmtFieldsProvider); ok {
125+
st.ParamFields = provider.GetParamFields()
126+
st.ColumnFields = provider.GetColumnFields()
127+
}
115128
st.ResetParams()
116129
c.stmts[c.stmtID] = st
117130
return st

server/stmt.go

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ type Stmt struct {
2525
Args []interface{}
2626

2727
Context interface{}
28+
29+
// Field definitions for proxy passthrough (optional, uses dummy fields if nil)
30+
ParamFields []*mysql.Field
31+
ColumnFields []*mysql.Field
2832
}
2933

3034
func (s *Stmt) Rest(params int, columns int, context interface{}) {
@@ -61,7 +65,11 @@ func (c *Conn) writePrepare(s *Stmt) error {
6165
if s.Params > 0 {
6266
for i := 0; i < s.Params; i++ {
6367
data = data[0:4]
64-
data = append(data, paramFieldData...)
68+
if s.ParamFields != nil && i < len(s.ParamFields) {
69+
data = append(data, s.ParamFields[i].Dump()...)
70+
} else {
71+
data = append(data, paramFieldData...)
72+
}
6573

6674
if err := c.WritePacket(data); err != nil {
6775
return errors.Trace(err)
@@ -76,7 +84,11 @@ func (c *Conn) writePrepare(s *Stmt) error {
7684
if s.Columns > 0 {
7785
for i := 0; i < s.Columns; i++ {
7886
data = data[0:4]
79-
data = append(data, columnFieldData...)
87+
if s.ColumnFields != nil && i < len(s.ColumnFields) {
88+
data = append(data, s.ColumnFields[i].Dump()...)
89+
} else {
90+
data = append(data, columnFieldData...)
91+
}
8092

8193
if err := c.WritePacket(data); err != nil {
8294
return errors.Trace(err)

0 commit comments

Comments
 (0)