Source code for websocket._abnf

"""

"""

"""
websocket - WebSocket client library for Python

Copyright (C) 2010 Hiroki Ohtani(liris)

    This library is free software; you can redistribute it and/or
    modify it under the terms of the GNU Lesser General Public
    License as published by the Free Software Foundation; either
    version 2.1 of the License, or (at your option) any later version.

    This library is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
    Lesser General Public License for more details.

    You should have received a copy of the GNU Lesser General Public
    License along with this library; if not, write to the Free Software
    Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA  02110-1301  USA

"""
import array
import os
import struct

from ._exceptions import *
from ._utils import validate_utf8
from threading import Lock

try:
    import numpy
except ImportError:
    numpy = None

try:
    # If wsaccel is available we use compiled routines to mask data.
    if not numpy:
        from wsaccel.xormask import XorMaskerSimple

        def _mask(_m, _d):
            return XorMaskerSimple(_m).process(_d)
except ImportError:
    # wsaccel is not available, we rely on python implementations.
    def _mask(_m, _d):
        for i in range(len(_d)):
            _d[i] ^= _m[i % 4]

        return _d.tobytes()


__all__ = [
    'ABNF', 'continuous_frame', 'frame_buffer',
    'STATUS_NORMAL',
    'STATUS_GOING_AWAY',
    'STATUS_PROTOCOL_ERROR',
    'STATUS_UNSUPPORTED_DATA_TYPE',
    'STATUS_STATUS_NOT_AVAILABLE',
    'STATUS_ABNORMAL_CLOSED',
    'STATUS_INVALID_PAYLOAD',
    'STATUS_POLICY_VIOLATION',
    'STATUS_MESSAGE_TOO_BIG',
    'STATUS_INVALID_EXTENSION',
    'STATUS_UNEXPECTED_CONDITION',
    'STATUS_BAD_GATEWAY',
    'STATUS_TLS_HANDSHAKE_ERROR',
]

# closing frame status codes.
STATUS_NORMAL = 1000
STATUS_GOING_AWAY = 1001
STATUS_PROTOCOL_ERROR = 1002
STATUS_UNSUPPORTED_DATA_TYPE = 1003
STATUS_STATUS_NOT_AVAILABLE = 1005
STATUS_ABNORMAL_CLOSED = 1006
STATUS_INVALID_PAYLOAD = 1007
STATUS_POLICY_VIOLATION = 1008
STATUS_MESSAGE_TOO_BIG = 1009
STATUS_INVALID_EXTENSION = 1010
STATUS_UNEXPECTED_CONDITION = 1011
STATUS_BAD_GATEWAY = 1014
STATUS_TLS_HANDSHAKE_ERROR = 1015

VALID_CLOSE_STATUS = (
    STATUS_NORMAL,
    STATUS_GOING_AWAY,
    STATUS_PROTOCOL_ERROR,
    STATUS_UNSUPPORTED_DATA_TYPE,
    STATUS_INVALID_PAYLOAD,
    STATUS_POLICY_VIOLATION,
    STATUS_MESSAGE_TOO_BIG,
    STATUS_INVALID_EXTENSION,
    STATUS_UNEXPECTED_CONDITION,
    STATUS_BAD_GATEWAY,
)


