@@ -8,47 +8,28 @@ import (
88 "runtime"
99
1010 "github.com/go-mysql-org/go-mysql/mysql"
11+ "github.com/go-mysql-org/go-mysql/stmt"
1112 "github.com/go-mysql-org/go-mysql/utils"
1213 "github.com/pingcap/errors"
1314)
1415
1516type Stmt struct {
1617 conn * Conn
17- id uint32
1818
19- params int
20- columns int
21- warnings int
22-
23- // Field definitions from the PREPARE response (for proxy passthrough)
24- paramFields []* mysql.Field
25- columnFields []* mysql.Field
19+ // PreparedStmt contains common fields shared with server.Stmt for proxy passthrough
20+ stmt.PreparedStmt
2621}
2722
2823func (s * Stmt ) ParamNum () int {
29- return s .params
24+ return s .Params
3025}
3126
3227func (s * Stmt ) ColumnNum () int {
33- return s .columns
28+ return s .Columns
3429}
3530
3631func (s * Stmt ) WarningsNum () int {
37- return s .warnings
38- }
39-
40- // GetParamFields returns the parameter field definitions from the PREPARE response.
41- // Implements server.StmtFieldsProvider for proxy passthrough.
42- // The caller should not modify the returned slice.
43- func (s * Stmt ) GetParamFields () []* mysql.Field {
44- return s .paramFields
45- }
46-
47- // GetColumnFields returns the column field definitions from the PREPARE response.
48- // Implements server.StmtFieldsProvider for proxy passthrough.
49- // The caller should not modify the returned slice.
50- func (s * Stmt ) GetColumnFields () []* mysql.Field {
51- return s .columnFields
32+ return s .Warnings
5233}
5334
5435func (s * Stmt ) Execute (args ... interface {}) (* mysql.Result , error ) {
@@ -68,7 +49,7 @@ func (s *Stmt) ExecuteSelectStreaming(result *mysql.Result, perRowCb SelectPerRo
6849}
6950
7051func (s * Stmt ) Close () error {
71- if err := s .conn .writeCommandUint32 (mysql .COM_STMT_CLOSE , s .id ); err != nil {
52+ if err := s .conn .writeCommandUint32 (mysql .COM_STMT_CLOSE , s .ID ); err != nil {
7253 return errors .Trace (err )
7354 }
7455
@@ -78,10 +59,10 @@ func (s *Stmt) Close() error {
7859// https://dev.mysql.com/doc/dev/mysql-server/latest/page_protocol_com_stmt_execute.html
7960func (s * Stmt ) write (args ... interface {}) error {
8061 defer clear (s .conn .queryAttributes )
81- paramsNum := s .params
62+ paramsNum := s .Params
8263
8364 if len (args ) != paramsNum {
84- return fmt .Errorf ("argument mismatch, need %d but got %d" , s .params , len (args ))
65+ return fmt .Errorf ("argument mismatch, need %d but got %d" , s .Params , len (args ))
8566 }
8667
8768 if (s .conn .capability & mysql .CLIENT_QUERY_ATTRIBUTES > 0 ) && (s .conn .includeLine >= 0 ) {
@@ -205,7 +186,7 @@ func (s *Stmt) write(args ...interface{}) error {
205186
206187 data .Write ([]byte {0 , 0 , 0 , 0 })
207188 data .WriteByte (mysql .COM_STMT_EXECUTE )
208- data .Write ([]byte {byte (s .id ), byte (s .id >> 8 ), byte (s .id >> 16 ), byte (s .id >> 24 )})
189+ data .Write ([]byte {byte (s .ID ), byte (s .ID >> 8 ), byte (s .ID >> 16 ), byte (s .ID >> 24 )})
209190
210191 flags := mysql .CURSOR_TYPE_NO_CURSOR
211192 if paramsNum > 0 {
@@ -272,37 +253,34 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
272253 pos := 1
273254
274255 // for statement id
275- s .id = binary .LittleEndian .Uint32 (data [pos :])
256+ s .ID = binary .LittleEndian .Uint32 (data [pos :])
276257 pos += 4
277258
278259 // number columns
279- s .columns = int (binary .LittleEndian .Uint16 (data [pos :]))
260+ s .Columns = int (binary .LittleEndian .Uint16 (data [pos :]))
280261 pos += 2
281262
282263 // number params
283- s .params = int (binary .LittleEndian .Uint16 (data [pos :]))
264+ s .Params = int (binary .LittleEndian .Uint16 (data [pos :]))
284265 pos += 2
285266
286267 // reserved
287268 pos += 1
288269
289270 if len (data ) >= 12 {
290271 // warnings
291- s .warnings = int (binary .LittleEndian .Uint16 (data [pos :]))
272+ s .Warnings = int (binary .LittleEndian .Uint16 (data [pos :]))
292273 // pos += 2
293274 }
294275
295- if s .params > 0 {
296- s .paramFields = make ([]* mysql. Field , s .params )
297- for i := range s .params {
276+ if s .Params > 0 {
277+ s .RawParamFields = make ([][] byte , s .Params )
278+ for i := range s .Params {
298279 data , err := s .conn .ReadPacket ()
299280 if err != nil {
300281 return nil , errors .Trace (err )
301282 }
302- s .paramFields [i ] = & mysql.Field {}
303- if err := s .paramFields [i ].Parse (data ); err != nil {
304- return nil , errors .Trace (err )
305- }
283+ s .RawParamFields [i ] = data
306284 }
307285 if s .conn .capability & mysql .CLIENT_DEPRECATE_EOF == 0 {
308286 if packet , err := s .conn .ReadPacket (); err != nil {
@@ -313,17 +291,14 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
313291 }
314292 }
315293
316- if s .columns > 0 {
317- s .columnFields = make ([]* mysql. Field , s .columns )
318- for i := range s .columns {
294+ if s .Columns > 0 {
295+ s .RawColumnFields = make ([][] byte , s .Columns )
296+ for i := range s .Columns {
319297 data , err := s .conn .ReadPacket ()
320298 if err != nil {
321299 return nil , errors .Trace (err )
322300 }
323- s .columnFields [i ] = & mysql.Field {}
324- if err := s .columnFields [i ].Parse (data ); err != nil {
325- return nil , errors .Trace (err )
326- }
301+ s .RawColumnFields [i ] = data
327302 }
328303 if s .conn .capability & mysql .CLIENT_DEPRECATE_EOF == 0 {
329304 if packet , err := s .conn .ReadPacket (); err != nil {
0 commit comments