New function get_elements_by_class() and get_elements_by_attribute() which return lists of hits, instead of only one

Deleted duplicate codes, calling new functions from old ones

New Tests for Function get_element_by_attribute() + get_elements_by_class() + get_elements_by_attribute() + Bug fixing of functions

Review as requested by dstftw

changes as requested

Checked with flake8
This commit is contained in:
Thomas Christlieb 2017-02-09 13:13:03 +01:00
parent 5abcca9060
commit c4835496cf
2 changed files with 51 additions and 10 deletions

View File

@ -34,6 +34,9 @@ from youtube_dl.utils import (
find_xpath_attr, find_xpath_attr,
fix_xml_ampersands, fix_xml_ampersands,
get_element_by_class, get_element_by_class,
get_element_by_attribute,
get_elements_by_class,
get_elements_by_attribute,
InAdvancePagedList, InAdvancePagedList,
intlist_to_bytes, intlist_to_bytes,
is_html, is_html,
@ -1124,6 +1127,32 @@ The first line
self.assertEqual(get_element_by_class('foo', html), 'nice') self.assertEqual(get_element_by_class('foo', html), 'nice')
self.assertEqual(get_element_by_class('no-such-class', html), None) self.assertEqual(get_element_by_class('no-such-class', html), None)
def test_get_element_by_attribute(self):
html = '''
<span class="foo bar">nice</span>
'''
self.assertEqual(get_element_by_attribute('class', 'foo bar', html), 'nice')
self.assertEqual(get_element_by_attribute('class', 'foo', html), None)
self.assertEqual(get_element_by_attribute('class', 'no-such-foo', html), None)
def test_get_elements_by_class(self):
html = '''
<span class="foo bar">nice</span><span class="foo bar">also nice</span>
'''
self.assertEqual(get_elements_by_class('foo', html), ['nice', 'also nice'])
self.assertEqual(get_elements_by_class('no-such-class', html), [])
def test_get_elements_by_attribute(self):
html = '''
<span class="foo bar">nice</span><span class="foo bar">also nice</span>
'''
self.assertEqual(get_elements_by_attribute('class', 'foo bar', html), ['nice', 'also nice'])
self.assertEqual(get_elements_by_attribute('class', 'foo', html), [])
self.assertEqual(get_elements_by_attribute('class', 'no-such-foo', html), [])
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()

View File

@ -337,17 +337,30 @@ def get_element_by_id(id, html):
def get_element_by_class(class_name, html): def get_element_by_class(class_name, html):
return get_element_by_attribute( """Return the content of the first tag with the specified class in the passed HTML document"""
retval = get_elements_by_class(class_name, html)
return retval[0] if retval else None
def get_element_by_attribute(attribute, value, html, escape_value=True):
retval = get_elements_by_attribute(attribute, value, html, escape_value)
return retval[0] if retval else None
def get_elements_by_class(class_name, html):
"""Return the content of all tags with the specified class in the passed HTML document as a list"""
return get_elements_by_attribute(
'class', r'[^\'"]*\b%s\b[^\'"]*' % re.escape(class_name), 'class', r'[^\'"]*\b%s\b[^\'"]*' % re.escape(class_name),
html, escape_value=False) html, escape_value=False)
def get_element_by_attribute(attribute, value, html, escape_value=True): def get_elements_by_attribute(attribute, value, html, escape_value=True):
"""Return the content of the tag with the specified attribute in the passed HTML document""" """Return the content of the tag with the specified attribute in the passed HTML document"""
value = re.escape(value) if escape_value else value value = re.escape(value) if escape_value else value
m = re.search(r'''(?xs) retlist = []
for m in re.finditer(r'''(?xs)
<([a-zA-Z0-9:._-]+) <([a-zA-Z0-9:._-]+)
(?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'))*? (?:\s+[a-zA-Z0-9:._-]+(?:=[a-zA-Z0-9:._-]*|="[^"]*"|='[^']*'))*?
\s+%s=['"]?%s['"]? \s+%s=['"]?%s['"]?
@ -355,16 +368,15 @@ def get_element_by_attribute(attribute, value, html, escape_value=True):
\s*> \s*>
(?P<content>.*?) (?P<content>.*?)
</\1> </\1>
''' % (re.escape(attribute), value), html) ''' % (re.escape(attribute), value), html):
res = m.group('content')
if not m: if res.startswith('"') or res.startswith("'"):
return None res = res[1:-1]
res = m.group('content')
if res.startswith('"') or res.startswith("'"): retlist.append(unescapeHTML(res))
res = res[1:-1]
return unescapeHTML(res) return retlist
class HTMLAttributeParser(compat_HTMLParser): class HTMLAttributeParser(compat_HTMLParser):