[docs]class ABNF(object): """ ABNF frame class. See http://tools.ietf.org/html/rfc5234 and http://tools.ietf.org/html/rfc6455#section-5.2 """ # operation code values. OPCODE_CONT = 0x0 OPCODE_TEXT = 0x1 OPCODE_BINARY = 0x2 OPCODE_CLOSE = 0x8 OPCODE_PING = 0x9 OPCODE_PONG = 0xa # available operation code value tuple OPCODES = (OPCODE_CONT, OPCODE_TEXT, OPCODE_BINARY, OPCODE_CLOSE, OPCODE_PING, OPCODE_PONG) # opcode human readable string OPCODE_MAP = { OPCODE_CONT: "cont", OPCODE_TEXT: "text", OPCODE_BINARY: "binary", OPCODE_CLOSE: "close", OPCODE_PING: "ping", OPCODE_PONG: "pong" } # data length threshold. LENGTH_7 = 0x7e LENGTH_16 = 1 << 16 LENGTH_63 = 1 << 63
[docs] def __init__(self, fin=0, rsv1=0, rsv2=0, rsv3=0, opcode=OPCODE_TEXT, mask=1, data=""): """ Constructor for ABNF. Please check RFC for arguments. """ self.fin = fin self.rsv1 = rsv1 self.rsv2 = rsv2 self.rsv3 = rsv3 self.opcode = opcode self.mask = mask if data is None: data = "" self.data = data self.get_mask_key = os.urandom
[docs] def validate(self, skip_utf8_validation=False): """ Validate the ABNF frame. Parameters ---------- skip_utf8_validation: skip utf8 validation. """ if self.rsv1 or self.rsv2 or self.rsv3: raise WebSocketProtocolException("rsv is not implemented, yet") if self.opcode not in ABNF.OPCODES: raise WebSocketProtocolException("Invalid opcode %r", self.opcode) if self.opcode == ABNF.OPCODE_PING and not self.fin: raise WebSocketProtocolException("Invalid ping frame.") if self.opcode == ABNF.OPCODE_CLOSE: l = len(self.data) if not l: return if l == 1 or l >= 126: raise WebSocketProtocolException("Invalid close frame.") if l > 2 and not skip_utf8_validation and not validate_utf8(self.data[2:]): raise WebSocketProtocolException("Invalid close frame.") code = 256 * self.data[0] + self.data[1] if not self._is_valid_close_status(code): raise WebSocketProtocolException("Invalid close opcode.")
@staticmethod def _is_valid_close_status(code): return code in VALID_CLOSE_STATUS or (3000 <= code < 5000) def __str__(self): return "fin=" + str(self.fin) \ + " opcode=" + str(self.opcode) \ + " data=" + str(self.data)
[docs] @staticmethod def create_frame(data, opcode, fin=1): """ Create frame to send text, binary and other data. Parameters ---------- data: <type> data to send. This is string value(byte array). If opcode is OPCODE_TEXT and this value is unicode, data value is converted into unicode string, automatically. opcode: <type> operation code. please see OPCODE_XXX. fin: <type> fin flag. if set to 0, create continue fragmentation. """ if opcode == ABNF.OPCODE_TEXT and isinstance(data, str): data = data.encode("utf-8") # mask must be set if send data from client return ABNF(fin, 0, 0, 0, opcode, 1, data)
[docs] def format(self): """ Format this object to string(byte array) to send data to server. """ if any(x not in (0, 1) for x in [self.fin, self.rsv1, self.rsv2, self.rsv3]): raise ValueError("not 0 or 1") if self.opcode not in ABNF.OPCODES: raise ValueError("Invalid OPCODE") length = len(self.data) if length >= ABNF.LENGTH_63: raise ValueError("data is too long") frame_header = chr(self.fin << 7 | self.rsv1 << 6 | self.rsv2 << 5 | self.rsv3 << 4 | self.opcode).encode('latin-1') if length < ABNF.LENGTH_7: frame_header += chr(self.mask << 7 | length).encode('latin-1') elif length < ABNF.LENGTH_16: frame_header += chr(self.mask << 7 | 0x7e).encode('latin-1') frame_header += struct.pack("!H", length) else: frame_header += chr(self.mask << 7 | 0x7f).encode('latin-1') frame_header += struct.pack("!Q", length) if not self.mask: return frame_header + self.data else: mask_key = self.get_mask_key(4) return frame_header + self._get_masked(mask_key)
def _get_masked(self, mask_key): s = ABNF.mask(mask_key, self.data) if isinstance(mask_key, str): mask_key = mask_key.encode('utf-8') return mask_key + s
[docs] @staticmethod def mask(mask_key, data): """ Mask or unmask data. Just do xor for each byte Parameters ---------- mask_key: <type> 4 byte string. data: <type> data to mask/unmask. """ if data is None: data = "" if isinstance(mask_key, str): mask_key = mask_key.encode('latin-1') if isinstance(data, str): data = data.encode('latin-1') if numpy: origlen = len(data) _mask_key = mask_key[3] << 24 | mask_key[2] << 16 | mask_key[1] << 8 | mask_key[0] # We need data to be a multiple of four... data += b' ' * (4 - (len(data) % 4)) a = numpy.frombuffer(data, dtype="uint32") masked = numpy.bitwise_xor(a, [_mask_key]).astype("uint32") if len(data) > origlen: return masked.tobytes()[:origlen] return masked.tobytes() else: return _mask(array.array("B", mask_key), array.array("B", data))
class frame_buffer(object): _HEADER_MASK_INDEX = 5 _HEADER_LENGTH_INDEX = 6 def __init__(self, recv_fn, skip_utf8_validation): self.recv = recv_fn self.skip_utf8_validation = skip_utf8_validation # Buffers over the packets from the layer beneath until desired amount # bytes of bytes are received. self.recv_buffer = [] self.clear() self.lock = Lock() def clear(self): self.header = None self.length = None self.mask = None def has_received_header(self): return self.header is None def recv_header(self): header = self.recv_strict(2) b1 = header[0] fin = b1 >> 7 & 1 rsv1 = b1 >> 6 & 1 rsv2 = b1 >> 5 & 1 rsv3 = b1 >> 4 & 1 opcode = b1 & 0xf b2 = header[1] has_mask = b2 >> 7 & 1 length_bits = b2 & 0x7f self.header = (fin, rsv1, rsv2, rsv3, opcode, has_mask, length_bits) def has_mask(self): if not self.header: return False return self.header[frame_buffer._HEADER_MASK_INDEX] def has_received_length(self): return self.length is None def recv_length(self): bits = self.header[frame_buffer._HEADER_LENGTH_INDEX] length_bits = bits & 0x7f if length_bits == 0x7e: v = self.recv_strict(2) self.length = struct.unpack("!H", v)[0] elif length_bits == 0x7f: v = self.recv_strict(8) self.length = struct.unpack("!Q", v)[0] else: self.length = length_bits def has_received_mask(self): return self.mask is None def recv_mask(self): self.mask = self.recv_strict(4) if self.has_mask() else "" def recv_frame(self): with self.lock: # Header if self.has_received_header(): self.recv_header() (fin, rsv1, rsv2, rsv3, opcode, has_mask, _) = self.header # Frame length if self.has_received_length(): self.recv_length() length = self.length # Mask if self.has_received_mask(): self.recv_mask() mask = self.mask # Payload payload = self.recv_strict(length) if has_mask: payload = ABNF.mask(mask, payload) # Reset for next frame self.clear() frame = ABNF(fin, rsv1, rsv2, rsv3, opcode, has_mask, payload) frame.validate(self.skip_utf8_validation) return frame def recv_strict(self, bufsize): shortage = bufsize - sum(len(x) for x in self.recv_buffer) while shortage > 0: # Limit buffer size that we pass to socket.recv() to avoid # fragmenting the heap -- the number of bytes recv() actually # reads is limited by socket buffer and is relatively small, # yet passing large numbers repeatedly causes lots of large # buffers allocated and then shrunk, which results in # fragmentation. bytes_ = self.recv(min(16384, shortage)) self.recv_buffer.append(bytes_) shortage -= len(bytes_) unified = bytes("", 'utf-8').join(self.recv_buffer) if shortage == 0: self.recv_buffer = [] return unified else: self.recv_buffer = [unified[bufsize:]] return unified[:bufsize] class continuous_frame(object): def __init__(self, fire_cont_frame, skip_utf8_validation): self.fire_cont_frame = fire_cont_frame self.skip_utf8_validation = skip_utf8_validation self.cont_data = None self.recving_frames = None def validate(self, frame): if not self.recving_frames and frame.opcode == ABNF.OPCODE_CONT: raise WebSocketProtocolException("Illegal frame") if self.recving_frames and \ frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): raise WebSocketProtocolException("Illegal frame") def add(self, frame): if self.cont_data: self.cont_data[1] += frame.data else: if frame.opcode in (ABNF.OPCODE_TEXT, ABNF.OPCODE_BINARY): self.recving_frames = frame.opcode self.cont_data = [frame.opcode, frame.data] if frame.fin: self.recving_frames = None def is_fire(self, frame): return frame.fin or self.fire_cont_frame def extract(self, frame): data = self.cont_data self.cont_data = None frame.data = data[1] if not self.fire_cont_frame and data[0] == ABNF.OPCODE_TEXT and not self.skip_utf8_validation and not validate_utf8(frame.data): raise WebSocketPayloadException( "cannot decode: " + repr(frame.data)) return [data[0], frame]