561 lines
16 KiB
Ruby
561 lines
16 KiB
Ruby
# -*- coding: binary -*-
|
|
# Copyright (C) 2008-2009 TOMITA Masahiro
|
|
# mailto:tommy@tmtm.org
|
|
|
|
require "socket"
|
|
require "timeout"
|
|
require "digest/sha1"
|
|
require "thread"
|
|
require "stringio"
|
|
|
|
class RbMysql
|
|
# MySQL network protocol
|
|
class Protocol
|
|
|
|
VERSION = 10
|
|
MAX_PACKET_LENGTH = 2**24-1
|
|
|
|
# convert Numeric to LengthCodedBinary
|
|
def self.lcb(num)
|
|
return "\xfb" if num.nil?
|
|
return [num].pack("C") if num < 251
|
|
return [252, num].pack("Cv") if num < 65536
|
|
return [253, num&0xffff, num>>16].pack("CvC") if num < 16777216
|
|
return [254, num&0xffffffff, num>>32].pack("CVV")
|
|
end
|
|
|
|
# convert String to LengthCodedString
|
|
def self.lcs(str)
|
|
str = Charset.to_binary str
|
|
lcb(str.length)+str
|
|
end
|
|
|
|
# convert LengthCodedBinary to Integer
|
|
# === Argument
|
|
# lcb :: [String] LengthCodedBinary. This value will be broken.
|
|
# === Return
|
|
# Integer or nil
|
|
def self.lcb2int!(lcb)
|
|
return nil if lcb.empty?
|
|
case v = lcb.slice!(0)
|
|
when ?\xfb
|
|
return nil
|
|
when ?\xfc
|
|
return lcb.slice!(0,2).unpack("v").first
|
|
when ?\xfd
|
|
c, v = lcb.slice!(0,3).unpack("Cv")
|
|
return (v << 8)+c
|
|
when ?\xfe
|
|
v1, v2 = lcb.slice!(0,8).unpack("VV")
|
|
return (v2 << 32)+v1
|
|
else
|
|
return v.ord
|
|
end
|
|
end
|
|
|
|
# convert LengthCodedString to String
|
|
# === Argument
|
|
# lcs :: [String] LengthCodedString. This value will be broken.
|
|
# === Return
|
|
# String or nil
|
|
def self.lcs2str!(lcs)
|
|
len = lcb2int! lcs
|
|
return len && lcs.slice!(0, len)
|
|
end
|
|
|
|
def self.eof_packet?(data)
|
|
data[0] == ?\xfe && data.length == 5
|
|
end
|
|
|
|
# Convert netdata to Ruby value
|
|
# === Argument
|
|
# data :: [String] packet data. This will be broken.
|
|
# type :: [Integer] field type
|
|
# unsigned :: [true or false] true if value is unsigned
|
|
# === Return
|
|
# Object :: converted value.
|
|
def self.net2value(data, type, unsigned)
|
|
case type
|
|
when Field::TYPE_STRING, Field::TYPE_VAR_STRING, Field::TYPE_NEWDECIMAL, Field::TYPE_BLOB
|
|
return Protocol.lcs2str!(data)
|
|
when Field::TYPE_TINY
|
|
v = data.slice!(0).ord
|
|
return unsigned ? v : v < 128 ? v : v-256
|
|
when Field::TYPE_SHORT
|
|
v = data.slice!(0,2).unpack("v").first
|
|
return unsigned ? v : v < 32768 ? v : v-65536
|
|
when Field::TYPE_INT24, Field::TYPE_LONG
|
|
v = data.slice!(0,4).unpack("V").first
|
|
return unsigned ? v : v < 2**32/2 ? v : v-2**32
|
|
when Field::TYPE_LONGLONG
|
|
n1, n2 = data.slice!(0,8).unpack("VV")
|
|
v = (n2 << 32) | n1
|
|
return unsigned ? v : v < 2**64/2 ? v : v-2**64
|
|
when Field::TYPE_FLOAT
|
|
return data.slice!(0,4).unpack("e").first
|
|
when Field::TYPE_DOUBLE
|
|
return data.slice!(0,8).unpack("E").first
|
|
when Field::TYPE_DATE, Field::TYPE_DATETIME, Field::TYPE_TIMESTAMP
|
|
len = data.slice!(0).ord
|
|
y, m, d, h, mi, s, bs = data.slice!(0,len).unpack("vCCCCCV")
|
|
return RbMysql::Time.new(y, m, d, h, mi, s, bs)
|
|
when Field::TYPE_TIME
|
|
len = data.slice!(0).ord
|
|
sign, d, h, mi, s, sp = data.slice!(0,len).unpack("CVCCCV")
|
|
h = d.to_i * 24 + h.to_i
|
|
return RbMysql::Time.new(0, 0, 0, h, mi, s, sign!=0, sp)
|
|
when Field::TYPE_YEAR
|
|
return data.slice!(0,2).unpack("v").first
|
|
when Field::TYPE_BIT
|
|
return Protocol.lcs2str!(data)
|
|
else
|
|
raise "not implemented: type=#{type}"
|
|
end
|
|
end
|
|
|
|
# convert Ruby value to netdata
|
|
# === Argument
|
|
# v :: [Object] Ruby value.
|
|
# === Return
|
|
# String :: netdata
|
|
# === Exception
|
|
# ProtocolError :: value too large / value is not supported
|
|
def self.value2net(v)
|
|
case v
|
|
when nil
|
|
type = Field::TYPE_NULL
|
|
val = ""
|
|
when Integer
|
|
if v >= 0
|
|
if v < 256
|
|
type = Field::TYPE_TINY | 0x8000
|
|
val = [v].pack("C")
|
|
elsif v < 256**2
|
|
type = Field::TYPE_SHORT | 0x8000
|
|
val = [v].pack("v")
|
|
elsif v < 256**4
|
|
type = Field::TYPE_LONG | 0x8000
|
|
val = [v].pack("V")
|
|
elsif v < 256**8
|
|
type = Field::TYPE_LONGLONG | 0x8000
|
|
val = [v&0xffffffff, v>>32].pack("VV")
|
|
else
|
|
raise ProtocolError, "value too large: #{v}"
|
|
end
|
|
else
|
|
if -v <= 256/2
|
|
type = Field::TYPE_TINY
|
|
val = [v].pack("C")
|
|
elsif -v <= 256**2/2
|
|
type = Field::TYPE_SHORT
|
|
val = [v].pack("v")
|
|
elsif -v <= 256**4/2
|
|
type = Field::TYPE_LONG
|
|
val = [v].pack("V")
|
|
elsif -v <= 256**8/2
|
|
type = Field::TYPE_LONGLONG
|
|
val = [v&0xffffffff, v>>32].pack("VV")
|
|
else
|
|
raise ProtocolError, "value too large: #{v}"
|
|
end
|
|
end
|
|
when Float
|
|
type = Field::TYPE_DOUBLE
|
|
val = [v].pack("E")
|
|
when String
|
|
type = Field::TYPE_STRING
|
|
val = Protocol.lcs(v)
|
|
when RbMysql::Time, ::Time
|
|
type = Field::TYPE_DATETIME
|
|
val = [7, v.year, v.month, v.day, v.hour, v.min, v.sec].pack("CvCCCCC")
|
|
else
|
|
raise ProtocolError, "class #{v.class} is not supported"
|
|
end
|
|
return type, val
|
|
end
|
|
|
|
attr_reader :sqlstate
|
|
|
|
# make socket connection to server.
|
|
# === Argument
|
|
# host :: [String] if "localhost" or "" nil then use UNIXSocket. Otherwise use TCPSocket
|
|
# port :: [Integer] port number using by TCPSocket
|
|
# socket :: [String] socket file name using by UNIXSocket
|
|
# conn_timeout :: [Integer] connect timeout (sec).
|
|
# read_timeout :: [Integer] read timeout (sec).
|
|
# write_timeout :: [Integer] write timeout (sec).
|
|
def initialize(host, port, socket, conn_timeout, read_timeout, write_timeout)
|
|
begin
|
|
Timeout.timeout conn_timeout do
|
|
if host.nil? or host.empty? or host == "localhost"
|
|
socket = ENV["MYSQL_UNIX_PORT"] || MYSQL_UNIX_PORT
|
|
@sock = UNIXSocket.new socket
|
|
else
|
|
if !socket
|
|
port ||= ENV["MYSQL_TCP_PORT"] || (Socket.getservbyname("mysql","tcp") rescue MYSQL_TCP_PORT)
|
|
@sock = TCPSocket.new host, port
|
|
else
|
|
@sock = socket
|
|
end
|
|
end
|
|
end
|
|
rescue Timeout::Error
|
|
raise ClientError, "connection timeout"
|
|
end
|
|
@read_timeout = read_timeout
|
|
@write_timeout = write_timeout
|
|
@seq = 0 # packet counter. reset by each command
|
|
@mutex = Mutex.new
|
|
end
|
|
|
|
def close
|
|
@sock.close
|
|
end
|
|
|
|
def synchronize
|
|
@mutex.synchronize do
|
|
return yield
|
|
end
|
|
end
|
|
|
|
# Reset sequence number
|
|
def reset
|
|
@seq = 0
|
|
end
|
|
|
|
# Read one packet data
|
|
# === Return
|
|
# String
|
|
# === Exception
|
|
# ProtocolError :: invalid packet sequence number
|
|
def read
|
|
ret = ""
|
|
len = nil
|
|
begin
|
|
Timeout.timeout @read_timeout do
|
|
header = @sock.read(4)
|
|
raise EOFError unless header && header.length == 4
|
|
len1, len2, seq = header.unpack("CvC")
|
|
len = (len2 << 8) + len1
|
|
# Ignore the sequence number -- protocol differences between 4.x and 5.x
|
|
# raise ProtocolError, "invalid packet: sequence number mismatch(#{seq} != #{@seq}(expected))" if @seq != seq
|
|
@seq = (@seq + 1) % 256
|
|
ret.concat @sock.read(len)
|
|
end
|
|
rescue Timeout::Error
|
|
raise ClientError, "read timeout"
|
|
end while len == MAX_PACKET_LENGTH
|
|
|
|
@sqlstate = "00000"
|
|
|
|
# Error packet
|
|
if ret[0] == ?\xff
|
|
f, errno, marker, @sqlstate, message = ret.unpack("Cvaa5a*")
|
|
unless marker == "#"
|
|
f, errno, message = ret.unpack("Cva*") # Version 4.0 Error
|
|
@sqlstate = ""
|
|
end
|
|
if RbMysql::ServerError::ERROR_MAP.key? errno
|
|
raise RbMysql::ServerError::ERROR_MAP[errno].new(message, @sqlstate)
|
|
end
|
|
raise RbMysql::ServerError.new(message, @sqlstate)
|
|
end
|
|
ret
|
|
end
|
|
|
|
# Write one packet data
|
|
# === Argument
|
|
# data [String / IO] ::
|
|
def write(data)
|
|
begin
|
|
@sock.sync = false
|
|
data = StringIO.new data if data.is_a? String
|
|
while d = data.read(MAX_PACKET_LENGTH)
|
|
Timeout.timeout @write_timeout do
|
|
@sock.write [d.length%256, d.length/256, @seq].pack("CvC")
|
|
@sock.write d
|
|
end
|
|
@seq = (@seq + 1) % 256
|
|
end
|
|
@sock.sync = true
|
|
Timeout.timeout @write_timeout do
|
|
@sock.flush
|
|
end
|
|
rescue Timeout::Error
|
|
raise ClientError, "write timeout"
|
|
end
|
|
end
|
|
|
|
# Send one packet
|
|
# === Argument
|
|
# packet :: [*Packet]
|
|
def send_packet(packet)
|
|
write packet.serialize
|
|
end
|
|
|
|
# Read EOF packet
|
|
# === Exception
|
|
# ProtocolError :: packet is not EOF
|
|
def read_eof_packet
|
|
data = read
|
|
# EOF packet is different between MySQL 4.x and 5.x, so ignore.
|
|
# raise ProtocolError, "packet is not EOF" unless Protocol.eof_packet? data
|
|
end
|
|
|
|
# Read initial packet
|
|
# === Return
|
|
# InitialPacket ::
|
|
# === Exception
|
|
# ProtocolError :: invalid packet
|
|
def read_initial_packet
|
|
InitialPacket.parse read
|
|
end
|
|
|
|
# Read result packet
|
|
# === Return
|
|
# ResultPacket ::
|
|
def read_result_packet
|
|
ResultPacket.parse read
|
|
end
|
|
|
|
# Read field packet
|
|
# === Return
|
|
# FieldPacket :: packet data
|
|
# === Exception
|
|
# ProtocolError :: invalid packet
|
|
def read_field_packet
|
|
FieldPacket.parse read
|
|
end
|
|
|
|
# Read prepare result packet
|
|
# === Return
|
|
# PrepareResultPacket ::
|
|
# === Exception
|
|
# ProtocolError :: invalid packet
|
|
def read_prepare_result_packet
|
|
PrepareResultPacket.parse read
|
|
end
|
|
|
|
# client->server packet base class
|
|
class TxPacket
|
|
end
|
|
|
|
# server->client packet base class
|
|
class RxPacket
|
|
end
|
|
|
|
# Initial packet
|
|
class InitialPacket < RxPacket
|
|
def self.parse(data)
|
|
protocol_version, server_version, thread_id, scramble_buff, f0,
|
|
server_capabilities, server_charset, server_status, f1,
|
|
rest_scramble_buff = data.unpack("CZ*Va8CvCva13Z13")
|
|
raise ProtocolError, "unsupported version: #{protocol_version}" unless protocol_version == VERSION
|
|
raise ProtocolError, "invalid packet: f0=#{f0}" unless f0 == 0
|
|
# Remove the f1 check to backport https://github.com/tmtm/ruby-mysql/commit/07ddfafafbd1d46bbb71c7cb54ae0f03bc998d27
|
|
# raise ProtocolError, "invalid packet: f1=#{f1.inspect}" unless f1 == "\0\0\0\0\0\0\0\0\0\0\0\0\0"
|
|
scramble_buff.concat rest_scramble_buff
|
|
self.new protocol_version, server_version, thread_id, server_capabilities, server_charset, server_status, scramble_buff
|
|
end
|
|
|
|
attr_accessor :protocol_version, :server_version, :thread_id, :server_capabilities, :server_charset, :server_status, :scramble_buff
|
|
|
|
def initialize(*args)
|
|
@protocol_version, @server_version, @thread_id, @server_capabilities, @server_charset, @server_status, @scramble_buff = args
|
|
end
|
|
|
|
def crypt_password(plain)
|
|
return "" if plain.nil? or plain.empty?
|
|
hash_stage1 = Digest::SHA1.digest plain
|
|
hash_stage2 = Digest::SHA1.digest hash_stage1
|
|
return hash_stage1.unpack("C*").zip(Digest::SHA1.digest(@scramble_buff+hash_stage2).unpack("C*")).map{|a,b| a^b}.pack("C*")
|
|
end
|
|
end
|
|
|
|
# Authentication packet
|
|
class AuthenticationPacket < TxPacket
|
|
attr_accessor :client_flags, :max_packet_size, :charset_number, :username, :scrambled_password, :databasename
|
|
|
|
def initialize(*args)
|
|
@client_flags, @max_packet_size, @charset_number, @username, @scrambled_password, @databasename = args
|
|
end
|
|
|
|
def serialize
|
|
[
|
|
client_flags,
|
|
max_packet_size,
|
|
Protocol.lcb(charset_number),
|
|
"", # always 0x00 * 23
|
|
username,
|
|
Protocol.lcs(scrambled_password),
|
|
databasename
|
|
].pack("VVa*a23Z*A*Z*")
|
|
end
|
|
end
|
|
|
|
# Quit packet
|
|
class QuitPacket < TxPacket
|
|
def serialize
|
|
[COM_QUIT].pack("C")
|
|
end
|
|
end
|
|
|
|
# Query packet
|
|
class QueryPacket < TxPacket
|
|
attr_accessor :query
|
|
|
|
def initialize(*args)
|
|
@query, = args
|
|
end
|
|
|
|
def serialize
|
|
[COM_QUERY, query].pack("Ca*")
|
|
end
|
|
end
|
|
|
|
# Result packet
|
|
class ResultPacket < RxPacket
|
|
def self.parse(data)
|
|
field_count = Protocol.lcb2int! data
|
|
if field_count == 0
|
|
affected_rows = Protocol.lcb2int! data
|
|
insert_id = Protocol.lcb2int!(data)
|
|
server_status, warning_count, message = data.unpack("vva*")
|
|
return self.new(field_count, affected_rows, insert_id, server_status, warning_count, message)
|
|
else
|
|
return self.new(field_count)
|
|
end
|
|
end
|
|
|
|
attr_accessor :field_count, :affected_rows, :insert_id, :server_status, :warning_count, :message
|
|
|
|
def initialize(*args)
|
|
@field_count, @affected_rows, @insert_id, @server_status, @warning_count, @message = args
|
|
end
|
|
end
|
|
|
|
# Field packet
|
|
class FieldPacket < RxPacket
|
|
def self.parse(data)
|
|
first = Protocol.lcs2str! data
|
|
db = Protocol.lcs2str! data
|
|
table = Protocol.lcs2str! data
|
|
org_table = Protocol.lcs2str! data
|
|
name = Protocol.lcs2str! data
|
|
org_name = Protocol.lcs2str! data
|
|
f0, charsetnr, length, type, flags, decimals, f1, data = data.unpack("CvVCvCva*")
|
|
raise ProtocolError, "invalid packet: f1=#{f1}" unless f1 == 0
|
|
default = Protocol.lcs2str! data
|
|
return self.new(db, table, org_table, name, org_name, charsetnr, length, type, flags, decimals, default)
|
|
end
|
|
|
|
attr_accessor :db, :table, :org_table, :name, :org_name, :charsetnr, :length, :type, :flags, :decimals, :default
|
|
|
|
def initialize(*args)
|
|
@db, @table, @org_table, @name, @org_name, @charsetnr, @length, @type, @flags, @decimals, @default = args
|
|
end
|
|
end
|
|
|
|
# Prepare packet
|
|
class PreparePacket < TxPacket
|
|
attr_accessor :query
|
|
|
|
def initialize(*args)
|
|
@query, = args
|
|
end
|
|
|
|
def serialize
|
|
[COM_STMT_PREPARE, query].pack("Ca*")
|
|
end
|
|
end
|
|
|
|
# Prepare result packet
|
|
class PrepareResultPacket < RxPacket
|
|
def self.parse(data)
|
|
raise ProtocolError, "invalid packet" unless data.slice!(0) == ?\0
|
|
statement_id, field_count, param_count, f, warning_count = data.unpack("VvvCv")
|
|
raise ProtocolError, "invalid packet" unless f == 0x00
|
|
self.new statement_id, field_count, param_count, warning_count
|
|
end
|
|
|
|
attr_accessor :statement_id, :field_count, :param_count, :warning_count
|
|
|
|
def initialize(*args)
|
|
@statement_id, @field_count, @param_count, @warning_count = args
|
|
end
|
|
end
|
|
|
|
# Execute packet
|
|
class ExecutePacket < TxPacket
|
|
attr_accessor :statement_id, :cursor_type, :values
|
|
|
|
def initialize(*args)
|
|
@statement_id, @cursor_type, @values = args
|
|
end
|
|
|
|
def serialize
|
|
nbm = null_bitmap values
|
|
netvalues = ""
|
|
types = values.map do |v|
|
|
t, n = Protocol.value2net v
|
|
netvalues.concat n if v
|
|
t
|
|
end
|
|
[RbMysql::COM_STMT_EXECUTE, statement_id, cursor_type, 1, nbm, 1, types.pack("v*"), netvalues].pack("CVCVa*Ca*a*")
|
|
end
|
|
|
|
private
|
|
|
|
# make null bitmap
|
|
#
|
|
# If values is [1, nil, 2, 3, nil] then returns "\x12"(0b10010).
|
|
def null_bitmap(values)
|
|
bitmap = values.enum_for(:each_slice,8).map do |vals|
|
|
vals.reverse.inject(0){|b, v|(b << 1 | (v ? 0 : 1))}
|
|
end
|
|
return bitmap.pack("C*")
|
|
end
|
|
|
|
end
|
|
|
|
# Fetch packet
|
|
class FetchPacket < TxPacket
|
|
attr_accessor :statement_id, :fetch_length
|
|
|
|
def initialize(*args)
|
|
@statement_id, @fetch_length = args
|
|
end
|
|
|
|
def serialize
|
|
[RbMysql::COM_STMT_FETCH, statement_id, fetch_length].pack("CVV")
|
|
end
|
|
end
|
|
|
|
# Stmt close packet
|
|
class StmtClosePacket < TxPacket
|
|
attr_accessor :statement_id
|
|
|
|
def initialize(*args)
|
|
@statement_id, = args
|
|
end
|
|
|
|
def serialize
|
|
[RbMysql::COM_STMT_CLOSE, statement_id].pack("CV")
|
|
end
|
|
end
|
|
|
|
class FieldListPacket < TxPacket
|
|
attr_accessor :table, :field
|
|
|
|
def initialize(*args)
|
|
@table, @field = args
|
|
end
|
|
|
|
def serialize
|
|
[RbMysql::COM_FIELD_LIST, "#{@table}\0#{@field}"].pack("Ca*")
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|