aboutsummaryrefslogtreecommitdiffstats
path: root/youtube/proto.py
blob: 924e983d1466ef30e48bbd519930cb77f38d2e6e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
from math import ceil
import base64
import io
import traceback


def byte(n):
    return bytes((n,))


def varint_encode(offset):
    '''In this encoding system, for each 8-bit byte, the first bit is 1 if there are more bytes, and 0 is this is the last one.
    The next 7 bits are data. These 7-bit sections represent the data in Little endian order. For example, suppose the data is
    aaaaaaabbbbbbbccccccc (each of these sections is 7 bits). It will be encoded as:
    1ccccccc 1bbbbbbb 0aaaaaaa

    This encoding is used in youtube parameters to encode offsets and to encode the length for length-prefixed data.
    See https://developers.google.com/protocol-buffers/docs/encoding#varints for more info.'''
    needed_bytes = ceil(offset.bit_length()/7) or 1 # (0).bit_length() returns 0, but we need 1 in that case.
    encoded_bytes = bytearray(needed_bytes)
    for i in range(0, needed_bytes - 1):
        encoded_bytes[i] = (offset & 127) | 128  # 7 least significant bits
        offset = offset >> 7
    encoded_bytes[-1] = offset & 127  # leave first bit as zero for last byte

    return bytes(encoded_bytes)


def varint_decode(encoded):
    decoded = 0
    for i, byte in enumerate(encoded):
        decoded |= (byte & 127) << 7*i

        if not (byte & 128):
            break
    return decoded


def string(field_number, data):
    data = as_bytes(data)
    return _proto_field(2, field_number, varint_encode(len(data)) + data)


nested = string


def uint(field_number, value):
    return _proto_field(0, field_number, varint_encode(value))


def _proto_field(wire_type, field_number, data):
    ''' See https://developers.google.com/protocol-buffers/docs/encoding#structure '''
    return varint_encode((field_number << 3) | wire_type) + data


def percent_b64encode(data):
    return base64.urlsafe_b64encode(data).replace(b'=', b'%3D')


def unpadded_b64encode(data):
    return base64.urlsafe_b64encode(data).replace(b'=', b'')


def as_bytes(value):
    if isinstance(value, str):
        return value.encode('utf-8')
    return value


def read_varint(data):
    result = 0
    i = 0
    while True:
        try:
            byte = data.read(1)[0]
        except IndexError:
            if i == 0:
                raise EOFError()
            raise Exception('Unterminated varint starting at ' + str(data.tell() - i))
        result |= (byte & 127) << 7*i
        if not byte & 128:
            break

        i += 1
    return result


def read_group(data, end_sequence):
    start = data.tell()
    index = data.original.find(end_sequence, start)
    if index == -1:
        raise Exception('Unterminated group')
    data.seek(index + len(end_sequence))
    return data.original[start:index]

def read_protobuf(data):
    data_original = data
    data = io.BytesIO(data)
    data.original = data_original
    while True:
        try:
            tag = read_varint(data)
        except EOFError:
            break
        wire_type = tag & 7
        field_number = tag >> 3

        if wire_type == 0:
            value = read_varint(data)
        elif wire_type == 1:
            value = data.read(8)
        elif wire_type == 2:
            length = read_varint(data)
            value = data.read(length)
        elif wire_type == 3:
            end_bytes = encode_varint((field_number << 3) | 4)
            value = read_group(data, end_bytes)
        elif wire_type == 5:
            value = data.read(4)
        else:
            raise Exception("Unknown wire type: " + str(wire_type) + ", Tag: " + bytes_to_hex(succinct_encode(tag)) + ", at position " + str(data.tell()))
        yield (wire_type, field_number, value)


def parse(data, include_wire_type=False):
    '''Returns a dict mapping field numbers to values

    data is the protobuf structure, which must not be b64-encoded'''
    if include_wire_type:
        return {field_number: [wire_type, value]
                for wire_type, field_number, value in read_protobuf(data)}
    return {field_number: value
            for _, field_number, value in read_protobuf(data)}


base64_enc_funcs = {
    'base64': base64.urlsafe_b64encode,
    'base64s': unpadded_b64encode,
    'base64p': percent_b64encode,
}


def _make_protobuf(data):
    '''
    Input: Recursive list of protobuf objects or base-64 encodings
    Output: Protobuf bytestring
    Each protobuf object takes the form [wire_type, field_number, field_data]
    If a string protobuf has a list/tuple of length 2, this has the form
    (base64 type, data)
    The base64 types are
    - base64 means a base64 encode with equals sign paddings
    - base64s means a base64 encode without padding
    - base64p means a url base64 encode with equals signs replaced with %3D
    '''
    # must be dict mapping field_number to [wire_type, value]
    if isinstance(data, dict):
        new_data = []
        for field_num, (wire_type, value) in sorted(data.items()):
            new_data.append((wire_type, field_num, value))
        data = new_data
    if isinstance(data, str):
        return data.encode('utf-8')
    elif len(data) == 2 and data[0] in list(base64_enc_funcs.keys()):
        return base64_enc_funcs[data[0]](_make_protobuf(data[1]))
    elif isinstance(data, list):
        result = b''
        for field in data:
            if field[0] == 0:
                result += uint(field[1], field[2])
            elif field[0] == 2:
                result += string(field[1], _make_protobuf(field[2]))
            else:
                raise NotImplementedError('Wire type ' + str(field[0])
                    + ' not implemented')
        return result
    return data


def make_protobuf(data):
    return _make_protobuf(data).decode('ascii')


def _set_protobuf_value(data, *path, value):
    if not path:
        return value
    op = path[0]
    if op in base64_enc_funcs:
        inner_data = b64_to_bytes(data)
        return base64_enc_funcs[op](
            _set_protobuf_value(inner_data, *path[1:], value=value)
        )
    pb_dict = parse(data, include_wire_type=True)
    pb_dict[op][1] = _set_protobuf_value(
        pb_dict[op][1], *path[1:], value=value
    )
    return _make_protobuf(pb_dict)


def set_protobuf_value(data, *path, value):
    '''Set a field's value in a raw protobuf structure

    path is a list of field numbers and/or base64 encoding directives

    The directives are
        base64: normal base64 encoding with equal signs padding
        base64s ("stripped"): no padding
        base64p: %3D instead of = for padding

    return new_protobuf, err'''
    try:
        new_protobuf = _set_protobuf_value(data, *path, value=value)
        return new_protobuf.decode('ascii'), None
    except Exception:
        return None, traceback.format_exc()


def b64_to_bytes(data):
    if isinstance(data, bytes):
        data = data.decode('ascii')
    data = data.replace("%3D", "=")
    return base64.urlsafe_b64decode(data + "="*((4 - len(data) % 4) % 4))