Home > Software engineering >  Checking for equality if either input can be `str` or `bytes`
Checking for equality if either input can be `str` or `bytes`

Time:08-27

I am trying to write a function that checks if two strings (with ASCII-only content) or bytes are equal.

Right now I have:

import typing as typ


def is_equal_str_bytes(
    a: typ.Union[str, bytes],
    b: typ.Union[str, bytes],
) -> bool:
    if isinstance(a, str):
        a = a.encode()
    if isinstance(b, str):
        b = b.encode()
    return a == b

This works with the any combination of str or bytes types, while the == operator will return False (rightfully) if the two types differ.

import itertools


ss = "ciao", b"ciao"
for a, b in itertools.product(ss, repeat=2):
    print(f"{a!r:<8} {b!r:<8} {is_equal_str_bytes(a, b)} {a == b}")
# 'ciao'   'ciao'   True True
# 'ciao'   b'ciao'  True False
# b'ciao'  'ciao'   True False
# b'ciao'  b'ciao'  True True

Is there a simpler / faster way?

CodePudding user response:

Some benchmarks with random equal strings/bytes of a million characters (on TIO with Python 3.8 pre-release, but I got similar times with 3.10.2):

  186.88 us  s.encode()
  187.39 us  s.encode("utf-8")
  183.85 us  s.encode("ascii")
   94.62 us  b.decode()
   94.27 us  b.decode("utf-8")
  137.91 us  b.decode("ascii")
   79.93 us  s == s2
   82.69 us  b == b2
  182.72 us  s   "a"
  177.06 us  b   b"a"
    0.08 us  len(s)
    0.07 us  len(b)
    1.14 us  s[:1000].encode()
    0.97 us  b[:1000].decode()
    2.06 us  s[::1000].encode()
    1.45 us  b[::1000].decode()
    1.91 us  hash(s)
    1.56 us  hash(b)
  508.62 us  hash(s2)
  546.00 us  hash(b2)
    2.85 us  str(s)
 9142.59 us  str(b)
13541.64 us  repr(s)
 9100.34 us  repr(b)

Thoughts based on that:

  • I thought for simpler code, maybe we could apply str or repr to both of them and then somehow compare the resulting strings (like after removing b prefixes) but the benchmark shows that that would be very slow.
  • Getting the lengths is very cheap, so I'd compare those first. Return False if different, otherwise continue.
  • If you've hashed them already or are going to afterwards anyway, then you could compare the hashes (and return False if different, otherwise continue). See ASCII str / bytes hash collision for why equal ASCII string and ASCII bytes have the same hash. (But I'm not sure it's guaranteed by the language, so it might not be safe, I'm not sure). Note that hashing the first time is slow (see times for hashing s2/b2) but subsequent lookups of the stored hash is fast (see times for hashing s/b).
  • Decoding seems faster than encoding, so do that instead.
  • Only decode if the types differ (one is string and one is bytes), otherwise just use ==.
  • It's wasteful to decode a million bytes if already the first one is a mismatch. So might be worth it to decode/compare chunks of shorter length instead of the whole thing, or test some short prefix or cross section before testing the whole thing.

So here's some potentially faster one using the above optimizations (not tested/benchmarked, partly because it depends on your data):

import typing as typ

def is_equal_str_bytes(
    a: typ.Union[str, bytes],
    b: typ.Union[str, bytes],
) -> bool:
    if len(a) != len(b):
        return False
    if hash(a) != hash(b):
        return False
    if type(a) is type(b):
        return a == b
    if isinstance(a, bytes):  # make a=str, b=bytes
        a, b = b, a
    if a[:1000] != b[:1000].decode():
        return False
    if a[::1000] != b[::1000].decode():
        return False
    return a == b.decode()

My benchmark code:

import os
from timeit import repeat

n = 10**6
b = bytes(x & 127 for x in os.urandom(n))
s = b.decode()
assert hash(s) == hash(b)

setup = '''
from __main__ import s, b
s2 = b.decode()  # Always fresh so it doesn't have a hash stored already 
b2 = s.encode()
assert s2 is not s and b2 is not b
'''

exprs = [
    's.encode()',
    's.encode("utf-8")',
    's.encode("ascii")',
    'b.decode()',
    'b.decode("utf-8")',
    'b.decode("ascii")',
    's == s2',
    'b == b2',
    's   "a"',
    'b   b"a"',
    'len(s)',
    'len(b)',
    's[:1000].encode()',
    'b[:1000].decode()',
    's[::1000].encode()',
    'b[::1000].decode()',
    'hash(s)',
    'hash(b)',
    'hash(s2)',
    'hash(b2)',
    'str(s)',
    'str(b)',
    'repr(s)',
    'repr(b)',
]

for _ in range(3):
    for e in exprs:
        number = 100 if exprs.index(e) < exprs.index('hash(s)') else 1
        t = min(repeat(e, setup, number=number)) / number
        print('%8.2f us ' % (t * 1e6), e)
    print()

CodePudding user response:

I'm afraid there is no simpler way to do it, if typing everything to str for processing purposes straight from the source is not an option.

If you want to make the function itself marginally faster, you can add another check for reducing unnecessary translations off the bat:

if type(a) == type(b):
    return a == b

Third option would be introducing a new subclass, derivative of str for example and adding a comparison function - or a dummy decode() function. Then use that instead of the builtin class with __builtin__.str = my_str.

  • Related