@@ -8,49 +8,31 @@ import (
88 "runtime"
99
1010 "github.com/go-mysql-org/go-mysql/mysql"
11+ "github.com/go-mysql-org/go-mysql/stmt"
1112 "github.com/go-mysql-org/go-mysql/utils"
1213 "github.com/pingcap/errors"
1314)
1415
1516type Stmt struct {
16- conn * Conn
17- id uint32
18-
19- params int
20- columns int
17+ conn * Conn
2118 warnings int
2219
23- // Field definitions from the PREPARE response (for proxy passthrough)
24- paramFields []* mysql.Field
25- columnFields []* mysql.Field
20+ // PreparedStmt contains common fields shared with server.Stmt for proxy passthrough
21+ stmt.PreparedStmt
2622}
2723
2824func (s * Stmt ) ParamNum () int {
29- return s .params
25+ return s .Params
3026}
3127
3228func (s * Stmt ) ColumnNum () int {
33- return s .columns
29+ return s .Columns
3430}
3531
3632func (s * Stmt ) WarningsNum () int {
3733 return s .warnings
3834}
3935
40- // GetParamFields returns the parameter field definitions from the PREPARE response.
41- // Implements server.StmtFieldsProvider for proxy passthrough.
42- // The caller should not modify the returned slice.
43- func (s * Stmt ) GetParamFields () []* mysql.Field {
44- return s .paramFields
45- }
46-
47- // GetColumnFields returns the column field definitions from the PREPARE response.
48- // Implements server.StmtFieldsProvider for proxy passthrough.
49- // The caller should not modify the returned slice.
50- func (s * Stmt ) GetColumnFields () []* mysql.Field {
51- return s .columnFields
52- }
53-
5436func (s * Stmt ) Execute (args ... interface {}) (* mysql.Result , error ) {
5537 if err := s .write (args ... ); err != nil {
5638 return nil , errors .Trace (err )
@@ -68,7 +50,7 @@ func (s *Stmt) ExecuteSelectStreaming(result *mysql.Result, perRowCb SelectPerRo
6850}
6951
7052func (s * Stmt ) Close () error {
71- if err := s .conn .writeCommandUint32 (mysql .COM_STMT_CLOSE , s .id ); err != nil {
53+ if err := s .conn .writeCommandUint32 (mysql .COM_STMT_CLOSE , s .ID ); err != nil {
7254 return errors .Trace (err )
7355 }
7456
@@ -78,10 +60,10 @@ func (s *Stmt) Close() error {
7860// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
7961func (s * Stmt ) write (args ... interface {}) error {
8062 defer clear (s .conn .queryAttributes )
81- paramsNum := s .params
63+ paramsNum := s .Params
8264
8365 if len (args ) != paramsNum {
84- return fmt .Errorf ("argument mismatch, need %d but got %d" , s .params , len (args ))
66+ return fmt .Errorf ("argument mismatch, need %d but got %d" , s .Params , len (args ))
8567 }
8668
8769 if (s .conn .capability & mysql .CLIENT_QUERY_ATTRIBUTES > 0 ) && (s .conn .includeLine >= 0 ) {
@@ -205,7 +187,7 @@ func (s *Stmt) write(args ...interface{}) error {
205187
206188 data .Write ([]byte {0 , 0 , 0 , 0 })
207189 data .WriteByte (mysql .COM_STMT_EXECUTE )
208- data .Write ([]byte {byte (s .id ), byte (s .id >> 8 ), byte (s .id >> 16 ), byte (s .id >> 24 )})
190+ data .Write ([]byte {byte (s .ID ), byte (s .ID >> 8 ), byte (s .ID >> 16 ), byte (s .ID >> 24 )})
209191
210192 flags := mysql .CURSOR_TYPE_NO_CURSOR
211193 if paramsNum > 0 {
@@ -272,15 +254,15 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
272254 pos := 1
273255
274256 // for statement id
275- s .id = binary .LittleEndian .Uint32 (data [pos :])
257+ s .ID = binary .LittleEndian .Uint32 (data [pos :])
276258 pos += 4
277259
278260 // number columns
279- s .columns = int (binary .LittleEndian .Uint16 (data [pos :]))
261+ s .Columns = int (binary .LittleEndian .Uint16 (data [pos :]))
280262 pos += 2
281263
282264 // number params
283- s .params = int (binary .LittleEndian .Uint16 (data [pos :]))
265+ s .Params = int (binary .LittleEndian .Uint16 (data [pos :]))
284266 pos += 2
285267
286268 // reserved
@@ -292,17 +274,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
292274 // pos += 2
293275 }
294276
295- if s .params > 0 {
296- s .paramFields = make ([]* mysql. Field , s .params )
297- for i := range s .params {
277+ if s .Params > 0 {
278+ s .RawParamFields = make ([][] byte , s .Params )
279+ for i := range s .Params {
298280 data , err := s .conn .ReadPacket ()
299281 if err != nil {
300282 return nil , errors .Trace (err )
301283 }
302- s .paramFields [i ] = & mysql.Field {}
303- if err := s .paramFields [i ].Parse (data ); err != nil {
304- return nil , errors .Trace (err )
305- }
284+ s .RawParamFields [i ] = data
306285 }
307286 if s .conn .capability & mysql .CLIENT_DEPRECATE_EOF == 0 {
308287 if packet , err := s .conn .ReadPacket (); err != nil {
@@ -313,17 +292,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
313292 }
314293 }
315294
316- if s .columns > 0 {
317- s .columnFields = make ([]* mysql. Field , s .columns )
318- for i := range s .columns {
295+ if s .Columns > 0 {
296+ s .RawColumnFields = make ([][] byte , s .Columns )
297+ for i := range s .Columns {
319298 data , err := s .conn .ReadPacket ()
320299 if err != nil {
321300 return nil , errors .Trace (err )
322301 }
323- s .columnFields [i ] = & mysql.Field {}
324- if err := s .columnFields [i ].Parse (data ); err != nil {
325- return nil , errors .Trace (err )
326- }
302+ s .RawColumnFields [i ] = data
327303 }
328304 if s .conn .capability & mysql .CLIENT_DEPRECATE_EOF == 0 {
329305 if packet , err := s .conn .ReadPacket (); err != nil {
0 commit comments