diff options
65 files changed, 13566 insertions, 48 deletions
diff --git a/python/atoma/__init__.py b/python/atoma/__init__.py new file mode 100644 index 0000000..0768081 --- /dev/null +++ b/python/atoma/__init__.py @@ -0,0 +1,12 @@ +from .atom import parse_atom_file, parse_atom_bytes +from .rss import parse_rss_file, parse_rss_bytes +from .json_feed import ( + parse_json_feed, parse_json_feed_file, parse_json_feed_bytes +) +from .opml import parse_opml_file, parse_opml_bytes +from .exceptions import ( + FeedParseError, FeedDocumentError, FeedXMLError, FeedJSONError +) +from .const import VERSION + +__version__ = VERSION diff --git a/python/atoma/atom.py b/python/atoma/atom.py new file mode 100644 index 0000000..d4e676c --- /dev/null +++ b/python/atoma/atom.py @@ -0,0 +1,284 @@ +from datetime import datetime +import enum +from io import BytesIO +from typing import Optional, List +from xml.etree.ElementTree import Element + +import attr + +from .utils import ( + parse_xml, get_child, get_text, get_datetime, FeedParseError, ns +) + + +class AtomTextType(enum.Enum): + text = "text" + html = "html" + xhtml = "xhtml" + + +@attr.s +class AtomTextConstruct: + text_type: str = attr.ib() + lang: Optional[str] = attr.ib() + value: str = attr.ib() + + +@attr.s +class AtomEntry: + title: AtomTextConstruct = attr.ib() + id_: str = attr.ib() + + # Should be mandatory but many feeds use published instead + updated: Optional[datetime] = attr.ib() + + authors: List['AtomPerson'] = attr.ib() + contributors: List['AtomPerson'] = attr.ib() + links: List['AtomLink'] = attr.ib() + categories: List['AtomCategory'] = attr.ib() + published: Optional[datetime] = attr.ib() + rights: Optional[AtomTextConstruct] = attr.ib() + summary: Optional[AtomTextConstruct] = attr.ib() + content: Optional[AtomTextConstruct] = attr.ib() + source: Optional['AtomFeed'] = attr.ib() + + +@attr.s +class AtomFeed: + title: Optional[AtomTextConstruct] = attr.ib() + id_: str = attr.ib() + + # Should be mandatory but many feeds do not include it + updated: Optional[datetime] = attr.ib() + + authors: List['AtomPerson'] = attr.ib() + contributors: List['AtomPerson'] = attr.ib() + links: List['AtomLink'] = attr.ib() + categories: List['AtomCategory'] = attr.ib() + generator: Optional['AtomGenerator'] = attr.ib() + subtitle: Optional[AtomTextConstruct] = attr.ib() + rights: Optional[AtomTextConstruct] = attr.ib() + icon: Optional[str] = attr.ib() + logo: Optional[str] = attr.ib() + + entries: List[AtomEntry] = attr.ib() + + +@attr.s +class AtomPerson: + name: str = attr.ib() + uri: Optional[str] = attr.ib() + email: Optional[str] = attr.ib() + + +@attr.s +class AtomLink: + href: str = attr.ib() + rel: Optional[str] = attr.ib() + type_: Optional[str] = attr.ib() + hreflang: Optional[str] = attr.ib() + title: Optional[str] = attr.ib() + length: Optional[int] = attr.ib() + + +@attr.s +class AtomCategory: + term: str = attr.ib() + scheme: Optional[str] = attr.ib() + label: Optional[str] = attr.ib() + + +@attr.s +class AtomGenerator: + name: str = attr.ib() + uri: Optional[str] = attr.ib() + version: Optional[str] = attr.ib() + + +def _get_generator(element: Element, name, + optional: bool=True) -> Optional[AtomGenerator]: + child = get_child(element, name, optional) + if child is None: + return None + + return AtomGenerator( + child.text.strip(), + child.attrib.get('uri'), + child.attrib.get('version'), + ) + + +def _get_text_construct(element: Element, name, + optional: bool=True) -> Optional[AtomTextConstruct]: + child = get_child(element, name, optional) + if child is None: + return None + + try: + text_type = AtomTextType(child.attrib['type']) + except KeyError: + text_type = AtomTextType.text + + try: + lang = child.lang + except AttributeError: + lang = None + + if child.text is None: + if optional: + return None + + raise FeedParseError( + 'Could not parse atom feed: "{}" text is required but is empty' + .format(name) + ) + + return AtomTextConstruct( + text_type, + lang, + child.text.strip() + ) + + +def _get_person(element: Element) -> Optional[AtomPerson]: + try: + return AtomPerson( + get_text(element, 'feed:name', optional=False), + get_text(element, 'feed:uri'), + get_text(element, 'feed:email') + ) + except FeedParseError: + return None + + +def _get_link(element: Element) -> AtomLink: + length = element.attrib.get('length') + length = int(length) if length else None + return AtomLink( + element.attrib['href'], + element.attrib.get('rel'), + element.attrib.get('type'), + element.attrib.get('hreflang'), + element.attrib.get('title'), + length + ) + + +def _get_category(element: Element) -> AtomCategory: + return AtomCategory( + element.attrib['term'], + element.attrib.get('scheme'), + element.attrib.get('label'), + ) + + +def _get_entry(element: Element, + default_authors: List[AtomPerson]) -> AtomEntry: + root = element + + # Mandatory + title = _get_text_construct(root, 'feed:title') + id_ = get_text(root, 'feed:id') + + # Optional + try: + source = _parse_atom(get_child(root, 'feed:source', optional=False), + parse_entries=False) + except FeedParseError: + source = None + source_authors = [] + else: + source_authors = source.authors + + authors = [_get_person(e) + for e in root.findall('feed:author', ns)] or default_authors + authors = [a for a in authors if a is not None] + authors = authors or default_authors or source_authors + + contributors = [_get_person(e) + for e in root.findall('feed:contributor', ns) if e] + contributors = [c for c in contributors if c is not None] + + links = [_get_link(e) for e in root.findall('feed:link', ns)] + categories = [_get_category(e) for e in root.findall('feed:category', ns)] + + updated = get_datetime(root, 'feed:updated') + published = get_datetime(root, 'feed:published') + rights = _get_text_construct(root, 'feed:rights') + summary = _get_text_construct(root, 'feed:summary') + content = _get_text_construct(root, 'feed:content') + + return AtomEntry( + title, + id_, + updated, + authors, + contributors, + links, + categories, + published, + rights, + summary, + content, + source + ) + + +def _parse_atom(root: Element, parse_entries: bool=True) -> AtomFeed: + # Mandatory + id_ = get_text(root, 'feed:id', optional=False) + + # Optional + title = _get_text_construct(root, 'feed:title') + updated = get_datetime(root, 'feed:updated') + authors = [_get_person(e) + for e in root.findall('feed:author', ns) if e] + authors = [a for a in authors if a is not None] + contributors = [_get_person(e) + for e in root.findall('feed:contributor', ns) if e] + contributors = [c for c in contributors if c is not None] + links = [_get_link(e) + for e in root.findall('feed:link', ns)] + categories = [_get_category(e) + for e in root.findall('feed:category', ns)] + + generator = _get_generator(root, 'feed:generator') + subtitle = _get_text_construct(root, 'feed:subtitle') + rights = _get_text_construct(root, 'feed:rights') + icon = get_text(root, 'feed:icon') + logo = get_text(root, 'feed:logo') + + if parse_entries: + entries = [_get_entry(e, authors) + for e in root.findall('feed:entry', ns)] + else: + entries = [] + + atom_feed = AtomFeed( + title, + id_, + updated, + authors, + contributors, + links, + categories, + generator, + subtitle, + rights, + icon, + logo, + entries + ) + return atom_feed + + +def parse_atom_file(filename: str) -> AtomFeed: + """Parse an Atom feed from a local XML file.""" + root = parse_xml(filename).getroot() + return _parse_atom(root) + + +def parse_atom_bytes(data: bytes) -> AtomFeed: + """Parse an Atom feed from a byte-string containing XML data.""" + root = parse_xml(BytesIO(data)).getroot() + return _parse_atom(root) diff --git a/python/atoma/const.py b/python/atoma/const.py new file mode 100644 index 0000000..d52d0f6 --- /dev/null +++ b/python/atoma/const.py @@ -0,0 +1 @@ +VERSION = '0.0.13' diff --git a/python/atoma/exceptions.py b/python/atoma/exceptions.py new file mode 100644 index 0000000..88170c5 --- /dev/null +++ b/python/atoma/exceptions.py @@ -0,0 +1,14 @@ +class FeedParseError(Exception): + """Document is an invalid feed.""" + + +class FeedDocumentError(Exception): + """Document is not a supported file.""" + + +class FeedXMLError(FeedDocumentError): + """Document is not valid XML.""" + + +class FeedJSONError(FeedDocumentError): + """Document is not valid JSON.""" diff --git a/python/atoma/json_feed.py b/python/atoma/json_feed.py new file mode 100644 index 0000000..410ff4a --- /dev/null +++ b/python/atoma/json_feed.py @@ -0,0 +1,223 @@ +from datetime import datetime, timedelta +import json +from typing import Optional, List + +import attr + +from .exceptions import FeedParseError, FeedJSONError +from .utils import try_parse_date + + +@attr.s +class JSONFeedAuthor: + + name: Optional[str] = attr.ib() + url: Optional[str] = attr.ib() + avatar: Optional[str] = attr.ib() + + +@attr.s +class JSONFeedAttachment: + + url: str = attr.ib() + mime_type: str = attr.ib() + title: Optional[str] = attr.ib() + size_in_bytes: Optional[int] = attr.ib() + duration: Optional[timedelta] = attr.ib() + + +@attr.s +class JSONFeedItem: + + id_: str = attr.ib() + url: Optional[str] = attr.ib() + external_url: Optional[str] = attr.ib() + title: Optional[str] = attr.ib() + content_html: Optional[str] = attr.ib() + content_text: Optional[str] = attr.ib() + summary: Optional[str] = attr.ib() + image: Optional[str] = attr.ib() + banner_image: Optional[str] = attr.ib() + date_published: Optional[datetime] = attr.ib() + date_modified: Optional[datetime] = attr.ib() + author: Optional[JSONFeedAuthor] = attr.ib() + + tags: List[str] = attr.ib() + attachments: List[JSONFeedAttachment] = attr.ib() + + +@attr.s +class JSONFeed: + + version: str = attr.ib() + title: str = attr.ib() + home_page_url: Optional[str] = attr.ib() + feed_url: Optional[str] = attr.ib() + description: Optional[str] = attr.ib() + user_comment: Optional[str] = attr.ib() + next_url: Optional[str] = attr.ib() + icon: Optional[str] = attr.ib() + favicon: Optional[str] = attr.ib() + author: Optional[JSONFeedAuthor] = attr.ib() + expired: bool = attr.ib() + + items: List[JSONFeedItem] = attr.ib() + + +def _get_items(root: dict) -> List[JSONFeedItem]: + rv = [] + items = root.get('items', []) + if not items: + return rv + + for item in items: + rv.append(_get_item(item)) + + return rv + + +def _get_item(item_dict: dict) -> JSONFeedItem: + return JSONFeedItem( + id_=_get_text(item_dict, 'id', optional=False), + url=_get_text(item_dict, 'url'), + external_url=_get_text(item_dict, 'external_url'), + title=_get_text(item_dict, 'title'), + content_html=_get_text(item_dict, 'content_html'), + content_text=_get_text(item_dict, 'content_text'), + summary=_get_text(item_dict, 'summary'), + image=_get_text(item_dict, 'image'), + banner_image=_get_text(item_dict, 'banner_image'), + date_published=_get_datetime(item_dict, 'date_published'), + date_modified=_get_datetime(item_dict, 'date_modified'), + author=_get_author(item_dict), + tags=_get_tags(item_dict, 'tags'), + attachments=_get_attachments(item_dict, 'attachments') + ) + + +def _get_attachments(root, name) -> List[JSONFeedAttachment]: + rv = list() + for attachment_dict in root.get(name, []): + rv.append(JSONFeedAttachment( + _get_text(attachment_dict, 'url', optional=False), + _get_text(attachment_dict, 'mime_type', optional=False), + _get_text(attachment_dict, 'title'), + _get_int(attachment_dict, 'size_in_bytes'), + _get_duration(attachment_dict, 'duration_in_seconds') + )) + return rv + + +def _get_tags(root, name) -> List[str]: + tags = root.get(name, []) + return [tag for tag in tags if isinstance(tag, str)] + + +def _get_datetime(root: dict, name, optional: bool=True) -> Optional[datetime]: + text = _get_text(root, name, optional) + if text is None: + return None + + return try_parse_date(text) + + +def _get_expired(root: dict) -> bool: + if root.get('expired') is True: + return True + + return False + + +def _get_author(root: dict) -> Optional[JSONFeedAuthor]: + author_dict = root.get('author') + if not author_dict: + return None + + rv = JSONFeedAuthor( + name=_get_text(author_dict, 'name'), + url=_get_text(author_dict, 'url'), + avatar=_get_text(author_dict, 'avatar'), + ) + if rv.name is None and rv.url is None and rv.avatar is None: + return None + + return rv + + +def _get_int(root: dict, name: str, optional: bool=True) -> Optional[int]: + rv = root.get(name) + if not optional and rv is None: + raise FeedParseError('Could not parse feed: "{}" int is required but ' + 'is empty'.format(name)) + + if optional and rv is None: + return None + + if not isinstance(rv, int): + raise FeedParseError('Could not parse feed: "{}" is not an int' + .format(name)) + + return rv + + +def _get_duration(root: dict, name: str, + optional: bool=True) -> Optional[timedelta]: + duration = _get_int(root, name, optional) + if duration is None: + return None + + return timedelta(seconds=duration) + + +def _get_text(root: dict, name: str, optional: bool=True) -> Optional[str]: + rv = root.get(name) + if not optional and rv is None: + raise FeedParseError('Could not parse feed: "{}" text is required but ' + 'is empty'.format(name)) + + if optional and rv is None: + return None + + if not isinstance(rv, str): + raise FeedParseError('Could not parse feed: "{}" is not a string' + .format(name)) + + return rv + + +def parse_json_feed(root: dict) -> JSONFeed: + return JSONFeed( + version=_get_text(root, 'version', optional=False), + title=_get_text(root, 'title', optional=False), + home_page_url=_get_text(root, 'home_page_url'), + feed_url=_get_text(root, 'feed_url'), + description=_get_text(root, 'description'), + user_comment=_get_text(root, 'user_comment'), + next_url=_get_text(root, 'next_url'), + icon=_get_text(root, 'icon'), + favicon=_get_text(root, 'favicon'), + author=_get_author(root), + expired=_get_expired(root), + items=_get_items(root) + ) + + +def parse_json_feed_file(filename: str) -> JSONFeed: + """Parse a JSON feed from a local json file.""" + with open(filename) as f: + try: + root = json.load(f) + except json.decoder.JSONDecodeError: + raise FeedJSONError('Not a valid JSON document') + + return parse_json_feed(root) + + +def parse_json_feed_bytes(data: bytes) -> JSONFeed: + """Parse a JSON feed from a byte-string containing JSON data.""" + try: + root = json.loads(data) + except json.decoder.JSONDecodeError: + raise FeedJSONError('Not a valid JSON document') + + return parse_json_feed(root) diff --git a/python/atoma/opml.py b/python/atoma/opml.py new file mode 100644 index 0000000..a73105e --- /dev/null +++ b/python/atoma/opml.py @@ -0,0 +1,107 @@ +from datetime import datetime +from io import BytesIO +from typing import Optional, List +from xml.etree.ElementTree import Element + +import attr + +from .utils import parse_xml, get_text, get_int, get_datetime + + +@attr.s +class OPMLOutline: + text: Optional[str] = attr.ib() + type: Optional[str] = attr.ib() + xml_url: Optional[str] = attr.ib() + description: Optional[str] = attr.ib() + html_url: Optional[str] = attr.ib() + language: Optional[str] = attr.ib() + title: Optional[str] = attr.ib() + version: Optional[str] = attr.ib() + + outlines: List['OPMLOutline'] = attr.ib() + + +@attr.s +class OPML: + title: Optional[str] = attr.ib() + owner_name: Optional[str] = attr.ib() + owner_email: Optional[str] = attr.ib() + date_created: Optional[datetime] = attr.ib() + date_modified: Optional[datetime] = attr.ib() + expansion_state: Optional[str] = attr.ib() + + vertical_scroll_state: Optional[int] = attr.ib() + window_top: Optional[int] = attr.ib() + window_left: Optional[int] = attr.ib() + window_bottom: Optional[int] = attr.ib() + window_right: Optional[int] = attr.ib() + + outlines: List[OPMLOutline] = attr.ib() + + +def _get_outlines(element: Element) -> List[OPMLOutline]: + rv = list() + + for outline in element.findall('outline'): + rv.append(OPMLOutline( + outline.attrib.get('text'), + outline.attrib.get('type'), + outline.attrib.get('xmlUrl'), + outline.attrib.get('description'), + outline.attrib.get('htmlUrl'), + outline.attrib.get('language'), + outline.attrib.get('title'), + outline.attrib.get('version'), + _get_outlines(outline) + )) + + return rv + + +def _parse_opml(root: Element) -> OPML: + head = root.find('head') + body = root.find('body') + + return OPML( + get_text(head, 'title'), + get_text(head, 'ownerName'), + get_text(head, 'ownerEmail'), + get_datetime(head, 'dateCreated'), + get_datetime(head, 'dateModified'), + get_text(head, 'expansionState'), + get_int(head, 'vertScrollState'), + get_int(head, 'windowTop'), + get_int(head, 'windowLeft'), + get_int(head, 'windowBottom'), + get_int(head, 'windowRight'), + outlines=_get_outlines(body) + ) + + +def parse_opml_file(filename: str) -> OPML: + """Parse an OPML document from a local XML file.""" + root = parse_xml(filename).getroot() + return _parse_opml(root) + + +def parse_opml_bytes(data: bytes) -> OPML: + """Parse an OPML document from a byte-string containing XML data.""" + root = parse_xml(BytesIO(data)).getroot() + return _parse_opml(root) + + +def get_feed_list(opml_obj: OPML) -> List[str]: + """Walk an OPML document to extract the list of feed it contains.""" + rv = list() + + def collect(obj): + for outline in obj.outlines: + if outline.type == 'rss' and outline.xml_url: + rv.append(outline.xml_url) + + if outline.outlines: + collect(outline) + + collect(opml_obj) + return rv diff --git a/python/atoma/rss.py b/python/atoma/rss.py new file mode 100644 index 0000000..f447a2f --- /dev/null +++ b/python/atoma/rss.py @@ -0,0 +1,221 @@ +from datetime import datetime +from io import BytesIO +from typing import Optional, List +from xml.etree.ElementTree import Element + +import attr + +from .utils import ( + parse_xml, get_child, get_text, get_int, get_datetime, FeedParseError +) + + +@attr.s +class RSSImage: + url: str = attr.ib() + title: Optional[str] = attr.ib() + link: str = attr.ib() + width: int = attr.ib() + height: int = attr.ib() + description: Optional[str] = attr.ib() + + +@attr.s +class RSSEnclosure: + url: str = attr.ib() + length: Optional[int] = attr.ib() + type: Optional[str] = attr.ib() + + +@attr.s +class RSSSource: + title: str = attr.ib() + url: Optional[str] = attr.ib() + + +@attr.s +class RSSItem: + title: Optional[str] = attr.ib() + link: Optional[str] = attr.ib() + description: Optional[str] = attr.ib() + author: Optional[str] = attr.ib() + categories: List[str] = attr.ib() + comments: Optional[str] = attr.ib() + enclosures: List[RSSEnclosure] = attr.ib() + guid: Optional[str] = attr.ib() + pub_date: Optional[datetime] = attr.ib() + source: Optional[RSSSource] = attr.ib() + + # Extension + content_encoded: Optional[str] = attr.ib() + + +@attr.s +class RSSChannel: + title: Optional[str] = attr.ib() + link: Optional[str] = attr.ib() + description: Optional[str] = attr.ib() + language: Optional[str] = attr.ib() + copyright: Optional[str] = attr.ib() + managing_editor: Optional[str] = attr.ib() + web_master: Optional[str] = attr.ib() + pub_date: Optional[datetime] = attr.ib() + last_build_date: Optional[datetime] = attr.ib() + categories: List[str] = attr.ib() + generator: Optional[str] = attr.ib() + docs: Optional[str] = attr.ib() + ttl: Optional[int] = attr.ib() + image: Optional[RSSImage] = attr.ib() + + items: List[RSSItem] = attr.ib() + + # Extension + content_encoded: Optional[str] = attr.ib() + + +def _get_image(element: Element, name, + optional: bool=True) -> Optional[RSSImage]: + child = get_child(element, name, optional) + if child is None: + return None + + return RSSImage( + get_text(child, 'url', optional=False), + get_text(child, 'title'), + get_text(child, 'link', optional=False), + get_int(child, 'width') or 88, + get_int(child, 'height') or 31, + get_text(child, 'description') + ) + + +def _get_source(element: Element, name, + optional: bool=True) -> Optional[RSSSource]: + child = get_child(element, name, optional) + if child is None: + return None + + return RSSSource( + child.text.strip(), + child.attrib.get('url'), + ) + + +def _get_enclosure(element: Element) -> RSSEnclosure: + length = element.attrib.get('length') + try: + length = int(length) + except (TypeError, ValueError): + length = None + + return RSSEnclosure( + element.attrib['url'], + length, + element.attrib.get('type'), + ) + + +def _get_link(element: Element) -> Optional[str]: + """Attempt to retrieve item link. + + Use the GUID as a fallback if it is a permalink. + """ + link = get_text(element, 'link') + if link is not None: + return link + + guid = get_child(element, 'guid') + if guid is not None and guid.attrib.get('isPermaLink') == 'true': + return get_text(element, 'guid') + + return None + + +def _get_item(element: Element) -> RSSItem: + root = element + + title = get_text(root, 'title') + link = _get_link(root) + description = get_text(root, 'description') + author = get_text(root, 'author') + categories = [e.text for e in root.findall('category')] + comments = get_text(root, 'comments') + enclosure = [_get_enclosure(e) for e in root.findall('enclosure')] + guid = get_text(root, 'guid') + pub_date = get_datetime(root, 'pubDate') + source = _get_source(root, 'source') + + content_encoded = get_text(root, 'content:encoded') + + return RSSItem( + title, + link, + description, + author, + categories, + comments, + enclosure, + guid, + pub_date, + source, + content_encoded + ) + + +def _parse_rss(root: Element) -> RSSChannel: + rss_version = root.get('version') + if rss_version != '2.0': + raise FeedParseError('Cannot process RSS feed version "{}"' + .format(rss_version)) + + root = root.find('channel') + + title = get_text(root, 'title') + link = get_text(root, 'link') + description = get_text(root, 'description') + language = get_text(root, 'language') + copyright = get_text(root, 'copyright') + managing_editor = get_text(root, 'managingEditor') + web_master = get_text(root, 'webMaster') + pub_date = get_datetime(root, 'pubDate') + last_build_date = get_datetime(root, 'lastBuildDate') + categories = [e.text for e in root.findall('category')] + generator = get_text(root, 'generator') + docs = get_text(root, 'docs') + ttl = get_int(root, 'ttl') + + image = _get_image(root, 'image') + items = [_get_item(e) for e in root.findall('item')] + + content_encoded = get_text(root, 'content:encoded') + + return RSSChannel( + title, + link, + description, + language, + copyright, + managing_editor, + web_master, + pub_date, + last_build_date, + categories, + generator, + docs, + ttl, + image, + items, + content_encoded + ) + + +def parse_rss_file(filename: str) -> RSSChannel: + """Parse an RSS feed from a local XML file.""" + root = parse_xml(filename).getroot() + return _parse_rss(root) + + +def parse_rss_bytes(data: bytes) -> RSSChannel: + """Parse an RSS feed from a byte-string containing XML data.""" + root = parse_xml(BytesIO(data)).getroot() + return _parse_rss(root) diff --git a/python/atoma/simple.py b/python/atoma/simple.py new file mode 100644 index 0000000..98bb3e1 --- /dev/null +++ b/python/atoma/simple.py @@ -0,0 +1,224 @@ +"""Simple API that abstracts away the differences between feed types.""" + +from datetime import datetime, timedelta +import html +import os +from typing import Optional, List, Tuple +import urllib.parse + +import attr + +from . import atom, rss, json_feed +from .exceptions import ( + FeedParseError, FeedDocumentError, FeedXMLError, FeedJSONError +) + + +@attr.s +class Attachment: + link: str = attr.ib() + mime_type: Optional[str] = attr.ib() + title: Optional[str] = attr.ib() + size_in_bytes: Optional[int] = attr.ib() + duration: Optional[timedelta] = attr.ib() + + +@attr.s +class Article: + id: str = attr.ib() + title: Optional[str] = attr.ib() + link: Optional[str] = attr.ib() + content: str = attr.ib() + published_at: Optional[datetime] = attr.ib() + updated_at: Optional[datetime] = attr.ib() + attachments: List[Attachment] = attr.ib() + + +@attr.s +class Feed: + title: str = attr.ib() + subtitle: Optional[str] = attr.ib() + link: Optional[str] = attr.ib() + updated_at: Optional[datetime] = attr.ib() + articles: List[Article] = attr.ib() + + +def _adapt_atom_feed(atom_feed: atom.AtomFeed) -> Feed: + articles = list() + for entry in atom_feed.entries: + if entry.content is not None: + content = entry.content.value + elif entry.summary is not None: + content = entry.summary.value + else: + content = '' + published_at, updated_at = _get_article_dates(entry.published, + entry.updated) + # Find article link and attachments + article_link = None + attachments = list() + for candidate_link in entry.links: + if candidate_link.rel in ('alternate', None): + article_link = candidate_link.href + elif candidate_link.rel == 'enclosure': + attachments.append(Attachment( + title=_get_attachment_title(candidate_link.title, + candidate_link.href), + link=candidate_link.href, + mime_type=candidate_link.type_, + size_in_bytes=candidate_link.length, + duration=None + )) + + if entry.title is None: + entry_title = None + elif entry.title.text_type in (atom.AtomTextType.html, + atom.AtomTextType.xhtml): + entry_title = html.unescape(entry.title.value).strip() + else: + entry_title = entry.title.value + + articles.append(Article( + entry.id_, + entry_title, + article_link, + content, + published_at, + updated_at, + attachments + )) + + # Find feed link + link = None + for candidate_link in atom_feed.links: + if candidate_link.rel == 'self': + link = candidate_link.href + break + + return Feed( + atom_feed.title.value if atom_feed.title else atom_feed.id_, + atom_feed.subtitle.value if atom_feed.subtitle else None, + link, + atom_feed.updated, + articles + ) + + +def _adapt_rss_channel(rss_channel: rss.RSSChannel) -> Feed: + articles = list() + for item in rss_channel.items: + attachments = [ + Attachment(link=e.url, mime_type=e.type, size_in_bytes=e.length, + title=_get_attachment_title(None, e.url), duration=None) + for e in item.enclosures + ] + articles.append(Article( + item.guid or item.link, + item.title, + item.link, + item.content_encoded or item.description or '', + item.pub_date, + None, + attachments + )) + + if rss_channel.title is None and rss_channel.link is None: + raise FeedParseError('RSS feed does not have a title nor a link') + + return Feed( + rss_channel.title if rss_channel.title else rss_channel.link, + rss_channel.description, + rss_channel.link, + rss_channel.pub_date, + articles + ) + + +def _adapt_json_feed(json_feed: json_feed.JSONFeed) -> Feed: + articles = list() + for item in json_feed.items: + attachments = [ + Attachment(a.url, a.mime_type, + _get_attachment_title(a.title, a.url), + a.size_in_bytes, a.duration) + for a in item.attachments + ] + articles.append(Article( + item.id_, + item.title, + item.url, + item.content_html or item.content_text or '', + item.date_published, + item.date_modified, + attachments + )) + + return Feed( + json_feed.title, + json_feed.description, + json_feed.feed_url, + None, + articles + ) + + +def _get_article_dates(published_at: Optional[datetime], + updated_at: Optional[datetime] + ) -> Tuple[Optional[datetime], Optional[datetime]]: + if published_at and updated_at: + return published_at, updated_at + + if updated_at: + return updated_at, None + + if published_at: + return published_at, None + + raise FeedParseError('Article does not have proper dates') + + +def _get_attachment_title(attachment_title: Optional[str], link: str) -> str: + if attachment_title: + return attachment_title + + parsed_link = urllib.parse.urlparse(link) + return os.path.basename(parsed_link.path) + + +def _simple_parse(pairs, content) -> Feed: + is_xml = True + is_json = True + for parser, adapter in pairs: + try: + return adapter(parser(content)) + except FeedXMLError: + is_xml = False + except FeedJSONError: + is_json = False + except FeedParseError: + continue + + if not is_xml and not is_json: + raise FeedDocumentError('File is not a supported feed type') + + raise FeedParseError('File is not a valid supported feed') + + +def simple_parse_file(filename: str) -> Feed: + """Parse an Atom, RSS or JSON feed from a local file.""" + pairs = ( + (rss.parse_rss_file, _adapt_rss_channel), + (atom.parse_atom_file, _adapt_atom_feed), + (json_feed.parse_json_feed_file, _adapt_json_feed) + ) + return _simple_parse(pairs, filename) + + +def simple_parse_bytes(data: bytes) -> Feed: + """Parse an Atom, RSS or JSON feed from a byte-string containing data.""" + pairs = ( + (rss.parse_rss_bytes, _adapt_rss_channel), + (atom.parse_atom_bytes, _adapt_atom_feed), + (json_feed.parse_json_feed_bytes, _adapt_json_feed) + ) + return _simple_parse(pairs, data) diff --git a/python/atoma/utils.py b/python/atoma/utils.py new file mode 100644 index 0000000..4dc1ab5 --- /dev/null +++ b/python/atoma/utils.py @@ -0,0 +1,84 @@ +from datetime import datetime, timezone +from xml.etree.ElementTree import Element +from typing import Optional + +import dateutil.parser +from defusedxml.ElementTree import parse as defused_xml_parse, ParseError + +from .exceptions import FeedXMLError, FeedParseError + +ns = { + 'content': 'http://purl.org/rss/1.0/modules/content/', + 'feed': 'http://www.w3.org/2005/Atom' +} + + +def parse_xml(xml_content): + try: + return defused_xml_parse(xml_content) + except ParseError: + raise FeedXMLError('Not a valid XML document') + + +def get_child(element: Element, name, + optional: bool=True) -> Optional[Element]: + child = element.find(name, namespaces=ns) + + if child is None and not optional: + raise FeedParseError( + 'Could not parse feed: "{}" does not have a "{}"' + .format(element.tag, name) + ) + + elif child is None: + return None + + return child + + +def get_text(element: Element, name, optional: bool=True) -> Optional[str]: + child = get_child(element, name, optional) + if child is None: + return None + + if child.text is None: + if optional: + return None + + raise FeedParseError( + 'Could not parse feed: "{}" text is required but is empty' + .format(name) + ) + + return child.text.strip() + + +def get_int(element: Element, name, optional: bool=True) -> Optional[int]: + text = get_text(element, name, optional) + if text is None: + return None + + return int(text) + + +def get_datetime(element: Element, name, + optional: bool=True) -> Optional[datetime]: + text = get_text(element, name, optional) + if text is None: + return None + + return try_parse_date(text) + + +def try_parse_date(date_str: str) -> Optional[datetime]: + try: + date = dateutil.parser.parse(date_str, fuzzy=True) + except (ValueError, OverflowError): + return None + + if date.tzinfo is None: + # TZ naive datetime, make it a TZ aware datetime by assuming it + # contains UTC time + date = date.replace(tzinfo=timezone.utc) + + return date diff --git a/python/attr/__init__.py b/python/attr/__init__.py new file mode 100644 index 0000000..debfd57 --- /dev/null +++ b/python/attr/__init__.py @@ -0,0 +1,65 @@ +from __future__ import absolute_import, division, print_function + +from functools import partial + +from . import converters, exceptions, filters, validators +from ._config import get_run_validators, set_run_validators +from ._funcs import asdict, assoc, astuple, evolve, has +from ._make import ( + NOTHING, + Attribute, + Factory, + attrib, + attrs, + fields, + fields_dict, + make_class, + validate, +) + + +__version__ = "18.2.0" + +__title__ = "attrs" +__description__ = "Classes Without Boilerplate" +__url__ = "https://www.attrs.org/" +__uri__ = __url__ +__doc__ = __description__ + " <" + __uri__ + ">" + +__author__ = "Hynek Schlawack" +__email__ = "hs@ox.cx" + +__license__ = "MIT" +__copyright__ = "Copyright (c) 2015 Hynek Schlawack" + + +s = attributes = attrs +ib = attr = attrib +dataclass = partial(attrs, auto_attribs=True) # happy Easter ;) + +__all__ = [ + "Attribute", + "Factory", + "NOTHING", + "asdict", + "assoc", + "astuple", + "attr", + "attrib", + "attributes", + "attrs", + "converters", + "evolve", + "exceptions", + "fields", + "fields_dict", + "filters", + "get_run_validators", + "has", + "ib", + "make_class", + "s", + "set_run_validators", + "validate", + "validators", +] diff --git a/python/attr/__init__.pyi b/python/attr/__init__.pyi new file mode 100644 index 0000000..492fb85 --- /dev/null +++ b/python/attr/__init__.pyi @@ -0,0 +1,252 @@ +from typing import ( + Any, + Callable, + Dict, + Generic, + List, + Optional, + Sequence, + Mapping, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +# `import X as X` is required to make these public +from . import exceptions as exceptions +from . import filters as filters +from . import converters as converters +from . import validators as validators + +_T = TypeVar("_T") +_C = TypeVar("_C", bound=type) + +_ValidatorType = Callable[[Any, Attribute, _T], Any] +_ConverterType = Callable[[Any], _T] +_FilterType = Callable[[Attribute, Any], bool] +# FIXME: in reality, if multiple validators are passed they must be in a list or tuple, +# but those are invariant and so would prevent subtypes of _ValidatorType from working +# when passed in a list or tuple. +_ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]] + +# _make -- + +NOTHING: object + +# NOTE: Factory lies about its return type to make this possible: `x: List[int] = Factory(list)` +# Work around mypy issue #4554 in the common case by using an overload. +@overload +def Factory(factory: Callable[[], _T]) -> _T: ... +@overload +def Factory( + factory: Union[Callable[[Any], _T], Callable[[], _T]], + takes_self: bool = ..., +) -> _T: ... + +class Attribute(Generic[_T]): + name: str + default: Optional[_T] + validator: Optional[_ValidatorType[_T]] + repr: bool + cmp: bool + hash: Optional[bool] + init: bool + converter: Optional[_ConverterType[_T]] + metadata: Dict[Any, Any] + type: Optional[Type[_T]] + kw_only: bool + def __lt__(self, x: Attribute) -> bool: ... + def __le__(self, x: Attribute) -> bool: ... + def __gt__(self, x: Attribute) -> bool: ... + def __ge__(self, x: Attribute) -> bool: ... + +# NOTE: We had several choices for the annotation to use for type arg: +# 1) Type[_T] +# - Pros: Handles simple cases correctly +# - Cons: Might produce less informative errors in the case of conflicting TypeVars +# e.g. `attr.ib(default='bad', type=int)` +# 2) Callable[..., _T] +# - Pros: Better error messages than #1 for conflicting TypeVars +# - Cons: Terrible error messages for validator checks. +# e.g. attr.ib(type=int, validator=validate_str) +# -> error: Cannot infer function type argument +# 3) type (and do all of the work in the mypy plugin) +# - Pros: Simple here, and we could customize the plugin with our own errors. +# - Cons: Would need to write mypy plugin code to handle all the cases. +# We chose option #1. + +# `attr` lies about its return type to make the following possible: +# attr() -> Any +# attr(8) -> int +# attr(validator=<some callable>) -> Whatever the callable expects. +# This makes this type of assignments possible: +# x: int = attr(8) +# +# This form catches explicit None or no default but with no other arguments returns Any. +@overload +def attrib( + default: None = ..., + validator: None = ..., + repr: bool = ..., + cmp: bool = ..., + hash: Optional[bool] = ..., + init: bool = ..., + convert: None = ..., + metadata: Optional[Mapping[Any, Any]] = ..., + type: None = ..., + converter: None = ..., + factory: None = ..., + kw_only: bool = ..., +) -> Any: ... + +# This form catches an explicit None or no default and infers the type from the other arguments. +@overload +def attrib( + default: None = ..., + validator: Optional[_ValidatorArgType[_T]] = ..., + repr: bool = ..., + cmp: bool = ..., + hash: Optional[bool] = ..., + init: bool = ..., + convert: Optional[_ConverterType[_T]] = ..., + metadata: Optional[Mapping[Any, Any]] = ..., + type: Optional[Type[_T]] = ..., + converter: Optional[_ConverterType[_T]] = ..., + factory: Optional[Callable[[], _T]] = ..., + kw_only: bool = ..., +) -> _T: ... + +# This form catches an explicit default argument. +@overload +def attrib( + default: _T, + validator: Optional[_ValidatorArgType[_T]] = ..., + repr: bool = ..., + cmp: bool = ..., + hash: Optional[bool] = ..., + init: bool = ..., + convert: Optional[_ConverterType[_T]] = ..., + metadata: Optional[Mapping[Any, Any]] = ..., + type: Optional[Type[_T]] = ..., + converter: Optional[_ConverterType[_T]] = ..., + factory: Optional[Callable[[], _T]] = ..., + kw_only: bool = ..., +) -> _T: ... + +# This form covers type=non-Type: e.g. forward references (str), Any +@overload +def attrib( + default: Optional[_T] = ..., + validator: Optional[_ValidatorArgType[_T]] = ..., + repr: bool = ..., + cmp: bool = ..., + hash: Optional[bool] = ..., + init: bool = ..., + convert: Optional[_ConverterType[_T]] = ..., + metadata: Optional[Mapping[Any, Any]] = ..., + type: object = ..., + converter: Optional[_ConverterType[_T]] = ..., + factory: Optional[Callable[[], _T]] = ..., + kw_only: bool = ..., +) -> Any: ... +@overload +def attrs( + maybe_cls: _C, + these: Optional[Dict[str, Any]] = ..., + repr_ns: Optional[str] = ..., + repr: bool = ..., + cmp: bool = ..., + hash: Optional[bool] = ..., + init: bool = ..., + slots: bool = ..., + frozen: bool = ..., + weakref_slot: bool = ..., + str: bool = ..., + auto_attribs: bool = ..., + kw_only: bool = ..., + cache_hash: bool = ..., +) -> _C: ... +@overload +def attrs( + maybe_cls: None = ..., + these: Optional[Dict[str, Any]] = ..., + repr_ns: Optional[str] = ..., + repr: bool = ..., + cmp: bool = ..., + hash: Optional[bool] = ..., + init: bool = ..., + slots: bool = ..., + frozen: bool = ..., + weakref_slot: bool = ..., + str: bool = ..., + auto_attribs: bool = ..., + kw_only: bool = ..., + cache_hash: bool = ..., +) -> Callable[[_C], _C]: ... + +# TODO: add support for returning NamedTuple from the mypy plugin +class _Fields(Tuple[Attribute, ...]): + def __getattr__(self, name: str) -> Attribute: ... + +def fields(cls: type) -> _Fields: ... +def fields_dict(cls: type) -> Dict[str, Attribute]: ... +def validate(inst: Any) -> None: ... + +# TODO: add support for returning a proper attrs class from the mypy plugin +# we use Any instead of _CountingAttr so that e.g. `make_class('Foo', [attr.ib()])` is valid +def make_class( + name: str, + attrs: Union[List[str], Tuple[str, ...], Dict[str, Any]], + bases: Tuple[type, ...] = ..., + repr_ns: Optional[str] = ..., + repr: bool = ..., + cmp: bool = ..., + hash: Optional[bool] = ..., + init: bool = ..., + slots: bool = ..., + frozen: bool = ..., + weakref_slot: bool = ..., + str: bool = ..., + auto_attribs: bool = ..., + kw_only: bool = ..., + cache_hash: bool = ..., +) -> type: ... + +# _funcs -- + +# TODO: add support for returning TypedDict from the mypy plugin +# FIXME: asdict/astuple do not honor their factory args. waiting on one of these: +# https://github.com/python/mypy/issues/4236 +# https://github.com/python/typing/issues/253 +def asdict( + inst: Any, + recurse: bool = ..., + filter: Optional[_FilterType] = ..., + dict_factory: Type[Mapping[Any, Any]] = ..., + retain_collection_types: bool = ..., +) -> Dict[str, Any]: ... + +# TODO: add support for returning NamedTuple from the mypy plugin +def astuple( + inst: Any, + recurse: bool = ..., + filter: Optional[_FilterType] = ..., + tuple_factory: Type[Sequence] = ..., + retain_collection_types: bool = ..., +) -> Tuple[Any, ...]: ... +def has(cls: type) -> bool: ... +def assoc(inst: _T, **changes: Any) -> _T: ... +def evolve(inst: _T, **changes: Any) -> _T: ... + +# _config -- + +def set_run_validators(run: bool) -> None: ... +def get_run_validators() -> bool: ... + +# aliases -- + +s = attributes = attrs +ib = attr = attrib +dataclass = attrs # Technically, partial(attrs, auto_attribs=True) ;) diff --git a/python/attr/_compat.py b/python/attr/_compat.py new file mode 100644 index 0000000..5bb0659 --- /dev/null +++ b/python/attr/_compat.py @@ -0,0 +1,163 @@ +from __future__ import absolute_import, division, print_function + +import platform +import sys +import types +import warnings + + +PY2 = sys.version_info[0] == 2 +PYPY = platform.python_implementation() == "PyPy" + + +if PYPY or sys.version_info[:2] >= (3, 6): + ordered_dict = dict +else: + from collections import OrderedDict + + ordered_dict = OrderedDict + + +if PY2: + from UserDict import IterableUserDict + + # We 'bundle' isclass instead of using inspect as importing inspect is + # fairly expensive (order of 10-15 ms for a modern machine in 2016) + def isclass(klass): + return isinstance(klass, (type, types.ClassType)) + + # TYPE is used in exceptions, repr(int) is different on Python 2 and 3. + TYPE = "type" + + def iteritems(d): + return d.iteritems() + + # Python 2 is bereft of a read-only dict proxy, so we make one! + class ReadOnlyDict(IterableUserDict): + """ + Best-effort read-only dict wrapper. + """ + + def __setitem__(self, key, val): + # We gently pretend we're a Python 3 mappingproxy. + raise TypeError( + "'mappingproxy' object does not support item assignment" + ) + + def update(self, _): + # We gently pretend we're a Python 3 mappingproxy. + raise AttributeError( + "'mappingproxy' object has no attribute 'update'" + ) + + def __delitem__(self, _): + # We gently pretend we're a Python 3 mappingproxy. + raise TypeError( + "'mappingproxy' object does not support item deletion" + ) + + def clear(self): + # We gently pretend we're a Python 3 mappingproxy. + raise AttributeError( + "'mappingproxy' object has no attribute 'clear'" + ) + + def pop(self, key, default=None): + # We gently pretend we're a Python 3 mappingproxy. + raise AttributeError( + "'mappingproxy' object has no attribute 'pop'" + ) + + def popitem(self): + # We gently pretend we're a Python 3 mappingproxy. + raise AttributeError( + "'mappingproxy' object has no attribute 'popitem'" + ) + + def setdefault(self, key, default=None): + # We gently pretend we're a Python 3 mappingproxy. + raise AttributeError( + "'mappingproxy' object has no attribute 'setdefault'" + ) + + def __repr__(self): + # Override to be identical to the Python 3 version. + return "mappingproxy(" + repr(self.data) + ")" + + def metadata_proxy(d): + res = ReadOnlyDict() + res.data.update(d) # We blocked update, so we have to do it like this. + return res + + +else: + + def isclass(klass): + return isinstance(klass, type) + + TYPE = "class" + + def iteritems(d): + return d.items() + + def metadata_proxy(d): + return types.MappingProxyType(dict(d)) + + +def import_ctypes(): + """ + Moved into a function for testability. + """ + import ctypes + + return ctypes + + +if not PY2: + + def just_warn(*args, **kw): + """ + We only warn on Python 3 because we are not aware of any concrete + consequences of not setting the cell on Python 2. + """ + warnings.warn( + "Missing ctypes. Some features like bare super() or accessing " + "__class__ will not work with slots classes.", + RuntimeWarning, + stacklevel=2, + ) + + +else: + + def just_warn(*args, **kw): # pragma: nocover + """ + We only warn on Python 3 because we are not aware of any concrete + consequences of not setting the cell on Python 2. + """ + + +def make_set_closure_cell(): + """ + Moved into a function for testability. + """ + if PYPY: # pragma: no cover + + def set_closure_cell(cell, value): + cell.__setstate__((value,)) + + else: + try: + ctypes = import_ctypes() + + set_closure_cell = ctypes.pythonapi.PyCell_Set + set_closure_cell.argtypes = (ctypes.py_object, ctypes.py_object) + set_closure_cell.restype = ctypes.c_int + except Exception: + # We try best effort to set the cell, but sometimes it's not + # possible. For example on Jython or on GAE. + set_closure_cell = just_warn + return set_closure_cell + + +set_closure_cell = make_set_closure_cell() diff --git a/python/attr/_config.py b/python/attr/_config.py new file mode 100644 index 0000000..8ec9209 --- /dev/null +++ b/python/attr/_config.py @@ -0,0 +1,23 @@ +from __future__ import absolute_import, division, print_function + + +__all__ = ["set_run_validators", "get_run_validators"] + +_run_validators = True + + +def set_run_validators(run): + """ + Set whether or not validators are run. By default, they are run. + """ + if not isinstance(run, bool): + raise TypeError("'run' must be bool.") + global _run_validators + _run_validators = run + + +def get_run_validators(): + """ + Return whether or not validators are run. + """ + return _run_validators diff --git a/python/attr/_funcs.py b/python/attr/_funcs.py new file mode 100644 index 0000000..b61d239 --- /dev/null +++ b/python/attr/_funcs.py @@ -0,0 +1,290 @@ +from __future__ import absolute_import, division, print_function + +import copy + +from ._compat import iteritems +from ._make import NOTHING, _obj_setattr, fields +from .exceptions import AttrsAttributeNotFoundError + + +def asdict( + inst, + recurse=True, + filter=None, + dict_factory=dict, + retain_collection_types=False, +): + """ + Return the ``attrs`` attribute values of *inst* as a dict. + + Optionally recurse into other ``attrs``-decorated classes. + + :param inst: Instance of an ``attrs``-decorated class. + :param bool recurse: Recurse into classes that are also + ``attrs``-decorated. + :param callable filter: A callable whose return code determines whether an + attribute or element is included (``True``) or dropped (``False``). Is + called with the :class:`attr.Attribute` as the first argument and the + value as the second argument. + :param callable dict_factory: A callable to produce dictionaries from. For + example, to produce ordered dictionaries instead of normal Python + dictionaries, pass in ``collections.OrderedDict``. + :param bool retain_collection_types: Do not convert to ``list`` when + encountering an attribute whose type is ``tuple`` or ``set``. Only + meaningful if ``recurse`` is ``True``. + + :rtype: return type of *dict_factory* + + :raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs`` + class. + + .. versionadded:: 16.0.0 *dict_factory* + .. versionadded:: 16.1.0 *retain_collection_types* + """ + attrs = fields(inst.__class__) + rv = dict_factory() + for a in attrs: + v = getattr(inst, a.name) + if filter is not None and not filter(a, v): + continue + if recurse is True: + if has(v.__class__): + rv[a.name] = asdict( + v, True, filter, dict_factory, retain_collection_types + ) + elif isinstance(v, (tuple, list, set)): + cf = v.__class__ if retain_collection_types is True else list + rv[a.name] = cf( + [ + _asdict_anything( + i, filter, dict_factory, retain_collection_types + ) + for i in v + ] + ) + elif isinstance(v, dict): + df = dict_factory + rv[a.name] = df( + ( + _asdict_anything( + kk, filter, df, retain_collection_types + ), + _asdict_anything( + vv, filter, df, retain_collection_types + ), + ) + for kk, vv in iteritems(v) + ) + else: + rv[a.name] = v + else: + rv[a.name] = v + return rv + + +def _asdict_anything(val, filter, dict_factory, retain_collection_types): + """ + ``asdict`` only works on attrs instances, this works on anything. + """ + if getattr(val.__class__, "__attrs_attrs__", None) is not None: + # Attrs class. + rv = asdict(val, True, filter, dict_factory, retain_collection_types) + elif isinstance(val, (tuple, list, set)): + cf = val.__class__ if retain_collection_types is True else list + rv = cf( + [ + _asdict_anything( + i, filter, dict_factory, retain_collection_types + ) + for i in val + ] + ) + elif isinstance(val, dict): + df = dict_factory + rv = df( + ( + _asdict_anything(kk, filter, df, retain_collection_types), + _asdict_anything(vv, filter, df, retain_collection_types), + ) + for kk, vv in iteritems(val) + ) + else: + rv = val + return rv + + +def astuple( + inst, + recurse=True, + filter=None, + tuple_factory=tuple, + retain_collection_types=False, +): + """ + Return the ``attrs`` attribute values of *inst* as a tuple. + + Optionally recurse into other ``attrs``-decorated classes. + + :param inst: Instance of an ``attrs``-decorated class. + :param bool recurse: Recurse into classes that are also + ``attrs``-decorated. + :param callable filter: A callable whose return code determines whether an + attribute or element is included (``True``) or dropped (``False``). Is + called with the :class:`attr.Attribute` as the first argument and the + value as the second argument. + :param callable tuple_factory: A callable to produce tuples from. For + example, to produce lists instead of tuples. + :param bool retain_collection_types: Do not convert to ``list`` + or ``dict`` when encountering an attribute which type is + ``tuple``, ``dict`` or ``set``. Only meaningful if ``recurse`` is + ``True``. + + :rtype: return type of *tuple_factory* + + :raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs`` + class. + + .. versionadded:: 16.2.0 + """ + attrs = fields(inst.__class__) + rv = [] + retain = retain_collection_types # Very long. :/ + for a in attrs: + v = getattr(inst, a.name) + if filter is not None and not filter(a, v): + continue + if recurse is True: + if has(v.__class__): + rv.append( + astuple( + v, + recurse=True, + filter=filter, + tuple_factory=tuple_factory, + retain_collection_types=retain, + ) + ) + elif isinstance(v, (tuple, list, set)): + cf = v.__class__ if retain is True else list + rv.append( + cf( + [ + astuple( + j, + recurse=True, + filter=filter, + tuple_factory=tuple_factory, + retain_collection_types=retain, + ) + if has(j.__class__) + else j + for j in v + ] + ) + ) + elif isinstance(v, dict): + df = v.__class__ if retain is True else dict + rv.append( + df( + ( + astuple( + kk, + tuple_factory=tuple_factory, + retain_collection_types=retain, + ) + if has(kk.__class__) + else kk, + astuple( + vv, + tuple_factory=tuple_factory, + retain_collection_types=retain, + ) + if has(vv.__class__) + else vv, + ) + for kk, vv in iteritems(v) + ) + ) + else: + rv.append(v) + else: + rv.append(v) + return rv if tuple_factory is list else tuple_factory(rv) + + +def has(cls): + """ + Check whether *cls* is a class with ``attrs`` attributes. + + :param type cls: Class to introspect. + :raise TypeError: If *cls* is not a class. + + :rtype: :class:`bool` + """ + return getattr(cls, "__attrs_attrs__", None) is not None + + +def assoc(inst, **changes): + """ + Copy *inst* and apply *changes*. + + :param inst: Instance of a class with ``attrs`` attributes. + :param changes: Keyword changes in the new copy. + + :return: A copy of inst with *changes* incorporated. + + :raise attr.exceptions.AttrsAttributeNotFoundError: If *attr_name* couldn't + be found on *cls*. + :raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs`` + class. + + .. deprecated:: 17.1.0 + Use :func:`evolve` instead. + """ + import warnings + + warnings.warn( + "assoc is deprecated and will be removed after 2018/01.", + DeprecationWarning, + stacklevel=2, + ) + new = copy.copy(inst) + attrs = fields(inst.__class__) + for k, v in iteritems(changes): + a = getattr(attrs, k, NOTHING) + if a is NOTHING: + raise AttrsAttributeNotFoundError( + "{k} is not an attrs attribute on {cl}.".format( + k=k, cl=new.__class__ + ) + ) + _obj_setattr(new, k, v) + return new + + +def evolve(inst, **changes): + """ + Create a new instance, based on *inst* with *changes* applied. + + :param inst: Instance of a class with ``attrs`` attributes. + :param changes: Keyword changes in the new copy. + + :return: A copy of inst with *changes* incorporated. + + :raise TypeError: If *attr_name* couldn't be found in the class + ``__init__``. + :raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs`` + class. + + .. versionadded:: 17.1.0 + """ + cls = inst.__class__ + attrs = fields(cls) + for a in attrs: + if not a.init: + continue + attr_name = a.name # To deal with private attributes. + init_name = attr_name if attr_name[0] != "_" else attr_name[1:] + if init_name not in changes: + changes[init_name] = getattr(inst, attr_name) + return cls(**changes) diff --git a/python/attr/_make.py b/python/attr/_make.py new file mode 100644 index 0000000..f7fd05e --- /dev/null +++ b/python/attr/_make.py @@ -0,0 +1,2034 @@ +from __future__ import absolute_import, division, print_function + +import copy +import hashlib +import linecache +import sys +import threading +import warnings + +from operator import itemgetter + +from . import _config +from ._compat import ( + PY2, + isclass, + iteritems, + metadata_proxy, + ordered_dict, + set_closure_cell, +) +from .exceptions import ( + DefaultAlreadySetError, + FrozenInstanceError, + NotAnAttrsClassError, + PythonTooOldError, + UnannotatedAttributeError, +) + + +# This is used at least twice, so cache it here. +_obj_setattr = object.__setattr__ +_init_converter_pat = "__attr_converter_{}" +_init_factory_pat = "__attr_factory_{}" +_tuple_property_pat = ( + " {attr_name} = _attrs_property(_attrs_itemgetter({index}))" +) +_classvar_prefixes = ("typing.ClassVar", "t.ClassVar", "ClassVar") +# we don't use a double-underscore prefix because that triggers +# name mangling when trying to create a slot for the field +# (when slots=True) +_hash_cache_field = "_attrs_cached_hash" + +_empty_metadata_singleton = metadata_proxy({}) + + +class _Nothing(object): + """ + Sentinel class to indicate the lack of a value when ``None`` is ambiguous. + + ``_Nothing`` is a singleton. There is only ever one of it. + """ + + _singleton = None + + def __new__(cls): + if _Nothing._singleton is None: + _Nothing._singleton = super(_Nothing, cls).__new__(cls) + return _Nothing._singleton + + def __repr__(self): + return "NOTHING" + + +NOTHING = _Nothing() +""" +Sentinel to indicate the lack of a value when ``None`` is ambiguous. +""" + + +def attrib( + default=NOTHING, + validator=None, + repr=True, + cmp=True, + hash=None, + init=True, + convert=None, + metadata=None, + type=None, + converter=None, + factory=None, + kw_only=False, +): + """ + Create a new attribute on a class. + + .. warning:: + + Does *not* do anything unless the class is also decorated with + :func:`attr.s`! + + :param default: A value that is used if an ``attrs``-generated ``__init__`` + is used and no value is passed while instantiating or the attribute is + excluded using ``init=False``. + + If the value is an instance of :class:`Factory`, its callable will be + used to construct a new value (useful for mutable data types like lists + or dicts). + + If a default is not set (or set manually to ``attr.NOTHING``), a value + *must* be supplied when instantiating; otherwise a :exc:`TypeError` + will be raised. + + The default can also be set using decorator notation as shown below. + + :type default: Any value. + + :param callable factory: Syntactic sugar for + ``default=attr.Factory(callable)``. + + :param validator: :func:`callable` that is called by ``attrs``-generated + ``__init__`` methods after the instance has been initialized. They + receive the initialized instance, the :class:`Attribute`, and the + passed value. + + The return value is *not* inspected so the validator has to throw an + exception itself. + + If a ``list`` is passed, its items are treated as validators and must + all pass. + + Validators can be globally disabled and re-enabled using + :func:`get_run_validators`. + + The validator can also be set using decorator notation as shown below. + + :type validator: ``callable`` or a ``list`` of ``callable``\\ s. + + :param bool repr: Include this attribute in the generated ``__repr__`` + method. + :param bool cmp: Include this attribute in the generated comparison methods + (``__eq__`` et al). + :param hash: Include this attribute in the generated ``__hash__`` + method. If ``None`` (default), mirror *cmp*'s value. This is the + correct behavior according the Python spec. Setting this value to + anything else than ``None`` is *discouraged*. + :type hash: ``bool`` or ``None`` + :param bool init: Include this attribute in the generated ``__init__`` + method. It is possible to set this to ``False`` and set a default + value. In that case this attributed is unconditionally initialized + with the specified default value or factory. + :param callable converter: :func:`callable` that is called by + ``attrs``-generated ``__init__`` methods to converter attribute's value + to the desired format. It is given the passed-in value, and the + returned value will be used as the new value of the attribute. The + value is converted before being passed to the validator, if any. + :param metadata: An arbitrary mapping, to be used by third-party + components. See :ref:`extending_metadata`. + :param type: The type of the attribute. In Python 3.6 or greater, the + preferred method to specify the type is using a variable annotation + (see `PEP 526 <https://www.python.org/dev/peps/pep-0526/>`_). + This argument is provided for backward compatibility. + Regardless of the approach used, the type will be stored on + ``Attribute.type``. + + Please note that ``attrs`` doesn't do anything with this metadata by + itself. You can use it as part of your own code or for + :doc:`static type checking <types>`. + :param kw_only: Make this attribute keyword-only (Python 3+) + in the generated ``__init__`` (if ``init`` is ``False``, this + parameter is ignored). + + .. versionadded:: 15.2.0 *convert* + .. versionadded:: 16.3.0 *metadata* + .. versionchanged:: 17.1.0 *validator* can be a ``list`` now. + .. versionchanged:: 17.1.0 + *hash* is ``None`` and therefore mirrors *cmp* by default. + .. versionadded:: 17.3.0 *type* + .. deprecated:: 17.4.0 *convert* + .. versionadded:: 17.4.0 *converter* as a replacement for the deprecated + *convert* to achieve consistency with other noun-based arguments. + .. versionadded:: 18.1.0 + ``factory=f`` is syntactic sugar for ``default=attr.Factory(f)``. + .. versionadded:: 18.2.0 *kw_only* + """ + if hash is not None and hash is not True and hash is not False: + raise TypeError( + "Invalid value for hash. Must be True, False, or None." + ) + + if convert is not None: + if converter is not None: + raise RuntimeError( + "Can't pass both `convert` and `converter`. " + "Please use `converter` only." + ) + warnings.warn( + "The `convert` argument is deprecated in favor of `converter`. " + "It will be removed after 2019/01.", + DeprecationWarning, + stacklevel=2, + ) + converter = convert + + if factory is not None: + if default is not NOTHING: + raise ValueError( + "The `default` and `factory` arguments are mutually " + "exclusive." + ) + if not callable(factory): + raise ValueError("The `factory` argument must be a callable.") + default = Factory(factory) + + if metadata is None: + metadata = {} + + return _CountingAttr( + default=default, + validator=validator, + repr=repr, + cmp=cmp, + hash=hash, + init=init, + converter=converter, + metadata=metadata, + type=type, + kw_only=kw_only, + ) + + +def _make_attr_tuple_class(cls_name, attr_names): + """ + Create a tuple subclass to hold `Attribute`s for an `attrs` class. + + The subclass is a bare tuple with properties for names. + + class MyClassAttributes(tuple): + __slots__ = () + x = property(itemgetter(0)) + """ + attr_class_name = "{}Attributes".format(cls_name) + attr_class_template = [ + "class {}(tuple):".format(attr_class_name), + " __slots__ = ()", + ] + if attr_names: + for i, attr_name in enumerate(attr_names): + attr_class_template.append( + _tuple_property_pat.format(index=i, attr_name=attr_name) + ) + else: + attr_class_template.append(" pass") + globs = {"_attrs_itemgetter": itemgetter, "_attrs_property": property} + eval(compile("\n".join(attr_class_template), "", "exec"), globs) + + return globs[attr_class_name] + + +# Tuple class for extracted attributes from a class definition. +# `base_attrs` is a subset of `attrs`. +_Attributes = _make_attr_tuple_class( + "_Attributes", + [ + # all attributes to build dunder methods for + "attrs", + # attributes that have been inherited + "base_attrs", + # map inherited attributes to their originating classes + "base_attrs_map", + ], +) + + +def _is_class_var(annot): + """ + Check whether *annot* is a typing.ClassVar. + + The string comparison hack is used to avoid evaluating all string + annotations which would put attrs-based classes at a performance + disadvantage compared to plain old classes. + """ + return str(annot).startswith(_classvar_prefixes) + + +def _get_annotations(cls): + """ + Get annotations for *cls*. + """ + anns = getattr(cls, "__annotations__", None) + if anns is None: + return {} + + # Verify that the annotations aren't merely inherited. + for base_cls in cls.__mro__[1:]: + if anns is getattr(base_cls, "__annotations__", None): + return {} + + return anns + + +def _counter_getter(e): + """ + Key function for sorting to avoid re-creating a lambda for every class. + """ + return e[1].counter + + +def _transform_attrs(cls, these, auto_attribs, kw_only): + """ + Transform all `_CountingAttr`s on a class into `Attribute`s. + + If *these* is passed, use that and don't look for them on the class. + + Return an `_Attributes`. + """ + cd = cls.__dict__ + anns = _get_annotations(cls) + + if these is not None: + ca_list = [(name, ca) for name, ca in iteritems(these)] + + if not isinstance(these, ordered_dict): + ca_list.sort(key=_counter_getter) + elif auto_attribs is True: + ca_names = { + name + for name, attr in cd.items() + if isinstance(attr, _CountingAttr) + } + ca_list = [] + annot_names = set() + for attr_name, type in anns.items(): + if _is_class_var(type): + continue + annot_names.add(attr_name) + a = cd.get(attr_name, NOTHING) + if not isinstance(a, _CountingAttr): + if a is NOTHING: + a = attrib() + else: + a = attrib(default=a) + ca_list.append((attr_name, a)) + + unannotated = ca_names - annot_names + if len(unannotated) > 0: + raise UnannotatedAttributeError( + "The following `attr.ib`s lack a type annotation: " + + ", ".join( + sorted(unannotated, key=lambda n: cd.get(n).counter) + ) + + "." + ) + else: + ca_list = sorted( + ( + (name, attr) + for name, attr in cd.items() + if isinstance(attr, _CountingAttr) + ), + key=lambda e: e[1].counter, + ) + + own_attrs = [ + Attribute.from_counting_attr( + name=attr_name, ca=ca, type=anns.get(attr_name) + ) + for attr_name, ca in ca_list + ] + + base_attrs = [] + base_attr_map = {} # A dictionary of base attrs to their classes. + taken_attr_names = {a.name: a for a in own_attrs} + + # Traverse the MRO and collect attributes. + for base_cls in cls.__mro__[1:-1]: + sub_attrs = getattr(base_cls, "__attrs_attrs__", None) + if sub_attrs is not None: + for a in sub_attrs: + prev_a = taken_attr_names.get(a.name) + # Only add an attribute if it hasn't been defined before. This + # allows for overwriting attribute definitions by subclassing. + if prev_a is None: + base_attrs.append(a) + taken_attr_names[a.name] = a + base_attr_map[a.name] = base_cls + + attr_names = [a.name for a in base_attrs + own_attrs] + + AttrsClass = _make_attr_tuple_class(cls.__name__, attr_names) + + if kw_only: + own_attrs = [a._assoc(kw_only=True) for a in own_attrs] + base_attrs = [a._assoc(kw_only=True) for a in base_attrs] + + attrs = AttrsClass(base_attrs + own_attrs) + + had_default = False + was_kw_only = False + for a in attrs: + if ( + was_kw_only is False + and had_default is True + and a.default is NOTHING + and a.init is True + and a.kw_only is False + ): + raise ValueError( + "No mandatory attributes allowed after an attribute with a " + "default value or factory. Attribute in question: %r" % (a,) + ) + elif ( + had_default is False + and a.default is not NOTHING + and a.init is not False + and + # Keyword-only attributes without defaults can be specified + # after keyword-only attributes with defaults. + a.kw_only is False + ): + had_default = True + if was_kw_only is True and a.kw_only is False: + raise ValueError( + "Non keyword-only attributes are not allowed after a " + "keyword-only attribute. Attribute in question: {a!r}".format( + a=a + ) + ) + if was_kw_only is False and a.init is True and a.kw_only is True: + was_kw_only = True + + return _Attributes((attrs, base_attrs, base_attr_map)) + + +def _frozen_setattrs(self, name, value): + """ + Attached to frozen classes as __setattr__. + """ + raise FrozenInstanceError() + + +def _frozen_delattrs(self, name): + """ + Attached to frozen classes as __delattr__. + """ + raise FrozenInstanceError() + + +class _ClassBuilder(object): + """ + Iteratively build *one* class. + """ + + __slots__ = ( + "_cls", + "_cls_dict", + "_attrs", + "_base_names", + "_attr_names", + "_slots", + "_frozen", + "_weakref_slot", + "_cache_hash", + "_has_post_init", + "_delete_attribs", + "_base_attr_map", + ) + + def __init__( + self, + cls, + these, + slots, + frozen, + weakref_slot, + auto_attribs, + kw_only, + cache_hash, + ): + attrs, base_attrs, base_map = _transform_attrs( + cls, these, auto_attribs, kw_only + ) + + self._cls = cls + self._cls_dict = dict(cls.__dict__) if slots else {} + self._attrs = attrs + self._base_names = set(a.name for a in base_attrs) + self._base_attr_map = base_map + self._attr_names = tuple(a.name for a in attrs) + self._slots = slots + self._frozen = frozen or _has_frozen_base_class(cls) + self._weakref_slot = weakref_slot + self._cache_hash = cache_hash + self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False)) + self._delete_attribs = not bool(these) + + self._cls_dict["__attrs_attrs__"] = self._attrs + + if frozen: + self._cls_dict["__setattr__"] = _frozen_setattrs + self._cls_dict["__delattr__"] = _frozen_delattrs + + def __repr__(self): + return "<_ClassBuilder(cls={cls})>".format(cls=self._cls.__name__) + + def build_class(self): + """ + Finalize class based on the accumulated configuration. + + Builder cannot be used after calling this method. + """ + if self._slots is True: + return self._create_slots_class() + else: + return self._patch_original_class() + + def _patch_original_class(self): + """ + Apply accumulated methods and return the class. + """ + cls = self._cls + base_names = self._base_names + + # Clean class of attribute definitions (`attr.ib()`s). + if self._delete_attribs: + for name in self._attr_names: + if ( + name not in base_names + and getattr(cls, name, None) is not None + ): + try: + delattr(cls, name) + except AttributeError: + # This can happen if a base class defines a class + # variable and we want to set an attribute with the + # same name by using only a type annotation. + pass + + # Attach our dunder methods. + for name, value in self._cls_dict.items(): + setattr(cls, name, value) + + return cls + + def _create_slots_class(self): + """ + Build and return a new class with a `__slots__` attribute. + """ + base_names = self._base_names + cd = { + k: v + for k, v in iteritems(self._cls_dict) + if k not in tuple(self._attr_names) + ("__dict__", "__weakref__") + } + + weakref_inherited = False + + # Traverse the MRO to check for an existing __weakref__. + for base_cls in self._cls.__mro__[1:-1]: + if "__weakref__" in getattr(base_cls, "__dict__", ()): + weakref_inherited = True + break + + names = self._attr_names + if ( + self._weakref_slot + and "__weakref__" not in getattr(self._cls, "__slots__", ()) + and "__weakref__" not in names + and not weakref_inherited + ): + names += ("__weakref__",) + + # We only add the names of attributes that aren't inherited. + # Settings __slots__ to inherited attributes wastes memory. + slot_names = [name for name in names if name not in base_names] + if self._cache_hash: + slot_names.append(_hash_cache_field) + cd["__slots__"] = tuple(slot_names) + + qualname = getattr(self._cls, "__qualname__", None) + if qualname is not None: + cd["__qualname__"] = qualname + + # __weakref__ is not writable. + state_attr_names = tuple( + an for an in self._attr_names if an != "__weakref__" + ) + + def slots_getstate(self): + """ + Automatically created by attrs. + """ + return tuple(getattr(self, name) for name in state_attr_names) + + def slots_setstate(self, state): + """ + Automatically created by attrs. + """ + __bound_setattr = _obj_setattr.__get__(self, Attribute) + for name, value in zip(state_attr_names, state): + __bound_setattr(name, value) + + # slots and frozen require __getstate__/__setstate__ to work + cd["__getstate__"] = slots_getstate + cd["__setstate__"] = slots_setstate + + # Create new class based on old class and our methods. + cls = type(self._cls)(self._cls.__name__, self._cls.__bases__, cd) + + # The following is a fix for + # https://github.com/python-attrs/attrs/issues/102. On Python 3, + # if a method mentions `__class__` or uses the no-arg super(), the + # compiler will bake a reference to the class in the method itself + # as `method.__closure__`. Since we replace the class with a + # clone, we rewrite these references so it keeps working. + for item in cls.__dict__.values(): + if isinstance(item, (classmethod, staticmethod)): + # Class- and staticmethods hide their functions inside. + # These might need to be rewritten as well. + closure_cells = getattr(item.__func__, "__closure__", None) + else: + closure_cells = getattr(item, "__closure__", None) + + if not closure_cells: # Catch None or the empty list. + continue + for cell in closure_cells: + if cell.cell_contents is self._cls: + set_closure_cell(cell, cls) + + return cls + + def add_repr(self, ns): + self._cls_dict["__repr__"] = self._add_method_dunders( + _make_repr(self._attrs, ns=ns) + ) + return self + + def add_str(self): + repr = self._cls_dict.get("__repr__") + if repr is None: + raise ValueError( + "__str__ can only be generated if a __repr__ exists." + ) + + def __str__(self): + return self.__repr__() + + self._cls_dict["__str__"] = self._add_method_dunders(__str__) + return self + + def make_unhashable(self): + self._cls_dict["__hash__"] = None + return self + + def add_hash(self): + self._cls_dict["__hash__"] = self._add_method_dunders( + _make_hash( + self._attrs, frozen=self._frozen, cache_hash=self._cache_hash + ) + ) + + return self + + def add_init(self): + self._cls_dict["__init__"] = self._add_method_dunders( + _make_init( + self._attrs, + self._has_post_init, + self._frozen, + self._slots, + self._cache_hash, + self._base_attr_map, + ) + ) + + return self + + def add_cmp(self): + cd = self._cls_dict + + cd["__eq__"], cd["__ne__"], cd["__lt__"], cd["__le__"], cd[ + "__gt__" + ], cd["__ge__"] = ( + self._add_method_dunders(meth) for meth in _make_cmp(self._attrs) + ) + + return self + + def _add_method_dunders(self, method): + """ + Add __module__ and __qualname__ to a *method* if possible. + """ + try: + method.__module__ = self._cls.__module__ + except AttributeError: + pass + + try: + method.__qualname__ = ".".join( + (self._cls.__qualname__, method.__name__) + ) + except AttributeError: + pass + + return method + + +def attrs( + maybe_cls=None, + these=None, + repr_ns=None, + repr=True, + cmp=True, + hash=None, + init=True, + slots=False, + frozen=False, + weakref_slot=True, + str=False, + auto_attribs=False, + kw_only=False, + cache_hash=False, +): + r""" + A class decorator that adds `dunder + <https://wiki.python.org/moin/DunderAlias>`_\ -methods according to the + specified attributes using :func:`attr.ib` or the *these* argument. + + :param these: A dictionary of name to :func:`attr.ib` mappings. This is + useful to avoid the definition of your attributes within the class body + because you can't (e.g. if you want to add ``__repr__`` methods to + Django models) or don't want to. + + If *these* is not ``None``, ``attrs`` will *not* search the class body + for attributes and will *not* remove any attributes from it. + + If *these* is an ordered dict (:class:`dict` on Python 3.6+, + :class:`collections.OrderedDict` otherwise), the order is deduced from + the order of the attributes inside *these*. Otherwise the order + of the definition of the attributes is used. + + :type these: :class:`dict` of :class:`str` to :func:`attr.ib` + + :param str repr_ns: When using nested classes, there's no way in Python 2 + to automatically detect that. Therefore it's possible to set the + namespace explicitly for a more meaningful ``repr`` output. + :param bool repr: Create a ``__repr__`` method with a human readable + representation of ``attrs`` attributes.. + :param bool str: Create a ``__str__`` method that is identical to + ``__repr__``. This is usually not necessary except for + :class:`Exception`\ s. + :param bool cmp: Create ``__eq__``, ``__ne__``, ``__lt__``, ``__le__``, + ``__gt__``, and ``__ge__`` methods that compare the class as if it were + a tuple of its ``attrs`` attributes. But the attributes are *only* + compared, if the types of both classes are *identical*! + :param hash: If ``None`` (default), the ``__hash__`` method is generated + according how *cmp* and *frozen* are set. + + 1. If *both* are True, ``attrs`` will generate a ``__hash__`` for you. + 2. If *cmp* is True and *frozen* is False, ``__hash__`` will be set to + None, marking it unhashable (which it is). + 3. If *cmp* is False, ``__hash__`` will be left untouched meaning the + ``__hash__`` method of the base class will be used (if base class is + ``object``, this means it will fall back to id-based hashing.). + + Although not recommended, you can decide for yourself and force + ``attrs`` to create one (e.g. if the class is immutable even though you + didn't freeze it programmatically) by passing ``True`` or not. Both of + these cases are rather special and should be used carefully. + + See the `Python documentation \ + <https://docs.python.org/3/reference/datamodel.html#object.__hash__>`_ + and the `GitHub issue that led to the default behavior \ + <https://github.com/python-attrs/attrs/issues/136>`_ for more details. + :type hash: ``bool`` or ``None`` + :param bool init: Create a ``__init__`` method that initializes the + ``attrs`` attributes. Leading underscores are stripped for the + argument name. If a ``__attrs_post_init__`` method exists on the + class, it will be called after the class is fully initialized. + :param bool slots: Create a slots_-style class that's more + memory-efficient. See :ref:`slots` for further ramifications. + :param bool frozen: Make instances immutable after initialization. If + someone attempts to modify a frozen instance, + :exc:`attr.exceptions.FrozenInstanceError` is raised. + + Please note: + + 1. This is achieved by installing a custom ``__setattr__`` method + on your class so you can't implement an own one. + + 2. True immutability is impossible in Python. + + 3. This *does* have a minor a runtime performance :ref:`impact + <how-frozen>` when initializing new instances. In other words: + ``__init__`` is slightly slower with ``frozen=True``. + + 4. If a class is frozen, you cannot modify ``self`` in + ``__attrs_post_init__`` or a self-written ``__init__``. You can + circumvent that limitation by using + ``object.__setattr__(self, "attribute_name", value)``. + + .. _slots: https://docs.python.org/3/reference/datamodel.html#slots + :param bool weakref_slot: Make instances weak-referenceable. This has no + effect unless ``slots`` is also enabled. + :param bool auto_attribs: If True, collect `PEP 526`_-annotated attributes + (Python 3.6 and later only) from the class body. + + In this case, you **must** annotate every field. If ``attrs`` + encounters a field that is set to an :func:`attr.ib` but lacks a type + annotation, an :exc:`attr.exceptions.UnannotatedAttributeError` is + raised. Use ``field_name: typing.Any = attr.ib(...)`` if you don't + want to set a type. + + If you assign a value to those attributes (e.g. ``x: int = 42``), that + value becomes the default value like if it were passed using + ``attr.ib(default=42)``. Passing an instance of :class:`Factory` also + works as expected. + + Attributes annotated as :data:`typing.ClassVar` are **ignored**. + + .. _`PEP 526`: https://www.python.org/dev/peps/pep-0526/ + :param bool kw_only: Make all attributes keyword-only (Python 3+) + in the generated ``__init__`` (if ``init`` is ``False``, this + parameter is ignored). + :param bool cache_hash: Ensure that the object's hash code is computed + only once and stored on the object. If this is set to ``True``, + hashing must be either explicitly or implicitly enabled for this + class. If the hash code is cached, then no attributes of this + class which participate in hash code computation may be mutated + after object creation. + + + .. versionadded:: 16.0.0 *slots* + .. versionadded:: 16.1.0 *frozen* + .. versionadded:: 16.3.0 *str* + .. versionadded:: 16.3.0 Support for ``__attrs_post_init__``. + .. versionchanged:: 17.1.0 + *hash* supports ``None`` as value which is also the default now. + .. versionadded:: 17.3.0 *auto_attribs* + .. versionchanged:: 18.1.0 + If *these* is passed, no attributes are deleted from the class body. + .. versionchanged:: 18.1.0 If *these* is ordered, the order is retained. + .. versionadded:: 18.2.0 *weakref_slot* + .. deprecated:: 18.2.0 + ``__lt__``, ``__le__``, ``__gt__``, and ``__ge__`` now raise a + :class:`DeprecationWarning` if the classes compared are subclasses of + each other. ``__eq`` and ``__ne__`` never tried to compared subclasses + to each other. + .. versionadded:: 18.2.0 *kw_only* + .. versionadded:: 18.2.0 *cache_hash* + """ + + def wrap(cls): + if getattr(cls, "__class__", None) is None: + raise TypeError("attrs only works with new-style classes.") + + builder = _ClassBuilder( + cls, + these, + slots, + frozen, + weakref_slot, + auto_attribs, + kw_only, + cache_hash, + ) + + if repr is True: + builder.add_repr(repr_ns) + if str is True: + builder.add_str() + if cmp is True: + builder.add_cmp() + + if hash is not True and hash is not False and hash is not None: + # Can't use `hash in` because 1 == True for example. + raise TypeError( + "Invalid value for hash. Must be True, False, or None." + ) + elif hash is False or (hash is None and cmp is False): + if cache_hash: + raise TypeError( + "Invalid value for cache_hash. To use hash caching," + " hashing must be either explicitly or implicitly " + "enabled." + ) + elif hash is True or (hash is None and cmp is True and frozen is True): + builder.add_hash() + else: + if cache_hash: + raise TypeError( + "Invalid value for cache_hash. To use hash caching," + " hashing must be either explicitly or implicitly " + "enabled." + ) + builder.make_unhashable() + + if init is True: + builder.add_init() + else: + if cache_hash: + raise TypeError( + "Invalid value for cache_hash. To use hash caching," + " init must be True." + ) + + return builder.build_class() + + # maybe_cls's type depends on the usage of the decorator. It's a class + # if it's used as `@attrs` but ``None`` if used as `@attrs()`. + if maybe_cls is None: + return wrap + else: + return wrap(maybe_cls) + + +_attrs = attrs +""" +Internal alias so we can use it in functions that take an argument called +*attrs*. +""" + + +if PY2: + + def _has_frozen_base_class(cls): + """ + Check whether *cls* has a frozen ancestor by looking at its + __setattr__. + """ + return ( + getattr(cls.__setattr__, "__module__", None) + == _frozen_setattrs.__module__ + and cls.__setattr__.__name__ == _frozen_setattrs.__name__ + ) + + +else: + + def _has_frozen_base_class(cls): + """ + Check whether *cls* has a frozen ancestor by looking at its + __setattr__. + """ + return cls.__setattr__ == _frozen_setattrs + + +def _attrs_to_tuple(obj, attrs): + """ + Create a tuple of all values of *obj*'s *attrs*. + """ + return tuple(getattr(obj, a.name) for a in attrs) + + +def _make_hash(attrs, frozen, cache_hash): + attrs = tuple( + a + for a in attrs + if a.hash is True or (a.hash is None and a.cmp is True) + ) + + tab = " " + + # We cache the generated hash methods for the same kinds of attributes. + sha1 = hashlib.sha1() + sha1.update(repr(attrs).encode("utf-8")) + unique_filename = "<attrs generated hash %s>" % (sha1.hexdigest(),) + type_hash = hash(unique_filename) + + method_lines = ["def __hash__(self):"] + + def append_hash_computation_lines(prefix, indent): + """ + Generate the code for actually computing the hash code. + Below this will either be returned directly or used to compute + a value which is then cached, depending on the value of cache_hash + """ + method_lines.extend( + [indent + prefix + "hash((", indent + " %d," % (type_hash,)] + ) + + for a in attrs: + method_lines.append(indent + " self.%s," % a.name) + + method_lines.append(indent + " ))") + + if cache_hash: + method_lines.append(tab + "if self.%s is None:" % _hash_cache_field) + if frozen: + append_hash_computation_lines( + "object.__setattr__(self, '%s', " % _hash_cache_field, tab * 2 + ) + method_lines.append(tab * 2 + ")") # close __setattr__ + else: + append_hash_computation_lines( + "self.%s = " % _hash_cache_field, tab * 2 + ) + method_lines.append(tab + "return self.%s" % _hash_cache_field) + else: + append_hash_computation_lines("return ", tab) + + script = "\n".join(method_lines) + globs = {} + locs = {} + bytecode = compile(script, unique_filename, "exec") + eval(bytecode, globs, locs) + + # In order of debuggers like PDB being able to step through the code, + # we add a fake linecache entry. + linecache.cache[unique_filename] = ( + len(script), + None, + script.splitlines(True), + unique_filename, + ) + + return locs["__hash__"] + + +def _add_hash(cls, attrs): + """ + Add a hash method to *cls*. + """ + cls.__hash__ = _make_hash(attrs, frozen=False, cache_hash=False) + return cls + + +def __ne__(self, other): + """ + Check equality and either forward a NotImplemented or return the result + negated. + """ + result = self.__eq__(other) + if result is NotImplemented: + return NotImplemented + + return not result + + +WARNING_CMP_ISINSTANCE = ( + "Comparision of subclasses using __%s__ is deprecated and will be removed " + "in 2019." +) + + +def _make_cmp(attrs): + attrs = [a for a in attrs if a.cmp] + + # We cache the generated eq methods for the same kinds of attributes. + sha1 = hashlib.sha1() + sha1.update(repr(attrs).encode("utf-8")) + unique_filename = "<attrs generated eq %s>" % (sha1.hexdigest(),) + lines = [ + "def __eq__(self, other):", + " if other.__class__ is not self.__class__:", + " return NotImplemented", + ] + # We can't just do a big self.x = other.x and... clause due to + # irregularities like nan == nan is false but (nan,) == (nan,) is true. + if attrs: + lines.append(" return (") + others = [" ) == ("] + for a in attrs: + lines.append(" self.%s," % (a.name,)) + others.append(" other.%s," % (a.name,)) + + lines += others + [" )"] + else: + lines.append(" return True") + + script = "\n".join(lines) + globs = {} + locs = {} + bytecode = compile(script, unique_filename, "exec") + eval(bytecode, globs, locs) + + # In order of debuggers like PDB being able to step through the code, + # we add a fake linecache entry. + linecache.cache[unique_filename] = ( + len(script), + None, + script.splitlines(True), + unique_filename, + ) + eq = locs["__eq__"] + ne = __ne__ + + def attrs_to_tuple(obj): + """ + Save us some typing. + """ + return _attrs_to_tuple(obj, attrs) + + def __lt__(self, other): + """ + Automatically created by attrs. + """ + if isinstance(other, self.__class__): + if other.__class__ is not self.__class__: + warnings.warn( + WARNING_CMP_ISINSTANCE % ("lt",), DeprecationWarning + ) + return attrs_to_tuple(self) < attrs_to_tuple(other) + else: + return NotImplemented + + def __le__(self, other): + """ + Automatically created by attrs. + """ + if isinstance(other, self.__class__): + if other.__class__ is not self.__class__: + warnings.warn( + WARNING_CMP_ISINSTANCE % ("le",), DeprecationWarning + ) + return attrs_to_tuple(self) <= attrs_to_tuple(other) + else: + return NotImplemented + + def __gt__(self, other): + """ + Automatically created by attrs. + """ + if isinstance(other, self.__class__): + if other.__class__ is not self.__class__: + warnings.warn( + WARNING_CMP_ISINSTANCE % ("gt",), DeprecationWarning + ) + return attrs_to_tuple(self) > attrs_to_tuple(other) + else: + return NotImplemented + + def __ge__(self, other): + """ + Automatically created by attrs. + """ + if isinstance(other, self.__class__): + if other.__class__ is not self.__class__: + warnings.warn( + WARNING_CMP_ISINSTANCE % ("ge",), DeprecationWarning + ) + return attrs_to_tuple(self) >= attrs_to_tuple(other) + else: + return NotImplemented + + return eq, ne, __lt__, __le__, __gt__, __ge__ + + +def _add_cmp(cls, attrs=None): + """ + Add comparison methods to *cls*. + """ + if attrs is None: + attrs = cls.__attrs_attrs__ + + cls.__eq__, cls.__ne__, cls.__lt__, cls.__le__, cls.__gt__, cls.__ge__ = _make_cmp( # noqa + attrs + ) + + return cls + + +_already_repring = threading.local() + + +def _make_repr(attrs, ns): + """ + Make a repr method for *attr_names* adding *ns* to the full name. + """ + attr_names = tuple(a.name for a in attrs if a.repr) + + def __repr__(self): + """ + Automatically created by attrs. + """ + try: + working_set = _already_repring.working_set + except AttributeError: + working_set = set() + _already_repring.working_set = working_set + + if id(self) in working_set: + return "..." + real_cls = self.__class__ + if ns is None: + qualname = getattr(real_cls, "__qualname__", None) + if qualname is not None: + class_name = qualname.rsplit(">.", 1)[-1] + else: + class_name = real_cls.__name__ + else: + class_name = ns + "." + real_cls.__name__ + + # Since 'self' remains on the stack (i.e.: strongly referenced) for the + # duration of this call, it's safe to depend on id(...) stability, and + # not need to track the instance and therefore worry about properties + # like weakref- or hash-ability. + working_set.add(id(self)) + try: + result = [class_name, "("] + first = True + for name in attr_names: + if first: + first = False + else: + result.append(", ") + result.extend((name, "=", repr(getattr(self, name, NOTHING)))) + return "".join(result) + ")" + finally: + working_set.remove(id(self)) + + return __repr__ + + +def _add_repr(cls, ns=None, attrs=None): + """ + Add a repr method to *cls*. + """ + if attrs is None: + attrs = cls.__attrs_attrs__ + + cls.__repr__ = _make_repr(attrs, ns) + return cls + + +def _make_init(attrs, post_init, frozen, slots, cache_hash, base_attr_map): + attrs = [a for a in attrs if a.init or a.default is not NOTHING] + + # We cache the generated init methods for the same kinds of attributes. + sha1 = hashlib.sha1() + sha1.update(repr(attrs).encode("utf-8")) + unique_filename = "<attrs generated init {0}>".format(sha1.hexdigest()) + + script, globs, annotations = _attrs_to_init_script( + attrs, frozen, slots, post_init, cache_hash, base_attr_map + ) + locs = {} + bytecode = compile(script, unique_filename, "exec") + attr_dict = dict((a.name, a) for a in attrs) + globs.update({"NOTHING": NOTHING, "attr_dict": attr_dict}) + if frozen is True: + # Save the lookup overhead in __init__ if we need to circumvent + # immutability. + globs["_cached_setattr"] = _obj_setattr + eval(bytecode, globs, locs) + + # In order of debuggers like PDB being able to step through the code, + # we add a fake linecache entry. + linecache.cache[unique_filename] = ( + len(script), + None, + script.splitlines(True), + unique_filename, + ) + + __init__ = locs["__init__"] + __init__.__annotations__ = annotations + return __init__ + + +def _add_init(cls, frozen): + """ + Add a __init__ method to *cls*. If *frozen* is True, make it immutable. + """ + cls.__init__ = _make_init( + cls.__attrs_attrs__, + getattr(cls, "__attrs_post_init__", False), + frozen, + _is_slot_cls(cls), + cache_hash=False, + base_attr_map={}, + ) + return cls + + +def fields(cls): + """ + Return the tuple of ``attrs`` attributes for a class. + + The tuple also allows accessing the fields by their names (see below for + examples). + + :param type cls: Class to introspect. + + :raise TypeError: If *cls* is not a class. + :raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs`` + class. + + :rtype: tuple (with name accessors) of :class:`attr.Attribute` + + .. versionchanged:: 16.2.0 Returned tuple allows accessing the fields + by name. + """ + if not isclass(cls): + raise TypeError("Passed object must be a class.") + attrs = getattr(cls, "__attrs_attrs__", None) + if attrs is None: + raise NotAnAttrsClassError( + "{cls!r} is not an attrs-decorated class.".format(cls=cls) + ) + return attrs + + +def fields_dict(cls): + """ + Return an ordered dictionary of ``attrs`` attributes for a class, whose + keys are the attribute names. + + :param type cls: Class to introspect. + + :raise TypeError: If *cls* is not a class. + :raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs`` + class. + + :rtype: an ordered dict where keys are attribute names and values are + :class:`attr.Attribute`\\ s. This will be a :class:`dict` if it's + naturally ordered like on Python 3.6+ or an + :class:`~collections.OrderedDict` otherwise. + + .. versionadded:: 18.1.0 + """ + if not isclass(cls): + raise TypeError("Passed object must be a class.") + attrs = getattr(cls, "__attrs_attrs__", None) + if attrs is None: + raise NotAnAttrsClassError( + "{cls!r} is not an attrs-decorated class.".format(cls=cls) + ) + return ordered_dict(((a.name, a) for a in attrs)) + + +def validate(inst): + """ + Validate all attributes on *inst* that have a validator. + + Leaves all exceptions through. + + :param inst: Instance of a class with ``attrs`` attributes. + """ + if _config._run_validators is False: + return + + for a in fields(inst.__class__): + v = a.validator + if v is not None: + v(inst, a, getattr(inst, a.name)) + + +def _is_slot_cls(cls): + return "__slots__" in cls.__dict__ + + +def _is_slot_attr(a_name, base_attr_map): + """ + Check if the attribute name comes from a slot class. + """ + return a_name in base_attr_map and _is_slot_cls(base_attr_map[a_name]) + + +def _attrs_to_init_script( + attrs, frozen, slots, post_init, cache_hash, base_attr_map +): + """ + Return a script of an initializer for *attrs* and a dict of globals. + + The globals are expected by the generated script. + + If *frozen* is True, we cannot set the attributes directly so we use + a cached ``object.__setattr__``. + """ + lines = [] + any_slot_ancestors = any( + _is_slot_attr(a.name, base_attr_map) for a in attrs + ) + if frozen is True: + if slots is True: + lines.append( + # Circumvent the __setattr__ descriptor to save one lookup per + # assignment. + # Note _setattr will be used again below if cache_hash is True + "_setattr = _cached_setattr.__get__(self, self.__class__)" + ) + + def fmt_setter(attr_name, value_var): + return "_setattr('%(attr_name)s', %(value_var)s)" % { + "attr_name": attr_name, + "value_var": value_var, + } + + def fmt_setter_with_converter(attr_name, value_var): + conv_name = _init_converter_pat.format(attr_name) + return "_setattr('%(attr_name)s', %(conv)s(%(value_var)s))" % { + "attr_name": attr_name, + "value_var": value_var, + "conv": conv_name, + } + + else: + # Dict frozen classes assign directly to __dict__. + # But only if the attribute doesn't come from an ancestor slot + # class. + # Note _inst_dict will be used again below if cache_hash is True + lines.append("_inst_dict = self.__dict__") + if any_slot_ancestors: + lines.append( + # Circumvent the __setattr__ descriptor to save one lookup + # per assignment. + "_setattr = _cached_setattr.__get__(self, self.__class__)" + ) + + def fmt_setter(attr_name, value_var): + if _is_slot_attr(attr_name, base_attr_map): + res = "_setattr('%(attr_name)s', %(value_var)s)" % { + "attr_name": attr_name, + "value_var": value_var, + } + else: + res = "_inst_dict['%(attr_name)s'] = %(value_var)s" % { + "attr_name": attr_name, + "value_var": value_var, + } + return res + + def fmt_setter_with_converter(attr_name, value_var): + conv_name = _init_converter_pat.format(attr_name) + if _is_slot_attr(attr_name, base_attr_map): + tmpl = "_setattr('%(attr_name)s', %(c)s(%(value_var)s))" + else: + tmpl = "_inst_dict['%(attr_name)s'] = %(c)s(%(value_var)s)" + return tmpl % { + "attr_name": attr_name, + "value_var": value_var, + "c": conv_name, + } + + else: + # Not frozen. + def fmt_setter(attr_name, value): + return "self.%(attr_name)s = %(value)s" % { + "attr_name": attr_name, + "value": value, + } + + def fmt_setter_with_converter(attr_name, value_var): + conv_name = _init_converter_pat.format(attr_name) + return "self.%(attr_name)s = %(conv)s(%(value_var)s)" % { + "attr_name": attr_name, + "value_var": value_var, + "conv": conv_name, + } + + args = [] + kw_only_args = [] + attrs_to_validate = [] + + # This is a dictionary of names to validator and converter callables. + # Injecting this into __init__ globals lets us avoid lookups. + names_for_globals = {} + annotations = {"return": None} + + for a in attrs: + if a.validator: + attrs_to_validate.append(a) + attr_name = a.name + arg_name = a.name.lstrip("_") + has_factory = isinstance(a.default, Factory) + if has_factory and a.default.takes_self: + maybe_self = "self" + else: + maybe_self = "" + if a.init is False: + if has_factory: + init_factory_name = _init_factory_pat.format(a.name) + if a.converter is not None: + lines.append( + fmt_setter_with_converter( + attr_name, + init_factory_name + "({0})".format(maybe_self), + ) + ) + conv_name = _init_converter_pat.format(a.name) + names_for_globals[conv_name] = a.converter + else: + lines.append( + fmt_setter( + attr_name, + init_factory_name + "({0})".format(maybe_self), + ) + ) + names_for_globals[init_factory_name] = a.default.factory + else: + if a.converter is not None: + lines.append( + fmt_setter_with_converter( + attr_name, + "attr_dict['{attr_name}'].default".format( + attr_name=attr_name + ), + ) + ) + conv_name = _init_converter_pat.format(a.name) + names_for_globals[conv_name] = a.converter + else: + lines.append( + fmt_setter( + attr_name, + "attr_dict['{attr_name}'].default".format( + attr_name=attr_name + ), + ) + ) + elif a.default is not NOTHING and not has_factory: + arg = "{arg_name}=attr_dict['{attr_name}'].default".format( + arg_name=arg_name, attr_name=attr_name + ) + if a.kw_only: + kw_only_args.append(arg) + else: + args.append(arg) + if a.converter is not None: + lines.append(fmt_setter_with_converter(attr_name, arg_name)) + names_for_globals[ + _init_converter_pat.format(a.name) + ] = a.converter + else: + lines.append(fmt_setter(attr_name, arg_name)) + elif has_factory: + arg = "{arg_name}=NOTHING".format(arg_name=arg_name) + if a.kw_only: + kw_only_args.append(arg) + else: + args.append(arg) + lines.append( + "if {arg_name} is not NOTHING:".format(arg_name=arg_name) + ) + init_factory_name = _init_factory_pat.format(a.name) + if a.converter is not None: + lines.append( + " " + fmt_setter_with_converter(attr_name, arg_name) + ) + lines.append("else:") + lines.append( + " " + + fmt_setter_with_converter( + attr_name, + init_factory_name + "({0})".format(maybe_self), + ) + ) + names_for_globals[ + _init_converter_pat.format(a.name) + ] = a.converter + else: + lines.append(" " + fmt_setter(attr_name, arg_name)) + lines.append("else:") + lines.append( + " " + + fmt_setter( + attr_name, + init_factory_name + "({0})".format(maybe_self), + ) + ) + names_for_globals[init_factory_name] = a.default.factory + else: + if a.kw_only: + kw_only_args.append(arg_name) + else: + args.append(arg_name) + if a.converter is not None: + lines.append(fmt_setter_with_converter(attr_name, arg_name)) + names_for_globals[ + _init_converter_pat.format(a.name) + ] = a.converter + else: + lines.append(fmt_setter(attr_name, arg_name)) + + if a.init is True and a.converter is None and a.type is not None: + annotations[arg_name] = a.type + + if attrs_to_validate: # we can skip this if there are no validators. + names_for_globals["_config"] = _config + lines.append("if _config._run_validators is True:") + for a in attrs_to_validate: + val_name = "__attr_validator_{}".format(a.name) + attr_name = "__attr_{}".format(a.name) + lines.append( + " {}(self, {}, self.{})".format(val_name, attr_name, a.name) + ) + names_for_globals[val_name] = a.validator + names_for_globals[attr_name] = a + if post_init: + lines.append("self.__attrs_post_init__()") + + # because this is set only after __attrs_post_init is called, a crash + # will result if post-init tries to access the hash code. This seemed + # preferable to setting this beforehand, in which case alteration to + # field values during post-init combined with post-init accessing the + # hash code would result in silent bugs. + if cache_hash: + if frozen: + if slots: + # if frozen and slots, then _setattr defined above + init_hash_cache = "_setattr('%s', %s)" + else: + # if frozen and not slots, then _inst_dict defined above + init_hash_cache = "_inst_dict['%s'] = %s" + else: + init_hash_cache = "self.%s = %s" + lines.append(init_hash_cache % (_hash_cache_field, "None")) + + args = ", ".join(args) + if kw_only_args: + if PY2: + raise PythonTooOldError( + "Keyword-only arguments only work on Python 3 and later." + ) + + args += "{leading_comma}*, {kw_only_args}".format( + leading_comma=", " if args else "", + kw_only_args=", ".join(kw_only_args), + ) + return ( + """\ +def __init__(self, {args}): + {lines} +""".format( + args=args, lines="\n ".join(lines) if lines else "pass" + ), + names_for_globals, + annotations, + ) + + +class Attribute(object): + """ + *Read-only* representation of an attribute. + + :attribute name: The name of the attribute. + + Plus *all* arguments of :func:`attr.ib`. + + For the version history of the fields, see :func:`attr.ib`. + """ + + __slots__ = ( + "name", + "default", + "validator", + "repr", + "cmp", + "hash", + "init", + "metadata", + "type", + "converter", + "kw_only", + ) + + def __init__( + self, + name, + default, + validator, + repr, + cmp, + hash, + init, + convert=None, + metadata=None, + type=None, + converter=None, + kw_only=False, + ): + # Cache this descriptor here to speed things up later. + bound_setattr = _obj_setattr.__get__(self, Attribute) + + # Despite the big red warning, people *do* instantiate `Attribute` + # themselves. + if convert is not None: + if converter is not None: + raise RuntimeError( + "Can't pass both `convert` and `converter`. " + "Please use `converter` only." + ) + warnings.warn( + "The `convert` argument is deprecated in favor of `converter`." + " It will be removed after 2019/01.", + DeprecationWarning, + stacklevel=2, + ) + converter = convert + + bound_setattr("name", name) + bound_setattr("default", default) + bound_setattr("validator", validator) + bound_setattr("repr", repr) + bound_setattr("cmp", cmp) + bound_setattr("hash", hash) + bound_setattr("init", init) + bound_setattr("converter", converter) + bound_setattr( + "metadata", + ( + metadata_proxy(metadata) + if metadata + else _empty_metadata_singleton + ), + ) + bound_setattr("type", type) + bound_setattr("kw_only", kw_only) + + def __setattr__(self, name, value): + raise FrozenInstanceError() + + @property + def convert(self): + warnings.warn( + "The `convert` attribute is deprecated in favor of `converter`. " + "It will be removed after 2019/01.", + DeprecationWarning, + stacklevel=2, + ) + return self.converter + + @classmethod + def from_counting_attr(cls, name, ca, type=None): + # type holds the annotated value. deal with conflicts: + if type is None: + type = ca.type + elif ca.type is not None: + raise ValueError( + "Type annotation and type argument cannot both be present" + ) + inst_dict = { + k: getattr(ca, k) + for k in Attribute.__slots__ + if k + not in ( + "name", + "validator", + "default", + "type", + "convert", + ) # exclude methods and deprecated alias + } + return cls( + name=name, + validator=ca._validator, + default=ca._default, + type=type, + **inst_dict + ) + + # Don't use attr.assoc since fields(Attribute) doesn't work + def _assoc(self, **changes): + """ + Copy *self* and apply *changes*. + """ + new = copy.copy(self) + + new._setattrs(changes.items()) + + return new + + # Don't use _add_pickle since fields(Attribute) doesn't work + def __getstate__(self): + """ + Play nice with pickle. + """ + return tuple( + getattr(self, name) if name != "metadata" else dict(self.metadata) + for name in self.__slots__ + ) + + def __setstate__(self, state): + """ + Play nice with pickle. + """ + self._setattrs(zip(self.__slots__, state)) + + def _setattrs(self, name_values_pairs): + bound_setattr = _obj_setattr.__get__(self, Attribute) + for name, value in name_values_pairs: + if name != "metadata": + bound_setattr(name, value) + else: + bound_setattr( + name, + metadata_proxy(value) + if value + else _empty_metadata_singleton, + ) + + +_a = [ + Attribute( + name=name, + default=NOTHING, + validator=None, + repr=True, + cmp=True, + hash=(name != "metadata"), + init=True, + ) + for name in Attribute.__slots__ + if name != "convert" # XXX: remove once `convert` is gone +] + +Attribute = _add_hash( + _add_cmp(_add_repr(Attribute, attrs=_a), attrs=_a), + attrs=[a for a in _a if a.hash], +) + + +class _CountingAttr(object): + """ + Intermediate representation of attributes that uses a counter to preserve + the order in which the attributes have been defined. + + *Internal* data structure of the attrs library. Running into is most + likely the result of a bug like a forgotten `@attr.s` decorator. + """ + + __slots__ = ( + "counter", + "_default", + "repr", + "cmp", + "hash", + "init", + "metadata", + "_validator", + "converter", + "type", + "kw_only", + ) + __attrs_attrs__ = tuple( + Attribute( + name=name, + default=NOTHING, + validator=None, + repr=True, + cmp=True, + hash=True, + init=True, + kw_only=False, + ) + for name in ("counter", "_default", "repr", "cmp", "hash", "init") + ) + ( + Attribute( + name="metadata", + default=None, + validator=None, + repr=True, + cmp=True, + hash=False, + init=True, + kw_only=False, + ), + ) + cls_counter = 0 + + def __init__( + self, + default, + validator, + repr, + cmp, + hash, + init, + converter, + metadata, + type, + kw_only, + ): + _CountingAttr.cls_counter += 1 + self.counter = _CountingAttr.cls_counter + self._default = default + # If validator is a list/tuple, wrap it using helper validator. + if validator and isinstance(validator, (list, tuple)): + self._validator = and_(*validator) + else: + self._validator = validator + self.repr = repr + self.cmp = cmp + self.hash = hash + self.init = init + self.converter = converter + self.metadata = metadata + self.type = type + self.kw_only = kw_only + + def validator(self, meth): + """ + Decorator that adds *meth* to the list of validators. + + Returns *meth* unchanged. + + .. versionadded:: 17.1.0 + """ + if self._validator is None: + self._validator = meth + else: + self._validator = and_(self._validator, meth) + return meth + + def default(self, meth): + """ + Decorator that allows to set the default for an attribute. + + Returns *meth* unchanged. + + :raises DefaultAlreadySetError: If default has been set before. + + .. versionadded:: 17.1.0 + """ + if self._default is not NOTHING: + raise DefaultAlreadySetError() + + self._default = Factory(meth, takes_self=True) + + return meth + + +_CountingAttr = _add_cmp(_add_repr(_CountingAttr)) + + +@attrs(slots=True, init=False, hash=True) +class Factory(object): + """ + Stores a factory callable. + + If passed as the default value to :func:`attr.ib`, the factory is used to + generate a new value. + + :param callable factory: A callable that takes either none or exactly one + mandatory positional argument depending on *takes_self*. + :param bool takes_self: Pass the partially initialized instance that is + being initialized as a positional argument. + + .. versionadded:: 17.1.0 *takes_self* + """ + + factory = attrib() + takes_self = attrib() + + def __init__(self, factory, takes_self=False): + """ + `Factory` is part of the default machinery so if we want a default + value here, we have to implement it ourselves. + """ + self.factory = factory + self.takes_self = takes_self + + +def make_class(name, attrs, bases=(object,), **attributes_arguments): + """ + A quick way to create a new class called *name* with *attrs*. + + :param name: The name for the new class. + :type name: str + + :param attrs: A list of names or a dictionary of mappings of names to + attributes. + + If *attrs* is a list or an ordered dict (:class:`dict` on Python 3.6+, + :class:`collections.OrderedDict` otherwise), the order is deduced from + the order of the names or attributes inside *attrs*. Otherwise the + order of the definition of the attributes is used. + :type attrs: :class:`list` or :class:`dict` + + :param tuple bases: Classes that the new class will subclass. + + :param attributes_arguments: Passed unmodified to :func:`attr.s`. + + :return: A new class with *attrs*. + :rtype: type + + .. versionadded:: 17.1.0 *bases* + .. versionchanged:: 18.1.0 If *attrs* is ordered, the order is retained. + """ + if isinstance(attrs, dict): + cls_dict = attrs + elif isinstance(attrs, (list, tuple)): + cls_dict = dict((a, attrib()) for a in attrs) + else: + raise TypeError("attrs argument must be a dict or a list.") + + post_init = cls_dict.pop("__attrs_post_init__", None) + type_ = type( + name, + bases, + {} if post_init is None else {"__attrs_post_init__": post_init}, + ) + # For pickling to work, the __module__ variable needs to be set to the + # frame where the class is created. Bypass this step in environments where + # sys._getframe is not defined (Jython for example) or sys._getframe is not + # defined for arguments greater than 0 (IronPython). + try: + type_.__module__ = sys._getframe(1).f_globals.get( + "__name__", "__main__" + ) + except (AttributeError, ValueError): + pass + + return _attrs(these=cls_dict, **attributes_arguments)(type_) + + +# These are required by within this module so we define them here and merely +# import into .validators. + + +@attrs(slots=True, hash=True) +class _AndValidator(object): + """ + Compose many validators to a single one. + """ + + _validators = attrib() + + def __call__(self, inst, attr, value): + for v in self._validators: + v(inst, attr, value) + + +def and_(*validators): + """ + A validator that composes multiple validators into one. + + When called on a value, it runs all wrapped validators. + + :param validators: Arbitrary number of validators. + :type validators: callables + + .. versionadded:: 17.1.0 + """ + vals = [] + for validator in validators: + vals.extend( + validator._validators + if isinstance(validator, _AndValidator) + else [validator] + ) + + return _AndValidator(tuple(vals)) diff --git a/python/attr/converters.py b/python/attr/converters.py new file mode 100644 index 0000000..37c4a07 --- /dev/null +++ b/python/attr/converters.py @@ -0,0 +1,78 @@ +""" +Commonly useful converters. +""" + +from __future__ import absolute_import, division, print_function + +from ._make import NOTHING, Factory + + +def optional(converter): + """ + A converter that allows an attribute to be optional. An optional attribute + is one which can be set to ``None``. + + :param callable converter: the converter that is used for non-``None`` + values. + + .. versionadded:: 17.1.0 + """ + + def optional_converter(val): + if val is None: + return None + return converter(val) + + return optional_converter + + +def default_if_none(default=NOTHING, factory=None): + """ + A converter that allows to replace ``None`` values by *default* or the + result of *factory*. + + :param default: Value to be used if ``None`` is passed. Passing an instance + of :class:`attr.Factory` is supported, however the ``takes_self`` option + is *not*. + :param callable factory: A callable that takes not parameters whose result + is used if ``None`` is passed. + + :raises TypeError: If **neither** *default* or *factory* is passed. + :raises TypeError: If **both** *default* and *factory* are passed. + :raises ValueError: If an instance of :class:`attr.Factory` is passed with + ``takes_self=True``. + + .. versionadded:: 18.2.0 + """ + if default is NOTHING and factory is None: + raise TypeError("Must pass either `default` or `factory`.") + + if default is not NOTHING and factory is not None: + raise TypeError( + "Must pass either `default` or `factory` but not both." + ) + + if factory is not None: + default = Factory(factory) + + if isinstance(default, Factory): + if default.takes_self: + raise ValueError( + "`takes_self` is not supported by default_if_none." + ) + + def default_if_none_converter(val): + if val is not None: + return val + + return default.factory() + + else: + + def default_if_none_converter(val): + if val is not None: + return val + + return default + + return default_if_none_converter diff --git a/python/attr/converters.pyi b/python/attr/converters.pyi new file mode 100644 index 0000000..63b2a38 --- /dev/null +++ b/python/attr/converters.pyi @@ -0,0 +1,12 @@ +from typing import TypeVar, Optional, Callable, overload +from . import _ConverterType + +_T = TypeVar("_T") + +def optional( + converter: _ConverterType[_T] +) -> _ConverterType[Optional[_T]]: ... +@overload +def default_if_none(default: _T) -> _ConverterType[_T]: ... +@overload +def default_if_none(*, factory: Callable[[], _T]) -> _ConverterType[_T]: ... diff --git a/python/attr/exceptions.py b/python/attr/exceptions.py new file mode 100644 index 0000000..b12e41e --- /dev/null +++ b/python/attr/exceptions.py @@ -0,0 +1,57 @@ +from __future__ import absolute_import, division, print_function + + +class FrozenInstanceError(AttributeError): + """ + A frozen/immutable instance has been attempted to be modified. + + It mirrors the behavior of ``namedtuples`` by using the same error message + and subclassing :exc:`AttributeError`. + + .. versionadded:: 16.1.0 + """ + + msg = "can't set attribute" + args = [msg] + + +class AttrsAttributeNotFoundError(ValueError): + """ + An ``attrs`` function couldn't find an attribute that the user asked for. + + .. versionadded:: 16.2.0 + """ + + +class NotAnAttrsClassError(ValueError): + """ + A non-``attrs`` class has been passed into an ``attrs`` function. + + .. versionadded:: 16.2.0 + """ + + +class DefaultAlreadySetError(RuntimeError): + """ + A default has been set using ``attr.ib()`` and is attempted to be reset + using the decorator. + + .. versionadded:: 17.1.0 + """ + + +class UnannotatedAttributeError(RuntimeError): + """ + A class with ``auto_attribs=True`` has an ``attr.ib()`` without a type + annotation. + + .. versionadded:: 17.3.0 + """ + + +class PythonTooOldError(RuntimeError): + """ + An ``attrs`` feature requiring a more recent python version has been used. + + .. versionadded:: 18.2.0 + """ diff --git a/python/attr/exceptions.pyi b/python/attr/exceptions.pyi new file mode 100644 index 0000000..48fffcc --- /dev/null +++ b/python/attr/exceptions.pyi @@ -0,0 +1,7 @@ +class FrozenInstanceError(AttributeError): + msg: str = ... + +class AttrsAttributeNotFoundError(ValueError): ... +class NotAnAttrsClassError(ValueError): ... +class DefaultAlreadySetError(RuntimeError): ... +class UnannotatedAttributeError(RuntimeError): ... diff --git a/python/attr/filters.py b/python/attr/filters.py new file mode 100644 index 0000000..f1c69b8 --- /dev/null +++ b/python/attr/filters.py @@ -0,0 +1,52 @@ +""" +Commonly useful filters for :func:`attr.asdict`. +""" + +from __future__ import absolute_import, division, print_function + +from ._compat import isclass +from ._make import Attribute + + +def _split_what(what): + """ + Returns a tuple of `frozenset`s of classes and attributes. + """ + return ( + frozenset(cls for cls in what if isclass(cls)), + frozenset(cls for cls in what if isinstance(cls, Attribute)), + ) + + +def include(*what): + """ + Whitelist *what*. + + :param what: What to whitelist. + :type what: :class:`list` of :class:`type` or :class:`attr.Attribute`\\ s + + :rtype: :class:`callable` + """ + cls, attrs = _split_what(what) + + def include_(attribute, value): + return value.__class__ in cls or attribute in attrs + + return include_ + + +def exclude(*what): + """ + Blacklist *what*. + + :param what: What to blacklist. + :type what: :class:`list` of classes or :class:`attr.Attribute`\\ s. + + :rtype: :class:`callable` + """ + cls, attrs = _split_what(what) + + def exclude_(attribute, value): + return value.__class__ not in cls and attribute not in attrs + + return exclude_ diff --git a/python/attr/filters.pyi b/python/attr/filters.pyi new file mode 100644 index 0000000..a618140 --- /dev/null +++ b/python/attr/filters.pyi @@ -0,0 +1,5 @@ +from typing import Union +from . import Attribute, _FilterType + +def include(*what: Union[type, Attribute]) -> _FilterType: ... +def exclude(*what: Union[type, Attribute]) -> _FilterType: ... diff --git a/python/attr/py.typed b/python/attr/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/python/attr/py.typed diff --git a/python/attr/validators.py b/python/attr/validators.py new file mode 100644 index 0000000..f12d0aa --- /dev/null +++ b/python/attr/validators.py @@ -0,0 +1,170 @@ +""" +Commonly useful validators. +""" + +from __future__ import absolute_import, division, print_function + +from ._make import _AndValidator, and_, attrib, attrs + + +__all__ = ["and_", "in_", "instance_of", "optional", "provides"] + + +@attrs(repr=False, slots=True, hash=True) +class _InstanceOfValidator(object): + type = attrib() + + def __call__(self, inst, attr, value): + """ + We use a callable class to be able to change the ``__repr__``. + """ + if not isinstance(value, self.type): + raise TypeError( + "'{name}' must be {type!r} (got {value!r} that is a " + "{actual!r}).".format( + name=attr.name, + type=self.type, + actual=value.__class__, + value=value, + ), + attr, + self.type, + value, + ) + + def __repr__(self): + return "<instance_of validator for type {type!r}>".format( + type=self.type + ) + + +def instance_of(type): + """ + A validator that raises a :exc:`TypeError` if the initializer is called + with a wrong type for this particular attribute (checks are performed using + :func:`isinstance` therefore it's also valid to pass a tuple of types). + + :param type: The type to check for. + :type type: type or tuple of types + + :raises TypeError: With a human readable error message, the attribute + (of type :class:`attr.Attribute`), the expected type, and the value it + got. + """ + return _InstanceOfValidator(type) + + +@attrs(repr=False, slots=True, hash=True) +class _ProvidesValidator(object): + interface = attrib() + + def __call__(self, inst, attr, value): + """ + We use a callable class to be able to change the ``__repr__``. + """ + if not self.interface.providedBy(value): + raise TypeError( + "'{name}' must provide {interface!r} which {value!r} " + "doesn't.".format( + name=attr.name, interface=self.interface, value=value + ), + attr, + self.interface, + value, + ) + + def __repr__(self): + return "<provides validator for interface {interface!r}>".format( + interface=self.interface + ) + + +def provides(interface): + """ + A validator that raises a :exc:`TypeError` if the initializer is called + with an object that does not provide the requested *interface* (checks are + performed using ``interface.providedBy(value)`` (see `zope.interface + <https://zopeinterface.readthedocs.io/en/latest/>`_). + + :param zope.interface.Interface interface: The interface to check for. + + :raises TypeError: With a human readable error message, the attribute + (of type :class:`attr.Attribute`), the expected interface, and the + value it got. + """ + return _ProvidesValidator(interface) + + +@attrs(repr=False, slots=True, hash=True) +class _OptionalValidator(object): + validator = attrib() + + def __call__(self, inst, attr, value): + if value is None: + return + + self.validator(inst, attr, value) + + def __repr__(self): + return "<optional validator for {what} or None>".format( + what=repr(self.validator) + ) + + +def optional(validator): + """ + A validator that makes an attribute optional. An optional attribute is one + which can be set to ``None`` in addition to satisfying the requirements of + the sub-validator. + + :param validator: A validator (or a list of validators) that is used for + non-``None`` values. + :type validator: callable or :class:`list` of callables. + + .. versionadded:: 15.1.0 + .. versionchanged:: 17.1.0 *validator* can be a list of validators. + """ + if isinstance(validator, list): + return _OptionalValidator(_AndValidator(validator)) + return _OptionalValidator(validator) + + +@attrs(repr=False, slots=True, hash=True) +class _InValidator(object): + options = attrib() + + def __call__(self, inst, attr, value): + try: + in_options = value in self.options + except TypeError as e: # e.g. `1 in "abc"` + in_options = False + + if not in_options: + raise ValueError( + "'{name}' must be in {options!r} (got {value!r})".format( + name=attr.name, options=self.options, value=value + ) + ) + + def __repr__(self): + return "<in_ validator with options {options!r}>".format( + options=self.options + ) + + +def in_(options): + """ + A validator that raises a :exc:`ValueError` if the initializer is called + with a value that does not belong in the options provided. The check is + performed using ``value in options``. + + :param options: Allowed options. + :type options: list, tuple, :class:`enum.Enum`, ... + + :raises ValueError: With a human readable error message, the attribute (of + type :class:`attr.Attribute`), the expected options, and the value it + got. + + .. versionadded:: 17.1.0 + """ + return _InValidator(options) diff --git a/python/attr/validators.pyi b/python/attr/validators.pyi new file mode 100644 index 0000000..abbaedf --- /dev/null +++ b/python/attr/validators.pyi @@ -0,0 +1,14 @@ +from typing import Container, List, Union, TypeVar, Type, Any, Optional, Tuple +from . import _ValidatorType + +_T = TypeVar("_T") + +def instance_of( + type: Union[Tuple[Type[_T], ...], Type[_T]] +) -> _ValidatorType[_T]: ... +def provides(interface: Any) -> _ValidatorType[Any]: ... +def optional( + validator: Union[_ValidatorType[_T], List[_ValidatorType[_T]]] +) -> _ValidatorType[Optional[_T]]: ... +def in_(options: Container[_T]) -> _ValidatorType[_T]: ... +def and_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ... diff --git a/python/dateutil/__init__.py b/python/dateutil/__init__.py new file mode 100644 index 0000000..796ef3d --- /dev/null +++ b/python/dateutil/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +from ._version import VERSION as __version__ diff --git a/python/dateutil/_common.py b/python/dateutil/_common.py new file mode 100644 index 0000000..e8b4af7 --- /dev/null +++ b/python/dateutil/_common.py @@ -0,0 +1,34 @@ +""" +Common code used in multiple modules. +""" + + +class weekday(object): + __slots__ = ["weekday", "n"] + + def __init__(self, weekday, n=None): + self.weekday = weekday + self.n = n + + def __call__(self, n): + if n == self.n: + return self + else: + return self.__class__(self.weekday, n) + + def __eq__(self, other): + try: + if self.weekday != other.weekday or self.n != other.n: + return False + except AttributeError: + return False + return True + + __hash__ = None + + def __repr__(self): + s = ("MO", "TU", "WE", "TH", "FR", "SA", "SU")[self.weekday] + if not self.n: + return s + else: + return "%s(%+d)" % (s, self.n) diff --git a/python/dateutil/_version.py b/python/dateutil/_version.py new file mode 100644 index 0000000..c1a0357 --- /dev/null +++ b/python/dateutil/_version.py @@ -0,0 +1,10 @@ +""" +Contains information about the dateutil version. +""" + +VERSION_MAJOR = 2 +VERSION_MINOR = 6 +VERSION_PATCH = 1 + +VERSION_TUPLE = (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH) +VERSION = '.'.join(map(str, VERSION_TUPLE)) diff --git a/python/dateutil/easter.py b/python/dateutil/easter.py new file mode 100644 index 0000000..e4def97 --- /dev/null +++ b/python/dateutil/easter.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +""" +This module offers a generic easter computing method for any given year, using +Western, Orthodox or Julian algorithms. +""" + +import datetime + +__all__ = ["easter", "EASTER_JULIAN", "EASTER_ORTHODOX", "EASTER_WESTERN"] + +EASTER_JULIAN = 1 +EASTER_ORTHODOX = 2 +EASTER_WESTERN = 3 + + +def easter(year, method=EASTER_WESTERN): + """ + This method was ported from the work done by GM Arts, + on top of the algorithm by Claus Tondering, which was + based in part on the algorithm of Ouding (1940), as + quoted in "Explanatory Supplement to the Astronomical + Almanac", P. Kenneth Seidelmann, editor. + + This algorithm implements three different easter + calculation methods: + + 1 - Original calculation in Julian calendar, valid in + dates after 326 AD + 2 - Original method, with date converted to Gregorian + calendar, valid in years 1583 to 4099 + 3 - Revised method, in Gregorian calendar, valid in + years 1583 to 4099 as well + + These methods are represented by the constants: + + * ``EASTER_JULIAN = 1`` + * ``EASTER_ORTHODOX = 2`` + * ``EASTER_WESTERN = 3`` + + The default method is method 3. + + More about the algorithm may be found at: + + http://users.chariot.net.au/~gmarts/eastalg.htm + + and + + http://www.tondering.dk/claus/calendar.html + + """ + + if not (1 <= method <= 3): + raise ValueError("invalid method") + + # g - Golden year - 1 + # c - Century + # h - (23 - Epact) mod 30 + # i - Number of days from March 21 to Paschal Full Moon + # j - Weekday for PFM (0=Sunday, etc) + # p - Number of days from March 21 to Sunday on or before PFM + # (-6 to 28 methods 1 & 3, to 56 for method 2) + # e - Extra days to add for method 2 (converting Julian + # date to Gregorian date) + + y = year + g = y % 19 + e = 0 + if method < 3: + # Old method + i = (19*g + 15) % 30 + j = (y + y//4 + i) % 7 + if method == 2: + # Extra dates to convert Julian to Gregorian date + e = 10 + if y > 1600: + e = e + y//100 - 16 - (y//100 - 16)//4 + else: + # New method + c = y//100 + h = (c - c//4 - (8*c + 13)//25 + 19*g + 15) % 30 + i = h - (h//28)*(1 - (h//28)*(29//(h + 1))*((21 - g)//11)) + j = (y + y//4 + i + 2 - c + c//4) % 7 + + # p can be from -6 to 56 corresponding to dates 22 March to 23 May + # (later dates apply to method 2, although 23 May never actually occurs) + p = i - j + e + d = 1 + (p + 27 + (p + 6)//40) % 31 + m = 3 + (p + 26)//30 + return datetime.date(int(y), int(m), int(d)) diff --git a/python/dateutil/parser.py b/python/dateutil/parser.py new file mode 100644 index 0000000..595331f --- /dev/null +++ b/python/dateutil/parser.py @@ -0,0 +1,1374 @@ +# -*- coding: utf-8 -*- +""" +This module offers a generic date/time string parser which is able to parse +most known formats to represent a date and/or time. + +This module attempts to be forgiving with regards to unlikely input formats, +returning a datetime object even for dates which are ambiguous. If an element +of a date/time stamp is omitted, the following rules are applied: +- If AM or PM is left unspecified, a 24-hour clock is assumed, however, an hour + on a 12-hour clock (``0 <= hour <= 12``) *must* be specified if AM or PM is + specified. +- If a time zone is omitted, a timezone-naive datetime is returned. + +If any other elements are missing, they are taken from the +:class:`datetime.datetime` object passed to the parameter ``default``. If this +results in a day number exceeding the valid number of days per month, the +value falls back to the end of the month. + +Additional resources about date/time string formats can be found below: + +- `A summary of the international standard date and time notation + <http://www.cl.cam.ac.uk/~mgk25/iso-time.html>`_ +- `W3C Date and Time Formats <http://www.w3.org/TR/NOTE-datetime>`_ +- `Time Formats (Planetary Rings Node) <http://pds-rings.seti.org/tools/time_formats.html>`_ +- `CPAN ParseDate module + <http://search.cpan.org/~muir/Time-modules-2013.0912/lib/Time/ParseDate.pm>`_ +- `Java SimpleDateFormat Class + <https://docs.oracle.com/javase/6/docs/api/java/text/SimpleDateFormat.html>`_ +""" +from __future__ import unicode_literals + +import datetime +import string +import time +import collections +import re +from io import StringIO +from calendar import monthrange + +from six import text_type, binary_type, integer_types + +from . import relativedelta +from . import tz + +__all__ = ["parse", "parserinfo"] + + +class _timelex(object): + # Fractional seconds are sometimes split by a comma + _split_decimal = re.compile("([.,])") + + def __init__(self, instream): + if isinstance(instream, binary_type): + instream = instream.decode() + + if isinstance(instream, text_type): + instream = StringIO(instream) + + if getattr(instream, 'read', None) is None: + raise TypeError('Parser must be a string or character stream, not ' + '{itype}'.format(itype=instream.__class__.__name__)) + + self.instream = instream + self.charstack = [] + self.tokenstack = [] + self.eof = False + + def get_token(self): + """ + This function breaks the time string into lexical units (tokens), which + can be parsed by the parser. Lexical units are demarcated by changes in + the character set, so any continuous string of letters is considered + one unit, any continuous string of numbers is considered one unit. + + The main complication arises from the fact that dots ('.') can be used + both as separators (e.g. "Sep.20.2009") or decimal points (e.g. + "4:30:21.447"). As such, it is necessary to read the full context of + any dot-separated strings before breaking it into tokens; as such, this + function maintains a "token stack", for when the ambiguous context + demands that multiple tokens be parsed at once. + """ + if self.tokenstack: + return self.tokenstack.pop(0) + + seenletters = False + token = None + state = None + + while not self.eof: + # We only realize that we've reached the end of a token when we + # find a character that's not part of the current token - since + # that character may be part of the next token, it's stored in the + # charstack. + if self.charstack: + nextchar = self.charstack.pop(0) + else: + nextchar = self.instream.read(1) + while nextchar == '\x00': + nextchar = self.instream.read(1) + + if not nextchar: + self.eof = True + break + elif not state: + # First character of the token - determines if we're starting + # to parse a word, a number or something else. + token = nextchar + if self.isword(nextchar): + state = 'a' + elif self.isnum(nextchar): + state = '0' + elif self.isspace(nextchar): + token = ' ' + break # emit token + else: + break # emit token + elif state == 'a': + # If we've already started reading a word, we keep reading + # letters until we find something that's not part of a word. + seenletters = True + if self.isword(nextchar): + token += nextchar + elif nextchar == '.': + token += nextchar + state = 'a.' + else: + self.charstack.append(nextchar) + break # emit token + elif state == '0': + # If we've already started reading a number, we keep reading + # numbers until we find something that doesn't fit. + if self.isnum(nextchar): + token += nextchar + elif nextchar == '.' or (nextchar == ',' and len(token) >= 2): + token += nextchar + state = '0.' + else: + self.charstack.append(nextchar) + break # emit token + elif state == 'a.': + # If we've seen some letters and a dot separator, continue + # parsing, and the tokens will be broken up later. + seenletters = True + if nextchar == '.' or self.isword(nextchar): + token += nextchar + elif self.isnum(nextchar) and token[-1] == '.': + token += nextchar + state = '0.' + else: + self.charstack.append(nextchar) + break # emit token + elif state == '0.': + # If we've seen at least one dot separator, keep going, we'll + # break up the tokens later. + if nextchar == '.' or self.isnum(nextchar): + token += nextchar + elif self.isword(nextchar) and token[-1] == '.': + token += nextchar + state = 'a.' + else: + self.charstack.append(nextchar) + break # emit token + + if (state in ('a.', '0.') and (seenletters or token.count('.') > 1 or + token[-1] in '.,')): + l = self._split_decimal.split(token) + token = l[0] + for tok in l[1:]: + if tok: + self.tokenstack.append(tok) + + if state == '0.' and token.count('.') == 0: + token = token.replace(',', '.') + + return token + + def __iter__(self): + return self + + def __next__(self): + token = self.get_token() + if token is None: + raise StopIteration + + return token + + def next(self): + return self.__next__() # Python 2.x support + + @classmethod + def split(cls, s): + return list(cls(s)) + + @classmethod + def isword(cls, nextchar): + """ Whether or not the next character is part of a word """ + return nextchar.isalpha() + + @classmethod + def isnum(cls, nextchar): + """ Whether the next character is part of a number """ + return nextchar.isdigit() + + @classmethod + def isspace(cls, nextchar): + """ Whether the next character is whitespace """ + return nextchar.isspace() + + +class _resultbase(object): + + def __init__(self): + for attr in self.__slots__: + setattr(self, attr, None) + + def _repr(self, classname): + l = [] + for attr in self.__slots__: + value = getattr(self, attr) + if value is not None: + l.append("%s=%s" % (attr, repr(value))) + return "%s(%s)" % (classname, ", ".join(l)) + + def __len__(self): + return (sum(getattr(self, attr) is not None + for attr in self.__slots__)) + + def __repr__(self): + return self._repr(self.__class__.__name__) + + +class parserinfo(object): + """ + Class which handles what inputs are accepted. Subclass this to customize + the language and acceptable values for each parameter. + + :param dayfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the day (``True``) or month (``False``). If + ``yearfirst`` is set to ``True``, this distinguishes between YDM + and YMD. Default is ``False``. + + :param yearfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the year. If ``True``, the first number is taken + to be the year, otherwise the last number is taken to be the year. + Default is ``False``. + """ + + # m from a.m/p.m, t from ISO T separator + JUMP = [" ", ".", ",", ";", "-", "/", "'", + "at", "on", "and", "ad", "m", "t", "of", + "st", "nd", "rd", "th"] + + WEEKDAYS = [("Mon", "Monday"), + ("Tue", "Tuesday"), + ("Wed", "Wednesday"), + ("Thu", "Thursday"), + ("Fri", "Friday"), + ("Sat", "Saturday"), + ("Sun", "Sunday")] + MONTHS = [("Jan", "January"), + ("Feb", "February"), + ("Mar", "March"), + ("Apr", "April"), + ("May", "May"), + ("Jun", "June"), + ("Jul", "July"), + ("Aug", "August"), + ("Sep", "Sept", "September"), + ("Oct", "October"), + ("Nov", "November"), + ("Dec", "December")] + HMS = [("h", "hour", "hours"), + ("m", "minute", "minutes"), + ("s", "second", "seconds")] + AMPM = [("am", "a"), + ("pm", "p")] + UTCZONE = ["UTC", "GMT", "Z"] + PERTAIN = ["of"] + TZOFFSET = {} + + def __init__(self, dayfirst=False, yearfirst=False): + self._jump = self._convert(self.JUMP) + self._weekdays = self._convert(self.WEEKDAYS) + self._months = self._convert(self.MONTHS) + self._hms = self._convert(self.HMS) + self._ampm = self._convert(self.AMPM) + self._utczone = self._convert(self.UTCZONE) + self._pertain = self._convert(self.PERTAIN) + + self.dayfirst = dayfirst + self.yearfirst = yearfirst + + self._year = time.localtime().tm_year + self._century = self._year // 100 * 100 + + def _convert(self, lst): + dct = {} + for i, v in enumerate(lst): + if isinstance(v, tuple): + for v in v: + dct[v.lower()] = i + else: + dct[v.lower()] = i + return dct + + def jump(self, name): + return name.lower() in self._jump + + def weekday(self, name): + if len(name) >= min(len(n) for n in self._weekdays.keys()): + try: + return self._weekdays[name.lower()] + except KeyError: + pass + return None + + def month(self, name): + if len(name) >= min(len(n) for n in self._months.keys()): + try: + return self._months[name.lower()] + 1 + except KeyError: + pass + return None + + def hms(self, name): + try: + return self._hms[name.lower()] + except KeyError: + return None + + def ampm(self, name): + try: + return self._ampm[name.lower()] + except KeyError: + return None + + def pertain(self, name): + return name.lower() in self._pertain + + def utczone(self, name): + return name.lower() in self._utczone + + def tzoffset(self, name): + if name in self._utczone: + return 0 + + return self.TZOFFSET.get(name) + + def convertyear(self, year, century_specified=False): + if year < 100 and not century_specified: + year += self._century + if abs(year - self._year) >= 50: + if year < self._year: + year += 100 + else: + year -= 100 + return year + + def validate(self, res): + # move to info + if res.year is not None: + res.year = self.convertyear(res.year, res.century_specified) + + if res.tzoffset == 0 and not res.tzname or res.tzname == 'Z': + res.tzname = "UTC" + res.tzoffset = 0 + elif res.tzoffset != 0 and res.tzname and self.utczone(res.tzname): + res.tzoffset = 0 + return True + + +class _ymd(list): + def __init__(self, tzstr, *args, **kwargs): + super(self.__class__, self).__init__(*args, **kwargs) + self.century_specified = False + self.tzstr = tzstr + + @staticmethod + def token_could_be_year(token, year): + try: + return int(token) == year + except ValueError: + return False + + @staticmethod + def find_potential_year_tokens(year, tokens): + return [token for token in tokens if _ymd.token_could_be_year(token, year)] + + def find_probable_year_index(self, tokens): + """ + attempt to deduce if a pre 100 year was lost + due to padded zeros being taken off + """ + for index, token in enumerate(self): + potential_year_tokens = _ymd.find_potential_year_tokens(token, tokens) + if len(potential_year_tokens) == 1 and len(potential_year_tokens[0]) > 2: + return index + + def append(self, val): + if hasattr(val, '__len__'): + if val.isdigit() and len(val) > 2: + self.century_specified = True + elif val > 100: + self.century_specified = True + + super(self.__class__, self).append(int(val)) + + def resolve_ymd(self, mstridx, yearfirst, dayfirst): + len_ymd = len(self) + year, month, day = (None, None, None) + + if len_ymd > 3: + raise ValueError("More than three YMD values") + elif len_ymd == 1 or (mstridx != -1 and len_ymd == 2): + # One member, or two members with a month string + if mstridx != -1: + month = self[mstridx] + del self[mstridx] + + if len_ymd > 1 or mstridx == -1: + if self[0] > 31: + year = self[0] + else: + day = self[0] + + elif len_ymd == 2: + # Two members with numbers + if self[0] > 31: + # 99-01 + year, month = self + elif self[1] > 31: + # 01-99 + month, year = self + elif dayfirst and self[1] <= 12: + # 13-01 + day, month = self + else: + # 01-13 + month, day = self + + elif len_ymd == 3: + # Three members + if mstridx == 0: + month, day, year = self + elif mstridx == 1: + if self[0] > 31 or (yearfirst and self[2] <= 31): + # 99-Jan-01 + year, month, day = self + else: + # 01-Jan-01 + # Give precendence to day-first, since + # two-digit years is usually hand-written. + day, month, year = self + + elif mstridx == 2: + # WTF!? + if self[1] > 31: + # 01-99-Jan + day, year, month = self + else: + # 99-01-Jan + year, day, month = self + + else: + if self[0] > 31 or \ + self.find_probable_year_index(_timelex.split(self.tzstr)) == 0 or \ + (yearfirst and self[1] <= 12 and self[2] <= 31): + # 99-01-01 + if dayfirst and self[2] <= 12: + year, day, month = self + else: + year, month, day = self + elif self[0] > 12 or (dayfirst and self[1] <= 12): + # 13-01-01 + day, month, year = self + else: + # 01-13-01 + month, day, year = self + + return year, month, day + + +class parser(object): + def __init__(self, info=None): + self.info = info or parserinfo() + + def parse(self, timestr, default=None, ignoretz=False, tzinfos=None, **kwargs): + """ + Parse the date/time string into a :class:`datetime.datetime` object. + + :param timestr: + Any date/time string using the supported formats. + + :param default: + The default datetime object, if this is a datetime object and not + ``None``, elements specified in ``timestr`` replace elements in the + default object. + + :param ignoretz: + If set ``True``, time zones in parsed strings are ignored and a + naive :class:`datetime.datetime` object is returned. + + :param tzinfos: + Additional time zone names / aliases which may be present in the + string. This argument maps time zone names (and optionally offsets + from those time zones) to time zones. This parameter can be a + dictionary with timezone aliases mapping time zone names to time + zones or a function taking two parameters (``tzname`` and + ``tzoffset``) and returning a time zone. + + The timezones to which the names are mapped can be an integer + offset from UTC in minutes or a :class:`tzinfo` object. + + .. doctest:: + :options: +NORMALIZE_WHITESPACE + + >>> from dateutil.parser import parse + >>> from dateutil.tz import gettz + >>> tzinfos = {"BRST": -10800, "CST": gettz("America/Chicago")} + >>> parse("2012-01-19 17:21:00 BRST", tzinfos=tzinfos) + datetime.datetime(2012, 1, 19, 17, 21, tzinfo=tzoffset(u'BRST', -10800)) + >>> parse("2012-01-19 17:21:00 CST", tzinfos=tzinfos) + datetime.datetime(2012, 1, 19, 17, 21, + tzinfo=tzfile('/usr/share/zoneinfo/America/Chicago')) + + This parameter is ignored if ``ignoretz`` is set. + + :param **kwargs: + Keyword arguments as passed to ``_parse()``. + + :return: + Returns a :class:`datetime.datetime` object or, if the + ``fuzzy_with_tokens`` option is ``True``, returns a tuple, the + first element being a :class:`datetime.datetime` object, the second + a tuple containing the fuzzy tokens. + + :raises ValueError: + Raised for invalid or unknown string format, if the provided + :class:`tzinfo` is not in a valid format, or if an invalid date + would be created. + + :raises TypeError: + Raised for non-string or character stream input. + + :raises OverflowError: + Raised if the parsed date exceeds the largest valid C integer on + your system. + """ + + if default is None: + default = datetime.datetime.now().replace(hour=0, minute=0, + second=0, microsecond=0) + + res, skipped_tokens = self._parse(timestr, **kwargs) + + if res is None: + raise ValueError("Unknown string format") + + if len(res) == 0: + raise ValueError("String does not contain a date.") + + repl = {} + for attr in ("year", "month", "day", "hour", + "minute", "second", "microsecond"): + value = getattr(res, attr) + if value is not None: + repl[attr] = value + + if 'day' not in repl: + # If the default day exceeds the last day of the month, fall back to + # the end of the month. + cyear = default.year if res.year is None else res.year + cmonth = default.month if res.month is None else res.month + cday = default.day if res.day is None else res.day + + if cday > monthrange(cyear, cmonth)[1]: + repl['day'] = monthrange(cyear, cmonth)[1] + + ret = default.replace(**repl) + + if res.weekday is not None and not res.day: + ret = ret+relativedelta.relativedelta(weekday=res.weekday) + + if not ignoretz: + if (isinstance(tzinfos, collections.Callable) or + tzinfos and res.tzname in tzinfos): + + if isinstance(tzinfos, collections.Callable): + tzdata = tzinfos(res.tzname, res.tzoffset) + else: + tzdata = tzinfos.get(res.tzname) + + if isinstance(tzdata, datetime.tzinfo): + tzinfo = tzdata + elif isinstance(tzdata, text_type): + tzinfo = tz.tzstr(tzdata) + elif isinstance(tzdata, integer_types): + tzinfo = tz.tzoffset(res.tzname, tzdata) + else: + raise ValueError("Offset must be tzinfo subclass, " + "tz string, or int offset.") + ret = ret.replace(tzinfo=tzinfo) + elif res.tzname and res.tzname in time.tzname: + ret = ret.replace(tzinfo=tz.tzlocal()) + elif res.tzoffset == 0: + ret = ret.replace(tzinfo=tz.tzutc()) + elif res.tzoffset: + ret = ret.replace(tzinfo=tz.tzoffset(res.tzname, res.tzoffset)) + + if kwargs.get('fuzzy_with_tokens', False): + return ret, skipped_tokens + else: + return ret + + class _result(_resultbase): + __slots__ = ["year", "month", "day", "weekday", + "hour", "minute", "second", "microsecond", + "tzname", "tzoffset", "ampm"] + + def _parse(self, timestr, dayfirst=None, yearfirst=None, fuzzy=False, + fuzzy_with_tokens=False): + """ + Private method which performs the heavy lifting of parsing, called from + ``parse()``, which passes on its ``kwargs`` to this function. + + :param timestr: + The string to parse. + + :param dayfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the day (``True``) or month (``False``). If + ``yearfirst`` is set to ``True``, this distinguishes between YDM + and YMD. If set to ``None``, this value is retrieved from the + current :class:`parserinfo` object (which itself defaults to + ``False``). + + :param yearfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the year. If ``True``, the first number is taken + to be the year, otherwise the last number is taken to be the year. + If this is set to ``None``, the value is retrieved from the current + :class:`parserinfo` object (which itself defaults to ``False``). + + :param fuzzy: + Whether to allow fuzzy parsing, allowing for string like "Today is + January 1, 2047 at 8:21:00AM". + + :param fuzzy_with_tokens: + If ``True``, ``fuzzy`` is automatically set to True, and the parser + will return a tuple where the first element is the parsed + :class:`datetime.datetime` datetimestamp and the second element is + a tuple containing the portions of the string which were ignored: + + .. doctest:: + + >>> from dateutil.parser import parse + >>> parse("Today is January 1, 2047 at 8:21:00AM", fuzzy_with_tokens=True) + (datetime.datetime(2047, 1, 1, 8, 21), (u'Today is ', u' ', u'at ')) + + """ + if fuzzy_with_tokens: + fuzzy = True + + info = self.info + + if dayfirst is None: + dayfirst = info.dayfirst + + if yearfirst is None: + yearfirst = info.yearfirst + + res = self._result() + l = _timelex.split(timestr) # Splits the timestr into tokens + + # keep up with the last token skipped so we can recombine + # consecutively skipped tokens (-2 for when i begins at 0). + last_skipped_token_i = -2 + skipped_tokens = list() + + try: + # year/month/day list + ymd = _ymd(timestr) + + # Index of the month string in ymd + mstridx = -1 + + len_l = len(l) + i = 0 + while i < len_l: + + # Check if it's a number + try: + value_repr = l[i] + value = float(value_repr) + except ValueError: + value = None + + if value is not None: + # Token is a number + len_li = len(l[i]) + i += 1 + + if (len(ymd) == 3 and len_li in (2, 4) + and res.hour is None and (i >= len_l or (l[i] != ':' and + info.hms(l[i]) is None))): + # 19990101T23[59] + s = l[i-1] + res.hour = int(s[:2]) + + if len_li == 4: + res.minute = int(s[2:]) + + elif len_li == 6 or (len_li > 6 and l[i-1].find('.') == 6): + # YYMMDD or HHMMSS[.ss] + s = l[i-1] + + if not ymd and l[i-1].find('.') == -1: + #ymd.append(info.convertyear(int(s[:2]))) + + ymd.append(s[:2]) + ymd.append(s[2:4]) + ymd.append(s[4:]) + else: + # 19990101T235959[.59] + res.hour = int(s[:2]) + res.minute = int(s[2:4]) + res.second, res.microsecond = _parsems(s[4:]) + + elif len_li in (8, 12, 14): + # YYYYMMDD + s = l[i-1] + ymd.append(s[:4]) + ymd.append(s[4:6]) + ymd.append(s[6:8]) + + if len_li > 8: + res.hour = int(s[8:10]) + res.minute = int(s[10:12]) + + if len_li > 12: + res.second = int(s[12:]) + + elif ((i < len_l and info.hms(l[i]) is not None) or + (i+1 < len_l and l[i] == ' ' and + info.hms(l[i+1]) is not None)): + + # HH[ ]h or MM[ ]m or SS[.ss][ ]s + if l[i] == ' ': + i += 1 + + idx = info.hms(l[i]) + + while True: + if idx == 0: + res.hour = int(value) + + if value % 1: + res.minute = int(60*(value % 1)) + + elif idx == 1: + res.minute = int(value) + + if value % 1: + res.second = int(60*(value % 1)) + + elif idx == 2: + res.second, res.microsecond = \ + _parsems(value_repr) + + i += 1 + + if i >= len_l or idx == 2: + break + + # 12h00 + try: + value_repr = l[i] + value = float(value_repr) + except ValueError: + break + else: + i += 1 + idx += 1 + + if i < len_l: + newidx = info.hms(l[i]) + + if newidx is not None: + idx = newidx + + elif (i == len_l and l[i-2] == ' ' and + info.hms(l[i-3]) is not None): + # X h MM or X m SS + idx = info.hms(l[i-3]) + + if idx == 0: # h + res.minute = int(value) + + sec_remainder = value % 1 + if sec_remainder: + res.second = int(60 * sec_remainder) + elif idx == 1: # m + res.second, res.microsecond = \ + _parsems(value_repr) + + # We don't need to advance the tokens here because the + # i == len_l call indicates that we're looking at all + # the tokens already. + + elif i+1 < len_l and l[i] == ':': + # HH:MM[:SS[.ss]] + res.hour = int(value) + i += 1 + value = float(l[i]) + res.minute = int(value) + + if value % 1: + res.second = int(60*(value % 1)) + + i += 1 + + if i < len_l and l[i] == ':': + res.second, res.microsecond = _parsems(l[i+1]) + i += 2 + + elif i < len_l and l[i] in ('-', '/', '.'): + sep = l[i] + ymd.append(value_repr) + i += 1 + + if i < len_l and not info.jump(l[i]): + try: + # 01-01[-01] + ymd.append(l[i]) + except ValueError: + # 01-Jan[-01] + value = info.month(l[i]) + + if value is not None: + ymd.append(value) + assert mstridx == -1 + mstridx = len(ymd)-1 + else: + return None, None + + i += 1 + + if i < len_l and l[i] == sep: + # We have three members + i += 1 + value = info.month(l[i]) + + if value is not None: + ymd.append(value) + mstridx = len(ymd)-1 + assert mstridx == -1 + else: + ymd.append(l[i]) + + i += 1 + elif i >= len_l or info.jump(l[i]): + if i+1 < len_l and info.ampm(l[i+1]) is not None: + # 12 am + res.hour = int(value) + + if res.hour < 12 and info.ampm(l[i+1]) == 1: + res.hour += 12 + elif res.hour == 12 and info.ampm(l[i+1]) == 0: + res.hour = 0 + + i += 1 + else: + # Year, month or day + ymd.append(value) + i += 1 + elif info.ampm(l[i]) is not None: + + # 12am + res.hour = int(value) + + if res.hour < 12 and info.ampm(l[i]) == 1: + res.hour += 12 + elif res.hour == 12 and info.ampm(l[i]) == 0: + res.hour = 0 + i += 1 + + elif not fuzzy: + return None, None + else: + i += 1 + continue + + # Check weekday + value = info.weekday(l[i]) + if value is not None: + res.weekday = value + i += 1 + continue + + # Check month name + value = info.month(l[i]) + if value is not None: + ymd.append(value) + assert mstridx == -1 + mstridx = len(ymd)-1 + + i += 1 + if i < len_l: + if l[i] in ('-', '/'): + # Jan-01[-99] + sep = l[i] + i += 1 + ymd.append(l[i]) + i += 1 + + if i < len_l and l[i] == sep: + # Jan-01-99 + i += 1 + ymd.append(l[i]) + i += 1 + + elif (i+3 < len_l and l[i] == l[i+2] == ' ' + and info.pertain(l[i+1])): + # Jan of 01 + # In this case, 01 is clearly year + try: + value = int(l[i+3]) + except ValueError: + # Wrong guess + pass + else: + # Convert it here to become unambiguous + ymd.append(str(info.convertyear(value))) + i += 4 + continue + + # Check am/pm + value = info.ampm(l[i]) + if value is not None: + # For fuzzy parsing, 'a' or 'am' (both valid English words) + # may erroneously trigger the AM/PM flag. Deal with that + # here. + val_is_ampm = True + + # If there's already an AM/PM flag, this one isn't one. + if fuzzy and res.ampm is not None: + val_is_ampm = False + + # If AM/PM is found and hour is not, raise a ValueError + if res.hour is None: + if fuzzy: + val_is_ampm = False + else: + raise ValueError('No hour specified with ' + + 'AM or PM flag.') + elif not 0 <= res.hour <= 12: + # If AM/PM is found, it's a 12 hour clock, so raise + # an error for invalid range + if fuzzy: + val_is_ampm = False + else: + raise ValueError('Invalid hour specified for ' + + '12-hour clock.') + + if val_is_ampm: + if value == 1 and res.hour < 12: + res.hour += 12 + elif value == 0 and res.hour == 12: + res.hour = 0 + + res.ampm = value + + elif fuzzy: + last_skipped_token_i = self._skip_token(skipped_tokens, + last_skipped_token_i, i, l) + i += 1 + continue + + # Check for a timezone name + if (res.hour is not None and len(l[i]) <= 5 and + res.tzname is None and res.tzoffset is None and + not [x for x in l[i] if x not in + string.ascii_uppercase]): + res.tzname = l[i] + res.tzoffset = info.tzoffset(res.tzname) + i += 1 + + # Check for something like GMT+3, or BRST+3. Notice + # that it doesn't mean "I am 3 hours after GMT", but + # "my time +3 is GMT". If found, we reverse the + # logic so that timezone parsing code will get it + # right. + if i < len_l and l[i] in ('+', '-'): + l[i] = ('+', '-')[l[i] == '+'] + res.tzoffset = None + if info.utczone(res.tzname): + # With something like GMT+3, the timezone + # is *not* GMT. + res.tzname = None + + continue + + # Check for a numbered timezone + if res.hour is not None and l[i] in ('+', '-'): + signal = (-1, 1)[l[i] == '+'] + i += 1 + len_li = len(l[i]) + + if len_li == 4: + # -0300 + res.tzoffset = int(l[i][:2])*3600+int(l[i][2:])*60 + elif i+1 < len_l and l[i+1] == ':': + # -03:00 + res.tzoffset = int(l[i])*3600+int(l[i+2])*60 + i += 2 + elif len_li <= 2: + # -[0]3 + res.tzoffset = int(l[i][:2])*3600 + else: + return None, None + i += 1 + + res.tzoffset *= signal + + # Look for a timezone name between parenthesis + if (i+3 < len_l and + info.jump(l[i]) and l[i+1] == '(' and l[i+3] == ')' and + 3 <= len(l[i+2]) <= 5 and + not [x for x in l[i+2] + if x not in string.ascii_uppercase]): + # -0300 (BRST) + res.tzname = l[i+2] + i += 4 + continue + + # Check jumps + if not (info.jump(l[i]) or fuzzy): + return None, None + + last_skipped_token_i = self._skip_token(skipped_tokens, + last_skipped_token_i, i, l) + i += 1 + + # Process year/month/day + year, month, day = ymd.resolve_ymd(mstridx, yearfirst, dayfirst) + if year is not None: + res.year = year + res.century_specified = ymd.century_specified + + if month is not None: + res.month = month + + if day is not None: + res.day = day + + except (IndexError, ValueError, AssertionError): + return None, None + + if not info.validate(res): + return None, None + + if fuzzy_with_tokens: + return res, tuple(skipped_tokens) + else: + return res, None + + @staticmethod + def _skip_token(skipped_tokens, last_skipped_token_i, i, l): + if last_skipped_token_i == i - 1: + # recombine the tokens + skipped_tokens[-1] += l[i] + else: + # just append + skipped_tokens.append(l[i]) + last_skipped_token_i = i + return last_skipped_token_i + + +DEFAULTPARSER = parser() + + +def parse(timestr, parserinfo=None, **kwargs): + """ + + Parse a string in one of the supported formats, using the + ``parserinfo`` parameters. + + :param timestr: + A string containing a date/time stamp. + + :param parserinfo: + A :class:`parserinfo` object containing parameters for the parser. + If ``None``, the default arguments to the :class:`parserinfo` + constructor are used. + + The ``**kwargs`` parameter takes the following keyword arguments: + + :param default: + The default datetime object, if this is a datetime object and not + ``None``, elements specified in ``timestr`` replace elements in the + default object. + + :param ignoretz: + If set ``True``, time zones in parsed strings are ignored and a naive + :class:`datetime` object is returned. + + :param tzinfos: + Additional time zone names / aliases which may be present in the + string. This argument maps time zone names (and optionally offsets + from those time zones) to time zones. This parameter can be a + dictionary with timezone aliases mapping time zone names to time + zones or a function taking two parameters (``tzname`` and + ``tzoffset``) and returning a time zone. + + The timezones to which the names are mapped can be an integer + offset from UTC in minutes or a :class:`tzinfo` object. + + .. doctest:: + :options: +NORMALIZE_WHITESPACE + + >>> from dateutil.parser import parse + >>> from dateutil.tz import gettz + >>> tzinfos = {"BRST": -10800, "CST": gettz("America/Chicago")} + >>> parse("2012-01-19 17:21:00 BRST", tzinfos=tzinfos) + datetime.datetime(2012, 1, 19, 17, 21, tzinfo=tzoffset(u'BRST', -10800)) + >>> parse("2012-01-19 17:21:00 CST", tzinfos=tzinfos) + datetime.datetime(2012, 1, 19, 17, 21, + tzinfo=tzfile('/usr/share/zoneinfo/America/Chicago')) + + This parameter is ignored if ``ignoretz`` is set. + + :param dayfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the day (``True``) or month (``False``). If + ``yearfirst`` is set to ``True``, this distinguishes between YDM and + YMD. If set to ``None``, this value is retrieved from the current + :class:`parserinfo` object (which itself defaults to ``False``). + + :param yearfirst: + Whether to interpret the first value in an ambiguous 3-integer date + (e.g. 01/05/09) as the year. If ``True``, the first number is taken to + be the year, otherwise the last number is taken to be the year. If + this is set to ``None``, the value is retrieved from the current + :class:`parserinfo` object (which itself defaults to ``False``). + + :param fuzzy: + Whether to allow fuzzy parsing, allowing for string like "Today is + January 1, 2047 at 8:21:00AM". + + :param fuzzy_with_tokens: + If ``True``, ``fuzzy`` is automatically set to True, and the parser + will return a tuple where the first element is the parsed + :class:`datetime.datetime` datetimestamp and the second element is + a tuple containing the portions of the string which were ignored: + + .. doctest:: + + >>> from dateutil.parser import parse + >>> parse("Today is January 1, 2047 at 8:21:00AM", fuzzy_with_tokens=True) + (datetime.datetime(2047, 1, 1, 8, 21), (u'Today is ', u' ', u'at ')) + + :return: + Returns a :class:`datetime.datetime` object or, if the + ``fuzzy_with_tokens`` option is ``True``, returns a tuple, the + first element being a :class:`datetime.datetime` object, the second + a tuple containing the fuzzy tokens. + + :raises ValueError: + Raised for invalid or unknown string format, if the provided + :class:`tzinfo` is not in a valid format, or if an invalid date + would be created. + + :raises OverflowError: + Raised if the parsed date exceeds the largest valid C integer on + your system. + """ + if parserinfo: + return parser(parserinfo).parse(timestr, **kwargs) + else: + return DEFAULTPARSER.parse(timestr, **kwargs) + + +class _tzparser(object): + + class _result(_resultbase): + + __slots__ = ["stdabbr", "stdoffset", "dstabbr", "dstoffset", + "start", "end"] + + class _attr(_resultbase): + __slots__ = ["month", "week", "weekday", + "yday", "jyday", "day", "time"] + + def __repr__(self): + return self._repr("") + + def __init__(self): + _resultbase.__init__(self) + self.start = self._attr() + self.end = self._attr() + + def parse(self, tzstr): + res = self._result() + l = _timelex.split(tzstr) + try: + + len_l = len(l) + + i = 0 + while i < len_l: + # BRST+3[BRDT[+2]] + j = i + while j < len_l and not [x for x in l[j] + if x in "0123456789:,-+"]: + j += 1 + if j != i: + if not res.stdabbr: + offattr = "stdoffset" + res.stdabbr = "".join(l[i:j]) + else: + offattr = "dstoffset" + res.dstabbr = "".join(l[i:j]) + i = j + if (i < len_l and (l[i] in ('+', '-') or l[i][0] in + "0123456789")): + if l[i] in ('+', '-'): + # Yes, that's right. See the TZ variable + # documentation. + signal = (1, -1)[l[i] == '+'] + i += 1 + else: + signal = -1 + len_li = len(l[i]) + if len_li == 4: + # -0300 + setattr(res, offattr, (int(l[i][:2])*3600 + + int(l[i][2:])*60)*signal) + elif i+1 < len_l and l[i+1] == ':': + # -03:00 + setattr(res, offattr, + (int(l[i])*3600+int(l[i+2])*60)*signal) + i += 2 + elif len_li <= 2: + # -[0]3 + setattr(res, offattr, + int(l[i][:2])*3600*signal) + else: + return None + i += 1 + if res.dstabbr: + break + else: + break + + if i < len_l: + for j in range(i, len_l): + if l[j] == ';': + l[j] = ',' + + assert l[i] == ',' + + i += 1 + + if i >= len_l: + pass + elif (8 <= l.count(',') <= 9 and + not [y for x in l[i:] if x != ',' + for y in x if y not in "0123456789"]): + # GMT0BST,3,0,30,3600,10,0,26,7200[,3600] + for x in (res.start, res.end): + x.month = int(l[i]) + i += 2 + if l[i] == '-': + value = int(l[i+1])*-1 + i += 1 + else: + value = int(l[i]) + i += 2 + if value: + x.week = value + x.weekday = (int(l[i])-1) % 7 + else: + x.day = int(l[i]) + i += 2 + x.time = int(l[i]) + i += 2 + if i < len_l: + if l[i] in ('-', '+'): + signal = (-1, 1)[l[i] == "+"] + i += 1 + else: + signal = 1 + res.dstoffset = (res.stdoffset+int(l[i]))*signal + elif (l.count(',') == 2 and l[i:].count('/') <= 2 and + not [y for x in l[i:] if x not in (',', '/', 'J', 'M', + '.', '-', ':') + for y in x if y not in "0123456789"]): + for x in (res.start, res.end): + if l[i] == 'J': + # non-leap year day (1 based) + i += 1 + x.jyday = int(l[i]) + elif l[i] == 'M': + # month[-.]week[-.]weekday + i += 1 + x.month = int(l[i]) + i += 1 + assert l[i] in ('-', '.') + i += 1 + x.week = int(l[i]) + if x.week == 5: + x.week = -1 + i += 1 + assert l[i] in ('-', '.') + i += 1 + x.weekday = (int(l[i])-1) % 7 + else: + # year day (zero based) + x.yday = int(l[i])+1 + + i += 1 + + if i < len_l and l[i] == '/': + i += 1 + # start time + len_li = len(l[i]) + if len_li == 4: + # -0300 + x.time = (int(l[i][:2])*3600+int(l[i][2:])*60) + elif i+1 < len_l and l[i+1] == ':': + # -03:00 + x.time = int(l[i])*3600+int(l[i+2])*60 + i += 2 + if i+1 < len_l and l[i+1] == ':': + i += 2 + x.time += int(l[i]) + elif len_li <= 2: + # -[0]3 + x.time = (int(l[i][:2])*3600) + else: + return None + i += 1 + + assert i == len_l or l[i] == ',' + + i += 1 + + assert i >= len_l + + except (IndexError, ValueError, AssertionError): + return None + + return res + + +DEFAULTTZPARSER = _tzparser() + + +def _parsetz(tzstr): + return DEFAULTTZPARSER.parse(tzstr) + + +def _parsems(value): + """Parse a I[.F] seconds value into (seconds, microseconds).""" + if "." not in value: + return int(value), 0 + else: + i, f = value.split(".") + return int(i), int(f.ljust(6, "0")[:6]) + + +# vim:ts=4:sw=4:et diff --git a/python/dateutil/relativedelta.py b/python/dateutil/relativedelta.py new file mode 100644 index 0000000..0e66afc --- /dev/null +++ b/python/dateutil/relativedelta.py @@ -0,0 +1,549 @@ +# -*- coding: utf-8 -*- +import datetime +import calendar + +import operator +from math import copysign + +from six import integer_types +from warnings import warn + +from ._common import weekday + +MO, TU, WE, TH, FR, SA, SU = weekdays = tuple(weekday(x) for x in range(7)) + +__all__ = ["relativedelta", "MO", "TU", "WE", "TH", "FR", "SA", "SU"] + + +class relativedelta(object): + """ + The relativedelta type is based on the specification of the excellent + work done by M.-A. Lemburg in his + `mx.DateTime <http://www.egenix.com/files/python/mxDateTime.html>`_ extension. + However, notice that this type does *NOT* implement the same algorithm as + his work. Do *NOT* expect it to behave like mx.DateTime's counterpart. + + There are two different ways to build a relativedelta instance. The + first one is passing it two date/datetime classes:: + + relativedelta(datetime1, datetime2) + + The second one is passing it any number of the following keyword arguments:: + + relativedelta(arg1=x,arg2=y,arg3=z...) + + year, month, day, hour, minute, second, microsecond: + Absolute information (argument is singular); adding or subtracting a + relativedelta with absolute information does not perform an aritmetic + operation, but rather REPLACES the corresponding value in the + original datetime with the value(s) in relativedelta. + + years, months, weeks, days, hours, minutes, seconds, microseconds: + Relative information, may be negative (argument is plural); adding + or subtracting a relativedelta with relative information performs + the corresponding aritmetic operation on the original datetime value + with the information in the relativedelta. + + weekday: + One of the weekday instances (MO, TU, etc). These instances may + receive a parameter N, specifying the Nth weekday, which could + be positive or negative (like MO(+1) or MO(-2). Not specifying + it is the same as specifying +1. You can also use an integer, + where 0=MO. + + leapdays: + Will add given days to the date found, if year is a leap + year, and the date found is post 28 of february. + + yearday, nlyearday: + Set the yearday or the non-leap year day (jump leap days). + These are converted to day/month/leapdays information. + + Here is the behavior of operations with relativedelta: + + 1. Calculate the absolute year, using the 'year' argument, or the + original datetime year, if the argument is not present. + + 2. Add the relative 'years' argument to the absolute year. + + 3. Do steps 1 and 2 for month/months. + + 4. Calculate the absolute day, using the 'day' argument, or the + original datetime day, if the argument is not present. Then, + subtract from the day until it fits in the year and month + found after their operations. + + 5. Add the relative 'days' argument to the absolute day. Notice + that the 'weeks' argument is multiplied by 7 and added to + 'days'. + + 6. Do steps 1 and 2 for hour/hours, minute/minutes, second/seconds, + microsecond/microseconds. + + 7. If the 'weekday' argument is present, calculate the weekday, + with the given (wday, nth) tuple. wday is the index of the + weekday (0-6, 0=Mon), and nth is the number of weeks to add + forward or backward, depending on its signal. Notice that if + the calculated date is already Monday, for example, using + (0, 1) or (0, -1) won't change the day. + """ + + def __init__(self, dt1=None, dt2=None, + years=0, months=0, days=0, leapdays=0, weeks=0, + hours=0, minutes=0, seconds=0, microseconds=0, + year=None, month=None, day=None, weekday=None, + yearday=None, nlyearday=None, + hour=None, minute=None, second=None, microsecond=None): + + # Check for non-integer values in integer-only quantities + if any(x is not None and x != int(x) for x in (years, months)): + raise ValueError("Non-integer years and months are " + "ambiguous and not currently supported.") + + if dt1 and dt2: + # datetime is a subclass of date. So both must be date + if not (isinstance(dt1, datetime.date) and + isinstance(dt2, datetime.date)): + raise TypeError("relativedelta only diffs datetime/date") + + # We allow two dates, or two datetimes, so we coerce them to be + # of the same type + if (isinstance(dt1, datetime.datetime) != + isinstance(dt2, datetime.datetime)): + if not isinstance(dt1, datetime.datetime): + dt1 = datetime.datetime.fromordinal(dt1.toordinal()) + elif not isinstance(dt2, datetime.datetime): + dt2 = datetime.datetime.fromordinal(dt2.toordinal()) + + self.years = 0 + self.months = 0 + self.days = 0 + self.leapdays = 0 + self.hours = 0 + self.minutes = 0 + self.seconds = 0 + self.microseconds = 0 + self.year = None + self.month = None + self.day = None + self.weekday = None + self.hour = None + self.minute = None + self.second = None + self.microsecond = None + self._has_time = 0 + + # Get year / month delta between the two + months = (dt1.year - dt2.year) * 12 + (dt1.month - dt2.month) + self._set_months(months) + + # Remove the year/month delta so the timedelta is just well-defined + # time units (seconds, days and microseconds) + dtm = self.__radd__(dt2) + + # If we've overshot our target, make an adjustment + if dt1 < dt2: + compare = operator.gt + increment = 1 + else: + compare = operator.lt + increment = -1 + + while compare(dt1, dtm): + months += increment + self._set_months(months) + dtm = self.__radd__(dt2) + + # Get the timedelta between the "months-adjusted" date and dt1 + delta = dt1 - dtm + self.seconds = delta.seconds + delta.days * 86400 + self.microseconds = delta.microseconds + else: + # Relative information + self.years = years + self.months = months + self.days = days + weeks * 7 + self.leapdays = leapdays + self.hours = hours + self.minutes = minutes + self.seconds = seconds + self.microseconds = microseconds + + # Absolute information + self.year = year + self.month = month + self.day = day + self.hour = hour + self.minute = minute + self.second = second + self.microsecond = microsecond + + if any(x is not None and int(x) != x + for x in (year, month, day, hour, + minute, second, microsecond)): + # For now we'll deprecate floats - later it'll be an error. + warn("Non-integer value passed as absolute information. " + + "This is not a well-defined condition and will raise " + + "errors in future versions.", DeprecationWarning) + + if isinstance(weekday, integer_types): + self.weekday = weekdays[weekday] + else: + self.weekday = weekday + + yday = 0 + if nlyearday: + yday = nlyearday + elif yearday: + yday = yearday + if yearday > 59: + self.leapdays = -1 + if yday: + ydayidx = [31, 59, 90, 120, 151, 181, 212, + 243, 273, 304, 334, 366] + for idx, ydays in enumerate(ydayidx): + if yday <= ydays: + self.month = idx+1 + if idx == 0: + self.day = yday + else: + self.day = yday-ydayidx[idx-1] + break + else: + raise ValueError("invalid year day (%d)" % yday) + + self._fix() + + def _fix(self): + if abs(self.microseconds) > 999999: + s = _sign(self.microseconds) + div, mod = divmod(self.microseconds * s, 1000000) + self.microseconds = mod * s + self.seconds += div * s + if abs(self.seconds) > 59: + s = _sign(self.seconds) + div, mod = divmod(self.seconds * s, 60) + self.seconds = mod * s + self.minutes += div * s + if abs(self.minutes) > 59: + s = _sign(self.minutes) + div, mod = divmod(self.minutes * s, 60) + self.minutes = mod * s + self.hours += div * s + if abs(self.hours) > 23: + s = _sign(self.hours) + div, mod = divmod(self.hours * s, 24) + self.hours = mod * s + self.days += div * s + if abs(self.months) > 11: + s = _sign(self.months) + div, mod = divmod(self.months * s, 12) + self.months = mod * s + self.years += div * s + if (self.hours or self.minutes or self.seconds or self.microseconds + or self.hour is not None or self.minute is not None or + self.second is not None or self.microsecond is not None): + self._has_time = 1 + else: + self._has_time = 0 + + @property + def weeks(self): + return self.days // 7 + + @weeks.setter + def weeks(self, value): + self.days = self.days - (self.weeks * 7) + value * 7 + + def _set_months(self, months): + self.months = months + if abs(self.months) > 11: + s = _sign(self.months) + div, mod = divmod(self.months * s, 12) + self.months = mod * s + self.years = div * s + else: + self.years = 0 + + def normalized(self): + """ + Return a version of this object represented entirely using integer + values for the relative attributes. + + >>> relativedelta(days=1.5, hours=2).normalized() + relativedelta(days=1, hours=14) + + :return: + Returns a :class:`dateutil.relativedelta.relativedelta` object. + """ + # Cascade remainders down (rounding each to roughly nearest microsecond) + days = int(self.days) + + hours_f = round(self.hours + 24 * (self.days - days), 11) + hours = int(hours_f) + + minutes_f = round(self.minutes + 60 * (hours_f - hours), 10) + minutes = int(minutes_f) + + seconds_f = round(self.seconds + 60 * (minutes_f - minutes), 8) + seconds = int(seconds_f) + + microseconds = round(self.microseconds + 1e6 * (seconds_f - seconds)) + + # Constructor carries overflow back up with call to _fix() + return self.__class__(years=self.years, months=self.months, + days=days, hours=hours, minutes=minutes, + seconds=seconds, microseconds=microseconds, + leapdays=self.leapdays, year=self.year, + month=self.month, day=self.day, + weekday=self.weekday, hour=self.hour, + minute=self.minute, second=self.second, + microsecond=self.microsecond) + + def __add__(self, other): + if isinstance(other, relativedelta): + return self.__class__(years=other.years + self.years, + months=other.months + self.months, + days=other.days + self.days, + hours=other.hours + self.hours, + minutes=other.minutes + self.minutes, + seconds=other.seconds + self.seconds, + microseconds=(other.microseconds + + self.microseconds), + leapdays=other.leapdays or self.leapdays, + year=(other.year if other.year is not None + else self.year), + month=(other.month if other.month is not None + else self.month), + day=(other.day if other.day is not None + else self.day), + weekday=(other.weekday if other.weekday is not None + else self.weekday), + hour=(other.hour if other.hour is not None + else self.hour), + minute=(other.minute if other.minute is not None + else self.minute), + second=(other.second if other.second is not None + else self.second), + microsecond=(other.microsecond if other.microsecond + is not None else + self.microsecond)) + if isinstance(other, datetime.timedelta): + return self.__class__(years=self.years, + months=self.months, + days=self.days + other.days, + hours=self.hours, + minutes=self.minutes, + seconds=self.seconds + other.seconds, + microseconds=self.microseconds + other.microseconds, + leapdays=self.leapdays, + year=self.year, + month=self.month, + day=self.day, + weekday=self.weekday, + hour=self.hour, + minute=self.minute, + second=self.second, + microsecond=self.microsecond) + if not isinstance(other, datetime.date): + return NotImplemented + elif self._has_time and not isinstance(other, datetime.datetime): + other = datetime.datetime.fromordinal(other.toordinal()) + year = (self.year or other.year)+self.years + month = self.month or other.month + if self.months: + assert 1 <= abs(self.months) <= 12 + month += self.months + if month > 12: + year += 1 + month -= 12 + elif month < 1: + year -= 1 + month += 12 + day = min(calendar.monthrange(year, month)[1], + self.day or other.day) + repl = {"year": year, "month": month, "day": day} + for attr in ["hour", "minute", "second", "microsecond"]: + value = getattr(self, attr) + if value is not None: + repl[attr] = value + days = self.days + if self.leapdays and month > 2 and calendar.isleap(year): + days += self.leapdays + ret = (other.replace(**repl) + + datetime.timedelta(days=days, + hours=self.hours, + minutes=self.minutes, + seconds=self.seconds, + microseconds=self.microseconds)) + if self.weekday: + weekday, nth = self.weekday.weekday, self.weekday.n or 1 + jumpdays = (abs(nth) - 1) * 7 + if nth > 0: + jumpdays += (7 - ret.weekday() + weekday) % 7 + else: + jumpdays += (ret.weekday() - weekday) % 7 + jumpdays *= -1 + ret += datetime.timedelta(days=jumpdays) + return ret + + def __radd__(self, other): + return self.__add__(other) + + def __rsub__(self, other): + return self.__neg__().__radd__(other) + + def __sub__(self, other): + if not isinstance(other, relativedelta): + return NotImplemented # In case the other object defines __rsub__ + return self.__class__(years=self.years - other.years, + months=self.months - other.months, + days=self.days - other.days, + hours=self.hours - other.hours, + minutes=self.minutes - other.minutes, + seconds=self.seconds - other.seconds, + microseconds=self.microseconds - other.microseconds, + leapdays=self.leapdays or other.leapdays, + year=(self.year if self.year is not None + else other.year), + month=(self.month if self.month is not None else + other.month), + day=(self.day if self.day is not None else + other.day), + weekday=(self.weekday if self.weekday is not None else + other.weekday), + hour=(self.hour if self.hour is not None else + other.hour), + minute=(self.minute if self.minute is not None else + other.minute), + second=(self.second if self.second is not None else + other.second), + microsecond=(self.microsecond if self.microsecond + is not None else + other.microsecond)) + + def __neg__(self): + return self.__class__(years=-self.years, + months=-self.months, + days=-self.days, + hours=-self.hours, + minutes=-self.minutes, + seconds=-self.seconds, + microseconds=-self.microseconds, + leapdays=self.leapdays, + year=self.year, + month=self.month, + day=self.day, + weekday=self.weekday, + hour=self.hour, + minute=self.minute, + second=self.second, + microsecond=self.microsecond) + + def __bool__(self): + return not (not self.years and + not self.months and + not self.days and + not self.hours and + not self.minutes and + not self.seconds and + not self.microseconds and + not self.leapdays and + self.year is None and + self.month is None and + self.day is None and + self.weekday is None and + self.hour is None and + self.minute is None and + self.second is None and + self.microsecond is None) + # Compatibility with Python 2.x + __nonzero__ = __bool__ + + def __mul__(self, other): + try: + f = float(other) + except TypeError: + return NotImplemented + + return self.__class__(years=int(self.years * f), + months=int(self.months * f), + days=int(self.days * f), + hours=int(self.hours * f), + minutes=int(self.minutes * f), + seconds=int(self.seconds * f), + microseconds=int(self.microseconds * f), + leapdays=self.leapdays, + year=self.year, + month=self.month, + day=self.day, + weekday=self.weekday, + hour=self.hour, + minute=self.minute, + second=self.second, + microsecond=self.microsecond) + + __rmul__ = __mul__ + + def __eq__(self, other): + if not isinstance(other, relativedelta): + return NotImplemented + if self.weekday or other.weekday: + if not self.weekday or not other.weekday: + return False + if self.weekday.weekday != other.weekday.weekday: + return False + n1, n2 = self.weekday.n, other.weekday.n + if n1 != n2 and not ((not n1 or n1 == 1) and (not n2 or n2 == 1)): + return False + return (self.years == other.years and + self.months == other.months and + self.days == other.days and + self.hours == other.hours and + self.minutes == other.minutes and + self.seconds == other.seconds and + self.microseconds == other.microseconds and + self.leapdays == other.leapdays and + self.year == other.year and + self.month == other.month and + self.day == other.day and + self.hour == other.hour and + self.minute == other.minute and + self.second == other.second and + self.microsecond == other.microsecond) + + __hash__ = None + + def __ne__(self, other): + return not self.__eq__(other) + + def __div__(self, other): + try: + reciprocal = 1 / float(other) + except TypeError: + return NotImplemented + + return self.__mul__(reciprocal) + + __truediv__ = __div__ + + def __repr__(self): + l = [] + for attr in ["years", "months", "days", "leapdays", + "hours", "minutes", "seconds", "microseconds"]: + value = getattr(self, attr) + if value: + l.append("{attr}={value:+g}".format(attr=attr, value=value)) + for attr in ["year", "month", "day", "weekday", + "hour", "minute", "second", "microsecond"]: + value = getattr(self, attr) + if value is not None: + l.append("{attr}={value}".format(attr=attr, value=repr(value))) + return "{classname}({attrs})".format(classname=self.__class__.__name__, + attrs=", ".join(l)) + + +def _sign(x): + return int(copysign(1, x)) + +# vim:ts=4:sw=4:et diff --git a/python/dateutil/rrule.py b/python/dateutil/rrule.py new file mode 100644 index 0000000..429f8fc --- /dev/null +++ b/python/dateutil/rrule.py @@ -0,0 +1,1610 @@ +# -*- coding: utf-8 -*- +""" +The rrule module offers a small, complete, and very fast, implementation of +the recurrence rules documented in the +`iCalendar RFC <http://www.ietf.org/rfc/rfc2445.txt>`_, +including support for caching of results. +""" +import itertools +import datetime +import calendar +import sys + +try: + from math import gcd +except ImportError: + from fractions import gcd + +from six import advance_iterator, integer_types +from six.moves import _thread, range +import heapq + +from ._common import weekday as weekdaybase + +# For warning about deprecation of until and count +from warnings import warn + +__all__ = ["rrule", "rruleset", "rrulestr", + "YEARLY", "MONTHLY", "WEEKLY", "DAILY", + "HOURLY", "MINUTELY", "SECONDLY", + "MO", "TU", "WE", "TH", "FR", "SA", "SU"] + +# Every mask is 7 days longer to handle cross-year weekly periods. +M366MASK = tuple([1]*31+[2]*29+[3]*31+[4]*30+[5]*31+[6]*30 + + [7]*31+[8]*31+[9]*30+[10]*31+[11]*30+[12]*31+[1]*7) +M365MASK = list(M366MASK) +M29, M30, M31 = list(range(1, 30)), list(range(1, 31)), list(range(1, 32)) +MDAY366MASK = tuple(M31+M29+M31+M30+M31+M30+M31+M31+M30+M31+M30+M31+M31[:7]) +MDAY365MASK = list(MDAY366MASK) +M29, M30, M31 = list(range(-29, 0)), list(range(-30, 0)), list(range(-31, 0)) +NMDAY366MASK = tuple(M31+M29+M31+M30+M31+M30+M31+M31+M30+M31+M30+M31+M31[:7]) +NMDAY365MASK = list(NMDAY366MASK) +M366RANGE = (0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, 366) +M365RANGE = (0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365) +WDAYMASK = [0, 1, 2, 3, 4, 5, 6]*55 +del M29, M30, M31, M365MASK[59], MDAY365MASK[59], NMDAY365MASK[31] +MDAY365MASK = tuple(MDAY365MASK) +M365MASK = tuple(M365MASK) + +FREQNAMES = ['YEARLY', 'MONTHLY', 'WEEKLY', 'DAILY', 'HOURLY', 'MINUTELY', 'SECONDLY'] + +(YEARLY, + MONTHLY, + WEEKLY, + DAILY, + HOURLY, + MINUTELY, + SECONDLY) = list(range(7)) + +# Imported on demand. +easter = None +parser = None + + +class weekday(weekdaybase): + """ + This version of weekday does not allow n = 0. + """ + def __init__(self, wkday, n=None): + if n == 0: + raise ValueError("Can't create weekday with n==0") + + super(weekday, self).__init__(wkday, n) + + +MO, TU, WE, TH, FR, SA, SU = weekdays = tuple(weekday(x) for x in range(7)) + + +def _invalidates_cache(f): + """ + Decorator for rruleset methods which may invalidate the + cached length. + """ + def inner_func(self, *args, **kwargs): + rv = f(self, *args, **kwargs) + self._invalidate_cache() + return rv + + return inner_func + + +class rrulebase(object): + def __init__(self, cache=False): + if cache: + self._cache = [] + self._cache_lock = _thread.allocate_lock() + self._invalidate_cache() + else: + self._cache = None + self._cache_complete = False + self._len = None + + def __iter__(self): + if self._cache_complete: + return iter(self._cache) + elif self._cache is None: + return self._iter() + else: + return self._iter_cached() + + def _invalidate_cache(self): + if self._cache is not None: + self._cache = [] + self._cache_complete = False + self._cache_gen = self._iter() + + if self._cache_lock.locked(): + self._cache_lock.release() + + self._len = None + + def _iter_cached(self): + i = 0 + gen = self._cache_gen + cache = self._cache + acquire = self._cache_lock.acquire + release = self._cache_lock.release + while gen: + if i == len(cache): + acquire() + if self._cache_complete: + break + try: + for j in range(10): + cache.append(advance_iterator(gen)) + except StopIteration: + self._cache_gen = gen = None + self._cache_complete = True + break + release() + yield cache[i] + i += 1 + while i < self._len: + yield cache[i] + i += 1 + + def __getitem__(self, item): + if self._cache_complete: + return self._cache[item] + elif isinstance(item, slice): + if item.step and item.step < 0: + return list(iter(self))[item] + else: + return list(itertools.islice(self, + item.start or 0, + item.stop or sys.maxsize, + item.step or 1)) + elif item >= 0: + gen = iter(self) + try: + for i in range(item+1): + res = advance_iterator(gen) + except StopIteration: + raise IndexError + return res + else: + return list(iter(self))[item] + + def __contains__(self, item): + if self._cache_complete: + return item in self._cache + else: + for i in self: + if i == item: + return True + elif i > item: + return False + return False + + # __len__() introduces a large performance penality. + def count(self): + """ Returns the number of recurrences in this set. It will have go + trough the whole recurrence, if this hasn't been done before. """ + if self._len is None: + for x in self: + pass + return self._len + + def before(self, dt, inc=False): + """ Returns the last recurrence before the given datetime instance. The + inc keyword defines what happens if dt is an occurrence. With + inc=True, if dt itself is an occurrence, it will be returned. """ + if self._cache_complete: + gen = self._cache + else: + gen = self + last = None + if inc: + for i in gen: + if i > dt: + break + last = i + else: + for i in gen: + if i >= dt: + break + last = i + return last + + def after(self, dt, inc=False): + """ Returns the first recurrence after the given datetime instance. The + inc keyword defines what happens if dt is an occurrence. With + inc=True, if dt itself is an occurrence, it will be returned. """ + if self._cache_complete: + gen = self._cache + else: + gen = self + if inc: + for i in gen: + if i >= dt: + return i + else: + for i in gen: + if i > dt: + return i + return None + + def xafter(self, dt, count=None, inc=False): + """ + Generator which yields up to `count` recurrences after the given + datetime instance, equivalent to `after`. + + :param dt: + The datetime at which to start generating recurrences. + + :param count: + The maximum number of recurrences to generate. If `None` (default), + dates are generated until the recurrence rule is exhausted. + + :param inc: + If `dt` is an instance of the rule and `inc` is `True`, it is + included in the output. + + :yields: Yields a sequence of `datetime` objects. + """ + + if self._cache_complete: + gen = self._cache + else: + gen = self + + # Select the comparison function + if inc: + comp = lambda dc, dtc: dc >= dtc + else: + comp = lambda dc, dtc: dc > dtc + + # Generate dates + n = 0 + for d in gen: + if comp(d, dt): + if count is not None: + n += 1 + if n > count: + break + + yield d + + def between(self, after, before, inc=False, count=1): + """ Returns all the occurrences of the rrule between after and before. + The inc keyword defines what happens if after and/or before are + themselves occurrences. With inc=True, they will be included in the + list, if they are found in the recurrence set. """ + if self._cache_complete: + gen = self._cache + else: + gen = self + started = False + l = [] + if inc: + for i in gen: + if i > before: + break + elif not started: + if i >= after: + started = True + l.append(i) + else: + l.append(i) + else: + for i in gen: + if i >= before: + break + elif not started: + if i > after: + started = True + l.append(i) + else: + l.append(i) + return l + + +class rrule(rrulebase): + """ + That's the base of the rrule operation. It accepts all the keywords + defined in the RFC as its constructor parameters (except byday, + which was renamed to byweekday) and more. The constructor prototype is:: + + rrule(freq) + + Where freq must be one of YEARLY, MONTHLY, WEEKLY, DAILY, HOURLY, MINUTELY, + or SECONDLY. + + .. note:: + Per RFC section 3.3.10, recurrence instances falling on invalid dates + and times are ignored rather than coerced: + + Recurrence rules may generate recurrence instances with an invalid + date (e.g., February 30) or nonexistent local time (e.g., 1:30 AM + on a day where the local time is moved forward by an hour at 1:00 + AM). Such recurrence instances MUST be ignored and MUST NOT be + counted as part of the recurrence set. + + This can lead to possibly surprising behavior when, for example, the + start date occurs at the end of the month: + + >>> from dateutil.rrule import rrule, MONTHLY + >>> from datetime import datetime + >>> start_date = datetime(2014, 12, 31) + >>> list(rrule(freq=MONTHLY, count=4, dtstart=start_date)) + ... # doctest: +NORMALIZE_WHITESPACE + [datetime.datetime(2014, 12, 31, 0, 0), + datetime.datetime(2015, 1, 31, 0, 0), + datetime.datetime(2015, 3, 31, 0, 0), + datetime.datetime(2015, 5, 31, 0, 0)] + + Additionally, it supports the following keyword arguments: + + :param cache: + If given, it must be a boolean value specifying to enable or disable + caching of results. If you will use the same rrule instance multiple + times, enabling caching will improve the performance considerably. + :param dtstart: + The recurrence start. Besides being the base for the recurrence, + missing parameters in the final recurrence instances will also be + extracted from this date. If not given, datetime.now() will be used + instead. + :param interval: + The interval between each freq iteration. For example, when using + YEARLY, an interval of 2 means once every two years, but with HOURLY, + it means once every two hours. The default interval is 1. + :param wkst: + The week start day. Must be one of the MO, TU, WE constants, or an + integer, specifying the first day of the week. This will affect + recurrences based on weekly periods. The default week start is got + from calendar.firstweekday(), and may be modified by + calendar.setfirstweekday(). + :param count: + How many occurrences will be generated. + + .. note:: + As of version 2.5.0, the use of the ``until`` keyword together + with the ``count`` keyword is deprecated per RFC-2445 Sec. 4.3.10. + :param until: + If given, this must be a datetime instance, that will specify the + limit of the recurrence. The last recurrence in the rule is the greatest + datetime that is less than or equal to the value specified in the + ``until`` parameter. + + .. note:: + As of version 2.5.0, the use of the ``until`` keyword together + with the ``count`` keyword is deprecated per RFC-2445 Sec. 4.3.10. + :param bysetpos: + If given, it must be either an integer, or a sequence of integers, + positive or negative. Each given integer will specify an occurrence + number, corresponding to the nth occurrence of the rule inside the + frequency period. For example, a bysetpos of -1 if combined with a + MONTHLY frequency, and a byweekday of (MO, TU, WE, TH, FR), will + result in the last work day of every month. + :param bymonth: + If given, it must be either an integer, or a sequence of integers, + meaning the months to apply the recurrence to. + :param bymonthday: + If given, it must be either an integer, or a sequence of integers, + meaning the month days to apply the recurrence to. + :param byyearday: + If given, it must be either an integer, or a sequence of integers, + meaning the year days to apply the recurrence to. + :param byweekno: + If given, it must be either an integer, or a sequence of integers, + meaning the week numbers to apply the recurrence to. Week numbers + have the meaning described in ISO8601, that is, the first week of + the year is that containing at least four days of the new year. + :param byweekday: + If given, it must be either an integer (0 == MO), a sequence of + integers, one of the weekday constants (MO, TU, etc), or a sequence + of these constants. When given, these variables will define the + weekdays where the recurrence will be applied. It's also possible to + use an argument n for the weekday instances, which will mean the nth + occurrence of this weekday in the period. For example, with MONTHLY, + or with YEARLY and BYMONTH, using FR(+1) in byweekday will specify the + first friday of the month where the recurrence happens. Notice that in + the RFC documentation, this is specified as BYDAY, but was renamed to + avoid the ambiguity of that keyword. + :param byhour: + If given, it must be either an integer, or a sequence of integers, + meaning the hours to apply the recurrence to. + :param byminute: + If given, it must be either an integer, or a sequence of integers, + meaning the minutes to apply the recurrence to. + :param bysecond: + If given, it must be either an integer, or a sequence of integers, + meaning the seconds to apply the recurrence to. + :param byeaster: + If given, it must be either an integer, or a sequence of integers, + positive or negative. Each integer will define an offset from the + Easter Sunday. Passing the offset 0 to byeaster will yield the Easter + Sunday itself. This is an extension to the RFC specification. + """ + def __init__(self, freq, dtstart=None, + interval=1, wkst=None, count=None, until=None, bysetpos=None, + bymonth=None, bymonthday=None, byyearday=None, byeaster=None, + byweekno=None, byweekday=None, + byhour=None, byminute=None, bysecond=None, + cache=False): + super(rrule, self).__init__(cache) + global easter + if not dtstart: + dtstart = datetime.datetime.now().replace(microsecond=0) + elif not isinstance(dtstart, datetime.datetime): + dtstart = datetime.datetime.fromordinal(dtstart.toordinal()) + else: + dtstart = dtstart.replace(microsecond=0) + self._dtstart = dtstart + self._tzinfo = dtstart.tzinfo + self._freq = freq + self._interval = interval + self._count = count + + # Cache the original byxxx rules, if they are provided, as the _byxxx + # attributes do not necessarily map to the inputs, and this can be + # a problem in generating the strings. Only store things if they've + # been supplied (the string retrieval will just use .get()) + self._original_rule = {} + + if until and not isinstance(until, datetime.datetime): + until = datetime.datetime.fromordinal(until.toordinal()) + self._until = until + + if count is not None and until: + warn("Using both 'count' and 'until' is inconsistent with RFC 2445" + " and has been deprecated in dateutil. Future versions will " + "raise an error.", DeprecationWarning) + + if wkst is None: + self._wkst = calendar.firstweekday() + elif isinstance(wkst, integer_types): + self._wkst = wkst + else: + self._wkst = wkst.weekday + + if bysetpos is None: + self._bysetpos = None + elif isinstance(bysetpos, integer_types): + if bysetpos == 0 or not (-366 <= bysetpos <= 366): + raise ValueError("bysetpos must be between 1 and 366, " + "or between -366 and -1") + self._bysetpos = (bysetpos,) + else: + self._bysetpos = tuple(bysetpos) + for pos in self._bysetpos: + if pos == 0 or not (-366 <= pos <= 366): + raise ValueError("bysetpos must be between 1 and 366, " + "or between -366 and -1") + + if self._bysetpos: + self._original_rule['bysetpos'] = self._bysetpos + + if (byweekno is None and byyearday is None and bymonthday is None and + byweekday is None and byeaster is None): + if freq == YEARLY: + if bymonth is None: + bymonth = dtstart.month + self._original_rule['bymonth'] = None + bymonthday = dtstart.day + self._original_rule['bymonthday'] = None + elif freq == MONTHLY: + bymonthday = dtstart.day + self._original_rule['bymonthday'] = None + elif freq == WEEKLY: + byweekday = dtstart.weekday() + self._original_rule['byweekday'] = None + + # bymonth + if bymonth is None: + self._bymonth = None + else: + if isinstance(bymonth, integer_types): + bymonth = (bymonth,) + + self._bymonth = tuple(sorted(set(bymonth))) + + if 'bymonth' not in self._original_rule: + self._original_rule['bymonth'] = self._bymonth + + # byyearday + if byyearday is None: + self._byyearday = None + else: + if isinstance(byyearday, integer_types): + byyearday = (byyearday,) + + self._byyearday = tuple(sorted(set(byyearday))) + self._original_rule['byyearday'] = self._byyearday + + # byeaster + if byeaster is not None: + if not easter: + from dateutil import easter + if isinstance(byeaster, integer_types): + self._byeaster = (byeaster,) + else: + self._byeaster = tuple(sorted(byeaster)) + + self._original_rule['byeaster'] = self._byeaster + else: + self._byeaster = None + + # bymonthday + if bymonthday is None: + self._bymonthday = () + self._bynmonthday = () + else: + if isinstance(bymonthday, integer_types): + bymonthday = (bymonthday,) + + bymonthday = set(bymonthday) # Ensure it's unique + + self._bymonthday = tuple(sorted(x for x in bymonthday if x > 0)) + self._bynmonthday = tuple(sorted(x for x in bymonthday if x < 0)) + + # Storing positive numbers first, then negative numbers + if 'bymonthday' not in self._original_rule: + self._original_rule['bymonthday'] = tuple( + itertools.chain(self._bymonthday, self._bynmonthday)) + + # byweekno + if byweekno is None: + self._byweekno = None + else: + if isinstance(byweekno, integer_types): + byweekno = (byweekno,) + + self._byweekno = tuple(sorted(set(byweekno))) + + self._original_rule['byweekno'] = self._byweekno + + # byweekday / bynweekday + if byweekday is None: + self._byweekday = None + self._bynweekday = None + else: + # If it's one of the valid non-sequence types, convert to a + # single-element sequence before the iterator that builds the + # byweekday set. + if isinstance(byweekday, integer_types) or hasattr(byweekday, "n"): + byweekday = (byweekday,) + + self._byweekday = set() + self._bynweekday = set() + for wday in byweekday: + if isinstance(wday, integer_types): + self._byweekday.add(wday) + elif not wday.n or freq > MONTHLY: + self._byweekday.add(wday.weekday) + else: + self._bynweekday.add((wday.weekday, wday.n)) + + if not self._byweekday: + self._byweekday = None + elif not self._bynweekday: + self._bynweekday = None + + if self._byweekday is not None: + self._byweekday = tuple(sorted(self._byweekday)) + orig_byweekday = [weekday(x) for x in self._byweekday] + else: + orig_byweekday = tuple() + + if self._bynweekday is not None: + self._bynweekday = tuple(sorted(self._bynweekday)) + orig_bynweekday = [weekday(*x) for x in self._bynweekday] + else: + orig_bynweekday = tuple() + + if 'byweekday' not in self._original_rule: + self._original_rule['byweekday'] = tuple(itertools.chain( + orig_byweekday, orig_bynweekday)) + + # byhour + if byhour is None: + if freq < HOURLY: + self._byhour = set((dtstart.hour,)) + else: + self._byhour = None + else: + if isinstance(byhour, integer_types): + byhour = (byhour,) + + if freq == HOURLY: + self._byhour = self.__construct_byset(start=dtstart.hour, + byxxx=byhour, + base=24) + else: + self._byhour = set(byhour) + + self._byhour = tuple(sorted(self._byhour)) + self._original_rule['byhour'] = self._byhour + + # byminute + if byminute is None: + if freq < MINUTELY: + self._byminute = set((dtstart.minute,)) + else: + self._byminute = None + else: + if isinstance(byminute, integer_types): + byminute = (byminute,) + + if freq == MINUTELY: + self._byminute = self.__construct_byset(start=dtstart.minute, + byxxx=byminute, + base=60) + else: + self._byminute = set(byminute) + + self._byminute = tuple(sorted(self._byminute)) + self._original_rule['byminute'] = self._byminute + + # bysecond + if bysecond is None: + if freq < SECONDLY: + self._bysecond = ((dtstart.second,)) + else: + self._bysecond = None + else: + if isinstance(bysecond, integer_types): + bysecond = (bysecond,) + + self._bysecond = set(bysecond) + + if freq == SECONDLY: + self._bysecond = self.__construct_byset(start=dtstart.second, + byxxx=bysecond, + base=60) + else: + self._bysecond = set(bysecond) + + self._bysecond = tuple(sorted(self._bysecond)) + self._original_rule['bysecond'] = self._bysecond + + if self._freq >= HOURLY: + self._timeset = None + else: + self._timeset = [] + for hour in self._byhour: + for minute in self._byminute: + for second in self._bysecond: + self._timeset.append( + datetime.time(hour, minute, second, + tzinfo=self._tzinfo)) + self._timeset.sort() + self._timeset = tuple(self._timeset) + + def __str__(self): + """ + Output a string that would generate this RRULE if passed to rrulestr. + This is mostly compatible with RFC2445, except for the + dateutil-specific extension BYEASTER. + """ + + output = [] + h, m, s = [None] * 3 + if self._dtstart: + output.append(self._dtstart.strftime('DTSTART:%Y%m%dT%H%M%S')) + h, m, s = self._dtstart.timetuple()[3:6] + + parts = ['FREQ=' + FREQNAMES[self._freq]] + if self._interval != 1: + parts.append('INTERVAL=' + str(self._interval)) + + if self._wkst: + parts.append('WKST=' + repr(weekday(self._wkst))[0:2]) + + if self._count is not None: + parts.append('COUNT=' + str(self._count)) + + if self._until: + parts.append(self._until.strftime('UNTIL=%Y%m%dT%H%M%S')) + + if self._original_rule.get('byweekday') is not None: + # The str() method on weekday objects doesn't generate + # RFC2445-compliant strings, so we should modify that. + original_rule = dict(self._original_rule) + wday_strings = [] + for wday in original_rule['byweekday']: + if wday.n: + wday_strings.append('{n:+d}{wday}'.format( + n=wday.n, + wday=repr(wday)[0:2])) + else: + wday_strings.append(repr(wday)) + + original_rule['byweekday'] = wday_strings + else: + original_rule = self._original_rule + + partfmt = '{name}={vals}' + for name, key in [('BYSETPOS', 'bysetpos'), + ('BYMONTH', 'bymonth'), + ('BYMONTHDAY', 'bymonthday'), + ('BYYEARDAY', 'byyearday'), + ('BYWEEKNO', 'byweekno'), + ('BYDAY', 'byweekday'), + ('BYHOUR', 'byhour'), + ('BYMINUTE', 'byminute'), + ('BYSECOND', 'bysecond'), + ('BYEASTER', 'byeaster')]: + value = original_rule.get(key) + if value: + parts.append(partfmt.format(name=name, vals=(','.join(str(v) + for v in value)))) + + output.append(';'.join(parts)) + return '\n'.join(output) + + def replace(self, **kwargs): + """Return new rrule with same attributes except for those attributes given new + values by whichever keyword arguments are specified.""" + new_kwargs = {"interval": self._interval, + "count": self._count, + "dtstart": self._dtstart, + "freq": self._freq, + "until": self._until, + "wkst": self._wkst, + "cache": False if self._cache is None else True } + new_kwargs.update(self._original_rule) + new_kwargs.update(kwargs) + return rrule(**new_kwargs) + + def _iter(self): + year, month, day, hour, minute, second, weekday, yearday, _ = \ + self._dtstart.timetuple() + + # Some local variables to speed things up a bit + freq = self._freq + interval = self._interval + wkst = self._wkst + until = self._until + bymonth = self._bymonth + byweekno = self._byweekno + byyearday = self._byyearday + byweekday = self._byweekday + byeaster = self._byeaster + bymonthday = self._bymonthday + bynmonthday = self._bynmonthday + bysetpos = self._bysetpos + byhour = self._byhour + byminute = self._byminute + bysecond = self._bysecond + + ii = _iterinfo(self) + ii.rebuild(year, month) + + getdayset = {YEARLY: ii.ydayset, + MONTHLY: ii.mdayset, + WEEKLY: ii.wdayset, + DAILY: ii.ddayset, + HOURLY: ii.ddayset, + MINUTELY: ii.ddayset, + SECONDLY: ii.ddayset}[freq] + + if freq < HOURLY: + timeset = self._timeset + else: + gettimeset = {HOURLY: ii.htimeset, + MINUTELY: ii.mtimeset, + SECONDLY: ii.stimeset}[freq] + if ((freq >= HOURLY and + self._byhour and hour not in self._byhour) or + (freq >= MINUTELY and + self._byminute and minute not in self._byminute) or + (freq >= SECONDLY and + self._bysecond and second not in self._bysecond)): + timeset = () + else: + timeset = gettimeset(hour, minute, second) + + total = 0 + count = self._count + while True: + # Get dayset with the right frequency + dayset, start, end = getdayset(year, month, day) + + # Do the "hard" work ;-) + filtered = False + for i in dayset[start:end]: + if ((bymonth and ii.mmask[i] not in bymonth) or + (byweekno and not ii.wnomask[i]) or + (byweekday and ii.wdaymask[i] not in byweekday) or + (ii.nwdaymask and not ii.nwdaymask[i]) or + (byeaster and not ii.eastermask[i]) or + ((bymonthday or bynmonthday) and + ii.mdaymask[i] not in bymonthday and + ii.nmdaymask[i] not in bynmonthday) or + (byyearday and + ((i < ii.yearlen and i+1 not in byyearday and + -ii.yearlen+i not in byyearday) or + (i >= ii.yearlen and i+1-ii.yearlen not in byyearday and + -ii.nextyearlen+i-ii.yearlen not in byyearday)))): + dayset[i] = None + filtered = True + + # Output results + if bysetpos and timeset: + poslist = [] + for pos in bysetpos: + if pos < 0: + daypos, timepos = divmod(pos, len(timeset)) + else: + daypos, timepos = divmod(pos-1, len(timeset)) + try: + i = [x for x in dayset[start:end] + if x is not None][daypos] + time = timeset[timepos] + except IndexError: + pass + else: + date = datetime.date.fromordinal(ii.yearordinal+i) + res = datetime.datetime.combine(date, time) + if res not in poslist: + poslist.append(res) + poslist.sort() + for res in poslist: + if until and res > until: + self._len = total + return + elif res >= self._dtstart: + if count is not None: + count -= 1 + if count < 0: + self._len = total + return + total += 1 + yield res + else: + for i in dayset[start:end]: + if i is not None: + date = datetime.date.fromordinal(ii.yearordinal + i) + for time in timeset: + res = datetime.datetime.combine(date, time) + if until and res > until: + self._len = total + return + elif res >= self._dtstart: + if count is not None: + count -= 1 + if count < 0: + self._len = total + return + + total += 1 + yield res + + # Handle frequency and interval + fixday = False + if freq == YEARLY: + year += interval + if year > datetime.MAXYEAR: + self._len = total + return + ii.rebuild(year, month) + elif freq == MONTHLY: + month += interval + if month > 12: + div, mod = divmod(month, 12) + month = mod + year += div + if month == 0: + month = 12 + year -= 1 + if year > datetime.MAXYEAR: + self._len = total + return + ii.rebuild(year, month) + elif freq == WEEKLY: + if wkst > weekday: + day += -(weekday+1+(6-wkst))+self._interval*7 + else: + day += -(weekday-wkst)+self._interval*7 + weekday = wkst + fixday = True + elif freq == DAILY: + day += interval + fixday = True + elif freq == HOURLY: + if filtered: + # Jump to one iteration before next day + hour += ((23-hour)//interval)*interval + + if byhour: + ndays, hour = self.__mod_distance(value=hour, + byxxx=self._byhour, + base=24) + else: + ndays, hour = divmod(hour+interval, 24) + + if ndays: + day += ndays + fixday = True + + timeset = gettimeset(hour, minute, second) + elif freq == MINUTELY: + if filtered: + # Jump to one iteration before next day + minute += ((1439-(hour*60+minute))//interval)*interval + + valid = False + rep_rate = (24*60) + for j in range(rep_rate // gcd(interval, rep_rate)): + if byminute: + nhours, minute = \ + self.__mod_distance(value=minute, + byxxx=self._byminute, + base=60) + else: + nhours, minute = divmod(minute+interval, 60) + + div, hour = divmod(hour+nhours, 24) + if div: + day += div + fixday = True + filtered = False + + if not byhour or hour in byhour: + valid = True + break + + if not valid: + raise ValueError('Invalid combination of interval and ' + + 'byhour resulting in empty rule.') + + timeset = gettimeset(hour, minute, second) + elif freq == SECONDLY: + if filtered: + # Jump to one iteration before next day + second += (((86399 - (hour * 3600 + minute * 60 + second)) + // interval) * interval) + + rep_rate = (24 * 3600) + valid = False + for j in range(0, rep_rate // gcd(interval, rep_rate)): + if bysecond: + nminutes, second = \ + self.__mod_distance(value=second, + byxxx=self._bysecond, + base=60) + else: + nminutes, second = divmod(second+interval, 60) + + div, minute = divmod(minute+nminutes, 60) + if div: + hour += div + div, hour = divmod(hour, 24) + if div: + day += div + fixday = True + + if ((not byhour or hour in byhour) and + (not byminute or minute in byminute) and + (not bysecond or second in bysecond)): + valid = True + break + + if not valid: + raise ValueError('Invalid combination of interval, ' + + 'byhour and byminute resulting in empty' + + ' rule.') + + timeset = gettimeset(hour, minute, second) + + if fixday and day > 28: + daysinmonth = calendar.monthrange(year, month)[1] + if day > daysinmonth: + while day > daysinmonth: + day -= daysinmonth + month += 1 + if month == 13: + month = 1 + year += 1 + if year > datetime.MAXYEAR: + self._len = total + return + daysinmonth = calendar.monthrange(year, month)[1] + ii.rebuild(year, month) + + def __construct_byset(self, start, byxxx, base): + """ + If a `BYXXX` sequence is passed to the constructor at the same level as + `FREQ` (e.g. `FREQ=HOURLY,BYHOUR={2,4,7},INTERVAL=3`), there are some + specifications which cannot be reached given some starting conditions. + + This occurs whenever the interval is not coprime with the base of a + given unit and the difference between the starting position and the + ending position is not coprime with the greatest common denominator + between the interval and the base. For example, with a FREQ of hourly + starting at 17:00 and an interval of 4, the only valid values for + BYHOUR would be {21, 1, 5, 9, 13, 17}, because 4 and 24 are not + coprime. + + :param start: + Specifies the starting position. + :param byxxx: + An iterable containing the list of allowed values. + :param base: + The largest allowable value for the specified frequency (e.g. + 24 hours, 60 minutes). + + This does not preserve the type of the iterable, returning a set, since + the values should be unique and the order is irrelevant, this will + speed up later lookups. + + In the event of an empty set, raises a :exception:`ValueError`, as this + results in an empty rrule. + """ + + cset = set() + + # Support a single byxxx value. + if isinstance(byxxx, integer_types): + byxxx = (byxxx, ) + + for num in byxxx: + i_gcd = gcd(self._interval, base) + # Use divmod rather than % because we need to wrap negative nums. + if i_gcd == 1 or divmod(num - start, i_gcd)[1] == 0: + cset.add(num) + + if len(cset) == 0: + raise ValueError("Invalid rrule byxxx generates an empty set.") + + return cset + + def __mod_distance(self, value, byxxx, base): + """ + Calculates the next value in a sequence where the `FREQ` parameter is + specified along with a `BYXXX` parameter at the same "level" + (e.g. `HOURLY` specified with `BYHOUR`). + + :param value: + The old value of the component. + :param byxxx: + The `BYXXX` set, which should have been generated by + `rrule._construct_byset`, or something else which checks that a + valid rule is present. + :param base: + The largest allowable value for the specified frequency (e.g. + 24 hours, 60 minutes). + + If a valid value is not found after `base` iterations (the maximum + number before the sequence would start to repeat), this raises a + :exception:`ValueError`, as no valid values were found. + + This returns a tuple of `divmod(n*interval, base)`, where `n` is the + smallest number of `interval` repetitions until the next specified + value in `byxxx` is found. + """ + accumulator = 0 + for ii in range(1, base + 1): + # Using divmod() over % to account for negative intervals + div, value = divmod(value + self._interval, base) + accumulator += div + if value in byxxx: + return (accumulator, value) + + +class _iterinfo(object): + __slots__ = ["rrule", "lastyear", "lastmonth", + "yearlen", "nextyearlen", "yearordinal", "yearweekday", + "mmask", "mrange", "mdaymask", "nmdaymask", + "wdaymask", "wnomask", "nwdaymask", "eastermask"] + + def __init__(self, rrule): + for attr in self.__slots__: + setattr(self, attr, None) + self.rrule = rrule + + def rebuild(self, year, month): + # Every mask is 7 days longer to handle cross-year weekly periods. + rr = self.rrule + if year != self.lastyear: + self.yearlen = 365 + calendar.isleap(year) + self.nextyearlen = 365 + calendar.isleap(year + 1) + firstyday = datetime.date(year, 1, 1) + self.yearordinal = firstyday.toordinal() + self.yearweekday = firstyday.weekday() + + wday = datetime.date(year, 1, 1).weekday() + if self.yearlen == 365: + self.mmask = M365MASK + self.mdaymask = MDAY365MASK + self.nmdaymask = NMDAY365MASK + self.wdaymask = WDAYMASK[wday:] + self.mrange = M365RANGE + else: + self.mmask = M366MASK + self.mdaymask = MDAY366MASK + self.nmdaymask = NMDAY366MASK + self.wdaymask = WDAYMASK[wday:] + self.mrange = M366RANGE + + if not rr._byweekno: + self.wnomask = None + else: + self.wnomask = [0]*(self.yearlen+7) + # no1wkst = firstwkst = self.wdaymask.index(rr._wkst) + no1wkst = firstwkst = (7-self.yearweekday+rr._wkst) % 7 + if no1wkst >= 4: + no1wkst = 0 + # Number of days in the year, plus the days we got + # from last year. + wyearlen = self.yearlen+(self.yearweekday-rr._wkst) % 7 + else: + # Number of days in the year, minus the days we + # left in last year. + wyearlen = self.yearlen-no1wkst + div, mod = divmod(wyearlen, 7) + numweeks = div+mod//4 + for n in rr._byweekno: + if n < 0: + n += numweeks+1 + if not (0 < n <= numweeks): + continue + if n > 1: + i = no1wkst+(n-1)*7 + if no1wkst != firstwkst: + i -= 7-firstwkst + else: + i = no1wkst + for j in range(7): + self.wnomask[i] = 1 + i += 1 + if self.wdaymask[i] == rr._wkst: + break + if 1 in rr._byweekno: + # Check week number 1 of next year as well + # TODO: Check -numweeks for next year. + i = no1wkst+numweeks*7 + if no1wkst != firstwkst: + i -= 7-firstwkst + if i < self.yearlen: + # If week starts in next year, we + # don't care about it. + for j in range(7): + self.wnomask[i] = 1 + i += 1 + if self.wdaymask[i] == rr._wkst: + break + if no1wkst: + # Check last week number of last year as + # well. If no1wkst is 0, either the year + # started on week start, or week number 1 + # got days from last year, so there are no + # days from last year's last week number in + # this year. + if -1 not in rr._byweekno: + lyearweekday = datetime.date(year-1, 1, 1).weekday() + lno1wkst = (7-lyearweekday+rr._wkst) % 7 + lyearlen = 365+calendar.isleap(year-1) + if lno1wkst >= 4: + lno1wkst = 0 + lnumweeks = 52+(lyearlen + + (lyearweekday-rr._wkst) % 7) % 7//4 + else: + lnumweeks = 52+(self.yearlen-no1wkst) % 7//4 + else: + lnumweeks = -1 + if lnumweeks in rr._byweekno: + for i in range(no1wkst): + self.wnomask[i] = 1 + + if (rr._bynweekday and (month != self.lastmonth or + year != self.lastyear)): + ranges = [] + if rr._freq == YEARLY: + if rr._bymonth: + for month in rr._bymonth: + ranges.append(self.mrange[month-1:month+1]) + else: + ranges = [(0, self.yearlen)] + elif rr._freq == MONTHLY: + ranges = [self.mrange[month-1:month+1]] + if ranges: + # Weekly frequency won't get here, so we may not + # care about cross-year weekly periods. + self.nwdaymask = [0]*self.yearlen + for first, last in ranges: + last -= 1 + for wday, n in rr._bynweekday: + if n < 0: + i = last+(n+1)*7 + i -= (self.wdaymask[i]-wday) % 7 + else: + i = first+(n-1)*7 + i += (7-self.wdaymask[i]+wday) % 7 + if first <= i <= last: + self.nwdaymask[i] = 1 + + if rr._byeaster: + self.eastermask = [0]*(self.yearlen+7) + eyday = easter.easter(year).toordinal()-self.yearordinal + for offset in rr._byeaster: + self.eastermask[eyday+offset] = 1 + + self.lastyear = year + self.lastmonth = month + + def ydayset(self, year, month, day): + return list(range(self.yearlen)), 0, self.yearlen + + def mdayset(self, year, month, day): + dset = [None]*self.yearlen + start, end = self.mrange[month-1:month+1] + for i in range(start, end): + dset[i] = i + return dset, start, end + + def wdayset(self, year, month, day): + # We need to handle cross-year weeks here. + dset = [None]*(self.yearlen+7) + i = datetime.date(year, month, day).toordinal()-self.yearordinal + start = i + for j in range(7): + dset[i] = i + i += 1 + # if (not (0 <= i < self.yearlen) or + # self.wdaymask[i] == self.rrule._wkst): + # This will cross the year boundary, if necessary. + if self.wdaymask[i] == self.rrule._wkst: + break + return dset, start, i + + def ddayset(self, year, month, day): + dset = [None] * self.yearlen + i = datetime.date(year, month, day).toordinal() - self.yearordinal + dset[i] = i + return dset, i, i + 1 + + def htimeset(self, hour, minute, second): + tset = [] + rr = self.rrule + for minute in rr._byminute: + for second in rr._bysecond: + tset.append(datetime.time(hour, minute, second, + tzinfo=rr._tzinfo)) + tset.sort() + return tset + + def mtimeset(self, hour, minute, second): + tset = [] + rr = self.rrule + for second in rr._bysecond: + tset.append(datetime.time(hour, minute, second, tzinfo=rr._tzinfo)) + tset.sort() + return tset + + def stimeset(self, hour, minute, second): + return (datetime.time(hour, minute, second, + tzinfo=self.rrule._tzinfo),) + + +class rruleset(rrulebase): + """ The rruleset type allows more complex recurrence setups, mixing + multiple rules, dates, exclusion rules, and exclusion dates. The type + constructor takes the following keyword arguments: + + :param cache: If True, caching of results will be enabled, improving + performance of multiple queries considerably. """ + + class _genitem(object): + def __init__(self, genlist, gen): + try: + self.dt = advance_iterator(gen) + genlist.append(self) + except StopIteration: + pass + self.genlist = genlist + self.gen = gen + + def __next__(self): + try: + self.dt = advance_iterator(self.gen) + except StopIteration: + if self.genlist[0] is self: + heapq.heappop(self.genlist) + else: + self.genlist.remove(self) + heapq.heapify(self.genlist) + + next = __next__ + + def __lt__(self, other): + return self.dt < other.dt + + def __gt__(self, other): + return self.dt > other.dt + + def __eq__(self, other): + return self.dt == other.dt + + def __ne__(self, other): + return self.dt != other.dt + + def __init__(self, cache=False): + super(rruleset, self).__init__(cache) + self._rrule = [] + self._rdate = [] + self._exrule = [] + self._exdate = [] + + @_invalidates_cache + def rrule(self, rrule): + """ Include the given :py:class:`rrule` instance in the recurrence set + generation. """ + self._rrule.append(rrule) + + @_invalidates_cache + def rdate(self, rdate): + """ Include the given :py:class:`datetime` instance in the recurrence + set generation. """ + self._rdate.append(rdate) + + @_invalidates_cache + def exrule(self, exrule): + """ Include the given rrule instance in the recurrence set exclusion + list. Dates which are part of the given recurrence rules will not + be generated, even if some inclusive rrule or rdate matches them. + """ + self._exrule.append(exrule) + + @_invalidates_cache + def exdate(self, exdate): + """ Include the given datetime instance in the recurrence set + exclusion list. Dates included that way will not be generated, + even if some inclusive rrule or rdate matches them. """ + self._exdate.append(exdate) + + def _iter(self): + rlist = [] + self._rdate.sort() + self._genitem(rlist, iter(self._rdate)) + for gen in [iter(x) for x in self._rrule]: + self._genitem(rlist, gen) + exlist = [] + self._exdate.sort() + self._genitem(exlist, iter(self._exdate)) + for gen in [iter(x) for x in self._exrule]: + self._genitem(exlist, gen) + lastdt = None + total = 0 + heapq.heapify(rlist) + heapq.heapify(exlist) + while rlist: + ritem = rlist[0] + if not lastdt or lastdt != ritem.dt: + while exlist and exlist[0] < ritem: + exitem = exlist[0] + advance_iterator(exitem) + if exlist and exlist[0] is exitem: + heapq.heapreplace(exlist, exitem) + if not exlist or ritem != exlist[0]: + total += 1 + yield ritem.dt + lastdt = ritem.dt + advance_iterator(ritem) + if rlist and rlist[0] is ritem: + heapq.heapreplace(rlist, ritem) + self._len = total + + +class _rrulestr(object): + + _freq_map = {"YEARLY": YEARLY, + "MONTHLY": MONTHLY, + "WEEKLY": WEEKLY, + "DAILY": DAILY, + "HOURLY": HOURLY, + "MINUTELY": MINUTELY, + "SECONDLY": SECONDLY} + + _weekday_map = {"MO": 0, "TU": 1, "WE": 2, "TH": 3, + "FR": 4, "SA": 5, "SU": 6} + + def _handle_int(self, rrkwargs, name, value, **kwargs): + rrkwargs[name.lower()] = int(value) + + def _handle_int_list(self, rrkwargs, name, value, **kwargs): + rrkwargs[name.lower()] = [int(x) for x in value.split(',')] + + _handle_INTERVAL = _handle_int + _handle_COUNT = _handle_int + _handle_BYSETPOS = _handle_int_list + _handle_BYMONTH = _handle_int_list + _handle_BYMONTHDAY = _handle_int_list + _handle_BYYEARDAY = _handle_int_list + _handle_BYEASTER = _handle_int_list + _handle_BYWEEKNO = _handle_int_list + _handle_BYHOUR = _handle_int_list + _handle_BYMINUTE = _handle_int_list + _handle_BYSECOND = _handle_int_list + + def _handle_FREQ(self, rrkwargs, name, value, **kwargs): + rrkwargs["freq"] = self._freq_map[value] + + def _handle_UNTIL(self, rrkwargs, name, value, **kwargs): + global parser + if not parser: + from dateutil import parser + try: + rrkwargs["until"] = parser.parse(value, + ignoretz=kwargs.get("ignoretz"), + tzinfos=kwargs.get("tzinfos")) + except ValueError: + raise ValueError("invalid until date") + + def _handle_WKST(self, rrkwargs, name, value, **kwargs): + rrkwargs["wkst"] = self._weekday_map[value] + + def _handle_BYWEEKDAY(self, rrkwargs, name, value, **kwargs): + """ + Two ways to specify this: +1MO or MO(+1) + """ + l = [] + for wday in value.split(','): + if '(' in wday: + # If it's of the form TH(+1), etc. + splt = wday.split('(') + w = splt[0] + n = int(splt[1][:-1]) + elif len(wday): + # If it's of the form +1MO + for i in range(len(wday)): + if wday[i] not in '+-0123456789': + break + n = wday[:i] or None + w = wday[i:] + if n: + n = int(n) + else: + raise ValueError("Invalid (empty) BYDAY specification.") + + l.append(weekdays[self._weekday_map[w]](n)) + rrkwargs["byweekday"] = l + + _handle_BYDAY = _handle_BYWEEKDAY + + def _parse_rfc_rrule(self, line, + dtstart=None, + cache=False, + ignoretz=False, + tzinfos=None): + if line.find(':') != -1: + name, value = line.split(':') + if name != "RRULE": + raise ValueError("unknown parameter name") + else: + value = line + rrkwargs = {} + for pair in value.split(';'): + name, value = pair.split('=') + name = name.upper() + value = value.upper() + try: + getattr(self, "_handle_"+name)(rrkwargs, name, value, + ignoretz=ignoretz, + tzinfos=tzinfos) + except AttributeError: + raise ValueError("unknown parameter '%s'" % name) + except (KeyError, ValueError): + raise ValueError("invalid '%s': %s" % (name, value)) + return rrule(dtstart=dtstart, cache=cache, **rrkwargs) + + def _parse_rfc(self, s, + dtstart=None, + cache=False, + unfold=False, + forceset=False, + compatible=False, + ignoretz=False, + tzinfos=None): + global parser + if compatible: + forceset = True + unfold = True + s = s.upper() + if not s.strip(): + raise ValueError("empty string") + if unfold: + lines = s.splitlines() + i = 0 + while i < len(lines): + line = lines[i].rstrip() + if not line: + del lines[i] + elif i > 0 and line[0] == " ": + lines[i-1] += line[1:] + del lines[i] + else: + i += 1 + else: + lines = s.split() + if (not forceset and len(lines) == 1 and (s.find(':') == -1 or + s.startswith('RRULE:'))): + return self._parse_rfc_rrule(lines[0], cache=cache, + dtstart=dtstart, ignoretz=ignoretz, + tzinfos=tzinfos) + else: + rrulevals = [] + rdatevals = [] + exrulevals = [] + exdatevals = [] + for line in lines: + if not line: + continue + if line.find(':') == -1: + name = "RRULE" + value = line + else: + name, value = line.split(':', 1) + parms = name.split(';') + if not parms: + raise ValueError("empty property name") + name = parms[0] + parms = parms[1:] + if name == "RRULE": + for parm in parms: + raise ValueError("unsupported RRULE parm: "+parm) + rrulevals.append(value) + elif name == "RDATE": + for parm in parms: + if parm != "VALUE=DATE-TIME": + raise ValueError("unsupported RDATE parm: "+parm) + rdatevals.append(value) + elif name == "EXRULE": + for parm in parms: + raise ValueError("unsupported EXRULE parm: "+parm) + exrulevals.append(value) + elif name == "EXDATE": + for parm in parms: + if parm != "VALUE=DATE-TIME": + raise ValueError("unsupported EXDATE parm: "+parm) + exdatevals.append(value) + elif name == "DTSTART": + for parm in parms: + raise ValueError("unsupported DTSTART parm: "+parm) + if not parser: + from dateutil import parser + dtstart = parser.parse(value, ignoretz=ignoretz, + tzinfos=tzinfos) + else: + raise ValueError("unsupported property: "+name) + if (forceset or len(rrulevals) > 1 or rdatevals + or exrulevals or exdatevals): + if not parser and (rdatevals or exdatevals): + from dateutil import parser + rset = rruleset(cache=cache) + for value in rrulevals: + rset.rrule(self._parse_rfc_rrule(value, dtstart=dtstart, + ignoretz=ignoretz, + tzinfos=tzinfos)) + for value in rdatevals: + for datestr in value.split(','): + rset.rdate(parser.parse(datestr, + ignoretz=ignoretz, + tzinfos=tzinfos)) + for value in exrulevals: + rset.exrule(self._parse_rfc_rrule(value, dtstart=dtstart, + ignoretz=ignoretz, + tzinfos=tzinfos)) + for value in exdatevals: + for datestr in value.split(','): + rset.exdate(parser.parse(datestr, + ignoretz=ignoretz, + tzinfos=tzinfos)) + if compatible and dtstart: + rset.rdate(dtstart) + return rset + else: + return self._parse_rfc_rrule(rrulevals[0], + dtstart=dtstart, + cache=cache, + ignoretz=ignoretz, + tzinfos=tzinfos) + + def __call__(self, s, **kwargs): + return self._parse_rfc(s, **kwargs) + + +rrulestr = _rrulestr() + +# vim:ts=4:sw=4:et diff --git a/python/dateutil/tz/__init__.py b/python/dateutil/tz/__init__.py new file mode 100644 index 0000000..b0a5043 --- /dev/null +++ b/python/dateutil/tz/__init__.py @@ -0,0 +1,5 @@ +from .tz import * + +__all__ = ["tzutc", "tzoffset", "tzlocal", "tzfile", "tzrange", + "tzstr", "tzical", "tzwin", "tzwinlocal", "gettz", + "enfold", "datetime_ambiguous", "datetime_exists"] diff --git a/python/dateutil/tz/_common.py b/python/dateutil/tz/_common.py new file mode 100644 index 0000000..f1cf2af --- /dev/null +++ b/python/dateutil/tz/_common.py @@ -0,0 +1,394 @@ +from six import PY3 + +from functools import wraps + +from datetime import datetime, timedelta, tzinfo + + +ZERO = timedelta(0) + +__all__ = ['tzname_in_python2', 'enfold'] + + +def tzname_in_python2(namefunc): + """Change unicode output into bytestrings in Python 2 + + tzname() API changed in Python 3. It used to return bytes, but was changed + to unicode strings + """ + def adjust_encoding(*args, **kwargs): + name = namefunc(*args, **kwargs) + if name is not None and not PY3: + name = name.encode() + + return name + + return adjust_encoding + + +# The following is adapted from Alexander Belopolsky's tz library +# https://github.com/abalkin/tz +if hasattr(datetime, 'fold'): + # This is the pre-python 3.6 fold situation + def enfold(dt, fold=1): + """ + Provides a unified interface for assigning the ``fold`` attribute to + datetimes both before and after the implementation of PEP-495. + + :param fold: + The value for the ``fold`` attribute in the returned datetime. This + should be either 0 or 1. + + :return: + Returns an object for which ``getattr(dt, 'fold', 0)`` returns + ``fold`` for all versions of Python. In versions prior to + Python 3.6, this is a ``_DatetimeWithFold`` object, which is a + subclass of :py:class:`datetime.datetime` with the ``fold`` + attribute added, if ``fold`` is 1. + + .. versionadded:: 2.6.0 + """ + return dt.replace(fold=fold) + +else: + class _DatetimeWithFold(datetime): + """ + This is a class designed to provide a PEP 495-compliant interface for + Python versions before 3.6. It is used only for dates in a fold, so + the ``fold`` attribute is fixed at ``1``. + + .. versionadded:: 2.6.0 + """ + __slots__ = () + + @property + def fold(self): + return 1 + + def enfold(dt, fold=1): + """ + Provides a unified interface for assigning the ``fold`` attribute to + datetimes both before and after the implementation of PEP-495. + + :param fold: + The value for the ``fold`` attribute in the returned datetime. This + should be either 0 or 1. + + :return: + Returns an object for which ``getattr(dt, 'fold', 0)`` returns + ``fold`` for all versions of Python. In versions prior to + Python 3.6, this is a ``_DatetimeWithFold`` object, which is a + subclass of :py:class:`datetime.datetime` with the ``fold`` + attribute added, if ``fold`` is 1. + + .. versionadded:: 2.6.0 + """ + if getattr(dt, 'fold', 0) == fold: + return dt + + args = dt.timetuple()[:6] + args += (dt.microsecond, dt.tzinfo) + + if fold: + return _DatetimeWithFold(*args) + else: + return datetime(*args) + + +def _validate_fromutc_inputs(f): + """ + The CPython version of ``fromutc`` checks that the input is a ``datetime`` + object and that ``self`` is attached as its ``tzinfo``. + """ + @wraps(f) + def fromutc(self, dt): + if not isinstance(dt, datetime): + raise TypeError("fromutc() requires a datetime argument") + if dt.tzinfo is not self: + raise ValueError("dt.tzinfo is not self") + + return f(self, dt) + + return fromutc + + +class _tzinfo(tzinfo): + """ + Base class for all ``dateutil`` ``tzinfo`` objects. + """ + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + + dt = dt.replace(tzinfo=self) + + wall_0 = enfold(dt, fold=0) + wall_1 = enfold(dt, fold=1) + + same_offset = wall_0.utcoffset() == wall_1.utcoffset() + same_dt = wall_0.replace(tzinfo=None) == wall_1.replace(tzinfo=None) + + return same_dt and not same_offset + + def _fold_status(self, dt_utc, dt_wall): + """ + Determine the fold status of a "wall" datetime, given a representation + of the same datetime as a (naive) UTC datetime. This is calculated based + on the assumption that ``dt.utcoffset() - dt.dst()`` is constant for all + datetimes, and that this offset is the actual number of hours separating + ``dt_utc`` and ``dt_wall``. + + :param dt_utc: + Representation of the datetime as UTC + + :param dt_wall: + Representation of the datetime as "wall time". This parameter must + either have a `fold` attribute or have a fold-naive + :class:`datetime.tzinfo` attached, otherwise the calculation may + fail. + """ + if self.is_ambiguous(dt_wall): + delta_wall = dt_wall - dt_utc + _fold = int(delta_wall == (dt_utc.utcoffset() - dt_utc.dst())) + else: + _fold = 0 + + return _fold + + def _fold(self, dt): + return getattr(dt, 'fold', 0) + + def _fromutc(self, dt): + """ + Given a timezone-aware datetime in a given timezone, calculates a + timezone-aware datetime in a new timezone. + + Since this is the one time that we *know* we have an unambiguous + datetime object, we take this opportunity to determine whether the + datetime is ambiguous and in a "fold" state (e.g. if it's the first + occurence, chronologically, of the ambiguous datetime). + + :param dt: + A timezone-aware :class:`datetime.datetime` object. + """ + + # Re-implement the algorithm from Python's datetime.py + dtoff = dt.utcoffset() + if dtoff is None: + raise ValueError("fromutc() requires a non-None utcoffset() " + "result") + + # The original datetime.py code assumes that `dst()` defaults to + # zero during ambiguous times. PEP 495 inverts this presumption, so + # for pre-PEP 495 versions of python, we need to tweak the algorithm. + dtdst = dt.dst() + if dtdst is None: + raise ValueError("fromutc() requires a non-None dst() result") + delta = dtoff - dtdst + + dt += delta + # Set fold=1 so we can default to being in the fold for + # ambiguous dates. + dtdst = enfold(dt, fold=1).dst() + if dtdst is None: + raise ValueError("fromutc(): dt.dst gave inconsistent " + "results; cannot convert") + return dt + dtdst + + @_validate_fromutc_inputs + def fromutc(self, dt): + """ + Given a timezone-aware datetime in a given timezone, calculates a + timezone-aware datetime in a new timezone. + + Since this is the one time that we *know* we have an unambiguous + datetime object, we take this opportunity to determine whether the + datetime is ambiguous and in a "fold" state (e.g. if it's the first + occurance, chronologically, of the ambiguous datetime). + + :param dt: + A timezone-aware :class:`datetime.datetime` object. + """ + dt_wall = self._fromutc(dt) + + # Calculate the fold status given the two datetimes. + _fold = self._fold_status(dt, dt_wall) + + # Set the default fold value for ambiguous dates + return enfold(dt_wall, fold=_fold) + + +class tzrangebase(_tzinfo): + """ + This is an abstract base class for time zones represented by an annual + transition into and out of DST. Child classes should implement the following + methods: + + * ``__init__(self, *args, **kwargs)`` + * ``transitions(self, year)`` - this is expected to return a tuple of + datetimes representing the DST on and off transitions in standard + time. + + A fully initialized ``tzrangebase`` subclass should also provide the + following attributes: + * ``hasdst``: Boolean whether or not the zone uses DST. + * ``_dst_offset`` / ``_std_offset``: :class:`datetime.timedelta` objects + representing the respective UTC offsets. + * ``_dst_abbr`` / ``_std_abbr``: Strings representing the timezone short + abbreviations in DST and STD, respectively. + * ``_hasdst``: Whether or not the zone has DST. + + .. versionadded:: 2.6.0 + """ + def __init__(self): + raise NotImplementedError('tzrangebase is an abstract base class') + + def utcoffset(self, dt): + isdst = self._isdst(dt) + + if isdst is None: + return None + elif isdst: + return self._dst_offset + else: + return self._std_offset + + def dst(self, dt): + isdst = self._isdst(dt) + + if isdst is None: + return None + elif isdst: + return self._dst_base_offset + else: + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + if self._isdst(dt): + return self._dst_abbr + else: + return self._std_abbr + + def fromutc(self, dt): + """ Given a datetime in UTC, return local time """ + if not isinstance(dt, datetime): + raise TypeError("fromutc() requires a datetime argument") + + if dt.tzinfo is not self: + raise ValueError("dt.tzinfo is not self") + + # Get transitions - if there are none, fixed offset + transitions = self.transitions(dt.year) + if transitions is None: + return dt + self.utcoffset(dt) + + # Get the transition times in UTC + dston, dstoff = transitions + + dston -= self._std_offset + dstoff -= self._std_offset + + utc_transitions = (dston, dstoff) + dt_utc = dt.replace(tzinfo=None) + + isdst = self._naive_isdst(dt_utc, utc_transitions) + + if isdst: + dt_wall = dt + self._dst_offset + else: + dt_wall = dt + self._std_offset + + _fold = int(not isdst and self.is_ambiguous(dt_wall)) + + return enfold(dt_wall, fold=_fold) + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + if not self.hasdst: + return False + + start, end = self.transitions(dt.year) + + dt = dt.replace(tzinfo=None) + return (end <= dt < end + self._dst_base_offset) + + def _isdst(self, dt): + if not self.hasdst: + return False + elif dt is None: + return None + + transitions = self.transitions(dt.year) + + if transitions is None: + return False + + dt = dt.replace(tzinfo=None) + + isdst = self._naive_isdst(dt, transitions) + + # Handle ambiguous dates + if not isdst and self.is_ambiguous(dt): + return not self._fold(dt) + else: + return isdst + + def _naive_isdst(self, dt, transitions): + dston, dstoff = transitions + + dt = dt.replace(tzinfo=None) + + if dston < dstoff: + isdst = dston <= dt < dstoff + else: + isdst = not dstoff <= dt < dston + + return isdst + + @property + def _dst_base_offset(self): + return self._dst_offset - self._std_offset + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s(...)" % self.__class__.__name__ + + __reduce__ = object.__reduce__ + + +def _total_seconds(td): + # Python 2.6 doesn't have a total_seconds() method on timedelta objects + return ((td.seconds + td.days * 86400) * 1000000 + + td.microseconds) // 1000000 + + +_total_seconds = getattr(timedelta, 'total_seconds', _total_seconds) diff --git a/python/dateutil/tz/tz.py b/python/dateutil/tz/tz.py new file mode 100644 index 0000000..9468282 --- /dev/null +++ b/python/dateutil/tz/tz.py @@ -0,0 +1,1511 @@ +# -*- coding: utf-8 -*- +""" +This module offers timezone implementations subclassing the abstract +:py:`datetime.tzinfo` type. There are classes to handle tzfile format files +(usually are in :file:`/etc/localtime`, :file:`/usr/share/zoneinfo`, etc), TZ +environment string (in all known formats), given ranges (with help from +relative deltas), local machine timezone, fixed offset timezone, and UTC +timezone. +""" +import datetime +import struct +import time +import sys +import os +import bisect + +from six import string_types +from ._common import tzname_in_python2, _tzinfo, _total_seconds +from ._common import tzrangebase, enfold +from ._common import _validate_fromutc_inputs + +try: + from .win import tzwin, tzwinlocal +except ImportError: + tzwin = tzwinlocal = None + +ZERO = datetime.timedelta(0) +EPOCH = datetime.datetime.utcfromtimestamp(0) +EPOCHORDINAL = EPOCH.toordinal() + + +class tzutc(datetime.tzinfo): + """ + This is a tzinfo object that represents the UTC time zone. + """ + def utcoffset(self, dt): + return ZERO + + def dst(self, dt): + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + return "UTC" + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + return False + + @_validate_fromutc_inputs + def fromutc(self, dt): + """ + Fast track version of fromutc() returns the original ``dt`` object for + any valid :py:class:`datetime.datetime` object. + """ + return dt + + def __eq__(self, other): + if not isinstance(other, (tzutc, tzoffset)): + return NotImplemented + + return (isinstance(other, tzutc) or + (isinstance(other, tzoffset) and other._offset == ZERO)) + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s()" % self.__class__.__name__ + + __reduce__ = object.__reduce__ + + +class tzoffset(datetime.tzinfo): + """ + A simple class for representing a fixed offset from UTC. + + :param name: + The timezone name, to be returned when ``tzname()`` is called. + + :param offset: + The time zone offset in seconds, or (since version 2.6.0, represented + as a :py:class:`datetime.timedelta` object. + """ + def __init__(self, name, offset): + self._name = name + + try: + # Allow a timedelta + offset = _total_seconds(offset) + except (TypeError, AttributeError): + pass + self._offset = datetime.timedelta(seconds=offset) + + def utcoffset(self, dt): + return self._offset + + def dst(self, dt): + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + return self._name + + @_validate_fromutc_inputs + def fromutc(self, dt): + return dt + self._offset + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + return False + + def __eq__(self, other): + if not isinstance(other, tzoffset): + return NotImplemented + + return self._offset == other._offset + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s(%s, %s)" % (self.__class__.__name__, + repr(self._name), + int(_total_seconds(self._offset))) + + __reduce__ = object.__reduce__ + + +class tzlocal(_tzinfo): + """ + A :class:`tzinfo` subclass built around the ``time`` timezone functions. + """ + def __init__(self): + super(tzlocal, self).__init__() + + self._std_offset = datetime.timedelta(seconds=-time.timezone) + if time.daylight: + self._dst_offset = datetime.timedelta(seconds=-time.altzone) + else: + self._dst_offset = self._std_offset + + self._dst_saved = self._dst_offset - self._std_offset + self._hasdst = bool(self._dst_saved) + + def utcoffset(self, dt): + if dt is None and self._hasdst: + return None + + if self._isdst(dt): + return self._dst_offset + else: + return self._std_offset + + def dst(self, dt): + if dt is None and self._hasdst: + return None + + if self._isdst(dt): + return self._dst_offset - self._std_offset + else: + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + return time.tzname[self._isdst(dt)] + + def is_ambiguous(self, dt): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + naive_dst = self._naive_is_dst(dt) + return (not naive_dst and + (naive_dst != self._naive_is_dst(dt - self._dst_saved))) + + def _naive_is_dst(self, dt): + timestamp = _datetime_to_timestamp(dt) + return time.localtime(timestamp + time.timezone).tm_isdst + + def _isdst(self, dt, fold_naive=True): + # We can't use mktime here. It is unstable when deciding if + # the hour near to a change is DST or not. + # + # timestamp = time.mktime((dt.year, dt.month, dt.day, dt.hour, + # dt.minute, dt.second, dt.weekday(), 0, -1)) + # return time.localtime(timestamp).tm_isdst + # + # The code above yields the following result: + # + # >>> import tz, datetime + # >>> t = tz.tzlocal() + # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() + # 'BRDT' + # >>> datetime.datetime(2003,2,16,0,tzinfo=t).tzname() + # 'BRST' + # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() + # 'BRST' + # >>> datetime.datetime(2003,2,15,22,tzinfo=t).tzname() + # 'BRDT' + # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() + # 'BRDT' + # + # Here is a more stable implementation: + # + if not self._hasdst: + return False + + # Check for ambiguous times: + dstval = self._naive_is_dst(dt) + fold = getattr(dt, 'fold', None) + + if self.is_ambiguous(dt): + if fold is not None: + return not self._fold(dt) + else: + return True + + return dstval + + def __eq__(self, other): + if not isinstance(other, tzlocal): + return NotImplemented + + return (self._std_offset == other._std_offset and + self._dst_offset == other._dst_offset) + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s()" % self.__class__.__name__ + + __reduce__ = object.__reduce__ + + +class _ttinfo(object): + __slots__ = ["offset", "delta", "isdst", "abbr", + "isstd", "isgmt", "dstoffset"] + + def __init__(self): + for attr in self.__slots__: + setattr(self, attr, None) + + def __repr__(self): + l = [] + for attr in self.__slots__: + value = getattr(self, attr) + if value is not None: + l.append("%s=%s" % (attr, repr(value))) + return "%s(%s)" % (self.__class__.__name__, ", ".join(l)) + + def __eq__(self, other): + if not isinstance(other, _ttinfo): + return NotImplemented + + return (self.offset == other.offset and + self.delta == other.delta and + self.isdst == other.isdst and + self.abbr == other.abbr and + self.isstd == other.isstd and + self.isgmt == other.isgmt and + self.dstoffset == other.dstoffset) + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __getstate__(self): + state = {} + for name in self.__slots__: + state[name] = getattr(self, name, None) + return state + + def __setstate__(self, state): + for name in self.__slots__: + if name in state: + setattr(self, name, state[name]) + + +class _tzfile(object): + """ + Lightweight class for holding the relevant transition and time zone + information read from binary tzfiles. + """ + attrs = ['trans_list', 'trans_list_utc', 'trans_idx', 'ttinfo_list', + 'ttinfo_std', 'ttinfo_dst', 'ttinfo_before', 'ttinfo_first'] + + def __init__(self, **kwargs): + for attr in self.attrs: + setattr(self, attr, kwargs.get(attr, None)) + + +class tzfile(_tzinfo): + """ + This is a ``tzinfo`` subclass thant allows one to use the ``tzfile(5)`` + format timezone files to extract current and historical zone information. + + :param fileobj: + This can be an opened file stream or a file name that the time zone + information can be read from. + + :param filename: + This is an optional parameter specifying the source of the time zone + information in the event that ``fileobj`` is a file object. If omitted + and ``fileobj`` is a file stream, this parameter will be set either to + ``fileobj``'s ``name`` attribute or to ``repr(fileobj)``. + + See `Sources for Time Zone and Daylight Saving Time Data + <http://www.twinsun.com/tz/tz-link.htm>`_ for more information. Time zone + files can be compiled from the `IANA Time Zone database files + <https://www.iana.org/time-zones>`_ with the `zic time zone compiler + <https://www.freebsd.org/cgi/man.cgi?query=zic&sektion=8>`_ + """ + + def __init__(self, fileobj, filename=None): + super(tzfile, self).__init__() + + file_opened_here = False + if isinstance(fileobj, string_types): + self._filename = fileobj + fileobj = open(fileobj, 'rb') + file_opened_here = True + elif filename is not None: + self._filename = filename + elif hasattr(fileobj, "name"): + self._filename = fileobj.name + else: + self._filename = repr(fileobj) + + if fileobj is not None: + if not file_opened_here: + fileobj = _ContextWrapper(fileobj) + + with fileobj as file_stream: + tzobj = self._read_tzfile(file_stream) + + self._set_tzdata(tzobj) + + def _set_tzdata(self, tzobj): + """ Set the time zone data of this object from a _tzfile object """ + # Copy the relevant attributes over as private attributes + for attr in _tzfile.attrs: + setattr(self, '_' + attr, getattr(tzobj, attr)) + + def _read_tzfile(self, fileobj): + out = _tzfile() + + # From tzfile(5): + # + # The time zone information files used by tzset(3) + # begin with the magic characters "TZif" to identify + # them as time zone information files, followed by + # sixteen bytes reserved for future use, followed by + # six four-byte values of type long, written in a + # ``standard'' byte order (the high-order byte + # of the value is written first). + if fileobj.read(4).decode() != "TZif": + raise ValueError("magic not found") + + fileobj.read(16) + + ( + # The number of UTC/local indicators stored in the file. + ttisgmtcnt, + + # The number of standard/wall indicators stored in the file. + ttisstdcnt, + + # The number of leap seconds for which data is + # stored in the file. + leapcnt, + + # The number of "transition times" for which data + # is stored in the file. + timecnt, + + # The number of "local time types" for which data + # is stored in the file (must not be zero). + typecnt, + + # The number of characters of "time zone + # abbreviation strings" stored in the file. + charcnt, + + ) = struct.unpack(">6l", fileobj.read(24)) + + # The above header is followed by tzh_timecnt four-byte + # values of type long, sorted in ascending order. + # These values are written in ``standard'' byte order. + # Each is used as a transition time (as returned by + # time(2)) at which the rules for computing local time + # change. + + if timecnt: + out.trans_list_utc = list(struct.unpack(">%dl" % timecnt, + fileobj.read(timecnt*4))) + else: + out.trans_list_utc = [] + + # Next come tzh_timecnt one-byte values of type unsigned + # char; each one tells which of the different types of + # ``local time'' types described in the file is associated + # with the same-indexed transition time. These values + # serve as indices into an array of ttinfo structures that + # appears next in the file. + + if timecnt: + out.trans_idx = struct.unpack(">%dB" % timecnt, + fileobj.read(timecnt)) + else: + out.trans_idx = [] + + # Each ttinfo structure is written as a four-byte value + # for tt_gmtoff of type long, in a standard byte + # order, followed by a one-byte value for tt_isdst + # and a one-byte value for tt_abbrind. In each + # structure, tt_gmtoff gives the number of + # seconds to be added to UTC, tt_isdst tells whether + # tm_isdst should be set by localtime(3), and + # tt_abbrind serves as an index into the array of + # time zone abbreviation characters that follow the + # ttinfo structure(s) in the file. + + ttinfo = [] + + for i in range(typecnt): + ttinfo.append(struct.unpack(">lbb", fileobj.read(6))) + + abbr = fileobj.read(charcnt).decode() + + # Then there are tzh_leapcnt pairs of four-byte + # values, written in standard byte order; the + # first value of each pair gives the time (as + # returned by time(2)) at which a leap second + # occurs; the second gives the total number of + # leap seconds to be applied after the given time. + # The pairs of values are sorted in ascending order + # by time. + + # Not used, for now (but seek for correct file position) + if leapcnt: + fileobj.seek(leapcnt * 8, os.SEEK_CUR) + + # Then there are tzh_ttisstdcnt standard/wall + # indicators, each stored as a one-byte value; + # they tell whether the transition times associated + # with local time types were specified as standard + # time or wall clock time, and are used when + # a time zone file is used in handling POSIX-style + # time zone environment variables. + + if ttisstdcnt: + isstd = struct.unpack(">%db" % ttisstdcnt, + fileobj.read(ttisstdcnt)) + + # Finally, there are tzh_ttisgmtcnt UTC/local + # indicators, each stored as a one-byte value; + # they tell whether the transition times associated + # with local time types were specified as UTC or + # local time, and are used when a time zone file + # is used in handling POSIX-style time zone envi- + # ronment variables. + + if ttisgmtcnt: + isgmt = struct.unpack(">%db" % ttisgmtcnt, + fileobj.read(ttisgmtcnt)) + + # Build ttinfo list + out.ttinfo_list = [] + for i in range(typecnt): + gmtoff, isdst, abbrind = ttinfo[i] + # Round to full-minutes if that's not the case. Python's + # datetime doesn't accept sub-minute timezones. Check + # http://python.org/sf/1447945 for some information. + gmtoff = 60 * ((gmtoff + 30) // 60) + tti = _ttinfo() + tti.offset = gmtoff + tti.dstoffset = datetime.timedelta(0) + tti.delta = datetime.timedelta(seconds=gmtoff) + tti.isdst = isdst + tti.abbr = abbr[abbrind:abbr.find('\x00', abbrind)] + tti.isstd = (ttisstdcnt > i and isstd[i] != 0) + tti.isgmt = (ttisgmtcnt > i and isgmt[i] != 0) + out.ttinfo_list.append(tti) + + # Replace ttinfo indexes for ttinfo objects. + out.trans_idx = [out.ttinfo_list[idx] for idx in out.trans_idx] + + # Set standard, dst, and before ttinfos. before will be + # used when a given time is before any transitions, + # and will be set to the first non-dst ttinfo, or to + # the first dst, if all of them are dst. + out.ttinfo_std = None + out.ttinfo_dst = None + out.ttinfo_before = None + if out.ttinfo_list: + if not out.trans_list_utc: + out.ttinfo_std = out.ttinfo_first = out.ttinfo_list[0] + else: + for i in range(timecnt-1, -1, -1): + tti = out.trans_idx[i] + if not out.ttinfo_std and not tti.isdst: + out.ttinfo_std = tti + elif not out.ttinfo_dst and tti.isdst: + out.ttinfo_dst = tti + + if out.ttinfo_std and out.ttinfo_dst: + break + else: + if out.ttinfo_dst and not out.ttinfo_std: + out.ttinfo_std = out.ttinfo_dst + + for tti in out.ttinfo_list: + if not tti.isdst: + out.ttinfo_before = tti + break + else: + out.ttinfo_before = out.ttinfo_list[0] + + # Now fix transition times to become relative to wall time. + # + # I'm not sure about this. In my tests, the tz source file + # is setup to wall time, and in the binary file isstd and + # isgmt are off, so it should be in wall time. OTOH, it's + # always in gmt time. Let me know if you have comments + # about this. + laststdoffset = None + out.trans_list = [] + for i, tti in enumerate(out.trans_idx): + if not tti.isdst: + offset = tti.offset + laststdoffset = offset + else: + if laststdoffset is not None: + # Store the DST offset as well and update it in the list + tti.dstoffset = tti.offset - laststdoffset + out.trans_idx[i] = tti + + offset = laststdoffset or 0 + + out.trans_list.append(out.trans_list_utc[i] + offset) + + # In case we missed any DST offsets on the way in for some reason, make + # a second pass over the list, looking for the /next/ DST offset. + laststdoffset = None + for i in reversed(range(len(out.trans_idx))): + tti = out.trans_idx[i] + if tti.isdst: + if not (tti.dstoffset or laststdoffset is None): + tti.dstoffset = tti.offset - laststdoffset + else: + laststdoffset = tti.offset + + if not isinstance(tti.dstoffset, datetime.timedelta): + tti.dstoffset = datetime.timedelta(seconds=tti.dstoffset) + + out.trans_idx[i] = tti + + out.trans_idx = tuple(out.trans_idx) + out.trans_list = tuple(out.trans_list) + out.trans_list_utc = tuple(out.trans_list_utc) + + return out + + def _find_last_transition(self, dt, in_utc=False): + # If there's no list, there are no transitions to find + if not self._trans_list: + return None + + timestamp = _datetime_to_timestamp(dt) + + # Find where the timestamp fits in the transition list - if the + # timestamp is a transition time, it's part of the "after" period. + trans_list = self._trans_list_utc if in_utc else self._trans_list + idx = bisect.bisect_right(trans_list, timestamp) + + # We want to know when the previous transition was, so subtract off 1 + return idx - 1 + + def _get_ttinfo(self, idx): + # For no list or after the last transition, default to _ttinfo_std + if idx is None or (idx + 1) >= len(self._trans_list): + return self._ttinfo_std + + # If there is a list and the time is before it, return _ttinfo_before + if idx < 0: + return self._ttinfo_before + + return self._trans_idx[idx] + + def _find_ttinfo(self, dt): + idx = self._resolve_ambiguous_time(dt) + + return self._get_ttinfo(idx) + + def fromutc(self, dt): + """ + The ``tzfile`` implementation of :py:func:`datetime.tzinfo.fromutc`. + + :param dt: + A :py:class:`datetime.datetime` object. + + :raises TypeError: + Raised if ``dt`` is not a :py:class:`datetime.datetime` object. + + :raises ValueError: + Raised if this is called with a ``dt`` which does not have this + ``tzinfo`` attached. + + :return: + Returns a :py:class:`datetime.datetime` object representing the + wall time in ``self``'s time zone. + """ + # These isinstance checks are in datetime.tzinfo, so we'll preserve + # them, even if we don't care about duck typing. + if not isinstance(dt, datetime.datetime): + raise TypeError("fromutc() requires a datetime argument") + + if dt.tzinfo is not self: + raise ValueError("dt.tzinfo is not self") + + # First treat UTC as wall time and get the transition we're in. + idx = self._find_last_transition(dt, in_utc=True) + tti = self._get_ttinfo(idx) + + dt_out = dt + datetime.timedelta(seconds=tti.offset) + + fold = self.is_ambiguous(dt_out, idx=idx) + + return enfold(dt_out, fold=int(fold)) + + def is_ambiguous(self, dt, idx=None): + """ + Whether or not the "wall time" of a given datetime is ambiguous in this + zone. + + :param dt: + A :py:class:`datetime.datetime`, naive or time zone aware. + + + :return: + Returns ``True`` if ambiguous, ``False`` otherwise. + + .. versionadded:: 2.6.0 + """ + if idx is None: + idx = self._find_last_transition(dt) + + # Calculate the difference in offsets from current to previous + timestamp = _datetime_to_timestamp(dt) + tti = self._get_ttinfo(idx) + + if idx is None or idx <= 0: + return False + + od = self._get_ttinfo(idx - 1).offset - tti.offset + tt = self._trans_list[idx] # Transition time + + return timestamp < tt + od + + def _resolve_ambiguous_time(self, dt): + idx = self._find_last_transition(dt) + + # If we have no transitions, return the index + _fold = self._fold(dt) + if idx is None or idx == 0: + return idx + + # If it's ambiguous and we're in a fold, shift to a different index. + idx_offset = int(not _fold and self.is_ambiguous(dt, idx)) + + return idx - idx_offset + + def utcoffset(self, dt): + if dt is None: + return None + + if not self._ttinfo_std: + return ZERO + + return self._find_ttinfo(dt).delta + + def dst(self, dt): + if dt is None: + return None + + if not self._ttinfo_dst: + return ZERO + + tti = self._find_ttinfo(dt) + + if not tti.isdst: + return ZERO + + # The documentation says that utcoffset()-dst() must + # be constant for every dt. + return tti.dstoffset + + @tzname_in_python2 + def tzname(self, dt): + if not self._ttinfo_std or dt is None: + return None + return self._find_ttinfo(dt).abbr + + def __eq__(self, other): + if not isinstance(other, tzfile): + return NotImplemented + return (self._trans_list == other._trans_list and + self._trans_idx == other._trans_idx and + self._ttinfo_list == other._ttinfo_list) + + __hash__ = None + + def __ne__(self, other): + return not (self == other) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, repr(self._filename)) + + def __reduce__(self): + return self.__reduce_ex__(None) + + def __reduce_ex__(self, protocol): + return (self.__class__, (None, self._filename), self.__dict__) + + +class tzrange(tzrangebase): + """ + The ``tzrange`` object is a time zone specified by a set of offsets and + abbreviations, equivalent to the way the ``TZ`` variable can be specified + in POSIX-like systems, but using Python delta objects to specify DST + start, end and offsets. + + :param stdabbr: + The abbreviation for standard time (e.g. ``'EST'``). + + :param stdoffset: + An integer or :class:`datetime.timedelta` object or equivalent + specifying the base offset from UTC. + + If unspecified, +00:00 is used. + + :param dstabbr: + The abbreviation for DST / "Summer" time (e.g. ``'EDT'``). + + If specified, with no other DST information, DST is assumed to occur + and the default behavior or ``dstoffset``, ``start`` and ``end`` is + used. If unspecified and no other DST information is specified, it + is assumed that this zone has no DST. + + If this is unspecified and other DST information is *is* specified, + DST occurs in the zone but the time zone abbreviation is left + unchanged. + + :param dstoffset: + A an integer or :class:`datetime.timedelta` object or equivalent + specifying the UTC offset during DST. If unspecified and any other DST + information is specified, it is assumed to be the STD offset +1 hour. + + :param start: + A :class:`relativedelta.relativedelta` object or equivalent specifying + the time and time of year that daylight savings time starts. To specify, + for example, that DST starts at 2AM on the 2nd Sunday in March, pass: + + ``relativedelta(hours=2, month=3, day=1, weekday=SU(+2))`` + + If unspecified and any other DST information is specified, the default + value is 2 AM on the first Sunday in April. + + :param end: + A :class:`relativedelta.relativedelta` object or equivalent representing + the time and time of year that daylight savings time ends, with the + same specification method as in ``start``. One note is that this should + point to the first time in the *standard* zone, so if a transition + occurs at 2AM in the DST zone and the clocks are set back 1 hour to 1AM, + set the `hours` parameter to +1. + + + **Examples:** + + .. testsetup:: tzrange + + from dateutil.tz import tzrange, tzstr + + .. doctest:: tzrange + + >>> tzstr('EST5EDT') == tzrange("EST", -18000, "EDT") + True + + >>> from dateutil.relativedelta import * + >>> range1 = tzrange("EST", -18000, "EDT") + >>> range2 = tzrange("EST", -18000, "EDT", -14400, + ... relativedelta(hours=+2, month=4, day=1, + ... weekday=SU(+1)), + ... relativedelta(hours=+1, month=10, day=31, + ... weekday=SU(-1))) + >>> tzstr('EST5EDT') == range1 == range2 + True + + """ + def __init__(self, stdabbr, stdoffset=None, + dstabbr=None, dstoffset=None, + start=None, end=None): + + global relativedelta + from dateutil import relativedelta + + self._std_abbr = stdabbr + self._dst_abbr = dstabbr + + try: + stdoffset = _total_seconds(stdoffset) + except (TypeError, AttributeError): + pass + + try: + dstoffset = _total_seconds(dstoffset) + except (TypeError, AttributeError): + pass + + if stdoffset is not None: + self._std_offset = datetime.timedelta(seconds=stdoffset) + else: + self._std_offset = ZERO + + if dstoffset is not None: + self._dst_offset = datetime.timedelta(seconds=dstoffset) + elif dstabbr and stdoffset is not None: + self._dst_offset = self._std_offset + datetime.timedelta(hours=+1) + else: + self._dst_offset = ZERO + + if dstabbr and start is None: + self._start_delta = relativedelta.relativedelta( + hours=+2, month=4, day=1, weekday=relativedelta.SU(+1)) + else: + self._start_delta = start + + if dstabbr and end is None: + self._end_delta = relativedelta.relativedelta( + hours=+1, month=10, day=31, weekday=relativedelta.SU(-1)) + else: + self._end_delta = end + + self._dst_base_offset_ = self._dst_offset - self._std_offset + self.hasdst = bool(self._start_delta) + + def transitions(self, year): + """ + For a given year, get the DST on and off transition times, expressed + always on the standard time side. For zones with no transitions, this + function returns ``None``. + + :param year: + The year whose transitions you would like to query. + + :return: + Returns a :class:`tuple` of :class:`datetime.datetime` objects, + ``(dston, dstoff)`` for zones with an annual DST transition, or + ``None`` for fixed offset zones. + """ + if not self.hasdst: + return None + + base_year = datetime.datetime(year, 1, 1) + + start = base_year + self._start_delta + end = base_year + self._end_delta + + return (start, end) + + def __eq__(self, other): + if not isinstance(other, tzrange): + return NotImplemented + + return (self._std_abbr == other._std_abbr and + self._dst_abbr == other._dst_abbr and + self._std_offset == other._std_offset and + self._dst_offset == other._dst_offset and + self._start_delta == other._start_delta and + self._end_delta == other._end_delta) + + @property + def _dst_base_offset(self): + return self._dst_base_offset_ + + +class tzstr(tzrange): + """ + ``tzstr`` objects are time zone objects specified by a time-zone string as + it would be passed to a ``TZ`` variable on POSIX-style systems (see + the `GNU C Library: TZ Variable`_ for more details). + + There is one notable exception, which is that POSIX-style time zones use an + inverted offset format, so normally ``GMT+3`` would be parsed as an offset + 3 hours *behind* GMT. The ``tzstr`` time zone object will parse this as an + offset 3 hours *ahead* of GMT. If you would like to maintain the POSIX + behavior, pass a ``True`` value to ``posix_offset``. + + The :class:`tzrange` object provides the same functionality, but is + specified using :class:`relativedelta.relativedelta` objects. rather than + strings. + + :param s: + A time zone string in ``TZ`` variable format. This can be a + :class:`bytes` (2.x: :class:`str`), :class:`str` (2.x: :class:`unicode`) + or a stream emitting unicode characters (e.g. :class:`StringIO`). + + :param posix_offset: + Optional. If set to ``True``, interpret strings such as ``GMT+3`` or + ``UTC+3`` as being 3 hours *behind* UTC rather than ahead, per the + POSIX standard. + + .. _`GNU C Library: TZ Variable`: + https://www.gnu.org/software/libc/manual/html_node/TZ-Variable.html + """ + def __init__(self, s, posix_offset=False): + global parser + from dateutil import parser + + self._s = s + + res = parser._parsetz(s) + if res is None: + raise ValueError("unknown string format") + + # Here we break the compatibility with the TZ variable handling. + # GMT-3 actually *means* the timezone -3. + if res.stdabbr in ("GMT", "UTC") and not posix_offset: + res.stdoffset *= -1 + + # We must initialize it first, since _delta() needs + # _std_offset and _dst_offset set. Use False in start/end + # to avoid building it two times. + tzrange.__init__(self, res.stdabbr, res.stdoffset, + res.dstabbr, res.dstoffset, + start=False, end=False) + + if not res.dstabbr: + self._start_delta = None + self._end_delta = None + else: + self._start_delta = self._delta(res.start) + if self._start_delta: + self._end_delta = self._delta(res.end, isend=1) + + self.hasdst = bool(self._start_delta) + + def _delta(self, x, isend=0): + from dateutil import relativedelta + kwargs = {} + if x.month is not None: + kwargs["month"] = x.month + if x.weekday is not None: + kwargs["weekday"] = relativedelta.weekday(x.weekday, x.week) + if x.week > 0: + kwargs["day"] = 1 + else: + kwargs["day"] = 31 + elif x.day: + kwargs["day"] = x.day + elif x.yday is not None: + kwargs["yearday"] = x.yday + elif x.jyday is not None: + kwargs["nlyearday"] = x.jyday + if not kwargs: + # Default is to start on first sunday of april, and end + # on last sunday of october. + if not isend: + kwargs["month"] = 4 + kwargs["day"] = 1 + kwargs["weekday"] = relativedelta.SU(+1) + else: + kwargs["month"] = 10 + kwargs["day"] = 31 + kwargs["weekday"] = relativedelta.SU(-1) + if x.time is not None: + kwargs["seconds"] = x.time + else: + # Default is 2AM. + kwargs["seconds"] = 7200 + if isend: + # Convert to standard time, to follow the documented way + # of working with the extra hour. See the documentation + # of the tzinfo class. + delta = self._dst_offset - self._std_offset + kwargs["seconds"] -= delta.seconds + delta.days * 86400 + return relativedelta.relativedelta(**kwargs) + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, repr(self._s)) + + +class _tzicalvtzcomp(object): + def __init__(self, tzoffsetfrom, tzoffsetto, isdst, + tzname=None, rrule=None): + self.tzoffsetfrom = datetime.timedelta(seconds=tzoffsetfrom) + self.tzoffsetto = datetime.timedelta(seconds=tzoffsetto) + self.tzoffsetdiff = self.tzoffsetto - self.tzoffsetfrom + self.isdst = isdst + self.tzname = tzname + self.rrule = rrule + + +class _tzicalvtz(_tzinfo): + def __init__(self, tzid, comps=[]): + super(_tzicalvtz, self).__init__() + + self._tzid = tzid + self._comps = comps + self._cachedate = [] + self._cachecomp = [] + + def _find_comp(self, dt): + if len(self._comps) == 1: + return self._comps[0] + + dt = dt.replace(tzinfo=None) + + try: + return self._cachecomp[self._cachedate.index((dt, self._fold(dt)))] + except ValueError: + pass + + lastcompdt = None + lastcomp = None + + for comp in self._comps: + compdt = self._find_compdt(comp, dt) + + if compdt and (not lastcompdt or lastcompdt < compdt): + lastcompdt = compdt + lastcomp = comp + + if not lastcomp: + # RFC says nothing about what to do when a given + # time is before the first onset date. We'll look for the + # first standard component, or the first component, if + # none is found. + for comp in self._comps: + if not comp.isdst: + lastcomp = comp + break + else: + lastcomp = comp[0] + + self._cachedate.insert(0, (dt, self._fold(dt))) + self._cachecomp.insert(0, lastcomp) + + if len(self._cachedate) > 10: + self._cachedate.pop() + self._cachecomp.pop() + + return lastcomp + + def _find_compdt(self, comp, dt): + if comp.tzoffsetdiff < ZERO and self._fold(dt): + dt -= comp.tzoffsetdiff + + compdt = comp.rrule.before(dt, inc=True) + + return compdt + + def utcoffset(self, dt): + if dt is None: + return None + + return self._find_comp(dt).tzoffsetto + + def dst(self, dt): + comp = self._find_comp(dt) + if comp.isdst: + return comp.tzoffsetdiff + else: + return ZERO + + @tzname_in_python2 + def tzname(self, dt): + return self._find_comp(dt).tzname + + def __repr__(self): + return "<tzicalvtz %s>" % repr(self._tzid) + + __reduce__ = object.__reduce__ + + +class tzical(object): + """ + This object is designed to parse an iCalendar-style ``VTIMEZONE`` structure + as set out in `RFC 2445`_ Section 4.6.5 into one or more `tzinfo` objects. + + :param `fileobj`: + A file or stream in iCalendar format, which should be UTF-8 encoded + with CRLF endings. + + .. _`RFC 2445`: https://www.ietf.org/rfc/rfc2445.txt + """ + def __init__(self, fileobj): + global rrule + from dateutil import rrule + + if isinstance(fileobj, string_types): + self._s = fileobj + # ical should be encoded in UTF-8 with CRLF + fileobj = open(fileobj, 'r') + else: + self._s = getattr(fileobj, 'name', repr(fileobj)) + fileobj = _ContextWrapper(fileobj) + + self._vtz = {} + + with fileobj as fobj: + self._parse_rfc(fobj.read()) + + def keys(self): + """ + Retrieves the available time zones as a list. + """ + return list(self._vtz.keys()) + + def get(self, tzid=None): + """ + Retrieve a :py:class:`datetime.tzinfo` object by its ``tzid``. + + :param tzid: + If there is exactly one time zone available, omitting ``tzid`` + or passing :py:const:`None` value returns it. Otherwise a valid + key (which can be retrieved from :func:`keys`) is required. + + :raises ValueError: + Raised if ``tzid`` is not specified but there are either more + or fewer than 1 zone defined. + + :returns: + Returns either a :py:class:`datetime.tzinfo` object representing + the relevant time zone or :py:const:`None` if the ``tzid`` was + not found. + """ + if tzid is None: + if len(self._vtz) == 0: + raise ValueError("no timezones defined") + elif len(self._vtz) > 1: + raise ValueError("more than one timezone available") + tzid = next(iter(self._vtz)) + + return self._vtz.get(tzid) + + def _parse_offset(self, s): + s = s.strip() + if not s: + raise ValueError("empty offset") + if s[0] in ('+', '-'): + signal = (-1, +1)[s[0] == '+'] + s = s[1:] + else: + signal = +1 + if len(s) == 4: + return (int(s[:2]) * 3600 + int(s[2:]) * 60) * signal + elif len(s) == 6: + return (int(s[:2]) * 3600 + int(s[2:4]) * 60 + int(s[4:])) * signal + else: + raise ValueError("invalid offset: " + s) + + def _parse_rfc(self, s): + lines = s.splitlines() + if not lines: + raise ValueError("empty string") + + # Unfold + i = 0 + while i < len(lines): + line = lines[i].rstrip() + if not line: + del lines[i] + elif i > 0 and line[0] == " ": + lines[i-1] += line[1:] + del lines[i] + else: + i += 1 + + tzid = None + comps = [] + invtz = False + comptype = None + for line in lines: + if not line: + continue + name, value = line.split(':', 1) + parms = name.split(';') + if not parms: + raise ValueError("empty property name") + name = parms[0].upper() + parms = parms[1:] + if invtz: + if name == "BEGIN": + if value in ("STANDARD", "DAYLIGHT"): + # Process component + pass + else: + raise ValueError("unknown component: "+value) + comptype = value + founddtstart = False + tzoffsetfrom = None + tzoffsetto = None + rrulelines = [] + tzname = None + elif name == "END": + if value == "VTIMEZONE": + if comptype: + raise ValueError("component not closed: "+comptype) + if not tzid: + raise ValueError("mandatory TZID not found") + if not comps: + raise ValueError( + "at least one component is needed") + # Process vtimezone + self._vtz[tzid] = _tzicalvtz(tzid, comps) + invtz = False + elif value == comptype: + if not founddtstart: + raise ValueError("mandatory DTSTART not found") + if tzoffsetfrom is None: + raise ValueError( + "mandatory TZOFFSETFROM not found") + if tzoffsetto is None: + raise ValueError( + "mandatory TZOFFSETFROM not found") + # Process component + rr = None + if rrulelines: + rr = rrule.rrulestr("\n".join(rrulelines), + compatible=True, + ignoretz=True, + cache=True) + comp = _tzicalvtzcomp(tzoffsetfrom, tzoffsetto, + (comptype == "DAYLIGHT"), + tzname, rr) + comps.append(comp) + comptype = None + else: + raise ValueError("invalid component end: "+value) + elif comptype: + if name == "DTSTART": + rrulelines.append(line) + founddtstart = True + elif name in ("RRULE", "RDATE", "EXRULE", "EXDATE"): + rrulelines.append(line) + elif name == "TZOFFSETFROM": + if parms: + raise ValueError( + "unsupported %s parm: %s " % (name, parms[0])) + tzoffsetfrom = self._parse_offset(value) + elif name == "TZOFFSETTO": + if parms: + raise ValueError( + "unsupported TZOFFSETTO parm: "+parms[0]) + tzoffsetto = self._parse_offset(value) + elif name == "TZNAME": + if parms: + raise ValueError( + "unsupported TZNAME parm: "+parms[0]) + tzname = value + elif name == "COMMENT": + pass + else: + raise ValueError("unsupported property: "+name) + else: + if name == "TZID": + if parms: + raise ValueError( + "unsupported TZID parm: "+parms[0]) + tzid = value + elif name in ("TZURL", "LAST-MODIFIED", "COMMENT"): + pass + else: + raise ValueError("unsupported property: "+name) + elif name == "BEGIN" and value == "VTIMEZONE": + tzid = None + comps = [] + invtz = True + + def __repr__(self): + return "%s(%s)" % (self.__class__.__name__, repr(self._s)) + + +if sys.platform != "win32": + TZFILES = ["/etc/localtime", "localtime"] + TZPATHS = ["/usr/share/zoneinfo", + "/usr/lib/zoneinfo", + "/usr/share/lib/zoneinfo", + "/etc/zoneinfo"] +else: + TZFILES = [] + TZPATHS = [] + + +def gettz(name=None): + tz = None + if not name: + try: + name = os.environ["TZ"] + except KeyError: + pass + if name is None or name == ":": + for filepath in TZFILES: + if not os.path.isabs(filepath): + filename = filepath + for path in TZPATHS: + filepath = os.path.join(path, filename) + if os.path.isfile(filepath): + break + else: + continue + if os.path.isfile(filepath): + try: + tz = tzfile(filepath) + break + except (IOError, OSError, ValueError): + pass + else: + tz = tzlocal() + else: + if name.startswith(":"): + name = name[:-1] + if os.path.isabs(name): + if os.path.isfile(name): + tz = tzfile(name) + else: + tz = None + else: + for path in TZPATHS: + filepath = os.path.join(path, name) + if not os.path.isfile(filepath): + filepath = filepath.replace(' ', '_') + if not os.path.isfile(filepath): + continue + try: + tz = tzfile(filepath) + break + except (IOError, OSError, ValueError): + pass + else: + tz = None + if tzwin is not None: + try: + tz = tzwin(name) + except WindowsError: + tz = None + + if not tz: + from dateutil.zoneinfo import get_zonefile_instance + tz = get_zonefile_instance().get(name) + + if not tz: + for c in name: + # name must have at least one offset to be a tzstr + if c in "0123456789": + try: + tz = tzstr(name) + except ValueError: + pass + break + else: + if name in ("GMT", "UTC"): + tz = tzutc() + elif name in time.tzname: + tz = tzlocal() + return tz + + +def datetime_exists(dt, tz=None): + """ + Given a datetime and a time zone, determine whether or not a given datetime + would fall in a gap. + + :param dt: + A :class:`datetime.datetime` (whose time zone will be ignored if ``tz`` + is provided.) + + :param tz: + A :class:`datetime.tzinfo` with support for the ``fold`` attribute. If + ``None`` or not provided, the datetime's own time zone will be used. + + :return: + Returns a boolean value whether or not the "wall time" exists in ``tz``. + """ + if tz is None: + if dt.tzinfo is None: + raise ValueError('Datetime is naive and no time zone provided.') + tz = dt.tzinfo + + dt = dt.replace(tzinfo=None) + + # This is essentially a test of whether or not the datetime can survive + # a round trip to UTC. + dt_rt = dt.replace(tzinfo=tz).astimezone(tzutc()).astimezone(tz) + dt_rt = dt_rt.replace(tzinfo=None) + + return dt == dt_rt + + +def datetime_ambiguous(dt, tz=None): + """ + Given a datetime and a time zone, determine whether or not a given datetime + is ambiguous (i.e if there are two times differentiated only by their DST + status). + + :param dt: + A :class:`datetime.datetime` (whose time zone will be ignored if ``tz`` + is provided.) + + :param tz: + A :class:`datetime.tzinfo` with support for the ``fold`` attribute. If + ``None`` or not provided, the datetime's own time zone will be used. + + :return: + Returns a boolean value whether or not the "wall time" is ambiguous in + ``tz``. + + .. versionadded:: 2.6.0 + """ + if tz is None: + if dt.tzinfo is None: + raise ValueError('Datetime is naive and no time zone provided.') + + tz = dt.tzinfo + + # If a time zone defines its own "is_ambiguous" function, we'll use that. + is_ambiguous_fn = getattr(tz, 'is_ambiguous', None) + if is_ambiguous_fn is not None: + try: + return tz.is_ambiguous(dt) + except: + pass + + # If it doesn't come out and tell us it's ambiguous, we'll just check if + # the fold attribute has any effect on this particular date and time. + dt = dt.replace(tzinfo=tz) + wall_0 = enfold(dt, fold=0) + wall_1 = enfold(dt, fold=1) + + same_offset = wall_0.utcoffset() == wall_1.utcoffset() + same_dst = wall_0.dst() == wall_1.dst() + + return not (same_offset and same_dst) + + +def _datetime_to_timestamp(dt): + """ + Convert a :class:`datetime.datetime` object to an epoch timestamp in seconds + since January 1, 1970, ignoring the time zone. + """ + return _total_seconds((dt.replace(tzinfo=None) - EPOCH)) + + +class _ContextWrapper(object): + """ + Class for wrapping contexts so that they are passed through in a + with statement. + """ + def __init__(self, context): + self.context = context + + def __enter__(self): + return self.context + + def __exit__(*args, **kwargs): + pass + +# vim:ts=4:sw=4:et diff --git a/python/dateutil/tz/win.py b/python/dateutil/tz/win.py new file mode 100644 index 0000000..36a1c26 --- /dev/null +++ b/python/dateutil/tz/win.py @@ -0,0 +1,332 @@ +# This code was originally contributed by Jeffrey Harris. +import datetime +import struct + +from six.moves import winreg +from six import text_type + +try: + import ctypes + from ctypes import wintypes +except ValueError: + # ValueError is raised on non-Windows systems for some horrible reason. + raise ImportError("Running tzwin on non-Windows system") + +from ._common import tzrangebase + +__all__ = ["tzwin", "tzwinlocal", "tzres"] + +ONEWEEK = datetime.timedelta(7) + +TZKEYNAMENT = r"SOFTWARE\Microsoft\Windows NT\CurrentVersion\Time Zones" +TZKEYNAME9X = r"SOFTWARE\Microsoft\Windows\CurrentVersion\Time Zones" +TZLOCALKEYNAME = r"SYSTEM\CurrentControlSet\Control\TimeZoneInformation" + + +def _settzkeyname(): + handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) + try: + winreg.OpenKey(handle, TZKEYNAMENT).Close() + TZKEYNAME = TZKEYNAMENT + except WindowsError: + TZKEYNAME = TZKEYNAME9X + handle.Close() + return TZKEYNAME + + +TZKEYNAME = _settzkeyname() + + +class tzres(object): + """ + Class for accessing `tzres.dll`, which contains timezone name related + resources. + + .. versionadded:: 2.5.0 + """ + p_wchar = ctypes.POINTER(wintypes.WCHAR) # Pointer to a wide char + + def __init__(self, tzres_loc='tzres.dll'): + # Load the user32 DLL so we can load strings from tzres + user32 = ctypes.WinDLL('user32') + + # Specify the LoadStringW function + user32.LoadStringW.argtypes = (wintypes.HINSTANCE, + wintypes.UINT, + wintypes.LPWSTR, + ctypes.c_int) + + self.LoadStringW = user32.LoadStringW + self._tzres = ctypes.WinDLL(tzres_loc) + self.tzres_loc = tzres_loc + + def load_name(self, offset): + """ + Load a timezone name from a DLL offset (integer). + + >>> from dateutil.tzwin import tzres + >>> tzr = tzres() + >>> print(tzr.load_name(112)) + 'Eastern Standard Time' + + :param offset: + A positive integer value referring to a string from the tzres dll. + + ..note: + Offsets found in the registry are generally of the form + `@tzres.dll,-114`. The offset in this case if 114, not -114. + + """ + resource = self.p_wchar() + lpBuffer = ctypes.cast(ctypes.byref(resource), wintypes.LPWSTR) + nchar = self.LoadStringW(self._tzres._handle, offset, lpBuffer, 0) + return resource[:nchar] + + def name_from_string(self, tzname_str): + """ + Parse strings as returned from the Windows registry into the time zone + name as defined in the registry. + + >>> from dateutil.tzwin import tzres + >>> tzr = tzres() + >>> print(tzr.name_from_string('@tzres.dll,-251')) + 'Dateline Daylight Time' + >>> print(tzr.name_from_string('Eastern Standard Time')) + 'Eastern Standard Time' + + :param tzname_str: + A timezone name string as returned from a Windows registry key. + + :return: + Returns the localized timezone string from tzres.dll if the string + is of the form `@tzres.dll,-offset`, else returns the input string. + """ + if not tzname_str.startswith('@'): + return tzname_str + + name_splt = tzname_str.split(',-') + try: + offset = int(name_splt[1]) + except: + raise ValueError("Malformed timezone string.") + + return self.load_name(offset) + + +class tzwinbase(tzrangebase): + """tzinfo class based on win32's timezones available in the registry.""" + def __init__(self): + raise NotImplementedError('tzwinbase is an abstract base class') + + def __eq__(self, other): + # Compare on all relevant dimensions, including name. + if not isinstance(other, tzwinbase): + return NotImplemented + + return (self._std_offset == other._std_offset and + self._dst_offset == other._dst_offset and + self._stddayofweek == other._stddayofweek and + self._dstdayofweek == other._dstdayofweek and + self._stdweeknumber == other._stdweeknumber and + self._dstweeknumber == other._dstweeknumber and + self._stdhour == other._stdhour and + self._dsthour == other._dsthour and + self._stdminute == other._stdminute and + self._dstminute == other._dstminute and + self._std_abbr == other._std_abbr and + self._dst_abbr == other._dst_abbr) + + @staticmethod + def list(): + """Return a list of all time zones known to the system.""" + with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle: + with winreg.OpenKey(handle, TZKEYNAME) as tzkey: + result = [winreg.EnumKey(tzkey, i) + for i in range(winreg.QueryInfoKey(tzkey)[0])] + return result + + def display(self): + return self._display + + def transitions(self, year): + """ + For a given year, get the DST on and off transition times, expressed + always on the standard time side. For zones with no transitions, this + function returns ``None``. + + :param year: + The year whose transitions you would like to query. + + :return: + Returns a :class:`tuple` of :class:`datetime.datetime` objects, + ``(dston, dstoff)`` for zones with an annual DST transition, or + ``None`` for fixed offset zones. + """ + + if not self.hasdst: + return None + + dston = picknthweekday(year, self._dstmonth, self._dstdayofweek, + self._dsthour, self._dstminute, + self._dstweeknumber) + + dstoff = picknthweekday(year, self._stdmonth, self._stddayofweek, + self._stdhour, self._stdminute, + self._stdweeknumber) + + # Ambiguous dates default to the STD side + dstoff -= self._dst_base_offset + + return dston, dstoff + + def _get_hasdst(self): + return self._dstmonth != 0 + + @property + def _dst_base_offset(self): + return self._dst_base_offset_ + + +class tzwin(tzwinbase): + + def __init__(self, name): + self._name = name + + # multiple contexts only possible in 2.7 and 3.1, we still support 2.6 + with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle: + tzkeyname = text_type("{kn}\\{name}").format(kn=TZKEYNAME, name=name) + with winreg.OpenKey(handle, tzkeyname) as tzkey: + keydict = valuestodict(tzkey) + + self._std_abbr = keydict["Std"] + self._dst_abbr = keydict["Dlt"] + + self._display = keydict["Display"] + + # See http://ww_winreg.jsiinc.com/SUBA/tip0300/rh0398.htm + tup = struct.unpack("=3l16h", keydict["TZI"]) + stdoffset = -tup[0]-tup[1] # Bias + StandardBias * -1 + dstoffset = stdoffset-tup[2] # + DaylightBias * -1 + self._std_offset = datetime.timedelta(minutes=stdoffset) + self._dst_offset = datetime.timedelta(minutes=dstoffset) + + # for the meaning see the win32 TIME_ZONE_INFORMATION structure docs + # http://msdn.microsoft.com/en-us/library/windows/desktop/ms725481(v=vs.85).aspx + (self._stdmonth, + self._stddayofweek, # Sunday = 0 + self._stdweeknumber, # Last = 5 + self._stdhour, + self._stdminute) = tup[4:9] + + (self._dstmonth, + self._dstdayofweek, # Sunday = 0 + self._dstweeknumber, # Last = 5 + self._dsthour, + self._dstminute) = tup[12:17] + + self._dst_base_offset_ = self._dst_offset - self._std_offset + self.hasdst = self._get_hasdst() + + def __repr__(self): + return "tzwin(%s)" % repr(self._name) + + def __reduce__(self): + return (self.__class__, (self._name,)) + + +class tzwinlocal(tzwinbase): + def __init__(self): + with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle: + with winreg.OpenKey(handle, TZLOCALKEYNAME) as tzlocalkey: + keydict = valuestodict(tzlocalkey) + + self._std_abbr = keydict["StandardName"] + self._dst_abbr = keydict["DaylightName"] + + try: + tzkeyname = text_type('{kn}\\{sn}').format(kn=TZKEYNAME, + sn=self._std_abbr) + with winreg.OpenKey(handle, tzkeyname) as tzkey: + _keydict = valuestodict(tzkey) + self._display = _keydict["Display"] + except OSError: + self._display = None + + stdoffset = -keydict["Bias"]-keydict["StandardBias"] + dstoffset = stdoffset-keydict["DaylightBias"] + + self._std_offset = datetime.timedelta(minutes=stdoffset) + self._dst_offset = datetime.timedelta(minutes=dstoffset) + + # For reasons unclear, in this particular key, the day of week has been + # moved to the END of the SYSTEMTIME structure. + tup = struct.unpack("=8h", keydict["StandardStart"]) + + (self._stdmonth, + self._stdweeknumber, # Last = 5 + self._stdhour, + self._stdminute) = tup[1:5] + + self._stddayofweek = tup[7] + + tup = struct.unpack("=8h", keydict["DaylightStart"]) + + (self._dstmonth, + self._dstweeknumber, # Last = 5 + self._dsthour, + self._dstminute) = tup[1:5] + + self._dstdayofweek = tup[7] + + self._dst_base_offset_ = self._dst_offset - self._std_offset + self.hasdst = self._get_hasdst() + + def __repr__(self): + return "tzwinlocal()" + + def __str__(self): + # str will return the standard name, not the daylight name. + return "tzwinlocal(%s)" % repr(self._std_abbr) + + def __reduce__(self): + return (self.__class__, ()) + + +def picknthweekday(year, month, dayofweek, hour, minute, whichweek): + """ dayofweek == 0 means Sunday, whichweek 5 means last instance """ + first = datetime.datetime(year, month, 1, hour, minute) + + # This will work if dayofweek is ISO weekday (1-7) or Microsoft-style (0-6), + # Because 7 % 7 = 0 + weekdayone = first.replace(day=((dayofweek - first.isoweekday()) % 7) + 1) + wd = weekdayone + ((whichweek - 1) * ONEWEEK) + if (wd.month != month): + wd -= ONEWEEK + + return wd + + +def valuestodict(key): + """Convert a registry key's values to a dictionary.""" + dout = {} + size = winreg.QueryInfoKey(key)[1] + tz_res = None + + for i in range(size): + key_name, value, dtype = winreg.EnumValue(key, i) + if dtype == winreg.REG_DWORD or dtype == winreg.REG_DWORD_LITTLE_ENDIAN: + # If it's a DWORD (32-bit integer), it's stored as unsigned - convert + # that to a proper signed integer + if value & (1 << 31): + value = value - (1 << 32) + elif dtype == winreg.REG_SZ: + # If it's a reference to the tzres DLL, load the actual string + if value.startswith('@tzres'): + tz_res = tz_res or tzres() + value = tz_res.name_from_string(value) + + value = value.rstrip('\x00') # Remove trailing nulls + + dout[key_name] = value + + return dout diff --git a/python/dateutil/tzwin.py b/python/dateutil/tzwin.py new file mode 100644 index 0000000..cebc673 --- /dev/null +++ b/python/dateutil/tzwin.py @@ -0,0 +1,2 @@ +# tzwin has moved to dateutil.tz.win +from .tz.win import * diff --git a/python/dateutil/zoneinfo/__init__.py b/python/dateutil/zoneinfo/__init__.py new file mode 100644 index 0000000..a2ed4f9 --- /dev/null +++ b/python/dateutil/zoneinfo/__init__.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +import warnings +import json + +from tarfile import TarFile +from pkgutil import get_data +from io import BytesIO +from contextlib import closing + +from dateutil.tz import tzfile + +__all__ = ["get_zonefile_instance", "gettz", "gettz_db_metadata", "rebuild"] + +ZONEFILENAME = "dateutil-zoneinfo.tar.gz" +METADATA_FN = 'METADATA' + +# python2.6 compatability. Note that TarFile.__exit__ != TarFile.close, but +# it's close enough for python2.6 +tar_open = TarFile.open +if not hasattr(TarFile, '__exit__'): + def tar_open(*args, **kwargs): + return closing(TarFile.open(*args, **kwargs)) + + +class tzfile(tzfile): + def __reduce__(self): + return (gettz, (self._filename,)) + + +def getzoneinfofile_stream(): + try: + return BytesIO(get_data(__name__, ZONEFILENAME)) + except IOError as e: # TODO switch to FileNotFoundError? + warnings.warn("I/O error({0}): {1}".format(e.errno, e.strerror)) + return None + + +class ZoneInfoFile(object): + def __init__(self, zonefile_stream=None): + if zonefile_stream is not None: + with tar_open(fileobj=zonefile_stream, mode='r') as tf: + # dict comprehension does not work on python2.6 + # TODO: get back to the nicer syntax when we ditch python2.6 + # self.zones = {zf.name: tzfile(tf.extractfile(zf), + # filename = zf.name) + # for zf in tf.getmembers() if zf.isfile()} + self.zones = dict((zf.name, tzfile(tf.extractfile(zf), + filename=zf.name)) + for zf in tf.getmembers() + if zf.isfile() and zf.name != METADATA_FN) + # deal with links: They'll point to their parent object. Less + # waste of memory + # links = {zl.name: self.zones[zl.linkname] + # for zl in tf.getmembers() if zl.islnk() or zl.issym()} + links = dict((zl.name, self.zones[zl.linkname]) + for zl in tf.getmembers() if + zl.islnk() or zl.issym()) + self.zones.update(links) + try: + metadata_json = tf.extractfile(tf.getmember(METADATA_FN)) + metadata_str = metadata_json.read().decode('UTF-8') + self.metadata = json.loads(metadata_str) + except KeyError: + # no metadata in tar file + self.metadata = None + else: + self.zones = dict() + self.metadata = None + + def get(self, name, default=None): + """ + Wrapper for :func:`ZoneInfoFile.zones.get`. This is a convenience method + for retrieving zones from the zone dictionary. + + :param name: + The name of the zone to retrieve. (Generally IANA zone names) + + :param default: + The value to return in the event of a missing key. + + .. versionadded:: 2.6.0 + + """ + return self.zones.get(name, default) + + +# The current API has gettz as a module function, although in fact it taps into +# a stateful class. So as a workaround for now, without changing the API, we +# will create a new "global" class instance the first time a user requests a +# timezone. Ugly, but adheres to the api. +# +# TODO: Remove after deprecation period. +_CLASS_ZONE_INSTANCE = list() + + +def get_zonefile_instance(new_instance=False): + """ + This is a convenience function which provides a :class:`ZoneInfoFile` + instance using the data provided by the ``dateutil`` package. By default, it + caches a single instance of the ZoneInfoFile object and returns that. + + :param new_instance: + If ``True``, a new instance of :class:`ZoneInfoFile` is instantiated and + used as the cached instance for the next call. Otherwise, new instances + are created only as necessary. + + :return: + Returns a :class:`ZoneInfoFile` object. + + .. versionadded:: 2.6 + """ + if new_instance: + zif = None + else: + zif = getattr(get_zonefile_instance, '_cached_instance', None) + + if zif is None: + zif = ZoneInfoFile(getzoneinfofile_stream()) + + get_zonefile_instance._cached_instance = zif + + return zif + + +def gettz(name): + """ + This retrieves a time zone from the local zoneinfo tarball that is packaged + with dateutil. + + :param name: + An IANA-style time zone name, as found in the zoneinfo file. + + :return: + Returns a :class:`dateutil.tz.tzfile` time zone object. + + .. warning:: + It is generally inadvisable to use this function, and it is only + provided for API compatibility with earlier versions. This is *not* + equivalent to ``dateutil.tz.gettz()``, which selects an appropriate + time zone based on the inputs, favoring system zoneinfo. This is ONLY + for accessing the dateutil-specific zoneinfo (which may be out of + date compared to the system zoneinfo). + + .. deprecated:: 2.6 + If you need to use a specific zoneinfofile over the system zoneinfo, + instantiate a :class:`dateutil.zoneinfo.ZoneInfoFile` object and call + :func:`dateutil.zoneinfo.ZoneInfoFile.get(name)` instead. + + Use :func:`get_zonefile_instance` to retrieve an instance of the + dateutil-provided zoneinfo. + """ + warnings.warn("zoneinfo.gettz() will be removed in future versions, " + "to use the dateutil-provided zoneinfo files, instantiate a " + "ZoneInfoFile object and use ZoneInfoFile.zones.get() " + "instead. See the documentation for details.", + DeprecationWarning) + + if len(_CLASS_ZONE_INSTANCE) == 0: + _CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream())) + return _CLASS_ZONE_INSTANCE[0].zones.get(name) + + +def gettz_db_metadata(): + """ Get the zonefile metadata + + See `zonefile_metadata`_ + + :returns: + A dictionary with the database metadata + + .. deprecated:: 2.6 + See deprecation warning in :func:`zoneinfo.gettz`. To get metadata, + query the attribute ``zoneinfo.ZoneInfoFile.metadata``. + """ + warnings.warn("zoneinfo.gettz_db_metadata() will be removed in future " + "versions, to use the dateutil-provided zoneinfo files, " + "ZoneInfoFile object and query the 'metadata' attribute " + "instead. See the documentation for details.", + DeprecationWarning) + + if len(_CLASS_ZONE_INSTANCE) == 0: + _CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream())) + return _CLASS_ZONE_INSTANCE[0].metadata diff --git a/python/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz b/python/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz Binary files differnew file mode 100644 index 0000000..613c0ff --- /dev/null +++ b/python/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz diff --git a/python/dateutil/zoneinfo/rebuild.py b/python/dateutil/zoneinfo/rebuild.py new file mode 100644 index 0000000..9d53bb8 --- /dev/null +++ b/python/dateutil/zoneinfo/rebuild.py @@ -0,0 +1,52 @@ +import logging +import os +import tempfile +import shutil +import json +from subprocess import check_call + +from dateutil.zoneinfo import tar_open, METADATA_FN, ZONEFILENAME + + +def rebuild(filename, tag=None, format="gz", zonegroups=[], metadata=None): + """Rebuild the internal timezone info in dateutil/zoneinfo/zoneinfo*tar* + + filename is the timezone tarball from ftp.iana.org/tz. + + """ + tmpdir = tempfile.mkdtemp() + zonedir = os.path.join(tmpdir, "zoneinfo") + moduledir = os.path.dirname(__file__) + try: + with tar_open(filename) as tf: + for name in zonegroups: + tf.extract(name, tmpdir) + filepaths = [os.path.join(tmpdir, n) for n in zonegroups] + try: + check_call(["zic", "-d", zonedir] + filepaths) + except OSError as e: + _print_on_nosuchfile(e) + raise + # write metadata file + with open(os.path.join(zonedir, METADATA_FN), 'w') as f: + json.dump(metadata, f, indent=4, sort_keys=True) + target = os.path.join(moduledir, ZONEFILENAME) + with tar_open(target, "w:%s" % format) as tf: + for entry in os.listdir(zonedir): + entrypath = os.path.join(zonedir, entry) + tf.add(entrypath, entry) + finally: + shutil.rmtree(tmpdir) + + +def _print_on_nosuchfile(e): + """Print helpful troubleshooting message + + e is an exception raised by subprocess.check_call() + + """ + if e.errno == 2: + logging.error( + "Could not find zic. Perhaps you need to install " + "libc-bin or some other package that provides it, " + "or it's not in your PATH?") diff --git a/python/defusedxml/ElementTree.py b/python/defusedxml/ElementTree.py new file mode 100644 index 0000000..41b2ea8 --- /dev/null +++ b/python/defusedxml/ElementTree.py @@ -0,0 +1,112 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.etree.ElementTree facade +""" +from __future__ import print_function, absolute_import + +import sys +from xml.etree.ElementTree import TreeBuilder as _TreeBuilder +from xml.etree.ElementTree import parse as _parse +from xml.etree.ElementTree import tostring + +from .common import PY3 + + +if PY3: + import importlib +else: + from xml.etree.ElementTree import XMLParser as _XMLParser + from xml.etree.ElementTree import iterparse as _iterparse + from xml.etree.ElementTree import ParseError + + +from .common import (DTDForbidden, EntitiesForbidden, + ExternalReferenceForbidden, _generate_etree_functions) + +__origin__ = "xml.etree.ElementTree" + + +def _get_py3_cls(): + """Python 3.3 hides the pure Python code but defusedxml requires it. + + The code is based on test.support.import_fresh_module(). + """ + pymodname = "xml.etree.ElementTree" + cmodname = "_elementtree" + + pymod = sys.modules.pop(pymodname, None) + cmod = sys.modules.pop(cmodname, None) + + sys.modules[cmodname] = None + pure_pymod = importlib.import_module(pymodname) + if cmod is not None: + sys.modules[cmodname] = cmod + else: + sys.modules.pop(cmodname) + sys.modules[pymodname] = pymod + + _XMLParser = pure_pymod.XMLParser + _iterparse = pure_pymod.iterparse + ParseError = pure_pymod.ParseError + + return _XMLParser, _iterparse, ParseError + + +if PY3: + _XMLParser, _iterparse, ParseError = _get_py3_cls() + + +class DefusedXMLParser(_XMLParser): + + def __init__(self, html=0, target=None, encoding=None, + forbid_dtd=False, forbid_entities=True, + forbid_external=True): + # Python 2.x old style class + _XMLParser.__init__(self, html, target, encoding) + self.forbid_dtd = forbid_dtd + self.forbid_entities = forbid_entities + self.forbid_external = forbid_external + if PY3: + parser = self.parser + else: + parser = self._parser + if self.forbid_dtd: + parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl + if self.forbid_entities: + parser.EntityDeclHandler = self.defused_entity_decl + parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl + if self.forbid_external: + parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler + + def defused_start_doctype_decl(self, name, sysid, pubid, + has_internal_subset): + raise DTDForbidden(name, sysid, pubid) + + def defused_entity_decl(self, name, is_parameter_entity, value, base, + sysid, pubid, notation_name): + raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name) + + def defused_unparsed_entity_decl(self, name, base, sysid, pubid, + notation_name): + # expat 1.2 + raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name) + + def defused_external_entity_ref_handler(self, context, base, sysid, + pubid): + raise ExternalReferenceForbidden(context, base, sysid, pubid) + + +# aliases +XMLTreeBuilder = XMLParse = DefusedXMLParser + +parse, iterparse, fromstring = _generate_etree_functions(DefusedXMLParser, + _TreeBuilder, _parse, + _iterparse) +XML = fromstring + + +__all__ = ['XML', 'XMLParse', 'XMLTreeBuilder', 'fromstring', 'iterparse', + 'parse', 'tostring'] diff --git a/python/defusedxml/__init__.py b/python/defusedxml/__init__.py new file mode 100644 index 0000000..590a5a9 --- /dev/null +++ b/python/defusedxml/__init__.py @@ -0,0 +1,45 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defuse XML bomb denial of service vulnerabilities +""" +from __future__ import print_function, absolute_import + +from .common import (DefusedXmlException, DTDForbidden, EntitiesForbidden, + ExternalReferenceForbidden, NotSupportedError, + _apply_defusing) + + +def defuse_stdlib(): + """Monkey patch and defuse all stdlib packages + + :warning: The monkey patch is an EXPERIMETNAL feature. + """ + defused = {} + + from . import cElementTree + from . import ElementTree + from . import minidom + from . import pulldom + from . import sax + from . import expatbuilder + from . import expatreader + from . import xmlrpc + + xmlrpc.monkey_patch() + defused[xmlrpc] = None + + for defused_mod in [cElementTree, ElementTree, minidom, pulldom, sax, + expatbuilder, expatreader]: + stdlib_mod = _apply_defusing(defused_mod) + defused[defused_mod] = stdlib_mod + + return defused + + +__version__ = "0.5.0" + +__all__ = ['DefusedXmlException', 'DTDForbidden', 'EntitiesForbidden', + 'ExternalReferenceForbidden', 'NotSupportedError'] diff --git a/python/defusedxml/cElementTree.py b/python/defusedxml/cElementTree.py new file mode 100644 index 0000000..cc13689 --- /dev/null +++ b/python/defusedxml/cElementTree.py @@ -0,0 +1,30 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.etree.cElementTree +""" +from __future__ import absolute_import + +from xml.etree.cElementTree import TreeBuilder as _TreeBuilder +from xml.etree.cElementTree import parse as _parse +from xml.etree.cElementTree import tostring +# iterparse from ElementTree! +from xml.etree.ElementTree import iterparse as _iterparse + +from .ElementTree import DefusedXMLParser +from .common import _generate_etree_functions + +__origin__ = "xml.etree.cElementTree" + + +XMLTreeBuilder = XMLParse = DefusedXMLParser + +parse, iterparse, fromstring = _generate_etree_functions(DefusedXMLParser, + _TreeBuilder, _parse, + _iterparse) +XML = fromstring + +__all__ = ['XML', 'XMLParse', 'XMLTreeBuilder', 'fromstring', 'iterparse', + 'parse', 'tostring'] diff --git a/python/defusedxml/common.py b/python/defusedxml/common.py new file mode 100644 index 0000000..668b609 --- /dev/null +++ b/python/defusedxml/common.py @@ -0,0 +1,120 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Common constants, exceptions and helpe functions +""" +import sys + +PY3 = sys.version_info[0] == 3 + + +class DefusedXmlException(ValueError): + """Base exception + """ + + def __repr__(self): + return str(self) + + +class DTDForbidden(DefusedXmlException): + """Document type definition is forbidden + """ + + def __init__(self, name, sysid, pubid): + super(DTDForbidden, self).__init__() + self.name = name + self.sysid = sysid + self.pubid = pubid + + def __str__(self): + tpl = "DTDForbidden(name='{}', system_id={!r}, public_id={!r})" + return tpl.format(self.name, self.sysid, self.pubid) + + +class EntitiesForbidden(DefusedXmlException): + """Entity definition is forbidden + """ + + def __init__(self, name, value, base, sysid, pubid, notation_name): + super(EntitiesForbidden, self).__init__() + self.name = name + self.value = value + self.base = base + self.sysid = sysid + self.pubid = pubid + self.notation_name = notation_name + + def __str__(self): + tpl = "EntitiesForbidden(name='{}', system_id={!r}, public_id={!r})" + return tpl.format(self.name, self.sysid, self.pubid) + + +class ExternalReferenceForbidden(DefusedXmlException): + """Resolving an external reference is forbidden + """ + + def __init__(self, context, base, sysid, pubid): + super(ExternalReferenceForbidden, self).__init__() + self.context = context + self.base = base + self.sysid = sysid + self.pubid = pubid + + def __str__(self): + tpl = "ExternalReferenceForbidden(system_id='{}', public_id={})" + return tpl.format(self.sysid, self.pubid) + + +class NotSupportedError(DefusedXmlException): + """The operation is not supported + """ + + +def _apply_defusing(defused_mod): + assert defused_mod is sys.modules[defused_mod.__name__] + stdlib_name = defused_mod.__origin__ + __import__(stdlib_name, {}, {}, ["*"]) + stdlib_mod = sys.modules[stdlib_name] + stdlib_names = set(dir(stdlib_mod)) + for name, obj in vars(defused_mod).items(): + if name.startswith("_") or name not in stdlib_names: + continue + setattr(stdlib_mod, name, obj) + return stdlib_mod + + +def _generate_etree_functions(DefusedXMLParser, _TreeBuilder, + _parse, _iterparse): + """Factory for functions needed by etree, dependent on whether + cElementTree or ElementTree is used.""" + + def parse(source, parser=None, forbid_dtd=False, forbid_entities=True, + forbid_external=True): + if parser is None: + parser = DefusedXMLParser(target=_TreeBuilder(), + forbid_dtd=forbid_dtd, + forbid_entities=forbid_entities, + forbid_external=forbid_external) + return _parse(source, parser) + + def iterparse(source, events=None, parser=None, forbid_dtd=False, + forbid_entities=True, forbid_external=True): + if parser is None: + parser = DefusedXMLParser(target=_TreeBuilder(), + forbid_dtd=forbid_dtd, + forbid_entities=forbid_entities, + forbid_external=forbid_external) + return _iterparse(source, events, parser) + + def fromstring(text, forbid_dtd=False, forbid_entities=True, + forbid_external=True): + parser = DefusedXMLParser(target=_TreeBuilder(), + forbid_dtd=forbid_dtd, + forbid_entities=forbid_entities, + forbid_external=forbid_external) + parser.feed(text) + return parser.close() + + return parse, iterparse, fromstring diff --git a/python/defusedxml/expatbuilder.py b/python/defusedxml/expatbuilder.py new file mode 100644 index 0000000..0eb6b91 --- /dev/null +++ b/python/defusedxml/expatbuilder.py @@ -0,0 +1,110 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.dom.expatbuilder +""" +from __future__ import print_function, absolute_import + +from xml.dom.expatbuilder import ExpatBuilder as _ExpatBuilder +from xml.dom.expatbuilder import Namespaces as _Namespaces + +from .common import (DTDForbidden, EntitiesForbidden, + ExternalReferenceForbidden) + +__origin__ = "xml.dom.expatbuilder" + + +class DefusedExpatBuilder(_ExpatBuilder): + """Defused document builder""" + + def __init__(self, options=None, forbid_dtd=False, forbid_entities=True, + forbid_external=True): + _ExpatBuilder.__init__(self, options) + self.forbid_dtd = forbid_dtd + self.forbid_entities = forbid_entities + self.forbid_external = forbid_external + + def defused_start_doctype_decl(self, name, sysid, pubid, + has_internal_subset): + raise DTDForbidden(name, sysid, pubid) + + def defused_entity_decl(self, name, is_parameter_entity, value, base, + sysid, pubid, notation_name): + raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name) + + def defused_unparsed_entity_decl(self, name, base, sysid, pubid, + notation_name): + # expat 1.2 + raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name) + + def defused_external_entity_ref_handler(self, context, base, sysid, + pubid): + raise ExternalReferenceForbidden(context, base, sysid, pubid) + + def install(self, parser): + _ExpatBuilder.install(self, parser) + + if self.forbid_dtd: + parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl + if self.forbid_entities: + # if self._options.entities: + parser.EntityDeclHandler = self.defused_entity_decl + parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl + if self.forbid_external: + parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler + + +class DefusedExpatBuilderNS(_Namespaces, DefusedExpatBuilder): + """Defused document builder that supports namespaces.""" + + def install(self, parser): + DefusedExpatBuilder.install(self, parser) + if self._options.namespace_declarations: + parser.StartNamespaceDeclHandler = ( + self.start_namespace_decl_handler) + + def reset(self): + DefusedExpatBuilder.reset(self) + self._initNamespaces() + + +def parse(file, namespaces=True, forbid_dtd=False, forbid_entities=True, + forbid_external=True): + """Parse a document, returning the resulting Document node. + + 'file' may be either a file name or an open file object. + """ + if namespaces: + build_builder = DefusedExpatBuilderNS + else: + build_builder = DefusedExpatBuilder + builder = build_builder(forbid_dtd=forbid_dtd, + forbid_entities=forbid_entities, + forbid_external=forbid_external) + + if isinstance(file, str): + fp = open(file, 'rb') + try: + result = builder.parseFile(fp) + finally: + fp.close() + else: + result = builder.parseFile(file) + return result + + +def parseString(string, namespaces=True, forbid_dtd=False, + forbid_entities=True, forbid_external=True): + """Parse a document from a string, returning the resulting + Document node. + """ + if namespaces: + build_builder = DefusedExpatBuilderNS + else: + build_builder = DefusedExpatBuilder + builder = build_builder(forbid_dtd=forbid_dtd, + forbid_entities=forbid_entities, + forbid_external=forbid_external) + return builder.parseString(string) diff --git a/python/defusedxml/expatreader.py b/python/defusedxml/expatreader.py new file mode 100644 index 0000000..ef6bc39 --- /dev/null +++ b/python/defusedxml/expatreader.py @@ -0,0 +1,59 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.sax.expatreader +""" +from __future__ import print_function, absolute_import + +from xml.sax.expatreader import ExpatParser as _ExpatParser + +from .common import (DTDForbidden, EntitiesForbidden, + ExternalReferenceForbidden) + +__origin__ = "xml.sax.expatreader" + + +class DefusedExpatParser(_ExpatParser): + """Defused SAX driver for the pyexpat C module.""" + + def __init__(self, namespaceHandling=0, bufsize=2 ** 16 - 20, + forbid_dtd=False, forbid_entities=True, + forbid_external=True): + _ExpatParser.__init__(self, namespaceHandling, bufsize) + self.forbid_dtd = forbid_dtd + self.forbid_entities = forbid_entities + self.forbid_external = forbid_external + + def defused_start_doctype_decl(self, name, sysid, pubid, + has_internal_subset): + raise DTDForbidden(name, sysid, pubid) + + def defused_entity_decl(self, name, is_parameter_entity, value, base, + sysid, pubid, notation_name): + raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name) + + def defused_unparsed_entity_decl(self, name, base, sysid, pubid, + notation_name): + # expat 1.2 + raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name) + + def defused_external_entity_ref_handler(self, context, base, sysid, + pubid): + raise ExternalReferenceForbidden(context, base, sysid, pubid) + + def reset(self): + _ExpatParser.reset(self) + parser = self._parser + if self.forbid_dtd: + parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl + if self.forbid_entities: + parser.EntityDeclHandler = self.defused_entity_decl + parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl + if self.forbid_external: + parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler + + +def create_parser(*args, **kwargs): + return DefusedExpatParser(*args, **kwargs) diff --git a/python/defusedxml/lxml.py b/python/defusedxml/lxml.py new file mode 100644 index 0000000..7f3ee0b --- /dev/null +++ b/python/defusedxml/lxml.py @@ -0,0 +1,153 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Example code for lxml.etree protection + +The code has NO protection against decompression bombs. +""" +from __future__ import print_function, absolute_import + +import threading +from lxml import etree as _etree + +from .common import DTDForbidden, EntitiesForbidden, NotSupportedError + +LXML3 = _etree.LXML_VERSION[0] >= 3 + +__origin__ = "lxml.etree" + +tostring = _etree.tostring + + +class RestrictedElement(_etree.ElementBase): + """A restricted Element class that filters out instances of some classes + """ + __slots__ = () + # blacklist = (etree._Entity, etree._ProcessingInstruction, etree._Comment) + blacklist = _etree._Entity + + def _filter(self, iterator): + blacklist = self.blacklist + for child in iterator: + if isinstance(child, blacklist): + continue + yield child + + def __iter__(self): + iterator = super(RestrictedElement, self).__iter__() + return self._filter(iterator) + + def iterchildren(self, tag=None, reversed=False): + iterator = super(RestrictedElement, self).iterchildren( + tag=tag, reversed=reversed) + return self._filter(iterator) + + def iter(self, tag=None, *tags): + iterator = super(RestrictedElement, self).iter(tag=tag, *tags) + return self._filter(iterator) + + def iterdescendants(self, tag=None, *tags): + iterator = super(RestrictedElement, + self).iterdescendants(tag=tag, *tags) + return self._filter(iterator) + + def itersiblings(self, tag=None, preceding=False): + iterator = super(RestrictedElement, self).itersiblings( + tag=tag, preceding=preceding) + return self._filter(iterator) + + def getchildren(self): + iterator = super(RestrictedElement, self).__iter__() + return list(self._filter(iterator)) + + def getiterator(self, tag=None): + iterator = super(RestrictedElement, self).getiterator(tag) + return self._filter(iterator) + + +class GlobalParserTLS(threading.local): + """Thread local context for custom parser instances + """ + parser_config = { + 'resolve_entities': False, + # 'remove_comments': True, + # 'remove_pis': True, + } + + element_class = RestrictedElement + + def createDefaultParser(self): + parser = _etree.XMLParser(**self.parser_config) + element_class = self.element_class + if self.element_class is not None: + lookup = _etree.ElementDefaultClassLookup(element=element_class) + parser.set_element_class_lookup(lookup) + return parser + + def setDefaultParser(self, parser): + self._default_parser = parser + + def getDefaultParser(self): + parser = getattr(self, "_default_parser", None) + if parser is None: + parser = self.createDefaultParser() + self.setDefaultParser(parser) + return parser + + +_parser_tls = GlobalParserTLS() +getDefaultParser = _parser_tls.getDefaultParser + + +def check_docinfo(elementtree, forbid_dtd=False, forbid_entities=True): + """Check docinfo of an element tree for DTD and entity declarations + + The check for entity declarations needs lxml 3 or newer. lxml 2.x does + not support dtd.iterentities(). + """ + docinfo = elementtree.docinfo + if docinfo.doctype: + if forbid_dtd: + raise DTDForbidden(docinfo.doctype, + docinfo.system_url, + docinfo.public_id) + if forbid_entities and not LXML3: + # lxml < 3 has no iterentities() + raise NotSupportedError("Unable to check for entity declarations " + "in lxml 2.x") + + if forbid_entities: + for dtd in docinfo.internalDTD, docinfo.externalDTD: + if dtd is None: + continue + for entity in dtd.iterentities(): + raise EntitiesForbidden(entity.name, entity.content, None, + None, None, None) + + +def parse(source, parser=None, base_url=None, forbid_dtd=False, + forbid_entities=True): + if parser is None: + parser = getDefaultParser() + elementtree = _etree.parse(source, parser, base_url=base_url) + check_docinfo(elementtree, forbid_dtd, forbid_entities) + return elementtree + + +def fromstring(text, parser=None, base_url=None, forbid_dtd=False, + forbid_entities=True): + if parser is None: + parser = getDefaultParser() + rootelement = _etree.fromstring(text, parser, base_url=base_url) + elementtree = rootelement.getroottree() + check_docinfo(elementtree, forbid_dtd, forbid_entities) + return rootelement + + +XML = fromstring + + +def iterparse(*args, **kwargs): + raise NotSupportedError("defused lxml.etree.iterparse not available") diff --git a/python/defusedxml/minidom.py b/python/defusedxml/minidom.py new file mode 100644 index 0000000..0fd8684 --- /dev/null +++ b/python/defusedxml/minidom.py @@ -0,0 +1,42 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.dom.minidom +""" +from __future__ import print_function, absolute_import + +from xml.dom.minidom import _do_pulldom_parse +from . import expatbuilder as _expatbuilder +from . import pulldom as _pulldom + +__origin__ = "xml.dom.minidom" + + +def parse(file, parser=None, bufsize=None, forbid_dtd=False, + forbid_entities=True, forbid_external=True): + """Parse a file into a DOM by filename or file object.""" + if parser is None and not bufsize: + return _expatbuilder.parse(file, forbid_dtd=forbid_dtd, + forbid_entities=forbid_entities, + forbid_external=forbid_external) + else: + return _do_pulldom_parse(_pulldom.parse, (file,), + {'parser': parser, 'bufsize': bufsize, + 'forbid_dtd': forbid_dtd, 'forbid_entities': forbid_entities, + 'forbid_external': forbid_external}) + + +def parseString(string, parser=None, forbid_dtd=False, + forbid_entities=True, forbid_external=True): + """Parse a file into a DOM from a string.""" + if parser is None: + return _expatbuilder.parseString(string, forbid_dtd=forbid_dtd, + forbid_entities=forbid_entities, + forbid_external=forbid_external) + else: + return _do_pulldom_parse(_pulldom.parseString, (string,), + {'parser': parser, 'forbid_dtd': forbid_dtd, + 'forbid_entities': forbid_entities, + 'forbid_external': forbid_external}) diff --git a/python/defusedxml/pulldom.py b/python/defusedxml/pulldom.py new file mode 100644 index 0000000..fc9e466 --- /dev/null +++ b/python/defusedxml/pulldom.py @@ -0,0 +1,34 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.dom.pulldom +""" +from __future__ import print_function, absolute_import + +from xml.dom.pulldom import parse as _parse +from xml.dom.pulldom import parseString as _parseString +from .sax import make_parser + +__origin__ = "xml.dom.pulldom" + + +def parse(stream_or_string, parser=None, bufsize=None, forbid_dtd=False, + forbid_entities=True, forbid_external=True): + if parser is None: + parser = make_parser() + parser.forbid_dtd = forbid_dtd + parser.forbid_entities = forbid_entities + parser.forbid_external = forbid_external + return _parse(stream_or_string, parser, bufsize) + + +def parseString(string, parser=None, forbid_dtd=False, + forbid_entities=True, forbid_external=True): + if parser is None: + parser = make_parser() + parser.forbid_dtd = forbid_dtd + parser.forbid_entities = forbid_entities + parser.forbid_external = forbid_external + return _parseString(string, parser) diff --git a/python/defusedxml/sax.py b/python/defusedxml/sax.py new file mode 100644 index 0000000..534d0ca --- /dev/null +++ b/python/defusedxml/sax.py @@ -0,0 +1,49 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.sax +""" +from __future__ import print_function, absolute_import + +from xml.sax import InputSource as _InputSource +from xml.sax import ErrorHandler as _ErrorHandler + +from . import expatreader + +__origin__ = "xml.sax" + + +def parse(source, handler, errorHandler=_ErrorHandler(), forbid_dtd=False, + forbid_entities=True, forbid_external=True): + parser = make_parser() + parser.setContentHandler(handler) + parser.setErrorHandler(errorHandler) + parser.forbid_dtd = forbid_dtd + parser.forbid_entities = forbid_entities + parser.forbid_external = forbid_external + parser.parse(source) + + +def parseString(string, handler, errorHandler=_ErrorHandler(), + forbid_dtd=False, forbid_entities=True, + forbid_external=True): + from io import BytesIO + + if errorHandler is None: + errorHandler = _ErrorHandler() + parser = make_parser() + parser.setContentHandler(handler) + parser.setErrorHandler(errorHandler) + parser.forbid_dtd = forbid_dtd + parser.forbid_entities = forbid_entities + parser.forbid_external = forbid_external + + inpsrc = _InputSource() + inpsrc.setByteStream(BytesIO(string)) + parser.parse(inpsrc) + + +def make_parser(parser_list=[]): + return expatreader.create_parser() diff --git a/python/defusedxml/xmlrpc.py b/python/defusedxml/xmlrpc.py new file mode 100644 index 0000000..2a456e6 --- /dev/null +++ b/python/defusedxml/xmlrpc.py @@ -0,0 +1,157 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xmlrpclib + +Also defuses gzip bomb +""" +from __future__ import print_function, absolute_import + +import io + +from .common import ( + DTDForbidden, EntitiesForbidden, ExternalReferenceForbidden, PY3) + +if PY3: + __origin__ = "xmlrpc.client" + from xmlrpc.client import ExpatParser + from xmlrpc import client as xmlrpc_client + from xmlrpc import server as xmlrpc_server + from xmlrpc.client import gzip_decode as _orig_gzip_decode + from xmlrpc.client import GzipDecodedResponse as _OrigGzipDecodedResponse +else: + __origin__ = "xmlrpclib" + from xmlrpclib import ExpatParser + import xmlrpclib as xmlrpc_client + xmlrpc_server = None + from xmlrpclib import gzip_decode as _orig_gzip_decode + from xmlrpclib import GzipDecodedResponse as _OrigGzipDecodedResponse + +try: + import gzip +except ImportError: + gzip = None + + +# Limit maximum request size to prevent resource exhaustion DoS +# Also used to limit maximum amount of gzip decoded data in order to prevent +# decompression bombs +# A value of -1 or smaller disables the limit +MAX_DATA = 30 * 1024 * 1024 # 30 MB + + +def defused_gzip_decode(data, limit=None): + """gzip encoded data -> unencoded data + + Decode data using the gzip content encoding as described in RFC 1952 + """ + if not gzip: + raise NotImplementedError + if limit is None: + limit = MAX_DATA + f = io.BytesIO(data) + gzf = gzip.GzipFile(mode="rb", fileobj=f) + try: + if limit < 0: # no limit + decoded = gzf.read() + else: + decoded = gzf.read(limit + 1) + except IOError: + raise ValueError("invalid data") + f.close() + gzf.close() + if limit >= 0 and len(decoded) > limit: + raise ValueError("max gzipped payload length exceeded") + return decoded + + +class DefusedGzipDecodedResponse(gzip.GzipFile if gzip else object): + """a file-like object to decode a response encoded with the gzip + method, as described in RFC 1952. + """ + + def __init__(self, response, limit=None): + # response doesn't support tell() and read(), required by + # GzipFile + if not gzip: + raise NotImplementedError + self.limit = limit = limit if limit is not None else MAX_DATA + if limit < 0: # no limit + data = response.read() + self.readlength = None + else: + data = response.read(limit + 1) + self.readlength = 0 + if limit >= 0 and len(data) > limit: + raise ValueError("max payload length exceeded") + self.stringio = io.BytesIO(data) + gzip.GzipFile.__init__(self, mode="rb", fileobj=self.stringio) + + def read(self, n): + if self.limit >= 0: + left = self.limit - self.readlength + n = min(n, left + 1) + data = gzip.GzipFile.read(self, n) + self.readlength += len(data) + if self.readlength > self.limit: + raise ValueError("max payload length exceeded") + return data + else: + return gzip.GzipFile.read(self, n) + + def close(self): + gzip.GzipFile.close(self) + self.stringio.close() + + +class DefusedExpatParser(ExpatParser): + + def __init__(self, target, forbid_dtd=False, forbid_entities=True, + forbid_external=True): + ExpatParser.__init__(self, target) + self.forbid_dtd = forbid_dtd + self.forbid_entities = forbid_entities + self.forbid_external = forbid_external + parser = self._parser + if self.forbid_dtd: + parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl + if self.forbid_entities: + parser.EntityDeclHandler = self.defused_entity_decl + parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl + if self.forbid_external: + parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler + + def defused_start_doctype_decl(self, name, sysid, pubid, + has_internal_subset): + raise DTDForbidden(name, sysid, pubid) + + def defused_entity_decl(self, name, is_parameter_entity, value, base, + sysid, pubid, notation_name): + raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name) + + def defused_unparsed_entity_decl(self, name, base, sysid, pubid, + notation_name): + # expat 1.2 + raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name) + + def defused_external_entity_ref_handler(self, context, base, sysid, + pubid): + raise ExternalReferenceForbidden(context, base, sysid, pubid) + + +def monkey_patch(): + xmlrpc_client.FastParser = DefusedExpatParser + xmlrpc_client.GzipDecodedResponse = DefusedGzipDecodedResponse + xmlrpc_client.gzip_decode = defused_gzip_decode + if xmlrpc_server: + xmlrpc_server.gzip_decode = defused_gzip_decode + + +def unmonkey_patch(): + xmlrpc_client.FastParser = None + xmlrpc_client.GzipDecodedResponse = _OrigGzipDecodedResponse + xmlrpc_client.gzip_decode = _orig_gzip_decode + if xmlrpc_server: + xmlrpc_server.gzip_decode = _orig_gzip_decode diff --git a/python/six.py b/python/six.py new file mode 100644 index 0000000..6bf4fd3 --- /dev/null +++ b/python/six.py @@ -0,0 +1,891 @@ +# Copyright (c) 2010-2017 Benjamin Peterson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Utilities for writing code that runs on Python 2 and 3""" + +from __future__ import absolute_import + +import functools +import itertools +import operator +import sys +import types + +__author__ = "Benjamin Peterson <benjamin@python.org>" +__version__ = "1.11.0" + + +# Useful for very coarse version differentiation. +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 +PY34 = sys.version_info[0:2] >= (3, 4) + +if PY3: + string_types = str, + integer_types = int, + class_types = type, + text_type = str + binary_type = bytes + + MAXSIZE = sys.maxsize +else: + string_types = basestring, + integer_types = (int, long) + class_types = (type, types.ClassType) + text_type = unicode + binary_type = str + + if sys.platform.startswith("java"): + # Jython always uses 32 bits. + MAXSIZE = int((1 << 31) - 1) + else: + # It's possible to have sizeof(long) != sizeof(Py_ssize_t). + class X(object): + + def __len__(self): + return 1 << 31 + try: + len(X()) + except OverflowError: + # 32-bit + MAXSIZE = int((1 << 31) - 1) + else: + # 64-bit + MAXSIZE = int((1 << 63) - 1) + del X + + +def _add_doc(func, doc): + """Add documentation to a function.""" + func.__doc__ = doc + + +def _import_module(name): + """Import module, returning the module after the last dot.""" + __import__(name) + return sys.modules[name] + + +class _LazyDescr(object): + + def __init__(self, name): + self.name = name + + def __get__(self, obj, tp): + result = self._resolve() + setattr(obj, self.name, result) # Invokes __set__. + try: + # This is a bit ugly, but it avoids running this again by + # removing this descriptor. + delattr(obj.__class__, self.name) + except AttributeError: + pass + return result + + +class MovedModule(_LazyDescr): + + def __init__(self, name, old, new=None): + super(MovedModule, self).__init__(name) + if PY3: + if new is None: + new = name + self.mod = new + else: + self.mod = old + + def _resolve(self): + return _import_module(self.mod) + + def __getattr__(self, attr): + _module = self._resolve() + value = getattr(_module, attr) + setattr(self, attr, value) + return value + + +class _LazyModule(types.ModuleType): + + def __init__(self, name): + super(_LazyModule, self).__init__(name) + self.__doc__ = self.__class__.__doc__ + + def __dir__(self): + attrs = ["__doc__", "__name__"] + attrs += [attr.name for attr in self._moved_attributes] + return attrs + + # Subclasses should override this + _moved_attributes = [] + + +class MovedAttribute(_LazyDescr): + + def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): + super(MovedAttribute, self).__init__(name) + if PY3: + if new_mod is None: + new_mod = name + self.mod = new_mod + if new_attr is None: + if old_attr is None: + new_attr = name + else: + new_attr = old_attr + self.attr = new_attr + else: + self.mod = old_mod + if old_attr is None: + old_attr = name + self.attr = old_attr + + def _resolve(self): + module = _import_module(self.mod) + return getattr(module, self.attr) + + +class _SixMetaPathImporter(object): + + """ + A meta path importer to import six.moves and its submodules. + + This class implements a PEP302 finder and loader. It should be compatible + with Python 2.5 and all existing versions of Python3 + """ + + def __init__(self, six_module_name): + self.name = six_module_name + self.known_modules = {} + + def _add_module(self, mod, *fullnames): + for fullname in fullnames: + self.known_modules[self.name + "." + fullname] = mod + + def _get_module(self, fullname): + return self.known_modules[self.name + "." + fullname] + + def find_module(self, fullname, path=None): + if fullname in self.known_modules: + return self + return None + + def __get_module(self, fullname): + try: + return self.known_modules[fullname] + except KeyError: + raise ImportError("This loader does not know module " + fullname) + + def load_module(self, fullname): + try: + # in case of a reload + return sys.modules[fullname] + except KeyError: + pass + mod = self.__get_module(fullname) + if isinstance(mod, MovedModule): + mod = mod._resolve() + else: + mod.__loader__ = self + sys.modules[fullname] = mod + return mod + + def is_package(self, fullname): + """ + Return true, if the named module is a package. + + We need this method to get correct spec objects with + Python 3.4 (see PEP451) + """ + return hasattr(self.__get_module(fullname), "__path__") + + def get_code(self, fullname): + """Return None + + Required, if is_package is implemented""" + self.__get_module(fullname) # eventually raises ImportError + return None + get_source = get_code # same as get_code + +_importer = _SixMetaPathImporter(__name__) + + +class _MovedItems(_LazyModule): + + """Lazy loading of moved objects""" + __path__ = [] # mark as package + + +_moved_attributes = [ + MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), + MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), + MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), + MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), + MovedAttribute("intern", "__builtin__", "sys"), + MovedAttribute("map", "itertools", "builtins", "imap", "map"), + MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), + MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), + MovedAttribute("getoutput", "commands", "subprocess"), + MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), + MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), + MovedAttribute("reduce", "__builtin__", "functools"), + MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), + MovedAttribute("StringIO", "StringIO", "io"), + MovedAttribute("UserDict", "UserDict", "collections"), + MovedAttribute("UserList", "UserList", "collections"), + MovedAttribute("UserString", "UserString", "collections"), + MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), + MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), + MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), + MovedModule("builtins", "__builtin__"), + MovedModule("configparser", "ConfigParser"), + MovedModule("copyreg", "copy_reg"), + MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), + MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"), + MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), + MovedModule("http_cookies", "Cookie", "http.cookies"), + MovedModule("html_entities", "htmlentitydefs", "html.entities"), + MovedModule("html_parser", "HTMLParser", "html.parser"), + MovedModule("http_client", "httplib", "http.client"), + MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), + MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"), + MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), + MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), + MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), + MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), + MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), + MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), + MovedModule("cPickle", "cPickle", "pickle"), + MovedModule("queue", "Queue"), + MovedModule("reprlib", "repr"), + MovedModule("socketserver", "SocketServer"), + MovedModule("_thread", "thread", "_thread"), + MovedModule("tkinter", "Tkinter"), + MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), + MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), + MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), + MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), + MovedModule("tkinter_tix", "Tix", "tkinter.tix"), + MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), + MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), + MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), + MovedModule("tkinter_colorchooser", "tkColorChooser", + "tkinter.colorchooser"), + MovedModule("tkinter_commondialog", "tkCommonDialog", + "tkinter.commondialog"), + MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), + MovedModule("tkinter_font", "tkFont", "tkinter.font"), + MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), + MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", + "tkinter.simpledialog"), + MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), + MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), + MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), + MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), + MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), + MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), +] +# Add windows specific modules. +if sys.platform == "win32": + _moved_attributes += [ + MovedModule("winreg", "_winreg"), + ] + +for attr in _moved_attributes: + setattr(_MovedItems, attr.name, attr) + if isinstance(attr, MovedModule): + _importer._add_module(attr, "moves." + attr.name) +del attr + +_MovedItems._moved_attributes = _moved_attributes + +moves = _MovedItems(__name__ + ".moves") +_importer._add_module(moves, "moves") + + +class Module_six_moves_urllib_parse(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_parse""" + + +_urllib_parse_moved_attributes = [ + MovedAttribute("ParseResult", "urlparse", "urllib.parse"), + MovedAttribute("SplitResult", "urlparse", "urllib.parse"), + MovedAttribute("parse_qs", "urlparse", "urllib.parse"), + MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), + MovedAttribute("urldefrag", "urlparse", "urllib.parse"), + MovedAttribute("urljoin", "urlparse", "urllib.parse"), + MovedAttribute("urlparse", "urlparse", "urllib.parse"), + MovedAttribute("urlsplit", "urlparse", "urllib.parse"), + MovedAttribute("urlunparse", "urlparse", "urllib.parse"), + MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), + MovedAttribute("quote", "urllib", "urllib.parse"), + MovedAttribute("quote_plus", "urllib", "urllib.parse"), + MovedAttribute("unquote", "urllib", "urllib.parse"), + MovedAttribute("unquote_plus", "urllib", "urllib.parse"), + MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"), + MovedAttribute("urlencode", "urllib", "urllib.parse"), + MovedAttribute("splitquery", "urllib", "urllib.parse"), + MovedAttribute("splittag", "urllib", "urllib.parse"), + MovedAttribute("splituser", "urllib", "urllib.parse"), + MovedAttribute("splitvalue", "urllib", "urllib.parse"), + MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), + MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), + MovedAttribute("uses_params", "urlparse", "urllib.parse"), + MovedAttribute("uses_query", "urlparse", "urllib.parse"), + MovedAttribute("uses_relative", "urlparse", "urllib.parse"), +] +for attr in _urllib_parse_moved_attributes: + setattr(Module_six_moves_urllib_parse, attr.name, attr) +del attr + +Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes + +_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), + "moves.urllib_parse", "moves.urllib.parse") + + +class Module_six_moves_urllib_error(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_error""" + + +_urllib_error_moved_attributes = [ + MovedAttribute("URLError", "urllib2", "urllib.error"), + MovedAttribute("HTTPError", "urllib2", "urllib.error"), + MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), +] +for attr in _urllib_error_moved_attributes: + setattr(Module_six_moves_urllib_error, attr.name, attr) +del attr + +Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes + +_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), + "moves.urllib_error", "moves.urllib.error") + + +class Module_six_moves_urllib_request(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_request""" + + +_urllib_request_moved_attributes = [ + MovedAttribute("urlopen", "urllib2", "urllib.request"), + MovedAttribute("install_opener", "urllib2", "urllib.request"), + MovedAttribute("build_opener", "urllib2", "urllib.request"), + MovedAttribute("pathname2url", "urllib", "urllib.request"), + MovedAttribute("url2pathname", "urllib", "urllib.request"), + MovedAttribute("getproxies", "urllib", "urllib.request"), + MovedAttribute("Request", "urllib2", "urllib.request"), + MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), + MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), + MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), + MovedAttribute("BaseHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), + MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), + MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), + MovedAttribute("FileHandler", "urllib2", "urllib.request"), + MovedAttribute("FTPHandler", "urllib2", "urllib.request"), + MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), + MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), + MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), + MovedAttribute("urlretrieve", "urllib", "urllib.request"), + MovedAttribute("urlcleanup", "urllib", "urllib.request"), + MovedAttribute("URLopener", "urllib", "urllib.request"), + MovedAttribute("FancyURLopener", "urllib", "urllib.request"), + MovedAttribute("proxy_bypass", "urllib", "urllib.request"), + MovedAttribute("parse_http_list", "urllib2", "urllib.request"), + MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"), +] +for attr in _urllib_request_moved_attributes: + setattr(Module_six_moves_urllib_request, attr.name, attr) +del attr + +Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes + +_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), + "moves.urllib_request", "moves.urllib.request") + + +class Module_six_moves_urllib_response(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_response""" + + +_urllib_response_moved_attributes = [ + MovedAttribute("addbase", "urllib", "urllib.response"), + MovedAttribute("addclosehook", "urllib", "urllib.response"), + MovedAttribute("addinfo", "urllib", "urllib.response"), + MovedAttribute("addinfourl", "urllib", "urllib.response"), +] +for attr in _urllib_response_moved_attributes: + setattr(Module_six_moves_urllib_response, attr.name, attr) +del attr + +Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes + +_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), + "moves.urllib_response", "moves.urllib.response") + + +class Module_six_moves_urllib_robotparser(_LazyModule): + + """Lazy loading of moved objects in six.moves.urllib_robotparser""" + + +_urllib_robotparser_moved_attributes = [ + MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"), +] +for attr in _urllib_robotparser_moved_attributes: + setattr(Module_six_moves_urllib_robotparser, attr.name, attr) +del attr + +Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes + +_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), + "moves.urllib_robotparser", "moves.urllib.robotparser") + + +class Module_six_moves_urllib(types.ModuleType): + + """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" + __path__ = [] # mark as package + parse = _importer._get_module("moves.urllib_parse") + error = _importer._get_module("moves.urllib_error") + request = _importer._get_module("moves.urllib_request") + response = _importer._get_module("moves.urllib_response") + robotparser = _importer._get_module("moves.urllib_robotparser") + + def __dir__(self): + return ['parse', 'error', 'request', 'response', 'robotparser'] + +_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), + "moves.urllib") + + +def add_move(move): + """Add an item to six.moves.""" + setattr(_MovedItems, move.name, move) + + +def remove_move(name): + """Remove item from six.moves.""" + try: + delattr(_MovedItems, name) + except AttributeError: + try: + del moves.__dict__[name] + except KeyError: + raise AttributeError("no such move, %r" % (name,)) + + +if PY3: + _meth_func = "__func__" + _meth_self = "__self__" + + _func_closure = "__closure__" + _func_code = "__code__" + _func_defaults = "__defaults__" + _func_globals = "__globals__" +else: + _meth_func = "im_func" + _meth_self = "im_self" + + _func_closure = "func_closure" + _func_code = "func_code" + _func_defaults = "func_defaults" + _func_globals = "func_globals" + + +try: + advance_iterator = next +except NameError: + def advance_iterator(it): + return it.next() +next = advance_iterator + + +try: + callable = callable +except NameError: + def callable(obj): + return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) + + +if PY3: + def get_unbound_function(unbound): + return unbound + + create_bound_method = types.MethodType + + def create_unbound_method(func, cls): + return func + + Iterator = object +else: + def get_unbound_function(unbound): + return unbound.im_func + + def create_bound_method(func, obj): + return types.MethodType(func, obj, obj.__class__) + + def create_unbound_method(func, cls): + return types.MethodType(func, None, cls) + + class Iterator(object): + + def next(self): + return type(self).__next__(self) + + callable = callable +_add_doc(get_unbound_function, + """Get the function out of a possibly unbound function""") + + +get_method_function = operator.attrgetter(_meth_func) +get_method_self = operator.attrgetter(_meth_self) +get_function_closure = operator.attrgetter(_func_closure) +get_function_code = operator.attrgetter(_func_code) +get_function_defaults = operator.attrgetter(_func_defaults) +get_function_globals = operator.attrgetter(_func_globals) + + +if PY3: + def iterkeys(d, **kw): + return iter(d.keys(**kw)) + + def itervalues(d, **kw): + return iter(d.values(**kw)) + + def iteritems(d, **kw): + return iter(d.items(**kw)) + + def iterlists(d, **kw): + return iter(d.lists(**kw)) + + viewkeys = operator.methodcaller("keys") + + viewvalues = operator.methodcaller("values") + + viewitems = operator.methodcaller("items") +else: + def iterkeys(d, **kw): + return d.iterkeys(**kw) + + def itervalues(d, **kw): + return d.itervalues(**kw) + + def iteritems(d, **kw): + return d.iteritems(**kw) + + def iterlists(d, **kw): + return d.iterlists(**kw) + + viewkeys = operator.methodcaller("viewkeys") + + viewvalues = operator.methodcaller("viewvalues") + + viewitems = operator.methodcaller("viewitems") + +_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") +_add_doc(itervalues, "Return an iterator over the values of a dictionary.") +_add_doc(iteritems, + "Return an iterator over the (key, value) pairs of a dictionary.") +_add_doc(iterlists, + "Return an iterator over the (key, [values]) pairs of a dictionary.") + + +if PY3: + def b(s): + return s.encode("latin-1") + + def u(s): + return s + unichr = chr + import struct + int2byte = struct.Struct(">B").pack + del struct + byte2int = operator.itemgetter(0) + indexbytes = operator.getitem + iterbytes = iter + import io + StringIO = io.StringIO + BytesIO = io.BytesIO + _assertCountEqual = "assertCountEqual" + if sys.version_info[1] <= 1: + _assertRaisesRegex = "assertRaisesRegexp" + _assertRegex = "assertRegexpMatches" + else: + _assertRaisesRegex = "assertRaisesRegex" + _assertRegex = "assertRegex" +else: + def b(s): + return s + # Workaround for standalone backslash + + def u(s): + return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") + unichr = unichr + int2byte = chr + + def byte2int(bs): + return ord(bs[0]) + + def indexbytes(buf, i): + return ord(buf[i]) + iterbytes = functools.partial(itertools.imap, ord) + import StringIO + StringIO = BytesIO = StringIO.StringIO + _assertCountEqual = "assertItemsEqual" + _assertRaisesRegex = "assertRaisesRegexp" + _assertRegex = "assertRegexpMatches" +_add_doc(b, """Byte literal""") +_add_doc(u, """Text literal""") + + +def assertCountEqual(self, *args, **kwargs): + return getattr(self, _assertCountEqual)(*args, **kwargs) + + +def assertRaisesRegex(self, *args, **kwargs): + return getattr(self, _assertRaisesRegex)(*args, **kwargs) + + +def assertRegex(self, *args, **kwargs): + return getattr(self, _assertRegex)(*args, **kwargs) + + +if PY3: + exec_ = getattr(moves.builtins, "exec") + + def reraise(tp, value, tb=None): + try: + if value is None: + value = tp() + if value.__traceback__ is not tb: + raise value.with_traceback(tb) + raise value + finally: + value = None + tb = None + +else: + def exec_(_code_, _globs_=None, _locs_=None): + """Execute code in a namespace.""" + if _globs_ is None: + frame = sys._getframe(1) + _globs_ = frame.f_globals + if _locs_ is None: + _locs_ = frame.f_locals + del frame + elif _locs_ is None: + _locs_ = _globs_ + exec("""exec _code_ in _globs_, _locs_""") + + exec_("""def reraise(tp, value, tb=None): + try: + raise tp, value, tb + finally: + tb = None +""") + + +if sys.version_info[:2] == (3, 2): + exec_("""def raise_from(value, from_value): + try: + if from_value is None: + raise value + raise value from from_value + finally: + value = None +""") +elif sys.version_info[:2] > (3, 2): + exec_("""def raise_from(value, from_value): + try: + raise value from from_value + finally: + value = None +""") +else: + def raise_from(value, from_value): + raise value + + +print_ = getattr(moves.builtins, "print", None) +if print_ is None: + def print_(*args, **kwargs): + """The new-style print function for Python 2.4 and 2.5.""" + fp = kwargs.pop("file", sys.stdout) + if fp is None: + return + + def write(data): + if not isinstance(data, basestring): + data = str(data) + # If the file has an encoding, encode unicode with it. + if (isinstance(fp, file) and + isinstance(data, unicode) and + fp.encoding is not None): + errors = getattr(fp, "errors", None) + if errors is None: + errors = "strict" + data = data.encode(fp.encoding, errors) + fp.write(data) + want_unicode = False + sep = kwargs.pop("sep", None) + if sep is not None: + if isinstance(sep, unicode): + want_unicode = True + elif not isinstance(sep, str): + raise TypeError("sep must be None or a string") + end = kwargs.pop("end", None) + if end is not None: + if isinstance(end, unicode): + want_unicode = True + elif not isinstance(end, str): + raise TypeError("end must be None or a string") + if kwargs: + raise TypeError("invalid keyword arguments to print()") + if not want_unicode: + for arg in args: + if isinstance(arg, unicode): + want_unicode = True + break + if want_unicode: + newline = unicode("\n") + space = unicode(" ") + else: + newline = "\n" + space = " " + if sep is None: + sep = space + if end is None: + end = newline + for i, arg in enumerate(args): + if i: + write(sep) + write(arg) + write(end) +if sys.version_info[:2] < (3, 3): + _print = print_ + + def print_(*args, **kwargs): + fp = kwargs.get("file", sys.stdout) + flush = kwargs.pop("flush", False) + _print(*args, **kwargs) + if flush and fp is not None: + fp.flush() + +_add_doc(reraise, """Reraise an exception.""") + +if sys.version_info[0:2] < (3, 4): + def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, + updated=functools.WRAPPER_UPDATES): + def wrapper(f): + f = functools.wraps(wrapped, assigned, updated)(f) + f.__wrapped__ = wrapped + return f + return wrapper +else: + wraps = functools.wraps + + +def with_metaclass(meta, *bases): + """Create a base class with a metaclass.""" + # This requires a bit of explanation: the basic idea is to make a dummy + # metaclass for one level of class instantiation that replaces itself with + # the actual metaclass. + class metaclass(type): + + def __new__(cls, name, this_bases, d): + return meta(name, bases, d) + + @classmethod + def __prepare__(cls, name, this_bases): + return meta.__prepare__(name, bases) + return type.__new__(metaclass, 'temporary_class', (), {}) + + +def add_metaclass(metaclass): + """Class decorator for creating a class with a metaclass.""" + def wrapper(cls): + orig_vars = cls.__dict__.copy() + slots = orig_vars.get('__slots__') + if slots is not None: + if isinstance(slots, str): + slots = [slots] + for slots_var in slots: + orig_vars.pop(slots_var) + orig_vars.pop('__dict__', None) + orig_vars.pop('__weakref__', None) + return metaclass(cls.__name__, cls.__bases__, orig_vars) + return wrapper + + +def python_2_unicode_compatible(klass): + """ + A decorator that defines __unicode__ and __str__ methods under Python 2. + Under Python 3 it does nothing. + + To support Python 2 and 3 with a single code base, define a __str__ method + returning text and apply this decorator to the class. + """ + if PY2: + if '__str__' not in klass.__dict__: + raise ValueError("@python_2_unicode_compatible cannot be applied " + "to %s because it doesn't define __str__()." % + klass.__name__) + klass.__unicode__ = klass.__str__ + klass.__str__ = lambda self: self.__unicode__().encode('utf-8') + return klass + + +# Complete the moves implementation. +# This code is at the end of this module to speed up module loading. +# Turn this module into a package. +__path__ = [] # required for PEP 302 and PEP 451 +__package__ = __name__ # see PEP 366 @ReservedAssignment +if globals().get("__spec__") is not None: + __spec__.submodule_search_locations = [] # PEP 451 @UndefinedVariable +# Remove other six meta path importers, since they cause problems. This can +# happen if six is removed from sys.modules and then reloaded. (Setuptools does +# this for some reason.) +if sys.meta_path: + for i, importer in enumerate(sys.meta_path): + # Here's some real nastiness: Another "instance" of the six module might + # be floating around. Therefore, we can't use isinstance() to check for + # the six meta path importer, since the other six instance will have + # inserted an importer with different class. + if (type(importer).__name__ == "_SixMetaPathImporter" and + importer.name == __name__): + del sys.meta_path[i] + break + del i, importer +# Finally, add the importer to the meta path import hook. +sys.meta_path.append(_importer) @@ -6,7 +6,7 @@ from youtube import yt_app from youtube import util # these are just so the files get run - they import yt_app and add routes to it -from youtube import watch, search, playlist, channel, local_playlist, comments, post_comment +from youtube import watch, search, playlist, channel, local_playlist, comments, post_comment, subscriptions import settings diff --git a/settings.py b/settings.py index 15d2e01..bd6710c 100644 --- a/settings.py +++ b/settings.py @@ -119,6 +119,12 @@ For security reasons, enabling this is not recommended.''', ], }), + ('autocheck_subscriptions', { + 'type': bool, + 'default': 0, + 'comment': '', + }), + ('gather_googlevideo_domains', { 'type': bool, 'default': False, diff --git a/youtube/channel.py b/youtube/channel.py index 4c7d380..de75eaa 100644 --- a/youtube/channel.py +++ b/youtube/channel.py @@ -1,5 +1,5 @@ import base64 -from youtube import util, yt_data_extract, local_playlist +from youtube import util, yt_data_extract, local_playlist, subscriptions from youtube import yt_app import urllib @@ -83,13 +83,15 @@ def channel_ctoken(channel_id, page, sort, tab, view=1): return base64.urlsafe_b64encode(pointless_nest).decode('ascii') -def get_channel_tab(channel_id, page="1", sort=3, tab='videos', view=1): +def get_channel_tab(channel_id, page="1", sort=3, tab='videos', view=1, print_status=True): ctoken = channel_ctoken(channel_id, page, sort, tab, view).replace('=', '%3D') url = "https://www.youtube.com/browse_ajax?ctoken=" + ctoken - print("Sending channel tab ajax request") + if print_status: + print("Sending channel tab ajax request") content = util.fetch_url(url, util.desktop_ua + headers_1, debug_name='channel_tab') - print("Finished recieving channel tab response") + if print_status: + print("Finished recieving channel tab response") return content @@ -312,7 +314,7 @@ def get_channel_page(channel_id, tab='videos'): info['current_sort'] = sort elif tab == 'search': info['search_box_value'] = query - + info['subscribed'] = subscriptions.is_subscribed(info['channel_id']) return flask.render_template('channel.html', parameters_dictionary = request.args, @@ -352,7 +354,7 @@ def get_channel_page_general_url(base_url, tab, request): info['current_sort'] = sort elif tab == 'search': info['search_box_value'] = query - + info['subscribed'] = subscriptions.is_subscribed(info['channel_id']) return flask.render_template('channel.html', parameters_dictionary = request.args, diff --git a/youtube/local_playlist.py b/youtube/local_playlist.py index 4b92315..2375ba2 100644 --- a/youtube/local_playlist.py +++ b/youtube/local_playlist.py @@ -34,33 +34,7 @@ def add_to_playlist(name, video_info_list): if id not in ids: file.write(info + "\n") missing_thumbnails.append(id) - gevent.spawn(download_thumbnails, name, missing_thumbnails) - -def download_thumbnail(playlist_name, video_id): - url = "https://i.ytimg.com/vi/" + video_id + "/mqdefault.jpg" - save_location = os.path.join(thumbnails_directory, playlist_name, video_id + ".jpg") - try: - thumbnail = util.fetch_url(url, report_text="Saved local playlist thumbnail: " + video_id) - except urllib.error.HTTPError as e: - print("Failed to download thumbnail for " + video_id + ": " + str(e)) - return - try: - f = open(save_location, 'wb') - except FileNotFoundError: - os.makedirs(os.path.join(thumbnails_directory, playlist_name)) - f = open(save_location, 'wb') - f.write(thumbnail) - f.close() - -def download_thumbnails(playlist_name, ids): - # only do 5 at a time - # do the n where n is divisible by 5 - i = -1 - for i in range(0, int(len(ids)/5) - 1 ): - gevent.joinall([gevent.spawn(download_thumbnail, playlist_name, ids[j]) for j in range(i*5, i*5 + 5)]) - # do the remainders (< 5) - gevent.joinall([gevent.spawn(download_thumbnail, playlist_name, ids[j]) for j in range(i*5 + 5, len(ids))]) - + gevent.spawn(util.download_thumbnails, os.path.join(thumbnails_directory, name), missing_thumbnails) def get_local_playlist_videos(name, offset=0, amount=50): @@ -88,7 +62,7 @@ def get_local_playlist_videos(name, offset=0, amount=50): except json.decoder.JSONDecodeError: if not video_json.strip() == '': print('Corrupt playlist video entry: ' + video_json) - gevent.spawn(download_thumbnails, name, missing_thumbnails) + gevent.spawn(util.download_thumbnails, os.path.join(thumbnails_directory, name), missing_thumbnails) return videos[offset:offset+amount], len(videos) def get_playlist_names(): diff --git a/youtube/static/dark_theme.css b/youtube/static/dark_theme.css index d75bb82..7389aa5 100644 --- a/youtube/static/dark_theme.css +++ b/youtube/static/dark_theme.css @@ -9,7 +9,7 @@ a:link { } a:visited { - color: ##7755ff; + color: #7755ff; } a:not([href]){ @@ -23,3 +23,16 @@ a:not([href]){ .setting-item{ background-color: #444444; } + + +.muted{ + background-color: #111111; + color: gray; +} + +.muted a:link { + color: #10547f; +} + + + diff --git a/youtube/static/gray_theme.css b/youtube/static/gray_theme.css index a064252..a5f3b23 100644 --- a/youtube/static/gray_theme.css +++ b/youtube/static/gray_theme.css @@ -11,3 +11,7 @@ body{ .setting-item{ background-color: #eeeeee; } + +.muted{ + background-color: #888888; +} diff --git a/youtube/static/light_theme.css b/youtube/static/light_theme.css index 50286bb..9e37449 100644 --- a/youtube/static/light_theme.css +++ b/youtube/static/light_theme.css @@ -8,7 +8,11 @@ body{ color: #000000; } - .setting-item{ background-color: #f8f8f8; } + +.muted{ + background-color: #888888; +} + diff --git a/youtube/static/shared.css b/youtube/static/shared.css index a79e42b..72d290a 100644 --- a/youtube/static/shared.css +++ b/youtube/static/shared.css @@ -1,7 +1,10 @@ +* { + box-sizing: border-box; +} + h1, h2, h3, h4, h5, h6, div, button{ margin:0; padding:0; - } address{ diff --git a/youtube/subscriptions.py b/youtube/subscriptions.py index 47f1ea3..2e821de 100644 --- a/youtube/subscriptions.py +++ b/youtube/subscriptions.py @@ -1,18 +1,828 @@ +from youtube import util, yt_data_extract, channel +from youtube import yt_app +import settings + +import sqlite3 +import os +import time +import gevent +import json +import traceback +import contextlib +import defusedxml.ElementTree import urllib +import math +import secrets +import collections +import calendar # bullshit! https://bugs.python.org/issue6280 + +import flask +from flask import request + + +thumbnails_directory = os.path.join(settings.data_dir, "subscription_thumbnails") + +# https://stackabuse.com/a-sqlite-tutorial-with-python/ + +database_path = os.path.join(settings.data_dir, "subscriptions.sqlite") + +def open_database(): + if not os.path.exists(settings.data_dir): + os.makedirs(settings.data_dir) + connection = sqlite3.connect(database_path, check_same_thread=False) + + try: + cursor = connection.cursor() + cursor.execute('''PRAGMA foreign_keys = 1''') + # Create tables if they don't exist + cursor.execute('''CREATE TABLE IF NOT EXISTS subscribed_channels ( + id integer PRIMARY KEY, + yt_channel_id text UNIQUE NOT NULL, + channel_name text NOT NULL, + time_last_checked integer DEFAULT 0, + next_check_time integer DEFAULT 0, + muted integer DEFAULT 0 + )''') + cursor.execute('''CREATE TABLE IF NOT EXISTS videos ( + id integer PRIMARY KEY, + sql_channel_id integer NOT NULL REFERENCES subscribed_channels(id) ON UPDATE CASCADE ON DELETE CASCADE, + video_id text UNIQUE NOT NULL, + title text NOT NULL, + duration text, + time_published integer NOT NULL, + is_time_published_exact integer DEFAULT 0, + time_noticed integer NOT NULL, + description text, + watched integer default 0 + )''') + cursor.execute('''CREATE TABLE IF NOT EXISTS tag_associations ( + id integer PRIMARY KEY, + tag text NOT NULL, + sql_channel_id integer NOT NULL REFERENCES subscribed_channels(id) ON UPDATE CASCADE ON DELETE CASCADE, + UNIQUE(tag, sql_channel_id) + )''') + cursor.execute('''CREATE TABLE IF NOT EXISTS db_info ( + version integer DEFAULT 1 + )''') + + connection.commit() + except: + connection.rollback() + connection.close() + raise + + # https://stackoverflow.com/questions/19522505/using-sqlite3-in-python-with-with-keyword + return contextlib.closing(connection) + +def with_open_db(function, *args, **kwargs): + with open_database() as connection: + with connection as cursor: + return function(cursor, *args, **kwargs) + +def _is_subscribed(cursor, channel_id): + result = cursor.execute('''SELECT EXISTS( + SELECT 1 + FROM subscribed_channels + WHERE yt_channel_id=? + LIMIT 1 + )''', [channel_id]).fetchone() + return bool(result[0]) + +def is_subscribed(channel_id): + if not os.path.exists(database_path): + return False + + return with_open_db(_is_subscribed, channel_id) + +def _subscribe(channels): + ''' channels is a list of (channel_id, channel_name) ''' + channels = list(channels) + with open_database() as connection: + with connection as cursor: + channel_ids_to_check = [channel[0] for channel in channels if not _is_subscribed(cursor, channel[0])] + + rows = ( (channel_id, channel_name, 0, 0) for channel_id, channel_name in channels) + cursor.executemany('''INSERT OR IGNORE INTO subscribed_channels (yt_channel_id, channel_name, time_last_checked, next_check_time) + VALUES (?, ?, ?, ?)''', rows) + + if settings.autocheck_subscriptions: + # important that this is after the changes have been committed to database + # otherwise the autochecker (other thread) tries checking the channel before it's in the database + channel_names.update(channels) + check_channels_if_necessary(channel_ids_to_check) + +def delete_thumbnails(to_delete): + for thumbnail in to_delete: + try: + video_id = thumbnail[0:-4] + if video_id in existing_thumbnails: + os.remove(os.path.join(thumbnails_directory, thumbnail)) + existing_thumbnails.remove(video_id) + except Exception: + print('Failed to delete thumbnail: ' + thumbnail) + traceback.print_exc() + +def _unsubscribe(cursor, channel_ids): + ''' channel_ids is a list of channel_ids ''' + to_delete = [] + for channel_id in channel_ids: + rows = cursor.execute('''SELECT video_id + FROM videos + WHERE sql_channel_id = ( + SELECT id + FROM subscribed_channels + WHERE yt_channel_id=? + )''', (channel_id,)).fetchall() + to_delete += [row[0] + '.jpg' for row in rows] + + gevent.spawn(delete_thumbnails, to_delete) + cursor.executemany("DELETE FROM subscribed_channels WHERE yt_channel_id=?", ((channel_id, ) for channel_id in channel_ids)) + +def _get_videos(cursor, number_per_page, offset, tag = None): + '''Returns a full page of videos with an offset, and a value good enough to be used as the total number of videos''' + # We ask for the next 9 pages from the database + # Then the actual length of the results tell us if there are more than 9 pages left, and if not, how many there actually are + # This is done since there are only 9 page buttons on display at a time + # If there are more than 9 pages left, we give a fake value in place of the real number of results if the entire database was queried without limit + # This fake value is sufficient to get the page button generation macro to display 9 page buttons + # If we wish to display more buttons this logic must change + # We cannot use tricks with the sql id for the video since we frequently have filters and other restrictions in place on the results anyway + # TODO: This is probably not the ideal solution + if tag is not None: + db_videos = cursor.execute('''SELECT video_id, title, duration, time_published, is_time_published_exact, channel_name + FROM videos + INNER JOIN subscribed_channels on videos.sql_channel_id = subscribed_channels.id + INNER JOIN tag_associations on videos.sql_channel_id = tag_associations.sql_channel_id + WHERE tag = ? AND muted = 0 + ORDER BY time_noticed DESC, time_published DESC + LIMIT ? OFFSET ?''', (tag, number_per_page*9, offset)).fetchall() + else: + db_videos = cursor.execute('''SELECT video_id, title, duration, time_published, is_time_published_exact, channel_name + FROM videos + INNER JOIN subscribed_channels on videos.sql_channel_id = subscribed_channels.id + WHERE muted = 0 + ORDER BY time_noticed DESC, time_published DESC + LIMIT ? OFFSET ?''', (number_per_page*9, offset)).fetchall() + + pseudo_number_of_videos = offset + len(db_videos) + + videos = [] + for db_video in db_videos[0:number_per_page]: + videos.append({ + 'id': db_video[0], + 'title': db_video[1], + 'duration': db_video[2], + 'published': exact_timestamp(db_video[3]) if db_video[4] else posix_to_dumbed_down(db_video[3]), + 'author': db_video[5], + }) + + return videos, pseudo_number_of_videos + + + + +def _get_subscribed_channels(cursor): + for item in cursor.execute('''SELECT channel_name, yt_channel_id, muted + FROM subscribed_channels + ORDER BY channel_name COLLATE NOCASE'''): + yield item + + +def _add_tags(cursor, channel_ids, tags): + pairs = [(tag, yt_channel_id) for tag in tags for yt_channel_id in channel_ids] + cursor.executemany('''INSERT OR IGNORE INTO tag_associations (tag, sql_channel_id) + SELECT ?, id FROM subscribed_channels WHERE yt_channel_id = ? ''', pairs) + + +def _remove_tags(cursor, channel_ids, tags): + pairs = [(tag, yt_channel_id) for tag in tags for yt_channel_id in channel_ids] + cursor.executemany('''DELETE FROM tag_associations + WHERE tag = ? AND sql_channel_id = ( + SELECT id FROM subscribed_channels WHERE yt_channel_id = ? + )''', pairs) + + + +def _get_tags(cursor, channel_id): + return [row[0] for row in cursor.execute('''SELECT tag + FROM tag_associations + WHERE sql_channel_id = ( + SELECT id FROM subscribed_channels WHERE yt_channel_id = ? + )''', (channel_id,))] + +def _get_all_tags(cursor): + return [row[0] for row in cursor.execute('''SELECT DISTINCT tag FROM tag_associations''')] + +def _get_channel_names(cursor, channel_ids): + ''' returns list of (channel_id, channel_name) ''' + result = [] + for channel_id in channel_ids: + row = cursor.execute('''SELECT channel_name + FROM subscribed_channels + WHERE yt_channel_id = ?''', (channel_id,)).fetchone() + result.append( (channel_id, row[0]) ) + return result + + +def _channels_with_tag(cursor, tag, order=False, exclude_muted=False, include_muted_status=False): + ''' returns list of (channel_id, channel_name) ''' + + statement = '''SELECT yt_channel_id, channel_name''' + + if include_muted_status: + statement += ''', muted''' + + statement += ''' + FROM subscribed_channels + WHERE subscribed_channels.id IN ( + SELECT tag_associations.sql_channel_id FROM tag_associations WHERE tag=? + ) + ''' + if exclude_muted: + statement += '''AND muted != 1\n''' + if order: + statement += '''ORDER BY channel_name COLLATE NOCASE''' + + return cursor.execute(statement, [tag]).fetchall() + +def _schedule_checking(cursor, channel_id, next_check_time): + cursor.execute('''UPDATE subscribed_channels SET next_check_time = ? WHERE yt_channel_id = ?''', [int(next_check_time), channel_id]) + +def _is_muted(cursor, channel_id): + return bool(cursor.execute('''SELECT muted FROM subscribed_channels WHERE yt_channel_id=?''', [channel_id]).fetchone()[0]) + +units = collections.OrderedDict([ + ('year', 31536000), # 365*24*3600 + ('month', 2592000), # 30*24*3600 + ('week', 604800), # 7*24*3600 + ('day', 86400), # 24*3600 + ('hour', 3600), + ('minute', 60), + ('second', 1), +]) +def youtube_timestamp_to_posix(dumb_timestamp): + ''' Given a dumbed down timestamp such as 1 year ago, 3 hours ago, + approximates the unix time (seconds since 1/1/1970) ''' + dumb_timestamp = dumb_timestamp.lower() + now = time.time() + if dumb_timestamp == "just now": + return now + split = dumb_timestamp.split(' ') + quantifier, unit = int(split[0]), split[1] + if quantifier > 1: + unit = unit[:-1] # remove s from end + return now - quantifier*units[unit] + +def posix_to_dumbed_down(posix_time): + '''Inverse of youtube_timestamp_to_posix.''' + delta = int(time.time() - posix_time) + assert delta >= 0 + + if delta == 0: + return '0 seconds ago' + + for unit_name, unit_time in units.items(): + if delta >= unit_time: + quantifier = round(delta/unit_time) + if quantifier == 1: + return '1 ' + unit_name + ' ago' + else: + return str(quantifier) + ' ' + unit_name + 's ago' + else: + raise Exception() + +def exact_timestamp(posix_time): + result = time.strftime('%I:%M %p %m/%d/%y', time.localtime(posix_time)) + if result[0] == '0': # remove 0 infront of hour (like 01:00 PM) + return result[1:] + return result + +try: + existing_thumbnails = set(os.path.splitext(name)[0] for name in os.listdir(thumbnails_directory)) +except FileNotFoundError: + existing_thumbnails = set() + + +# --- Manual checking system. Rate limited in order to support very large numbers of channels to be checked --- +# Auto checking system plugs into this for convenience, though it doesn't really need the rate limiting + +check_channels_queue = util.RateLimitedQueue() +checking_channels = set() + +# Just to use for printing channel checking status to console without opening database +channel_names = dict() + +def check_channel_worker(): + while True: + channel_id = check_channels_queue.get() + try: + _get_upstream_videos(channel_id) + finally: + checking_channels.remove(channel_id) + +for i in range(0,5): + gevent.spawn(check_channel_worker) +# ---------------------------- + + + +# --- Auto checking system - Spaghetti code --- + +if settings.autocheck_subscriptions: + # job application format: dict with keys (channel_id, channel_name, next_check_time) + autocheck_job_application = gevent.queue.Queue() # only really meant to hold 1 item, just reusing gevent's wait and timeout machinery + + autocheck_jobs = [] # list of dicts with the keys (channel_id, channel_name, next_check_time). Stores all the channels that need to be autochecked and when to check them + with open_database() as connection: + with connection as cursor: + now = time.time() + for row in cursor.execute('''SELECT yt_channel_id, channel_name, next_check_time FROM subscribed_channels WHERE muted != 1''').fetchall(): + + if row[2] is None: + next_check_time = 0 + else: + next_check_time = row[2] + + # expired, check randomly within the next hour + # note: even if it isn't scheduled in the past right now, it might end up being if it's due soon and we dont start dispatching by then, see below where time_until_earliest_job is negative + if next_check_time < now: + next_check_time = now + 3600*secrets.randbelow(60)/60 + row = (row[0], row[1], next_check_time) + _schedule_checking(cursor, row[0], next_check_time) + autocheck_jobs.append({'channel_id': row[0], 'channel_name': row[1], 'next_check_time': next_check_time}) + + + + def autocheck_dispatcher(): + '''Scans the auto_check_list. Sleeps until the earliest job is due, then adds that channel to the checking queue above. Can be sent a new job through autocheck_job_application''' + while True: + if len(autocheck_jobs) == 0: + new_job = autocheck_job_application.get() + autocheck_jobs.append(new_job) + else: + earliest_job_index = min(range(0, len(autocheck_jobs)), key=lambda index: autocheck_jobs[index]['next_check_time']) # https://stackoverflow.com/a/11825864 + earliest_job = autocheck_jobs[earliest_job_index] + time_until_earliest_job = earliest_job['next_check_time'] - time.time() + + if time_until_earliest_job <= -5: # should not happen unless we're running extremely slow + print('ERROR: autocheck_dispatcher got job scheduled in the past, skipping and rescheduling: ' + earliest_job['channel_id'] + ', ' + earliest_job['channel_name'] + ', ' + str(earliest_job['next_check_time'])) + next_check_time = time.time() + 3600*secrets.randbelow(60)/60 + with_open_db(_schedule_checking, earliest_job['channel_id'], next_check_time) + autocheck_jobs[earliest_job_index]['next_check_time'] = next_check_time + continue + + # make sure it's not muted + if with_open_db(_is_muted, earliest_job['channel_id']): + del autocheck_jobs[earliest_job_index] + continue + + if time_until_earliest_job > 0: # it can become less than zero (in the past) when it's set to go off while the dispatcher is doing something else at that moment + try: + new_job = autocheck_job_application.get(timeout = time_until_earliest_job) # sleep for time_until_earliest_job time, but allow to be interrupted by new jobs + except gevent.queue.Empty: # no new jobs + pass + else: # new job, add it to the list + autocheck_jobs.append(new_job) + continue + + # no new jobs, time to execute the earliest job + channel_names[earliest_job['channel_id']] = earliest_job['channel_name'] + checking_channels.add(earliest_job['channel_id']) + check_channels_queue.put(earliest_job['channel_id']) + del autocheck_jobs[earliest_job_index] + + + gevent.spawn(autocheck_dispatcher) +# ---------------------------- + + + +def check_channels_if_necessary(channel_ids): + for channel_id in channel_ids: + if channel_id not in checking_channels: + checking_channels.add(channel_id) + check_channels_queue.put(channel_id) -with open("subscriptions.txt", 'r', encoding='utf-8') as file: - subscriptions = file.read() - -# Line format: "channel_id channel_name" -# Example: -# UCYO_jab_esuFRV4b17AJtAw 3Blue1Brown -subscriptions = ((line[0:24], line[25: ]) for line in subscriptions.splitlines()) -def get_new_videos(): - for channel_id, channel_name in subscriptions: - +def _get_upstream_videos(channel_id): + try: + channel_status_name = channel_names[channel_id] + except KeyError: + channel_status_name = channel_id + print("Checking channel: " + channel_status_name) + tasks = ( + gevent.spawn(channel.get_channel_tab, channel_id, print_status=False), # channel page, need for video duration + gevent.spawn(util.fetch_url, 'https://www.youtube.com/feeds/videos.xml?channel_id=' + channel_id) # atoma feed, need for exact published time + ) + gevent.joinall(tasks) + channel_tab, feed = tasks[0].value, tasks[1].value + + # extract published times from atoma feed + times_published = {} + try: + def remove_bullshit(tag): + '''Remove XML namespace bullshit from tagname. https://bugs.python.org/issue18304''' + if '}' in tag: + return tag[tag.rfind('}')+1:] + return tag + + def find_element(base, tag_name): + for element in base: + if remove_bullshit(element.tag) == tag_name: + return element + return None + + root = defusedxml.ElementTree.fromstring(feed.decode('utf-8')) + assert remove_bullshit(root.tag) == 'feed' + for entry in root: + if (remove_bullshit(entry.tag) != 'entry'): + continue + + # it's yt:videoId in the xml but the yt: is turned into a namespace which is removed by remove_bullshit + video_id_element = find_element(entry, 'videoId') + time_published_element = find_element(entry, 'published') + assert video_id_element is not None + assert time_published_element is not None + + time_published = int(calendar.timegm(time.strptime(time_published_element.text, '%Y-%m-%dT%H:%M:%S+00:00'))) + times_published[video_id_element.text] = time_published + + except (AssertionError, defusedxml.ElementTree.ParseError) as e: + print('Failed to read atoma feed for ' + channel_status_name) + traceback.print_exc() + + with open_database() as connection: + with connection as cursor: + is_first_check = cursor.execute('''SELECT time_last_checked FROM subscribed_channels WHERE yt_channel_id=?''', [channel_id]).fetchone()[0] in (None, 0) + video_add_time = int(time.time()) + + videos = [] + channel_videos = channel.extract_info(json.loads(channel_tab), 'videos')['items'] + for i, video_item in enumerate(channel_videos): + if 'description' not in video_item: + video_item['description'] = '' + + if video_item['id'] in times_published: + time_published = times_published[video_item['id']] + is_time_published_exact = True + else: + is_time_published_exact = False + try: + time_published = youtube_timestamp_to_posix(video_item['published']) - i # subtract a few seconds off the videos so they will be in the right order + except KeyError: + print(video_item) + if is_first_check: + time_noticed = time_published # don't want a crazy ordering on first check, since we're ordering by time_noticed + else: + time_noticed = video_add_time + videos.append((channel_id, video_item['id'], video_item['title'], video_item['duration'], time_published, is_time_published_exact, time_noticed, video_item['description'])) + + + if len(videos) == 0: + average_upload_period = 4*7*24*3600 # assume 1 month for channel with no videos + elif len(videos) < 5: + average_upload_period = int((time.time() - videos[len(videos)-1][4])/len(videos)) + else: + average_upload_period = int((time.time() - videos[4][4])/5) # equivalent to averaging the time between videos for the last 5 videos + + # calculate when to check next for auto checking + # add some quantization and randomness to make pattern analysis by Youtube slightly harder + quantized_upload_period = average_upload_period - (average_upload_period % (4*3600)) + 4*3600 # round up to nearest 4 hours + randomized_upload_period = quantized_upload_period*(1 + secrets.randbelow(50)/50*0.5) # randomly between 1x and 1.5x + next_check_delay = randomized_upload_period/10 # check at 10x the channel posting rate. might want to fine tune this number + next_check_time = int(time.time() + next_check_delay) + + # calculate how many new videos there are + row = cursor.execute('''SELECT video_id + FROM videos + INNER JOIN subscribed_channels ON videos.sql_channel_id = subscribed_channels.id + WHERE yt_channel_id=? + ORDER BY time_published DESC + LIMIT 1''', [channel_id]).fetchone() + if row is None: + number_of_new_videos = len(videos) + else: + latest_video_id = row[0] + index = 0 + for video in videos: + if video[1] == latest_video_id: + break + index += 1 + number_of_new_videos = index + + cursor.executemany('''INSERT OR IGNORE INTO videos (sql_channel_id, video_id, title, duration, time_published, is_time_published_exact, time_noticed, description) + VALUES ((SELECT id FROM subscribed_channels WHERE yt_channel_id=?), ?, ?, ?, ?, ?, ?, ?)''', videos) + cursor.execute('''UPDATE subscribed_channels + SET time_last_checked = ?, next_check_time = ? + WHERE yt_channel_id=?''', [int(time.time()), next_check_time, channel_id]) + + if settings.autocheck_subscriptions: + if not _is_muted(cursor, channel_id): + autocheck_job_application.put({'channel_id': channel_id, 'channel_name': channel_names[channel_id], 'next_check_time': next_check_time}) + + if number_of_new_videos == 0: + print('No new videos from ' + channel_status_name) + elif number_of_new_videos == 1: + print('1 new video from ' + channel_status_name) + else: + print(str(number_of_new_videos) + ' new videos from ' + channel_status_name) + + + +def check_all_channels(): + with open_database() as connection: + with connection as cursor: + channel_id_name_list = cursor.execute('''SELECT yt_channel_id, channel_name + FROM subscribed_channels + WHERE muted != 1''').fetchall() + + channel_names.update(channel_id_name_list) + check_channels_if_necessary([item[0] for item in channel_id_name_list]) + + +def check_tags(tags): + channel_id_name_list = [] + with open_database() as connection: + with connection as cursor: + for tag in tags: + channel_id_name_list += _channels_with_tag(cursor, tag, exclude_muted=True) + + channel_names.update(channel_id_name_list) + check_channels_if_necessary([item[0] for item in channel_id_name_list]) + + +def check_specific_channels(channel_ids): + with open_database() as connection: + with connection as cursor: + channel_id_name_list = [] + for channel_id in channel_ids: + channel_id_name_list += cursor.execute('''SELECT yt_channel_id, channel_name + FROM subscribed_channels + WHERE yt_channel_id=?''', [channel_id]).fetchall() + channel_names.update(channel_id_name_list) + check_channels_if_necessary(channel_ids) + + + +@yt_app.route('/import_subscriptions', methods=['POST']) +def import_subscriptions(): + + # check if the post request has the file part + if 'subscriptions_file' not in request.files: + #flash('No file part') + return flask.redirect(util.URL_ORIGIN + request.full_path) + file = request.files['subscriptions_file'] + # if user does not select file, browser also + # submit an empty part without filename + if file.filename == '': + #flash('No selected file') + return flask.redirect(util.URL_ORIGIN + request.full_path) + + + mime_type = file.mimetype + + if mime_type == 'application/json': + file = file.read().decode('utf-8') + try: + file = json.loads(file) + except json.decoder.JSONDecodeError: + traceback.print_exc() + return '400 Bad Request: Invalid json file', 400 + + try: + channels = ( (item['snippet']['resourceId']['channelId'], item['snippet']['title']) for item in file) + except (KeyError, IndexError): + traceback.print_exc() + return '400 Bad Request: Unknown json structure', 400 + elif mime_type in ('application/xml', 'text/xml', 'text/x-opml'): + file = file.read().decode('utf-8') + try: + root = defusedxml.ElementTree.fromstring(file) + assert root.tag == 'opml' + channels = [] + for outline_element in root[0][0]: + if (outline_element.tag != 'outline') or ('xmlUrl' not in outline_element.attrib): + continue + + + channel_name = outline_element.attrib['text'] + channel_rss_url = outline_element.attrib['xmlUrl'] + channel_id = channel_rss_url[channel_rss_url.find('channel_id=')+11:].strip() + channels.append( (channel_id, channel_name) ) + + except (AssertionError, IndexError, defusedxml.ElementTree.ParseError) as e: + return '400 Bad Request: Unable to read opml xml file, or the file is not the expected format', 400 + else: + return '400 Bad Request: Unsupported file format: ' + mime_type + '. Only subscription.json files (from Google Takeouts) and XML OPML files exported from Youtube\'s subscription manager page are supported', 400 + + _subscribe(channels) + + return flask.redirect(util.URL_ORIGIN + '/subscription_manager', 303) + + + +@yt_app.route('/subscription_manager', methods=['GET']) +def get_subscription_manager_page(): + group_by_tags = request.args.get('group_by_tags', '0') == '1' + with open_database() as connection: + with connection as cursor: + if group_by_tags: + tag_groups = [] + + for tag in _get_all_tags(cursor): + sub_list = [] + for channel_id, channel_name, muted in _channels_with_tag(cursor, tag, order=True, include_muted_status=True): + sub_list.append({ + 'channel_url': util.URL_ORIGIN + '/channel/' + channel_id, + 'channel_name': channel_name, + 'channel_id': channel_id, + 'muted': muted, + 'tags': [t for t in _get_tags(cursor, channel_id) if t != tag], + }) + + tag_groups.append( (tag, sub_list) ) + + # Channels with no tags + channel_list = cursor.execute('''SELECT yt_channel_id, channel_name, muted + FROM subscribed_channels + WHERE id NOT IN ( + SELECT sql_channel_id FROM tag_associations + ) + ORDER BY channel_name COLLATE NOCASE''').fetchall() + if channel_list: + sub_list = [] + for channel_id, channel_name, muted in channel_list: + sub_list.append({ + 'channel_url': util.URL_ORIGIN + '/channel/' + channel_id, + 'channel_name': channel_name, + 'channel_id': channel_id, + 'muted': muted, + 'tags': [], + }) + + tag_groups.append( ('No tags', sub_list) ) + else: + sub_list = [] + for channel_name, channel_id, muted in _get_subscribed_channels(cursor): + sub_list.append({ + 'channel_url': util.URL_ORIGIN + '/channel/' + channel_id, + 'channel_name': channel_name, + 'channel_id': channel_id, + 'muted': muted, + 'tags': _get_tags(cursor, channel_id), + }) + + + + + if group_by_tags: + return flask.render_template('subscription_manager.html', + group_by_tags = True, + tag_groups = tag_groups, + ) + else: + return flask.render_template('subscription_manager.html', + group_by_tags = False, + sub_list = sub_list, + ) + +def list_from_comma_separated_tags(string): + return [tag.strip() for tag in string.split(',') if tag.strip()] + + +@yt_app.route('/subscription_manager', methods=['POST']) +def post_subscription_manager_page(): + action = request.values['action'] + + with open_database() as connection: + with connection as cursor: + if action == 'add_tags': + _add_tags(cursor, request.values.getlist('channel_ids'), [tag.lower() for tag in list_from_comma_separated_tags(request.values['tags'])]) + elif action == 'remove_tags': + _remove_tags(cursor, request.values.getlist('channel_ids'), [tag.lower() for tag in list_from_comma_separated_tags(request.values['tags'])]) + elif action == 'unsubscribe': + _unsubscribe(cursor, request.values.getlist('channel_ids')) + elif action == 'unsubscribe_verify': + unsubscribe_list = _get_channel_names(cursor, request.values.getlist('channel_ids')) + return flask.render_template('unsubscribe_verify.html', unsubscribe_list = unsubscribe_list) + + elif action == 'mute': + cursor.executemany('''UPDATE subscribed_channels + SET muted = 1 + WHERE yt_channel_id = ?''', [(ci,) for ci in request.values.getlist('channel_ids')]) + elif action == 'unmute': + cursor.executemany('''UPDATE subscribed_channels + SET muted = 0 + WHERE yt_channel_id = ?''', [(ci,) for ci in request.values.getlist('channel_ids')]) + else: + flask.abort(400) + + return flask.redirect(util.URL_ORIGIN + request.full_path, 303) + +@yt_app.route('/subscriptions', methods=['GET']) +@yt_app.route('/feed/subscriptions', methods=['GET']) def get_subscriptions_page(): + page = int(request.args.get('page', 1)) + with open_database() as connection: + with connection as cursor: + tag = request.args.get('tag', None) + videos, number_of_videos_in_db = _get_videos(cursor, 60, (page - 1)*60, tag) + for video in videos: + video['thumbnail'] = util.URL_ORIGIN + '/data/subscription_thumbnails/' + video['id'] + '.jpg' + video['type'] = 'video' + video['item_size'] = 'small' + yt_data_extract.add_extra_html_info(video) + + tags = _get_all_tags(cursor) + + + subscription_list = [] + for channel_name, channel_id, muted in _get_subscribed_channels(cursor): + subscription_list.append({ + 'channel_url': util.URL_ORIGIN + '/channel/' + channel_id, + 'channel_name': channel_name, + 'channel_id': channel_id, + 'muted': muted, + }) + + return flask.render_template('subscriptions.html', + videos = videos, + num_pages = math.ceil(number_of_videos_in_db/60), + parameters_dictionary = request.args, + tags = tags, + current_tag = tag, + subscription_list = subscription_list, + ) + +@yt_app.route('/subscriptions', methods=['POST']) +@yt_app.route('/feed/subscriptions', methods=['POST']) +def post_subscriptions_page(): + action = request.values['action'] + if action == 'subscribe': + if len(request.values.getlist('channel_id')) != len(request.values.getlist('channel_name')): + return '400 Bad Request, length of channel_id != length of channel_name', 400 + _subscribe(zip(request.values.getlist('channel_id'), request.values.getlist('channel_name'))) + + elif action == 'unsubscribe': + with_open_db(_unsubscribe, request.values.getlist('channel_id')) + + elif action == 'refresh': + type = request.values['type'] + if type == 'all': + check_all_channels() + elif type == 'tag': + check_tags(request.values.getlist('tag_name')) + elif type == 'channel': + check_specific_channels(request.values.getlist('channel_id')) + else: + flask.abort(400) + else: + flask.abort(400) + + return '', 204 + + +@yt_app.route('/data/subscription_thumbnails/<thumbnail>') +def serve_subscription_thumbnail(thumbnail): + '''Serves thumbnail from disk if it's been saved already. If not, downloads the thumbnail, saves to disk, and serves it.''' + assert thumbnail[-4:] == '.jpg' + video_id = thumbnail[0:-4] + thumbnail_path = os.path.join(thumbnails_directory, thumbnail) + + if video_id in existing_thumbnails: + try: + f = open(thumbnail_path, 'rb') + except FileNotFoundError: + existing_thumbnails.remove(video_id) + else: + image = f.read() + f.close() + return flask.Response(image, mimetype='image/jpeg') + + url = "https://i.ytimg.com/vi/" + video_id + "/mqdefault.jpg" + try: + image = util.fetch_url(url, report_text="Saved thumbnail: " + video_id) + except urllib.error.HTTPError as e: + print("Failed to download thumbnail for " + video_id + ": " + str(e)) + abort(e.code) + try: + f = open(thumbnail_path, 'wb') + except FileNotFoundError: + os.makedirs(thumbnails_directory, exist_ok = True) + f = open(thumbnail_path, 'wb') + f.write(image) + f.close() + existing_thumbnails.add(video_id) + + return flask.Response(image, mimetype='image/jpeg') + + + + + + + diff --git a/youtube/templates/channel.html b/youtube/templates/channel.html index ed04988..48041a0 100644 --- a/youtube/templates/channel.html +++ b/youtube/templates/channel.html @@ -23,6 +23,9 @@ grid-column:2; margin-left: 5px; } + .summary subscribe-unsubscribe, .summary short-description{ + margin-top: 10px; + } main .channel-tabs{ grid-row:2; grid-column: 1 / span 2; @@ -89,6 +92,12 @@ <div class="summary"> <h2 class="title">{{ channel_name }}</h2> <p class="short-description">{{ short_description }}</p> + <form method="POST" action="/youtube.com/subscriptions" class="subscribe-unsubscribe"> + <input type="submit" value="{{ 'Unsubscribe' if subscribed else 'Subscribe' }}"> + <input type="hidden" name="channel_id" value="{{ channel_id }}"> + <input type="hidden" name="channel_name" value="{{ channel_name }}"> + <input type="hidden" name="action" value="{{ 'unsubscribe' if subscribed else 'subscribe' }}"> + </form> </div> <nav class="channel-tabs"> {% for tab_name in ('Videos', 'Playlists', 'About') %} diff --git a/youtube/templates/subscription_manager.html b/youtube/templates/subscription_manager.html new file mode 100644 index 0000000..ed6a320 --- /dev/null +++ b/youtube/templates/subscription_manager.html @@ -0,0 +1,139 @@ +{% set page_title = 'Subscription Manager' %} +{% extends "base.html" %} +{% block style %} + .import-export{ + display: flex; + flex-direction: row; + } + .subscriptions-import-form{ + background-color: var(--interface-color); + display: flex; + flex-direction: column; + align-items: flex-start; + max-width: 300px; + padding:10px; + } + .subscriptions-import-form h2{ + font-size: 20px; + margin-bottom: 10px; + } + + .import-submit-button{ + margin-top:15px; + align-self: flex-end; + } + + + .subscriptions-export-links{ + margin: 0px 0px 0px 20px; + background-color: var(--interface-color); + list-style: none; + max-width: 300px; + padding:10px; + } + + .sub-list-controls{ + background-color: var(--interface-color); + padding:10px; + } + + + .tag-group-list{ + list-style: none; + margin-left: 10px; + margin-right: 10px; + padding: 0px; + } + .tag-group{ + border-style: solid; + margin-bottom: 10px; + } + + .sub-list{ + list-style: none; + padding:10px; + column-width: 300px; + column-gap: 40px; + } + .sub-list-item{ + display:flex; + margin-bottom: 10px; + break-inside:avoid; + background-color: var(--interface-color); + } + .tag-list{ + margin-left:15px; + font-weight:bold; + } + .sub-list-item-name{ + margin-left:15px; + } + .sub-list-checkbox{ + height: 1.5em; + min-width: 1.5em; // need min-width otherwise browser doesn't respect the width and squishes the checkbox down when there's too many tags + } +{% endblock style %} + + +{% macro subscription_list(sub_list) %} + {% for subscription in sub_list %} + <li class="sub-list-item {{ 'muted' if subscription['muted'] else '' }}"> + <input class="sub-list-checkbox" name="channel_ids" value="{{ subscription['channel_id'] }}" form="subscription-manager-form" type="checkbox"> + <a href="{{ subscription['channel_url'] }}" class="sub-list-item-name" title="{{ subscription['channel_name'] }}">{{ subscription['channel_name'] }}</a> + <span class="tag-list">{{ ', '.join(subscription['tags']) }}</span> + </li> + {% endfor %} +{% endmacro %} + + + +{% block main %} + <div class="import-export"> + <form class="subscriptions-import-form" enctype="multipart/form-data" action="/youtube.com/import_subscriptions" method="POST"> + <h2>Import subscriptions</h2> + <input type="file" id="subscriptions-import" accept="application/json, application/xml, text/x-opml" name="subscriptions_file"> + <input type="submit" value="Import" class="import-submit-button"> + </form> + + <ul class="subscriptions-export-links"> + <li><a href="/youtube.com/subscriptions.opml">Export subscriptions (OPML)</a></li> + <li><a href="/youtube.com/subscriptions.xml">Export subscriptions (RSS)</a></li> + </ul> + </div> + + <hr> + + <form id="subscription-manager-form" class="sub-list-controls" method="POST"> + {% if group_by_tags %} + <a class="sort-button" href="/https://www.youtube.com/subscription_manager?group_by_tags=0">Don't group</a> + {% else %} + <a class="sort-button" href="/https://www.youtube.com/subscription_manager?group_by_tags=1">Group by tags</a> + {% endif %} + <input type="text" name="tags"> + <button type="submit" name="action" value="add_tags">Add tags</button> + <button type="submit" name="action" value="remove_tags">Remove tags</button> + <button type="submit" name="action" value="unsubscribe_verify">Unsubscribe</button> + <button type="submit" name="action" value="mute">Mute</button> + <button type="submit" name="action" value="unmute">Unmute</button> + <input type="reset" value="Clear Selection"> + </form> + + + {% if group_by_tags %} + <ul class="tag-group-list"> + {% for tag_name, sub_list in tag_groups %} + <li class="tag-group"> + <h2 class="tag-group-name">{{ tag_name }}</h2> + <ol class="sub-list"> + {{ subscription_list(sub_list) }} + </ol> + </li> + {% endfor %} + </ul> + {% else %} + <ol class="sub-list"> + {{ subscription_list(sub_list) }} + </ol> + {% endif %} + +{% endblock main %} diff --git a/youtube/templates/subscriptions.html b/youtube/templates/subscriptions.html new file mode 100644 index 0000000..370ca23 --- /dev/null +++ b/youtube/templates/subscriptions.html @@ -0,0 +1,113 @@ +{% set page_title = 'Subscriptions' %} +{% extends "base.html" %} +{% import "common_elements.html" as common_elements %} + +{% block style %} + main{ + display:flex; + flex-direction: row; + } + .video-section{ + flex-grow: 1; + } + .video-section .page-button-row{ + justify-content: center; + } + .subscriptions-sidebar{ + flex-basis: 300px; + background-color: var(--interface-color); + border-left: 2px; + } + .sidebar-links{ + display:flex; + justify-content: space-between; + padding-left:10px; + padding-right: 10px; + } + + .sidebar-list{ + list-style: none; + padding-left:10px; + padding-right: 10px; + } + .sidebar-list-item{ + display:flex; + justify-content: space-between; + margin-bottom: 5px; + } + .sub-refresh-list .sidebar-item-name{ + text-overflow: clip; + white-space: nowrap; + overflow: hidden; + max-width: 200px; + } +{% endblock style %} + +{% block main %} + <div class="video-section"> + <nav class="item-grid"> + {% for video_info in videos %} + {{ common_elements.item(video_info, include_author=false) }} + {% endfor %} + </nav> + + <nav class="page-button-row"> + {{ common_elements.page_buttons(num_pages, '/youtube.com/subscriptions', parameters_dictionary) }} + </nav> + </div> + + <div class="subscriptions-sidebar"> + <div class="sidebar-links"> + <a href="/youtube.com/subscription_manager" class="sub-manager-link">Subscription Manager</a> + <form method="POST" class="refresh-all"> + <input type="submit" value="Check All"> + <input type="hidden" name="action" value="refresh"> + <input type="hidden" name="type" value="all"> + </form> + </div> + + <hr> + + <ol class="sidebar-list tags"> + {% if current_tag %} + <li class="sidebar-list-item"> + <a href="/youtube.com/subscriptions" class="sidebar-item-name">Any tag</a> + </li> + {% endif %} + + {% for tag in tags %} + <li class="sidebar-list-item"> + {% if tag == current_tag %} + <span class="sidebar-item-name">{{ tag }}</span> + {% else %} + <a href="?tag={{ tag|urlencode }}" class="sidebar-item-name">{{ tag }}</a> + {% endif %} + <form method="POST" class="sidebar-item-refresh"> + <input type="submit" value="Check"> + <input type="hidden" name="action" value="refresh"> + <input type="hidden" name="type" value="tag"> + <input type="hidden" name="tag_name" value="{{ tag }}"> + </form> + </li> + {% endfor %} + </ol> + + <hr> + + <ol class="sidebar-list sub-refresh-list"> + {% for subscription in subscription_list %} + <li class="sidebar-list-item {{ 'muted' if subscription['muted'] else '' }}"> + <a href="{{ subscription['channel_url'] }}" class="sidebar-item-name" title="{{ subscription['channel_name'] }}">{{ subscription['channel_name'] }}</a> + <form method="POST" class="sidebar-item-refresh"> + <input type="submit" value="Check"> + <input type="hidden" name="action" value="refresh"> + <input type="hidden" name="type" value="channel"> + <input type="hidden" name="channel_id" value="{{ subscription['channel_id'] }}"> + </form> + </li> + {% endfor %} + </ol> + + </div> + +{% endblock main %} diff --git a/youtube/templates/unsubscribe_verify.html b/youtube/templates/unsubscribe_verify.html new file mode 100644 index 0000000..98581c0 --- /dev/null +++ b/youtube/templates/unsubscribe_verify.html @@ -0,0 +1,19 @@ +{% set page_title = 'Unsubscribe?' %} +{% extends "base.html" %} + +{% block main %} + <span>Are you sure you want to unsubscribe from these channels?</span> + <form class="subscriptions-import-form" action="/youtube.com/subscription_manager" method="POST"> + {% for channel_id, channel_name in unsubscribe_list %} + <input type="hidden" name="channel_ids" value="{{ channel_id }}"> + {% endfor %} + + <input type="hidden" name="action" value="unsubscribe"> + <input type="submit" value="Yes, unsubscribe"> + </form> + <ul> + {% for channel_id, channel_name in unsubscribe_list %} + <li><a href="{{ '/https://www.youtube.com/channel/' + channel_id }}" title="{{ channel_name }}">{{ channel_name }}</a></li> + {% endfor %} + </ul> +{% endblock main %} diff --git a/youtube/util.py b/youtube/util.py index 2f80f11..2205645 100644 --- a/youtube/util.py +++ b/youtube/util.py @@ -6,6 +6,9 @@ import urllib.parse import re import time import os +import gevent +import gevent.queue +import gevent.lock # The trouble with the requests library: It ships its own certificate bundle via certifi # instead of using the system certificate store, meaning self-signed certificates @@ -183,6 +186,84 @@ desktop_ua = (('User-Agent', desktop_user_agent),) +class RateLimitedQueue(gevent.queue.Queue): + ''' Does initial_burst (def. 30) at first, then alternates between waiting waiting_period (def. 5) seconds and doing subsequent_bursts (def. 10) queries. After 5 seconds with nothing left in the queue, resets rate limiting. ''' + + def __init__(self, initial_burst=30, waiting_period=5, subsequent_bursts=10): + self.initial_burst = initial_burst + self.waiting_period = waiting_period + self.subsequent_bursts = subsequent_bursts + + self.count_since_last_wait = 0 + self.surpassed_initial = False + + self.lock = gevent.lock.BoundedSemaphore(1) + self.currently_empty = False + self.empty_start = 0 + gevent.queue.Queue.__init__(self) + + + def get(self): + self.lock.acquire() # blocks if another greenlet currently has the lock + if self.count_since_last_wait >= self.subsequent_bursts and self.surpassed_initial: + gevent.sleep(self.waiting_period) + self.count_since_last_wait = 0 + + elif self.count_since_last_wait >= self.initial_burst and not self.surpassed_initial: + self.surpassed_initial = True + gevent.sleep(self.waiting_period) + self.count_since_last_wait = 0 + + self.count_since_last_wait += 1 + + if not self.currently_empty and self.empty(): + self.currently_empty = True + self.empty_start = time.monotonic() + + item = gevent.queue.Queue.get(self) # blocks when nothing left + + if self.currently_empty: + if time.monotonic() - self.empty_start >= self.waiting_period: + self.count_since_last_wait = 0 + self.surpassed_initial = False + + self.currently_empty = False + + self.lock.release() + + return item + + + +def download_thumbnail(save_directory, video_id): + url = "https://i.ytimg.com/vi/" + video_id + "/mqdefault.jpg" + save_location = os.path.join(save_directory, video_id + ".jpg") + try: + thumbnail = fetch_url(url, report_text="Saved thumbnail: " + video_id) + except urllib.error.HTTPError as e: + print("Failed to download thumbnail for " + video_id + ": " + str(e)) + return False + try: + f = open(save_location, 'wb') + except FileNotFoundError: + os.makedirs(save_directory, exist_ok = True) + f = open(save_location, 'wb') + f.write(thumbnail) + f.close() + return True + +def download_thumbnails(save_directory, ids): + if not isinstance(ids, (list, tuple)): + ids = list(ids) + # only do 5 at a time + # do the n where n is divisible by 5 + i = -1 + for i in range(0, int(len(ids)/5) - 1 ): + gevent.joinall([gevent.spawn(download_thumbnail, save_directory, ids[j]) for j in range(i*5, i*5 + 5)]) + # do the remainders (< 5) + gevent.joinall([gevent.spawn(download_thumbnail, save_directory, ids[j]) for j in range(i*5 + 5, len(ids))]) + + |