diff --git a/client/stmt.go b/client/stmt.go index 106e176de..cd64f3524 100644 --- a/client/stmt.go +++ b/client/stmt.go @@ -8,25 +8,25 @@ import ( "runtime" "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/stmt" "github.com/go-mysql-org/go-mysql/utils" "github.com/pingcap/errors" ) type Stmt struct { - conn *Conn - id uint32 - - params int - columns int + conn *Conn warnings int + + // PreparedStmt contains common fields shared with server.Stmt for proxy passthrough + stmt.PreparedStmt } func (s *Stmt) ParamNum() int { - return s.params + return s.Params } func (s *Stmt) ColumnNum() int { - return s.columns + return s.Columns } func (s *Stmt) WarningsNum() int { @@ -50,7 +50,7 @@ func (s *Stmt) ExecuteSelectStreaming(result *mysql.Result, perRowCb SelectPerRo } func (s *Stmt) Close() error { - if err := s.conn.writeCommandUint32(mysql.COM_STMT_CLOSE, s.id); err != nil { + if err := s.conn.writeCommandUint32(mysql.COM_STMT_CLOSE, s.ID); err != nil { return errors.Trace(err) } @@ -60,10 +60,10 @@ func (s *Stmt) Close() error { // https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html func (s *Stmt) write(args ...interface{}) error { defer clear(s.conn.queryAttributes) - paramsNum := s.params + paramsNum := s.Params if len(args) != paramsNum { - return fmt.Errorf("argument mismatch, need %d but got %d", s.params, len(args)) + return fmt.Errorf("argument mismatch, need %d but got %d", s.Params, len(args)) } if (s.conn.capability&mysql.CLIENT_QUERY_ATTRIBUTES > 0) && (s.conn.includeLine >= 0) { @@ -187,7 +187,7 @@ func (s *Stmt) write(args ...interface{}) error { data.Write([]byte{0, 0, 0, 0}) data.WriteByte(mysql.COM_STMT_EXECUTE) - data.Write([]byte{byte(s.id), byte(s.id >> 8), byte(s.id >> 16), byte(s.id >> 24)}) + data.Write([]byte{byte(s.ID), byte(s.ID >> 8), byte(s.ID >> 16), byte(s.ID >> 24)}) flags := mysql.CURSOR_TYPE_NO_CURSOR if paramsNum > 0 { @@ -254,15 +254,15 @@ func (c *Conn) Prepare(query string) (*Stmt, error) { pos := 1 // for statement id - s.id = binary.LittleEndian.Uint32(data[pos:]) + s.ID = binary.LittleEndian.Uint32(data[pos:]) pos += 4 // number columns - s.columns = int(binary.LittleEndian.Uint16(data[pos:])) + s.Columns = int(binary.LittleEndian.Uint16(data[pos:])) pos += 2 // number params - s.params = int(binary.LittleEndian.Uint16(data[pos:])) + s.Params = int(binary.LittleEndian.Uint16(data[pos:])) pos += 2 // reserved @@ -274,11 +274,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) { // pos += 2 } - if s.params > 0 { - for range s.params { - if _, err := s.conn.ReadPacket(); err != nil { + if s.Params > 0 { + s.RawParamFields = make([][]byte, s.Params) + for i := range s.Params { + data, err := s.conn.ReadPacket() + if err != nil { return nil, errors.Trace(err) } + s.RawParamFields[i] = data } if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 { if packet, err := s.conn.ReadPacket(); err != nil { @@ -289,12 +292,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) { } } - if s.columns > 0 { - // TODO process when CLIENT_CACHE_METADATA enabled - for range s.columns { - if _, err := s.conn.ReadPacket(); err != nil { + if s.Columns > 0 { + s.RawColumnFields = make([][]byte, s.Columns) + for i := range s.Columns { + data, err := s.conn.ReadPacket() + if err != nil { return nil, errors.Trace(err) } + s.RawColumnFields[i] = data } if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 { if packet, err := s.conn.ReadPacket(); err != nil { diff --git a/server/command.go b/server/command.go index e23244bbd..a645798a9 100644 --- a/server/command.go +++ b/server/command.go @@ -7,6 +7,7 @@ import ( "github.com/go-mysql-org/go-mysql/mysql" "github.com/go-mysql-org/go-mysql/replication" + "github.com/go-mysql-org/go-mysql/stmt" "github.com/go-mysql-org/go-mysql/utils" ) @@ -112,6 +113,10 @@ func (c *Conn) dispatch(data []byte) interface{} { if st.Params, st.Columns, st.Context, err = c.h.HandleStmtPrepare(st.Query); err != nil { return err } else { + if provider, ok := st.Context.(*stmt.PreparedStmt); ok { + st.RawParamFields = provider.RawParamFields + st.RawColumnFields = provider.RawColumnFields + } st.ResetParams() c.stmts[c.stmtID] = st return st diff --git a/server/stmt.go b/server/stmt.go index ca9eae796..553c9e695 100644 --- a/server/stmt.go +++ b/server/stmt.go @@ -7,6 +7,7 @@ import ( "strconv" "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/stmt" "github.com/pingcap/errors" ) @@ -16,15 +17,13 @@ var ( ) type Stmt struct { - ID uint32 Query string - - Params int - Columns int - - Args []interface{} + Args []interface{} Context interface{} + + // PreparedStmt contains common fields shared with client.Stmt for proxy passthrough + stmt.PreparedStmt } func (s *Stmt) Rest(params int, columns int, context interface{}) { @@ -61,7 +60,11 @@ func (c *Conn) writePrepare(s *Stmt) error { if s.Params > 0 { for i := 0; i < s.Params; i++ { data = data[0:4] - data = append(data, paramFieldData...) + if s.RawParamFields != nil && i < len(s.RawParamFields) { + data = append(data, s.RawParamFields[i]...) + } else { + data = append(data, paramFieldData...) + } if err := c.WritePacket(data); err != nil { return errors.Trace(err) @@ -76,7 +79,11 @@ func (c *Conn) writePrepare(s *Stmt) error { if s.Columns > 0 { for i := 0; i < s.Columns; i++ { data = data[0:4] - data = append(data, columnFieldData...) + if s.RawColumnFields != nil && i < len(s.RawColumnFields) { + data = append(data, s.RawColumnFields[i]...) + } else { + data = append(data, columnFieldData...) + } if err := c.WritePacket(data); err != nil { return errors.Trace(err) diff --git a/server/stmt_test.go b/server/stmt_test.go index bf9142f54..935597f68 100644 --- a/server/stmt_test.go +++ b/server/stmt_test.go @@ -3,6 +3,8 @@ package server import ( "testing" + "github.com/go-mysql-org/go-mysql/mysql" + "github.com/go-mysql-org/go-mysql/stmt" "github.com/stretchr/testify/require" ) @@ -46,3 +48,52 @@ func TestHandleStmtExecute(t *testing.T) { } } } + +type mockPrepareHandler struct { + EmptyHandler + context any + paramCount, columnCount int +} + +func (h *mockPrepareHandler) HandleStmtPrepare(query string) (int, int, any, error) { + return h.paramCount, h.columnCount, h.context, nil +} + +func TestStmtPrepareWithoutPreparedStmt(t *testing.T) { + c := &Conn{ + h: &mockPrepareHandler{context: "plain string", paramCount: 1, columnCount: 1}, + stmts: make(map[uint32]*Stmt), + } + + result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT * FROM t"...)) + + st := result.(*Stmt) + require.Nil(t, st.RawParamFields) + require.Nil(t, st.RawColumnFields) +} + +func TestStmtPrepareWithPreparedStmt(t *testing.T) { + paramField := &mysql.Field{Name: []byte("?"), Type: mysql.MYSQL_TYPE_LONG} + columnField := &mysql.Field{Name: []byte("id"), Type: mysql.MYSQL_TYPE_LONGLONG} + + provider := &stmt.PreparedStmt{ + RawParamFields: [][]byte{paramField.Dump()}, + RawColumnFields: [][]byte{columnField.Dump()}, + } + c := &Conn{ + h: &mockPrepareHandler{context: provider, paramCount: 1, columnCount: 1}, + stmts: make(map[uint32]*Stmt), + } + + result := c.dispatch(append([]byte{mysql.COM_STMT_PREPARE}, "SELECT id FROM t WHERE id = ?"...)) + + st := result.(*Stmt) + require.NotNil(t, st.RawParamFields) + require.NotNil(t, st.RawColumnFields) + paramFields, err := st.GetParamFields() + require.NoError(t, err) + require.Equal(t, mysql.MYSQL_TYPE_LONG, paramFields[0].Type) + columnFields, err := st.GetColumnFields() + require.NoError(t, err) + require.Equal(t, mysql.MYSQL_TYPE_LONGLONG, columnFields[0].Type) +} diff --git a/stmt/stmt.go b/stmt/stmt.go new file mode 100644 index 000000000..0f75724c0 --- /dev/null +++ b/stmt/stmt.go @@ -0,0 +1,51 @@ +package stmt + +import "github.com/go-mysql-org/go-mysql/mysql" + +type PreparedStmt struct { + ID uint32 + Params int + Columns int + + RawParamFields [][]byte + RawColumnFields [][]byte + + paramFields []*mysql.Field + columnFields []*mysql.Field +} + +func (s *PreparedStmt) GetParamFields() ([]*mysql.Field, error) { + if s.RawParamFields == nil { + return nil, nil + } + if s.paramFields == nil { + fields := make([]*mysql.Field, len(s.RawParamFields)) + for i, raw := range s.RawParamFields { + field := &mysql.Field{} + if err := field.Parse(raw); err != nil { + return nil, err + } + fields[i] = field + } + s.paramFields = fields + } + return s.paramFields, nil +} + +func (s *PreparedStmt) GetColumnFields() ([]*mysql.Field, error) { + if s.RawColumnFields == nil { + return nil, nil + } + if s.columnFields == nil { + fields := make([]*mysql.Field, len(s.RawColumnFields)) + for i, raw := range s.RawColumnFields { + field := &mysql.Field{} + if err := field.Parse(raw); err != nil { + return nil, err + } + fields[i] = field + } + s.columnFields = fields + } + return s.columnFields, nil +}