cant download with pytube even with version 12 - python
I have seen other posts about this error, and I have applied all the recommended so-called 'fixes', but none of them are working in my case.
I have windows 11, AMD 5700g, Radeon Rx 560, 32gb RAM.
So here's the error text in full:
Traceback (most recent call last):
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\__main__.py", line 181, in fmt_streams
extract.apply_signature(stream_manifest, self.vid_info, self.js)
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\extract.py", line 409, in apply_signature
cipher = Cipher(js=js)
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\cipher.py", line 43, in __init__
self.throttling_plan = get_throttling_plan(js)
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\cipher.py", line 405, in get_throttling_plan
raw_code = get_throttling_function_code(js)
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\cipher.py", line 311, in get_throttling_function_code
name = re.escape(get_throttling_function_name(js))
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\cipher.py", line 296, in get_throttling_function_name
raise RegexMatchError(
pytube.exceptions.RegexMatchError: get_throttling_function_name: could not find match for multiple
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "C:\Users\Redje\Desktop\PythonFiles\YouTubeDownloader\YTD1.py", line 5, in <module>
stream = video.streams.get_highest_resolution()
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\__main__.py", line 296, in streams
return StreamQuery(self.fmt_streams)
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\__main__.py", line 188, in fmt_streams
extract.apply_signature(stream_manifest, self.vid_info, self.js)
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\extract.py", line 409, in apply_signature
cipher = Cipher(js=js)
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\cipher.py", line 43, in __init__
self.throttling_plan = get_throttling_plan(js)
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\cipher.py", line 405, in get_throttling_plan
raw_code = get_throttling_function_code(js)
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\cipher.py", line 311, in get_throttling_function_code
name = re.escape(get_throttling_function_name(js))
File "C:\Users\Redje\AppData\Local\Programs\Python\Python310\lib\site-packages\pytube\cipher.py", line 296, in get_throttling_function_name
raise RegexMatchError(
pytube.exceptions.RegexMatchError: get_throttling_function_name: could not find match for multiple
I have tweaked 'cipher.py' according to an earlier post, and no luck. Any help would be appreciated.
Here is my cipher.py code:
"""
This module contains all logic necessary to decipher the signature.
YouTube's strategy to restrict downloading videos is to send a ciphered version
of the signature to the client, along with the decryption algorithm obfuscated
in JavaScript. For the clients to play the videos, JavaScript must take the
ciphered version, cycle it through a series of "transform functions," and then
signs the media URL with the output.
This module is responsible for (1) finding and extracting those "transform
functions" (2) maps them to Python equivalents and (3) taking the ciphered
signature and decoding it.
"""
import logging
import re
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Tuple
from pytube.exceptions import ExtractError, RegexMatchError
from pytube.helpers import cache, regex_search
from pytube.parser import find_object_from_startpoint, throttling_array_split
logger = logging.getLogger(__name__)
class Cipher:
def __init__(self, js: str):
self.transform_plan: List[str] = get_transform_plan(js)
var_regex = re.compile(r"^\w+\W")
var_match = var_regex.search(self.transform_plan[0])
if not var_match:
raise RegexMatchError(
caller="__init__", pattern=var_regex.pattern
)
var = var_match.group(0)[:-1]
self.transform_map = get_transform_map(js, var)
self.js_func_patterns = [
r"\w+\.(\w+)\(\w,(\d+)\)",
r"\w+\[(\"\w+\")\]\(\w,(\d+)\)"
]
self.throttling_plan = get_throttling_plan(js)
self.throttling_array = get_throttling_function_array(js)
self.calculated_n = None
def calculate_n(self, initial_n: list):
"""Converts n to the correct value to prevent throttling."""
if self.calculated_n:
return self.calculated_n
# First, update all instances of 'b' with the list(initial_n)
for i in range(len(self.throttling_array)):
if self.throttling_array[i] == 'b':
self.throttling_array[i] = initial_n
for step in self.throttling_plan:
curr_func = self.throttling_array[int(step[0])]
if not callable(curr_func):
logger.debug(f'{curr_func} is not callable.')
logger.debug(f'Throttling array:\n{self.throttling_array}\n')
raise ExtractError(f'{curr_func} is not callable.')
first_arg = self.throttling_array[int(step[1])]
if len(step) == 2:
curr_func(first_arg)
elif len(step) == 3:
second_arg = self.throttling_array[int(step[2])]
curr_func(first_arg, second_arg)
self.calculated_n = ''.join(initial_n)
return self.calculated_n
def get_signature(self, ciphered_signature: str) -> str:
"""Decipher the signature.
Taking the ciphered signature, applies the transform functions.
:param str ciphered_signature:
The ciphered signature sent in the ``player_config``.
:rtype: str
:returns:
Decrypted signature required to download the media content.
"""
signature = list(ciphered_signature)
for js_func in self.transform_plan:
name, argument = self.parse_function(js_func) # type: ignore
signature = self.transform_map[name](signature, argument)
logger.debug(
"applied transform function\n"
"output: %s\n"
"js_function: %s\n"
"argument: %d\n"
"function: %s",
"".join(signature),
name,
argument,
self.transform_map[name],
)
return "".join(signature)
#cache
def parse_function(self, js_func: str) -> Tuple[str, int]:
"""Parse the Javascript transform function.
Break a JavaScript transform function down into a two element ``tuple``
containing the function name and some integer-based argument.
:param str js_func:
The JavaScript version of the transform function.
:rtype: tuple
:returns:
two element tuple containing the function name and an argument.
**Example**:
parse_function('DE.AJ(a,15)')
('AJ', 15)
"""
logger.debug("parsing transform function")
for pattern in self.js_func_patterns:
regex = re.compile(pattern)
parse_match = regex.search(js_func)
if parse_match:
fn_name, fn_arg = parse_match.groups()
return fn_name, int(fn_arg)
raise RegexMatchError(
caller="parse_function", pattern="js_func_patterns"
)
def get_initial_function_name(js: str) -> str:
"""Extract the name of the function responsible for computing the signature.
:param str js:
The contents of the base.js asset file.
:rtype: str
:returns:
Function name from regex match
"""
function_patterns = [
r"\b[cs]\s*&&\s*[adf]\.set\([^,]+\s*,\s*encodeURIComponent\s*\(\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r"\b[a-zA-Z0-9]+\s*&&\s*[a-zA-Z0-9]+\.set\([^,]+\s*,\s*encodeURIComponent\s*\(\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r'(?:\b|[^a-zA-Z0-9$])(?P<sig>[a-zA-Z0-9$]{2})\s*=\s*function\(\s*a\s*\)\s*{\s*a\s*=\s*a\.split\(\s*""\s*\)', # noqa: E501
r'(?P<sig>[a-zA-Z0-9$]+)\s*=\s*function\(\s*a\s*\)\s*{\s*a\s*=\s*a\.split\(\s*""\s*\)', # noqa: E501
r'(["\'])signature\1\s*,\s*(?P<sig>[a-zA-Z0-9$]+)\(',
r"\.sig\|\|(?P<sig>[a-zA-Z0-9$]+)\(",
r"yt\.akamaized\.net/\)\s*\|\|\s*.*?\s*[cs]\s*&&\s*[adf]\.set\([^,]+\s*,\s*(?:encodeURIComponent\s*\()?\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r"\b[cs]\s*&&\s*[adf]\.set\([^,]+\s*,\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r"\b[a-zA-Z0-9]+\s*&&\s*[a-zA-Z0-9]+\.set\([^,]+\s*,\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r"\bc\s*&&\s*a\.set\([^,]+\s*,\s*\([^)]*\)\s*\(\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r"\bc\s*&&\s*[a-zA-Z0-9]+\.set\([^,]+\s*,\s*\([^)]*\)\s*\(\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
r"\bc\s*&&\s*[a-zA-Z0-9]+\.set\([^,]+\s*,\s*\([^)]*\)\s*\(\s*(?P<sig>[a-zA-Z0-9$]+)\(", # noqa: E501
]
logger.debug("finding initial function name")
for pattern in function_patterns:
regex = re.compile(pattern)
function_match = regex.search(js)
if function_match:
logger.debug("finished regex search, matched: %s", pattern)
return function_match.group(1)
raise RegexMatchError(
caller="get_initial_function_name", pattern="multiple"
)
def get_transform_plan(js: str) -> List[str]:
"""Extract the "transform plan".
The "transform plan" is the functions that the ciphered signature is
cycled through to obtain the actual signature.
:param str js:
The contents of the base.js asset file.
**Example**:
['DE.AJ(a,15)',
'DE.VR(a,3)',
'DE.AJ(a,51)',
'DE.VR(a,3)',
'DE.kT(a,51)',
'DE.kT(a,8)',
'DE.VR(a,3)',
'DE.kT(a,21)']
"""
name = re.escape(get_initial_function_name(js))
pattern = r"%s=function\(\w\){[a-z=\.\(\"\)]*;(.*);(?:.+)}" % name
logger.debug("getting transform plan")
return regex_search(pattern, js, group=1).split(";")
def get_transform_object(js: str, var: str) -> List[str]:
"""Extract the "transform object".
The "transform object" contains the function definitions referenced in the
"transform plan". The ``var`` argument is the obfuscated variable name
which contains these functions, for example, given the function call
``DE.AJ(a,15)`` returned by the transform plan, "DE" would be the var.
:param str js:
The contents of the base.js asset file.
:param str var:
The obfuscated variable name that stores an object with all functions
that descrambles the signature.
**Example**:
>>> get_transform_object(js, 'DE')
['AJ:function(a){a.reverse()}',
'VR:function(a,b){a.splice(0,b)}',
'kT:function(a,b){var c=a[0];a[0]=a[b%a.length];a[b]=c}']
"""
pattern = r"var %s={(.*?)};" % re.escape(var)
logger.debug("getting transform object")
regex = re.compile(pattern, flags=re.DOTALL)
transform_match = regex.search(js)
if not transform_match:
raise RegexMatchError(caller="get_transform_object", pattern=pattern)
return transform_match.group(1).replace("\n", " ").split(", ")
def get_transform_map(js: str, var: str) -> Dict:
"""Build a transform function lookup.
Build a lookup table of obfuscated JavaScript function names to the
Python equivalents.
:param str js:
The contents of the base.js asset file.
:param str var:
The obfuscated variable name that stores an object with all functions
that descrambles the signature.
"""
transform_object = get_transform_object(js, var)
mapper = {}
for obj in transform_object:
# AJ:function(a){a.reverse()} => AJ, function(a){a.reverse()}
name, function = obj.split(":", 1)
fn = map_functions(function)
mapper[name] = fn
return mapper
def get_throttling_function_name(js: str) -> str:
"""Extract the name of the function that computes the throttling parameter.
:param str js:
The contents of the base.js asset file.
:rtype: str
:returns:
The name of the function used to compute the throttling parameter.
"""
function_patterns = [
# https://github.com/ytdl-org/youtube-dl/issues/29326#issuecomment-865985377
# https://github.com/yt-dlp/yt-dlp/commit/48416bc4a8f1d5ff07d5977659cb8ece7640dcd8
# var Bpa = [iha];
# ...
# a.C && (b = a.get("n")) && (b = Bpa[0](b), a.set("n", b),
# Bpa.length || iha("")) }};
# In the above case, `iha` is the relevant function name
r'a\.[a-zA-Z]\s*&&\s*\([a-z]\s*=\s*a\.get\("n"\)\)\s*&&\s*'
r'\([a-z]\s*=\s*([a-zA-Z0-9$]{2, 3})(\[\d+\])?\([a-z]\)',
]
logger.debug('Finding throttling function name')
for pattern in function_patterns:
regex = re.compile(pattern)
function_match = regex.search(js)
if function_match:
logger.debug("finished regex search, matched: %s", pattern)
if len(function_match.groups()) == 1:
return function_match.group(1)
idx = function_match.group(2)
if idx:
idx = idx.strip("[]")
array = re.search(
r'var {nfunc}\s*=\s*(\[.+?\]);'.format(
nfunc=escape(function_match.group(1))),
js
)
if array:
array = array.group(1).strip("[]").split(",")
array = [x.strip() for x in array]
return array[int(idx)]
raise RegexMatchError(
caller="get_throttling_function_name", pattern="multiple"
)
def get_throttling_function_code(js: str) -> str:
"""Extract the raw code for the throttling function.
:param str js:
The contents of the base.js asset file.
:rtype: str
:returns:
The name of the function used to compute the throttling parameter.
"""
# Begin by extracting the correct function name
name = re.escape(get_throttling_function_name(js))
# Identify where the function is defined
pattern_start = r"%s=function\(\w\)" % name
regex = re.compile(pattern_start)
match = regex.search(js)
# Extract the code within curly braces for the function itself, and merge any split lines
code_lines_list = find_object_from_startpoint(js, match.span()[1]).split('\n')
joined_lines = "".join(code_lines_list)
# Prepend function definition (e.g. `Dea=function(a)`)
return match.group(0) + joined_lines
def get_throttling_function_array(js: str) -> List[Any]:
"""Extract the "c" array.
:param str js:
The contents of the base.js asset file.
:returns:
The array of various integers, arrays, and functions.
"""
raw_code = get_throttling_function_code(js)
array_start = r",c=\["
array_regex = re.compile(array_start)
match = array_regex.search(raw_code)
array_raw = find_object_from_startpoint(raw_code, match.span()[1] - 1)
str_array = throttling_array_split(array_raw)
converted_array = []
for el in str_array:
try:
converted_array.append(int(el))
continue
except ValueError:
# Not an integer value.
pass
if el == 'null':
converted_array.append(None)
continue
if el.startswith('"') and el.endswith('"'):
# Convert e.g. '"abcdef"' to string without quotation marks, 'abcdef'
converted_array.append(el[1:-1])
continue
if el.startswith('function'):
mapper = (
(r"{for\(\w=\(\w%\w\.length\+\w\.length\)%\w\.length;\w--;\)\w\.unshift\(\w.pop\(\)\)}", throttling_unshift), # noqa:E501
(r"{\w\.reverse\(\)}", throttling_reverse),
(r"{\w\.push\(\w\)}", throttling_push),
(r";var\s\w=\w\[0\];\w\[0\]=\w\[\w\];\w\[\w\]=\w}", throttling_swap),
(r"case\s\d+", throttling_cipher_function),
(r"\w\.splice\(0,1,\w\.splice\(\w,1,\w\[0\]\)\[0\]\)", throttling_nested_splice), # noqa:E501
(r";\w\.splice\(\w,1\)}", js_splice),
(r"\w\.splice\(-\w\)\.reverse\(\)\.forEach\(function\(\w\){\w\.unshift\(\w\)}\)", throttling_prepend), # noqa:E501
(r"for\(var \w=\w\.length;\w;\)\w\.push\(\w\.splice\(--\w,1\)\[0\]\)}", throttling_reverse), # noqa:E501
)
found = False
for pattern, fn in mapper:
if re.search(pattern, el):
converted_array.append(fn)
found = True
if found:
continue
converted_array.append(el)
# Replace null elements with array itself
for i in range(len(converted_array)):
if converted_array[i] is None:
converted_array[i] = converted_array
return converted_array
def get_throttling_plan(js: str):
"""Extract the "throttling plan".
The "throttling plan" is a list of tuples used for calling functions
in the c array. The first element of the tuple is the index of the
function to call, and any remaining elements of the tuple are arguments
to pass to that function.
:param str js:
The contents of the base.js asset file.
:returns:
The full function code for computing the throttlign parameter.
"""
raw_code = get_throttling_function_code(js)
transform_start = r"try{"
plan_regex = re.compile(transform_start)
match = plan_regex.search(raw_code)
transform_plan_raw = find_object_from_startpoint(raw_code, match.span()[1] - 1)
# Steps are either c[x](c[y]) or c[x](c[y],c[z])
step_start = r"c\[(\d+)\]\(c\[(\d+)\](,c(\[(\d+)\]))?\)"
step_regex = re.compile(step_start)
matches = step_regex.findall(transform_plan_raw)
transform_steps = []
for match in matches:
if match[4] != '':
transform_steps.append((match[0],match[1],match[4]))
else:
transform_steps.append((match[0],match[1]))
return transform_steps
def reverse(arr: List, _: Optional[Any]):
"""Reverse elements in a list.
This function is equivalent to:
.. code-block:: javascript
function(a, b) { a.reverse() }
This method takes an unused ``b`` variable as their transform functions
universally sent two arguments.
**Example**:
>>> reverse([1, 2, 3, 4])
[4, 3, 2, 1]
"""
return arr[::-1]
def splice(arr: List, b: int):
"""Add/remove items to/from a list.
This function is equivalent to:
.. code-block:: javascript
function(a, b) { a.splice(0, b) }
**Example**:
>>> splice([1, 2, 3, 4], 2)
[1, 2]
"""
return arr[b:]
def swap(arr: List, b: int):
"""Swap positions at b modulus the list length.
This function is equivalent to:
.. code-block:: javascript
function(a, b) { var c=a[0];a[0]=a[b%a.length];a[b]=c }
**Example**:
>>> swap([1, 2, 3, 4], 2)
[3, 2, 1, 4]
"""
r = b % len(arr)
return list(chain([arr[r]], arr[1:r], [arr[0]], arr[r + 1 :]))
def throttling_reverse(arr: list):
"""Reverses the input list.
Needs to do an in-place reversal so that the passed list gets changed.
To accomplish this, we create a reversed copy, and then change each
indvidual element.
"""
reverse_copy = arr.copy()[::-1]
for i in range(len(reverse_copy)):
arr[i] = reverse_copy[i]
def throttling_push(d: list, e: Any):
"""Pushes an element onto a list."""
d.append(e)
def throttling_mod_func(d: list, e: int):
"""Perform the modular function from the throttling array functions.
In the javascript, the modular operation is as follows:
e = (e % d.length + d.length) % d.length
We simply translate this to python here.
"""
return (e % len(d) + len(d)) % len(d)
def throttling_unshift(d: list, e: int):
"""Rotates the elements of the list to the right.
In the javascript, the operation is as follows:
for(e=(e%d.length+d.length)%d.length;e--;)d.unshift(d.pop())
"""
e = throttling_mod_func(d, e)
new_arr = d[-e:] + d[:-e]
d.clear()
for el in new_arr:
d.append(el)
def throttling_cipher_function(d: list, e: str):
"""This ciphers d with e to generate a new list.
In the javascript, the operation is as follows:
var h = [A-Za-z0-9-_], f = 96; // simplified from switch-case loop
d.forEach(
function(l,m,n){
this.push(
n[m]=h[
(h.indexOf(l)-h.indexOf(this[m])+m-32+f--)%h.length
]
)
},
e.split("")
)
"""
h = list('ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_')
f = 96
# by naming it "this" we can more closely reflect the js
this = list(e)
# This is so we don't run into weirdness with enumerate while
# we change the input list
copied_list = d.copy()
for m, l in enumerate(copied_list):
bracket_val = (h.index(l) - h.index(this[m]) + m - 32 + f) % len(h)
this.append(
h[bracket_val]
)
d[m] = h[bracket_val]
f -= 1
def throttling_nested_splice(d: list, e: int):
"""Nested splice function in throttling js.
In the javascript, the operation is as follows:
function(d,e){
e=(e%d.length+d.length)%d.length;
d.splice(
0,
1,
d.splice(
e,
1,
d[0]
)[0]
)
}
While testing, all this seemed to do is swap element 0 and e,
but the actual process is preserved in case there was an edge
case that was not considered.
"""
e = throttling_mod_func(d, e)
inner_splice = js_splice(
d,
e,
1,
d[0]
)
js_splice(
d,
0,
1,
inner_splice[0]
)
def throttling_prepend(d: list, e: int):
"""
In the javascript, the operation is as follows:
function(d,e){
e=(e%d.length+d.length)%d.length;
d.splice(-e).reverse().forEach(
function(f){
d.unshift(f)
}
)
}
Effectively, this moves the last e elements of d to the beginning.
"""
start_len = len(d)
# First, calculate e
e = throttling_mod_func(d, e)
# Then do the prepending
new_arr = d[-e:] + d[:-e]
# And update the input list
d.clear()
for el in new_arr:
d.append(el)
end_len = len(d)
assert start_len == end_len
def throttling_swap(d: list, e: int):
"""Swap positions of the 0'th and e'th elements in-place."""
e = throttling_mod_func(d, e)
f = d[0]
d[0] = d[e]
d[e] = f
def js_splice(arr: list, start: int, delete_count=None, *items):
"""Implementation of javascript's splice function.
:param list arr:
Array to splice
:param int start:
Index at which to start changing the array
:param int delete_count:
Number of elements to delete from the array
:param *items:
Items to add to the array
Reference: https://developer.mozilla.org/en-US/docs/Web/JavaScript/Reference/Global_Objects/Array/splice # noqa:E501
"""
# Special conditions for start value
try:
if start > len(arr):
start = len(arr)
# If start is negative, count backwards from end
if start < 0:
start = len(arr) - start
except TypeError:
# Non-integer start values are treated as 0 in js
start = 0
# Special condition when delete_count is greater than remaining elements
if not delete_count or delete_count >= len(arr) - start:
delete_count = len(arr) - start # noqa: N806
deleted_elements = arr[start:start + delete_count]
# Splice appropriately.
new_arr = arr[:start] + list(items) + arr[start + delete_count:]
# Replace contents of input array
arr.clear()
for el in new_arr:
arr.append(el)
return deleted_elements
def map_functions(js_func: str) -> Callable:
"""For a given JavaScript transform function, return the Python equivalent.
:param str js_func:
The JavaScript version of the transform function.
"""
mapper = (
# function(a){a.reverse()}
(r"{\w\.reverse\(\)}", reverse),
# function(a,b){a.splice(0,b)}
(r"{\w\.splice\(0,\w\)}", splice),
# function(a,b){var c=a[0];a[0]=a[b%a.length];a[b]=c}
(r"{var\s\w=\w\[0\];\w\[0\]=\w\[\w\%\w.length\];\w\[\w\]=\w}", swap),
# function(a,b){var c=a[0];a[0]=a[b%a.length];a[b%a.length]=c}
(
r"{var\s\w=\w\[0\];\w\[0\]=\w\[\w\%\w.length\];\w\[\w\%\w.length\]=\w}",
swap,
),
)
for pattern, fn in mapper:
if re.search(pattern, js_func):
return fn
raise RegexMatchError(caller="map_functions", pattern="multiple")
Any help would be appreciated!
Related
Check which optional parameters are supplied to a function call
My goal is to run through all the *.py files in a directory and look at each call to a specific function test_func. This function has some optional parameters and I need to audit when the function is called with the optional parameters. My thought is to use the ast library (specifically ast.walk()). I suppose this is a static analysis problem. # function definition def test_func( name: str, *, user: Optional['User'] = None, request: Optional[WebRequest] = None, **kwargs ) -> bool: pass # somewhere in another file ... test_func('name0') test_func('name1', request=request) test_func('name1') test_func('name2', user=user) # figure out something like below: # name0 is never given any optional parameters # name1 is sometimes given request # name2 is always given user
Here is a POC : import typing from typing import Optional class User: pass class WebRequest: pass # function definition def test_func( name: str, *, user: Optional['User'] = None, request: Optional[WebRequest] = None, **kwargs ) -> bool: pass # somewhere in another file ... test_func('name0') test_func('name1', request=WebRequest()) test_func('name1') test_func('name2', user=User()) # figure out something like below: # name0 is never given any optional parameters # name1 is sometimes given request # name2 is always given user with open(__file__, "rt") as py_file: py_code = py_file.read() import collections each_call_kwargs_names_by_arg0_value: typing.Dict[str, typing.List[typing.Tuple[str, ...]]] = collections.defaultdict(list) import ast tree = ast.parse(py_code) for node in ast.walk(tree): if isinstance(node, ast.Call): if hasattr(node.func, "id"): name = node.func.id elif hasattr(node.func, "attr"): name = node.func.attr elif hasattr(node.func, "value"): name = node.func.value.id else: raise NotImplementedError print(name) if name == "test_func": arg0_value = typing.cast(ast.Str, node.args[0]).s each_call_kwargs_names_by_arg0_value[arg0_value].append( tuple(keyword.arg for keyword in node.keywords) ) for arg0_value, each_call_kwargs_names in each_call_kwargs_names_by_arg0_value.items(): frequency = "NEVER" if all(len(call_args) == 0 for call_args in each_call_kwargs_names) else \ "ALWAYS" if all(len(call_args) != 0 for call_args in each_call_kwargs_names) else \ "SOMETIMES" print(f"{arg0_value!r} {frequency}: {each_call_kwargs_names}") # Output : # 'name0' NEVER: [()] # 'name1' SOMETIMES: [('request',), ()] # 'name2' ALWAYS: [('user',)]
You can use a recursive generator function to traverse an ast of your Python code: import ast def get_calls(d, f = ['test_func']): if isinstance(d, ast.Call) and d.func.id in f: yield None if not d.args else d.args[0].value, [i.arg for i in d.keywords] for i in getattr(d, '_fields', []): vals = (m if isinstance((m:=getattr(d, i)), list) else [m]) yield from [j for k in vals for j in get_calls(k, f = f)] Putting it all together: import os, collections d = collections.defaultdict(list) for f in os.listdir(os.getcwd()): if f.endswith('.py'): with open(f) as f: for a, b in get_calls(ast.parse(f.read())): d[a].append(b) r = {a:{'verdict':'never' if not any(b) else 'always' if all(b) else 'sometimes', 'params':[i[0] for i in b if i]} for a, b in d.items()} Output: {'name0': {'verdict': 'never', 'params': []}, 'name1': {'verdict': 'sometimes', 'params': ['request']}, 'name2': {'verdict': 'always', 'params': ['user']}}
#rdflib (python): how to get a URIRef from a string such as 'ns:xxx'?
I have a RDFlib graph g, whose NameSpaceManager is aware of some namespaces. How do I get a URIRef from a string such as 'ns:xxx', where ns is the prefix associated to a namespace known by g.namespace_manager? Basically, I'm looking for a method which does the inverse operation of URIRef's n3(g.namespace_manager). I'm pretty confident that there is a way to do it, as a similar function is needed to parse turtle files, or sparql queries, but I can't find it. Otherwise of course, it must not be very difficult to write it. TIA
from rdflib import Graph, Namespace from rdflib.namespace import RDF g = Graph() NS = Namespace("http://example.com/") # then, say Xxx is a class and Aaa is an instance of Xxx... g.add((NS.Aaa, RDF.type, NS.Xxx)) # so use NS.Xxx (or NS["Xxx"]) to get a URIRef of NS.Xxx from Namespace NS print(type(NS)) # --> <class 'rdflib.term.URIRef'> print(type(NS.Xxx)) # --> <class 'rdflib.term.URIRef'> print(NS.Xxx) # --> "http://example.com/Xxx" If you want to bind a prefix within a graph, you use the rdflib Graph class' bind() method so, for the code above, you would use: g.bind("ns", NS) Now the graph, if serialized with a format that knows about prefixes, like Turtle, will use "ns". The above data would be: #prefix ns: <http://example.com/> . #prefix rdf: <http://www.w3.org/1999/02/22-rdf-syntax-ns#> . ns:Aaa rdf:type ns:Xxx . So, in Python, if you want to make the URI "http://example.com/Xxx"" from the string "ns:Xxx" you should have everything you need: the namespace declaration: NS = Namespace("http://example.com/") the prefix/namespace binding g.bind("ns", NS) IFF, on the other hand, you didn't declare the namespace yourself but it's in a graph and you only have the short form URI "ns:Xxx", you can do this to list all bound prefixes & namespaces used in the graph: for n in g.namespace_manager.namespaces(): print(n) returns, for the data above: ('xml', rdflib.term.URIRef('http://www.w3.org/XML/1998/namespace')) ('rdf', rdflib.term.URIRef('http://www.w3.org/1999/02/22-rdf-syntax-ns#')) ('rdfs', rdflib.term.URIRef('http://www.w3.org/2000/01/rdf-schema#')) ('xsd', rdflib.term.URIRef('http://www.w3.org/2001/XMLSchema#')) ('eg', rdflib.term.URIRef('http://example.com/')) So, if you know "eg:Xxx", you can split off the "eg" part and make the URI you want like this: print( [str(x[1]) for x in g.namespace_manager.namespaces() if x[0] == s.split(":")[0]] [0] + s.split(":")[1] ) prints: http://example.com/Xxx
from rdflib.namespace import NamespaceManager import rdflib as rdflib class MyNamespacesInfo: def __init__(self, namespace_manager: NamespaceManager): # as I don't know how to get the namespace from a prefix from the API # I construct a dict self.pref2ns = {} for pref, ns in namespace_manager.namespaces(): self.pref2ns[pref] = ns def uriref(self, n3uri: str) -> rdflib.URIRef: # n3uri: either 'ns:xxx', '<http://..../xxx>' or 'http://..../xxx' if n3uri[0] == '<': if n3uri[len(n3uri)-1] == '>': return rdflib.URIRef(n3uri[1:-1]) else: raise ValueError("Illegal uri: ", n3uri) else: return self.prefixed_2_uriref(n3uri, laxist=True) def prefixed_2_uriref(self, short_uri: str, laxist=True) -> rdflib.URIRef: # param short_uri eg. 'ns:xxx', where ns declared in namespace_manager # using laxist = True, you also can pass a long uri s = short_uri.split(':') if len(s) < 2: if laxist: return rdflib.URIRef(short_uri) else: raise ValueError('Not a prefix:localname string: ' + short_uri) prefix = s[0] ns = self.pref2ns.get(prefix) if ns == None: if laxist: return rdflib.URIRef(short_uri) else: raise ValueError('Unknown prefix: ' + prefix) else: x = ns + s[1] for i in range(2, len(s)): x = x + (':' + s[i]) return x # example of use: g = rdflib.Graph() g.parse('http://www.semanlink.net/tag/rdf_tools.rdf') ns_info = MyNamespacesInfo(g.namespace_manager) # all of following calls print http://www.semanlink.net/tag/rdflib x = ns_info.uriref('tag:rdflib') print(x) x = ns_info.uriref('http://www.semanlink.net/tag/rdflib') print(x) x = ns_info.uriref('<http://www.semanlink.net/tag/rdflib>') print(x)
Run python script to replace betacode with greek letters LaTeX
I want to convert the betacode in an existing .tex-File to normal greek letters. For example: I want to replace: \bcode{lo/gos} with simple: λόγος And so on for all other glyphs. Fortunately there seems to be a python-script that is supposed to do just that. But, being completely inexperienced I simply don’t know how to run it. Here is the code of the python sript: # beta2unicode.py # # Version 2004-11-23 # # James Tauber # http://jtauber.com/ # # You are free to redistribute this, but please inform me of any errors # # USAGE: # # trie = beta2unicodeTrie() # beta = "LO/GOS\n"; # unicode, remainder = trie.convert(beta) # # - to get final sigma, string must end in \n # - remainder will contain rest of beta if not all can be converted class Trie: def __init__(self): self.root = [None, {}] def add(self, key, value): curr_node = self.root for ch in key: curr_node = curr_node[1].setdefault(ch, [None, {}]) curr_node[0] = value def find(self, key): curr_node = self.root for ch in key: try: curr_node = curr_node[1][ch] except KeyError: return None return curr_node[0] def findp(self, key): curr_node = self.root remainder = key for ch in key: try: curr_node = curr_node[1][ch] except KeyError: return (curr_node[0], remainder) remainder = remainder[1:] return (curr_node[0], remainder) def convert(self, keystring): valuestring = "" key = keystring while key: value, key = self.findp(key) if not value: return (valuestring, key) valuestring += value return (valuestring, key) def beta2unicodeTrie(): t = Trie() t.add("*A", u"\u0391") t.add("*B", u"\u0392") t.add("*G", u"\u0393") t.add("*D", u"\u0394") t.add("*E", u"\u0395") t.add("*Z", u"\u0396") t.add("*H", u"\u0397") t.add("*Q", u"\u0398") t.add("*I", u"\u0399") t.add("*K", u"\u039A") t.add("*L", u"\u039B") t.add("*M", u"\u039C") t.add("*N", u"\u039D") t.add("*C", u"\u039E") t.add("*O", u"\u039F") t.add("*P", u"\u03A0") t.add("*R", u"\u03A1") t.add("*S", u"\u03A3") t.add("*T", u"\u03A4") t.add("*U", u"\u03A5") t.add("*F", u"\u03A6") t.add("*X", u"\u03A7") t.add("*Y", u"\u03A8") t.add("*W", u"\u03A9") t.add("A", u"\u03B1") t.add("B", u"\u03B2") t.add("G", u"\u03B3") t.add("D", u"\u03B4") t.add("E", u"\u03B5") t.add("Z", u"\u03B6") t.add("H", u"\u03B7") t.add("Q", u"\u03B8") t.add("I", u"\u03B9") t.add("K", u"\u03BA") t.add("L", u"\u03BB") t.add("M", u"\u03BC") t.add("N", u"\u03BD") t.add("C", u"\u03BE") t.add("O", u"\u03BF") t.add("P", u"\u03C0") t.add("R", u"\u03C1") t.add("S\n", u"\u03C2") t.add("S,", u"\u03C2,") t.add("S.", u"\u03C2.") t.add("S:", u"\u03C2:") t.add("S;", u"\u03C2;") t.add("S]", u"\u03C2]") t.add("S#", u"\u03C2#") t.add("S_", u"\u03C2_") t.add("S", u"\u03C3") t.add("T", u"\u03C4") t.add("U", u"\u03C5") t.add("F", u"\u03C6") t.add("X", u"\u03C7") t.add("Y", u"\u03C8") t.add("W", u"\u03C9") t.add("I+", U"\u03CA") t.add("U+", U"\u03CB") t.add("A)", u"\u1F00") t.add("A(", u"\u1F01") t.add("A)\\", u"\u1F02") t.add("A(\\", u"\u1F03") t.add("A)/", u"\u1F04") t.add("A(/", u"\u1F05") t.add("E)", u"\u1F10") t.add("E(", u"\u1F11") t.add("E)\\", u"\u1F12") t.add("E(\\", u"\u1F13") t.add("E)/", u"\u1F14") t.add("E(/", u"\u1F15") t.add("H)", u"\u1F20") t.add("H(", u"\u1F21") t.add("H)\\", u"\u1F22") t.add("H(\\", u"\u1F23") t.add("H)/", u"\u1F24") t.add("H(/", u"\u1F25") t.add("I)", u"\u1F30") t.add("I(", u"\u1F31") t.add("I)\\", u"\u1F32") t.add("I(\\", u"\u1F33") t.add("I)/", u"\u1F34") t.add("I(/", u"\u1F35") t.add("O)", u"\u1F40") t.add("O(", u"\u1F41") t.add("O)\\", u"\u1F42") t.add("O(\\", u"\u1F43") t.add("O)/", u"\u1F44") t.add("O(/", u"\u1F45") t.add("U)", u"\u1F50") t.add("U(", u"\u1F51") t.add("U)\\", u"\u1F52") t.add("U(\\", u"\u1F53") t.add("U)/", u"\u1F54") t.add("U(/", u"\u1F55") t.add("W)", u"\u1F60") t.add("W(", u"\u1F61") t.add("W)\\", u"\u1F62") t.add("W(\\", u"\u1F63") t.add("W)/", u"\u1F64") t.add("W(/", u"\u1F65") t.add("A)=", u"\u1F06") t.add("A(=", u"\u1F07") t.add("H)=", u"\u1F26") t.add("H(=", u"\u1F27") t.add("I)=", u"\u1F36") t.add("I(=", u"\u1F37") t.add("U)=", u"\u1F56") t.add("U(=", u"\u1F57") t.add("W)=", u"\u1F66") t.add("W(=", u"\u1F67") t.add("*A)", u"\u1F08") t.add("*)A", u"\u1F08") t.add("*A(", u"\u1F09") t.add("*(A", u"\u1F09") # t.add("*(\A", u"\u1F0B") t.add("*A)/", u"\u1F0C") t.add("*)/A", u"\u1F0C") t.add("*A(/", u"\u1F0F") t.add("*(/A", u"\u1F0F") t.add("*E)", u"\u1F18") t.add("*)E", u"\u1F18") t.add("*E(", u"\u1F19") t.add("*(E", u"\u1F19") # t.add("*(\E", u"\u1F1B") t.add("*E)/", u"\u1F1C") t.add("*)/E", u"\u1F1C") t.add("*E(/", u"\u1F1D") t.add("*(/E", u"\u1F1D") t.add("*H)", u"\u1F28") t.add("*)H", u"\u1F28") t.add("*H(", u"\u1F29") t.add("*(H", u"\u1F29") t.add("*H)\\", u"\u1F2A") t.add(")\\*H", u"\u1F2A") t.add("*)\\H", u"\u1F2A") # t.add("*H)/", u"\u1F2C") t.add("*)/H", u"\u1F2C") # t.add("*)=H", u"\u1F2E") t.add("(/*H", u"\u1F2F") t.add("*(/H", u"\u1F2F") t.add("*I)", u"\u1F38") t.add("*)I", u"\u1F38") t.add("*I(", u"\u1F39") t.add("*(I", u"\u1F39") # # t.add("*I)/", u"\u1F3C") t.add("*)/I", u"\u1F3C") # # t.add("*I(/", u"\u1F3F") t.add("*(/I", u"\u1F3F") # t.add("*O)", u"\u1F48") t.add("*)O", u"\u1F48") t.add("*O(", u"\u1F49") t.add("*(O", u"\u1F49") # # t.add("*(\O", u"\u1F4B") t.add("*O)/", u"\u1F4C") t.add("*)/O", u"\u1F4C") t.add("*O(/", u"\u1F4F") t.add("*(/O", u"\u1F4F") # t.add("*U(", u"\u1F59") t.add("*(U", u"\u1F59") # t.add("*(/U", u"\u1F5D") # t.add("*(=U", u"\u1F5F") t.add("*W)", u"\u1F68") t.add("*W(", u"\u1F69") t.add("*(W", u"\u1F69") # # t.add("*W)/", u"\u1F6C") t.add("*)/W", u"\u1F6C") t.add("*W(/", u"\u1F6F") t.add("*(/W", u"\u1F6F") t.add("*A)=", u"\u1F0E") t.add("*)=A", u"\u1F0E") t.add("*A(=", u"\u1F0F") t.add("*W)=", u"\u1F6E") t.add("*)=W", u"\u1F6E") t.add("*W(=", u"\u1F6F") t.add("*(=W", u"\u1F6F") t.add("A\\", u"\u1F70") t.add("A/", u"\u1F71") t.add("E\\", u"\u1F72") t.add("E/", u"\u1F73") t.add("H\\", u"\u1F74") t.add("H/", u"\u1F75") t.add("I\\", u"\u1F76") t.add("I/", u"\u1F77") t.add("O\\", u"\u1F78") t.add("O/", u"\u1F79") t.add("U\\", u"\u1F7A") t.add("U/", u"\u1F7B") t.add("W\\", u"\u1F7C") t.add("W/", u"\u1F7D") t.add("A)/|", u"\u1F84") t.add("A(/|", u"\u1F85") t.add("H)|", u"\u1F90") t.add("H(|", u"\u1F91") t.add("H)/|", u"\u1F94") t.add("H)=|", u"\u1F96") t.add("H(=|", u"\u1F97") t.add("W)|", u"\u1FA0") t.add("W(=|", u"\u1FA7") t.add("A=", u"\u1FB6") t.add("H=", u"\u1FC6") t.add("I=", u"\u1FD6") t.add("U=", u"\u1FE6") t.add("W=", u"\u1FF6") t.add("I\\+", u"\u1FD2") t.add("I/+", u"\u1FD3") t.add("I+/", u"\u1FD3") t.add("U\\+", u"\u1FE2") t.add("U/+", u"\u1FE3") t.add("A|", u"\u1FB3") t.add("A/|", u"\u1FB4") t.add("H|", u"\u1FC3") t.add("H/|", u"\u1FC4") t.add("W|", u"\u1FF3") t.add("W|/", u"\u1FF4") t.add("W/|", u"\u1FF4") t.add("A=|", u"\u1FB7") t.add("H=|", u"\u1FC7") t.add("W=|", u"\u1FF7") t.add("R(", u"\u1FE4") t.add("*R(", u"\u1FEC") t.add("*(R", u"\u1FEC") # t.add("~", u"~") # t.add("-", u"-") # t.add("(null)", u"(null)") # t.add("&", "&") t.add("0", u"0") t.add("1", u"1") t.add("2", u"2") t.add("3", u"3") t.add("4", u"4") t.add("5", u"5") t.add("6", u"6") t.add("7", u"7") t.add("8", u"8") t.add("9", u"9") t.add("#", u"#") t.add("$", u"$") t.add(" ", u" ") t.add(".", u".") t.add(",", u",") t.add("'", u"'") t.add(":", u":") t.add(";", u";") t.add("_", u"_") t.add("[", u"[") t.add("]", u"]") t.add("\n", u"") return t t = beta2unicodeTrie() import sys for line in file(sys.argv[1]): a, b = t.convert(line) if b: print a.encode("utf-8"), b raise Exception print a.encode("utf-8") And here is a little .tex-file with which it should work. \documentclass[12pt]{scrbook} \usepackage[polutonikogreek, ngerman]{babel} \usepackage[ngerman]{betababel} \usepackage{fontspec} %\defaultfontfeatures{Ligatures=TeX} %\newfontfeature{Microtype}{protrusion=default;expansion=default;} \begin{document} \bcode{lo/gos} \end{document} In case the script does not work: would it be possible to convert all the strings within the \bcode-Makro with something like regex? For example the "o/" to the ό and so on? What would be the weapon of choice here?
Do I have python installed? Try python -V at a shell prompt. Your code is python 2 code, so you will a python 2 version. I need to install Python Most straight forward way if you don't need a complex environment (and you don't for this problem) is just to go to python.org. Don't forget you need python 2. Running the program Generally it will be as simple as: python beta2unicode.py myfile.tex-file And to capture the output: python beta2unicode.py myfile.tex-file > myfile.not-tex-file Does the script work? Almost. You will need to replace the code at the end of the script that starts the same way this does, with this: import sys t = beta2unicodeTrie() import re BCODE = re.compile(r'\\bcode{[^}]*}') for line in open(sys.argv[1]): matches = BCODE.search(line) for match in BCODE.findall(line): bcode = match[7:-1] a, b = t.convert(bcode.upper()) if b: raise IOError("failed conversion '%s' in '%s'" % (b, line)) converted = a.encode("utf-8") line = line.replace(match, converted) print(line.rstrip()) Results \documentclass[12pt]{scrbook} \usepackage[polutonikogreek, ngerman]{babel} \usepackage[ngerman]{betababel} \usepackage{fontspec} %\defaultfontfeatures{Ligatures=TeX} %\newfontfeature{Microtype}{protrusion=default;expansion=default;} \begin{document} λόγοσ \end{document}
Weird bug in python project
I'm working on a big project in Python, and I've run into a bizarre error I can't explain. In one of my classes, I have a private method being called during instantiation: def _convertIndex(self, dimInd, dimName): '''Private function that converts numbers to numbers and non-integers using the subclass\'s convertIndex.''' print dimInd, ' is dimInd' try: return int(dimName) except: if dimName == '*': return 0 else: print self.param.sets, ' is self.param.sets' print type(self.param.sets), ' is the type of self.param.sets' print self.param.sets[dimInd], ' is the param at dimind' return self.param.sets[dimInd].value(dimName) What it's printing out: 0 is dimInd [<coremcs.SymbolicSet.SymbolicSet object at 0x10618ad90>] is self.param.sets <type 'list'> is the type of self.param.sets <SymbolicSet BAZ=['baz1', 'baz2', 'baz3', 'baz4']> is the param at dimind ====================================================================== ERROR: testParameterSet (GtapTest.TestGtapParameter) ---------------------------------------------------------------------- Traceback (most recent call last): File "/Users/myuser/Documents/workspace/ilucmc/gtapmcs/test/GtapTest.py", line 116, in testParameterSet pset = ParameterSet(prmFile, dstFile, GtapParameter) File "/Users/myuser/Documents/workspace/ilucmc/coremcs/ParameterSet.py", line 103, in __init__ self.distroDict, corrDefs = AbsBaseDistro.readFile(distFile, self.paramDict) File "/Users/myuser/Documents/workspace/ilucmc/coremcs/Distro.py", line 359, in readFile distro = cls.parseDistro(param, target, distroType, args) File "/Users/myuser/Documents/workspace/ilucmc/coremcs/Distro.py", line 301, in parseDistro return cls(param, target, distro, dim_names, argDict) File "/Users/myuser/Documents/workspace/ilucmc/coremcs/Distro.py", line 150, in __init__ self.dim_indxs = list(starmap(self._convertIndex, enumerate(dim_names))) # convert to numeric values and save in dim_indxs File "/Users/myuser/Documents/workspace/ilucmc/coremcs/Distro.py", line 194, in _convertIndex print self.param.sets[dimInd], ' is the param at dimind' IndexError: list index out of range Obviously this isn't the code for the whole class, but it represents something that I don't understand. The error is coming when I index into self.param.sets. Apparently, dimInd is out of range. the problem is, dimInd is 0, and self.param.sets is a list of length 1 (as shown from the print statements), so why can't I index into it? EDIT: For what it's worth, the __init__ method looks like this: ''' Stores a definitions of a distribution to be applied to a header variable. See the file setup/gtap/DistroDoc.txt for the details. ''' def __init__(self, param, target, distType, dim_names, argDict): self.name = param.name self.dim_names = dim_names self.dim_indxs = [] self.target = target.lower() if target else None self.distType = distType.lower() if distType else None self.rv = None self.argDict = {} self.modifier = defaultdict(lambda: None) self.param = param # Separate args into modifiers and distribution arguments for k, v in argDict.iteritems(): if k[0] == '_': # modifiers start with underscore self.modifier[k] = v else: self.argDict[k] = v # distribution arguments do not have underscore if self.target == 'index': print dim_names self.dim_indxs = list(starmap(self._convertIndex, enumerate(dim_names))) # convert to numeric values and save in dim_indxs if distType == 'discrete': entries = self.modifier['_entries'] if not entries: raise DistributionSpecError("Not enough arguments given to discrete distribution.") modDict = {k[1:]: float(v) for k, v in self.modifier.iteritems() if k[1:] in getOptionalArgs(DiscreteDist.__init__)} self.rv = DiscreteDist(entries, **modDict) return sig = DistroGen.signature(distType, self.argDict.keys()) gen = DistroGen.generator(sig) if gen is None: raise DistributionSpecError("Unknown distribution signature %s" % str(sig)) self.rv = gen.makeRV(self.argDict) # generate a frozen RV with the specified arguments self.isFactor = gen.isFactor
pretty print assertEqual() for HTML strings
I want to compare two strings in a python unittest which contain html. Is there a method which outputs the result in a human friendly (diff like) version?
A simple method is to strip whitespace from the HTML and split it into a list. Python 2.7's unittest (or the backported unittest2) then gives a human-readable diff between the lists. import re def split_html(html): return re.split(r'\s*\n\s*', html.strip()) def test_render_html(): expected = ['<div>', '...', '</div>'] got = split_html(render_html()) self.assertEqual(expected, got) If I'm writing a test for working code, I usually first set expected = [], insert a self.maxDiff = None before the assert and let the test fail once. The expected list can then be copy-pasted from the test output. You might need to tweak how whitespace is stripped depending on what your HTML looks like.
I submitted a patch to do this some years back. The patch was rejected but you can still view it on the python bug list. I doubt you would want to hack your unittest.py to apply the patch (if it even still works after all this time), but here's the function for reducing two strings a manageable size while still keeping at least part of what differs. So long as all you didn't want the complete differences this might be what you want: def shortdiff(x,y): '''shortdiff(x,y) Compare strings x and y and display differences. If the strings are too long, shorten them to fit in one line, while still keeping at least some difference. ''' import difflib LINELEN = 79 def limit(s): if len(s) > LINELEN: return s[:LINELEN-3] + '...' return s def firstdiff(s, t): span = 1000 for pos in range(0, max(len(s), len(t)), span): if s[pos:pos+span] != t[pos:pos+span]: for index in range(pos, pos+span): if s[index:index+1] != t[index:index+1]: return index left = LINELEN/4 index = firstdiff(x, y) if index > left + 7: x = x[:left] + '...' + x[index-4:index+LINELEN] y = y[:left] + '...' + y[index-4:index+LINELEN] else: x, y = x[:LINELEN+1], y[:LINELEN+1] left = 0 cruncher = difflib.SequenceMatcher(None) xtags = ytags = "" cruncher.set_seqs(x, y) editchars = { 'replace': ('^', '^'), 'delete': ('-', ''), 'insert': ('', '+'), 'equal': (' ',' ') } for tag, xi1, xi2, yj1, yj2 in cruncher.get_opcodes(): lx, ly = xi2 - xi1, yj2 - yj1 edits = editchars[tag] xtags += edits[0] * lx ytags += edits[1] * ly # Include ellipsis in edits line. if left: xtags = xtags[:left] + '...' + xtags[left+3:] ytags = ytags[:left] + '...' + ytags[left+3:] diffs = [ x, xtags, y, ytags ] if max([len(s) for s in diffs]) < LINELEN: return '\n'.join(diffs) diffs = [ limit(s) for s in diffs ] return '\n'.join(diffs)
Maybe this is a quite 'verbose' solution. You could add a new 'equality function' for your user defined type (e.g: HTMLString) which you have to define first: class HTMLString(str): pass Now you have to define a type equality function: def assertHTMLStringEqual(first, second): if first != second: message = ... # TODO here: format your message, e.g a diff raise AssertionError(message) All you have to do is format your message as you like. You can also use a class method in your specific TestCase as a type equality function. This gives you more functionality to format your message, since unittest.TestCase does this a lot. Now you have to register this equality function in your unittest.TestCase: ... def __init__(self): self.addTypeEqualityFunc(HTMLString, assertHTMLStringEqual) The same for a class method: ... def __init__(self): self.addTypeEqualityFunc(HTMLString, 'assertHTMLStringEqual') And now you can use it in your tests: def test_something(self): htmlstring1 = HTMLString(...) htmlstring2 = HTMLString(...) self.assertEqual(htmlstring1, htmlstring2) This should work well with python 2.7.
I (the one asking this question) use BeautfulSoup now: def assertEqualHTML(string1, string2, file1='', file2=''): u''' Compare two unicode strings containing HTML. A human friendly diff goes to logging.error() if there are not equal, and an exception gets raised. ''' from BeautifulSoup import BeautifulSoup as bs import difflib def short(mystr): max=20 if len(mystr)>max: return mystr[:max] return mystr p=[] for mystr, file in [(string1, file1), (string2, file2)]: if not isinstance(mystr, unicode): raise Exception(u'string ist not unicode: %r %s' % (short(mystr), file)) soup=bs(mystr) pretty=soup.prettify() p.append(pretty) if p[0]!=p[1]: for line in difflib.unified_diff(p[0].splitlines(), p[1].splitlines(), fromfile=file1, tofile=file2): logging.error(line) raise Exception('Not equal %s %s' % (file1, file2))