// GoMySQL - A MySQL client library for Go // // Copyright 2010-2011 Phil Bayfield. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package mysql // Imports import ( "bytes" "fmt" "io" "log" "net" "os" "strings" "sync" "time" ) // Constants const ( // General VERSION = "0.3.2" DEFAULT_PORT = "3306" DEFAULT_SOCKET = "/var/run/mysqld/mysqld.sock" MAX_PACKET_SIZE = 1<<24 - 1 PROTOCOL_41 = 41 PROTOCOL_40 = 40 DEFAULT_PROTOCOL = PROTOCOL_41 // Connection types TCP = "tcp" UNIX = "unix" // Log methods LOG_SCREEN = 0x0 LOG_FILE = 0x1 // Result storage methods RESULT_UNUSED = 0x0 RESULT_STORED = 0x1 RESULT_USED = 0x2 RESULT_FREE = 0x3 ) // Client struct type Client struct { // Mutex for thread safety sync.Mutex // Logging LogLevel uint8 LogType uint8 LogFile *os.File // Credentials network string raddr string user string passwd string dbname string // Connection conn io.ReadWriteCloser r *reader w *writer connected bool Reconnect bool // Sequence protocol uint8 sequence uint8 // Server settings serverVersion string serverProtocol uint8 serverFlags ClientFlag serverCharset uint8 serverStatus ServerStatus scrambleBuff []byte // Result AffectedRows uint64 LastInsertId uint64 Warnings uint16 result *Result } // Create new client func NewClient(protocol ...uint8) (c *Client) { if len(protocol) == 0 { protocol = make([]uint8, 1) protocol[0] = DEFAULT_PROTOCOL } c = &Client{ protocol: protocol[0], } return } // Connect to server via TCP func DialTCP(raddr, user, passwd string, dbname ...string) (c *Client, err error) { c = NewClient(DEFAULT_PROTOCOL) // Add port if not set if strings.Index(raddr, ":") == -1 { raddr += ":" + DEFAULT_PORT } // Connect to server err = c.Connect(TCP, raddr, user, passwd, dbname...) return } // Connect to server via unix socket func DialUnix(raddr, user, passwd string, dbname ...string) (c *Client, err error) { c = NewClient(DEFAULT_PROTOCOL) // Use default socket if socket is empty if raddr == "" { raddr = DEFAULT_SOCKET } // Connect to server err = c.Connect(UNIX, raddr, user, passwd, dbname...) return } // Connect to the server func (c *Client) Connect(network, raddr, user, passwd string, dbname ...string) (err error) { // Log connect c.log(1, "=== Begin connect ===") // Check not already connected if c.checkConn() { return &ClientError{CR_ALREADY_CONNECTED, CR_ALREADY_CONNECTED_STR} } // Reset client c.reset() // Store connection credentials c.network = network c.raddr = raddr c.user = user c.passwd = passwd if len(dbname) > 0 { c.dbname = dbname[0] } // Call connect err = c.connect() if err != nil { return } // Set connected c.connected = true return } // Close connection to server func (c *Client) Close() (err error) { // Log close c.log(1, "=== Begin close ===") // Check connection if !c.checkConn() { return &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR} } // Reset client c.reset() // Send close command c.command(COM_QUIT) // Close connection c.conn.Close() // Log disconnect c.log(1, "Disconnected") // Set connected c.connected = false return } // Change the current database func (c *Client) ChangeDb(dbname string) (err error) { // Auto reconnect defer func() { if err != nil && c.checkNet(err) && c.Reconnect { c.log(1, "!!! Lost connection to server !!!") c.connected = false err = c.reconnect() if err == nil { err = c.ChangeDb(dbname) } } }() // Log changeDb c.log(1, "=== Begin change db to '%s' ===", dbname) // Pre-run checks if !c.checkConn() || c.checkResult() { return &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR} } // Reset client c.reset() // Send close command err = c.command(COM_INIT_DB, dbname) if err != nil { return } // Read result from server c.sequence++ _, err = c.getResult(PACKET_OK | PACKET_ERROR) return } // Send a query/queries to the server func (c *Client) Query(sql string) (err error) { // Auto reconnect defer func() { if err != nil && c.checkNet(err) && c.Reconnect { c.log(1, "!!! Lost connection to server !!!") c.connected = false err = c.reconnect() if err == nil { err = c.Query(sql) } } }() // Log query c.log(1, "=== Begin query '%s' ===", sql) // Pre-run checks if !c.checkConn() || c.checkResult() { return &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR} } // Reset client c.reset() // Send query command err = c.command(COM_QUERY, sql) if err != nil { return } // Read result from server c.sequence++ _, err = c.getResult(PACKET_OK | PACKET_ERROR | PACKET_RESULT) if err != nil || c.result == nil { return } // Store fields err = c.getFields() return } // Fetch all rows for a result and store it, returning the result set func (c *Client) StoreResult() (result *Result, err error) { // Auto reconnect defer func() { err = c.simpleReconnect(err) }() // Log store result c.log(1, "=== Begin store result ===") // Check result if !c.checkResult() { return nil, &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR} } // Check if result already used/stored if c.result.mode != RESULT_UNUSED { return nil, &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR} } // Set storage mode c.result.mode = RESULT_STORED // Store all rows err = c.getAllRows() if err != nil { return } c.result.allRead = true return c.result, nil } // Use a result set, does not store rows func (c *Client) UseResult() (result *Result, err error) { // Auto reconnect defer func() { err = c.simpleReconnect(err) }() // Log use result c.log(1, "=== Begin use result ===") // Check result if !c.checkResult() { return nil, &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR} } // Check if result already used/stored if c.result.mode != RESULT_UNUSED { return nil, &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR} } // Set storage mode c.result.mode = RESULT_USED return c.result, nil } // Free the current result func (c *Client) FreeResult() (err error) { // Auto reconnect defer func() { err = c.simpleReconnect(err) }() // Log free result c.log(1, "=== Begin free result ===") // Check result if !c.checkResult() { return &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR} } // Check for unread rows if !c.result.allRead { // Read all rows err = c.getAllRows() if err != nil { return } } // Reset some of the properties to ensure any pointers are "destroyed" c.result.c = nil c.result.fieldCount = 0 c.result.fieldPos = 0 c.result.fields = nil c.result.rowPos = 0 c.result.rows = nil c.result.mode = RESULT_UNUSED c.result.allRead = false // Unset the result c.result = nil return } // Check if more results are available func (c *Client) MoreResults() bool { return c.serverStatus&SERVER_MORE_RESULTS_EXISTS > 0 } // Move to the next available result func (c *Client) NextResult() (more bool, err error) { // Auto reconnect defer func() { err = c.simpleReconnect(err) }() // Log next result c.log(1, "=== Begin next result ===") // Pre-run checks if !c.checkConn() || c.checkResult() { return false, &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR} } // Check for more results more = c.MoreResults() if !more { return } // Read result from server c.sequence++ _, err = c.getResult(PACKET_OK | PACKET_ERROR | PACKET_RESULT) // Store fields err = c.getFields() return } // Set autocommit func (c *Client) SetAutoCommit(state bool) (err error) { // Log set autocommit c.log(1, "=== Begin set autocommit ===") // Use set autocommit query sql := "set autocommit=" if state { sql += "1" } else { sql += "0" } return c.Query(sql) } // Start a transaction func (c *Client) Start() (err error) { // Log start transaction c.log(1, "=== Begin start transaction ===") // Use start transaction query return c.Query("start transaction") } // Commit a transaction func (c *Client) Commit() (err error) { // Log commit c.log(1, "=== Begin commit ===") // Use commit query return c.Query("commit") } // Rollback a transaction func (c *Client) Rollback() (err error) { // Log rollback c.log(1, "=== Begin rollback ===") // Use rollback query return c.Query("rollback") } // Escape a string func (c *Client) Escape(s string) (esc string) { var prev byte var b bytes.Buffer for i := 0; i < len(s); i++ { switch s[i] { case '\'', '"': if prev != '\\' { b.WriteString(s[:i]) b.WriteByte('\\') s = s[i:] i = 0 } } prev = s[i] } b.WriteString(s) return b.String() } // Initialise a new statment func (c *Client) InitStmt() (stmt *Statement, err error) { // Check connection if !c.checkConn() { return nil, &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR} } // Create new statement stmt = new(Statement) stmt.c = c return } // Initialise and prepare a new statement func (c *Client) Prepare(sql string) (stmt *Statement, err error) { // Initialise a new statement stmt, err = c.InitStmt() if err != nil { return } // Prepare statement err = stmt.Prepare(sql) return } // Reset the client func (c *Client) reset() { c.sequence = 0 c.serverStatus = 0 c.AffectedRows = 0 c.LastInsertId = 0 c.Warnings = 0 c.result = nil } // Format errors func (c *Client) fmtError(str Error, args ...interface{}) Error { return Error(fmt.Sprintf(string(str), args...)) } // Logging func (c *Client) log(level uint8, format string, args ...interface{}) { // If logging is disabled, ignore if level > c.LogLevel { return } // Log based on logging type switch c.LogType { // Log to screen case LOG_SCREEN: log.Printf(format, args...) // Log to file case LOG_FILE: // If file pointer is nil return if c.LogFile == nil { return } // This is the same as log package does internally for logging // to the screen (via stderr) just requires an io.Writer l := log.New(c.LogFile, "", log.Ldate|log.Ltime) l.Printf(format, args...) // Not set default: return } } // Provide detailed log output for server capabilities func (c *Client) logCaps() { c.log(3, "=== Server Capabilities ===") c.log(3, "Long password support: %d", c.serverFlags&CLIENT_LONG_PASSWORD) c.log(3, "Found rows: %d", c.serverFlags&CLIENT_FOUND_ROWS>>1) c.log(3, "All column flags: %d", c.serverFlags&CLIENT_LONG_FLAG>>2) c.log(3, "Connect with database support: %d", c.serverFlags&CLIENT_CONNECT_WITH_DB>>3) c.log(3, "No schema support: %d", c.serverFlags&CLIENT_NO_SCHEMA>>4) c.log(3, "Compression support: %d", c.serverFlags&CLIENT_COMPRESS>>5) c.log(3, "ODBC support: %d", c.serverFlags&CLIENT_ODBC>>6) c.log(3, "Load data local support: %d", c.serverFlags&CLIENT_LOCAL_FILES>>7) c.log(3, "Ignore spaces: %d", c.serverFlags&CLIENT_IGNORE_SPACE>>8) c.log(3, "4.1 protocol support: %d", c.serverFlags&CLIENT_PROTOCOL_41>>9) c.log(3, "Interactive client: %d", c.serverFlags&CLIENT_INTERACTIVE>>10) c.log(3, "Switch to SSL: %d", c.serverFlags&CLIENT_SSL>>11) c.log(3, "Ignore sigpipes: %d", c.serverFlags&CLIENT_IGNORE_SIGPIPE>>12) c.log(3, "Transaction support: %d", c.serverFlags&CLIENT_TRANSACTIONS>>13) c.log(3, "4.1 protocol authentication: %d", c.serverFlags&CLIENT_SECURE_CONN>>15) } // Provide detailed log output for the server status flags func (c *Client) logStatus() { c.log(3, "=== Server Status ===") c.log(3, "In transaction: %d", c.serverStatus&SERVER_STATUS_IN_TRANS) c.log(3, "Auto commit enabled: %d", c.serverStatus&SERVER_STATUS_AUTOCOMMIT>>1) c.log(3, "More results exist: %d", c.serverStatus&SERVER_MORE_RESULTS_EXISTS>>3) c.log(3, "No good indexes were used: %d", c.serverStatus&SERVER_QUERY_NO_GOOD_INDEX_USED>>4) c.log(3, "No indexes were used: %d", c.serverStatus&SERVER_QUERY_NO_INDEX_USED>>5) c.log(3, "Cursor exists: %d", c.serverStatus&SERVER_STATUS_CURSOR_EXISTS>>6) c.log(3, "Last row has been sent: %d", c.serverStatus&SERVER_STATUS_LAST_ROW_SENT>>7) c.log(3, "Database dropped: %d", c.serverStatus&SERVER_STATUS_DB_DROPPED>>8) c.log(3, "No backslash escapes: %d", c.serverStatus&SERVER_STATUS_NO_BACKSLASH_ESCAPES>>9) c.log(3, "Metadata has changed: %d", c.serverStatus&SERVER_STATUS_METADATA_CHANGED>>10) } // Check if connected // @todo expand to perform an actual connection check func (c *Client) checkConn() bool { if c.connected { return true } return false } // Check if a result exists func (c *Client) checkResult() bool { if c.result != nil { return true } return false } // Check if a network error occurred func (c *Client) checkNet(err error) bool { if cErr, ok := err.(*ClientError); ok { if cErr.Errno == CR_SERVER_GONE_ERROR || cErr.Errno == CR_SERVER_LOST { return true } } return false } // Performs the actual connect func (c *Client) connect() (err error) { // Connect to server err = c.dial() if err != nil { return } // Read initial packet from server err = c.init() if err != nil { return } // Send auth packet to server c.sequence++ err = c.auth() if err != nil { return } // Read result from server c.sequence++ eof, err := c.getResult(PACKET_OK | PACKET_ERROR | PACKET_EOF) // If eof need to authenticate with a 3.23 password if eof { c.sequence++ // Create packet p := &packetPassword{ scrambleBuff: scramble323(c.scrambleBuff, []byte(c.passwd)), } p.sequence = c.sequence // Write packet err = c.w.writePacket(p) if err != nil { return } c.log(1, "[%d] Sent old password packet", p.sequence) // Read result c.sequence++ _, err = c.getResult(PACKET_OK | PACKET_ERROR) } return } // Connect to server func (c *Client) dial() (err error) { // Log connect c.log(1, "Connecting to server via %s to %s", c.network, c.raddr) // Connect to server c.conn, err = net.Dial(c.network, c.raddr) if err != nil { // Store error state if c.network == UNIX { err = &ClientError{CR_CONNECTION_ERROR, c.fmtError(CR_CONNECTION_ERROR_STR, c.raddr)} } if c.network == TCP { err = &ClientError{CR_CONN_HOST_ERROR, c.fmtError(CR_CONN_HOST_ERROR_STR, c.raddr)} } // Log error if cErr, ok := err.(*ClientError); ok { c.log(1, string(cErr.Err)) } return } // Log connect success c.log(1, "Connected to server") // Create reader and writer c.r = newReader(c.conn) c.w = newWriter(c.conn) // Set the reader default protocol c.r.protocol = c.protocol return } // Read initial packet from server func (c *Client) init() (err error) { // Log read packet c.log(1, "Reading handshake initialization packet from server") // Read packet p, err := c.r.readPacket(PACKET_INIT) if err != nil { return } err = c.checkSequence(p.(*packetInit).sequence) if err != nil { return } // Log success c.log(1, "[%d] Received handshake initialization packet", p.(*packetInit).sequence) // Assign values c.serverVersion = p.(*packetInit).serverVersion c.serverProtocol = p.(*packetInit).protocolVersion c.serverFlags = ClientFlag(p.(*packetInit).serverCaps) c.serverCharset = p.(*packetInit).serverLanguage c.serverStatus = ServerStatus(p.(*packetInit).serverStatus) c.scrambleBuff = p.(*packetInit).scrambleBuff // Extended logging [level 2+] if c.LogLevel > 1 { // Log server info c.log(2, "Server version: %s", c.serverVersion) c.log(2, "Protocol version: %d", c.serverProtocol) } // Full logging [level 3] if c.LogLevel > 2 { c.logCaps() c.logStatus() } // If we're using 4.1 protocol and server doesn't support, drop to 4.0 if c.protocol == PROTOCOL_41 && c.serverFlags&CLIENT_PROTOCOL_41 == 0 { c.protocol = PROTOCOL_40 c.r.protocol = PROTOCOL_40 } return } // Send auth packet to the server func (c *Client) auth() (err error) { // Log write packet c.log(1, "Sending authentication packet to server") // Construct packet p := &packetAuth{ clientFlags: uint32(CLIENT_MULTI_STATEMENTS | CLIENT_MULTI_RESULTS), maxPacketSize: MAX_PACKET_SIZE, charsetNumber: c.serverCharset, user: c.user, } // Add protocol and sequence p.protocol = c.protocol p.sequence = c.sequence // Adjust client flags based on server support if c.serverFlags&CLIENT_LONG_PASSWORD > 0 { p.clientFlags |= uint32(CLIENT_LONG_PASSWORD) } if c.serverFlags&CLIENT_LONG_FLAG > 0 { p.clientFlags |= uint32(CLIENT_LONG_FLAG) } if c.serverFlags&CLIENT_TRANSACTIONS > 0 { p.clientFlags |= uint32(CLIENT_TRANSACTIONS) } // Check protocol if c.protocol == PROTOCOL_41 { p.clientFlags |= uint32(CLIENT_PROTOCOL_41 | CLIENT_SECURE_CONN) p.scrambleBuff = scramble41(c.scrambleBuff, []byte(c.passwd)) // To specify a db name if c.serverFlags&CLIENT_CONNECT_WITH_DB > 0 && len(c.dbname) > 0 { p.clientFlags |= uint32(CLIENT_CONNECT_WITH_DB) p.database = c.dbname } } else { p.scrambleBuff = scramble323(c.scrambleBuff, []byte(c.passwd)) } // Write packet err = c.w.writePacket(p) if err != nil { return } // Log write success c.log(1, "[%d] Sent authentication packet", p.sequence) return } // Simple non-recovered reconnect func (c *Client) simpleReconnect(err error) error { if err != nil && c.checkNet(err) && c.Reconnect { c.log(1, "!!! Lost connection to server !!!") c.connected = false rcErr := c.reconnect() if rcErr != nil { return rcErr } } return err } // Perform reconnect if a network error occurs func (c *Client) reconnect() (err error) { // Log auto reconnect c.log(1, "=== Begin auto reconnect attempt ===") // Reset the client c.reset() // Attempt to reconnect for i := 0; i < 10; i++ { err = c.connect() if err == nil { c.connected = true break } time.Sleep(2000000000) } return } // Send a command to the server func (c *Client) command(command command, args ...interface{}) (err error) { // Log write packet c.log(1, "Sending command packet to server") // Simple validation, arg count switch command { // No args case COM_QUIT, COM_STATISTICS, COM_PROCESS_INFO, COM_DEBUG, COM_PING: if len(args) != 0 { return &ClientError{CR_UNKNOWN_ERROR, CR_UNKNOWN_ERROR_STR} } // 1 arg case COM_INIT_DB, COM_QUERY, COM_REFRESH, COM_SHUTDOWN, COM_PROCESS_KILL, COM_STMT_PREPARE, COM_STMT_CLOSE, COM_STMT_RESET: if len(args) != 1 { return &ClientError{CR_UNKNOWN_ERROR, CR_UNKNOWN_ERROR_STR} } // 1 or 2 args case COM_FIELD_LIST: if len(args) != 1 && len(args) != 2 { return &ClientError{CR_UNKNOWN_ERROR, CR_UNKNOWN_ERROR_STR} } // 2 args case COM_STMT_FETCH: if len(args) != 2 { return &ClientError{CR_UNKNOWN_ERROR, CR_UNKNOWN_ERROR_STR} } // 4 args case COM_CHANGE_USER: if len(args) != 4 { return &ClientError{CR_UNKNOWN_ERROR, CR_UNKNOWN_ERROR_STR} } // Everything else e.g. replication unsupported default: return &ClientError{CR_NOT_IMPLEMENTED, CR_NOT_IMPLEMENTED_STR} } // Construct packet p := &packetCommand{ command: command, args: args, } // Add protocol and sequence p.protocol = c.protocol p.sequence = c.sequence // Write packet err = c.w.writePacket(p) if err != nil { return &ClientError{CR_SERVER_LOST, CR_SERVER_LOST_STR} } // Log write success c.log(1, "[%d] Sent command packet", p.sequence) return } // Get field packets for a result func (c *Client) getFields() (err error) { // Check for a valid result if c.result == nil { return &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR} } // Read fields till EOF is returned for { c.sequence++ eof, err := c.getResult(PACKET_FIELD | PACKET_EOF) if err != nil { return } if eof { break } } return } // Get next row for a result func (c *Client) getRow() (eof bool, err error) { // Check for a valid result if c.result == nil { return false, &ClientError{CR_NO_RESULT_SET, CR_NO_RESULT_SET_STR} } // Read next row packet or EOF c.sequence++ eof, err = c.getResult(PACKET_ROW | PACKET_EOF) return } // Get all rows for the result func (c *Client) getAllRows() (err error) { for { eof, err := c.getRow() if err != nil { return } if eof { break } } return } // Get result func (c *Client) getResult(types packetType) (eof bool, err error) { // Log read result c.log(1, "Reading result packet from server") // Get result packet p, err := c.r.readPacket(types) if err != nil { return } // Process result packet switch p.(type) { default: err = &ClientError{CR_UNKNOWN_ERROR, CR_UNKNOWN_ERROR_STR} case *packetOK: err = handleOK(p.(*packetOK), c, &c.AffectedRows, &c.LastInsertId, &c.Warnings) case *packetError: err = handleError(p.(*packetError), c) case *packetEOF: eof = true err = handleEOF(p.(*packetEOF), c) case *packetResultSet: c.result = &Result{c: c} err = handleResultSet(p.(*packetResultSet), c, c.result) case *packetField: err = handleField(p.(*packetField), c, c.result) case *packetRowData: err = handleRow(p.(*packetRowData), c, c.result) } return } // Sequence check func (c *Client) checkSequence(sequence uint8) (err error) { if sequence != c.sequence { c.log(1, "Sequence doesn't match - expected %d but got %d, commands out of sync", c.sequence, sequence) return &ClientError{CR_COMMANDS_OUT_OF_SYNC, CR_COMMANDS_OUT_OF_SYNC_STR} } return }