Skip to content

Commit 96fac3a

Browse files
authored
[mypyc] Add urlsafe_b64encode and urlsafe_b64decode to librt.base64 (#20274)
These just have a post-processing or a pre-processing step but otherwise share the implementation with the non-urlsafe variants. We can break the ABI since base64 functions are still experimental. I checked that at least gcc and clang are able to vectorize the pre/post processing loops (when using -O3).
1 parent 89782cc commit 96fac3a

File tree

6 files changed

+218
-33
lines changed

6 files changed

+218
-33
lines changed
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,4 @@
11
def b64encode(s: bytes) -> bytes: ...
22
def b64decode(s: bytes | str) -> bytes: ...
3+
def urlsafe_b64encode(s: bytes) -> bytes: ...
4+
def urlsafe_b64decode(s: bytes | str) -> bytes: ...

mypyc/lib-rt/librt_base64.c

Lines changed: 88 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,14 +9,42 @@
99

1010
static PyObject *
1111
b64decode_handle_invalid_input(
12-
PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen);
12+
PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen, bool freesrc);
1313

1414
#define BASE64_MAXBIN ((PY_SSIZE_T_MAX - 3) / 2)
1515

1616
#define STACK_BUFFER_SIZE 1024
1717

18+
static void
19+
convert_encoded_to_urlsafe(char *buf, size_t len) {
20+
// The loop is written to enable SIMD optimizations
21+
for (size_t i = 0; i < len; i++) {
22+
char ch = buf[i];
23+
if (ch == '+') {
24+
ch = '-';
25+
} else if (ch == '/') {
26+
ch = '_';
27+
}
28+
buf[i] = ch;
29+
}
30+
}
31+
32+
static void
33+
convert_urlsafe_to_encoded(const char *src, size_t len, char *buf) {
34+
// The loop is written to enable SIMD optimizations
35+
for (size_t i = 0; i < len; i++) {
36+
char ch = src[i];
37+
if (ch == '-') {
38+
ch = '+';
39+
} else if (ch == '_') {
40+
ch = '/';
41+
}
42+
buf[i] = ch;
43+
}
44+
}
45+
1846
static PyObject *
19-
b64encode_internal(PyObject *obj) {
47+
b64encode_internal(PyObject *obj, bool urlsafe) {
2048
unsigned char *ascii_data;
2149
char *bin_data;
2250
int leftbits = 0;
@@ -53,6 +81,11 @@ b64encode_internal(PyObject *obj) {
5381
}
5482
size_t actual_len;
5583
base64_encode(bin_data, bin_len, buf, &actual_len, 0);
84+
85+
if (urlsafe) {
86+
convert_encoded_to_urlsafe(buf, actual_len);
87+
}
88+
5689
PyObject *res = PyBytes_FromStringAndSize(buf, actual_len);
5790
if (buflen > STACK_BUFFER_SIZE)
5891
PyMem_Free(buf);
@@ -65,7 +98,16 @@ b64encode(PyObject *self, PyObject *const *args, size_t nargs) {
6598
PyErr_SetString(PyExc_TypeError, "b64encode() takes exactly one argument");
6699
return 0;
67100
}
68-
return b64encode_internal(args[0]);
101+
return b64encode_internal(args[0], false);
102+
}
103+
104+
static PyObject*
105+
urlsafe_b64encode(PyObject *self, PyObject *const *args, size_t nargs) {
106+
if (nargs != 1) {
107+
PyErr_SetString(PyExc_TypeError, "urlsafe_b64encode() takes exactly one argument");
108+
return 0;
109+
}
110+
return b64encode_internal(args[0], true);
69111
}
70112

71113
static inline int
@@ -75,7 +117,7 @@ is_valid_base64_char(char c, bool allow_padding) {
75117
}
76118

77119
static PyObject *
78-
b64decode_internal(PyObject *arg) {
120+
b64decode_internal(PyObject *arg, bool urlsafe) {
79121
const char *src;
80122
Py_ssize_t srclen_ssz;
81123

@@ -102,6 +144,15 @@ b64decode_internal(PyObject *arg) {
102144
return PyBytes_FromStringAndSize(NULL, 0);
103145
}
104146

147+
if (urlsafe) {
148+
char *new_src = PyMem_Malloc(srclen_ssz + 1);
149+
if (new_src == NULL) {
150+
return PyErr_NoMemory();
151+
}
152+
convert_urlsafe_to_encoded(src, srclen_ssz, new_src);
153+
src = new_src;
154+
}
155+
105156
// Quickly ignore invalid characters at the end. Other invalid characters
106157
// are also accepted, but they need a slow path.
107158
while (srclen_ssz > 0 && !is_valid_base64_char(src[srclen_ssz - 1], true)) {
@@ -123,6 +174,9 @@ b64decode_internal(PyObject *arg) {
123174
// Allocate output bytes (uninitialized) of the max capacity
124175
PyObject *out_bytes = PyBytes_FromStringAndSize(NULL, (Py_ssize_t)max_out);
125176
if (out_bytes == NULL) {
177+
if (urlsafe) {
178+
PyMem_Free((void *)src);
179+
}
126180
return NULL; // Propagate memory error
127181
}
128182

@@ -134,9 +188,12 @@ b64decode_internal(PyObject *arg) {
134188
if (ret != 1) {
135189
if (ret == 0) {
136190
// Slow path: handle non-base64 input
137-
return b64decode_handle_invalid_input(out_bytes, outbuf, max_out, src, srclen);
191+
return b64decode_handle_invalid_input(out_bytes, outbuf, max_out, src, srclen, urlsafe);
138192
}
139193
Py_DECREF(out_bytes);
194+
if (urlsafe) {
195+
PyMem_Free((void *)src);
196+
}
140197
if (ret == -1) {
141198
PyErr_SetString(PyExc_NotImplementedError, "base64 codec not available in this build");
142199
} else {
@@ -145,6 +202,10 @@ b64decode_internal(PyObject *arg) {
145202
return NULL;
146203
}
147204

205+
if (urlsafe) {
206+
PyMem_Free((void *)src);
207+
}
208+
148209
// Sanity-check contract (decoder must not overflow our buffer)
149210
if (outlen > max_out) {
150211
Py_DECREF(out_bytes);
@@ -164,14 +225,17 @@ b64decode_internal(PyObject *arg) {
164225
// with stdlib b64decode.
165226
static PyObject *
166227
b64decode_handle_invalid_input(
167-
PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen)
228+
PyObject *out_bytes, char *outbuf, size_t max_out, const char *src, size_t srclen, bool freesrc)
168229
{
169230
// Copy input to a temporary buffer, with non-base64 characters and extra suffix
170231
// characters removed
171232
size_t newbuf_len = 0;
172233
char *newbuf = PyMem_Malloc(srclen);
173234
if (newbuf == NULL) {
174235
Py_DECREF(out_bytes);
236+
if (freesrc) {
237+
PyMem_Free((void *)src);
238+
}
175239
return PyErr_NoMemory();
176240
}
177241

@@ -208,6 +272,9 @@ b64decode_handle_invalid_input(
208272

209273
// Stdlib always performs a non-strict padding check
210274
if (newbuf_len % 4 != 0) {
275+
if (freesrc) {
276+
PyMem_Free((void *)src);
277+
}
211278
Py_DECREF(out_bytes);
212279
PyMem_Free(newbuf);
213280
PyErr_SetString(PyExc_ValueError, "Incorrect padding");
@@ -217,6 +284,9 @@ b64decode_handle_invalid_input(
217284
size_t outlen = max_out;
218285
int ret = base64_decode(newbuf, newbuf_len, outbuf, &outlen, 0);
219286
PyMem_Free(newbuf);
287+
if (freesrc) {
288+
PyMem_Free((void *)src);
289+
}
220290

221291
if (ret != 1) {
222292
Py_DECREF(out_bytes);
@@ -239,14 +309,22 @@ b64decode_handle_invalid_input(
239309
return out_bytes;
240310
}
241311

242-
243312
static PyObject*
244313
b64decode(PyObject *self, PyObject *const *args, size_t nargs) {
245314
if (nargs != 1) {
246315
PyErr_SetString(PyExc_TypeError, "b64decode() takes exactly one argument");
247316
return 0;
248317
}
249-
return b64decode_internal(args[0]);
318+
return b64decode_internal(args[0], false);
319+
}
320+
321+
static PyObject*
322+
urlsafe_b64decode(PyObject *self, PyObject *const *args, size_t nargs) {
323+
if (nargs != 1) {
324+
PyErr_SetString(PyExc_TypeError, "urlsafe_b64decode() takes exactly one argument");
325+
return 0;
326+
}
327+
return b64decode_internal(args[0], true);
250328
}
251329

252330
#endif
@@ -255,6 +333,8 @@ static PyMethodDef librt_base64_module_methods[] = {
255333
#ifdef MYPYC_EXPERIMENTAL
256334
{"b64encode", (PyCFunction)b64encode, METH_FASTCALL, PyDoc_STR("Encode bytes object using Base64.")},
257335
{"b64decode", (PyCFunction)b64decode, METH_FASTCALL, PyDoc_STR("Decode a Base64 encoded bytes object or ASCII string.")},
336+
{"urlsafe_b64encode", (PyCFunction)urlsafe_b64encode, METH_FASTCALL, PyDoc_STR("Encode bytes object using URL and file system safe Base64 alphabet.")},
337+
{"urlsafe_b64decode", (PyCFunction)urlsafe_b64decode, METH_FASTCALL, PyDoc_STR("Decode bytes or ASCII string using URL and file system safe Base64 alphabet.")},
258338
#endif
259339
{NULL, NULL, 0, NULL}
260340
};

mypyc/lib-rt/librt_base64.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,16 +12,16 @@ import_librt_base64(void)
1212

1313
#else // MYPYC_EXPERIMENTAL
1414

15-
#define LIBRT_BASE64_ABI_VERSION 0
16-
#define LIBRT_BASE64_API_VERSION 1
15+
#define LIBRT_BASE64_ABI_VERSION 1
16+
#define LIBRT_BASE64_API_VERSION 2
1717
#define LIBRT_BASE64_API_LEN 4
1818

1919
static void *LibRTBase64_API[LIBRT_BASE64_API_LEN];
2020

2121
#define LibRTBase64_ABIVersion (*(int (*)(void)) LibRTBase64_API[0])
2222
#define LibRTBase64_APIVersion (*(int (*)(void)) LibRTBase64_API[1])
23-
#define LibRTBase64_b64encode_internal (*(PyObject* (*)(PyObject *source)) LibRTBase64_API[2])
24-
#define LibRTBase64_b64decode_internal (*(PyObject* (*)(PyObject *source)) LibRTBase64_API[3])
23+
#define LibRTBase64_b64encode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[2])
24+
#define LibRTBase64_b64decode_internal (*(PyObject* (*)(PyObject *source, bool urlsafe)) LibRTBase64_API[3])
2525

2626
static int
2727
import_librt_base64(void)

mypyc/primitives/misc_ops.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,18 @@
473473
return_type=bytes_rprimitive,
474474
c_function_name="LibRTBase64_b64encode_internal",
475475
error_kind=ERR_MAGIC,
476+
extra_int_constants=[(0, bool_rprimitive)],
477+
experimental=True,
478+
capsule="librt.base64",
479+
)
480+
481+
function_op(
482+
name="librt.base64.urlsafe_b64encode",
483+
arg_types=[bytes_rprimitive],
484+
return_type=bytes_rprimitive,
485+
c_function_name="LibRTBase64_b64encode_internal",
486+
error_kind=ERR_MAGIC,
487+
extra_int_constants=[(1, bool_rprimitive)],
476488
experimental=True,
477489
capsule="librt.base64",
478490
)
@@ -483,6 +495,18 @@
483495
return_type=bytes_rprimitive,
484496
c_function_name="LibRTBase64_b64decode_internal",
485497
error_kind=ERR_MAGIC,
498+
extra_int_constants=[(0, bool_rprimitive)],
499+
experimental=True,
500+
capsule="librt.base64",
501+
)
502+
503+
function_op(
504+
name="librt.base64.urlsafe_b64decode",
505+
arg_types=[RUnion([bytes_rprimitive, str_rprimitive])],
506+
return_type=bytes_rprimitive,
507+
c_function_name="LibRTBase64_b64decode_internal",
508+
error_kind=ERR_MAGIC,
509+
extra_int_constants=[(1, bool_rprimitive)],
486510
experimental=True,
487511
capsule="librt.base64",
488512
)

mypyc/test-data/irbuild-base64.test

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
[case testBase64_experimental]
2-
from librt.base64 import b64encode, b64decode
2+
from librt.base64 import b64encode, b64decode, urlsafe_b64encode, urlsafe_b64decode
33

44
def enc(b: bytes) -> bytes:
55
return b64encode(b)
@@ -9,22 +9,47 @@ def dec_bytes(b: bytes) -> bytes:
99

1010
def dec_str(b: str) -> bytes:
1111
return b64decode(b)
12+
13+
def uenc(b: bytes) -> bytes:
14+
return urlsafe_b64encode(b)
15+
16+
def udec_bytes(b: bytes) -> bytes:
17+
return urlsafe_b64decode(b)
18+
19+
def udec_str(b: str) -> bytes:
20+
return urlsafe_b64decode(b)
1221
[out]
1322
def enc(b):
1423
b, r0 :: bytes
1524
L0:
16-
r0 = LibRTBase64_b64encode_internal(b)
25+
r0 = LibRTBase64_b64encode_internal(b, 0)
1726
return r0
1827
def dec_bytes(b):
1928
b, r0 :: bytes
2029
L0:
21-
r0 = LibRTBase64_b64decode_internal(b)
30+
r0 = LibRTBase64_b64decode_internal(b, 0)
2231
return r0
2332
def dec_str(b):
2433
b :: str
2534
r0 :: bytes
2635
L0:
27-
r0 = LibRTBase64_b64decode_internal(b)
36+
r0 = LibRTBase64_b64decode_internal(b, 0)
37+
return r0
38+
def uenc(b):
39+
b, r0 :: bytes
40+
L0:
41+
r0 = LibRTBase64_b64encode_internal(b, 1)
42+
return r0
43+
def udec_bytes(b):
44+
b, r0 :: bytes
45+
L0:
46+
r0 = LibRTBase64_b64decode_internal(b, 1)
47+
return r0
48+
def udec_str(b):
49+
b :: str
50+
r0 :: bytes
51+
L0:
52+
r0 = LibRTBase64_b64decode_internal(b, 1)
2853
return r0
2954

3055
[case testBase64ExperimentalDisabled]

0 commit comments

Comments
 (0)