diff options
| author | James Taylor <user234683@users.noreply.github.com> | 2019-02-16 23:41:52 -0800 | 
|---|---|---|
| committer | James Taylor <user234683@users.noreply.github.com> | 2019-02-16 23:41:52 -0800 | 
| commit | 3905e7e64059b45479894ba1fdfb0ef9cef64475 (patch) | |
| tree | 4c5dbbfd204d0351cac8412cc87a65fea49c1a52 | |
| parent | 24642455d0dc5841ddec99f456598c4f763c1e8a (diff) | |
| download | yt-local-3905e7e64059b45479894ba1fdfb0ef9cef64475.tar.lz yt-local-3905e7e64059b45479894ba1fdfb0ef9cef64475.tar.xz yt-local-3905e7e64059b45479894ba1fdfb0ef9cef64475.zip | |
basic subscriptions system
57 files changed, 12440 insertions, 23 deletions
| diff --git a/python/atoma/__init__.py b/python/atoma/__init__.py new file mode 100644 index 0000000..0768081 --- /dev/null +++ b/python/atoma/__init__.py @@ -0,0 +1,12 @@ +from .atom import parse_atom_file, parse_atom_bytes +from .rss import parse_rss_file, parse_rss_bytes +from .json_feed import ( +    parse_json_feed, parse_json_feed_file, parse_json_feed_bytes +) +from .opml import parse_opml_file, parse_opml_bytes +from .exceptions import ( +    FeedParseError, FeedDocumentError, FeedXMLError, FeedJSONError +) +from .const import VERSION + +__version__ = VERSION diff --git a/python/atoma/atom.py b/python/atoma/atom.py new file mode 100644 index 0000000..d4e676c --- /dev/null +++ b/python/atoma/atom.py @@ -0,0 +1,284 @@ +from datetime import datetime +import enum +from io import BytesIO +from typing import Optional, List +from xml.etree.ElementTree import Element + +import attr + +from .utils import ( +    parse_xml, get_child, get_text, get_datetime, FeedParseError, ns +) + + +class AtomTextType(enum.Enum): +    text = "text" +    html = "html" +    xhtml = "xhtml" + + +@attr.s +class AtomTextConstruct: +    text_type: str = attr.ib() +    lang: Optional[str] = attr.ib() +    value: str = attr.ib() + + +@attr.s +class AtomEntry: +    title: AtomTextConstruct = attr.ib() +    id_: str = attr.ib() + +    # Should be mandatory but many feeds use published instead +    updated: Optional[datetime] = attr.ib() + +    authors: List['AtomPerson'] = attr.ib() +    contributors: List['AtomPerson'] = attr.ib() +    links: List['AtomLink'] = attr.ib() +    categories: List['AtomCategory'] = attr.ib() +    published: Optional[datetime] = attr.ib() +    rights: Optional[AtomTextConstruct] = attr.ib() +    summary: Optional[AtomTextConstruct] = attr.ib() +    content: Optional[AtomTextConstruct] = attr.ib() +    source: Optional['AtomFeed'] = attr.ib() + + +@attr.s +class AtomFeed: +    title: Optional[AtomTextConstruct] = attr.ib() +    id_: str = attr.ib() + +    # Should be mandatory but many feeds do not include it +    updated: Optional[datetime] = attr.ib() + +    authors: List['AtomPerson'] = attr.ib() +    contributors: List['AtomPerson'] = attr.ib() +    links: List['AtomLink'] = attr.ib() +    categories: List['AtomCategory'] = attr.ib() +    generator: Optional['AtomGenerator'] = attr.ib() +    subtitle: Optional[AtomTextConstruct] = attr.ib() +    rights: Optional[AtomTextConstruct] = attr.ib() +    icon: Optional[str] = attr.ib() +    logo: Optional[str] = attr.ib() + +    entries: List[AtomEntry] = attr.ib() + + +@attr.s +class AtomPerson: +    name: str = attr.ib() +    uri: Optional[str] = attr.ib() +    email: Optional[str] = attr.ib() + + +@attr.s +class AtomLink: +    href: str = attr.ib() +    rel: Optional[str] = attr.ib() +    type_: Optional[str] = attr.ib() +    hreflang: Optional[str] = attr.ib() +    title: Optional[str] = attr.ib() +    length: Optional[int] = attr.ib() + + +@attr.s +class AtomCategory: +    term: str = attr.ib() +    scheme: Optional[str] = attr.ib() +    label: Optional[str] = attr.ib() + + +@attr.s +class AtomGenerator: +    name: str = attr.ib() +    uri: Optional[str] = attr.ib() +    version: Optional[str] = attr.ib() + + +def _get_generator(element: Element, name, +                   optional: bool=True) -> Optional[AtomGenerator]: +    child = get_child(element, name, optional) +    if child is None: +        return None + +    return AtomGenerator( +        child.text.strip(), +        child.attrib.get('uri'), +        child.attrib.get('version'), +    ) + + +def _get_text_construct(element: Element, name, +                        optional: bool=True) -> Optional[AtomTextConstruct]: +    child = get_child(element, name, optional) +    if child is None: +        return None + +    try: +        text_type = AtomTextType(child.attrib['type']) +    except KeyError: +        text_type = AtomTextType.text + +    try: +        lang = child.lang +    except AttributeError: +        lang = None + +    if child.text is None: +        if optional: +            return None + +        raise FeedParseError( +            'Could not parse atom feed: "{}" text is required but is empty' +            .format(name) +        ) + +    return AtomTextConstruct( +        text_type, +        lang, +        child.text.strip() +    ) + + +def _get_person(element: Element) -> Optional[AtomPerson]: +    try: +        return AtomPerson( +            get_text(element, 'feed:name', optional=False), +            get_text(element, 'feed:uri'), +            get_text(element, 'feed:email') +        ) +    except FeedParseError: +        return None + + +def _get_link(element: Element) -> AtomLink: +    length = element.attrib.get('length') +    length = int(length) if length else None +    return AtomLink( +        element.attrib['href'], +        element.attrib.get('rel'), +        element.attrib.get('type'), +        element.attrib.get('hreflang'), +        element.attrib.get('title'), +        length +    ) + + +def _get_category(element: Element) -> AtomCategory: +    return AtomCategory( +        element.attrib['term'], +        element.attrib.get('scheme'), +        element.attrib.get('label'), +    ) + + +def _get_entry(element: Element, +               default_authors: List[AtomPerson]) -> AtomEntry: +    root = element + +    # Mandatory +    title = _get_text_construct(root, 'feed:title') +    id_ = get_text(root, 'feed:id') + +    # Optional +    try: +        source = _parse_atom(get_child(root, 'feed:source', optional=False), +                             parse_entries=False) +    except FeedParseError: +        source = None +        source_authors = [] +    else: +        source_authors = source.authors + +    authors = [_get_person(e) +               for e in root.findall('feed:author', ns)] or default_authors +    authors = [a for a in authors if a is not None] +    authors = authors or default_authors or source_authors + +    contributors = [_get_person(e) +                    for e in root.findall('feed:contributor', ns) if e] +    contributors = [c for c in contributors if c is not None] + +    links = [_get_link(e) for e in root.findall('feed:link', ns)] +    categories = [_get_category(e) for e in root.findall('feed:category', ns)] + +    updated = get_datetime(root, 'feed:updated') +    published = get_datetime(root, 'feed:published') +    rights = _get_text_construct(root, 'feed:rights') +    summary = _get_text_construct(root, 'feed:summary') +    content = _get_text_construct(root, 'feed:content') + +    return AtomEntry( +        title, +        id_, +        updated, +        authors, +        contributors, +        links, +        categories, +        published, +        rights, +        summary, +        content, +        source +    ) + + +def _parse_atom(root: Element, parse_entries: bool=True) -> AtomFeed: +    # Mandatory +    id_ = get_text(root, 'feed:id', optional=False) + +    # Optional +    title = _get_text_construct(root, 'feed:title') +    updated = get_datetime(root, 'feed:updated') +    authors = [_get_person(e) +               for e in root.findall('feed:author', ns) if e] +    authors = [a for a in authors if a is not None] +    contributors = [_get_person(e) +                    for e in root.findall('feed:contributor', ns) if e] +    contributors = [c for c in contributors if c is not None] +    links = [_get_link(e) +             for e in root.findall('feed:link', ns)] +    categories = [_get_category(e) +                  for e in root.findall('feed:category', ns)] + +    generator = _get_generator(root, 'feed:generator') +    subtitle = _get_text_construct(root, 'feed:subtitle') +    rights = _get_text_construct(root, 'feed:rights') +    icon = get_text(root, 'feed:icon') +    logo = get_text(root, 'feed:logo') + +    if parse_entries: +        entries = [_get_entry(e, authors) +                   for e in root.findall('feed:entry', ns)] +    else: +        entries = [] + +    atom_feed = AtomFeed( +        title, +        id_, +        updated, +        authors, +        contributors, +        links, +        categories, +        generator, +        subtitle, +        rights, +        icon, +        logo, +        entries +    ) +    return atom_feed + + +def parse_atom_file(filename: str) -> AtomFeed: +    """Parse an Atom feed from a local XML file.""" +    root = parse_xml(filename).getroot() +    return _parse_atom(root) + + +def parse_atom_bytes(data: bytes) -> AtomFeed: +    """Parse an Atom feed from a byte-string containing XML data.""" +    root = parse_xml(BytesIO(data)).getroot() +    return _parse_atom(root) diff --git a/python/atoma/const.py b/python/atoma/const.py new file mode 100644 index 0000000..d52d0f6 --- /dev/null +++ b/python/atoma/const.py @@ -0,0 +1 @@ +VERSION = '0.0.13' diff --git a/python/atoma/exceptions.py b/python/atoma/exceptions.py new file mode 100644 index 0000000..88170c5 --- /dev/null +++ b/python/atoma/exceptions.py @@ -0,0 +1,14 @@ +class FeedParseError(Exception): +    """Document is an invalid feed.""" + + +class FeedDocumentError(Exception): +    """Document is not a supported file.""" + + +class FeedXMLError(FeedDocumentError): +    """Document is not valid XML.""" + + +class FeedJSONError(FeedDocumentError): +    """Document is not valid JSON.""" diff --git a/python/atoma/json_feed.py b/python/atoma/json_feed.py new file mode 100644 index 0000000..410ff4a --- /dev/null +++ b/python/atoma/json_feed.py @@ -0,0 +1,223 @@ +from datetime import datetime, timedelta +import json +from typing import Optional, List + +import attr + +from .exceptions import FeedParseError, FeedJSONError +from .utils import try_parse_date + + +@attr.s +class JSONFeedAuthor: + +    name: Optional[str] = attr.ib() +    url: Optional[str] = attr.ib() +    avatar: Optional[str] = attr.ib() + + +@attr.s +class JSONFeedAttachment: + +    url: str = attr.ib() +    mime_type: str = attr.ib() +    title: Optional[str] = attr.ib() +    size_in_bytes: Optional[int] = attr.ib() +    duration: Optional[timedelta] = attr.ib() + + +@attr.s +class JSONFeedItem: + +    id_: str = attr.ib() +    url: Optional[str] = attr.ib() +    external_url: Optional[str] = attr.ib() +    title: Optional[str] = attr.ib() +    content_html: Optional[str] = attr.ib() +    content_text: Optional[str] = attr.ib() +    summary: Optional[str] = attr.ib() +    image: Optional[str] = attr.ib() +    banner_image: Optional[str] = attr.ib() +    date_published: Optional[datetime] = attr.ib() +    date_modified: Optional[datetime] = attr.ib() +    author: Optional[JSONFeedAuthor] = attr.ib() + +    tags: List[str] = attr.ib() +    attachments: List[JSONFeedAttachment] = attr.ib() + + +@attr.s +class JSONFeed: + +    version: str = attr.ib() +    title: str = attr.ib() +    home_page_url: Optional[str] = attr.ib() +    feed_url: Optional[str] = attr.ib() +    description: Optional[str] = attr.ib() +    user_comment: Optional[str] = attr.ib() +    next_url: Optional[str] = attr.ib() +    icon: Optional[str] = attr.ib() +    favicon: Optional[str] = attr.ib() +    author: Optional[JSONFeedAuthor] = attr.ib() +    expired: bool = attr.ib() + +    items: List[JSONFeedItem] = attr.ib() + + +def _get_items(root: dict) -> List[JSONFeedItem]: +    rv = [] +    items = root.get('items', []) +    if not items: +        return rv + +    for item in items: +        rv.append(_get_item(item)) + +    return rv + + +def _get_item(item_dict: dict) -> JSONFeedItem: +    return JSONFeedItem( +        id_=_get_text(item_dict, 'id', optional=False), +        url=_get_text(item_dict, 'url'), +        external_url=_get_text(item_dict, 'external_url'), +        title=_get_text(item_dict, 'title'), +        content_html=_get_text(item_dict, 'content_html'), +        content_text=_get_text(item_dict, 'content_text'), +        summary=_get_text(item_dict, 'summary'), +        image=_get_text(item_dict, 'image'), +        banner_image=_get_text(item_dict, 'banner_image'), +        date_published=_get_datetime(item_dict, 'date_published'), +        date_modified=_get_datetime(item_dict, 'date_modified'), +        author=_get_author(item_dict), +        tags=_get_tags(item_dict, 'tags'), +        attachments=_get_attachments(item_dict, 'attachments') +    ) + + +def _get_attachments(root, name) -> List[JSONFeedAttachment]: +    rv = list() +    for attachment_dict in root.get(name, []): +        rv.append(JSONFeedAttachment( +            _get_text(attachment_dict, 'url', optional=False), +            _get_text(attachment_dict, 'mime_type', optional=False), +            _get_text(attachment_dict, 'title'), +            _get_int(attachment_dict, 'size_in_bytes'), +            _get_duration(attachment_dict, 'duration_in_seconds') +        )) +    return rv + + +def _get_tags(root, name) -> List[str]: +    tags = root.get(name, []) +    return [tag for tag in tags if isinstance(tag, str)] + + +def _get_datetime(root: dict, name, optional: bool=True) -> Optional[datetime]: +    text = _get_text(root, name, optional) +    if text is None: +        return None + +    return try_parse_date(text) + + +def _get_expired(root: dict) -> bool: +    if root.get('expired') is True: +        return True + +    return False + + +def _get_author(root: dict) -> Optional[JSONFeedAuthor]: +    author_dict = root.get('author') +    if not author_dict: +        return None + +    rv = JSONFeedAuthor( +        name=_get_text(author_dict, 'name'), +        url=_get_text(author_dict, 'url'), +        avatar=_get_text(author_dict, 'avatar'), +    ) +    if rv.name is None and rv.url is None and rv.avatar is None: +        return None + +    return rv + + +def _get_int(root: dict, name: str, optional: bool=True) -> Optional[int]: +    rv = root.get(name) +    if not optional and rv is None: +        raise FeedParseError('Could not parse feed: "{}" int is required but ' +                             'is empty'.format(name)) + +    if optional and rv is None: +        return None + +    if not isinstance(rv, int): +        raise FeedParseError('Could not parse feed: "{}" is not an int' +                             .format(name)) + +    return rv + + +def _get_duration(root: dict, name: str, +                  optional: bool=True) -> Optional[timedelta]: +    duration = _get_int(root, name, optional) +    if duration is None: +        return None + +    return timedelta(seconds=duration) + + +def _get_text(root: dict, name: str, optional: bool=True) -> Optional[str]: +    rv = root.get(name) +    if not optional and rv is None: +        raise FeedParseError('Could not parse feed: "{}" text is required but ' +                             'is empty'.format(name)) + +    if optional and rv is None: +        return None + +    if not isinstance(rv, str): +        raise FeedParseError('Could not parse feed: "{}" is not a string' +                             .format(name)) + +    return rv + + +def parse_json_feed(root: dict) -> JSONFeed: +    return JSONFeed( +        version=_get_text(root, 'version', optional=False), +        title=_get_text(root, 'title', optional=False), +        home_page_url=_get_text(root, 'home_page_url'), +        feed_url=_get_text(root, 'feed_url'), +        description=_get_text(root, 'description'), +        user_comment=_get_text(root, 'user_comment'), +        next_url=_get_text(root, 'next_url'), +        icon=_get_text(root, 'icon'), +        favicon=_get_text(root, 'favicon'), +        author=_get_author(root), +        expired=_get_expired(root), +        items=_get_items(root) +    ) + + +def parse_json_feed_file(filename: str) -> JSONFeed: +    """Parse a JSON feed from a local json file.""" +    with open(filename) as f: +        try: +            root = json.load(f) +        except json.decoder.JSONDecodeError: +            raise FeedJSONError('Not a valid JSON document') + +    return parse_json_feed(root) + + +def parse_json_feed_bytes(data: bytes) -> JSONFeed: +    """Parse a JSON feed from a byte-string containing JSON data.""" +    try: +        root = json.loads(data) +    except json.decoder.JSONDecodeError: +        raise FeedJSONError('Not a valid JSON document') + +    return parse_json_feed(root) diff --git a/python/atoma/opml.py b/python/atoma/opml.py new file mode 100644 index 0000000..a73105e --- /dev/null +++ b/python/atoma/opml.py @@ -0,0 +1,107 @@ +from datetime import datetime +from io import BytesIO +from typing import Optional, List +from xml.etree.ElementTree import Element + +import attr + +from .utils import parse_xml, get_text, get_int, get_datetime + + +@attr.s +class OPMLOutline: +    text: Optional[str] = attr.ib() +    type: Optional[str] = attr.ib() +    xml_url: Optional[str] = attr.ib() +    description: Optional[str] = attr.ib() +    html_url: Optional[str] = attr.ib() +    language: Optional[str] = attr.ib() +    title: Optional[str] = attr.ib() +    version: Optional[str] = attr.ib() + +    outlines: List['OPMLOutline'] = attr.ib() + + +@attr.s +class OPML: +    title: Optional[str] = attr.ib() +    owner_name: Optional[str] = attr.ib() +    owner_email: Optional[str] = attr.ib() +    date_created: Optional[datetime] = attr.ib() +    date_modified: Optional[datetime] = attr.ib() +    expansion_state: Optional[str] = attr.ib() + +    vertical_scroll_state: Optional[int] = attr.ib() +    window_top: Optional[int] = attr.ib() +    window_left: Optional[int] = attr.ib() +    window_bottom: Optional[int] = attr.ib() +    window_right: Optional[int] = attr.ib() + +    outlines: List[OPMLOutline] = attr.ib() + + +def _get_outlines(element: Element) -> List[OPMLOutline]: +    rv = list() + +    for outline in element.findall('outline'): +        rv.append(OPMLOutline( +            outline.attrib.get('text'), +            outline.attrib.get('type'), +            outline.attrib.get('xmlUrl'), +            outline.attrib.get('description'), +            outline.attrib.get('htmlUrl'), +            outline.attrib.get('language'), +            outline.attrib.get('title'), +            outline.attrib.get('version'), +            _get_outlines(outline) +        )) + +    return rv + + +def _parse_opml(root: Element) -> OPML: +    head = root.find('head') +    body = root.find('body') + +    return OPML( +        get_text(head, 'title'), +        get_text(head, 'ownerName'), +        get_text(head, 'ownerEmail'), +        get_datetime(head, 'dateCreated'), +        get_datetime(head, 'dateModified'), +        get_text(head, 'expansionState'), +        get_int(head, 'vertScrollState'), +        get_int(head, 'windowTop'), +        get_int(head, 'windowLeft'), +        get_int(head, 'windowBottom'), +        get_int(head, 'windowRight'), +        outlines=_get_outlines(body) +    ) + + +def parse_opml_file(filename: str) -> OPML: +    """Parse an OPML document from a local XML file.""" +    root = parse_xml(filename).getroot() +    return _parse_opml(root) + + +def parse_opml_bytes(data: bytes) -> OPML: +    """Parse an OPML document from a byte-string containing XML data.""" +    root = parse_xml(BytesIO(data)).getroot() +    return _parse_opml(root) + + +def get_feed_list(opml_obj: OPML) -> List[str]: +    """Walk an OPML document to extract the list of feed it contains.""" +    rv = list() + +    def collect(obj): +        for outline in obj.outlines: +            if outline.type == 'rss' and outline.xml_url: +                rv.append(outline.xml_url) + +            if outline.outlines: +                collect(outline) + +    collect(opml_obj) +    return rv diff --git a/python/atoma/rss.py b/python/atoma/rss.py new file mode 100644 index 0000000..f447a2f --- /dev/null +++ b/python/atoma/rss.py @@ -0,0 +1,221 @@ +from datetime import datetime +from io import BytesIO +from typing import Optional, List +from xml.etree.ElementTree import Element + +import attr + +from .utils import ( +    parse_xml, get_child, get_text, get_int, get_datetime, FeedParseError +) + + +@attr.s +class RSSImage: +    url: str = attr.ib() +    title: Optional[str] = attr.ib() +    link: str = attr.ib() +    width: int = attr.ib() +    height: int = attr.ib() +    description: Optional[str] = attr.ib() + + +@attr.s +class RSSEnclosure: +    url: str = attr.ib() +    length: Optional[int] = attr.ib() +    type: Optional[str] = attr.ib() + + +@attr.s +class RSSSource: +    title: str = attr.ib() +    url: Optional[str] = attr.ib() + + +@attr.s +class RSSItem: +    title: Optional[str] = attr.ib() +    link: Optional[str] = attr.ib() +    description: Optional[str] = attr.ib() +    author: Optional[str] = attr.ib() +    categories: List[str] = attr.ib() +    comments: Optional[str] = attr.ib() +    enclosures: List[RSSEnclosure] = attr.ib() +    guid: Optional[str] = attr.ib() +    pub_date: Optional[datetime] = attr.ib() +    source: Optional[RSSSource] = attr.ib() + +    # Extension +    content_encoded: Optional[str] = attr.ib() + + +@attr.s +class RSSChannel: +    title: Optional[str] = attr.ib() +    link: Optional[str] = attr.ib() +    description: Optional[str] = attr.ib() +    language: Optional[str] = attr.ib() +    copyright: Optional[str] = attr.ib() +    managing_editor: Optional[str] = attr.ib() +    web_master: Optional[str] = attr.ib() +    pub_date: Optional[datetime] = attr.ib() +    last_build_date: Optional[datetime] = attr.ib() +    categories: List[str] = attr.ib() +    generator: Optional[str] = attr.ib() +    docs: Optional[str] = attr.ib() +    ttl: Optional[int] = attr.ib() +    image: Optional[RSSImage] = attr.ib() + +    items: List[RSSItem] = attr.ib() + +    # Extension +    content_encoded: Optional[str] = attr.ib() + + +def _get_image(element: Element, name, +               optional: bool=True) -> Optional[RSSImage]: +    child = get_child(element, name, optional) +    if child is None: +        return None + +    return RSSImage( +        get_text(child, 'url', optional=False), +        get_text(child, 'title'), +        get_text(child, 'link', optional=False), +        get_int(child, 'width') or 88, +        get_int(child, 'height') or 31, +        get_text(child, 'description') +    ) + + +def _get_source(element: Element, name, +                optional: bool=True) -> Optional[RSSSource]: +    child = get_child(element, name, optional) +    if child is None: +        return None + +    return RSSSource( +        child.text.strip(), +        child.attrib.get('url'), +    ) + + +def _get_enclosure(element: Element) -> RSSEnclosure: +    length = element.attrib.get('length') +    try: +        length = int(length) +    except (TypeError, ValueError): +        length = None + +    return RSSEnclosure( +        element.attrib['url'], +        length, +        element.attrib.get('type'), +    ) + + +def _get_link(element: Element) -> Optional[str]: +    """Attempt to retrieve item link. + +    Use the GUID as a fallback if it is a permalink. +    """ +    link = get_text(element, 'link') +    if link is not None: +        return link + +    guid = get_child(element, 'guid') +    if guid is not None and guid.attrib.get('isPermaLink') == 'true': +        return get_text(element, 'guid') + +    return None + + +def _get_item(element: Element) -> RSSItem: +    root = element + +    title = get_text(root, 'title') +    link = _get_link(root) +    description = get_text(root, 'description') +    author = get_text(root, 'author') +    categories = [e.text for e in root.findall('category')] +    comments = get_text(root, 'comments') +    enclosure = [_get_enclosure(e) for e in root.findall('enclosure')] +    guid = get_text(root, 'guid') +    pub_date = get_datetime(root, 'pubDate') +    source = _get_source(root, 'source') + +    content_encoded = get_text(root, 'content:encoded') + +    return RSSItem( +        title, +        link, +        description, +        author, +        categories, +        comments, +        enclosure, +        guid, +        pub_date, +        source, +        content_encoded +    ) + + +def _parse_rss(root: Element) -> RSSChannel: +    rss_version = root.get('version') +    if rss_version != '2.0': +        raise FeedParseError('Cannot process RSS feed version "{}"' +                             .format(rss_version)) + +    root = root.find('channel') + +    title = get_text(root, 'title') +    link = get_text(root, 'link') +    description = get_text(root, 'description') +    language = get_text(root, 'language') +    copyright = get_text(root, 'copyright') +    managing_editor = get_text(root, 'managingEditor') +    web_master = get_text(root, 'webMaster') +    pub_date = get_datetime(root, 'pubDate') +    last_build_date = get_datetime(root, 'lastBuildDate') +    categories = [e.text for e in root.findall('category')] +    generator = get_text(root, 'generator') +    docs = get_text(root, 'docs') +    ttl = get_int(root, 'ttl') + +    image = _get_image(root, 'image') +    items = [_get_item(e) for e in root.findall('item')] + +    content_encoded = get_text(root, 'content:encoded') + +    return RSSChannel( +        title, +        link, +        description, +        language, +        copyright, +        managing_editor, +        web_master, +        pub_date, +        last_build_date, +        categories, +        generator, +        docs, +        ttl, +        image, +        items, +        content_encoded +    ) + + +def parse_rss_file(filename: str) -> RSSChannel: +    """Parse an RSS feed from a local XML file.""" +    root = parse_xml(filename).getroot() +    return _parse_rss(root) + + +def parse_rss_bytes(data: bytes) -> RSSChannel: +    """Parse an RSS feed from a byte-string containing XML data.""" +    root = parse_xml(BytesIO(data)).getroot() +    return _parse_rss(root) diff --git a/python/atoma/simple.py b/python/atoma/simple.py new file mode 100644 index 0000000..98bb3e1 --- /dev/null +++ b/python/atoma/simple.py @@ -0,0 +1,224 @@ +"""Simple API that abstracts away the differences between feed types.""" + +from datetime import datetime, timedelta +import html +import os +from typing import Optional, List, Tuple +import urllib.parse + +import attr + +from . import atom, rss, json_feed +from .exceptions import ( +    FeedParseError, FeedDocumentError, FeedXMLError, FeedJSONError +) + + +@attr.s +class Attachment: +    link: str = attr.ib() +    mime_type: Optional[str] = attr.ib() +    title: Optional[str] = attr.ib() +    size_in_bytes: Optional[int] = attr.ib() +    duration: Optional[timedelta] = attr.ib() + + +@attr.s +class Article: +    id: str = attr.ib() +    title: Optional[str] = attr.ib() +    link: Optional[str] = attr.ib() +    content: str = attr.ib() +    published_at: Optional[datetime] = attr.ib() +    updated_at: Optional[datetime] = attr.ib() +    attachments: List[Attachment] = attr.ib() + + +@attr.s +class Feed: +    title: str = attr.ib() +    subtitle: Optional[str] = attr.ib() +    link: Optional[str] = attr.ib() +    updated_at: Optional[datetime] = attr.ib() +    articles: List[Article] = attr.ib() + + +def _adapt_atom_feed(atom_feed: atom.AtomFeed) -> Feed: +    articles = list() +    for entry in atom_feed.entries: +        if entry.content is not None: +            content = entry.content.value +        elif entry.summary is not None: +            content = entry.summary.value +        else: +            content = '' +        published_at, updated_at = _get_article_dates(entry.published, +                                                      entry.updated) +        # Find article link and attachments +        article_link = None +        attachments = list() +        for candidate_link in entry.links: +            if candidate_link.rel in ('alternate', None): +                article_link = candidate_link.href +            elif candidate_link.rel == 'enclosure': +                attachments.append(Attachment( +                    title=_get_attachment_title(candidate_link.title, +                                                candidate_link.href), +                    link=candidate_link.href, +                    mime_type=candidate_link.type_, +                    size_in_bytes=candidate_link.length, +                    duration=None +                )) + +        if entry.title is None: +            entry_title = None +        elif entry.title.text_type in (atom.AtomTextType.html, +                                       atom.AtomTextType.xhtml): +            entry_title = html.unescape(entry.title.value).strip() +        else: +            entry_title = entry.title.value + +        articles.append(Article( +            entry.id_, +            entry_title, +            article_link, +            content, +            published_at, +            updated_at, +            attachments +        )) + +    # Find feed link +    link = None +    for candidate_link in atom_feed.links: +        if candidate_link.rel == 'self': +            link = candidate_link.href +            break + +    return Feed( +        atom_feed.title.value if atom_feed.title else atom_feed.id_, +        atom_feed.subtitle.value if atom_feed.subtitle else None, +        link, +        atom_feed.updated, +        articles +    ) + + +def _adapt_rss_channel(rss_channel: rss.RSSChannel) -> Feed: +    articles = list() +    for item in rss_channel.items: +        attachments = [ +            Attachment(link=e.url, mime_type=e.type, size_in_bytes=e.length, +                       title=_get_attachment_title(None, e.url), duration=None) +            for e in item.enclosures +        ] +        articles.append(Article( +            item.guid or item.link, +            item.title, +            item.link, +            item.content_encoded or item.description or '', +            item.pub_date, +            None, +            attachments +        )) + +    if rss_channel.title is None and rss_channel.link is None: +        raise FeedParseError('RSS feed does not have a title nor a link') + +    return Feed( +        rss_channel.title if rss_channel.title else rss_channel.link, +        rss_channel.description, +        rss_channel.link, +        rss_channel.pub_date, +        articles +    ) + + +def _adapt_json_feed(json_feed: json_feed.JSONFeed) -> Feed: +    articles = list() +    for item in json_feed.items: +        attachments = [ +            Attachment(a.url, a.mime_type, +                       _get_attachment_title(a.title, a.url), +                       a.size_in_bytes, a.duration) +            for a in item.attachments +        ] +        articles.append(Article( +            item.id_, +            item.title, +            item.url, +            item.content_html or item.content_text or '', +            item.date_published, +            item.date_modified, +            attachments +        )) + +    return Feed( +        json_feed.title, +        json_feed.description, +        json_feed.feed_url, +        None, +        articles +    ) + + +def _get_article_dates(published_at: Optional[datetime], +                       updated_at: Optional[datetime] +                       ) -> Tuple[Optional[datetime], Optional[datetime]]: +    if published_at and updated_at: +        return published_at, updated_at + +    if updated_at: +        return updated_at, None + +    if published_at: +        return published_at, None + +    raise FeedParseError('Article does not have proper dates') + + +def _get_attachment_title(attachment_title: Optional[str], link: str) -> str: +    if attachment_title: +        return attachment_title + +    parsed_link = urllib.parse.urlparse(link) +    return os.path.basename(parsed_link.path) + + +def _simple_parse(pairs, content) -> Feed: +    is_xml = True +    is_json = True +    for parser, adapter in pairs: +        try: +            return adapter(parser(content)) +        except FeedXMLError: +            is_xml = False +        except FeedJSONError: +            is_json = False +        except FeedParseError: +            continue + +    if not is_xml and not is_json: +        raise FeedDocumentError('File is not a supported feed type') + +    raise FeedParseError('File is not a valid supported feed') + + +def simple_parse_file(filename: str) -> Feed: +    """Parse an Atom, RSS or JSON feed from a local file.""" +    pairs = ( +        (rss.parse_rss_file, _adapt_rss_channel), +        (atom.parse_atom_file, _adapt_atom_feed), +        (json_feed.parse_json_feed_file, _adapt_json_feed) +    ) +    return _simple_parse(pairs, filename) + + +def simple_parse_bytes(data: bytes) -> Feed: +    """Parse an Atom, RSS or JSON feed from a byte-string containing data.""" +    pairs = ( +        (rss.parse_rss_bytes, _adapt_rss_channel), +        (atom.parse_atom_bytes, _adapt_atom_feed), +        (json_feed.parse_json_feed_bytes, _adapt_json_feed) +    ) +    return _simple_parse(pairs, data) diff --git a/python/atoma/utils.py b/python/atoma/utils.py new file mode 100644 index 0000000..4dc1ab5 --- /dev/null +++ b/python/atoma/utils.py @@ -0,0 +1,84 @@ +from datetime import datetime, timezone +from xml.etree.ElementTree import Element +from typing import Optional + +import dateutil.parser +from defusedxml.ElementTree import parse as defused_xml_parse, ParseError + +from .exceptions import FeedXMLError, FeedParseError + +ns = { +    'content': 'http://purl.org/rss/1.0/modules/content/', +    'feed': 'http://www.w3.org/2005/Atom' +} + + +def parse_xml(xml_content): +    try: +        return defused_xml_parse(xml_content) +    except ParseError: +        raise FeedXMLError('Not a valid XML document') + + +def get_child(element: Element, name, +              optional: bool=True) -> Optional[Element]: +    child = element.find(name, namespaces=ns) + +    if child is None and not optional: +        raise FeedParseError( +            'Could not parse feed: "{}" does not have a "{}"' +            .format(element.tag, name) +        ) + +    elif child is None: +        return None + +    return child + + +def get_text(element: Element, name, optional: bool=True) -> Optional[str]: +    child = get_child(element, name, optional) +    if child is None: +        return None + +    if child.text is None: +        if optional: +            return None + +        raise FeedParseError( +            'Could not parse feed: "{}" text is required but is empty' +            .format(name) +        ) + +    return child.text.strip() + + +def get_int(element: Element, name, optional: bool=True) -> Optional[int]: +    text = get_text(element, name, optional) +    if text is None: +        return None + +    return int(text) + + +def get_datetime(element: Element, name, +                 optional: bool=True) -> Optional[datetime]: +    text = get_text(element, name, optional) +    if text is None: +        return None + +    return try_parse_date(text) + + +def try_parse_date(date_str: str) -> Optional[datetime]: +    try: +        date = dateutil.parser.parse(date_str, fuzzy=True) +    except (ValueError, OverflowError): +        return None + +    if date.tzinfo is None: +        # TZ naive datetime, make it a TZ aware datetime by assuming it +        # contains UTC time +        date = date.replace(tzinfo=timezone.utc) + +    return date diff --git a/python/attr/__init__.py b/python/attr/__init__.py new file mode 100644 index 0000000..debfd57 --- /dev/null +++ b/python/attr/__init__.py @@ -0,0 +1,65 @@ +from __future__ import absolute_import, division, print_function + +from functools import partial + +from . import converters, exceptions, filters, validators +from ._config import get_run_validators, set_run_validators +from ._funcs import asdict, assoc, astuple, evolve, has +from ._make import ( +    NOTHING, +    Attribute, +    Factory, +    attrib, +    attrs, +    fields, +    fields_dict, +    make_class, +    validate, +) + + +__version__ = "18.2.0" + +__title__ = "attrs" +__description__ = "Classes Without Boilerplate" +__url__ = "https://www.attrs.org/" +__uri__ = __url__ +__doc__ = __description__ + " <" + __uri__ + ">" + +__author__ = "Hynek Schlawack" +__email__ = "hs@ox.cx" + +__license__ = "MIT" +__copyright__ = "Copyright (c) 2015 Hynek Schlawack" + + +s = attributes = attrs +ib = attr = attrib +dataclass = partial(attrs, auto_attribs=True)  # happy Easter ;) + +__all__ = [ +    "Attribute", +    "Factory", +    "NOTHING", +    "asdict", +    "assoc", +    "astuple", +    "attr", +    "attrib", +    "attributes", +    "attrs", +    "converters", +    "evolve", +    "exceptions", +    "fields", +    "fields_dict", +    "filters", +    "get_run_validators", +    "has", +    "ib", +    "make_class", +    "s", +    "set_run_validators", +    "validate", +    "validators", +] diff --git a/python/attr/__init__.pyi b/python/attr/__init__.pyi new file mode 100644 index 0000000..492fb85 --- /dev/null +++ b/python/attr/__init__.pyi @@ -0,0 +1,252 @@ +from typing import ( +    Any, +    Callable, +    Dict, +    Generic, +    List, +    Optional, +    Sequence, +    Mapping, +    Tuple, +    Type, +    TypeVar, +    Union, +    overload, +) + +# `import X as X` is required to make these public +from . import exceptions as exceptions +from . import filters as filters +from . import converters as converters +from . import validators as validators + +_T = TypeVar("_T") +_C = TypeVar("_C", bound=type) + +_ValidatorType = Callable[[Any, Attribute, _T], Any] +_ConverterType = Callable[[Any], _T] +_FilterType = Callable[[Attribute, Any], bool] +# FIXME: in reality, if multiple validators are passed they must be in a list or tuple, +# but those are invariant and so would prevent subtypes of _ValidatorType from working +# when passed in a list or tuple. +_ValidatorArgType = Union[_ValidatorType[_T], Sequence[_ValidatorType[_T]]] + +# _make -- + +NOTHING: object + +# NOTE: Factory lies about its return type to make this possible: `x: List[int] = Factory(list)` +# Work around mypy issue #4554 in the common case by using an overload. +@overload +def Factory(factory: Callable[[], _T]) -> _T: ... +@overload +def Factory( +    factory: Union[Callable[[Any], _T], Callable[[], _T]], +    takes_self: bool = ..., +) -> _T: ... + +class Attribute(Generic[_T]): +    name: str +    default: Optional[_T] +    validator: Optional[_ValidatorType[_T]] +    repr: bool +    cmp: bool +    hash: Optional[bool] +    init: bool +    converter: Optional[_ConverterType[_T]] +    metadata: Dict[Any, Any] +    type: Optional[Type[_T]] +    kw_only: bool +    def __lt__(self, x: Attribute) -> bool: ... +    def __le__(self, x: Attribute) -> bool: ... +    def __gt__(self, x: Attribute) -> bool: ... +    def __ge__(self, x: Attribute) -> bool: ... + +# NOTE: We had several choices for the annotation to use for type arg: +# 1) Type[_T] +#   - Pros: Handles simple cases correctly +#   - Cons: Might produce less informative errors in the case of conflicting TypeVars +#   e.g. `attr.ib(default='bad', type=int)` +# 2) Callable[..., _T] +#   - Pros: Better error messages than #1 for conflicting TypeVars +#   - Cons: Terrible error messages for validator checks. +#   e.g. attr.ib(type=int, validator=validate_str) +#        -> error: Cannot infer function type argument +# 3) type (and do all of the work in the mypy plugin) +#   - Pros: Simple here, and we could customize the plugin with our own errors. +#   - Cons: Would need to write mypy plugin code to handle all the cases. +# We chose option #1. + +# `attr` lies about its return type to make the following possible: +#     attr()    -> Any +#     attr(8)   -> int +#     attr(validator=<some callable>)  -> Whatever the callable expects. +# This makes this type of assignments possible: +#     x: int = attr(8) +# +# This form catches explicit None or no default but with no other arguments returns Any. +@overload +def attrib( +    default: None = ..., +    validator: None = ..., +    repr: bool = ..., +    cmp: bool = ..., +    hash: Optional[bool] = ..., +    init: bool = ..., +    convert: None = ..., +    metadata: Optional[Mapping[Any, Any]] = ..., +    type: None = ..., +    converter: None = ..., +    factory: None = ..., +    kw_only: bool = ..., +) -> Any: ... + +# This form catches an explicit None or no default and infers the type from the other arguments. +@overload +def attrib( +    default: None = ..., +    validator: Optional[_ValidatorArgType[_T]] = ..., +    repr: bool = ..., +    cmp: bool = ..., +    hash: Optional[bool] = ..., +    init: bool = ..., +    convert: Optional[_ConverterType[_T]] = ..., +    metadata: Optional[Mapping[Any, Any]] = ..., +    type: Optional[Type[_T]] = ..., +    converter: Optional[_ConverterType[_T]] = ..., +    factory: Optional[Callable[[], _T]] = ..., +    kw_only: bool = ..., +) -> _T: ... + +# This form catches an explicit default argument. +@overload +def attrib( +    default: _T, +    validator: Optional[_ValidatorArgType[_T]] = ..., +    repr: bool = ..., +    cmp: bool = ..., +    hash: Optional[bool] = ..., +    init: bool = ..., +    convert: Optional[_ConverterType[_T]] = ..., +    metadata: Optional[Mapping[Any, Any]] = ..., +    type: Optional[Type[_T]] = ..., +    converter: Optional[_ConverterType[_T]] = ..., +    factory: Optional[Callable[[], _T]] = ..., +    kw_only: bool = ..., +) -> _T: ... + +# This form covers type=non-Type: e.g. forward references (str), Any +@overload +def attrib( +    default: Optional[_T] = ..., +    validator: Optional[_ValidatorArgType[_T]] = ..., +    repr: bool = ..., +    cmp: bool = ..., +    hash: Optional[bool] = ..., +    init: bool = ..., +    convert: Optional[_ConverterType[_T]] = ..., +    metadata: Optional[Mapping[Any, Any]] = ..., +    type: object = ..., +    converter: Optional[_ConverterType[_T]] = ..., +    factory: Optional[Callable[[], _T]] = ..., +    kw_only: bool = ..., +) -> Any: ... +@overload +def attrs( +    maybe_cls: _C, +    these: Optional[Dict[str, Any]] = ..., +    repr_ns: Optional[str] = ..., +    repr: bool = ..., +    cmp: bool = ..., +    hash: Optional[bool] = ..., +    init: bool = ..., +    slots: bool = ..., +    frozen: bool = ..., +    weakref_slot: bool = ..., +    str: bool = ..., +    auto_attribs: bool = ..., +    kw_only: bool = ..., +    cache_hash: bool = ..., +) -> _C: ... +@overload +def attrs( +    maybe_cls: None = ..., +    these: Optional[Dict[str, Any]] = ..., +    repr_ns: Optional[str] = ..., +    repr: bool = ..., +    cmp: bool = ..., +    hash: Optional[bool] = ..., +    init: bool = ..., +    slots: bool = ..., +    frozen: bool = ..., +    weakref_slot: bool = ..., +    str: bool = ..., +    auto_attribs: bool = ..., +    kw_only: bool = ..., +    cache_hash: bool = ..., +) -> Callable[[_C], _C]: ... + +# TODO: add support for returning NamedTuple from the mypy plugin +class _Fields(Tuple[Attribute, ...]): +    def __getattr__(self, name: str) -> Attribute: ... + +def fields(cls: type) -> _Fields: ... +def fields_dict(cls: type) -> Dict[str, Attribute]: ... +def validate(inst: Any) -> None: ... + +# TODO: add support for returning a proper attrs class from the mypy plugin +# we use Any instead of _CountingAttr so that e.g. `make_class('Foo', [attr.ib()])` is valid +def make_class( +    name: str, +    attrs: Union[List[str], Tuple[str, ...], Dict[str, Any]], +    bases: Tuple[type, ...] = ..., +    repr_ns: Optional[str] = ..., +    repr: bool = ..., +    cmp: bool = ..., +    hash: Optional[bool] = ..., +    init: bool = ..., +    slots: bool = ..., +    frozen: bool = ..., +    weakref_slot: bool = ..., +    str: bool = ..., +    auto_attribs: bool = ..., +    kw_only: bool = ..., +    cache_hash: bool = ..., +) -> type: ... + +# _funcs -- + +# TODO: add support for returning TypedDict from the mypy plugin +# FIXME: asdict/astuple do not honor their factory args.  waiting on one of these: +# https://github.com/python/mypy/issues/4236 +# https://github.com/python/typing/issues/253 +def asdict( +    inst: Any, +    recurse: bool = ..., +    filter: Optional[_FilterType] = ..., +    dict_factory: Type[Mapping[Any, Any]] = ..., +    retain_collection_types: bool = ..., +) -> Dict[str, Any]: ... + +# TODO: add support for returning NamedTuple from the mypy plugin +def astuple( +    inst: Any, +    recurse: bool = ..., +    filter: Optional[_FilterType] = ..., +    tuple_factory: Type[Sequence] = ..., +    retain_collection_types: bool = ..., +) -> Tuple[Any, ...]: ... +def has(cls: type) -> bool: ... +def assoc(inst: _T, **changes: Any) -> _T: ... +def evolve(inst: _T, **changes: Any) -> _T: ... + +# _config -- + +def set_run_validators(run: bool) -> None: ... +def get_run_validators() -> bool: ... + +# aliases -- + +s = attributes = attrs +ib = attr = attrib +dataclass = attrs  # Technically, partial(attrs, auto_attribs=True) ;) diff --git a/python/attr/_compat.py b/python/attr/_compat.py new file mode 100644 index 0000000..5bb0659 --- /dev/null +++ b/python/attr/_compat.py @@ -0,0 +1,163 @@ +from __future__ import absolute_import, division, print_function + +import platform +import sys +import types +import warnings + + +PY2 = sys.version_info[0] == 2 +PYPY = platform.python_implementation() == "PyPy" + + +if PYPY or sys.version_info[:2] >= (3, 6): +    ordered_dict = dict +else: +    from collections import OrderedDict + +    ordered_dict = OrderedDict + + +if PY2: +    from UserDict import IterableUserDict + +    # We 'bundle' isclass instead of using inspect as importing inspect is +    # fairly expensive (order of 10-15 ms for a modern machine in 2016) +    def isclass(klass): +        return isinstance(klass, (type, types.ClassType)) + +    # TYPE is used in exceptions, repr(int) is different on Python 2 and 3. +    TYPE = "type" + +    def iteritems(d): +        return d.iteritems() + +    # Python 2 is bereft of a read-only dict proxy, so we make one! +    class ReadOnlyDict(IterableUserDict): +        """ +        Best-effort read-only dict wrapper. +        """ + +        def __setitem__(self, key, val): +            # We gently pretend we're a Python 3 mappingproxy. +            raise TypeError( +                "'mappingproxy' object does not support item assignment" +            ) + +        def update(self, _): +            # We gently pretend we're a Python 3 mappingproxy. +            raise AttributeError( +                "'mappingproxy' object has no attribute 'update'" +            ) + +        def __delitem__(self, _): +            # We gently pretend we're a Python 3 mappingproxy. +            raise TypeError( +                "'mappingproxy' object does not support item deletion" +            ) + +        def clear(self): +            # We gently pretend we're a Python 3 mappingproxy. +            raise AttributeError( +                "'mappingproxy' object has no attribute 'clear'" +            ) + +        def pop(self, key, default=None): +            # We gently pretend we're a Python 3 mappingproxy. +            raise AttributeError( +                "'mappingproxy' object has no attribute 'pop'" +            ) + +        def popitem(self): +            # We gently pretend we're a Python 3 mappingproxy. +            raise AttributeError( +                "'mappingproxy' object has no attribute 'popitem'" +            ) + +        def setdefault(self, key, default=None): +            # We gently pretend we're a Python 3 mappingproxy. +            raise AttributeError( +                "'mappingproxy' object has no attribute 'setdefault'" +            ) + +        def __repr__(self): +            # Override to be identical to the Python 3 version. +            return "mappingproxy(" + repr(self.data) + ")" + +    def metadata_proxy(d): +        res = ReadOnlyDict() +        res.data.update(d)  # We blocked update, so we have to do it like this. +        return res + + +else: + +    def isclass(klass): +        return isinstance(klass, type) + +    TYPE = "class" + +    def iteritems(d): +        return d.items() + +    def metadata_proxy(d): +        return types.MappingProxyType(dict(d)) + + +def import_ctypes(): +    """ +    Moved into a function for testability. +    """ +    import ctypes + +    return ctypes + + +if not PY2: + +    def just_warn(*args, **kw): +        """ +        We only warn on Python 3 because we are not aware of any concrete +        consequences of not setting the cell on Python 2. +        """ +        warnings.warn( +            "Missing ctypes.  Some features like bare super() or accessing " +            "__class__ will not work with slots classes.", +            RuntimeWarning, +            stacklevel=2, +        ) + + +else: + +    def just_warn(*args, **kw):  # pragma: nocover +        """ +        We only warn on Python 3 because we are not aware of any concrete +        consequences of not setting the cell on Python 2. +        """ + + +def make_set_closure_cell(): +    """ +    Moved into a function for testability. +    """ +    if PYPY:  # pragma: no cover + +        def set_closure_cell(cell, value): +            cell.__setstate__((value,)) + +    else: +        try: +            ctypes = import_ctypes() + +            set_closure_cell = ctypes.pythonapi.PyCell_Set +            set_closure_cell.argtypes = (ctypes.py_object, ctypes.py_object) +            set_closure_cell.restype = ctypes.c_int +        except Exception: +            # We try best effort to set the cell, but sometimes it's not +            # possible.  For example on Jython or on GAE. +            set_closure_cell = just_warn +    return set_closure_cell + + +set_closure_cell = make_set_closure_cell() diff --git a/python/attr/_config.py b/python/attr/_config.py new file mode 100644 index 0000000..8ec9209 --- /dev/null +++ b/python/attr/_config.py @@ -0,0 +1,23 @@ +from __future__ import absolute_import, division, print_function + + +__all__ = ["set_run_validators", "get_run_validators"] + +_run_validators = True + + +def set_run_validators(run): +    """ +    Set whether or not validators are run.  By default, they are run. +    """ +    if not isinstance(run, bool): +        raise TypeError("'run' must be bool.") +    global _run_validators +    _run_validators = run + + +def get_run_validators(): +    """ +    Return whether or not validators are run. +    """ +    return _run_validators diff --git a/python/attr/_funcs.py b/python/attr/_funcs.py new file mode 100644 index 0000000..b61d239 --- /dev/null +++ b/python/attr/_funcs.py @@ -0,0 +1,290 @@ +from __future__ import absolute_import, division, print_function + +import copy + +from ._compat import iteritems +from ._make import NOTHING, _obj_setattr, fields +from .exceptions import AttrsAttributeNotFoundError + + +def asdict( +    inst, +    recurse=True, +    filter=None, +    dict_factory=dict, +    retain_collection_types=False, +): +    """ +    Return the ``attrs`` attribute values of *inst* as a dict. + +    Optionally recurse into other ``attrs``-decorated classes. + +    :param inst: Instance of an ``attrs``-decorated class. +    :param bool recurse: Recurse into classes that are also +        ``attrs``-decorated. +    :param callable filter: A callable whose return code determines whether an +        attribute or element is included (``True``) or dropped (``False``).  Is +        called with the :class:`attr.Attribute` as the first argument and the +        value as the second argument. +    :param callable dict_factory: A callable to produce dictionaries from.  For +        example, to produce ordered dictionaries instead of normal Python +        dictionaries, pass in ``collections.OrderedDict``. +    :param bool retain_collection_types: Do not convert to ``list`` when +        encountering an attribute whose type is ``tuple`` or ``set``.  Only +        meaningful if ``recurse`` is ``True``. + +    :rtype: return type of *dict_factory* + +    :raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs`` +        class. + +    ..  versionadded:: 16.0.0 *dict_factory* +    ..  versionadded:: 16.1.0 *retain_collection_types* +    """ +    attrs = fields(inst.__class__) +    rv = dict_factory() +    for a in attrs: +        v = getattr(inst, a.name) +        if filter is not None and not filter(a, v): +            continue +        if recurse is True: +            if has(v.__class__): +                rv[a.name] = asdict( +                    v, True, filter, dict_factory, retain_collection_types +                ) +            elif isinstance(v, (tuple, list, set)): +                cf = v.__class__ if retain_collection_types is True else list +                rv[a.name] = cf( +                    [ +                        _asdict_anything( +                            i, filter, dict_factory, retain_collection_types +                        ) +                        for i in v +                    ] +                ) +            elif isinstance(v, dict): +                df = dict_factory +                rv[a.name] = df( +                    ( +                        _asdict_anything( +                            kk, filter, df, retain_collection_types +                        ), +                        _asdict_anything( +                            vv, filter, df, retain_collection_types +                        ), +                    ) +                    for kk, vv in iteritems(v) +                ) +            else: +                rv[a.name] = v +        else: +            rv[a.name] = v +    return rv + + +def _asdict_anything(val, filter, dict_factory, retain_collection_types): +    """ +    ``asdict`` only works on attrs instances, this works on anything. +    """ +    if getattr(val.__class__, "__attrs_attrs__", None) is not None: +        # Attrs class. +        rv = asdict(val, True, filter, dict_factory, retain_collection_types) +    elif isinstance(val, (tuple, list, set)): +        cf = val.__class__ if retain_collection_types is True else list +        rv = cf( +            [ +                _asdict_anything( +                    i, filter, dict_factory, retain_collection_types +                ) +                for i in val +            ] +        ) +    elif isinstance(val, dict): +        df = dict_factory +        rv = df( +            ( +                _asdict_anything(kk, filter, df, retain_collection_types), +                _asdict_anything(vv, filter, df, retain_collection_types), +            ) +            for kk, vv in iteritems(val) +        ) +    else: +        rv = val +    return rv + + +def astuple( +    inst, +    recurse=True, +    filter=None, +    tuple_factory=tuple, +    retain_collection_types=False, +): +    """ +    Return the ``attrs`` attribute values of *inst* as a tuple. + +    Optionally recurse into other ``attrs``-decorated classes. + +    :param inst: Instance of an ``attrs``-decorated class. +    :param bool recurse: Recurse into classes that are also +        ``attrs``-decorated. +    :param callable filter: A callable whose return code determines whether an +        attribute or element is included (``True``) or dropped (``False``).  Is +        called with the :class:`attr.Attribute` as the first argument and the +        value as the second argument. +    :param callable tuple_factory: A callable to produce tuples from.  For +        example, to produce lists instead of tuples. +    :param bool retain_collection_types: Do not convert to ``list`` +        or ``dict`` when encountering an attribute which type is +        ``tuple``, ``dict`` or ``set``.  Only meaningful if ``recurse`` is +        ``True``. + +    :rtype: return type of *tuple_factory* + +    :raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs`` +        class. + +    ..  versionadded:: 16.2.0 +    """ +    attrs = fields(inst.__class__) +    rv = [] +    retain = retain_collection_types  # Very long. :/ +    for a in attrs: +        v = getattr(inst, a.name) +        if filter is not None and not filter(a, v): +            continue +        if recurse is True: +            if has(v.__class__): +                rv.append( +                    astuple( +                        v, +                        recurse=True, +                        filter=filter, +                        tuple_factory=tuple_factory, +                        retain_collection_types=retain, +                    ) +                ) +            elif isinstance(v, (tuple, list, set)): +                cf = v.__class__ if retain is True else list +                rv.append( +                    cf( +                        [ +                            astuple( +                                j, +                                recurse=True, +                                filter=filter, +                                tuple_factory=tuple_factory, +                                retain_collection_types=retain, +                            ) +                            if has(j.__class__) +                            else j +                            for j in v +                        ] +                    ) +                ) +            elif isinstance(v, dict): +                df = v.__class__ if retain is True else dict +                rv.append( +                    df( +                        ( +                            astuple( +                                kk, +                                tuple_factory=tuple_factory, +                                retain_collection_types=retain, +                            ) +                            if has(kk.__class__) +                            else kk, +                            astuple( +                                vv, +                                tuple_factory=tuple_factory, +                                retain_collection_types=retain, +                            ) +                            if has(vv.__class__) +                            else vv, +                        ) +                        for kk, vv in iteritems(v) +                    ) +                ) +            else: +                rv.append(v) +        else: +            rv.append(v) +    return rv if tuple_factory is list else tuple_factory(rv) + + +def has(cls): +    """ +    Check whether *cls* is a class with ``attrs`` attributes. + +    :param type cls: Class to introspect. +    :raise TypeError: If *cls* is not a class. + +    :rtype: :class:`bool` +    """ +    return getattr(cls, "__attrs_attrs__", None) is not None + + +def assoc(inst, **changes): +    """ +    Copy *inst* and apply *changes*. + +    :param inst: Instance of a class with ``attrs`` attributes. +    :param changes: Keyword changes in the new copy. + +    :return: A copy of inst with *changes* incorporated. + +    :raise attr.exceptions.AttrsAttributeNotFoundError: If *attr_name* couldn't +        be found on *cls*. +    :raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs`` +        class. + +    ..  deprecated:: 17.1.0 +        Use :func:`evolve` instead. +    """ +    import warnings + +    warnings.warn( +        "assoc is deprecated and will be removed after 2018/01.", +        DeprecationWarning, +        stacklevel=2, +    ) +    new = copy.copy(inst) +    attrs = fields(inst.__class__) +    for k, v in iteritems(changes): +        a = getattr(attrs, k, NOTHING) +        if a is NOTHING: +            raise AttrsAttributeNotFoundError( +                "{k} is not an attrs attribute on {cl}.".format( +                    k=k, cl=new.__class__ +                ) +            ) +        _obj_setattr(new, k, v) +    return new + + +def evolve(inst, **changes): +    """ +    Create a new instance, based on *inst* with *changes* applied. + +    :param inst: Instance of a class with ``attrs`` attributes. +    :param changes: Keyword changes in the new copy. + +    :return: A copy of inst with *changes* incorporated. + +    :raise TypeError: If *attr_name* couldn't be found in the class +        ``__init__``. +    :raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs`` +        class. + +    ..  versionadded:: 17.1.0 +    """ +    cls = inst.__class__ +    attrs = fields(cls) +    for a in attrs: +        if not a.init: +            continue +        attr_name = a.name  # To deal with private attributes. +        init_name = attr_name if attr_name[0] != "_" else attr_name[1:] +        if init_name not in changes: +            changes[init_name] = getattr(inst, attr_name) +    return cls(**changes) diff --git a/python/attr/_make.py b/python/attr/_make.py new file mode 100644 index 0000000..f7fd05e --- /dev/null +++ b/python/attr/_make.py @@ -0,0 +1,2034 @@ +from __future__ import absolute_import, division, print_function + +import copy +import hashlib +import linecache +import sys +import threading +import warnings + +from operator import itemgetter + +from . import _config +from ._compat import ( +    PY2, +    isclass, +    iteritems, +    metadata_proxy, +    ordered_dict, +    set_closure_cell, +) +from .exceptions import ( +    DefaultAlreadySetError, +    FrozenInstanceError, +    NotAnAttrsClassError, +    PythonTooOldError, +    UnannotatedAttributeError, +) + + +# This is used at least twice, so cache it here. +_obj_setattr = object.__setattr__ +_init_converter_pat = "__attr_converter_{}" +_init_factory_pat = "__attr_factory_{}" +_tuple_property_pat = ( +    "    {attr_name} = _attrs_property(_attrs_itemgetter({index}))" +) +_classvar_prefixes = ("typing.ClassVar", "t.ClassVar", "ClassVar") +# we don't use a double-underscore prefix because that triggers +# name mangling when trying to create a slot for the field +# (when slots=True) +_hash_cache_field = "_attrs_cached_hash" + +_empty_metadata_singleton = metadata_proxy({}) + + +class _Nothing(object): +    """ +    Sentinel class to indicate the lack of a value when ``None`` is ambiguous. + +    ``_Nothing`` is a singleton. There is only ever one of it. +    """ + +    _singleton = None + +    def __new__(cls): +        if _Nothing._singleton is None: +            _Nothing._singleton = super(_Nothing, cls).__new__(cls) +        return _Nothing._singleton + +    def __repr__(self): +        return "NOTHING" + + +NOTHING = _Nothing() +""" +Sentinel to indicate the lack of a value when ``None`` is ambiguous. +""" + + +def attrib( +    default=NOTHING, +    validator=None, +    repr=True, +    cmp=True, +    hash=None, +    init=True, +    convert=None, +    metadata=None, +    type=None, +    converter=None, +    factory=None, +    kw_only=False, +): +    """ +    Create a new attribute on a class. + +    ..  warning:: + +        Does *not* do anything unless the class is also decorated with +        :func:`attr.s`! + +    :param default: A value that is used if an ``attrs``-generated ``__init__`` +        is used and no value is passed while instantiating or the attribute is +        excluded using ``init=False``. + +        If the value is an instance of :class:`Factory`, its callable will be +        used to construct a new value (useful for mutable data types like lists +        or dicts). + +        If a default is not set (or set manually to ``attr.NOTHING``), a value +        *must* be supplied when instantiating; otherwise a :exc:`TypeError` +        will be raised. + +        The default can also be set using decorator notation as shown below. + +    :type default: Any value. + +    :param callable factory: Syntactic sugar for +        ``default=attr.Factory(callable)``. + +    :param validator: :func:`callable` that is called by ``attrs``-generated +        ``__init__`` methods after the instance has been initialized.  They +        receive the initialized instance, the :class:`Attribute`, and the +        passed value. + +        The return value is *not* inspected so the validator has to throw an +        exception itself. + +        If a ``list`` is passed, its items are treated as validators and must +        all pass. + +        Validators can be globally disabled and re-enabled using +        :func:`get_run_validators`. + +        The validator can also be set using decorator notation as shown below. + +    :type validator: ``callable`` or a ``list`` of ``callable``\\ s. + +    :param bool repr: Include this attribute in the generated ``__repr__`` +        method. +    :param bool cmp: Include this attribute in the generated comparison methods +        (``__eq__`` et al). +    :param hash: Include this attribute in the generated ``__hash__`` +        method.  If ``None`` (default), mirror *cmp*'s value.  This is the +        correct behavior according the Python spec.  Setting this value to +        anything else than ``None`` is *discouraged*. +    :type hash: ``bool`` or ``None`` +    :param bool init: Include this attribute in the generated ``__init__`` +        method.  It is possible to set this to ``False`` and set a default +        value.  In that case this attributed is unconditionally initialized +        with the specified default value or factory. +    :param callable converter: :func:`callable` that is called by +        ``attrs``-generated ``__init__`` methods to converter attribute's value +        to the desired format.  It is given the passed-in value, and the +        returned value will be used as the new value of the attribute.  The +        value is converted before being passed to the validator, if any. +    :param metadata: An arbitrary mapping, to be used by third-party +        components.  See :ref:`extending_metadata`. +    :param type: The type of the attribute.  In Python 3.6 or greater, the +        preferred method to specify the type is using a variable annotation +        (see `PEP 526 <https://www.python.org/dev/peps/pep-0526/>`_). +        This argument is provided for backward compatibility. +        Regardless of the approach used, the type will be stored on +        ``Attribute.type``. + +        Please note that ``attrs`` doesn't do anything with this metadata by +        itself. You can use it as part of your own code or for +        :doc:`static type checking <types>`. +    :param kw_only: Make this attribute keyword-only (Python 3+) +        in the generated ``__init__`` (if ``init`` is ``False``, this +        parameter is ignored). + +    .. versionadded:: 15.2.0 *convert* +    .. versionadded:: 16.3.0 *metadata* +    .. versionchanged:: 17.1.0 *validator* can be a ``list`` now. +    .. versionchanged:: 17.1.0 +       *hash* is ``None`` and therefore mirrors *cmp* by default. +    .. versionadded:: 17.3.0 *type* +    .. deprecated:: 17.4.0 *convert* +    .. versionadded:: 17.4.0 *converter* as a replacement for the deprecated +       *convert* to achieve consistency with other noun-based arguments. +    .. versionadded:: 18.1.0 +       ``factory=f`` is syntactic sugar for ``default=attr.Factory(f)``. +    .. versionadded:: 18.2.0 *kw_only* +    """ +    if hash is not None and hash is not True and hash is not False: +        raise TypeError( +            "Invalid value for hash.  Must be True, False, or None." +        ) + +    if convert is not None: +        if converter is not None: +            raise RuntimeError( +                "Can't pass both `convert` and `converter`.  " +                "Please use `converter` only." +            ) +        warnings.warn( +            "The `convert` argument is deprecated in favor of `converter`.  " +            "It will be removed after 2019/01.", +            DeprecationWarning, +            stacklevel=2, +        ) +        converter = convert + +    if factory is not None: +        if default is not NOTHING: +            raise ValueError( +                "The `default` and `factory` arguments are mutually " +                "exclusive." +            ) +        if not callable(factory): +            raise ValueError("The `factory` argument must be a callable.") +        default = Factory(factory) + +    if metadata is None: +        metadata = {} + +    return _CountingAttr( +        default=default, +        validator=validator, +        repr=repr, +        cmp=cmp, +        hash=hash, +        init=init, +        converter=converter, +        metadata=metadata, +        type=type, +        kw_only=kw_only, +    ) + + +def _make_attr_tuple_class(cls_name, attr_names): +    """ +    Create a tuple subclass to hold `Attribute`s for an `attrs` class. + +    The subclass is a bare tuple with properties for names. + +    class MyClassAttributes(tuple): +        __slots__ = () +        x = property(itemgetter(0)) +    """ +    attr_class_name = "{}Attributes".format(cls_name) +    attr_class_template = [ +        "class {}(tuple):".format(attr_class_name), +        "    __slots__ = ()", +    ] +    if attr_names: +        for i, attr_name in enumerate(attr_names): +            attr_class_template.append( +                _tuple_property_pat.format(index=i, attr_name=attr_name) +            ) +    else: +        attr_class_template.append("    pass") +    globs = {"_attrs_itemgetter": itemgetter, "_attrs_property": property} +    eval(compile("\n".join(attr_class_template), "", "exec"), globs) + +    return globs[attr_class_name] + + +# Tuple class for extracted attributes from a class definition. +# `base_attrs` is a subset of `attrs`. +_Attributes = _make_attr_tuple_class( +    "_Attributes", +    [ +        # all attributes to build dunder methods for +        "attrs", +        # attributes that have been inherited +        "base_attrs", +        # map inherited attributes to their originating classes +        "base_attrs_map", +    ], +) + + +def _is_class_var(annot): +    """ +    Check whether *annot* is a typing.ClassVar. + +    The string comparison hack is used to avoid evaluating all string +    annotations which would put attrs-based classes at a performance +    disadvantage compared to plain old classes. +    """ +    return str(annot).startswith(_classvar_prefixes) + + +def _get_annotations(cls): +    """ +    Get annotations for *cls*. +    """ +    anns = getattr(cls, "__annotations__", None) +    if anns is None: +        return {} + +    # Verify that the annotations aren't merely inherited. +    for base_cls in cls.__mro__[1:]: +        if anns is getattr(base_cls, "__annotations__", None): +            return {} + +    return anns + + +def _counter_getter(e): +    """ +    Key function for sorting to avoid re-creating a lambda for every class. +    """ +    return e[1].counter + + +def _transform_attrs(cls, these, auto_attribs, kw_only): +    """ +    Transform all `_CountingAttr`s on a class into `Attribute`s. + +    If *these* is passed, use that and don't look for them on the class. + +    Return an `_Attributes`. +    """ +    cd = cls.__dict__ +    anns = _get_annotations(cls) + +    if these is not None: +        ca_list = [(name, ca) for name, ca in iteritems(these)] + +        if not isinstance(these, ordered_dict): +            ca_list.sort(key=_counter_getter) +    elif auto_attribs is True: +        ca_names = { +            name +            for name, attr in cd.items() +            if isinstance(attr, _CountingAttr) +        } +        ca_list = [] +        annot_names = set() +        for attr_name, type in anns.items(): +            if _is_class_var(type): +                continue +            annot_names.add(attr_name) +            a = cd.get(attr_name, NOTHING) +            if not isinstance(a, _CountingAttr): +                if a is NOTHING: +                    a = attrib() +                else: +                    a = attrib(default=a) +            ca_list.append((attr_name, a)) + +        unannotated = ca_names - annot_names +        if len(unannotated) > 0: +            raise UnannotatedAttributeError( +                "The following `attr.ib`s lack a type annotation: " +                + ", ".join( +                    sorted(unannotated, key=lambda n: cd.get(n).counter) +                ) +                + "." +            ) +    else: +        ca_list = sorted( +            ( +                (name, attr) +                for name, attr in cd.items() +                if isinstance(attr, _CountingAttr) +            ), +            key=lambda e: e[1].counter, +        ) + +    own_attrs = [ +        Attribute.from_counting_attr( +            name=attr_name, ca=ca, type=anns.get(attr_name) +        ) +        for attr_name, ca in ca_list +    ] + +    base_attrs = [] +    base_attr_map = {}  # A dictionary of base attrs to their classes. +    taken_attr_names = {a.name: a for a in own_attrs} + +    # Traverse the MRO and collect attributes. +    for base_cls in cls.__mro__[1:-1]: +        sub_attrs = getattr(base_cls, "__attrs_attrs__", None) +        if sub_attrs is not None: +            for a in sub_attrs: +                prev_a = taken_attr_names.get(a.name) +                # Only add an attribute if it hasn't been defined before.  This +                # allows for overwriting attribute definitions by subclassing. +                if prev_a is None: +                    base_attrs.append(a) +                    taken_attr_names[a.name] = a +                    base_attr_map[a.name] = base_cls + +    attr_names = [a.name for a in base_attrs + own_attrs] + +    AttrsClass = _make_attr_tuple_class(cls.__name__, attr_names) + +    if kw_only: +        own_attrs = [a._assoc(kw_only=True) for a in own_attrs] +        base_attrs = [a._assoc(kw_only=True) for a in base_attrs] + +    attrs = AttrsClass(base_attrs + own_attrs) + +    had_default = False +    was_kw_only = False +    for a in attrs: +        if ( +            was_kw_only is False +            and had_default is True +            and a.default is NOTHING +            and a.init is True +            and a.kw_only is False +        ): +            raise ValueError( +                "No mandatory attributes allowed after an attribute with a " +                "default value or factory.  Attribute in question: %r" % (a,) +            ) +        elif ( +            had_default is False +            and a.default is not NOTHING +            and a.init is not False +            and +            # Keyword-only attributes without defaults can be specified +            # after keyword-only attributes with defaults. +            a.kw_only is False +        ): +            had_default = True +        if was_kw_only is True and a.kw_only is False: +            raise ValueError( +                "Non keyword-only attributes are not allowed after a " +                "keyword-only attribute.  Attribute in question: {a!r}".format( +                    a=a +                ) +            ) +        if was_kw_only is False and a.init is True and a.kw_only is True: +            was_kw_only = True + +    return _Attributes((attrs, base_attrs, base_attr_map)) + + +def _frozen_setattrs(self, name, value): +    """ +    Attached to frozen classes as __setattr__. +    """ +    raise FrozenInstanceError() + + +def _frozen_delattrs(self, name): +    """ +    Attached to frozen classes as __delattr__. +    """ +    raise FrozenInstanceError() + + +class _ClassBuilder(object): +    """ +    Iteratively build *one* class. +    """ + +    __slots__ = ( +        "_cls", +        "_cls_dict", +        "_attrs", +        "_base_names", +        "_attr_names", +        "_slots", +        "_frozen", +        "_weakref_slot", +        "_cache_hash", +        "_has_post_init", +        "_delete_attribs", +        "_base_attr_map", +    ) + +    def __init__( +        self, +        cls, +        these, +        slots, +        frozen, +        weakref_slot, +        auto_attribs, +        kw_only, +        cache_hash, +    ): +        attrs, base_attrs, base_map = _transform_attrs( +            cls, these, auto_attribs, kw_only +        ) + +        self._cls = cls +        self._cls_dict = dict(cls.__dict__) if slots else {} +        self._attrs = attrs +        self._base_names = set(a.name for a in base_attrs) +        self._base_attr_map = base_map +        self._attr_names = tuple(a.name for a in attrs) +        self._slots = slots +        self._frozen = frozen or _has_frozen_base_class(cls) +        self._weakref_slot = weakref_slot +        self._cache_hash = cache_hash +        self._has_post_init = bool(getattr(cls, "__attrs_post_init__", False)) +        self._delete_attribs = not bool(these) + +        self._cls_dict["__attrs_attrs__"] = self._attrs + +        if frozen: +            self._cls_dict["__setattr__"] = _frozen_setattrs +            self._cls_dict["__delattr__"] = _frozen_delattrs + +    def __repr__(self): +        return "<_ClassBuilder(cls={cls})>".format(cls=self._cls.__name__) + +    def build_class(self): +        """ +        Finalize class based on the accumulated configuration. + +        Builder cannot be used after calling this method. +        """ +        if self._slots is True: +            return self._create_slots_class() +        else: +            return self._patch_original_class() + +    def _patch_original_class(self): +        """ +        Apply accumulated methods and return the class. +        """ +        cls = self._cls +        base_names = self._base_names + +        # Clean class of attribute definitions (`attr.ib()`s). +        if self._delete_attribs: +            for name in self._attr_names: +                if ( +                    name not in base_names +                    and getattr(cls, name, None) is not None +                ): +                    try: +                        delattr(cls, name) +                    except AttributeError: +                        # This can happen if a base class defines a class +                        # variable and we want to set an attribute with the +                        # same name by using only a type annotation. +                        pass + +        # Attach our dunder methods. +        for name, value in self._cls_dict.items(): +            setattr(cls, name, value) + +        return cls + +    def _create_slots_class(self): +        """ +        Build and return a new class with a `__slots__` attribute. +        """ +        base_names = self._base_names +        cd = { +            k: v +            for k, v in iteritems(self._cls_dict) +            if k not in tuple(self._attr_names) + ("__dict__", "__weakref__") +        } + +        weakref_inherited = False + +        # Traverse the MRO to check for an existing __weakref__. +        for base_cls in self._cls.__mro__[1:-1]: +            if "__weakref__" in getattr(base_cls, "__dict__", ()): +                weakref_inherited = True +                break + +        names = self._attr_names +        if ( +            self._weakref_slot +            and "__weakref__" not in getattr(self._cls, "__slots__", ()) +            and "__weakref__" not in names +            and not weakref_inherited +        ): +            names += ("__weakref__",) + +        # We only add the names of attributes that aren't inherited. +        # Settings __slots__ to inherited attributes wastes memory. +        slot_names = [name for name in names if name not in base_names] +        if self._cache_hash: +            slot_names.append(_hash_cache_field) +        cd["__slots__"] = tuple(slot_names) + +        qualname = getattr(self._cls, "__qualname__", None) +        if qualname is not None: +            cd["__qualname__"] = qualname + +        # __weakref__ is not writable. +        state_attr_names = tuple( +            an for an in self._attr_names if an != "__weakref__" +        ) + +        def slots_getstate(self): +            """ +            Automatically created by attrs. +            """ +            return tuple(getattr(self, name) for name in state_attr_names) + +        def slots_setstate(self, state): +            """ +            Automatically created by attrs. +            """ +            __bound_setattr = _obj_setattr.__get__(self, Attribute) +            for name, value in zip(state_attr_names, state): +                __bound_setattr(name, value) + +        # slots and frozen require __getstate__/__setstate__ to work +        cd["__getstate__"] = slots_getstate +        cd["__setstate__"] = slots_setstate + +        # Create new class based on old class and our methods. +        cls = type(self._cls)(self._cls.__name__, self._cls.__bases__, cd) + +        # The following is a fix for +        # https://github.com/python-attrs/attrs/issues/102.  On Python 3, +        # if a method mentions `__class__` or uses the no-arg super(), the +        # compiler will bake a reference to the class in the method itself +        # as `method.__closure__`.  Since we replace the class with a +        # clone, we rewrite these references so it keeps working. +        for item in cls.__dict__.values(): +            if isinstance(item, (classmethod, staticmethod)): +                # Class- and staticmethods hide their functions inside. +                # These might need to be rewritten as well. +                closure_cells = getattr(item.__func__, "__closure__", None) +            else: +                closure_cells = getattr(item, "__closure__", None) + +            if not closure_cells:  # Catch None or the empty list. +                continue +            for cell in closure_cells: +                if cell.cell_contents is self._cls: +                    set_closure_cell(cell, cls) + +        return cls + +    def add_repr(self, ns): +        self._cls_dict["__repr__"] = self._add_method_dunders( +            _make_repr(self._attrs, ns=ns) +        ) +        return self + +    def add_str(self): +        repr = self._cls_dict.get("__repr__") +        if repr is None: +            raise ValueError( +                "__str__ can only be generated if a __repr__ exists." +            ) + +        def __str__(self): +            return self.__repr__() + +        self._cls_dict["__str__"] = self._add_method_dunders(__str__) +        return self + +    def make_unhashable(self): +        self._cls_dict["__hash__"] = None +        return self + +    def add_hash(self): +        self._cls_dict["__hash__"] = self._add_method_dunders( +            _make_hash( +                self._attrs, frozen=self._frozen, cache_hash=self._cache_hash +            ) +        ) + +        return self + +    def add_init(self): +        self._cls_dict["__init__"] = self._add_method_dunders( +            _make_init( +                self._attrs, +                self._has_post_init, +                self._frozen, +                self._slots, +                self._cache_hash, +                self._base_attr_map, +            ) +        ) + +        return self + +    def add_cmp(self): +        cd = self._cls_dict + +        cd["__eq__"], cd["__ne__"], cd["__lt__"], cd["__le__"], cd[ +            "__gt__" +        ], cd["__ge__"] = ( +            self._add_method_dunders(meth) for meth in _make_cmp(self._attrs) +        ) + +        return self + +    def _add_method_dunders(self, method): +        """ +        Add __module__ and __qualname__ to a *method* if possible. +        """ +        try: +            method.__module__ = self._cls.__module__ +        except AttributeError: +            pass + +        try: +            method.__qualname__ = ".".join( +                (self._cls.__qualname__, method.__name__) +            ) +        except AttributeError: +            pass + +        return method + + +def attrs( +    maybe_cls=None, +    these=None, +    repr_ns=None, +    repr=True, +    cmp=True, +    hash=None, +    init=True, +    slots=False, +    frozen=False, +    weakref_slot=True, +    str=False, +    auto_attribs=False, +    kw_only=False, +    cache_hash=False, +): +    r""" +    A class decorator that adds `dunder +    <https://wiki.python.org/moin/DunderAlias>`_\ -methods according to the +    specified attributes using :func:`attr.ib` or the *these* argument. + +    :param these: A dictionary of name to :func:`attr.ib` mappings.  This is +        useful to avoid the definition of your attributes within the class body +        because you can't (e.g. if you want to add ``__repr__`` methods to +        Django models) or don't want to. + +        If *these* is not ``None``, ``attrs`` will *not* search the class body +        for attributes and will *not* remove any attributes from it. + +        If *these* is an ordered dict (:class:`dict` on Python 3.6+, +        :class:`collections.OrderedDict` otherwise), the order is deduced from +        the order of the attributes inside *these*.  Otherwise the order +        of the definition of the attributes is used. + +    :type these: :class:`dict` of :class:`str` to :func:`attr.ib` + +    :param str repr_ns: When using nested classes, there's no way in Python 2 +        to automatically detect that.  Therefore it's possible to set the +        namespace explicitly for a more meaningful ``repr`` output. +    :param bool repr: Create a ``__repr__`` method with a human readable +        representation of ``attrs`` attributes.. +    :param bool str: Create a ``__str__`` method that is identical to +        ``__repr__``.  This is usually not necessary except for +        :class:`Exception`\ s. +    :param bool cmp: Create ``__eq__``, ``__ne__``, ``__lt__``, ``__le__``, +        ``__gt__``, and ``__ge__`` methods that compare the class as if it were +        a tuple of its ``attrs`` attributes.  But the attributes are *only* +        compared, if the types of both classes are *identical*! +    :param hash: If ``None`` (default), the ``__hash__`` method is generated +        according how *cmp* and *frozen* are set. + +        1. If *both* are True, ``attrs`` will generate a ``__hash__`` for you. +        2. If *cmp* is True and *frozen* is False, ``__hash__`` will be set to +           None, marking it unhashable (which it is). +        3. If *cmp* is False, ``__hash__`` will be left untouched meaning the +           ``__hash__`` method of the base class will be used (if base class is +           ``object``, this means it will fall back to id-based hashing.). + +        Although not recommended, you can decide for yourself and force +        ``attrs`` to create one (e.g. if the class is immutable even though you +        didn't freeze it programmatically) by passing ``True`` or not.  Both of +        these cases are rather special and should be used carefully. + +        See the `Python documentation \ +        <https://docs.python.org/3/reference/datamodel.html#object.__hash__>`_ +        and the `GitHub issue that led to the default behavior \ +        <https://github.com/python-attrs/attrs/issues/136>`_ for more details. +    :type hash: ``bool`` or ``None`` +    :param bool init: Create a ``__init__`` method that initializes the +        ``attrs`` attributes.  Leading underscores are stripped for the +        argument name.  If a ``__attrs_post_init__`` method exists on the +        class, it will be called after the class is fully initialized. +    :param bool slots: Create a slots_-style class that's more +        memory-efficient.  See :ref:`slots` for further ramifications. +    :param bool frozen: Make instances immutable after initialization.  If +        someone attempts to modify a frozen instance, +        :exc:`attr.exceptions.FrozenInstanceError` is raised. + +        Please note: + +            1. This is achieved by installing a custom ``__setattr__`` method +               on your class so you can't implement an own one. + +            2. True immutability is impossible in Python. + +            3. This *does* have a minor a runtime performance :ref:`impact +               <how-frozen>` when initializing new instances.  In other words: +               ``__init__`` is slightly slower with ``frozen=True``. + +            4. If a class is frozen, you cannot modify ``self`` in +               ``__attrs_post_init__`` or a self-written ``__init__``. You can +               circumvent that limitation by using +               ``object.__setattr__(self, "attribute_name", value)``. + +        ..  _slots: https://docs.python.org/3/reference/datamodel.html#slots +    :param bool weakref_slot: Make instances weak-referenceable.  This has no +        effect unless ``slots`` is also enabled. +    :param bool auto_attribs: If True, collect `PEP 526`_-annotated attributes +        (Python 3.6 and later only) from the class body. + +        In this case, you **must** annotate every field.  If ``attrs`` +        encounters a field that is set to an :func:`attr.ib` but lacks a type +        annotation, an :exc:`attr.exceptions.UnannotatedAttributeError` is +        raised.  Use ``field_name: typing.Any = attr.ib(...)`` if you don't +        want to set a type. + +        If you assign a value to those attributes (e.g. ``x: int = 42``), that +        value becomes the default value like if it were passed using +        ``attr.ib(default=42)``.  Passing an instance of :class:`Factory` also +        works as expected. + +        Attributes annotated as :data:`typing.ClassVar` are **ignored**. + +        .. _`PEP 526`: https://www.python.org/dev/peps/pep-0526/ +    :param bool kw_only: Make all attributes keyword-only (Python 3+) +        in the generated ``__init__`` (if ``init`` is ``False``, this +        parameter is ignored). +    :param bool cache_hash: Ensure that the object's hash code is computed +        only once and stored on the object.  If this is set to ``True``, +        hashing must be either explicitly or implicitly enabled for this +        class.  If the hash code is cached, then no attributes of this +        class which participate in hash code computation may be mutated +        after object creation. + + +    .. versionadded:: 16.0.0 *slots* +    .. versionadded:: 16.1.0 *frozen* +    .. versionadded:: 16.3.0 *str* +    .. versionadded:: 16.3.0 Support for ``__attrs_post_init__``. +    .. versionchanged:: 17.1.0 +       *hash* supports ``None`` as value which is also the default now. +    .. versionadded:: 17.3.0 *auto_attribs* +    .. versionchanged:: 18.1.0 +       If *these* is passed, no attributes are deleted from the class body. +    .. versionchanged:: 18.1.0 If *these* is ordered, the order is retained. +    .. versionadded:: 18.2.0 *weakref_slot* +    .. deprecated:: 18.2.0 +       ``__lt__``, ``__le__``, ``__gt__``, and ``__ge__`` now raise a +       :class:`DeprecationWarning` if the classes compared are subclasses of +       each other. ``__eq`` and ``__ne__`` never tried to compared subclasses +       to each other. +    .. versionadded:: 18.2.0 *kw_only* +    .. versionadded:: 18.2.0 *cache_hash* +    """ + +    def wrap(cls): +        if getattr(cls, "__class__", None) is None: +            raise TypeError("attrs only works with new-style classes.") + +        builder = _ClassBuilder( +            cls, +            these, +            slots, +            frozen, +            weakref_slot, +            auto_attribs, +            kw_only, +            cache_hash, +        ) + +        if repr is True: +            builder.add_repr(repr_ns) +        if str is True: +            builder.add_str() +        if cmp is True: +            builder.add_cmp() + +        if hash is not True and hash is not False and hash is not None: +            # Can't use `hash in` because 1 == True for example. +            raise TypeError( +                "Invalid value for hash.  Must be True, False, or None." +            ) +        elif hash is False or (hash is None and cmp is False): +            if cache_hash: +                raise TypeError( +                    "Invalid value for cache_hash.  To use hash caching," +                    " hashing must be either explicitly or implicitly " +                    "enabled." +                ) +        elif hash is True or (hash is None and cmp is True and frozen is True): +            builder.add_hash() +        else: +            if cache_hash: +                raise TypeError( +                    "Invalid value for cache_hash.  To use hash caching," +                    " hashing must be either explicitly or implicitly " +                    "enabled." +                ) +            builder.make_unhashable() + +        if init is True: +            builder.add_init() +        else: +            if cache_hash: +                raise TypeError( +                    "Invalid value for cache_hash.  To use hash caching," +                    " init must be True." +                ) + +        return builder.build_class() + +    # maybe_cls's type depends on the usage of the decorator.  It's a class +    # if it's used as `@attrs` but ``None`` if used as `@attrs()`. +    if maybe_cls is None: +        return wrap +    else: +        return wrap(maybe_cls) + + +_attrs = attrs +""" +Internal alias so we can use it in functions that take an argument called +*attrs*. +""" + + +if PY2: + +    def _has_frozen_base_class(cls): +        """ +        Check whether *cls* has a frozen ancestor by looking at its +        __setattr__. +        """ +        return ( +            getattr(cls.__setattr__, "__module__", None) +            == _frozen_setattrs.__module__ +            and cls.__setattr__.__name__ == _frozen_setattrs.__name__ +        ) + + +else: + +    def _has_frozen_base_class(cls): +        """ +        Check whether *cls* has a frozen ancestor by looking at its +        __setattr__. +        """ +        return cls.__setattr__ == _frozen_setattrs + + +def _attrs_to_tuple(obj, attrs): +    """ +    Create a tuple of all values of *obj*'s *attrs*. +    """ +    return tuple(getattr(obj, a.name) for a in attrs) + + +def _make_hash(attrs, frozen, cache_hash): +    attrs = tuple( +        a +        for a in attrs +        if a.hash is True or (a.hash is None and a.cmp is True) +    ) + +    tab = "        " + +    # We cache the generated hash methods for the same kinds of attributes. +    sha1 = hashlib.sha1() +    sha1.update(repr(attrs).encode("utf-8")) +    unique_filename = "<attrs generated hash %s>" % (sha1.hexdigest(),) +    type_hash = hash(unique_filename) + +    method_lines = ["def __hash__(self):"] + +    def append_hash_computation_lines(prefix, indent): +        """ +        Generate the code for actually computing the hash code. +        Below this will either be returned directly or used to compute +        a value which is then cached, depending on the value of cache_hash +        """ +        method_lines.extend( +            [indent + prefix + "hash((", indent + "        %d," % (type_hash,)] +        ) + +        for a in attrs: +            method_lines.append(indent + "        self.%s," % a.name) + +        method_lines.append(indent + "    ))") + +    if cache_hash: +        method_lines.append(tab + "if self.%s is None:" % _hash_cache_field) +        if frozen: +            append_hash_computation_lines( +                "object.__setattr__(self, '%s', " % _hash_cache_field, tab * 2 +            ) +            method_lines.append(tab * 2 + ")")  # close __setattr__ +        else: +            append_hash_computation_lines( +                "self.%s = " % _hash_cache_field, tab * 2 +            ) +        method_lines.append(tab + "return self.%s" % _hash_cache_field) +    else: +        append_hash_computation_lines("return ", tab) + +    script = "\n".join(method_lines) +    globs = {} +    locs = {} +    bytecode = compile(script, unique_filename, "exec") +    eval(bytecode, globs, locs) + +    # In order of debuggers like PDB being able to step through the code, +    # we add a fake linecache entry. +    linecache.cache[unique_filename] = ( +        len(script), +        None, +        script.splitlines(True), +        unique_filename, +    ) + +    return locs["__hash__"] + + +def _add_hash(cls, attrs): +    """ +    Add a hash method to *cls*. +    """ +    cls.__hash__ = _make_hash(attrs, frozen=False, cache_hash=False) +    return cls + + +def __ne__(self, other): +    """ +    Check equality and either forward a NotImplemented or return the result +    negated. +    """ +    result = self.__eq__(other) +    if result is NotImplemented: +        return NotImplemented + +    return not result + + +WARNING_CMP_ISINSTANCE = ( +    "Comparision of subclasses using __%s__ is deprecated and will be removed " +    "in 2019." +) + + +def _make_cmp(attrs): +    attrs = [a for a in attrs if a.cmp] + +    # We cache the generated eq methods for the same kinds of attributes. +    sha1 = hashlib.sha1() +    sha1.update(repr(attrs).encode("utf-8")) +    unique_filename = "<attrs generated eq %s>" % (sha1.hexdigest(),) +    lines = [ +        "def __eq__(self, other):", +        "    if other.__class__ is not self.__class__:", +        "        return NotImplemented", +    ] +    # We can't just do a big self.x = other.x and... clause due to +    # irregularities like nan == nan is false but (nan,) == (nan,) is true. +    if attrs: +        lines.append("    return  (") +        others = ["    ) == ("] +        for a in attrs: +            lines.append("        self.%s," % (a.name,)) +            others.append("        other.%s," % (a.name,)) + +        lines += others + ["    )"] +    else: +        lines.append("    return True") + +    script = "\n".join(lines) +    globs = {} +    locs = {} +    bytecode = compile(script, unique_filename, "exec") +    eval(bytecode, globs, locs) + +    # In order of debuggers like PDB being able to step through the code, +    # we add a fake linecache entry. +    linecache.cache[unique_filename] = ( +        len(script), +        None, +        script.splitlines(True), +        unique_filename, +    ) +    eq = locs["__eq__"] +    ne = __ne__ + +    def attrs_to_tuple(obj): +        """ +        Save us some typing. +        """ +        return _attrs_to_tuple(obj, attrs) + +    def __lt__(self, other): +        """ +        Automatically created by attrs. +        """ +        if isinstance(other, self.__class__): +            if other.__class__ is not self.__class__: +                warnings.warn( +                    WARNING_CMP_ISINSTANCE % ("lt",), DeprecationWarning +                ) +            return attrs_to_tuple(self) < attrs_to_tuple(other) +        else: +            return NotImplemented + +    def __le__(self, other): +        """ +        Automatically created by attrs. +        """ +        if isinstance(other, self.__class__): +            if other.__class__ is not self.__class__: +                warnings.warn( +                    WARNING_CMP_ISINSTANCE % ("le",), DeprecationWarning +                ) +            return attrs_to_tuple(self) <= attrs_to_tuple(other) +        else: +            return NotImplemented + +    def __gt__(self, other): +        """ +        Automatically created by attrs. +        """ +        if isinstance(other, self.__class__): +            if other.__class__ is not self.__class__: +                warnings.warn( +                    WARNING_CMP_ISINSTANCE % ("gt",), DeprecationWarning +                ) +            return attrs_to_tuple(self) > attrs_to_tuple(other) +        else: +            return NotImplemented + +    def __ge__(self, other): +        """ +        Automatically created by attrs. +        """ +        if isinstance(other, self.__class__): +            if other.__class__ is not self.__class__: +                warnings.warn( +                    WARNING_CMP_ISINSTANCE % ("ge",), DeprecationWarning +                ) +            return attrs_to_tuple(self) >= attrs_to_tuple(other) +        else: +            return NotImplemented + +    return eq, ne, __lt__, __le__, __gt__, __ge__ + + +def _add_cmp(cls, attrs=None): +    """ +    Add comparison methods to *cls*. +    """ +    if attrs is None: +        attrs = cls.__attrs_attrs__ + +    cls.__eq__, cls.__ne__, cls.__lt__, cls.__le__, cls.__gt__, cls.__ge__ = _make_cmp(  # noqa +        attrs +    ) + +    return cls + + +_already_repring = threading.local() + + +def _make_repr(attrs, ns): +    """ +    Make a repr method for *attr_names* adding *ns* to the full name. +    """ +    attr_names = tuple(a.name for a in attrs if a.repr) + +    def __repr__(self): +        """ +        Automatically created by attrs. +        """ +        try: +            working_set = _already_repring.working_set +        except AttributeError: +            working_set = set() +            _already_repring.working_set = working_set + +        if id(self) in working_set: +            return "..." +        real_cls = self.__class__ +        if ns is None: +            qualname = getattr(real_cls, "__qualname__", None) +            if qualname is not None: +                class_name = qualname.rsplit(">.", 1)[-1] +            else: +                class_name = real_cls.__name__ +        else: +            class_name = ns + "." + real_cls.__name__ + +        # Since 'self' remains on the stack (i.e.: strongly referenced) for the +        # duration of this call, it's safe to depend on id(...) stability, and +        # not need to track the instance and therefore worry about properties +        # like weakref- or hash-ability. +        working_set.add(id(self)) +        try: +            result = [class_name, "("] +            first = True +            for name in attr_names: +                if first: +                    first = False +                else: +                    result.append(", ") +                result.extend((name, "=", repr(getattr(self, name, NOTHING)))) +            return "".join(result) + ")" +        finally: +            working_set.remove(id(self)) + +    return __repr__ + + +def _add_repr(cls, ns=None, attrs=None): +    """ +    Add a repr method to *cls*. +    """ +    if attrs is None: +        attrs = cls.__attrs_attrs__ + +    cls.__repr__ = _make_repr(attrs, ns) +    return cls + + +def _make_init(attrs, post_init, frozen, slots, cache_hash, base_attr_map): +    attrs = [a for a in attrs if a.init or a.default is not NOTHING] + +    # We cache the generated init methods for the same kinds of attributes. +    sha1 = hashlib.sha1() +    sha1.update(repr(attrs).encode("utf-8")) +    unique_filename = "<attrs generated init {0}>".format(sha1.hexdigest()) + +    script, globs, annotations = _attrs_to_init_script( +        attrs, frozen, slots, post_init, cache_hash, base_attr_map +    ) +    locs = {} +    bytecode = compile(script, unique_filename, "exec") +    attr_dict = dict((a.name, a) for a in attrs) +    globs.update({"NOTHING": NOTHING, "attr_dict": attr_dict}) +    if frozen is True: +        # Save the lookup overhead in __init__ if we need to circumvent +        # immutability. +        globs["_cached_setattr"] = _obj_setattr +    eval(bytecode, globs, locs) + +    # In order of debuggers like PDB being able to step through the code, +    # we add a fake linecache entry. +    linecache.cache[unique_filename] = ( +        len(script), +        None, +        script.splitlines(True), +        unique_filename, +    ) + +    __init__ = locs["__init__"] +    __init__.__annotations__ = annotations +    return __init__ + + +def _add_init(cls, frozen): +    """ +    Add a __init__ method to *cls*.  If *frozen* is True, make it immutable. +    """ +    cls.__init__ = _make_init( +        cls.__attrs_attrs__, +        getattr(cls, "__attrs_post_init__", False), +        frozen, +        _is_slot_cls(cls), +        cache_hash=False, +        base_attr_map={}, +    ) +    return cls + + +def fields(cls): +    """ +    Return the tuple of ``attrs`` attributes for a class. + +    The tuple also allows accessing the fields by their names (see below for +    examples). + +    :param type cls: Class to introspect. + +    :raise TypeError: If *cls* is not a class. +    :raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs`` +        class. + +    :rtype: tuple (with name accessors) of :class:`attr.Attribute` + +    ..  versionchanged:: 16.2.0 Returned tuple allows accessing the fields +        by name. +    """ +    if not isclass(cls): +        raise TypeError("Passed object must be a class.") +    attrs = getattr(cls, "__attrs_attrs__", None) +    if attrs is None: +        raise NotAnAttrsClassError( +            "{cls!r} is not an attrs-decorated class.".format(cls=cls) +        ) +    return attrs + + +def fields_dict(cls): +    """ +    Return an ordered dictionary of ``attrs`` attributes for a class, whose +    keys are the attribute names. + +    :param type cls: Class to introspect. + +    :raise TypeError: If *cls* is not a class. +    :raise attr.exceptions.NotAnAttrsClassError: If *cls* is not an ``attrs`` +        class. + +    :rtype: an ordered dict where keys are attribute names and values are +        :class:`attr.Attribute`\\ s. This will be a :class:`dict` if it's +        naturally ordered like on Python 3.6+ or an +        :class:`~collections.OrderedDict` otherwise. + +    .. versionadded:: 18.1.0 +    """ +    if not isclass(cls): +        raise TypeError("Passed object must be a class.") +    attrs = getattr(cls, "__attrs_attrs__", None) +    if attrs is None: +        raise NotAnAttrsClassError( +            "{cls!r} is not an attrs-decorated class.".format(cls=cls) +        ) +    return ordered_dict(((a.name, a) for a in attrs)) + + +def validate(inst): +    """ +    Validate all attributes on *inst* that have a validator. + +    Leaves all exceptions through. + +    :param inst: Instance of a class with ``attrs`` attributes. +    """ +    if _config._run_validators is False: +        return + +    for a in fields(inst.__class__): +        v = a.validator +        if v is not None: +            v(inst, a, getattr(inst, a.name)) + + +def _is_slot_cls(cls): +    return "__slots__" in cls.__dict__ + + +def _is_slot_attr(a_name, base_attr_map): +    """ +    Check if the attribute name comes from a slot class. +    """ +    return a_name in base_attr_map and _is_slot_cls(base_attr_map[a_name]) + + +def _attrs_to_init_script( +    attrs, frozen, slots, post_init, cache_hash, base_attr_map +): +    """ +    Return a script of an initializer for *attrs* and a dict of globals. + +    The globals are expected by the generated script. + +    If *frozen* is True, we cannot set the attributes directly so we use +    a cached ``object.__setattr__``. +    """ +    lines = [] +    any_slot_ancestors = any( +        _is_slot_attr(a.name, base_attr_map) for a in attrs +    ) +    if frozen is True: +        if slots is True: +            lines.append( +                # Circumvent the __setattr__ descriptor to save one lookup per +                # assignment. +                # Note _setattr will be used again below if cache_hash is True +                "_setattr = _cached_setattr.__get__(self, self.__class__)" +            ) + +            def fmt_setter(attr_name, value_var): +                return "_setattr('%(attr_name)s', %(value_var)s)" % { +                    "attr_name": attr_name, +                    "value_var": value_var, +                } + +            def fmt_setter_with_converter(attr_name, value_var): +                conv_name = _init_converter_pat.format(attr_name) +                return "_setattr('%(attr_name)s', %(conv)s(%(value_var)s))" % { +                    "attr_name": attr_name, +                    "value_var": value_var, +                    "conv": conv_name, +                } + +        else: +            # Dict frozen classes assign directly to __dict__. +            # But only if the attribute doesn't come from an ancestor slot +            # class. +            # Note _inst_dict will be used again below if cache_hash is True +            lines.append("_inst_dict = self.__dict__") +            if any_slot_ancestors: +                lines.append( +                    # Circumvent the __setattr__ descriptor to save one lookup +                    # per assignment. +                    "_setattr = _cached_setattr.__get__(self, self.__class__)" +                ) + +            def fmt_setter(attr_name, value_var): +                if _is_slot_attr(attr_name, base_attr_map): +                    res = "_setattr('%(attr_name)s', %(value_var)s)" % { +                        "attr_name": attr_name, +                        "value_var": value_var, +                    } +                else: +                    res = "_inst_dict['%(attr_name)s'] = %(value_var)s" % { +                        "attr_name": attr_name, +                        "value_var": value_var, +                    } +                return res + +            def fmt_setter_with_converter(attr_name, value_var): +                conv_name = _init_converter_pat.format(attr_name) +                if _is_slot_attr(attr_name, base_attr_map): +                    tmpl = "_setattr('%(attr_name)s', %(c)s(%(value_var)s))" +                else: +                    tmpl = "_inst_dict['%(attr_name)s'] = %(c)s(%(value_var)s)" +                return tmpl % { +                    "attr_name": attr_name, +                    "value_var": value_var, +                    "c": conv_name, +                } + +    else: +        # Not frozen. +        def fmt_setter(attr_name, value): +            return "self.%(attr_name)s = %(value)s" % { +                "attr_name": attr_name, +                "value": value, +            } + +        def fmt_setter_with_converter(attr_name, value_var): +            conv_name = _init_converter_pat.format(attr_name) +            return "self.%(attr_name)s = %(conv)s(%(value_var)s)" % { +                "attr_name": attr_name, +                "value_var": value_var, +                "conv": conv_name, +            } + +    args = [] +    kw_only_args = [] +    attrs_to_validate = [] + +    # This is a dictionary of names to validator and converter callables. +    # Injecting this into __init__ globals lets us avoid lookups. +    names_for_globals = {} +    annotations = {"return": None} + +    for a in attrs: +        if a.validator: +            attrs_to_validate.append(a) +        attr_name = a.name +        arg_name = a.name.lstrip("_") +        has_factory = isinstance(a.default, Factory) +        if has_factory and a.default.takes_self: +            maybe_self = "self" +        else: +            maybe_self = "" +        if a.init is False: +            if has_factory: +                init_factory_name = _init_factory_pat.format(a.name) +                if a.converter is not None: +                    lines.append( +                        fmt_setter_with_converter( +                            attr_name, +                            init_factory_name + "({0})".format(maybe_self), +                        ) +                    ) +                    conv_name = _init_converter_pat.format(a.name) +                    names_for_globals[conv_name] = a.converter +                else: +                    lines.append( +                        fmt_setter( +                            attr_name, +                            init_factory_name + "({0})".format(maybe_self), +                        ) +                    ) +                names_for_globals[init_factory_name] = a.default.factory +            else: +                if a.converter is not None: +                    lines.append( +                        fmt_setter_with_converter( +                            attr_name, +                            "attr_dict['{attr_name}'].default".format( +                                attr_name=attr_name +                            ), +                        ) +                    ) +                    conv_name = _init_converter_pat.format(a.name) +                    names_for_globals[conv_name] = a.converter +                else: +                    lines.append( +                        fmt_setter( +                            attr_name, +                            "attr_dict['{attr_name}'].default".format( +                                attr_name=attr_name +                            ), +                        ) +                    ) +        elif a.default is not NOTHING and not has_factory: +            arg = "{arg_name}=attr_dict['{attr_name}'].default".format( +                arg_name=arg_name, attr_name=attr_name +            ) +            if a.kw_only: +                kw_only_args.append(arg) +            else: +                args.append(arg) +            if a.converter is not None: +                lines.append(fmt_setter_with_converter(attr_name, arg_name)) +                names_for_globals[ +                    _init_converter_pat.format(a.name) +                ] = a.converter +            else: +                lines.append(fmt_setter(attr_name, arg_name)) +        elif has_factory: +            arg = "{arg_name}=NOTHING".format(arg_name=arg_name) +            if a.kw_only: +                kw_only_args.append(arg) +            else: +                args.append(arg) +            lines.append( +                "if {arg_name} is not NOTHING:".format(arg_name=arg_name) +            ) +            init_factory_name = _init_factory_pat.format(a.name) +            if a.converter is not None: +                lines.append( +                    "    " + fmt_setter_with_converter(attr_name, arg_name) +                ) +                lines.append("else:") +                lines.append( +                    "    " +                    + fmt_setter_with_converter( +                        attr_name, +                        init_factory_name + "({0})".format(maybe_self), +                    ) +                ) +                names_for_globals[ +                    _init_converter_pat.format(a.name) +                ] = a.converter +            else: +                lines.append("    " + fmt_setter(attr_name, arg_name)) +                lines.append("else:") +                lines.append( +                    "    " +                    + fmt_setter( +                        attr_name, +                        init_factory_name + "({0})".format(maybe_self), +                    ) +                ) +            names_for_globals[init_factory_name] = a.default.factory +        else: +            if a.kw_only: +                kw_only_args.append(arg_name) +            else: +                args.append(arg_name) +            if a.converter is not None: +                lines.append(fmt_setter_with_converter(attr_name, arg_name)) +                names_for_globals[ +                    _init_converter_pat.format(a.name) +                ] = a.converter +            else: +                lines.append(fmt_setter(attr_name, arg_name)) + +        if a.init is True and a.converter is None and a.type is not None: +            annotations[arg_name] = a.type + +    if attrs_to_validate:  # we can skip this if there are no validators. +        names_for_globals["_config"] = _config +        lines.append("if _config._run_validators is True:") +        for a in attrs_to_validate: +            val_name = "__attr_validator_{}".format(a.name) +            attr_name = "__attr_{}".format(a.name) +            lines.append( +                "    {}(self, {}, self.{})".format(val_name, attr_name, a.name) +            ) +            names_for_globals[val_name] = a.validator +            names_for_globals[attr_name] = a +    if post_init: +        lines.append("self.__attrs_post_init__()") + +    # because this is set only after __attrs_post_init is called, a crash +    # will result if post-init tries to access the hash code.  This seemed +    # preferable to setting this beforehand, in which case alteration to +    # field values during post-init combined with post-init accessing the +    # hash code would result in silent bugs. +    if cache_hash: +        if frozen: +            if slots: +                # if frozen and slots, then _setattr defined above +                init_hash_cache = "_setattr('%s', %s)" +            else: +                # if frozen and not slots, then _inst_dict defined above +                init_hash_cache = "_inst_dict['%s'] = %s" +        else: +            init_hash_cache = "self.%s = %s" +        lines.append(init_hash_cache % (_hash_cache_field, "None")) + +    args = ", ".join(args) +    if kw_only_args: +        if PY2: +            raise PythonTooOldError( +                "Keyword-only arguments only work on Python 3 and later." +            ) + +        args += "{leading_comma}*, {kw_only_args}".format( +            leading_comma=", " if args else "", +            kw_only_args=", ".join(kw_only_args), +        ) +    return ( +        """\ +def __init__(self, {args}): +    {lines} +""".format( +            args=args, lines="\n    ".join(lines) if lines else "pass" +        ), +        names_for_globals, +        annotations, +    ) + + +class Attribute(object): +    """ +    *Read-only* representation of an attribute. + +    :attribute name: The name of the attribute. + +    Plus *all* arguments of :func:`attr.ib`. + +    For the version history of the fields, see :func:`attr.ib`. +    """ + +    __slots__ = ( +        "name", +        "default", +        "validator", +        "repr", +        "cmp", +        "hash", +        "init", +        "metadata", +        "type", +        "converter", +        "kw_only", +    ) + +    def __init__( +        self, +        name, +        default, +        validator, +        repr, +        cmp, +        hash, +        init, +        convert=None, +        metadata=None, +        type=None, +        converter=None, +        kw_only=False, +    ): +        # Cache this descriptor here to speed things up later. +        bound_setattr = _obj_setattr.__get__(self, Attribute) + +        # Despite the big red warning, people *do* instantiate `Attribute` +        # themselves. +        if convert is not None: +            if converter is not None: +                raise RuntimeError( +                    "Can't pass both `convert` and `converter`.  " +                    "Please use `converter` only." +                ) +            warnings.warn( +                "The `convert` argument is deprecated in favor of `converter`." +                "  It will be removed after 2019/01.", +                DeprecationWarning, +                stacklevel=2, +            ) +            converter = convert + +        bound_setattr("name", name) +        bound_setattr("default", default) +        bound_setattr("validator", validator) +        bound_setattr("repr", repr) +        bound_setattr("cmp", cmp) +        bound_setattr("hash", hash) +        bound_setattr("init", init) +        bound_setattr("converter", converter) +        bound_setattr( +            "metadata", +            ( +                metadata_proxy(metadata) +                if metadata +                else _empty_metadata_singleton +            ), +        ) +        bound_setattr("type", type) +        bound_setattr("kw_only", kw_only) + +    def __setattr__(self, name, value): +        raise FrozenInstanceError() + +    @property +    def convert(self): +        warnings.warn( +            "The `convert` attribute is deprecated in favor of `converter`.  " +            "It will be removed after 2019/01.", +            DeprecationWarning, +            stacklevel=2, +        ) +        return self.converter + +    @classmethod +    def from_counting_attr(cls, name, ca, type=None): +        # type holds the annotated value. deal with conflicts: +        if type is None: +            type = ca.type +        elif ca.type is not None: +            raise ValueError( +                "Type annotation and type argument cannot both be present" +            ) +        inst_dict = { +            k: getattr(ca, k) +            for k in Attribute.__slots__ +            if k +            not in ( +                "name", +                "validator", +                "default", +                "type", +                "convert", +            )  # exclude methods and deprecated alias +        } +        return cls( +            name=name, +            validator=ca._validator, +            default=ca._default, +            type=type, +            **inst_dict +        ) + +    # Don't use attr.assoc since fields(Attribute) doesn't work +    def _assoc(self, **changes): +        """ +        Copy *self* and apply *changes*. +        """ +        new = copy.copy(self) + +        new._setattrs(changes.items()) + +        return new + +    # Don't use _add_pickle since fields(Attribute) doesn't work +    def __getstate__(self): +        """ +        Play nice with pickle. +        """ +        return tuple( +            getattr(self, name) if name != "metadata" else dict(self.metadata) +            for name in self.__slots__ +        ) + +    def __setstate__(self, state): +        """ +        Play nice with pickle. +        """ +        self._setattrs(zip(self.__slots__, state)) + +    def _setattrs(self, name_values_pairs): +        bound_setattr = _obj_setattr.__get__(self, Attribute) +        for name, value in name_values_pairs: +            if name != "metadata": +                bound_setattr(name, value) +            else: +                bound_setattr( +                    name, +                    metadata_proxy(value) +                    if value +                    else _empty_metadata_singleton, +                ) + + +_a = [ +    Attribute( +        name=name, +        default=NOTHING, +        validator=None, +        repr=True, +        cmp=True, +        hash=(name != "metadata"), +        init=True, +    ) +    for name in Attribute.__slots__ +    if name != "convert"  # XXX: remove once `convert` is gone +] + +Attribute = _add_hash( +    _add_cmp(_add_repr(Attribute, attrs=_a), attrs=_a), +    attrs=[a for a in _a if a.hash], +) + + +class _CountingAttr(object): +    """ +    Intermediate representation of attributes that uses a counter to preserve +    the order in which the attributes have been defined. + +    *Internal* data structure of the attrs library.  Running into is most +    likely the result of a bug like a forgotten `@attr.s` decorator. +    """ + +    __slots__ = ( +        "counter", +        "_default", +        "repr", +        "cmp", +        "hash", +        "init", +        "metadata", +        "_validator", +        "converter", +        "type", +        "kw_only", +    ) +    __attrs_attrs__ = tuple( +        Attribute( +            name=name, +            default=NOTHING, +            validator=None, +            repr=True, +            cmp=True, +            hash=True, +            init=True, +            kw_only=False, +        ) +        for name in ("counter", "_default", "repr", "cmp", "hash", "init") +    ) + ( +        Attribute( +            name="metadata", +            default=None, +            validator=None, +            repr=True, +            cmp=True, +            hash=False, +            init=True, +            kw_only=False, +        ), +    ) +    cls_counter = 0 + +    def __init__( +        self, +        default, +        validator, +        repr, +        cmp, +        hash, +        init, +        converter, +        metadata, +        type, +        kw_only, +    ): +        _CountingAttr.cls_counter += 1 +        self.counter = _CountingAttr.cls_counter +        self._default = default +        # If validator is a list/tuple, wrap it using helper validator. +        if validator and isinstance(validator, (list, tuple)): +            self._validator = and_(*validator) +        else: +            self._validator = validator +        self.repr = repr +        self.cmp = cmp +        self.hash = hash +        self.init = init +        self.converter = converter +        self.metadata = metadata +        self.type = type +        self.kw_only = kw_only + +    def validator(self, meth): +        """ +        Decorator that adds *meth* to the list of validators. + +        Returns *meth* unchanged. + +        .. versionadded:: 17.1.0 +        """ +        if self._validator is None: +            self._validator = meth +        else: +            self._validator = and_(self._validator, meth) +        return meth + +    def default(self, meth): +        """ +        Decorator that allows to set the default for an attribute. + +        Returns *meth* unchanged. + +        :raises DefaultAlreadySetError: If default has been set before. + +        .. versionadded:: 17.1.0 +        """ +        if self._default is not NOTHING: +            raise DefaultAlreadySetError() + +        self._default = Factory(meth, takes_self=True) + +        return meth + + +_CountingAttr = _add_cmp(_add_repr(_CountingAttr)) + + +@attrs(slots=True, init=False, hash=True) +class Factory(object): +    """ +    Stores a factory callable. + +    If passed as the default value to :func:`attr.ib`, the factory is used to +    generate a new value. + +    :param callable factory: A callable that takes either none or exactly one +        mandatory positional argument depending on *takes_self*. +    :param bool takes_self: Pass the partially initialized instance that is +        being initialized as a positional argument. + +    .. versionadded:: 17.1.0  *takes_self* +    """ + +    factory = attrib() +    takes_self = attrib() + +    def __init__(self, factory, takes_self=False): +        """ +        `Factory` is part of the default machinery so if we want a default +        value here, we have to implement it ourselves. +        """ +        self.factory = factory +        self.takes_self = takes_self + + +def make_class(name, attrs, bases=(object,), **attributes_arguments): +    """ +    A quick way to create a new class called *name* with *attrs*. + +    :param name: The name for the new class. +    :type name: str + +    :param attrs: A list of names or a dictionary of mappings of names to +        attributes. + +        If *attrs* is a list or an ordered dict (:class:`dict` on Python 3.6+, +        :class:`collections.OrderedDict` otherwise), the order is deduced from +        the order of the names or attributes inside *attrs*.  Otherwise the +        order of the definition of the attributes is used. +    :type attrs: :class:`list` or :class:`dict` + +    :param tuple bases: Classes that the new class will subclass. + +    :param attributes_arguments: Passed unmodified to :func:`attr.s`. + +    :return: A new class with *attrs*. +    :rtype: type + +    .. versionadded:: 17.1.0 *bases* +    .. versionchanged:: 18.1.0 If *attrs* is ordered, the order is retained. +    """ +    if isinstance(attrs, dict): +        cls_dict = attrs +    elif isinstance(attrs, (list, tuple)): +        cls_dict = dict((a, attrib()) for a in attrs) +    else: +        raise TypeError("attrs argument must be a dict or a list.") + +    post_init = cls_dict.pop("__attrs_post_init__", None) +    type_ = type( +        name, +        bases, +        {} if post_init is None else {"__attrs_post_init__": post_init}, +    ) +    # For pickling to work, the __module__ variable needs to be set to the +    # frame where the class is created.  Bypass this step in environments where +    # sys._getframe is not defined (Jython for example) or sys._getframe is not +    # defined for arguments greater than 0 (IronPython). +    try: +        type_.__module__ = sys._getframe(1).f_globals.get( +            "__name__", "__main__" +        ) +    except (AttributeError, ValueError): +        pass + +    return _attrs(these=cls_dict, **attributes_arguments)(type_) + + +# These are required by within this module so we define them here and merely +# import into .validators. + + +@attrs(slots=True, hash=True) +class _AndValidator(object): +    """ +    Compose many validators to a single one. +    """ + +    _validators = attrib() + +    def __call__(self, inst, attr, value): +        for v in self._validators: +            v(inst, attr, value) + + +def and_(*validators): +    """ +    A validator that composes multiple validators into one. + +    When called on a value, it runs all wrapped validators. + +    :param validators: Arbitrary number of validators. +    :type validators: callables + +    .. versionadded:: 17.1.0 +    """ +    vals = [] +    for validator in validators: +        vals.extend( +            validator._validators +            if isinstance(validator, _AndValidator) +            else [validator] +        ) + +    return _AndValidator(tuple(vals)) diff --git a/python/attr/converters.py b/python/attr/converters.py new file mode 100644 index 0000000..37c4a07 --- /dev/null +++ b/python/attr/converters.py @@ -0,0 +1,78 @@ +""" +Commonly useful converters. +""" + +from __future__ import absolute_import, division, print_function + +from ._make import NOTHING, Factory + + +def optional(converter): +    """ +    A converter that allows an attribute to be optional. An optional attribute +    is one which can be set to ``None``. + +    :param callable converter: the converter that is used for non-``None`` +        values. + +    .. versionadded:: 17.1.0 +    """ + +    def optional_converter(val): +        if val is None: +            return None +        return converter(val) + +    return optional_converter + + +def default_if_none(default=NOTHING, factory=None): +    """ +    A converter that allows to replace ``None`` values by *default* or the +    result of *factory*. + +    :param default: Value to be used if ``None`` is passed. Passing an instance +       of :class:`attr.Factory` is supported, however the ``takes_self`` option +       is *not*. +    :param callable factory: A callable that takes not parameters whose result +       is used if ``None`` is passed. + +    :raises TypeError: If **neither** *default* or *factory* is passed. +    :raises TypeError: If **both** *default* and *factory* are passed. +    :raises ValueError: If an instance of :class:`attr.Factory` is passed with +       ``takes_self=True``. + +    .. versionadded:: 18.2.0 +    """ +    if default is NOTHING and factory is None: +        raise TypeError("Must pass either `default` or `factory`.") + +    if default is not NOTHING and factory is not None: +        raise TypeError( +            "Must pass either `default` or `factory` but not both." +        ) + +    if factory is not None: +        default = Factory(factory) + +    if isinstance(default, Factory): +        if default.takes_self: +            raise ValueError( +                "`takes_self` is not supported by default_if_none." +            ) + +        def default_if_none_converter(val): +            if val is not None: +                return val + +            return default.factory() + +    else: + +        def default_if_none_converter(val): +            if val is not None: +                return val + +            return default + +    return default_if_none_converter diff --git a/python/attr/converters.pyi b/python/attr/converters.pyi new file mode 100644 index 0000000..63b2a38 --- /dev/null +++ b/python/attr/converters.pyi @@ -0,0 +1,12 @@ +from typing import TypeVar, Optional, Callable, overload +from . import _ConverterType + +_T = TypeVar("_T") + +def optional( +    converter: _ConverterType[_T] +) -> _ConverterType[Optional[_T]]: ... +@overload +def default_if_none(default: _T) -> _ConverterType[_T]: ... +@overload +def default_if_none(*, factory: Callable[[], _T]) -> _ConverterType[_T]: ... diff --git a/python/attr/exceptions.py b/python/attr/exceptions.py new file mode 100644 index 0000000..b12e41e --- /dev/null +++ b/python/attr/exceptions.py @@ -0,0 +1,57 @@ +from __future__ import absolute_import, division, print_function + + +class FrozenInstanceError(AttributeError): +    """ +    A frozen/immutable instance has been attempted to be modified. + +    It mirrors the behavior of ``namedtuples`` by using the same error message +    and subclassing :exc:`AttributeError`. + +    .. versionadded:: 16.1.0 +    """ + +    msg = "can't set attribute" +    args = [msg] + + +class AttrsAttributeNotFoundError(ValueError): +    """ +    An ``attrs`` function couldn't find an attribute that the user asked for. + +    .. versionadded:: 16.2.0 +    """ + + +class NotAnAttrsClassError(ValueError): +    """ +    A non-``attrs`` class has been passed into an ``attrs`` function. + +    .. versionadded:: 16.2.0 +    """ + + +class DefaultAlreadySetError(RuntimeError): +    """ +    A default has been set using ``attr.ib()`` and is attempted to be reset +    using the decorator. + +    .. versionadded:: 17.1.0 +    """ + + +class UnannotatedAttributeError(RuntimeError): +    """ +    A class with ``auto_attribs=True`` has an ``attr.ib()`` without a type +    annotation. + +    .. versionadded:: 17.3.0 +    """ + + +class PythonTooOldError(RuntimeError): +    """ +    An ``attrs`` feature requiring a more recent python version has been used. + +    .. versionadded:: 18.2.0 +    """ diff --git a/python/attr/exceptions.pyi b/python/attr/exceptions.pyi new file mode 100644 index 0000000..48fffcc --- /dev/null +++ b/python/attr/exceptions.pyi @@ -0,0 +1,7 @@ +class FrozenInstanceError(AttributeError): +    msg: str = ... + +class AttrsAttributeNotFoundError(ValueError): ... +class NotAnAttrsClassError(ValueError): ... +class DefaultAlreadySetError(RuntimeError): ... +class UnannotatedAttributeError(RuntimeError): ... diff --git a/python/attr/filters.py b/python/attr/filters.py new file mode 100644 index 0000000..f1c69b8 --- /dev/null +++ b/python/attr/filters.py @@ -0,0 +1,52 @@ +""" +Commonly useful filters for :func:`attr.asdict`. +""" + +from __future__ import absolute_import, division, print_function + +from ._compat import isclass +from ._make import Attribute + + +def _split_what(what): +    """ +    Returns a tuple of `frozenset`s of classes and attributes. +    """ +    return ( +        frozenset(cls for cls in what if isclass(cls)), +        frozenset(cls for cls in what if isinstance(cls, Attribute)), +    ) + + +def include(*what): +    """ +    Whitelist *what*. + +    :param what: What to whitelist. +    :type what: :class:`list` of :class:`type` or :class:`attr.Attribute`\\ s + +    :rtype: :class:`callable` +    """ +    cls, attrs = _split_what(what) + +    def include_(attribute, value): +        return value.__class__ in cls or attribute in attrs + +    return include_ + + +def exclude(*what): +    """ +    Blacklist *what*. + +    :param what: What to blacklist. +    :type what: :class:`list` of classes or :class:`attr.Attribute`\\ s. + +    :rtype: :class:`callable` +    """ +    cls, attrs = _split_what(what) + +    def exclude_(attribute, value): +        return value.__class__ not in cls and attribute not in attrs + +    return exclude_ diff --git a/python/attr/filters.pyi b/python/attr/filters.pyi new file mode 100644 index 0000000..a618140 --- /dev/null +++ b/python/attr/filters.pyi @@ -0,0 +1,5 @@ +from typing import Union +from . import Attribute, _FilterType + +def include(*what: Union[type, Attribute]) -> _FilterType: ... +def exclude(*what: Union[type, Attribute]) -> _FilterType: ... diff --git a/python/attr/py.typed b/python/attr/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/python/attr/py.typed diff --git a/python/attr/validators.py b/python/attr/validators.py new file mode 100644 index 0000000..f12d0aa --- /dev/null +++ b/python/attr/validators.py @@ -0,0 +1,170 @@ +""" +Commonly useful validators. +""" + +from __future__ import absolute_import, division, print_function + +from ._make import _AndValidator, and_, attrib, attrs + + +__all__ = ["and_", "in_", "instance_of", "optional", "provides"] + + +@attrs(repr=False, slots=True, hash=True) +class _InstanceOfValidator(object): +    type = attrib() + +    def __call__(self, inst, attr, value): +        """ +        We use a callable class to be able to change the ``__repr__``. +        """ +        if not isinstance(value, self.type): +            raise TypeError( +                "'{name}' must be {type!r} (got {value!r} that is a " +                "{actual!r}).".format( +                    name=attr.name, +                    type=self.type, +                    actual=value.__class__, +                    value=value, +                ), +                attr, +                self.type, +                value, +            ) + +    def __repr__(self): +        return "<instance_of validator for type {type!r}>".format( +            type=self.type +        ) + + +def instance_of(type): +    """ +    A validator that raises a :exc:`TypeError` if the initializer is called +    with a wrong type for this particular attribute (checks are performed using +    :func:`isinstance` therefore it's also valid to pass a tuple of types). + +    :param type: The type to check for. +    :type type: type or tuple of types + +    :raises TypeError: With a human readable error message, the attribute +        (of type :class:`attr.Attribute`), the expected type, and the value it +        got. +    """ +    return _InstanceOfValidator(type) + + +@attrs(repr=False, slots=True, hash=True) +class _ProvidesValidator(object): +    interface = attrib() + +    def __call__(self, inst, attr, value): +        """ +        We use a callable class to be able to change the ``__repr__``. +        """ +        if not self.interface.providedBy(value): +            raise TypeError( +                "'{name}' must provide {interface!r} which {value!r} " +                "doesn't.".format( +                    name=attr.name, interface=self.interface, value=value +                ), +                attr, +                self.interface, +                value, +            ) + +    def __repr__(self): +        return "<provides validator for interface {interface!r}>".format( +            interface=self.interface +        ) + + +def provides(interface): +    """ +    A validator that raises a :exc:`TypeError` if the initializer is called +    with an object that does not provide the requested *interface* (checks are +    performed using ``interface.providedBy(value)`` (see `zope.interface +    <https://zopeinterface.readthedocs.io/en/latest/>`_). + +    :param zope.interface.Interface interface: The interface to check for. + +    :raises TypeError: With a human readable error message, the attribute +        (of type :class:`attr.Attribute`), the expected interface, and the +        value it got. +    """ +    return _ProvidesValidator(interface) + + +@attrs(repr=False, slots=True, hash=True) +class _OptionalValidator(object): +    validator = attrib() + +    def __call__(self, inst, attr, value): +        if value is None: +            return + +        self.validator(inst, attr, value) + +    def __repr__(self): +        return "<optional validator for {what} or None>".format( +            what=repr(self.validator) +        ) + + +def optional(validator): +    """ +    A validator that makes an attribute optional.  An optional attribute is one +    which can be set to ``None`` in addition to satisfying the requirements of +    the sub-validator. + +    :param validator: A validator (or a list of validators) that is used for +        non-``None`` values. +    :type validator: callable or :class:`list` of callables. + +    .. versionadded:: 15.1.0 +    .. versionchanged:: 17.1.0 *validator* can be a list of validators. +    """ +    if isinstance(validator, list): +        return _OptionalValidator(_AndValidator(validator)) +    return _OptionalValidator(validator) + + +@attrs(repr=False, slots=True, hash=True) +class _InValidator(object): +    options = attrib() + +    def __call__(self, inst, attr, value): +        try: +            in_options = value in self.options +        except TypeError as e:  # e.g. `1 in "abc"` +            in_options = False + +        if not in_options: +            raise ValueError( +                "'{name}' must be in {options!r} (got {value!r})".format( +                    name=attr.name, options=self.options, value=value +                ) +            ) + +    def __repr__(self): +        return "<in_ validator with options {options!r}>".format( +            options=self.options +        ) + + +def in_(options): +    """ +    A validator that raises a :exc:`ValueError` if the initializer is called +    with a value that does not belong in the options provided.  The check is +    performed using ``value in options``. + +    :param options: Allowed options. +    :type options: list, tuple, :class:`enum.Enum`, ... + +    :raises ValueError: With a human readable error message, the attribute (of +       type :class:`attr.Attribute`), the expected options, and the value it +       got. + +    .. versionadded:: 17.1.0 +    """ +    return _InValidator(options) diff --git a/python/attr/validators.pyi b/python/attr/validators.pyi new file mode 100644 index 0000000..abbaedf --- /dev/null +++ b/python/attr/validators.pyi @@ -0,0 +1,14 @@ +from typing import Container, List, Union, TypeVar, Type, Any, Optional, Tuple +from . import _ValidatorType + +_T = TypeVar("_T") + +def instance_of( +    type: Union[Tuple[Type[_T], ...], Type[_T]] +) -> _ValidatorType[_T]: ... +def provides(interface: Any) -> _ValidatorType[Any]: ... +def optional( +    validator: Union[_ValidatorType[_T], List[_ValidatorType[_T]]] +) -> _ValidatorType[Optional[_T]]: ... +def in_(options: Container[_T]) -> _ValidatorType[_T]: ... +def and_(*validators: _ValidatorType[_T]) -> _ValidatorType[_T]: ... diff --git a/python/dateutil/__init__.py b/python/dateutil/__init__.py new file mode 100644 index 0000000..796ef3d --- /dev/null +++ b/python/dateutil/__init__.py @@ -0,0 +1,2 @@ +# -*- coding: utf-8 -*- +from ._version import VERSION as __version__ diff --git a/python/dateutil/_common.py b/python/dateutil/_common.py new file mode 100644 index 0000000..e8b4af7 --- /dev/null +++ b/python/dateutil/_common.py @@ -0,0 +1,34 @@ +""" +Common code used in multiple modules. +""" + + +class weekday(object): +    __slots__ = ["weekday", "n"] + +    def __init__(self, weekday, n=None): +        self.weekday = weekday +        self.n = n + +    def __call__(self, n): +        if n == self.n: +            return self +        else: +            return self.__class__(self.weekday, n) + +    def __eq__(self, other): +        try: +            if self.weekday != other.weekday or self.n != other.n: +                return False +        except AttributeError: +            return False +        return True + +    __hash__ = None + +    def __repr__(self): +        s = ("MO", "TU", "WE", "TH", "FR", "SA", "SU")[self.weekday] +        if not self.n: +            return s +        else: +            return "%s(%+d)" % (s, self.n) diff --git a/python/dateutil/_version.py b/python/dateutil/_version.py new file mode 100644 index 0000000..c1a0357 --- /dev/null +++ b/python/dateutil/_version.py @@ -0,0 +1,10 @@ +""" +Contains information about the dateutil version. +""" + +VERSION_MAJOR = 2 +VERSION_MINOR = 6 +VERSION_PATCH = 1 + +VERSION_TUPLE = (VERSION_MAJOR, VERSION_MINOR, VERSION_PATCH) +VERSION = '.'.join(map(str, VERSION_TUPLE)) diff --git a/python/dateutil/easter.py b/python/dateutil/easter.py new file mode 100644 index 0000000..e4def97 --- /dev/null +++ b/python/dateutil/easter.py @@ -0,0 +1,89 @@ +# -*- coding: utf-8 -*- +""" +This module offers a generic easter computing method for any given year, using +Western, Orthodox or Julian algorithms. +""" + +import datetime + +__all__ = ["easter", "EASTER_JULIAN", "EASTER_ORTHODOX", "EASTER_WESTERN"] + +EASTER_JULIAN = 1 +EASTER_ORTHODOX = 2 +EASTER_WESTERN = 3 + + +def easter(year, method=EASTER_WESTERN): +    """ +    This method was ported from the work done by GM Arts, +    on top of the algorithm by Claus Tondering, which was +    based in part on the algorithm of Ouding (1940), as +    quoted in "Explanatory Supplement to the Astronomical +    Almanac", P.  Kenneth Seidelmann, editor. + +    This algorithm implements three different easter +    calculation methods: + +    1 - Original calculation in Julian calendar, valid in +        dates after 326 AD +    2 - Original method, with date converted to Gregorian +        calendar, valid in years 1583 to 4099 +    3 - Revised method, in Gregorian calendar, valid in +        years 1583 to 4099 as well + +    These methods are represented by the constants: + +    * ``EASTER_JULIAN   = 1`` +    * ``EASTER_ORTHODOX = 2`` +    * ``EASTER_WESTERN  = 3`` + +    The default method is method 3. + +    More about the algorithm may be found at: + +    http://users.chariot.net.au/~gmarts/eastalg.htm + +    and + +    http://www.tondering.dk/claus/calendar.html + +    """ + +    if not (1 <= method <= 3): +        raise ValueError("invalid method") + +    # g - Golden year - 1 +    # c - Century +    # h - (23 - Epact) mod 30 +    # i - Number of days from March 21 to Paschal Full Moon +    # j - Weekday for PFM (0=Sunday, etc) +    # p - Number of days from March 21 to Sunday on or before PFM +    #     (-6 to 28 methods 1 & 3, to 56 for method 2) +    # e - Extra days to add for method 2 (converting Julian +    #     date to Gregorian date) + +    y = year +    g = y % 19 +    e = 0 +    if method < 3: +        # Old method +        i = (19*g + 15) % 30 +        j = (y + y//4 + i) % 7 +        if method == 2: +            # Extra dates to convert Julian to Gregorian date +            e = 10 +            if y > 1600: +                e = e + y//100 - 16 - (y//100 - 16)//4 +    else: +        # New method +        c = y//100 +        h = (c - c//4 - (8*c + 13)//25 + 19*g + 15) % 30 +        i = h - (h//28)*(1 - (h//28)*(29//(h + 1))*((21 - g)//11)) +        j = (y + y//4 + i + 2 - c + c//4) % 7 + +    # p can be from -6 to 56 corresponding to dates 22 March to 23 May +    # (later dates apply to method 2, although 23 May never actually occurs) +    p = i - j + e +    d = 1 + (p + 27 + (p + 6)//40) % 31 +    m = 3 + (p + 26)//30 +    return datetime.date(int(y), int(m), int(d)) diff --git a/python/dateutil/parser.py b/python/dateutil/parser.py new file mode 100644 index 0000000..595331f --- /dev/null +++ b/python/dateutil/parser.py @@ -0,0 +1,1374 @@ +# -*- coding: utf-8 -*- +""" +This module offers a generic date/time string parser which is able to parse +most known formats to represent a date and/or time. + +This module attempts to be forgiving with regards to unlikely input formats, +returning a datetime object even for dates which are ambiguous. If an element +of a date/time stamp is omitted, the following rules are applied: +- If AM or PM is left unspecified, a 24-hour clock is assumed, however, an hour +  on a 12-hour clock (``0 <= hour <= 12``) *must* be specified if AM or PM is +  specified. +- If a time zone is omitted, a timezone-naive datetime is returned. + +If any other elements are missing, they are taken from the +:class:`datetime.datetime` object passed to the parameter ``default``. If this +results in a day number exceeding the valid number of days per month, the +value falls back to the end of the month. + +Additional resources about date/time string formats can be found below: + +- `A summary of the international standard date and time notation +  <http://www.cl.cam.ac.uk/~mgk25/iso-time.html>`_ +- `W3C Date and Time Formats <http://www.w3.org/TR/NOTE-datetime>`_ +- `Time Formats (Planetary Rings Node) <http://pds-rings.seti.org/tools/time_formats.html>`_ +- `CPAN ParseDate module +  <http://search.cpan.org/~muir/Time-modules-2013.0912/lib/Time/ParseDate.pm>`_ +- `Java SimpleDateFormat Class +  <https://docs.oracle.com/javase/6/docs/api/java/text/SimpleDateFormat.html>`_ +""" +from __future__ import unicode_literals + +import datetime +import string +import time +import collections +import re +from io import StringIO +from calendar import monthrange + +from six import text_type, binary_type, integer_types + +from . import relativedelta +from . import tz + +__all__ = ["parse", "parserinfo"] + + +class _timelex(object): +    # Fractional seconds are sometimes split by a comma +    _split_decimal = re.compile("([.,])") + +    def __init__(self, instream): +        if isinstance(instream, binary_type): +            instream = instream.decode() + +        if isinstance(instream, text_type): +            instream = StringIO(instream) + +        if getattr(instream, 'read', None) is None: +            raise TypeError('Parser must be a string or character stream, not ' +                            '{itype}'.format(itype=instream.__class__.__name__)) + +        self.instream = instream +        self.charstack = [] +        self.tokenstack = [] +        self.eof = False + +    def get_token(self): +        """ +        This function breaks the time string into lexical units (tokens), which +        can be parsed by the parser. Lexical units are demarcated by changes in +        the character set, so any continuous string of letters is considered +        one unit, any continuous string of numbers is considered one unit. + +        The main complication arises from the fact that dots ('.') can be used +        both as separators (e.g. "Sep.20.2009") or decimal points (e.g. +        "4:30:21.447"). As such, it is necessary to read the full context of +        any dot-separated strings before breaking it into tokens; as such, this +        function maintains a "token stack", for when the ambiguous context +        demands that multiple tokens be parsed at once. +        """ +        if self.tokenstack: +            return self.tokenstack.pop(0) + +        seenletters = False +        token = None +        state = None + +        while not self.eof: +            # We only realize that we've reached the end of a token when we +            # find a character that's not part of the current token - since +            # that character may be part of the next token, it's stored in the +            # charstack. +            if self.charstack: +                nextchar = self.charstack.pop(0) +            else: +                nextchar = self.instream.read(1) +                while nextchar == '\x00': +                    nextchar = self.instream.read(1) + +            if not nextchar: +                self.eof = True +                break +            elif not state: +                # First character of the token - determines if we're starting +                # to parse a word, a number or something else. +                token = nextchar +                if self.isword(nextchar): +                    state = 'a' +                elif self.isnum(nextchar): +                    state = '0' +                elif self.isspace(nextchar): +                    token = ' ' +                    break  # emit token +                else: +                    break  # emit token +            elif state == 'a': +                # If we've already started reading a word, we keep reading +                # letters until we find something that's not part of a word. +                seenletters = True +                if self.isword(nextchar): +                    token += nextchar +                elif nextchar == '.': +                    token += nextchar +                    state = 'a.' +                else: +                    self.charstack.append(nextchar) +                    break  # emit token +            elif state == '0': +                # If we've already started reading a number, we keep reading +                # numbers until we find something that doesn't fit. +                if self.isnum(nextchar): +                    token += nextchar +                elif nextchar == '.' or (nextchar == ',' and len(token) >= 2): +                    token += nextchar +                    state = '0.' +                else: +                    self.charstack.append(nextchar) +                    break  # emit token +            elif state == 'a.': +                # If we've seen some letters and a dot separator, continue +                # parsing, and the tokens will be broken up later. +                seenletters = True +                if nextchar == '.' or self.isword(nextchar): +                    token += nextchar +                elif self.isnum(nextchar) and token[-1] == '.': +                    token += nextchar +                    state = '0.' +                else: +                    self.charstack.append(nextchar) +                    break  # emit token +            elif state == '0.': +                # If we've seen at least one dot separator, keep going, we'll +                # break up the tokens later. +                if nextchar == '.' or self.isnum(nextchar): +                    token += nextchar +                elif self.isword(nextchar) and token[-1] == '.': +                    token += nextchar +                    state = 'a.' +                else: +                    self.charstack.append(nextchar) +                    break  # emit token + +        if (state in ('a.', '0.') and (seenletters or token.count('.') > 1 or +                                       token[-1] in '.,')): +            l = self._split_decimal.split(token) +            token = l[0] +            for tok in l[1:]: +                if tok: +                    self.tokenstack.append(tok) + +        if state == '0.' and token.count('.') == 0: +            token = token.replace(',', '.') + +        return token + +    def __iter__(self): +        return self + +    def __next__(self): +        token = self.get_token() +        if token is None: +            raise StopIteration + +        return token + +    def next(self): +        return self.__next__()  # Python 2.x support + +    @classmethod +    def split(cls, s): +        return list(cls(s)) + +    @classmethod +    def isword(cls, nextchar): +        """ Whether or not the next character is part of a word """ +        return nextchar.isalpha() + +    @classmethod +    def isnum(cls, nextchar): +        """ Whether the next character is part of a number """ +        return nextchar.isdigit() + +    @classmethod +    def isspace(cls, nextchar): +        """ Whether the next character is whitespace """ +        return nextchar.isspace() + + +class _resultbase(object): + +    def __init__(self): +        for attr in self.__slots__: +            setattr(self, attr, None) + +    def _repr(self, classname): +        l = [] +        for attr in self.__slots__: +            value = getattr(self, attr) +            if value is not None: +                l.append("%s=%s" % (attr, repr(value))) +        return "%s(%s)" % (classname, ", ".join(l)) + +    def __len__(self): +        return (sum(getattr(self, attr) is not None +                    for attr in self.__slots__)) + +    def __repr__(self): +        return self._repr(self.__class__.__name__) + + +class parserinfo(object): +    """ +    Class which handles what inputs are accepted. Subclass this to customize +    the language and acceptable values for each parameter. + +    :param dayfirst: +            Whether to interpret the first value in an ambiguous 3-integer date +            (e.g. 01/05/09) as the day (``True``) or month (``False``). If +            ``yearfirst`` is set to ``True``, this distinguishes between YDM +            and YMD. Default is ``False``. + +    :param yearfirst: +            Whether to interpret the first value in an ambiguous 3-integer date +            (e.g. 01/05/09) as the year. If ``True``, the first number is taken +            to be the year, otherwise the last number is taken to be the year. +            Default is ``False``. +    """ + +    # m from a.m/p.m, t from ISO T separator +    JUMP = [" ", ".", ",", ";", "-", "/", "'", +            "at", "on", "and", "ad", "m", "t", "of", +            "st", "nd", "rd", "th"] + +    WEEKDAYS = [("Mon", "Monday"), +                ("Tue", "Tuesday"), +                ("Wed", "Wednesday"), +                ("Thu", "Thursday"), +                ("Fri", "Friday"), +                ("Sat", "Saturday"), +                ("Sun", "Sunday")] +    MONTHS = [("Jan", "January"), +              ("Feb", "February"), +              ("Mar", "March"), +              ("Apr", "April"), +              ("May", "May"), +              ("Jun", "June"), +              ("Jul", "July"), +              ("Aug", "August"), +              ("Sep", "Sept", "September"), +              ("Oct", "October"), +              ("Nov", "November"), +              ("Dec", "December")] +    HMS = [("h", "hour", "hours"), +           ("m", "minute", "minutes"), +           ("s", "second", "seconds")] +    AMPM = [("am", "a"), +            ("pm", "p")] +    UTCZONE = ["UTC", "GMT", "Z"] +    PERTAIN = ["of"] +    TZOFFSET = {} + +    def __init__(self, dayfirst=False, yearfirst=False): +        self._jump = self._convert(self.JUMP) +        self._weekdays = self._convert(self.WEEKDAYS) +        self._months = self._convert(self.MONTHS) +        self._hms = self._convert(self.HMS) +        self._ampm = self._convert(self.AMPM) +        self._utczone = self._convert(self.UTCZONE) +        self._pertain = self._convert(self.PERTAIN) + +        self.dayfirst = dayfirst +        self.yearfirst = yearfirst + +        self._year = time.localtime().tm_year +        self._century = self._year // 100 * 100 + +    def _convert(self, lst): +        dct = {} +        for i, v in enumerate(lst): +            if isinstance(v, tuple): +                for v in v: +                    dct[v.lower()] = i +            else: +                dct[v.lower()] = i +        return dct + +    def jump(self, name): +        return name.lower() in self._jump + +    def weekday(self, name): +        if len(name) >= min(len(n) for n in self._weekdays.keys()): +            try: +                return self._weekdays[name.lower()] +            except KeyError: +                pass +        return None + +    def month(self, name): +        if len(name) >= min(len(n) for n in self._months.keys()): +            try: +                return self._months[name.lower()] + 1 +            except KeyError: +                pass +        return None + +    def hms(self, name): +        try: +            return self._hms[name.lower()] +        except KeyError: +            return None + +    def ampm(self, name): +        try: +            return self._ampm[name.lower()] +        except KeyError: +            return None + +    def pertain(self, name): +        return name.lower() in self._pertain + +    def utczone(self, name): +        return name.lower() in self._utczone + +    def tzoffset(self, name): +        if name in self._utczone: +            return 0 + +        return self.TZOFFSET.get(name) + +    def convertyear(self, year, century_specified=False): +        if year < 100 and not century_specified: +            year += self._century +            if abs(year - self._year) >= 50: +                if year < self._year: +                    year += 100 +                else: +                    year -= 100 +        return year + +    def validate(self, res): +        # move to info +        if res.year is not None: +            res.year = self.convertyear(res.year, res.century_specified) + +        if res.tzoffset == 0 and not res.tzname or res.tzname == 'Z': +            res.tzname = "UTC" +            res.tzoffset = 0 +        elif res.tzoffset != 0 and res.tzname and self.utczone(res.tzname): +            res.tzoffset = 0 +        return True + + +class _ymd(list): +    def __init__(self, tzstr, *args, **kwargs): +        super(self.__class__, self).__init__(*args, **kwargs) +        self.century_specified = False +        self.tzstr = tzstr + +    @staticmethod +    def token_could_be_year(token, year): +        try: +            return int(token) == year +        except ValueError: +            return False + +    @staticmethod +    def find_potential_year_tokens(year, tokens): +        return [token for token in tokens if _ymd.token_could_be_year(token, year)] + +    def find_probable_year_index(self, tokens): +        """ +        attempt to deduce if a pre 100 year was lost +         due to padded zeros being taken off +        """ +        for index, token in enumerate(self): +            potential_year_tokens = _ymd.find_potential_year_tokens(token, tokens) +            if len(potential_year_tokens) == 1 and len(potential_year_tokens[0]) > 2: +                return index + +    def append(self, val): +        if hasattr(val, '__len__'): +            if val.isdigit() and len(val) > 2: +                self.century_specified = True +        elif val > 100: +            self.century_specified = True + +        super(self.__class__, self).append(int(val)) + +    def resolve_ymd(self, mstridx, yearfirst, dayfirst): +        len_ymd = len(self) +        year, month, day = (None, None, None) + +        if len_ymd > 3: +            raise ValueError("More than three YMD values") +        elif len_ymd == 1 or (mstridx != -1 and len_ymd == 2): +            # One member, or two members with a month string +            if mstridx != -1: +                month = self[mstridx] +                del self[mstridx] + +            if len_ymd > 1 or mstridx == -1: +                if self[0] > 31: +                    year = self[0] +                else: +                    day = self[0] + +        elif len_ymd == 2: +            # Two members with numbers +            if self[0] > 31: +                # 99-01 +                year, month = self +            elif self[1] > 31: +                # 01-99 +                month, year = self +            elif dayfirst and self[1] <= 12: +                # 13-01 +                day, month = self +            else: +                # 01-13 +                month, day = self + +        elif len_ymd == 3: +            # Three members +            if mstridx == 0: +                month, day, year = self +            elif mstridx == 1: +                if self[0] > 31 or (yearfirst and self[2] <= 31): +                    # 99-Jan-01 +                    year, month, day = self +                else: +                    # 01-Jan-01 +                    # Give precendence to day-first, since +                    # two-digit years is usually hand-written. +                    day, month, year = self + +            elif mstridx == 2: +                # WTF!? +                if self[1] > 31: +                    # 01-99-Jan +                    day, year, month = self +                else: +                    # 99-01-Jan +                    year, day, month = self + +            else: +                if self[0] > 31 or \ +                    self.find_probable_year_index(_timelex.split(self.tzstr)) == 0 or \ +                   (yearfirst and self[1] <= 12 and self[2] <= 31): +                    # 99-01-01 +                    if dayfirst and self[2] <= 12: +                        year, day, month = self +                    else: +                        year, month, day = self +                elif self[0] > 12 or (dayfirst and self[1] <= 12): +                    # 13-01-01 +                    day, month, year = self +                else: +                    # 01-13-01 +                    month, day, year = self + +        return year, month, day + + +class parser(object): +    def __init__(self, info=None): +        self.info = info or parserinfo() + +    def parse(self, timestr, default=None, ignoretz=False, tzinfos=None, **kwargs): +        """ +        Parse the date/time string into a :class:`datetime.datetime` object. + +        :param timestr: +            Any date/time string using the supported formats. + +        :param default: +            The default datetime object, if this is a datetime object and not +            ``None``, elements specified in ``timestr`` replace elements in the +            default object. + +        :param ignoretz: +            If set ``True``, time zones in parsed strings are ignored and a +            naive :class:`datetime.datetime` object is returned. + +        :param tzinfos: +            Additional time zone names / aliases which may be present in the +            string. This argument maps time zone names (and optionally offsets +            from those time zones) to time zones. This parameter can be a +            dictionary with timezone aliases mapping time zone names to time +            zones or a function taking two parameters (``tzname`` and +            ``tzoffset``) and returning a time zone. + +            The timezones to which the names are mapped can be an integer +            offset from UTC in minutes or a :class:`tzinfo` object. + +            .. doctest:: +               :options: +NORMALIZE_WHITESPACE + +                >>> from dateutil.parser import parse +                >>> from dateutil.tz import gettz +                >>> tzinfos = {"BRST": -10800, "CST": gettz("America/Chicago")} +                >>> parse("2012-01-19 17:21:00 BRST", tzinfos=tzinfos) +                datetime.datetime(2012, 1, 19, 17, 21, tzinfo=tzoffset(u'BRST', -10800)) +                >>> parse("2012-01-19 17:21:00 CST", tzinfos=tzinfos) +                datetime.datetime(2012, 1, 19, 17, 21, +                                  tzinfo=tzfile('/usr/share/zoneinfo/America/Chicago')) + +            This parameter is ignored if ``ignoretz`` is set. + +        :param **kwargs: +            Keyword arguments as passed to ``_parse()``. + +        :return: +            Returns a :class:`datetime.datetime` object or, if the +            ``fuzzy_with_tokens`` option is ``True``, returns a tuple, the +            first element being a :class:`datetime.datetime` object, the second +            a tuple containing the fuzzy tokens. + +        :raises ValueError: +            Raised for invalid or unknown string format, if the provided +            :class:`tzinfo` is not in a valid format, or if an invalid date +            would be created. + +        :raises TypeError: +            Raised for non-string or character stream input. + +        :raises OverflowError: +            Raised if the parsed date exceeds the largest valid C integer on +            your system. +        """ + +        if default is None: +            default = datetime.datetime.now().replace(hour=0, minute=0, +                                                      second=0, microsecond=0) + +        res, skipped_tokens = self._parse(timestr, **kwargs) + +        if res is None: +            raise ValueError("Unknown string format") + +        if len(res) == 0: +            raise ValueError("String does not contain a date.") + +        repl = {} +        for attr in ("year", "month", "day", "hour", +                     "minute", "second", "microsecond"): +            value = getattr(res, attr) +            if value is not None: +                repl[attr] = value + +        if 'day' not in repl: +            # If the default day exceeds the last day of the month, fall back to +            # the end of the month. +            cyear = default.year if res.year is None else res.year +            cmonth = default.month if res.month is None else res.month +            cday = default.day if res.day is None else res.day + +            if cday > monthrange(cyear, cmonth)[1]: +                repl['day'] = monthrange(cyear, cmonth)[1] + +        ret = default.replace(**repl) + +        if res.weekday is not None and not res.day: +            ret = ret+relativedelta.relativedelta(weekday=res.weekday) + +        if not ignoretz: +            if (isinstance(tzinfos, collections.Callable) or +                    tzinfos and res.tzname in tzinfos): + +                if isinstance(tzinfos, collections.Callable): +                    tzdata = tzinfos(res.tzname, res.tzoffset) +                else: +                    tzdata = tzinfos.get(res.tzname) + +                if isinstance(tzdata, datetime.tzinfo): +                    tzinfo = tzdata +                elif isinstance(tzdata, text_type): +                    tzinfo = tz.tzstr(tzdata) +                elif isinstance(tzdata, integer_types): +                    tzinfo = tz.tzoffset(res.tzname, tzdata) +                else: +                    raise ValueError("Offset must be tzinfo subclass, " +                                     "tz string, or int offset.") +                ret = ret.replace(tzinfo=tzinfo) +            elif res.tzname and res.tzname in time.tzname: +                ret = ret.replace(tzinfo=tz.tzlocal()) +            elif res.tzoffset == 0: +                ret = ret.replace(tzinfo=tz.tzutc()) +            elif res.tzoffset: +                ret = ret.replace(tzinfo=tz.tzoffset(res.tzname, res.tzoffset)) + +        if kwargs.get('fuzzy_with_tokens', False): +            return ret, skipped_tokens +        else: +            return ret + +    class _result(_resultbase): +        __slots__ = ["year", "month", "day", "weekday", +                     "hour", "minute", "second", "microsecond", +                     "tzname", "tzoffset", "ampm"] + +    def _parse(self, timestr, dayfirst=None, yearfirst=None, fuzzy=False, +               fuzzy_with_tokens=False): +        """ +        Private method which performs the heavy lifting of parsing, called from +        ``parse()``, which passes on its ``kwargs`` to this function. + +        :param timestr: +            The string to parse. + +        :param dayfirst: +            Whether to interpret the first value in an ambiguous 3-integer date +            (e.g. 01/05/09) as the day (``True``) or month (``False``). If +            ``yearfirst`` is set to ``True``, this distinguishes between YDM +            and YMD. If set to ``None``, this value is retrieved from the +            current :class:`parserinfo` object (which itself defaults to +            ``False``). + +        :param yearfirst: +            Whether to interpret the first value in an ambiguous 3-integer date +            (e.g. 01/05/09) as the year. If ``True``, the first number is taken +            to be the year, otherwise the last number is taken to be the year. +            If this is set to ``None``, the value is retrieved from the current +            :class:`parserinfo` object (which itself defaults to ``False``). + +        :param fuzzy: +            Whether to allow fuzzy parsing, allowing for string like "Today is +            January 1, 2047 at 8:21:00AM". + +        :param fuzzy_with_tokens: +            If ``True``, ``fuzzy`` is automatically set to True, and the parser +            will return a tuple where the first element is the parsed +            :class:`datetime.datetime` datetimestamp and the second element is +            a tuple containing the portions of the string which were ignored: + +            .. doctest:: + +                >>> from dateutil.parser import parse +                >>> parse("Today is January 1, 2047 at 8:21:00AM", fuzzy_with_tokens=True) +                (datetime.datetime(2047, 1, 1, 8, 21), (u'Today is ', u' ', u'at ')) + +        """ +        if fuzzy_with_tokens: +            fuzzy = True + +        info = self.info + +        if dayfirst is None: +            dayfirst = info.dayfirst + +        if yearfirst is None: +            yearfirst = info.yearfirst + +        res = self._result() +        l = _timelex.split(timestr)         # Splits the timestr into tokens + +        # keep up with the last token skipped so we can recombine +        # consecutively skipped tokens (-2 for when i begins at 0). +        last_skipped_token_i = -2 +        skipped_tokens = list() + +        try: +            # year/month/day list +            ymd = _ymd(timestr) + +            # Index of the month string in ymd +            mstridx = -1 + +            len_l = len(l) +            i = 0 +            while i < len_l: + +                # Check if it's a number +                try: +                    value_repr = l[i] +                    value = float(value_repr) +                except ValueError: +                    value = None + +                if value is not None: +                    # Token is a number +                    len_li = len(l[i]) +                    i += 1 + +                    if (len(ymd) == 3 and len_li in (2, 4) +                        and res.hour is None and (i >= len_l or (l[i] != ':' and +                                                  info.hms(l[i]) is None))): +                        # 19990101T23[59] +                        s = l[i-1] +                        res.hour = int(s[:2]) + +                        if len_li == 4: +                            res.minute = int(s[2:]) + +                    elif len_li == 6 or (len_li > 6 and l[i-1].find('.') == 6): +                        # YYMMDD or HHMMSS[.ss] +                        s = l[i-1] + +                        if not ymd and l[i-1].find('.') == -1: +                            #ymd.append(info.convertyear(int(s[:2]))) + +                            ymd.append(s[:2]) +                            ymd.append(s[2:4]) +                            ymd.append(s[4:]) +                        else: +                            # 19990101T235959[.59] +                            res.hour = int(s[:2]) +                            res.minute = int(s[2:4]) +                            res.second, res.microsecond = _parsems(s[4:]) + +                    elif len_li in (8, 12, 14): +                        # YYYYMMDD +                        s = l[i-1] +                        ymd.append(s[:4]) +                        ymd.append(s[4:6]) +                        ymd.append(s[6:8]) + +                        if len_li > 8: +                            res.hour = int(s[8:10]) +                            res.minute = int(s[10:12]) + +                            if len_li > 12: +                                res.second = int(s[12:]) + +                    elif ((i < len_l and info.hms(l[i]) is not None) or +                          (i+1 < len_l and l[i] == ' ' and +                           info.hms(l[i+1]) is not None)): + +                        # HH[ ]h or MM[ ]m or SS[.ss][ ]s +                        if l[i] == ' ': +                            i += 1 + +                        idx = info.hms(l[i]) + +                        while True: +                            if idx == 0: +                                res.hour = int(value) + +                                if value % 1: +                                    res.minute = int(60*(value % 1)) + +                            elif idx == 1: +                                res.minute = int(value) + +                                if value % 1: +                                    res.second = int(60*(value % 1)) + +                            elif idx == 2: +                                res.second, res.microsecond = \ +                                    _parsems(value_repr) + +                            i += 1 + +                            if i >= len_l or idx == 2: +                                break + +                            # 12h00 +                            try: +                                value_repr = l[i] +                                value = float(value_repr) +                            except ValueError: +                                break +                            else: +                                i += 1 +                                idx += 1 + +                                if i < len_l: +                                    newidx = info.hms(l[i]) + +                                    if newidx is not None: +                                        idx = newidx + +                    elif (i == len_l and l[i-2] == ' ' and +                          info.hms(l[i-3]) is not None): +                        # X h MM or X m SS +                        idx = info.hms(l[i-3]) + +                        if idx == 0:               # h +                            res.minute = int(value) + +                            sec_remainder = value % 1 +                            if sec_remainder: +                                res.second = int(60 * sec_remainder) +                        elif idx == 1:             # m +                            res.second, res.microsecond = \ +                                _parsems(value_repr) + +                        # We don't need to advance the tokens here because the +                        # i == len_l call indicates that we're looking at all +                        # the tokens already. + +                    elif i+1 < len_l and l[i] == ':': +                        # HH:MM[:SS[.ss]] +                        res.hour = int(value) +                        i += 1 +                        value = float(l[i]) +                        res.minute = int(value) + +                        if value % 1: +                            res.second = int(60*(value % 1)) + +                        i += 1 + +                        if i < len_l and l[i] == ':': +                            res.second, res.microsecond = _parsems(l[i+1]) +                            i += 2 + +                    elif i < len_l and l[i] in ('-', '/', '.'): +                        sep = l[i] +                        ymd.append(value_repr) +                        i += 1 + +                        if i < len_l and not info.jump(l[i]): +                            try: +                                # 01-01[-01] +                                ymd.append(l[i]) +                            except ValueError: +                                # 01-Jan[-01] +                                value = info.month(l[i]) + +                                if value is not None: +                                    ymd.append(value) +                                    assert mstridx == -1 +                                    mstridx = len(ymd)-1 +                                else: +                                    return None, None + +                            i += 1 + +                            if i < len_l and l[i] == sep: +                                # We have three members +                                i += 1 +                                value = info.month(l[i]) + +                                if value is not None: +                                    ymd.append(value) +                                    mstridx = len(ymd)-1 +                                    assert mstridx == -1 +                                else: +                                    ymd.append(l[i]) + +                                i += 1 +                    elif i >= len_l or info.jump(l[i]): +                        if i+1 < len_l and info.ampm(l[i+1]) is not None: +                            # 12 am +                            res.hour = int(value) + +                            if res.hour < 12 and info.ampm(l[i+1]) == 1: +                                res.hour += 12 +                            elif res.hour == 12 and info.ampm(l[i+1]) == 0: +                                res.hour = 0 + +                            i += 1 +                        else: +                            # Year, month or day +                            ymd.append(value) +                        i += 1 +                    elif info.ampm(l[i]) is not None: + +                        # 12am +                        res.hour = int(value) + +                        if res.hour < 12 and info.ampm(l[i]) == 1: +                            res.hour += 12 +                        elif res.hour == 12 and info.ampm(l[i]) == 0: +                            res.hour = 0 +                        i += 1 + +                    elif not fuzzy: +                        return None, None +                    else: +                        i += 1 +                    continue + +                # Check weekday +                value = info.weekday(l[i]) +                if value is not None: +                    res.weekday = value +                    i += 1 +                    continue + +                # Check month name +                value = info.month(l[i]) +                if value is not None: +                    ymd.append(value) +                    assert mstridx == -1 +                    mstridx = len(ymd)-1 + +                    i += 1 +                    if i < len_l: +                        if l[i] in ('-', '/'): +                            # Jan-01[-99] +                            sep = l[i] +                            i += 1 +                            ymd.append(l[i]) +                            i += 1 + +                            if i < len_l and l[i] == sep: +                                # Jan-01-99 +                                i += 1 +                                ymd.append(l[i]) +                                i += 1 + +                        elif (i+3 < len_l and l[i] == l[i+2] == ' ' +                              and info.pertain(l[i+1])): +                            # Jan of 01 +                            # In this case, 01 is clearly year +                            try: +                                value = int(l[i+3]) +                            except ValueError: +                                # Wrong guess +                                pass +                            else: +                                # Convert it here to become unambiguous +                                ymd.append(str(info.convertyear(value))) +                            i += 4 +                    continue + +                # Check am/pm +                value = info.ampm(l[i]) +                if value is not None: +                    # For fuzzy parsing, 'a' or 'am' (both valid English words) +                    # may erroneously trigger the AM/PM flag. Deal with that +                    # here. +                    val_is_ampm = True + +                    # If there's already an AM/PM flag, this one isn't one. +                    if fuzzy and res.ampm is not None: +                        val_is_ampm = False + +                    # If AM/PM is found and hour is not, raise a ValueError +                    if res.hour is None: +                        if fuzzy: +                            val_is_ampm = False +                        else: +                            raise ValueError('No hour specified with ' + +                                             'AM or PM flag.') +                    elif not 0 <= res.hour <= 12: +                        # If AM/PM is found, it's a 12 hour clock, so raise +                        # an error for invalid range +                        if fuzzy: +                            val_is_ampm = False +                        else: +                            raise ValueError('Invalid hour specified for ' + +                                             '12-hour clock.') + +                    if val_is_ampm: +                        if value == 1 and res.hour < 12: +                            res.hour += 12 +                        elif value == 0 and res.hour == 12: +                            res.hour = 0 + +                        res.ampm = value + +                    elif fuzzy: +                        last_skipped_token_i = self._skip_token(skipped_tokens, +                                                    last_skipped_token_i, i, l) +                    i += 1 +                    continue + +                # Check for a timezone name +                if (res.hour is not None and len(l[i]) <= 5 and +                        res.tzname is None and res.tzoffset is None and +                        not [x for x in l[i] if x not in +                             string.ascii_uppercase]): +                    res.tzname = l[i] +                    res.tzoffset = info.tzoffset(res.tzname) +                    i += 1 + +                    # Check for something like GMT+3, or BRST+3. Notice +                    # that it doesn't mean "I am 3 hours after GMT", but +                    # "my time +3 is GMT". If found, we reverse the +                    # logic so that timezone parsing code will get it +                    # right. +                    if i < len_l and l[i] in ('+', '-'): +                        l[i] = ('+', '-')[l[i] == '+'] +                        res.tzoffset = None +                        if info.utczone(res.tzname): +                            # With something like GMT+3, the timezone +                            # is *not* GMT. +                            res.tzname = None + +                    continue + +                # Check for a numbered timezone +                if res.hour is not None and l[i] in ('+', '-'): +                    signal = (-1, 1)[l[i] == '+'] +                    i += 1 +                    len_li = len(l[i]) + +                    if len_li == 4: +                        # -0300 +                        res.tzoffset = int(l[i][:2])*3600+int(l[i][2:])*60 +                    elif i+1 < len_l and l[i+1] == ':': +                        # -03:00 +                        res.tzoffset = int(l[i])*3600+int(l[i+2])*60 +                        i += 2 +                    elif len_li <= 2: +                        # -[0]3 +                        res.tzoffset = int(l[i][:2])*3600 +                    else: +                        return None, None +                    i += 1 + +                    res.tzoffset *= signal + +                    # Look for a timezone name between parenthesis +                    if (i+3 < len_l and +                        info.jump(l[i]) and l[i+1] == '(' and l[i+3] == ')' and +                        3 <= len(l[i+2]) <= 5 and +                        not [x for x in l[i+2] +                             if x not in string.ascii_uppercase]): +                        # -0300 (BRST) +                        res.tzname = l[i+2] +                        i += 4 +                    continue + +                # Check jumps +                if not (info.jump(l[i]) or fuzzy): +                    return None, None + +                last_skipped_token_i = self._skip_token(skipped_tokens, +                                                last_skipped_token_i, i, l) +                i += 1 + +            # Process year/month/day +            year, month, day = ymd.resolve_ymd(mstridx, yearfirst, dayfirst) +            if year is not None: +                res.year = year +                res.century_specified = ymd.century_specified + +            if month is not None: +                res.month = month + +            if day is not None: +                res.day = day + +        except (IndexError, ValueError, AssertionError): +            return None, None + +        if not info.validate(res): +            return None, None + +        if fuzzy_with_tokens: +            return res, tuple(skipped_tokens) +        else: +            return res, None + +    @staticmethod +    def _skip_token(skipped_tokens, last_skipped_token_i, i, l): +        if last_skipped_token_i == i - 1: +            # recombine the tokens +            skipped_tokens[-1] += l[i] +        else: +            # just append +            skipped_tokens.append(l[i]) +        last_skipped_token_i = i +        return last_skipped_token_i + + +DEFAULTPARSER = parser() + + +def parse(timestr, parserinfo=None, **kwargs): +    """ + +    Parse a string in one of the supported formats, using the +    ``parserinfo`` parameters. + +    :param timestr: +        A string containing a date/time stamp. + +    :param parserinfo: +        A :class:`parserinfo` object containing parameters for the parser. +        If ``None``, the default arguments to the :class:`parserinfo` +        constructor are used. + +    The ``**kwargs`` parameter takes the following keyword arguments: + +    :param default: +        The default datetime object, if this is a datetime object and not +        ``None``, elements specified in ``timestr`` replace elements in the +        default object. + +    :param ignoretz: +        If set ``True``, time zones in parsed strings are ignored and a naive +        :class:`datetime` object is returned. + +    :param tzinfos: +            Additional time zone names / aliases which may be present in the +            string. This argument maps time zone names (and optionally offsets +            from those time zones) to time zones. This parameter can be a +            dictionary with timezone aliases mapping time zone names to time +            zones or a function taking two parameters (``tzname`` and +            ``tzoffset``) and returning a time zone. + +            The timezones to which the names are mapped can be an integer +            offset from UTC in minutes or a :class:`tzinfo` object. + +            .. doctest:: +               :options: +NORMALIZE_WHITESPACE + +                >>> from dateutil.parser import parse +                >>> from dateutil.tz import gettz +                >>> tzinfos = {"BRST": -10800, "CST": gettz("America/Chicago")} +                >>> parse("2012-01-19 17:21:00 BRST", tzinfos=tzinfos) +                datetime.datetime(2012, 1, 19, 17, 21, tzinfo=tzoffset(u'BRST', -10800)) +                >>> parse("2012-01-19 17:21:00 CST", tzinfos=tzinfos) +                datetime.datetime(2012, 1, 19, 17, 21, +                                  tzinfo=tzfile('/usr/share/zoneinfo/America/Chicago')) + +            This parameter is ignored if ``ignoretz`` is set. + +    :param dayfirst: +        Whether to interpret the first value in an ambiguous 3-integer date +        (e.g. 01/05/09) as the day (``True``) or month (``False``). If +        ``yearfirst`` is set to ``True``, this distinguishes between YDM and +        YMD. If set to ``None``, this value is retrieved from the current +        :class:`parserinfo` object (which itself defaults to ``False``). + +    :param yearfirst: +        Whether to interpret the first value in an ambiguous 3-integer date +        (e.g. 01/05/09) as the year. If ``True``, the first number is taken to +        be the year, otherwise the last number is taken to be the year. If +        this is set to ``None``, the value is retrieved from the current +        :class:`parserinfo` object (which itself defaults to ``False``). + +    :param fuzzy: +        Whether to allow fuzzy parsing, allowing for string like "Today is +        January 1, 2047 at 8:21:00AM". + +    :param fuzzy_with_tokens: +        If ``True``, ``fuzzy`` is automatically set to True, and the parser +        will return a tuple where the first element is the parsed +        :class:`datetime.datetime` datetimestamp and the second element is +        a tuple containing the portions of the string which were ignored: + +        .. doctest:: + +            >>> from dateutil.parser import parse +            >>> parse("Today is January 1, 2047 at 8:21:00AM", fuzzy_with_tokens=True) +            (datetime.datetime(2047, 1, 1, 8, 21), (u'Today is ', u' ', u'at ')) + +    :return: +        Returns a :class:`datetime.datetime` object or, if the +        ``fuzzy_with_tokens`` option is ``True``, returns a tuple, the +        first element being a :class:`datetime.datetime` object, the second +        a tuple containing the fuzzy tokens. + +    :raises ValueError: +        Raised for invalid or unknown string format, if the provided +        :class:`tzinfo` is not in a valid format, or if an invalid date +        would be created. + +    :raises OverflowError: +        Raised if the parsed date exceeds the largest valid C integer on +        your system. +    """ +    if parserinfo: +        return parser(parserinfo).parse(timestr, **kwargs) +    else: +        return DEFAULTPARSER.parse(timestr, **kwargs) + + +class _tzparser(object): + +    class _result(_resultbase): + +        __slots__ = ["stdabbr", "stdoffset", "dstabbr", "dstoffset", +                     "start", "end"] + +        class _attr(_resultbase): +            __slots__ = ["month", "week", "weekday", +                         "yday", "jyday", "day", "time"] + +        def __repr__(self): +            return self._repr("") + +        def __init__(self): +            _resultbase.__init__(self) +            self.start = self._attr() +            self.end = self._attr() + +    def parse(self, tzstr): +        res = self._result() +        l = _timelex.split(tzstr) +        try: + +            len_l = len(l) + +            i = 0 +            while i < len_l: +                # BRST+3[BRDT[+2]] +                j = i +                while j < len_l and not [x for x in l[j] +                                         if x in "0123456789:,-+"]: +                    j += 1 +                if j != i: +                    if not res.stdabbr: +                        offattr = "stdoffset" +                        res.stdabbr = "".join(l[i:j]) +                    else: +                        offattr = "dstoffset" +                        res.dstabbr = "".join(l[i:j]) +                    i = j +                    if (i < len_l and (l[i] in ('+', '-') or l[i][0] in +                                       "0123456789")): +                        if l[i] in ('+', '-'): +                            # Yes, that's right.  See the TZ variable +                            # documentation. +                            signal = (1, -1)[l[i] == '+'] +                            i += 1 +                        else: +                            signal = -1 +                        len_li = len(l[i]) +                        if len_li == 4: +                            # -0300 +                            setattr(res, offattr, (int(l[i][:2])*3600 + +                                                   int(l[i][2:])*60)*signal) +                        elif i+1 < len_l and l[i+1] == ':': +                            # -03:00 +                            setattr(res, offattr, +                                    (int(l[i])*3600+int(l[i+2])*60)*signal) +                            i += 2 +                        elif len_li <= 2: +                            # -[0]3 +                            setattr(res, offattr, +                                    int(l[i][:2])*3600*signal) +                        else: +                            return None +                        i += 1 +                    if res.dstabbr: +                        break +                else: +                    break + +            if i < len_l: +                for j in range(i, len_l): +                    if l[j] == ';': +                        l[j] = ',' + +                assert l[i] == ',' + +                i += 1 + +            if i >= len_l: +                pass +            elif (8 <= l.count(',') <= 9 and +                  not [y for x in l[i:] if x != ',' +                       for y in x if y not in "0123456789"]): +                # GMT0BST,3,0,30,3600,10,0,26,7200[,3600] +                for x in (res.start, res.end): +                    x.month = int(l[i]) +                    i += 2 +                    if l[i] == '-': +                        value = int(l[i+1])*-1 +                        i += 1 +                    else: +                        value = int(l[i]) +                    i += 2 +                    if value: +                        x.week = value +                        x.weekday = (int(l[i])-1) % 7 +                    else: +                        x.day = int(l[i]) +                    i += 2 +                    x.time = int(l[i]) +                    i += 2 +                if i < len_l: +                    if l[i] in ('-', '+'): +                        signal = (-1, 1)[l[i] == "+"] +                        i += 1 +                    else: +                        signal = 1 +                    res.dstoffset = (res.stdoffset+int(l[i]))*signal +            elif (l.count(',') == 2 and l[i:].count('/') <= 2 and +                  not [y for x in l[i:] if x not in (',', '/', 'J', 'M', +                                                     '.', '-', ':') +                       for y in x if y not in "0123456789"]): +                for x in (res.start, res.end): +                    if l[i] == 'J': +                        # non-leap year day (1 based) +                        i += 1 +                        x.jyday = int(l[i]) +                    elif l[i] == 'M': +                        # month[-.]week[-.]weekday +                        i += 1 +                        x.month = int(l[i]) +                        i += 1 +                        assert l[i] in ('-', '.') +                        i += 1 +                        x.week = int(l[i]) +                        if x.week == 5: +                            x.week = -1 +                        i += 1 +                        assert l[i] in ('-', '.') +                        i += 1 +                        x.weekday = (int(l[i])-1) % 7 +                    else: +                        # year day (zero based) +                        x.yday = int(l[i])+1 + +                    i += 1 + +                    if i < len_l and l[i] == '/': +                        i += 1 +                        # start time +                        len_li = len(l[i]) +                        if len_li == 4: +                            # -0300 +                            x.time = (int(l[i][:2])*3600+int(l[i][2:])*60) +                        elif i+1 < len_l and l[i+1] == ':': +                            # -03:00 +                            x.time = int(l[i])*3600+int(l[i+2])*60 +                            i += 2 +                            if i+1 < len_l and l[i+1] == ':': +                                i += 2 +                                x.time += int(l[i]) +                        elif len_li <= 2: +                            # -[0]3 +                            x.time = (int(l[i][:2])*3600) +                        else: +                            return None +                        i += 1 + +                    assert i == len_l or l[i] == ',' + +                    i += 1 + +                assert i >= len_l + +        except (IndexError, ValueError, AssertionError): +            return None + +        return res + + +DEFAULTTZPARSER = _tzparser() + + +def _parsetz(tzstr): +    return DEFAULTTZPARSER.parse(tzstr) + + +def _parsems(value): +    """Parse a I[.F] seconds value into (seconds, microseconds).""" +    if "." not in value: +        return int(value), 0 +    else: +        i, f = value.split(".") +        return int(i), int(f.ljust(6, "0")[:6]) + + +# vim:ts=4:sw=4:et diff --git a/python/dateutil/relativedelta.py b/python/dateutil/relativedelta.py new file mode 100644 index 0000000..0e66afc --- /dev/null +++ b/python/dateutil/relativedelta.py @@ -0,0 +1,549 @@ +# -*- coding: utf-8 -*- +import datetime +import calendar + +import operator +from math import copysign + +from six import integer_types +from warnings import warn + +from ._common import weekday + +MO, TU, WE, TH, FR, SA, SU = weekdays = tuple(weekday(x) for x in range(7)) + +__all__ = ["relativedelta", "MO", "TU", "WE", "TH", "FR", "SA", "SU"] + + +class relativedelta(object): +    """ +    The relativedelta type is based on the specification of the excellent +    work done by M.-A. Lemburg in his +    `mx.DateTime <http://www.egenix.com/files/python/mxDateTime.html>`_ extension. +    However, notice that this type does *NOT* implement the same algorithm as +    his work. Do *NOT* expect it to behave like mx.DateTime's counterpart. + +    There are two different ways to build a relativedelta instance. The +    first one is passing it two date/datetime classes:: + +        relativedelta(datetime1, datetime2) + +    The second one is passing it any number of the following keyword arguments:: + +        relativedelta(arg1=x,arg2=y,arg3=z...) + +        year, month, day, hour, minute, second, microsecond: +            Absolute information (argument is singular); adding or subtracting a +            relativedelta with absolute information does not perform an aritmetic +            operation, but rather REPLACES the corresponding value in the +            original datetime with the value(s) in relativedelta. + +        years, months, weeks, days, hours, minutes, seconds, microseconds: +            Relative information, may be negative (argument is plural); adding +            or subtracting a relativedelta with relative information performs +            the corresponding aritmetic operation on the original datetime value +            with the information in the relativedelta. + +        weekday: +            One of the weekday instances (MO, TU, etc). These instances may +            receive a parameter N, specifying the Nth weekday, which could +            be positive or negative (like MO(+1) or MO(-2). Not specifying +            it is the same as specifying +1. You can also use an integer, +            where 0=MO. + +        leapdays: +            Will add given days to the date found, if year is a leap +            year, and the date found is post 28 of february. + +        yearday, nlyearday: +            Set the yearday or the non-leap year day (jump leap days). +            These are converted to day/month/leapdays information. + +    Here is the behavior of operations with relativedelta: + +    1. Calculate the absolute year, using the 'year' argument, or the +       original datetime year, if the argument is not present. + +    2. Add the relative 'years' argument to the absolute year. + +    3. Do steps 1 and 2 for month/months. + +    4. Calculate the absolute day, using the 'day' argument, or the +       original datetime day, if the argument is not present. Then, +       subtract from the day until it fits in the year and month +       found after their operations. + +    5. Add the relative 'days' argument to the absolute day. Notice +       that the 'weeks' argument is multiplied by 7 and added to +       'days'. + +    6. Do steps 1 and 2 for hour/hours, minute/minutes, second/seconds, +       microsecond/microseconds. + +    7. If the 'weekday' argument is present, calculate the weekday, +       with the given (wday, nth) tuple. wday is the index of the +       weekday (0-6, 0=Mon), and nth is the number of weeks to add +       forward or backward, depending on its signal. Notice that if +       the calculated date is already Monday, for example, using +       (0, 1) or (0, -1) won't change the day. +    """ + +    def __init__(self, dt1=None, dt2=None, +                 years=0, months=0, days=0, leapdays=0, weeks=0, +                 hours=0, minutes=0, seconds=0, microseconds=0, +                 year=None, month=None, day=None, weekday=None, +                 yearday=None, nlyearday=None, +                 hour=None, minute=None, second=None, microsecond=None): + +        # Check for non-integer values in integer-only quantities +        if any(x is not None and x != int(x) for x in (years, months)): +            raise ValueError("Non-integer years and months are " +                             "ambiguous and not currently supported.") + +        if dt1 and dt2: +            # datetime is a subclass of date. So both must be date +            if not (isinstance(dt1, datetime.date) and +                    isinstance(dt2, datetime.date)): +                raise TypeError("relativedelta only diffs datetime/date") + +            # We allow two dates, or two datetimes, so we coerce them to be +            # of the same type +            if (isinstance(dt1, datetime.datetime) != +                    isinstance(dt2, datetime.datetime)): +                if not isinstance(dt1, datetime.datetime): +                    dt1 = datetime.datetime.fromordinal(dt1.toordinal()) +                elif not isinstance(dt2, datetime.datetime): +                    dt2 = datetime.datetime.fromordinal(dt2.toordinal()) + +            self.years = 0 +            self.months = 0 +            self.days = 0 +            self.leapdays = 0 +            self.hours = 0 +            self.minutes = 0 +            self.seconds = 0 +            self.microseconds = 0 +            self.year = None +            self.month = None +            self.day = None +            self.weekday = None +            self.hour = None +            self.minute = None +            self.second = None +            self.microsecond = None +            self._has_time = 0 + +            # Get year / month delta between the two +            months = (dt1.year - dt2.year) * 12 + (dt1.month - dt2.month) +            self._set_months(months) + +            # Remove the year/month delta so the timedelta is just well-defined +            # time units (seconds, days and microseconds) +            dtm = self.__radd__(dt2) + +            # If we've overshot our target, make an adjustment +            if dt1 < dt2: +                compare = operator.gt +                increment = 1 +            else: +                compare = operator.lt +                increment = -1 + +            while compare(dt1, dtm): +                months += increment +                self._set_months(months) +                dtm = self.__radd__(dt2) + +            # Get the timedelta between the "months-adjusted" date and dt1 +            delta = dt1 - dtm +            self.seconds = delta.seconds + delta.days * 86400 +            self.microseconds = delta.microseconds +        else: +            # Relative information +            self.years = years +            self.months = months +            self.days = days + weeks * 7 +            self.leapdays = leapdays +            self.hours = hours +            self.minutes = minutes +            self.seconds = seconds +            self.microseconds = microseconds + +            # Absolute information +            self.year = year +            self.month = month +            self.day = day +            self.hour = hour +            self.minute = minute +            self.second = second +            self.microsecond = microsecond + +            if any(x is not None and int(x) != x +                   for x in (year, month, day, hour, +                             minute, second, microsecond)): +                # For now we'll deprecate floats - later it'll be an error. +                warn("Non-integer value passed as absolute information. " + +                     "This is not a well-defined condition and will raise " + +                     "errors in future versions.", DeprecationWarning) + +            if isinstance(weekday, integer_types): +                self.weekday = weekdays[weekday] +            else: +                self.weekday = weekday + +            yday = 0 +            if nlyearday: +                yday = nlyearday +            elif yearday: +                yday = yearday +                if yearday > 59: +                    self.leapdays = -1 +            if yday: +                ydayidx = [31, 59, 90, 120, 151, 181, 212, +                           243, 273, 304, 334, 366] +                for idx, ydays in enumerate(ydayidx): +                    if yday <= ydays: +                        self.month = idx+1 +                        if idx == 0: +                            self.day = yday +                        else: +                            self.day = yday-ydayidx[idx-1] +                        break +                else: +                    raise ValueError("invalid year day (%d)" % yday) + +        self._fix() + +    def _fix(self): +        if abs(self.microseconds) > 999999: +            s = _sign(self.microseconds) +            div, mod = divmod(self.microseconds * s, 1000000) +            self.microseconds = mod * s +            self.seconds += div * s +        if abs(self.seconds) > 59: +            s = _sign(self.seconds) +            div, mod = divmod(self.seconds * s, 60) +            self.seconds = mod * s +            self.minutes += div * s +        if abs(self.minutes) > 59: +            s = _sign(self.minutes) +            div, mod = divmod(self.minutes * s, 60) +            self.minutes = mod * s +            self.hours += div * s +        if abs(self.hours) > 23: +            s = _sign(self.hours) +            div, mod = divmod(self.hours * s, 24) +            self.hours = mod * s +            self.days += div * s +        if abs(self.months) > 11: +            s = _sign(self.months) +            div, mod = divmod(self.months * s, 12) +            self.months = mod * s +            self.years += div * s +        if (self.hours or self.minutes or self.seconds or self.microseconds +                or self.hour is not None or self.minute is not None or +                self.second is not None or self.microsecond is not None): +            self._has_time = 1 +        else: +            self._has_time = 0 + +    @property +    def weeks(self): +        return self.days // 7 + +    @weeks.setter +    def weeks(self, value): +        self.days = self.days - (self.weeks * 7) + value * 7 + +    def _set_months(self, months): +        self.months = months +        if abs(self.months) > 11: +            s = _sign(self.months) +            div, mod = divmod(self.months * s, 12) +            self.months = mod * s +            self.years = div * s +        else: +            self.years = 0 + +    def normalized(self): +        """ +        Return a version of this object represented entirely using integer +        values for the relative attributes. + +        >>> relativedelta(days=1.5, hours=2).normalized() +        relativedelta(days=1, hours=14) + +        :return: +            Returns a :class:`dateutil.relativedelta.relativedelta` object. +        """ +        # Cascade remainders down (rounding each to roughly nearest microsecond) +        days = int(self.days) + +        hours_f = round(self.hours + 24 * (self.days - days), 11) +        hours = int(hours_f) + +        minutes_f = round(self.minutes + 60 * (hours_f - hours), 10) +        minutes = int(minutes_f) + +        seconds_f = round(self.seconds + 60 * (minutes_f - minutes), 8) +        seconds = int(seconds_f) + +        microseconds = round(self.microseconds + 1e6 * (seconds_f - seconds)) + +        # Constructor carries overflow back up with call to _fix() +        return self.__class__(years=self.years, months=self.months, +                              days=days, hours=hours, minutes=minutes, +                              seconds=seconds, microseconds=microseconds, +                              leapdays=self.leapdays, year=self.year, +                              month=self.month, day=self.day, +                              weekday=self.weekday, hour=self.hour, +                              minute=self.minute, second=self.second, +                              microsecond=self.microsecond) + +    def __add__(self, other): +        if isinstance(other, relativedelta): +            return self.__class__(years=other.years + self.years, +                                 months=other.months + self.months, +                                 days=other.days + self.days, +                                 hours=other.hours + self.hours, +                                 minutes=other.minutes + self.minutes, +                                 seconds=other.seconds + self.seconds, +                                 microseconds=(other.microseconds + +                                               self.microseconds), +                                 leapdays=other.leapdays or self.leapdays, +                                 year=(other.year if other.year is not None +                                       else self.year), +                                 month=(other.month if other.month is not None +                                        else self.month), +                                 day=(other.day if other.day is not None +                                      else self.day), +                                 weekday=(other.weekday if other.weekday is not None +                                          else self.weekday), +                                 hour=(other.hour if other.hour is not None +                                       else self.hour), +                                 minute=(other.minute if other.minute is not None +                                         else self.minute), +                                 second=(other.second if other.second is not None +                                         else self.second), +                                 microsecond=(other.microsecond if other.microsecond +                                              is not None else +                                              self.microsecond)) +        if isinstance(other, datetime.timedelta): +            return self.__class__(years=self.years, +                                  months=self.months, +                                  days=self.days + other.days, +                                  hours=self.hours, +                                  minutes=self.minutes, +                                  seconds=self.seconds + other.seconds, +                                  microseconds=self.microseconds + other.microseconds, +                                  leapdays=self.leapdays, +                                  year=self.year, +                                  month=self.month, +                                  day=self.day, +                                  weekday=self.weekday, +                                  hour=self.hour, +                                  minute=self.minute, +                                  second=self.second, +                                  microsecond=self.microsecond) +        if not isinstance(other, datetime.date): +            return NotImplemented +        elif self._has_time and not isinstance(other, datetime.datetime): +            other = datetime.datetime.fromordinal(other.toordinal()) +        year = (self.year or other.year)+self.years +        month = self.month or other.month +        if self.months: +            assert 1 <= abs(self.months) <= 12 +            month += self.months +            if month > 12: +                year += 1 +                month -= 12 +            elif month < 1: +                year -= 1 +                month += 12 +        day = min(calendar.monthrange(year, month)[1], +                  self.day or other.day) +        repl = {"year": year, "month": month, "day": day} +        for attr in ["hour", "minute", "second", "microsecond"]: +            value = getattr(self, attr) +            if value is not None: +                repl[attr] = value +        days = self.days +        if self.leapdays and month > 2 and calendar.isleap(year): +            days += self.leapdays +        ret = (other.replace(**repl) +               + datetime.timedelta(days=days, +                                    hours=self.hours, +                                    minutes=self.minutes, +                                    seconds=self.seconds, +                                    microseconds=self.microseconds)) +        if self.weekday: +            weekday, nth = self.weekday.weekday, self.weekday.n or 1 +            jumpdays = (abs(nth) - 1) * 7 +            if nth > 0: +                jumpdays += (7 - ret.weekday() + weekday) % 7 +            else: +                jumpdays += (ret.weekday() - weekday) % 7 +                jumpdays *= -1 +            ret += datetime.timedelta(days=jumpdays) +        return ret + +    def __radd__(self, other): +        return self.__add__(other) + +    def __rsub__(self, other): +        return self.__neg__().__radd__(other) + +    def __sub__(self, other): +        if not isinstance(other, relativedelta): +            return NotImplemented   # In case the other object defines __rsub__ +        return self.__class__(years=self.years - other.years, +                             months=self.months - other.months, +                             days=self.days - other.days, +                             hours=self.hours - other.hours, +                             minutes=self.minutes - other.minutes, +                             seconds=self.seconds - other.seconds, +                             microseconds=self.microseconds - other.microseconds, +                             leapdays=self.leapdays or other.leapdays, +                             year=(self.year if self.year is not None +                                   else other.year), +                             month=(self.month if self.month is not None else +                                    other.month), +                             day=(self.day if self.day is not None else +                                  other.day), +                             weekday=(self.weekday if self.weekday is not None else +                                      other.weekday), +                             hour=(self.hour if self.hour is not None else +                                   other.hour), +                             minute=(self.minute if self.minute is not None else +                                     other.minute), +                             second=(self.second if self.second is not None else +                                     other.second), +                             microsecond=(self.microsecond if self.microsecond +                                          is not None else +                                          other.microsecond)) + +    def __neg__(self): +        return self.__class__(years=-self.years, +                             months=-self.months, +                             days=-self.days, +                             hours=-self.hours, +                             minutes=-self.minutes, +                             seconds=-self.seconds, +                             microseconds=-self.microseconds, +                             leapdays=self.leapdays, +                             year=self.year, +                             month=self.month, +                             day=self.day, +                             weekday=self.weekday, +                             hour=self.hour, +                             minute=self.minute, +                             second=self.second, +                             microsecond=self.microsecond) + +    def __bool__(self): +        return not (not self.years and +                    not self.months and +                    not self.days and +                    not self.hours and +                    not self.minutes and +                    not self.seconds and +                    not self.microseconds and +                    not self.leapdays and +                    self.year is None and +                    self.month is None and +                    self.day is None and +                    self.weekday is None and +                    self.hour is None and +                    self.minute is None and +                    self.second is None and +                    self.microsecond is None) +    # Compatibility with Python 2.x +    __nonzero__ = __bool__ + +    def __mul__(self, other): +        try: +            f = float(other) +        except TypeError: +            return NotImplemented + +        return self.__class__(years=int(self.years * f), +                             months=int(self.months * f), +                             days=int(self.days * f), +                             hours=int(self.hours * f), +                             minutes=int(self.minutes * f), +                             seconds=int(self.seconds * f), +                             microseconds=int(self.microseconds * f), +                             leapdays=self.leapdays, +                             year=self.year, +                             month=self.month, +                             day=self.day, +                             weekday=self.weekday, +                             hour=self.hour, +                             minute=self.minute, +                             second=self.second, +                             microsecond=self.microsecond) + +    __rmul__ = __mul__ + +    def __eq__(self, other): +        if not isinstance(other, relativedelta): +            return NotImplemented +        if self.weekday or other.weekday: +            if not self.weekday or not other.weekday: +                return False +            if self.weekday.weekday != other.weekday.weekday: +                return False +            n1, n2 = self.weekday.n, other.weekday.n +            if n1 != n2 and not ((not n1 or n1 == 1) and (not n2 or n2 == 1)): +                return False +        return (self.years == other.years and +                self.months == other.months and +                self.days == other.days and +                self.hours == other.hours and +                self.minutes == other.minutes and +                self.seconds == other.seconds and +                self.microseconds == other.microseconds and +                self.leapdays == other.leapdays and +                self.year == other.year and +                self.month == other.month and +                self.day == other.day and +                self.hour == other.hour and +                self.minute == other.minute and +                self.second == other.second and +                self.microsecond == other.microsecond) + +    __hash__ = None + +    def __ne__(self, other): +        return not self.__eq__(other) + +    def __div__(self, other): +        try: +            reciprocal = 1 / float(other) +        except TypeError: +            return NotImplemented + +        return self.__mul__(reciprocal) + +    __truediv__ = __div__ + +    def __repr__(self): +        l = [] +        for attr in ["years", "months", "days", "leapdays", +                     "hours", "minutes", "seconds", "microseconds"]: +            value = getattr(self, attr) +            if value: +                l.append("{attr}={value:+g}".format(attr=attr, value=value)) +        for attr in ["year", "month", "day", "weekday", +                     "hour", "minute", "second", "microsecond"]: +            value = getattr(self, attr) +            if value is not None: +                l.append("{attr}={value}".format(attr=attr, value=repr(value))) +        return "{classname}({attrs})".format(classname=self.__class__.__name__, +                                             attrs=", ".join(l)) + + +def _sign(x): +    return int(copysign(1, x)) + +# vim:ts=4:sw=4:et diff --git a/python/dateutil/rrule.py b/python/dateutil/rrule.py new file mode 100644 index 0000000..429f8fc --- /dev/null +++ b/python/dateutil/rrule.py @@ -0,0 +1,1610 @@ +# -*- coding: utf-8 -*- +""" +The rrule module offers a small, complete, and very fast, implementation of +the recurrence rules documented in the +`iCalendar RFC <http://www.ietf.org/rfc/rfc2445.txt>`_, +including support for caching of results. +""" +import itertools +import datetime +import calendar +import sys + +try: +    from math import gcd +except ImportError: +    from fractions import gcd + +from six import advance_iterator, integer_types +from six.moves import _thread, range +import heapq + +from ._common import weekday as weekdaybase + +# For warning about deprecation of until and count +from warnings import warn + +__all__ = ["rrule", "rruleset", "rrulestr", +           "YEARLY", "MONTHLY", "WEEKLY", "DAILY", +           "HOURLY", "MINUTELY", "SECONDLY", +           "MO", "TU", "WE", "TH", "FR", "SA", "SU"] + +# Every mask is 7 days longer to handle cross-year weekly periods. +M366MASK = tuple([1]*31+[2]*29+[3]*31+[4]*30+[5]*31+[6]*30 + +                 [7]*31+[8]*31+[9]*30+[10]*31+[11]*30+[12]*31+[1]*7) +M365MASK = list(M366MASK) +M29, M30, M31 = list(range(1, 30)), list(range(1, 31)), list(range(1, 32)) +MDAY366MASK = tuple(M31+M29+M31+M30+M31+M30+M31+M31+M30+M31+M30+M31+M31[:7]) +MDAY365MASK = list(MDAY366MASK) +M29, M30, M31 = list(range(-29, 0)), list(range(-30, 0)), list(range(-31, 0)) +NMDAY366MASK = tuple(M31+M29+M31+M30+M31+M30+M31+M31+M30+M31+M30+M31+M31[:7]) +NMDAY365MASK = list(NMDAY366MASK) +M366RANGE = (0, 31, 60, 91, 121, 152, 182, 213, 244, 274, 305, 335, 366) +M365RANGE = (0, 31, 59, 90, 120, 151, 181, 212, 243, 273, 304, 334, 365) +WDAYMASK = [0, 1, 2, 3, 4, 5, 6]*55 +del M29, M30, M31, M365MASK[59], MDAY365MASK[59], NMDAY365MASK[31] +MDAY365MASK = tuple(MDAY365MASK) +M365MASK = tuple(M365MASK) + +FREQNAMES = ['YEARLY', 'MONTHLY', 'WEEKLY', 'DAILY', 'HOURLY', 'MINUTELY', 'SECONDLY'] + +(YEARLY, + MONTHLY, + WEEKLY, + DAILY, + HOURLY, + MINUTELY, + SECONDLY) = list(range(7)) + +# Imported on demand. +easter = None +parser = None + + +class weekday(weekdaybase): +    """ +    This version of weekday does not allow n = 0. +    """ +    def __init__(self, wkday, n=None): +        if n == 0: +            raise ValueError("Can't create weekday with n==0") + +        super(weekday, self).__init__(wkday, n) + + +MO, TU, WE, TH, FR, SA, SU = weekdays = tuple(weekday(x) for x in range(7)) + + +def _invalidates_cache(f): +    """ +    Decorator for rruleset methods which may invalidate the +    cached length. +    """ +    def inner_func(self, *args, **kwargs): +        rv = f(self, *args, **kwargs) +        self._invalidate_cache() +        return rv + +    return inner_func + + +class rrulebase(object): +    def __init__(self, cache=False): +        if cache: +            self._cache = [] +            self._cache_lock = _thread.allocate_lock() +            self._invalidate_cache() +        else: +            self._cache = None +            self._cache_complete = False +            self._len = None + +    def __iter__(self): +        if self._cache_complete: +            return iter(self._cache) +        elif self._cache is None: +            return self._iter() +        else: +            return self._iter_cached() + +    def _invalidate_cache(self): +        if self._cache is not None: +            self._cache = [] +            self._cache_complete = False +            self._cache_gen = self._iter() + +            if self._cache_lock.locked(): +                self._cache_lock.release() + +        self._len = None + +    def _iter_cached(self): +        i = 0 +        gen = self._cache_gen +        cache = self._cache +        acquire = self._cache_lock.acquire +        release = self._cache_lock.release +        while gen: +            if i == len(cache): +                acquire() +                if self._cache_complete: +                    break +                try: +                    for j in range(10): +                        cache.append(advance_iterator(gen)) +                except StopIteration: +                    self._cache_gen = gen = None +                    self._cache_complete = True +                    break +                release() +            yield cache[i] +            i += 1 +        while i < self._len: +            yield cache[i] +            i += 1 + +    def __getitem__(self, item): +        if self._cache_complete: +            return self._cache[item] +        elif isinstance(item, slice): +            if item.step and item.step < 0: +                return list(iter(self))[item] +            else: +                return list(itertools.islice(self, +                                             item.start or 0, +                                             item.stop or sys.maxsize, +                                             item.step or 1)) +        elif item >= 0: +            gen = iter(self) +            try: +                for i in range(item+1): +                    res = advance_iterator(gen) +            except StopIteration: +                raise IndexError +            return res +        else: +            return list(iter(self))[item] + +    def __contains__(self, item): +        if self._cache_complete: +            return item in self._cache +        else: +            for i in self: +                if i == item: +                    return True +                elif i > item: +                    return False +        return False + +    # __len__() introduces a large performance penality. +    def count(self): +        """ Returns the number of recurrences in this set. It will have go +            trough the whole recurrence, if this hasn't been done before. """ +        if self._len is None: +            for x in self: +                pass +        return self._len + +    def before(self, dt, inc=False): +        """ Returns the last recurrence before the given datetime instance. The +            inc keyword defines what happens if dt is an occurrence. With +            inc=True, if dt itself is an occurrence, it will be returned. """ +        if self._cache_complete: +            gen = self._cache +        else: +            gen = self +        last = None +        if inc: +            for i in gen: +                if i > dt: +                    break +                last = i +        else: +            for i in gen: +                if i >= dt: +                    break +                last = i +        return last + +    def after(self, dt, inc=False): +        """ Returns the first recurrence after the given datetime instance. The +            inc keyword defines what happens if dt is an occurrence. With +            inc=True, if dt itself is an occurrence, it will be returned.  """ +        if self._cache_complete: +            gen = self._cache +        else: +            gen = self +        if inc: +            for i in gen: +                if i >= dt: +                    return i +        else: +            for i in gen: +                if i > dt: +                    return i +        return None + +    def xafter(self, dt, count=None, inc=False): +        """ +        Generator which yields up to `count` recurrences after the given +        datetime instance, equivalent to `after`. + +        :param dt: +            The datetime at which to start generating recurrences. + +        :param count: +            The maximum number of recurrences to generate. If `None` (default), +            dates are generated until the recurrence rule is exhausted. + +        :param inc: +            If `dt` is an instance of the rule and `inc` is `True`, it is +            included in the output. + +        :yields: Yields a sequence of `datetime` objects. +        """ + +        if self._cache_complete: +            gen = self._cache +        else: +            gen = self + +        # Select the comparison function +        if inc: +            comp = lambda dc, dtc: dc >= dtc +        else: +            comp = lambda dc, dtc: dc > dtc + +        # Generate dates +        n = 0 +        for d in gen: +            if comp(d, dt): +                if count is not None: +                    n += 1 +                    if n > count: +                        break + +                yield d + +    def between(self, after, before, inc=False, count=1): +        """ Returns all the occurrences of the rrule between after and before. +        The inc keyword defines what happens if after and/or before are +        themselves occurrences. With inc=True, they will be included in the +        list, if they are found in the recurrence set. """ +        if self._cache_complete: +            gen = self._cache +        else: +            gen = self +        started = False +        l = [] +        if inc: +            for i in gen: +                if i > before: +                    break +                elif not started: +                    if i >= after: +                        started = True +                        l.append(i) +                else: +                    l.append(i) +        else: +            for i in gen: +                if i >= before: +                    break +                elif not started: +                    if i > after: +                        started = True +                        l.append(i) +                else: +                    l.append(i) +        return l + + +class rrule(rrulebase): +    """ +    That's the base of the rrule operation. It accepts all the keywords +    defined in the RFC as its constructor parameters (except byday, +    which was renamed to byweekday) and more. The constructor prototype is:: + +            rrule(freq) + +    Where freq must be one of YEARLY, MONTHLY, WEEKLY, DAILY, HOURLY, MINUTELY, +    or SECONDLY. + +    .. note:: +        Per RFC section 3.3.10, recurrence instances falling on invalid dates +        and times are ignored rather than coerced: + +            Recurrence rules may generate recurrence instances with an invalid +            date (e.g., February 30) or nonexistent local time (e.g., 1:30 AM +            on a day where the local time is moved forward by an hour at 1:00 +            AM).  Such recurrence instances MUST be ignored and MUST NOT be +            counted as part of the recurrence set. + +        This can lead to possibly surprising behavior when, for example, the +        start date occurs at the end of the month: + +        >>> from dateutil.rrule import rrule, MONTHLY +        >>> from datetime import datetime +        >>> start_date = datetime(2014, 12, 31) +        >>> list(rrule(freq=MONTHLY, count=4, dtstart=start_date)) +        ... # doctest: +NORMALIZE_WHITESPACE +        [datetime.datetime(2014, 12, 31, 0, 0), +         datetime.datetime(2015, 1, 31, 0, 0), +         datetime.datetime(2015, 3, 31, 0, 0), +         datetime.datetime(2015, 5, 31, 0, 0)] + +    Additionally, it supports the following keyword arguments: + +    :param cache: +        If given, it must be a boolean value specifying to enable or disable +        caching of results. If you will use the same rrule instance multiple +        times, enabling caching will improve the performance considerably. +    :param dtstart: +        The recurrence start. Besides being the base for the recurrence, +        missing parameters in the final recurrence instances will also be +        extracted from this date. If not given, datetime.now() will be used +        instead. +    :param interval: +        The interval between each freq iteration. For example, when using +        YEARLY, an interval of 2 means once every two years, but with HOURLY, +        it means once every two hours. The default interval is 1. +    :param wkst: +        The week start day. Must be one of the MO, TU, WE constants, or an +        integer, specifying the first day of the week. This will affect +        recurrences based on weekly periods. The default week start is got +        from calendar.firstweekday(), and may be modified by +        calendar.setfirstweekday(). +    :param count: +        How many occurrences will be generated. + +        .. note:: +            As of version 2.5.0, the use of the ``until`` keyword together +            with the ``count`` keyword is deprecated per RFC-2445 Sec. 4.3.10. +    :param until: +        If given, this must be a datetime instance, that will specify the +        limit of the recurrence. The last recurrence in the rule is the greatest +        datetime that is less than or equal to the value specified in the +        ``until`` parameter. + +        .. note:: +            As of version 2.5.0, the use of the ``until`` keyword together +            with the ``count`` keyword is deprecated per RFC-2445 Sec. 4.3.10. +    :param bysetpos: +        If given, it must be either an integer, or a sequence of integers, +        positive or negative. Each given integer will specify an occurrence +        number, corresponding to the nth occurrence of the rule inside the +        frequency period. For example, a bysetpos of -1 if combined with a +        MONTHLY frequency, and a byweekday of (MO, TU, WE, TH, FR), will +        result in the last work day of every month. +    :param bymonth: +        If given, it must be either an integer, or a sequence of integers, +        meaning the months to apply the recurrence to. +    :param bymonthday: +        If given, it must be either an integer, or a sequence of integers, +        meaning the month days to apply the recurrence to. +    :param byyearday: +        If given, it must be either an integer, or a sequence of integers, +        meaning the year days to apply the recurrence to. +    :param byweekno: +        If given, it must be either an integer, or a sequence of integers, +        meaning the week numbers to apply the recurrence to. Week numbers +        have the meaning described in ISO8601, that is, the first week of +        the year is that containing at least four days of the new year. +    :param byweekday: +        If given, it must be either an integer (0 == MO), a sequence of +        integers, one of the weekday constants (MO, TU, etc), or a sequence +        of these constants. When given, these variables will define the +        weekdays where the recurrence will be applied. It's also possible to +        use an argument n for the weekday instances, which will mean the nth +        occurrence of this weekday in the period. For example, with MONTHLY, +        or with YEARLY and BYMONTH, using FR(+1) in byweekday will specify the +        first friday of the month where the recurrence happens. Notice that in +        the RFC documentation, this is specified as BYDAY, but was renamed to +        avoid the ambiguity of that keyword. +    :param byhour: +        If given, it must be either an integer, or a sequence of integers, +        meaning the hours to apply the recurrence to. +    :param byminute: +        If given, it must be either an integer, or a sequence of integers, +        meaning the minutes to apply the recurrence to. +    :param bysecond: +        If given, it must be either an integer, or a sequence of integers, +        meaning the seconds to apply the recurrence to. +    :param byeaster: +        If given, it must be either an integer, or a sequence of integers, +        positive or negative. Each integer will define an offset from the +        Easter Sunday. Passing the offset 0 to byeaster will yield the Easter +        Sunday itself. This is an extension to the RFC specification. +     """ +    def __init__(self, freq, dtstart=None, +                 interval=1, wkst=None, count=None, until=None, bysetpos=None, +                 bymonth=None, bymonthday=None, byyearday=None, byeaster=None, +                 byweekno=None, byweekday=None, +                 byhour=None, byminute=None, bysecond=None, +                 cache=False): +        super(rrule, self).__init__(cache) +        global easter +        if not dtstart: +            dtstart = datetime.datetime.now().replace(microsecond=0) +        elif not isinstance(dtstart, datetime.datetime): +            dtstart = datetime.datetime.fromordinal(dtstart.toordinal()) +        else: +            dtstart = dtstart.replace(microsecond=0) +        self._dtstart = dtstart +        self._tzinfo = dtstart.tzinfo +        self._freq = freq +        self._interval = interval +        self._count = count + +        # Cache the original byxxx rules, if they are provided, as the _byxxx +        # attributes do not necessarily map to the inputs, and this can be +        # a problem in generating the strings. Only store things if they've +        # been supplied (the string retrieval will just use .get()) +        self._original_rule = {} + +        if until and not isinstance(until, datetime.datetime): +            until = datetime.datetime.fromordinal(until.toordinal()) +        self._until = until + +        if count is not None and until: +            warn("Using both 'count' and 'until' is inconsistent with RFC 2445" +                 " and has been deprecated in dateutil. Future versions will " +                 "raise an error.", DeprecationWarning) + +        if wkst is None: +            self._wkst = calendar.firstweekday() +        elif isinstance(wkst, integer_types): +            self._wkst = wkst +        else: +            self._wkst = wkst.weekday + +        if bysetpos is None: +            self._bysetpos = None +        elif isinstance(bysetpos, integer_types): +            if bysetpos == 0 or not (-366 <= bysetpos <= 366): +                raise ValueError("bysetpos must be between 1 and 366, " +                                 "or between -366 and -1") +            self._bysetpos = (bysetpos,) +        else: +            self._bysetpos = tuple(bysetpos) +            for pos in self._bysetpos: +                if pos == 0 or not (-366 <= pos <= 366): +                    raise ValueError("bysetpos must be between 1 and 366, " +                                     "or between -366 and -1") + +        if self._bysetpos: +            self._original_rule['bysetpos'] = self._bysetpos + +        if (byweekno is None and byyearday is None and bymonthday is None and +                byweekday is None and byeaster is None): +            if freq == YEARLY: +                if bymonth is None: +                    bymonth = dtstart.month +                    self._original_rule['bymonth'] = None +                bymonthday = dtstart.day +                self._original_rule['bymonthday'] = None +            elif freq == MONTHLY: +                bymonthday = dtstart.day +                self._original_rule['bymonthday'] = None +            elif freq == WEEKLY: +                byweekday = dtstart.weekday() +                self._original_rule['byweekday'] = None + +        # bymonth +        if bymonth is None: +            self._bymonth = None +        else: +            if isinstance(bymonth, integer_types): +                bymonth = (bymonth,) + +            self._bymonth = tuple(sorted(set(bymonth))) + +            if 'bymonth' not in self._original_rule: +                self._original_rule['bymonth'] = self._bymonth + +        # byyearday +        if byyearday is None: +            self._byyearday = None +        else: +            if isinstance(byyearday, integer_types): +                byyearday = (byyearday,) + +            self._byyearday = tuple(sorted(set(byyearday))) +            self._original_rule['byyearday'] = self._byyearday + +        # byeaster +        if byeaster is not None: +            if not easter: +                from dateutil import easter +            if isinstance(byeaster, integer_types): +                self._byeaster = (byeaster,) +            else: +                self._byeaster = tuple(sorted(byeaster)) + +            self._original_rule['byeaster'] = self._byeaster +        else: +            self._byeaster = None + +        # bymonthday +        if bymonthday is None: +            self._bymonthday = () +            self._bynmonthday = () +        else: +            if isinstance(bymonthday, integer_types): +                bymonthday = (bymonthday,) + +            bymonthday = set(bymonthday)            # Ensure it's unique + +            self._bymonthday = tuple(sorted(x for x in bymonthday if x > 0)) +            self._bynmonthday = tuple(sorted(x for x in bymonthday if x < 0)) + +            # Storing positive numbers first, then negative numbers +            if 'bymonthday' not in self._original_rule: +                self._original_rule['bymonthday'] = tuple( +                    itertools.chain(self._bymonthday, self._bynmonthday)) + +        # byweekno +        if byweekno is None: +            self._byweekno = None +        else: +            if isinstance(byweekno, integer_types): +                byweekno = (byweekno,) + +            self._byweekno = tuple(sorted(set(byweekno))) + +            self._original_rule['byweekno'] = self._byweekno + +        # byweekday / bynweekday +        if byweekday is None: +            self._byweekday = None +            self._bynweekday = None +        else: +            # If it's one of the valid non-sequence types, convert to a +            # single-element sequence before the iterator that builds the +            # byweekday set. +            if isinstance(byweekday, integer_types) or hasattr(byweekday, "n"): +                byweekday = (byweekday,) + +            self._byweekday = set() +            self._bynweekday = set() +            for wday in byweekday: +                if isinstance(wday, integer_types): +                    self._byweekday.add(wday) +                elif not wday.n or freq > MONTHLY: +                    self._byweekday.add(wday.weekday) +                else: +                    self._bynweekday.add((wday.weekday, wday.n)) + +            if not self._byweekday: +                self._byweekday = None +            elif not self._bynweekday: +                self._bynweekday = None + +            if self._byweekday is not None: +                self._byweekday = tuple(sorted(self._byweekday)) +                orig_byweekday = [weekday(x) for x in self._byweekday] +            else: +                orig_byweekday = tuple() + +            if self._bynweekday is not None: +                self._bynweekday = tuple(sorted(self._bynweekday)) +                orig_bynweekday = [weekday(*x) for x in self._bynweekday] +            else: +                orig_bynweekday = tuple() + +            if 'byweekday' not in self._original_rule: +                self._original_rule['byweekday'] = tuple(itertools.chain( +                    orig_byweekday, orig_bynweekday)) + +        # byhour +        if byhour is None: +            if freq < HOURLY: +                self._byhour = set((dtstart.hour,)) +            else: +                self._byhour = None +        else: +            if isinstance(byhour, integer_types): +                byhour = (byhour,) + +            if freq == HOURLY: +                self._byhour = self.__construct_byset(start=dtstart.hour, +                                                      byxxx=byhour, +                                                      base=24) +            else: +                self._byhour = set(byhour) + +            self._byhour = tuple(sorted(self._byhour)) +            self._original_rule['byhour'] = self._byhour + +        # byminute +        if byminute is None: +            if freq < MINUTELY: +                self._byminute = set((dtstart.minute,)) +            else: +                self._byminute = None +        else: +            if isinstance(byminute, integer_types): +                byminute = (byminute,) + +            if freq == MINUTELY: +                self._byminute = self.__construct_byset(start=dtstart.minute, +                                                        byxxx=byminute, +                                                        base=60) +            else: +                self._byminute = set(byminute) + +            self._byminute = tuple(sorted(self._byminute)) +            self._original_rule['byminute'] = self._byminute + +        # bysecond +        if bysecond is None: +            if freq < SECONDLY: +                self._bysecond = ((dtstart.second,)) +            else: +                self._bysecond = None +        else: +            if isinstance(bysecond, integer_types): +                bysecond = (bysecond,) + +            self._bysecond = set(bysecond) + +            if freq == SECONDLY: +                self._bysecond = self.__construct_byset(start=dtstart.second, +                                                        byxxx=bysecond, +                                                        base=60) +            else: +                self._bysecond = set(bysecond) + +            self._bysecond = tuple(sorted(self._bysecond)) +            self._original_rule['bysecond'] = self._bysecond + +        if self._freq >= HOURLY: +            self._timeset = None +        else: +            self._timeset = [] +            for hour in self._byhour: +                for minute in self._byminute: +                    for second in self._bysecond: +                        self._timeset.append( +                            datetime.time(hour, minute, second, +                                          tzinfo=self._tzinfo)) +            self._timeset.sort() +            self._timeset = tuple(self._timeset) + +    def __str__(self): +        """ +        Output a string that would generate this RRULE if passed to rrulestr. +        This is mostly compatible with RFC2445, except for the +        dateutil-specific extension BYEASTER. +        """ + +        output = [] +        h, m, s = [None] * 3 +        if self._dtstart: +            output.append(self._dtstart.strftime('DTSTART:%Y%m%dT%H%M%S')) +            h, m, s = self._dtstart.timetuple()[3:6] + +        parts = ['FREQ=' + FREQNAMES[self._freq]] +        if self._interval != 1: +            parts.append('INTERVAL=' + str(self._interval)) + +        if self._wkst: +            parts.append('WKST=' + repr(weekday(self._wkst))[0:2]) + +        if self._count is not None: +            parts.append('COUNT=' + str(self._count)) + +        if self._until: +            parts.append(self._until.strftime('UNTIL=%Y%m%dT%H%M%S')) + +        if self._original_rule.get('byweekday') is not None: +            # The str() method on weekday objects doesn't generate +            # RFC2445-compliant strings, so we should modify that. +            original_rule = dict(self._original_rule) +            wday_strings = [] +            for wday in original_rule['byweekday']: +                if wday.n: +                    wday_strings.append('{n:+d}{wday}'.format( +                        n=wday.n, +                        wday=repr(wday)[0:2])) +                else: +                    wday_strings.append(repr(wday)) + +            original_rule['byweekday'] = wday_strings +        else: +            original_rule = self._original_rule + +        partfmt = '{name}={vals}' +        for name, key in [('BYSETPOS', 'bysetpos'), +                          ('BYMONTH', 'bymonth'), +                          ('BYMONTHDAY', 'bymonthday'), +                          ('BYYEARDAY', 'byyearday'), +                          ('BYWEEKNO', 'byweekno'), +                          ('BYDAY', 'byweekday'), +                          ('BYHOUR', 'byhour'), +                          ('BYMINUTE', 'byminute'), +                          ('BYSECOND', 'bysecond'), +                          ('BYEASTER', 'byeaster')]: +            value = original_rule.get(key) +            if value: +                parts.append(partfmt.format(name=name, vals=(','.join(str(v) +                                                             for v in value)))) + +        output.append(';'.join(parts)) +        return '\n'.join(output) + +    def replace(self, **kwargs): +        """Return new rrule with same attributes except for those attributes given new +           values by whichever keyword arguments are specified.""" +        new_kwargs = {"interval": self._interval, +                      "count": self._count, +                      "dtstart": self._dtstart, +                      "freq": self._freq, +                      "until": self._until, +                      "wkst": self._wkst, +                      "cache": False if self._cache is None else True } +        new_kwargs.update(self._original_rule) +        new_kwargs.update(kwargs) +        return rrule(**new_kwargs) + +    def _iter(self): +        year, month, day, hour, minute, second, weekday, yearday, _ = \ +            self._dtstart.timetuple() + +        # Some local variables to speed things up a bit +        freq = self._freq +        interval = self._interval +        wkst = self._wkst +        until = self._until +        bymonth = self._bymonth +        byweekno = self._byweekno +        byyearday = self._byyearday +        byweekday = self._byweekday +        byeaster = self._byeaster +        bymonthday = self._bymonthday +        bynmonthday = self._bynmonthday +        bysetpos = self._bysetpos +        byhour = self._byhour +        byminute = self._byminute +        bysecond = self._bysecond + +        ii = _iterinfo(self) +        ii.rebuild(year, month) + +        getdayset = {YEARLY: ii.ydayset, +                     MONTHLY: ii.mdayset, +                     WEEKLY: ii.wdayset, +                     DAILY: ii.ddayset, +                     HOURLY: ii.ddayset, +                     MINUTELY: ii.ddayset, +                     SECONDLY: ii.ddayset}[freq] + +        if freq < HOURLY: +            timeset = self._timeset +        else: +            gettimeset = {HOURLY: ii.htimeset, +                          MINUTELY: ii.mtimeset, +                          SECONDLY: ii.stimeset}[freq] +            if ((freq >= HOURLY and +                 self._byhour and hour not in self._byhour) or +                (freq >= MINUTELY and +                 self._byminute and minute not in self._byminute) or +                (freq >= SECONDLY and +                 self._bysecond and second not in self._bysecond)): +                timeset = () +            else: +                timeset = gettimeset(hour, minute, second) + +        total = 0 +        count = self._count +        while True: +            # Get dayset with the right frequency +            dayset, start, end = getdayset(year, month, day) + +            # Do the "hard" work ;-) +            filtered = False +            for i in dayset[start:end]: +                if ((bymonth and ii.mmask[i] not in bymonth) or +                    (byweekno and not ii.wnomask[i]) or +                    (byweekday and ii.wdaymask[i] not in byweekday) or +                    (ii.nwdaymask and not ii.nwdaymask[i]) or +                    (byeaster and not ii.eastermask[i]) or +                    ((bymonthday or bynmonthday) and +                     ii.mdaymask[i] not in bymonthday and +                     ii.nmdaymask[i] not in bynmonthday) or +                    (byyearday and +                     ((i < ii.yearlen and i+1 not in byyearday and +                       -ii.yearlen+i not in byyearday) or +                      (i >= ii.yearlen and i+1-ii.yearlen not in byyearday and +                       -ii.nextyearlen+i-ii.yearlen not in byyearday)))): +                    dayset[i] = None +                    filtered = True + +            # Output results +            if bysetpos and timeset: +                poslist = [] +                for pos in bysetpos: +                    if pos < 0: +                        daypos, timepos = divmod(pos, len(timeset)) +                    else: +                        daypos, timepos = divmod(pos-1, len(timeset)) +                    try: +                        i = [x for x in dayset[start:end] +                             if x is not None][daypos] +                        time = timeset[timepos] +                    except IndexError: +                        pass +                    else: +                        date = datetime.date.fromordinal(ii.yearordinal+i) +                        res = datetime.datetime.combine(date, time) +                        if res not in poslist: +                            poslist.append(res) +                poslist.sort() +                for res in poslist: +                    if until and res > until: +                        self._len = total +                        return +                    elif res >= self._dtstart: +                        if count is not None: +                            count -= 1 +                            if count < 0: +                                self._len = total +                                return +                        total += 1 +                        yield res +            else: +                for i in dayset[start:end]: +                    if i is not None: +                        date = datetime.date.fromordinal(ii.yearordinal + i) +                        for time in timeset: +                            res = datetime.datetime.combine(date, time) +                            if until and res > until: +                                self._len = total +                                return +                            elif res >= self._dtstart: +                                if count is not None: +                                    count -= 1 +                                    if count < 0: +                                        self._len = total +                                        return + +                                total += 1 +                                yield res + +            # Handle frequency and interval +            fixday = False +            if freq == YEARLY: +                year += interval +                if year > datetime.MAXYEAR: +                    self._len = total +                    return +                ii.rebuild(year, month) +            elif freq == MONTHLY: +                month += interval +                if month > 12: +                    div, mod = divmod(month, 12) +                    month = mod +                    year += div +                    if month == 0: +                        month = 12 +                        year -= 1 +                    if year > datetime.MAXYEAR: +                        self._len = total +                        return +                ii.rebuild(year, month) +            elif freq == WEEKLY: +                if wkst > weekday: +                    day += -(weekday+1+(6-wkst))+self._interval*7 +                else: +                    day += -(weekday-wkst)+self._interval*7 +                weekday = wkst +                fixday = True +            elif freq == DAILY: +                day += interval +                fixday = True +            elif freq == HOURLY: +                if filtered: +                    # Jump to one iteration before next day +                    hour += ((23-hour)//interval)*interval + +                if byhour: +                    ndays, hour = self.__mod_distance(value=hour, +                                                      byxxx=self._byhour, +                                                      base=24) +                else: +                    ndays, hour = divmod(hour+interval, 24) + +                if ndays: +                    day += ndays +                    fixday = True + +                timeset = gettimeset(hour, minute, second) +            elif freq == MINUTELY: +                if filtered: +                    # Jump to one iteration before next day +                    minute += ((1439-(hour*60+minute))//interval)*interval + +                valid = False +                rep_rate = (24*60) +                for j in range(rep_rate // gcd(interval, rep_rate)): +                    if byminute: +                        nhours, minute = \ +                            self.__mod_distance(value=minute, +                                                byxxx=self._byminute, +                                                base=60) +                    else: +                        nhours, minute = divmod(minute+interval, 60) + +                    div, hour = divmod(hour+nhours, 24) +                    if div: +                        day += div +                        fixday = True +                        filtered = False + +                    if not byhour or hour in byhour: +                        valid = True +                        break + +                if not valid: +                    raise ValueError('Invalid combination of interval and ' + +                                     'byhour resulting in empty rule.') + +                timeset = gettimeset(hour, minute, second) +            elif freq == SECONDLY: +                if filtered: +                    # Jump to one iteration before next day +                    second += (((86399 - (hour * 3600 + minute * 60 + second)) +                                // interval) * interval) + +                rep_rate = (24 * 3600) +                valid = False +                for j in range(0, rep_rate // gcd(interval, rep_rate)): +                    if bysecond: +                        nminutes, second = \ +                            self.__mod_distance(value=second, +                                                byxxx=self._bysecond, +                                                base=60) +                    else: +                        nminutes, second = divmod(second+interval, 60) + +                    div, minute = divmod(minute+nminutes, 60) +                    if div: +                        hour += div +                        div, hour = divmod(hour, 24) +                        if div: +                            day += div +                            fixday = True + +                    if ((not byhour or hour in byhour) and +                            (not byminute or minute in byminute) and +                            (not bysecond or second in bysecond)): +                        valid = True +                        break + +                if not valid: +                    raise ValueError('Invalid combination of interval, ' + +                                     'byhour and byminute resulting in empty' + +                                     ' rule.') + +                timeset = gettimeset(hour, minute, second) + +            if fixday and day > 28: +                daysinmonth = calendar.monthrange(year, month)[1] +                if day > daysinmonth: +                    while day > daysinmonth: +                        day -= daysinmonth +                        month += 1 +                        if month == 13: +                            month = 1 +                            year += 1 +                            if year > datetime.MAXYEAR: +                                self._len = total +                                return +                        daysinmonth = calendar.monthrange(year, month)[1] +                    ii.rebuild(year, month) + +    def __construct_byset(self, start, byxxx, base): +        """ +        If a `BYXXX` sequence is passed to the constructor at the same level as +        `FREQ` (e.g. `FREQ=HOURLY,BYHOUR={2,4,7},INTERVAL=3`), there are some +        specifications which cannot be reached given some starting conditions. + +        This occurs whenever the interval is not coprime with the base of a +        given unit and the difference between the starting position and the +        ending position is not coprime with the greatest common denominator +        between the interval and the base. For example, with a FREQ of hourly +        starting at 17:00 and an interval of 4, the only valid values for +        BYHOUR would be {21, 1, 5, 9, 13, 17}, because 4 and 24 are not +        coprime. + +        :param start: +            Specifies the starting position. +        :param byxxx: +            An iterable containing the list of allowed values. +        :param base: +            The largest allowable value for the specified frequency (e.g. +            24 hours, 60 minutes). + +        This does not preserve the type of the iterable, returning a set, since +        the values should be unique and the order is irrelevant, this will +        speed up later lookups. + +        In the event of an empty set, raises a :exception:`ValueError`, as this +        results in an empty rrule. +        """ + +        cset = set() + +        # Support a single byxxx value. +        if isinstance(byxxx, integer_types): +            byxxx = (byxxx, ) + +        for num in byxxx: +            i_gcd = gcd(self._interval, base) +            # Use divmod rather than % because we need to wrap negative nums. +            if i_gcd == 1 or divmod(num - start, i_gcd)[1] == 0: +                cset.add(num) + +        if len(cset) == 0: +            raise ValueError("Invalid rrule byxxx generates an empty set.") + +        return cset + +    def __mod_distance(self, value, byxxx, base): +        """ +        Calculates the next value in a sequence where the `FREQ` parameter is +        specified along with a `BYXXX` parameter at the same "level" +        (e.g. `HOURLY` specified with `BYHOUR`). + +        :param value: +            The old value of the component. +        :param byxxx: +            The `BYXXX` set, which should have been generated by +            `rrule._construct_byset`, or something else which checks that a +            valid rule is present. +        :param base: +            The largest allowable value for the specified frequency (e.g. +            24 hours, 60 minutes). + +        If a valid value is not found after `base` iterations (the maximum +        number before the sequence would start to repeat), this raises a +        :exception:`ValueError`, as no valid values were found. + +        This returns a tuple of `divmod(n*interval, base)`, where `n` is the +        smallest number of `interval` repetitions until the next specified +        value in `byxxx` is found. +        """ +        accumulator = 0 +        for ii in range(1, base + 1): +            # Using divmod() over % to account for negative intervals +            div, value = divmod(value + self._interval, base) +            accumulator += div +            if value in byxxx: +                return (accumulator, value) + + +class _iterinfo(object): +    __slots__ = ["rrule", "lastyear", "lastmonth", +                 "yearlen", "nextyearlen", "yearordinal", "yearweekday", +                 "mmask", "mrange", "mdaymask", "nmdaymask", +                 "wdaymask", "wnomask", "nwdaymask", "eastermask"] + +    def __init__(self, rrule): +        for attr in self.__slots__: +            setattr(self, attr, None) +        self.rrule = rrule + +    def rebuild(self, year, month): +        # Every mask is 7 days longer to handle cross-year weekly periods. +        rr = self.rrule +        if year != self.lastyear: +            self.yearlen = 365 + calendar.isleap(year) +            self.nextyearlen = 365 + calendar.isleap(year + 1) +            firstyday = datetime.date(year, 1, 1) +            self.yearordinal = firstyday.toordinal() +            self.yearweekday = firstyday.weekday() + +            wday = datetime.date(year, 1, 1).weekday() +            if self.yearlen == 365: +                self.mmask = M365MASK +                self.mdaymask = MDAY365MASK +                self.nmdaymask = NMDAY365MASK +                self.wdaymask = WDAYMASK[wday:] +                self.mrange = M365RANGE +            else: +                self.mmask = M366MASK +                self.mdaymask = MDAY366MASK +                self.nmdaymask = NMDAY366MASK +                self.wdaymask = WDAYMASK[wday:] +                self.mrange = M366RANGE + +            if not rr._byweekno: +                self.wnomask = None +            else: +                self.wnomask = [0]*(self.yearlen+7) +                # no1wkst = firstwkst = self.wdaymask.index(rr._wkst) +                no1wkst = firstwkst = (7-self.yearweekday+rr._wkst) % 7 +                if no1wkst >= 4: +                    no1wkst = 0 +                    # Number of days in the year, plus the days we got +                    # from last year. +                    wyearlen = self.yearlen+(self.yearweekday-rr._wkst) % 7 +                else: +                    # Number of days in the year, minus the days we +                    # left in last year. +                    wyearlen = self.yearlen-no1wkst +                div, mod = divmod(wyearlen, 7) +                numweeks = div+mod//4 +                for n in rr._byweekno: +                    if n < 0: +                        n += numweeks+1 +                    if not (0 < n <= numweeks): +                        continue +                    if n > 1: +                        i = no1wkst+(n-1)*7 +                        if no1wkst != firstwkst: +                            i -= 7-firstwkst +                    else: +                        i = no1wkst +                    for j in range(7): +                        self.wnomask[i] = 1 +                        i += 1 +                        if self.wdaymask[i] == rr._wkst: +                            break +                if 1 in rr._byweekno: +                    # Check week number 1 of next year as well +                    # TODO: Check -numweeks for next year. +                    i = no1wkst+numweeks*7 +                    if no1wkst != firstwkst: +                        i -= 7-firstwkst +                    if i < self.yearlen: +                        # If week starts in next year, we +                        # don't care about it. +                        for j in range(7): +                            self.wnomask[i] = 1 +                            i += 1 +                            if self.wdaymask[i] == rr._wkst: +                                break +                if no1wkst: +                    # Check last week number of last year as +                    # well. If no1wkst is 0, either the year +                    # started on week start, or week number 1 +                    # got days from last year, so there are no +                    # days from last year's last week number in +                    # this year. +                    if -1 not in rr._byweekno: +                        lyearweekday = datetime.date(year-1, 1, 1).weekday() +                        lno1wkst = (7-lyearweekday+rr._wkst) % 7 +                        lyearlen = 365+calendar.isleap(year-1) +                        if lno1wkst >= 4: +                            lno1wkst = 0 +                            lnumweeks = 52+(lyearlen + +                                            (lyearweekday-rr._wkst) % 7) % 7//4 +                        else: +                            lnumweeks = 52+(self.yearlen-no1wkst) % 7//4 +                    else: +                        lnumweeks = -1 +                    if lnumweeks in rr._byweekno: +                        for i in range(no1wkst): +                            self.wnomask[i] = 1 + +        if (rr._bynweekday and (month != self.lastmonth or +                                year != self.lastyear)): +            ranges = [] +            if rr._freq == YEARLY: +                if rr._bymonth: +                    for month in rr._bymonth: +                        ranges.append(self.mrange[month-1:month+1]) +                else: +                    ranges = [(0, self.yearlen)] +            elif rr._freq == MONTHLY: +                ranges = [self.mrange[month-1:month+1]] +            if ranges: +                # Weekly frequency won't get here, so we may not +                # care about cross-year weekly periods. +                self.nwdaymask = [0]*self.yearlen +                for first, last in ranges: +                    last -= 1 +                    for wday, n in rr._bynweekday: +                        if n < 0: +                            i = last+(n+1)*7 +                            i -= (self.wdaymask[i]-wday) % 7 +                        else: +                            i = first+(n-1)*7 +                            i += (7-self.wdaymask[i]+wday) % 7 +                        if first <= i <= last: +                            self.nwdaymask[i] = 1 + +        if rr._byeaster: +            self.eastermask = [0]*(self.yearlen+7) +            eyday = easter.easter(year).toordinal()-self.yearordinal +            for offset in rr._byeaster: +                self.eastermask[eyday+offset] = 1 + +        self.lastyear = year +        self.lastmonth = month + +    def ydayset(self, year, month, day): +        return list(range(self.yearlen)), 0, self.yearlen + +    def mdayset(self, year, month, day): +        dset = [None]*self.yearlen +        start, end = self.mrange[month-1:month+1] +        for i in range(start, end): +            dset[i] = i +        return dset, start, end + +    def wdayset(self, year, month, day): +        # We need to handle cross-year weeks here. +        dset = [None]*(self.yearlen+7) +        i = datetime.date(year, month, day).toordinal()-self.yearordinal +        start = i +        for j in range(7): +            dset[i] = i +            i += 1 +            # if (not (0 <= i < self.yearlen) or +            #    self.wdaymask[i] == self.rrule._wkst): +            # This will cross the year boundary, if necessary. +            if self.wdaymask[i] == self.rrule._wkst: +                break +        return dset, start, i + +    def ddayset(self, year, month, day): +        dset = [None] * self.yearlen +        i = datetime.date(year, month, day).toordinal() - self.yearordinal +        dset[i] = i +        return dset, i, i + 1 + +    def htimeset(self, hour, minute, second): +        tset = [] +        rr = self.rrule +        for minute in rr._byminute: +            for second in rr._bysecond: +                tset.append(datetime.time(hour, minute, second, +                                          tzinfo=rr._tzinfo)) +        tset.sort() +        return tset + +    def mtimeset(self, hour, minute, second): +        tset = [] +        rr = self.rrule +        for second in rr._bysecond: +            tset.append(datetime.time(hour, minute, second, tzinfo=rr._tzinfo)) +        tset.sort() +        return tset + +    def stimeset(self, hour, minute, second): +        return (datetime.time(hour, minute, second, +                tzinfo=self.rrule._tzinfo),) + + +class rruleset(rrulebase): +    """ The rruleset type allows more complex recurrence setups, mixing +    multiple rules, dates, exclusion rules, and exclusion dates. The type +    constructor takes the following keyword arguments: + +    :param cache: If True, caching of results will be enabled, improving +                  performance of multiple queries considerably. """ + +    class _genitem(object): +        def __init__(self, genlist, gen): +            try: +                self.dt = advance_iterator(gen) +                genlist.append(self) +            except StopIteration: +                pass +            self.genlist = genlist +            self.gen = gen + +        def __next__(self): +            try: +                self.dt = advance_iterator(self.gen) +            except StopIteration: +                if self.genlist[0] is self: +                    heapq.heappop(self.genlist) +                else: +                    self.genlist.remove(self) +                    heapq.heapify(self.genlist) + +        next = __next__ + +        def __lt__(self, other): +            return self.dt < other.dt + +        def __gt__(self, other): +            return self.dt > other.dt + +        def __eq__(self, other): +            return self.dt == other.dt + +        def __ne__(self, other): +            return self.dt != other.dt + +    def __init__(self, cache=False): +        super(rruleset, self).__init__(cache) +        self._rrule = [] +        self._rdate = [] +        self._exrule = [] +        self._exdate = [] + +    @_invalidates_cache +    def rrule(self, rrule): +        """ Include the given :py:class:`rrule` instance in the recurrence set +            generation. """ +        self._rrule.append(rrule) + +    @_invalidates_cache +    def rdate(self, rdate): +        """ Include the given :py:class:`datetime` instance in the recurrence +            set generation. """ +        self._rdate.append(rdate) + +    @_invalidates_cache +    def exrule(self, exrule): +        """ Include the given rrule instance in the recurrence set exclusion +            list. Dates which are part of the given recurrence rules will not +            be generated, even if some inclusive rrule or rdate matches them. +        """ +        self._exrule.append(exrule) + +    @_invalidates_cache +    def exdate(self, exdate): +        """ Include the given datetime instance in the recurrence set +            exclusion list. Dates included that way will not be generated, +            even if some inclusive rrule or rdate matches them. """ +        self._exdate.append(exdate) + +    def _iter(self): +        rlist = [] +        self._rdate.sort() +        self._genitem(rlist, iter(self._rdate)) +        for gen in [iter(x) for x in self._rrule]: +            self._genitem(rlist, gen) +        exlist = [] +        self._exdate.sort() +        self._genitem(exlist, iter(self._exdate)) +        for gen in [iter(x) for x in self._exrule]: +            self._genitem(exlist, gen) +        lastdt = None +        total = 0 +        heapq.heapify(rlist) +        heapq.heapify(exlist) +        while rlist: +            ritem = rlist[0] +            if not lastdt or lastdt != ritem.dt: +                while exlist and exlist[0] < ritem: +                    exitem = exlist[0] +                    advance_iterator(exitem) +                    if exlist and exlist[0] is exitem: +                        heapq.heapreplace(exlist, exitem) +                if not exlist or ritem != exlist[0]: +                    total += 1 +                    yield ritem.dt +                lastdt = ritem.dt +            advance_iterator(ritem) +            if rlist and rlist[0] is ritem: +                heapq.heapreplace(rlist, ritem) +        self._len = total + + +class _rrulestr(object): + +    _freq_map = {"YEARLY": YEARLY, +                 "MONTHLY": MONTHLY, +                 "WEEKLY": WEEKLY, +                 "DAILY": DAILY, +                 "HOURLY": HOURLY, +                 "MINUTELY": MINUTELY, +                 "SECONDLY": SECONDLY} + +    _weekday_map = {"MO": 0, "TU": 1, "WE": 2, "TH": 3, +                    "FR": 4, "SA": 5, "SU": 6} + +    def _handle_int(self, rrkwargs, name, value, **kwargs): +        rrkwargs[name.lower()] = int(value) + +    def _handle_int_list(self, rrkwargs, name, value, **kwargs): +        rrkwargs[name.lower()] = [int(x) for x in value.split(',')] + +    _handle_INTERVAL = _handle_int +    _handle_COUNT = _handle_int +    _handle_BYSETPOS = _handle_int_list +    _handle_BYMONTH = _handle_int_list +    _handle_BYMONTHDAY = _handle_int_list +    _handle_BYYEARDAY = _handle_int_list +    _handle_BYEASTER = _handle_int_list +    _handle_BYWEEKNO = _handle_int_list +    _handle_BYHOUR = _handle_int_list +    _handle_BYMINUTE = _handle_int_list +    _handle_BYSECOND = _handle_int_list + +    def _handle_FREQ(self, rrkwargs, name, value, **kwargs): +        rrkwargs["freq"] = self._freq_map[value] + +    def _handle_UNTIL(self, rrkwargs, name, value, **kwargs): +        global parser +        if not parser: +            from dateutil import parser +        try: +            rrkwargs["until"] = parser.parse(value, +                                             ignoretz=kwargs.get("ignoretz"), +                                             tzinfos=kwargs.get("tzinfos")) +        except ValueError: +            raise ValueError("invalid until date") + +    def _handle_WKST(self, rrkwargs, name, value, **kwargs): +        rrkwargs["wkst"] = self._weekday_map[value] + +    def _handle_BYWEEKDAY(self, rrkwargs, name, value, **kwargs): +        """ +        Two ways to specify this: +1MO or MO(+1) +        """ +        l = [] +        for wday in value.split(','): +            if '(' in wday: +                # If it's of the form TH(+1), etc. +                splt = wday.split('(') +                w = splt[0] +                n = int(splt[1][:-1]) +            elif len(wday): +                # If it's of the form +1MO +                for i in range(len(wday)): +                    if wday[i] not in '+-0123456789': +                        break +                n = wday[:i] or None +                w = wday[i:] +                if n: +                    n = int(n) +            else: +                raise ValueError("Invalid (empty) BYDAY specification.") + +            l.append(weekdays[self._weekday_map[w]](n)) +        rrkwargs["byweekday"] = l + +    _handle_BYDAY = _handle_BYWEEKDAY + +    def _parse_rfc_rrule(self, line, +                         dtstart=None, +                         cache=False, +                         ignoretz=False, +                         tzinfos=None): +        if line.find(':') != -1: +            name, value = line.split(':') +            if name != "RRULE": +                raise ValueError("unknown parameter name") +        else: +            value = line +        rrkwargs = {} +        for pair in value.split(';'): +            name, value = pair.split('=') +            name = name.upper() +            value = value.upper() +            try: +                getattr(self, "_handle_"+name)(rrkwargs, name, value, +                                               ignoretz=ignoretz, +                                               tzinfos=tzinfos) +            except AttributeError: +                raise ValueError("unknown parameter '%s'" % name) +            except (KeyError, ValueError): +                raise ValueError("invalid '%s': %s" % (name, value)) +        return rrule(dtstart=dtstart, cache=cache, **rrkwargs) + +    def _parse_rfc(self, s, +                   dtstart=None, +                   cache=False, +                   unfold=False, +                   forceset=False, +                   compatible=False, +                   ignoretz=False, +                   tzinfos=None): +        global parser +        if compatible: +            forceset = True +            unfold = True +        s = s.upper() +        if not s.strip(): +            raise ValueError("empty string") +        if unfold: +            lines = s.splitlines() +            i = 0 +            while i < len(lines): +                line = lines[i].rstrip() +                if not line: +                    del lines[i] +                elif i > 0 and line[0] == " ": +                    lines[i-1] += line[1:] +                    del lines[i] +                else: +                    i += 1 +        else: +            lines = s.split() +        if (not forceset and len(lines) == 1 and (s.find(':') == -1 or +                                                  s.startswith('RRULE:'))): +            return self._parse_rfc_rrule(lines[0], cache=cache, +                                         dtstart=dtstart, ignoretz=ignoretz, +                                         tzinfos=tzinfos) +        else: +            rrulevals = [] +            rdatevals = [] +            exrulevals = [] +            exdatevals = [] +            for line in lines: +                if not line: +                    continue +                if line.find(':') == -1: +                    name = "RRULE" +                    value = line +                else: +                    name, value = line.split(':', 1) +                parms = name.split(';') +                if not parms: +                    raise ValueError("empty property name") +                name = parms[0] +                parms = parms[1:] +                if name == "RRULE": +                    for parm in parms: +                        raise ValueError("unsupported RRULE parm: "+parm) +                    rrulevals.append(value) +                elif name == "RDATE": +                    for parm in parms: +                        if parm != "VALUE=DATE-TIME": +                            raise ValueError("unsupported RDATE parm: "+parm) +                    rdatevals.append(value) +                elif name == "EXRULE": +                    for parm in parms: +                        raise ValueError("unsupported EXRULE parm: "+parm) +                    exrulevals.append(value) +                elif name == "EXDATE": +                    for parm in parms: +                        if parm != "VALUE=DATE-TIME": +                            raise ValueError("unsupported EXDATE parm: "+parm) +                    exdatevals.append(value) +                elif name == "DTSTART": +                    for parm in parms: +                        raise ValueError("unsupported DTSTART parm: "+parm) +                    if not parser: +                        from dateutil import parser +                    dtstart = parser.parse(value, ignoretz=ignoretz, +                                           tzinfos=tzinfos) +                else: +                    raise ValueError("unsupported property: "+name) +            if (forceset or len(rrulevals) > 1 or rdatevals +                    or exrulevals or exdatevals): +                if not parser and (rdatevals or exdatevals): +                    from dateutil import parser +                rset = rruleset(cache=cache) +                for value in rrulevals: +                    rset.rrule(self._parse_rfc_rrule(value, dtstart=dtstart, +                                                     ignoretz=ignoretz, +                                                     tzinfos=tzinfos)) +                for value in rdatevals: +                    for datestr in value.split(','): +                        rset.rdate(parser.parse(datestr, +                                                ignoretz=ignoretz, +                                                tzinfos=tzinfos)) +                for value in exrulevals: +                    rset.exrule(self._parse_rfc_rrule(value, dtstart=dtstart, +                                                      ignoretz=ignoretz, +                                                      tzinfos=tzinfos)) +                for value in exdatevals: +                    for datestr in value.split(','): +                        rset.exdate(parser.parse(datestr, +                                                 ignoretz=ignoretz, +                                                 tzinfos=tzinfos)) +                if compatible and dtstart: +                    rset.rdate(dtstart) +                return rset +            else: +                return self._parse_rfc_rrule(rrulevals[0], +                                             dtstart=dtstart, +                                             cache=cache, +                                             ignoretz=ignoretz, +                                             tzinfos=tzinfos) + +    def __call__(self, s, **kwargs): +        return self._parse_rfc(s, **kwargs) + + +rrulestr = _rrulestr() + +# vim:ts=4:sw=4:et diff --git a/python/dateutil/tz/__init__.py b/python/dateutil/tz/__init__.py new file mode 100644 index 0000000..b0a5043 --- /dev/null +++ b/python/dateutil/tz/__init__.py @@ -0,0 +1,5 @@ +from .tz import * + +__all__ = ["tzutc", "tzoffset", "tzlocal", "tzfile", "tzrange", +           "tzstr", "tzical", "tzwin", "tzwinlocal", "gettz", +           "enfold", "datetime_ambiguous", "datetime_exists"] diff --git a/python/dateutil/tz/_common.py b/python/dateutil/tz/_common.py new file mode 100644 index 0000000..f1cf2af --- /dev/null +++ b/python/dateutil/tz/_common.py @@ -0,0 +1,394 @@ +from six import PY3 + +from functools import wraps + +from datetime import datetime, timedelta, tzinfo + + +ZERO = timedelta(0) + +__all__ = ['tzname_in_python2', 'enfold'] + + +def tzname_in_python2(namefunc): +    """Change unicode output into bytestrings in Python 2 + +    tzname() API changed in Python 3. It used to return bytes, but was changed +    to unicode strings +    """ +    def adjust_encoding(*args, **kwargs): +        name = namefunc(*args, **kwargs) +        if name is not None and not PY3: +            name = name.encode() + +        return name + +    return adjust_encoding + + +# The following is adapted from Alexander Belopolsky's tz library +# https://github.com/abalkin/tz +if hasattr(datetime, 'fold'): +    # This is the pre-python 3.6 fold situation +    def enfold(dt, fold=1): +        """ +        Provides a unified interface for assigning the ``fold`` attribute to +        datetimes both before and after the implementation of PEP-495. + +        :param fold: +            The value for the ``fold`` attribute in the returned datetime. This +            should be either 0 or 1. + +        :return: +            Returns an object for which ``getattr(dt, 'fold', 0)`` returns +            ``fold`` for all versions of Python. In versions prior to +            Python 3.6, this is a ``_DatetimeWithFold`` object, which is a +            subclass of :py:class:`datetime.datetime` with the ``fold`` +            attribute added, if ``fold`` is 1. + +        .. versionadded:: 2.6.0 +        """ +        return dt.replace(fold=fold) + +else: +    class _DatetimeWithFold(datetime): +        """ +        This is a class designed to provide a PEP 495-compliant interface for +        Python versions before 3.6. It is used only for dates in a fold, so +        the ``fold`` attribute is fixed at ``1``. + +        .. versionadded:: 2.6.0 +        """ +        __slots__ = () + +        @property +        def fold(self): +            return 1 + +    def enfold(dt, fold=1): +        """ +        Provides a unified interface for assigning the ``fold`` attribute to +        datetimes both before and after the implementation of PEP-495. + +        :param fold: +            The value for the ``fold`` attribute in the returned datetime. This +            should be either 0 or 1. + +        :return: +            Returns an object for which ``getattr(dt, 'fold', 0)`` returns +            ``fold`` for all versions of Python. In versions prior to +            Python 3.6, this is a ``_DatetimeWithFold`` object, which is a +            subclass of :py:class:`datetime.datetime` with the ``fold`` +            attribute added, if ``fold`` is 1. + +        .. versionadded:: 2.6.0 +        """ +        if getattr(dt, 'fold', 0) == fold: +            return dt + +        args = dt.timetuple()[:6] +        args += (dt.microsecond, dt.tzinfo) + +        if fold: +            return _DatetimeWithFold(*args) +        else: +            return datetime(*args) + + +def _validate_fromutc_inputs(f): +    """ +    The CPython version of ``fromutc`` checks that the input is a ``datetime`` +    object and that ``self`` is attached as its ``tzinfo``. +    """ +    @wraps(f) +    def fromutc(self, dt): +        if not isinstance(dt, datetime): +            raise TypeError("fromutc() requires a datetime argument") +        if dt.tzinfo is not self: +            raise ValueError("dt.tzinfo is not self") + +        return f(self, dt) + +    return fromutc + + +class _tzinfo(tzinfo): +    """ +    Base class for all ``dateutil`` ``tzinfo`` objects. +    """ + +    def is_ambiguous(self, dt): +        """ +        Whether or not the "wall time" of a given datetime is ambiguous in this +        zone. + +        :param dt: +            A :py:class:`datetime.datetime`, naive or time zone aware. + + +        :return: +            Returns ``True`` if ambiguous, ``False`` otherwise. + +        .. versionadded:: 2.6.0 +        """ + +        dt = dt.replace(tzinfo=self) + +        wall_0 = enfold(dt, fold=0) +        wall_1 = enfold(dt, fold=1) + +        same_offset = wall_0.utcoffset() == wall_1.utcoffset() +        same_dt = wall_0.replace(tzinfo=None) == wall_1.replace(tzinfo=None) + +        return same_dt and not same_offset + +    def _fold_status(self, dt_utc, dt_wall): +        """ +        Determine the fold status of a "wall" datetime, given a representation +        of the same datetime as a (naive) UTC datetime. This is calculated based +        on the assumption that ``dt.utcoffset() - dt.dst()`` is constant for all +        datetimes, and that this offset is the actual number of hours separating +        ``dt_utc`` and ``dt_wall``. + +        :param dt_utc: +            Representation of the datetime as UTC + +        :param dt_wall: +            Representation of the datetime as "wall time". This parameter must +            either have a `fold` attribute or have a fold-naive +            :class:`datetime.tzinfo` attached, otherwise the calculation may +            fail. +        """ +        if self.is_ambiguous(dt_wall): +            delta_wall = dt_wall - dt_utc +            _fold = int(delta_wall == (dt_utc.utcoffset() - dt_utc.dst())) +        else: +            _fold = 0 + +        return _fold + +    def _fold(self, dt): +        return getattr(dt, 'fold', 0) + +    def _fromutc(self, dt): +        """ +        Given a timezone-aware datetime in a given timezone, calculates a +        timezone-aware datetime in a new timezone. + +        Since this is the one time that we *know* we have an unambiguous +        datetime object, we take this opportunity to determine whether the +        datetime is ambiguous and in a "fold" state (e.g. if it's the first +        occurence, chronologically, of the ambiguous datetime). + +        :param dt: +            A timezone-aware :class:`datetime.datetime` object. +        """ + +        # Re-implement the algorithm from Python's datetime.py +        dtoff = dt.utcoffset() +        if dtoff is None: +            raise ValueError("fromutc() requires a non-None utcoffset() " +                             "result") + +        # The original datetime.py code assumes that `dst()` defaults to +        # zero during ambiguous times. PEP 495 inverts this presumption, so +        # for pre-PEP 495 versions of python, we need to tweak the algorithm. +        dtdst = dt.dst() +        if dtdst is None: +            raise ValueError("fromutc() requires a non-None dst() result") +        delta = dtoff - dtdst + +        dt += delta +        # Set fold=1 so we can default to being in the fold for +        # ambiguous dates. +        dtdst = enfold(dt, fold=1).dst() +        if dtdst is None: +            raise ValueError("fromutc(): dt.dst gave inconsistent " +                             "results; cannot convert") +        return dt + dtdst + +    @_validate_fromutc_inputs +    def fromutc(self, dt): +        """ +        Given a timezone-aware datetime in a given timezone, calculates a +        timezone-aware datetime in a new timezone. + +        Since this is the one time that we *know* we have an unambiguous +        datetime object, we take this opportunity to determine whether the +        datetime is ambiguous and in a "fold" state (e.g. if it's the first +        occurance, chronologically, of the ambiguous datetime). + +        :param dt: +            A timezone-aware :class:`datetime.datetime` object. +        """ +        dt_wall = self._fromutc(dt) + +        # Calculate the fold status given the two datetimes. +        _fold = self._fold_status(dt, dt_wall) + +        # Set the default fold value for ambiguous dates +        return enfold(dt_wall, fold=_fold) + + +class tzrangebase(_tzinfo): +    """ +    This is an abstract base class for time zones represented by an annual +    transition into and out of DST. Child classes should implement the following +    methods: + +        * ``__init__(self, *args, **kwargs)`` +        * ``transitions(self, year)`` - this is expected to return a tuple of +          datetimes representing the DST on and off transitions in standard +          time. + +    A fully initialized ``tzrangebase`` subclass should also provide the +    following attributes: +        * ``hasdst``: Boolean whether or not the zone uses DST. +        * ``_dst_offset`` / ``_std_offset``: :class:`datetime.timedelta` objects +          representing the respective UTC offsets. +        * ``_dst_abbr`` / ``_std_abbr``: Strings representing the timezone short +          abbreviations in DST and STD, respectively. +        * ``_hasdst``: Whether or not the zone has DST. + +    .. versionadded:: 2.6.0 +    """ +    def __init__(self): +        raise NotImplementedError('tzrangebase is an abstract base class') + +    def utcoffset(self, dt): +        isdst = self._isdst(dt) + +        if isdst is None: +            return None +        elif isdst: +            return self._dst_offset +        else: +            return self._std_offset + +    def dst(self, dt): +        isdst = self._isdst(dt) + +        if isdst is None: +            return None +        elif isdst: +            return self._dst_base_offset +        else: +            return ZERO + +    @tzname_in_python2 +    def tzname(self, dt): +        if self._isdst(dt): +            return self._dst_abbr +        else: +            return self._std_abbr + +    def fromutc(self, dt): +        """ Given a datetime in UTC, return local time """ +        if not isinstance(dt, datetime): +            raise TypeError("fromutc() requires a datetime argument") + +        if dt.tzinfo is not self: +            raise ValueError("dt.tzinfo is not self") + +        # Get transitions - if there are none, fixed offset +        transitions = self.transitions(dt.year) +        if transitions is None: +            return dt + self.utcoffset(dt) + +        # Get the transition times in UTC +        dston, dstoff = transitions + +        dston -= self._std_offset +        dstoff -= self._std_offset + +        utc_transitions = (dston, dstoff) +        dt_utc = dt.replace(tzinfo=None) + +        isdst = self._naive_isdst(dt_utc, utc_transitions) + +        if isdst: +            dt_wall = dt + self._dst_offset +        else: +            dt_wall = dt + self._std_offset + +        _fold = int(not isdst and self.is_ambiguous(dt_wall)) + +        return enfold(dt_wall, fold=_fold) + +    def is_ambiguous(self, dt): +        """ +        Whether or not the "wall time" of a given datetime is ambiguous in this +        zone. + +        :param dt: +            A :py:class:`datetime.datetime`, naive or time zone aware. + + +        :return: +            Returns ``True`` if ambiguous, ``False`` otherwise. + +        .. versionadded:: 2.6.0 +        """ +        if not self.hasdst: +            return False + +        start, end = self.transitions(dt.year) + +        dt = dt.replace(tzinfo=None) +        return (end <= dt < end + self._dst_base_offset) + +    def _isdst(self, dt): +        if not self.hasdst: +            return False +        elif dt is None: +            return None + +        transitions = self.transitions(dt.year) + +        if transitions is None: +            return False + +        dt = dt.replace(tzinfo=None) + +        isdst = self._naive_isdst(dt, transitions) + +        # Handle ambiguous dates +        if not isdst and self.is_ambiguous(dt): +            return not self._fold(dt) +        else: +            return isdst + +    def _naive_isdst(self, dt, transitions): +        dston, dstoff = transitions + +        dt = dt.replace(tzinfo=None) + +        if dston < dstoff: +            isdst = dston <= dt < dstoff +        else: +            isdst = not dstoff <= dt < dston + +        return isdst + +    @property +    def _dst_base_offset(self): +        return self._dst_offset - self._std_offset + +    __hash__ = None + +    def __ne__(self, other): +        return not (self == other) + +    def __repr__(self): +        return "%s(...)" % self.__class__.__name__ + +    __reduce__ = object.__reduce__ + + +def _total_seconds(td): +    # Python 2.6 doesn't have a total_seconds() method on timedelta objects +    return ((td.seconds + td.days * 86400) * 1000000 + +            td.microseconds) // 1000000 + + +_total_seconds = getattr(timedelta, 'total_seconds', _total_seconds) diff --git a/python/dateutil/tz/tz.py b/python/dateutil/tz/tz.py new file mode 100644 index 0000000..9468282 --- /dev/null +++ b/python/dateutil/tz/tz.py @@ -0,0 +1,1511 @@ +# -*- coding: utf-8 -*- +""" +This module offers timezone implementations subclassing the abstract +:py:`datetime.tzinfo` type. There are classes to handle tzfile format files +(usually are in :file:`/etc/localtime`, :file:`/usr/share/zoneinfo`, etc), TZ +environment string (in all known formats), given ranges (with help from +relative deltas), local machine timezone, fixed offset timezone, and UTC +timezone. +""" +import datetime +import struct +import time +import sys +import os +import bisect + +from six import string_types +from ._common import tzname_in_python2, _tzinfo, _total_seconds +from ._common import tzrangebase, enfold +from ._common import _validate_fromutc_inputs + +try: +    from .win import tzwin, tzwinlocal +except ImportError: +    tzwin = tzwinlocal = None + +ZERO = datetime.timedelta(0) +EPOCH = datetime.datetime.utcfromtimestamp(0) +EPOCHORDINAL = EPOCH.toordinal() + + +class tzutc(datetime.tzinfo): +    """ +    This is a tzinfo object that represents the UTC time zone. +    """ +    def utcoffset(self, dt): +        return ZERO + +    def dst(self, dt): +        return ZERO + +    @tzname_in_python2 +    def tzname(self, dt): +        return "UTC" + +    def is_ambiguous(self, dt): +        """ +        Whether or not the "wall time" of a given datetime is ambiguous in this +        zone. + +        :param dt: +            A :py:class:`datetime.datetime`, naive or time zone aware. + + +        :return: +            Returns ``True`` if ambiguous, ``False`` otherwise. + +        .. versionadded:: 2.6.0 +        """ +        return False + +    @_validate_fromutc_inputs +    def fromutc(self, dt): +        """ +        Fast track version of fromutc() returns the original ``dt`` object for +        any valid :py:class:`datetime.datetime` object. +        """ +        return dt + +    def __eq__(self, other): +        if not isinstance(other, (tzutc, tzoffset)): +            return NotImplemented + +        return (isinstance(other, tzutc) or +                (isinstance(other, tzoffset) and other._offset == ZERO)) + +    __hash__ = None + +    def __ne__(self, other): +        return not (self == other) + +    def __repr__(self): +        return "%s()" % self.__class__.__name__ + +    __reduce__ = object.__reduce__ + + +class tzoffset(datetime.tzinfo): +    """ +    A simple class for representing a fixed offset from UTC. + +    :param name: +        The timezone name, to be returned when ``tzname()`` is called. + +    :param offset: +        The time zone offset in seconds, or (since version 2.6.0, represented +        as a :py:class:`datetime.timedelta` object. +    """ +    def __init__(self, name, offset): +        self._name = name + +        try: +            # Allow a timedelta +            offset = _total_seconds(offset) +        except (TypeError, AttributeError): +            pass +        self._offset = datetime.timedelta(seconds=offset) + +    def utcoffset(self, dt): +        return self._offset + +    def dst(self, dt): +        return ZERO + +    @tzname_in_python2 +    def tzname(self, dt): +        return self._name + +    @_validate_fromutc_inputs +    def fromutc(self, dt): +        return dt + self._offset + +    def is_ambiguous(self, dt): +        """ +        Whether or not the "wall time" of a given datetime is ambiguous in this +        zone. + +        :param dt: +            A :py:class:`datetime.datetime`, naive or time zone aware. + + +        :return: +            Returns ``True`` if ambiguous, ``False`` otherwise. + +        .. versionadded:: 2.6.0 +        """ +        return False + +    def __eq__(self, other): +        if not isinstance(other, tzoffset): +            return NotImplemented + +        return self._offset == other._offset + +    __hash__ = None + +    def __ne__(self, other): +        return not (self == other) + +    def __repr__(self): +        return "%s(%s, %s)" % (self.__class__.__name__, +                               repr(self._name), +                               int(_total_seconds(self._offset))) + +    __reduce__ = object.__reduce__ + + +class tzlocal(_tzinfo): +    """ +    A :class:`tzinfo` subclass built around the ``time`` timezone functions. +    """ +    def __init__(self): +        super(tzlocal, self).__init__() + +        self._std_offset = datetime.timedelta(seconds=-time.timezone) +        if time.daylight: +            self._dst_offset = datetime.timedelta(seconds=-time.altzone) +        else: +            self._dst_offset = self._std_offset + +        self._dst_saved = self._dst_offset - self._std_offset +        self._hasdst = bool(self._dst_saved) + +    def utcoffset(self, dt): +        if dt is None and self._hasdst: +            return None + +        if self._isdst(dt): +            return self._dst_offset +        else: +            return self._std_offset + +    def dst(self, dt): +        if dt is None and self._hasdst: +            return None + +        if self._isdst(dt): +            return self._dst_offset - self._std_offset +        else: +            return ZERO + +    @tzname_in_python2 +    def tzname(self, dt): +        return time.tzname[self._isdst(dt)] + +    def is_ambiguous(self, dt): +        """ +        Whether or not the "wall time" of a given datetime is ambiguous in this +        zone. + +        :param dt: +            A :py:class:`datetime.datetime`, naive or time zone aware. + + +        :return: +            Returns ``True`` if ambiguous, ``False`` otherwise. + +        .. versionadded:: 2.6.0 +        """ +        naive_dst = self._naive_is_dst(dt) +        return (not naive_dst and +                (naive_dst != self._naive_is_dst(dt - self._dst_saved))) + +    def _naive_is_dst(self, dt): +        timestamp = _datetime_to_timestamp(dt) +        return time.localtime(timestamp + time.timezone).tm_isdst + +    def _isdst(self, dt, fold_naive=True): +        # We can't use mktime here. It is unstable when deciding if +        # the hour near to a change is DST or not. +        # +        # timestamp = time.mktime((dt.year, dt.month, dt.day, dt.hour, +        #                         dt.minute, dt.second, dt.weekday(), 0, -1)) +        # return time.localtime(timestamp).tm_isdst +        # +        # The code above yields the following result: +        # +        # >>> import tz, datetime +        # >>> t = tz.tzlocal() +        # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() +        # 'BRDT' +        # >>> datetime.datetime(2003,2,16,0,tzinfo=t).tzname() +        # 'BRST' +        # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() +        # 'BRST' +        # >>> datetime.datetime(2003,2,15,22,tzinfo=t).tzname() +        # 'BRDT' +        # >>> datetime.datetime(2003,2,15,23,tzinfo=t).tzname() +        # 'BRDT' +        # +        # Here is a more stable implementation: +        # +        if not self._hasdst: +            return False + +        # Check for ambiguous times: +        dstval = self._naive_is_dst(dt) +        fold = getattr(dt, 'fold', None) + +        if self.is_ambiguous(dt): +            if fold is not None: +                return not self._fold(dt) +            else: +                return True + +        return dstval + +    def __eq__(self, other): +        if not isinstance(other, tzlocal): +            return NotImplemented + +        return (self._std_offset == other._std_offset and +                self._dst_offset == other._dst_offset) + +    __hash__ = None + +    def __ne__(self, other): +        return not (self == other) + +    def __repr__(self): +        return "%s()" % self.__class__.__name__ + +    __reduce__ = object.__reduce__ + + +class _ttinfo(object): +    __slots__ = ["offset", "delta", "isdst", "abbr", +                 "isstd", "isgmt", "dstoffset"] + +    def __init__(self): +        for attr in self.__slots__: +            setattr(self, attr, None) + +    def __repr__(self): +        l = [] +        for attr in self.__slots__: +            value = getattr(self, attr) +            if value is not None: +                l.append("%s=%s" % (attr, repr(value))) +        return "%s(%s)" % (self.__class__.__name__, ", ".join(l)) + +    def __eq__(self, other): +        if not isinstance(other, _ttinfo): +            return NotImplemented + +        return (self.offset == other.offset and +                self.delta == other.delta and +                self.isdst == other.isdst and +                self.abbr == other.abbr and +                self.isstd == other.isstd and +                self.isgmt == other.isgmt and +                self.dstoffset == other.dstoffset) + +    __hash__ = None + +    def __ne__(self, other): +        return not (self == other) + +    def __getstate__(self): +        state = {} +        for name in self.__slots__: +            state[name] = getattr(self, name, None) +        return state + +    def __setstate__(self, state): +        for name in self.__slots__: +            if name in state: +                setattr(self, name, state[name]) + + +class _tzfile(object): +    """ +    Lightweight class for holding the relevant transition and time zone +    information read from binary tzfiles. +    """ +    attrs = ['trans_list', 'trans_list_utc', 'trans_idx', 'ttinfo_list', +             'ttinfo_std', 'ttinfo_dst', 'ttinfo_before', 'ttinfo_first'] + +    def __init__(self, **kwargs): +        for attr in self.attrs: +            setattr(self, attr, kwargs.get(attr, None)) + + +class tzfile(_tzinfo): +    """ +    This is a ``tzinfo`` subclass thant allows one to use the ``tzfile(5)`` +    format timezone files to extract current and historical zone information. + +    :param fileobj: +        This can be an opened file stream or a file name that the time zone +        information can be read from. + +    :param filename: +        This is an optional parameter specifying the source of the time zone +        information in the event that ``fileobj`` is a file object. If omitted +        and ``fileobj`` is a file stream, this parameter will be set either to +        ``fileobj``'s ``name`` attribute or to ``repr(fileobj)``. + +    See `Sources for Time Zone and Daylight Saving Time Data +    <http://www.twinsun.com/tz/tz-link.htm>`_ for more information. Time zone +    files can be compiled from the `IANA Time Zone database files +    <https://www.iana.org/time-zones>`_ with the `zic time zone compiler +    <https://www.freebsd.org/cgi/man.cgi?query=zic&sektion=8>`_ +    """ + +    def __init__(self, fileobj, filename=None): +        super(tzfile, self).__init__() + +        file_opened_here = False +        if isinstance(fileobj, string_types): +            self._filename = fileobj +            fileobj = open(fileobj, 'rb') +            file_opened_here = True +        elif filename is not None: +            self._filename = filename +        elif hasattr(fileobj, "name"): +            self._filename = fileobj.name +        else: +            self._filename = repr(fileobj) + +        if fileobj is not None: +            if not file_opened_here: +                fileobj = _ContextWrapper(fileobj) + +            with fileobj as file_stream: +                tzobj = self._read_tzfile(file_stream) + +            self._set_tzdata(tzobj) + +    def _set_tzdata(self, tzobj): +        """ Set the time zone data of this object from a _tzfile object """ +        # Copy the relevant attributes over as private attributes +        for attr in _tzfile.attrs: +            setattr(self, '_' + attr, getattr(tzobj, attr)) + +    def _read_tzfile(self, fileobj): +        out = _tzfile() + +        # From tzfile(5): +        # +        # The time zone information files used by tzset(3) +        # begin with the magic characters "TZif" to identify +        # them as time zone information files, followed by +        # sixteen bytes reserved for future use, followed by +        # six four-byte values of type long, written in a +        # ``standard'' byte order (the high-order  byte +        # of the value is written first). +        if fileobj.read(4).decode() != "TZif": +            raise ValueError("magic not found") + +        fileobj.read(16) + +        ( +            # The number of UTC/local indicators stored in the file. +            ttisgmtcnt, + +            # The number of standard/wall indicators stored in the file. +            ttisstdcnt, + +            # The number of leap seconds for which data is +            # stored in the file. +            leapcnt, + +            # The number of "transition times" for which data +            # is stored in the file. +            timecnt, + +            # The number of "local time types" for which data +            # is stored in the file (must not be zero). +            typecnt, + +            # The  number  of  characters  of "time zone +            # abbreviation strings" stored in the file. +            charcnt, + +        ) = struct.unpack(">6l", fileobj.read(24)) + +        # The above header is followed by tzh_timecnt four-byte +        # values  of  type long,  sorted  in ascending order. +        # These values are written in ``standard'' byte order. +        # Each is used as a transition time (as  returned  by +        # time(2)) at which the rules for computing local time +        # change. + +        if timecnt: +            out.trans_list_utc = list(struct.unpack(">%dl" % timecnt, +                                                    fileobj.read(timecnt*4))) +        else: +            out.trans_list_utc = [] + +        # Next come tzh_timecnt one-byte values of type unsigned +        # char; each one tells which of the different types of +        # ``local time'' types described in the file is associated +        # with the same-indexed transition time. These values +        # serve as indices into an array of ttinfo structures that +        # appears next in the file. + +        if timecnt: +            out.trans_idx = struct.unpack(">%dB" % timecnt, +                                            fileobj.read(timecnt)) +        else: +            out.trans_idx = [] + +        # Each ttinfo structure is written as a four-byte value +        # for tt_gmtoff  of  type long,  in  a  standard  byte +        # order, followed  by a one-byte value for tt_isdst +        # and a one-byte  value  for  tt_abbrind.   In  each +        # structure, tt_gmtoff  gives  the  number  of +        # seconds to be added to UTC, tt_isdst tells whether +        # tm_isdst should be set by  localtime(3),  and +        # tt_abbrind serves  as an index into the array of +        # time zone abbreviation characters that follow the +        # ttinfo structure(s) in the file. + +        ttinfo = [] + +        for i in range(typecnt): +            ttinfo.append(struct.unpack(">lbb", fileobj.read(6))) + +        abbr = fileobj.read(charcnt).decode() + +        # Then there are tzh_leapcnt pairs of four-byte +        # values, written in  standard byte  order;  the +        # first  value  of  each pair gives the time (as +        # returned by time(2)) at which a leap second +        # occurs;  the  second  gives the  total  number of +        # leap seconds to be applied after the given time. +        # The pairs of values are sorted in ascending order +        # by time. + +        # Not used, for now (but seek for correct file position) +        if leapcnt: +            fileobj.seek(leapcnt * 8, os.SEEK_CUR) + +        # Then there are tzh_ttisstdcnt standard/wall +        # indicators, each stored as a one-byte value; +        # they tell whether the transition times associated +        # with local time types were specified as standard +        # time or wall clock time, and are used when +        # a time zone file is used in handling POSIX-style +        # time zone environment variables. + +        if ttisstdcnt: +            isstd = struct.unpack(">%db" % ttisstdcnt, +                                  fileobj.read(ttisstdcnt)) + +        # Finally, there are tzh_ttisgmtcnt UTC/local +        # indicators, each stored as a one-byte value; +        # they tell whether the transition times associated +        # with local time types were specified as UTC or +        # local time, and are used when a time zone file +        # is used in handling POSIX-style time zone envi- +        # ronment variables. + +        if ttisgmtcnt: +            isgmt = struct.unpack(">%db" % ttisgmtcnt, +                                  fileobj.read(ttisgmtcnt)) + +        # Build ttinfo list +        out.ttinfo_list = [] +        for i in range(typecnt): +            gmtoff, isdst, abbrind = ttinfo[i] +            # Round to full-minutes if that's not the case. Python's +            # datetime doesn't accept sub-minute timezones. Check +            # http://python.org/sf/1447945 for some information. +            gmtoff = 60 * ((gmtoff + 30) // 60) +            tti = _ttinfo() +            tti.offset = gmtoff +            tti.dstoffset = datetime.timedelta(0) +            tti.delta = datetime.timedelta(seconds=gmtoff) +            tti.isdst = isdst +            tti.abbr = abbr[abbrind:abbr.find('\x00', abbrind)] +            tti.isstd = (ttisstdcnt > i and isstd[i] != 0) +            tti.isgmt = (ttisgmtcnt > i and isgmt[i] != 0) +            out.ttinfo_list.append(tti) + +        # Replace ttinfo indexes for ttinfo objects. +        out.trans_idx = [out.ttinfo_list[idx] for idx in out.trans_idx] + +        # Set standard, dst, and before ttinfos. before will be +        # used when a given time is before any transitions, +        # and will be set to the first non-dst ttinfo, or to +        # the first dst, if all of them are dst. +        out.ttinfo_std = None +        out.ttinfo_dst = None +        out.ttinfo_before = None +        if out.ttinfo_list: +            if not out.trans_list_utc: +                out.ttinfo_std = out.ttinfo_first = out.ttinfo_list[0] +            else: +                for i in range(timecnt-1, -1, -1): +                    tti = out.trans_idx[i] +                    if not out.ttinfo_std and not tti.isdst: +                        out.ttinfo_std = tti +                    elif not out.ttinfo_dst and tti.isdst: +                        out.ttinfo_dst = tti + +                    if out.ttinfo_std and out.ttinfo_dst: +                        break +                else: +                    if out.ttinfo_dst and not out.ttinfo_std: +                        out.ttinfo_std = out.ttinfo_dst + +                for tti in out.ttinfo_list: +                    if not tti.isdst: +                        out.ttinfo_before = tti +                        break +                else: +                    out.ttinfo_before = out.ttinfo_list[0] + +        # Now fix transition times to become relative to wall time. +        # +        # I'm not sure about this. In my tests, the tz source file +        # is setup to wall time, and in the binary file isstd and +        # isgmt are off, so it should be in wall time. OTOH, it's +        # always in gmt time. Let me know if you have comments +        # about this. +        laststdoffset = None +        out.trans_list = [] +        for i, tti in enumerate(out.trans_idx): +            if not tti.isdst: +                offset = tti.offset +                laststdoffset = offset +            else: +                if laststdoffset is not None: +                    # Store the DST offset as well and update it in the list +                    tti.dstoffset = tti.offset - laststdoffset +                    out.trans_idx[i] = tti + +                offset = laststdoffset or 0 + +            out.trans_list.append(out.trans_list_utc[i] + offset) + +        # In case we missed any DST offsets on the way in for some reason, make +        # a second pass over the list, looking for the /next/ DST offset. +        laststdoffset = None +        for i in reversed(range(len(out.trans_idx))): +            tti = out.trans_idx[i] +            if tti.isdst: +                if not (tti.dstoffset or laststdoffset is None): +                    tti.dstoffset = tti.offset - laststdoffset +            else: +                laststdoffset = tti.offset + +            if not isinstance(tti.dstoffset, datetime.timedelta): +                tti.dstoffset = datetime.timedelta(seconds=tti.dstoffset) + +            out.trans_idx[i] = tti + +        out.trans_idx = tuple(out.trans_idx) +        out.trans_list = tuple(out.trans_list) +        out.trans_list_utc = tuple(out.trans_list_utc) + +        return out + +    def _find_last_transition(self, dt, in_utc=False): +        # If there's no list, there are no transitions to find +        if not self._trans_list: +            return None + +        timestamp = _datetime_to_timestamp(dt) + +        # Find where the timestamp fits in the transition list - if the +        # timestamp is a transition time, it's part of the "after" period. +        trans_list = self._trans_list_utc if in_utc else self._trans_list +        idx = bisect.bisect_right(trans_list, timestamp) + +        # We want to know when the previous transition was, so subtract off 1 +        return idx - 1         + +    def _get_ttinfo(self, idx): +        # For no list or after the last transition, default to _ttinfo_std +        if idx is None or (idx + 1) >= len(self._trans_list): +            return self._ttinfo_std + +        # If there is a list and the time is before it, return _ttinfo_before +        if idx < 0: +            return self._ttinfo_before + +        return self._trans_idx[idx] + +    def _find_ttinfo(self, dt): +        idx = self._resolve_ambiguous_time(dt) + +        return self._get_ttinfo(idx) + +    def fromutc(self, dt): +        """ +        The ``tzfile`` implementation of :py:func:`datetime.tzinfo.fromutc`. + +        :param dt: +            A :py:class:`datetime.datetime` object. + +        :raises TypeError: +            Raised if ``dt`` is not a :py:class:`datetime.datetime` object. + +        :raises ValueError: +            Raised if this is called with a ``dt`` which does not have this +            ``tzinfo`` attached. + +        :return: +            Returns a :py:class:`datetime.datetime` object representing the +            wall time in ``self``'s time zone. +        """ +        # These isinstance checks are in datetime.tzinfo, so we'll preserve +        # them, even if we don't care about duck typing. +        if not isinstance(dt, datetime.datetime): +            raise TypeError("fromutc() requires a datetime argument") + +        if dt.tzinfo is not self: +            raise ValueError("dt.tzinfo is not self") + +        # First treat UTC as wall time and get the transition we're in. +        idx = self._find_last_transition(dt, in_utc=True) +        tti = self._get_ttinfo(idx) + +        dt_out = dt + datetime.timedelta(seconds=tti.offset) + +        fold = self.is_ambiguous(dt_out, idx=idx) + +        return enfold(dt_out, fold=int(fold)) + +    def is_ambiguous(self, dt, idx=None): +        """ +        Whether or not the "wall time" of a given datetime is ambiguous in this +        zone. + +        :param dt: +            A :py:class:`datetime.datetime`, naive or time zone aware. + + +        :return: +            Returns ``True`` if ambiguous, ``False`` otherwise. + +        .. versionadded:: 2.6.0 +        """ +        if idx is None: +            idx = self._find_last_transition(dt) + +        # Calculate the difference in offsets from current to previous +        timestamp = _datetime_to_timestamp(dt) +        tti = self._get_ttinfo(idx) + +        if idx is None or idx <= 0: +            return False + +        od = self._get_ttinfo(idx - 1).offset - tti.offset +        tt = self._trans_list[idx]          # Transition time + +        return timestamp < tt + od + +    def _resolve_ambiguous_time(self, dt): +        idx = self._find_last_transition(dt) + +        # If we have no transitions, return the index +        _fold = self._fold(dt) +        if idx is None or idx == 0: +            return idx + +        # If it's ambiguous and we're in a fold, shift to a different index. +        idx_offset = int(not _fold and self.is_ambiguous(dt, idx)) + +        return idx - idx_offset + +    def utcoffset(self, dt): +        if dt is None: +            return None + +        if not self._ttinfo_std: +            return ZERO + +        return self._find_ttinfo(dt).delta + +    def dst(self, dt): +        if dt is None: +            return None + +        if not self._ttinfo_dst: +            return ZERO + +        tti = self._find_ttinfo(dt) + +        if not tti.isdst: +            return ZERO + +        # The documentation says that utcoffset()-dst() must +        # be constant for every dt. +        return tti.dstoffset + +    @tzname_in_python2 +    def tzname(self, dt): +        if not self._ttinfo_std or dt is None: +            return None +        return self._find_ttinfo(dt).abbr + +    def __eq__(self, other): +        if not isinstance(other, tzfile): +            return NotImplemented +        return (self._trans_list == other._trans_list and +                self._trans_idx == other._trans_idx and +                self._ttinfo_list == other._ttinfo_list) + +    __hash__ = None + +    def __ne__(self, other): +        return not (self == other) + +    def __repr__(self): +        return "%s(%s)" % (self.__class__.__name__, repr(self._filename)) + +    def __reduce__(self): +        return self.__reduce_ex__(None) + +    def __reduce_ex__(self, protocol): +        return (self.__class__, (None, self._filename), self.__dict__) + + +class tzrange(tzrangebase): +    """ +    The ``tzrange`` object is a time zone specified by a set of offsets and +    abbreviations, equivalent to the way the ``TZ`` variable can be specified +    in POSIX-like systems, but using Python delta objects to specify DST +    start, end and offsets. + +    :param stdabbr: +        The abbreviation for standard time (e.g. ``'EST'``). + +    :param stdoffset: +        An integer or :class:`datetime.timedelta` object or equivalent +        specifying the base offset from UTC. + +        If unspecified, +00:00 is used. + +    :param dstabbr: +        The abbreviation for DST / "Summer" time (e.g. ``'EDT'``). + +        If specified, with no other DST information, DST is assumed to occur +        and the default behavior or ``dstoffset``, ``start`` and ``end`` is +        used. If unspecified and no other DST information is specified, it +        is assumed that this zone has no DST. + +        If this is unspecified and other DST information is *is* specified, +        DST occurs in the zone but the time zone abbreviation is left +        unchanged. + +    :param dstoffset: +        A an integer or :class:`datetime.timedelta` object or equivalent +        specifying the UTC offset during DST. If unspecified and any other DST +        information is specified, it is assumed to be the STD offset +1 hour. + +    :param start: +        A :class:`relativedelta.relativedelta` object or equivalent specifying +        the time and time of year that daylight savings time starts. To specify, +        for example, that DST starts at 2AM on the 2nd Sunday in March, pass: + +            ``relativedelta(hours=2, month=3, day=1, weekday=SU(+2))`` + +        If unspecified and any other DST information is specified, the default +        value is 2 AM on the first Sunday in April. + +    :param end: +        A :class:`relativedelta.relativedelta` object or equivalent representing +        the time and time of year that daylight savings time ends, with the +        same specification method as in ``start``. One note is that this should +        point to the first time in the *standard* zone, so if a transition +        occurs at 2AM in the DST zone and the clocks are set back 1 hour to 1AM, +        set the `hours` parameter to +1. + + +    **Examples:** + +    .. testsetup:: tzrange + +        from dateutil.tz import tzrange, tzstr + +    .. doctest:: tzrange + +        >>> tzstr('EST5EDT') == tzrange("EST", -18000, "EDT") +        True + +        >>> from dateutil.relativedelta import * +        >>> range1 = tzrange("EST", -18000, "EDT") +        >>> range2 = tzrange("EST", -18000, "EDT", -14400, +        ...                  relativedelta(hours=+2, month=4, day=1, +        ...                                weekday=SU(+1)), +        ...                  relativedelta(hours=+1, month=10, day=31, +        ...                                weekday=SU(-1))) +        >>> tzstr('EST5EDT') == range1 == range2 +        True + +    """ +    def __init__(self, stdabbr, stdoffset=None, +                 dstabbr=None, dstoffset=None, +                 start=None, end=None): + +        global relativedelta +        from dateutil import relativedelta + +        self._std_abbr = stdabbr +        self._dst_abbr = dstabbr + +        try: +            stdoffset = _total_seconds(stdoffset) +        except (TypeError, AttributeError): +            pass + +        try: +            dstoffset = _total_seconds(dstoffset) +        except (TypeError, AttributeError): +            pass + +        if stdoffset is not None: +            self._std_offset = datetime.timedelta(seconds=stdoffset) +        else: +            self._std_offset = ZERO + +        if dstoffset is not None: +            self._dst_offset = datetime.timedelta(seconds=dstoffset) +        elif dstabbr and stdoffset is not None: +            self._dst_offset = self._std_offset + datetime.timedelta(hours=+1) +        else: +            self._dst_offset = ZERO + +        if dstabbr and start is None: +            self._start_delta = relativedelta.relativedelta( +                hours=+2, month=4, day=1, weekday=relativedelta.SU(+1)) +        else: +            self._start_delta = start + +        if dstabbr and end is None: +            self._end_delta = relativedelta.relativedelta( +                hours=+1, month=10, day=31, weekday=relativedelta.SU(-1)) +        else: +            self._end_delta = end + +        self._dst_base_offset_ = self._dst_offset - self._std_offset +        self.hasdst = bool(self._start_delta) + +    def transitions(self, year): +        """ +        For a given year, get the DST on and off transition times, expressed +        always on the standard time side. For zones with no transitions, this +        function returns ``None``. + +        :param year: +            The year whose transitions you would like to query. + +        :return: +            Returns a :class:`tuple` of :class:`datetime.datetime` objects, +            ``(dston, dstoff)`` for zones with an annual DST transition, or +            ``None`` for fixed offset zones. +        """ +        if not self.hasdst: +            return None + +        base_year = datetime.datetime(year, 1, 1) + +        start = base_year + self._start_delta +        end = base_year + self._end_delta + +        return (start, end) + +    def __eq__(self, other): +        if not isinstance(other, tzrange): +            return NotImplemented + +        return (self._std_abbr == other._std_abbr and +                self._dst_abbr == other._dst_abbr and +                self._std_offset == other._std_offset and +                self._dst_offset == other._dst_offset and +                self._start_delta == other._start_delta and +                self._end_delta == other._end_delta) + +    @property +    def _dst_base_offset(self): +        return self._dst_base_offset_ + + +class tzstr(tzrange): +    """ +    ``tzstr`` objects are time zone objects specified by a time-zone string as +    it would be passed to a ``TZ`` variable on POSIX-style systems (see +    the `GNU C Library: TZ Variable`_ for more details). + +    There is one notable exception, which is that POSIX-style time zones use an +    inverted offset format, so normally ``GMT+3`` would be parsed as an offset +    3 hours *behind* GMT. The ``tzstr`` time zone object will parse this as an +    offset 3 hours *ahead* of GMT. If you would like to maintain the POSIX +    behavior, pass a ``True`` value to ``posix_offset``. + +    The :class:`tzrange` object provides the same functionality, but is +    specified using :class:`relativedelta.relativedelta` objects. rather than +    strings. + +    :param s: +        A time zone string in ``TZ`` variable format. This can be a +        :class:`bytes` (2.x: :class:`str`), :class:`str` (2.x: :class:`unicode`) +        or a stream emitting unicode characters (e.g. :class:`StringIO`). + +    :param posix_offset: +        Optional. If set to ``True``, interpret strings such as ``GMT+3`` or +        ``UTC+3`` as being 3 hours *behind* UTC rather than ahead, per the +        POSIX standard. + +    .. _`GNU C Library: TZ Variable`: +        https://www.gnu.org/software/libc/manual/html_node/TZ-Variable.html +    """ +    def __init__(self, s, posix_offset=False): +        global parser +        from dateutil import parser + +        self._s = s + +        res = parser._parsetz(s) +        if res is None: +            raise ValueError("unknown string format") + +        # Here we break the compatibility with the TZ variable handling. +        # GMT-3 actually *means* the timezone -3. +        if res.stdabbr in ("GMT", "UTC") and not posix_offset: +            res.stdoffset *= -1 + +        # We must initialize it first, since _delta() needs +        # _std_offset and _dst_offset set. Use False in start/end +        # to avoid building it two times. +        tzrange.__init__(self, res.stdabbr, res.stdoffset, +                         res.dstabbr, res.dstoffset, +                         start=False, end=False) + +        if not res.dstabbr: +            self._start_delta = None +            self._end_delta = None +        else: +            self._start_delta = self._delta(res.start) +            if self._start_delta: +                self._end_delta = self._delta(res.end, isend=1) + +        self.hasdst = bool(self._start_delta) + +    def _delta(self, x, isend=0): +        from dateutil import relativedelta +        kwargs = {} +        if x.month is not None: +            kwargs["month"] = x.month +            if x.weekday is not None: +                kwargs["weekday"] = relativedelta.weekday(x.weekday, x.week) +                if x.week > 0: +                    kwargs["day"] = 1 +                else: +                    kwargs["day"] = 31 +            elif x.day: +                kwargs["day"] = x.day +        elif x.yday is not None: +            kwargs["yearday"] = x.yday +        elif x.jyday is not None: +            kwargs["nlyearday"] = x.jyday +        if not kwargs: +            # Default is to start on first sunday of april, and end +            # on last sunday of october. +            if not isend: +                kwargs["month"] = 4 +                kwargs["day"] = 1 +                kwargs["weekday"] = relativedelta.SU(+1) +            else: +                kwargs["month"] = 10 +                kwargs["day"] = 31 +                kwargs["weekday"] = relativedelta.SU(-1) +        if x.time is not None: +            kwargs["seconds"] = x.time +        else: +            # Default is 2AM. +            kwargs["seconds"] = 7200 +        if isend: +            # Convert to standard time, to follow the documented way +            # of working with the extra hour. See the documentation +            # of the tzinfo class. +            delta = self._dst_offset - self._std_offset +            kwargs["seconds"] -= delta.seconds + delta.days * 86400 +        return relativedelta.relativedelta(**kwargs) + +    def __repr__(self): +        return "%s(%s)" % (self.__class__.__name__, repr(self._s)) + + +class _tzicalvtzcomp(object): +    def __init__(self, tzoffsetfrom, tzoffsetto, isdst, +                 tzname=None, rrule=None): +        self.tzoffsetfrom = datetime.timedelta(seconds=tzoffsetfrom) +        self.tzoffsetto = datetime.timedelta(seconds=tzoffsetto) +        self.tzoffsetdiff = self.tzoffsetto - self.tzoffsetfrom +        self.isdst = isdst +        self.tzname = tzname +        self.rrule = rrule + + +class _tzicalvtz(_tzinfo): +    def __init__(self, tzid, comps=[]): +        super(_tzicalvtz, self).__init__() + +        self._tzid = tzid +        self._comps = comps +        self._cachedate = [] +        self._cachecomp = [] + +    def _find_comp(self, dt): +        if len(self._comps) == 1: +            return self._comps[0] + +        dt = dt.replace(tzinfo=None) + +        try: +            return self._cachecomp[self._cachedate.index((dt, self._fold(dt)))] +        except ValueError: +            pass + +        lastcompdt = None +        lastcomp = None + +        for comp in self._comps: +            compdt = self._find_compdt(comp, dt) + +            if compdt and (not lastcompdt or lastcompdt < compdt): +                lastcompdt = compdt +                lastcomp = comp + +        if not lastcomp: +            # RFC says nothing about what to do when a given +            # time is before the first onset date. We'll look for the +            # first standard component, or the first component, if +            # none is found. +            for comp in self._comps: +                if not comp.isdst: +                    lastcomp = comp +                    break +            else: +                lastcomp = comp[0] + +        self._cachedate.insert(0, (dt, self._fold(dt))) +        self._cachecomp.insert(0, lastcomp) + +        if len(self._cachedate) > 10: +            self._cachedate.pop() +            self._cachecomp.pop() + +        return lastcomp + +    def _find_compdt(self, comp, dt): +        if comp.tzoffsetdiff < ZERO and self._fold(dt): +            dt -= comp.tzoffsetdiff + +        compdt = comp.rrule.before(dt, inc=True) + +        return compdt + +    def utcoffset(self, dt): +        if dt is None: +            return None + +        return self._find_comp(dt).tzoffsetto + +    def dst(self, dt): +        comp = self._find_comp(dt) +        if comp.isdst: +            return comp.tzoffsetdiff +        else: +            return ZERO + +    @tzname_in_python2 +    def tzname(self, dt): +        return self._find_comp(dt).tzname + +    def __repr__(self): +        return "<tzicalvtz %s>" % repr(self._tzid) + +    __reduce__ = object.__reduce__ + + +class tzical(object): +    """ +    This object is designed to parse an iCalendar-style ``VTIMEZONE`` structure +    as set out in `RFC 2445`_ Section 4.6.5 into one or more `tzinfo` objects. + +    :param `fileobj`: +        A file or stream in iCalendar format, which should be UTF-8 encoded +        with CRLF endings. + +    .. _`RFC 2445`: https://www.ietf.org/rfc/rfc2445.txt +    """ +    def __init__(self, fileobj): +        global rrule +        from dateutil import rrule + +        if isinstance(fileobj, string_types): +            self._s = fileobj +            # ical should be encoded in UTF-8 with CRLF +            fileobj = open(fileobj, 'r') +        else: +            self._s = getattr(fileobj, 'name', repr(fileobj)) +            fileobj = _ContextWrapper(fileobj) + +        self._vtz = {} + +        with fileobj as fobj: +            self._parse_rfc(fobj.read()) + +    def keys(self): +        """ +        Retrieves the available time zones as a list. +        """ +        return list(self._vtz.keys()) + +    def get(self, tzid=None): +        """ +        Retrieve a :py:class:`datetime.tzinfo` object by its ``tzid``. + +        :param tzid: +            If there is exactly one time zone available, omitting ``tzid`` +            or passing :py:const:`None` value returns it. Otherwise a valid +            key (which can be retrieved from :func:`keys`) is required. + +        :raises ValueError: +            Raised if ``tzid`` is not specified but there are either more +            or fewer than 1 zone defined. + +        :returns: +            Returns either a :py:class:`datetime.tzinfo` object representing +            the relevant time zone or :py:const:`None` if the ``tzid`` was +            not found. +        """ +        if tzid is None: +            if len(self._vtz) == 0: +                raise ValueError("no timezones defined") +            elif len(self._vtz) > 1: +                raise ValueError("more than one timezone available") +            tzid = next(iter(self._vtz)) + +        return self._vtz.get(tzid) + +    def _parse_offset(self, s): +        s = s.strip() +        if not s: +            raise ValueError("empty offset") +        if s[0] in ('+', '-'): +            signal = (-1, +1)[s[0] == '+'] +            s = s[1:] +        else: +            signal = +1 +        if len(s) == 4: +            return (int(s[:2]) * 3600 + int(s[2:]) * 60) * signal +        elif len(s) == 6: +            return (int(s[:2]) * 3600 + int(s[2:4]) * 60 + int(s[4:])) * signal +        else: +            raise ValueError("invalid offset: " + s) + +    def _parse_rfc(self, s): +        lines = s.splitlines() +        if not lines: +            raise ValueError("empty string") + +        # Unfold +        i = 0 +        while i < len(lines): +            line = lines[i].rstrip() +            if not line: +                del lines[i] +            elif i > 0 and line[0] == " ": +                lines[i-1] += line[1:] +                del lines[i] +            else: +                i += 1 + +        tzid = None +        comps = [] +        invtz = False +        comptype = None +        for line in lines: +            if not line: +                continue +            name, value = line.split(':', 1) +            parms = name.split(';') +            if not parms: +                raise ValueError("empty property name") +            name = parms[0].upper() +            parms = parms[1:] +            if invtz: +                if name == "BEGIN": +                    if value in ("STANDARD", "DAYLIGHT"): +                        # Process component +                        pass +                    else: +                        raise ValueError("unknown component: "+value) +                    comptype = value +                    founddtstart = False +                    tzoffsetfrom = None +                    tzoffsetto = None +                    rrulelines = [] +                    tzname = None +                elif name == "END": +                    if value == "VTIMEZONE": +                        if comptype: +                            raise ValueError("component not closed: "+comptype) +                        if not tzid: +                            raise ValueError("mandatory TZID not found") +                        if not comps: +                            raise ValueError( +                                "at least one component is needed") +                        # Process vtimezone +                        self._vtz[tzid] = _tzicalvtz(tzid, comps) +                        invtz = False +                    elif value == comptype: +                        if not founddtstart: +                            raise ValueError("mandatory DTSTART not found") +                        if tzoffsetfrom is None: +                            raise ValueError( +                                "mandatory TZOFFSETFROM not found") +                        if tzoffsetto is None: +                            raise ValueError( +                                "mandatory TZOFFSETFROM not found") +                        # Process component +                        rr = None +                        if rrulelines: +                            rr = rrule.rrulestr("\n".join(rrulelines), +                                                compatible=True, +                                                ignoretz=True, +                                                cache=True) +                        comp = _tzicalvtzcomp(tzoffsetfrom, tzoffsetto, +                                              (comptype == "DAYLIGHT"), +                                              tzname, rr) +                        comps.append(comp) +                        comptype = None +                    else: +                        raise ValueError("invalid component end: "+value) +                elif comptype: +                    if name == "DTSTART": +                        rrulelines.append(line) +                        founddtstart = True +                    elif name in ("RRULE", "RDATE", "EXRULE", "EXDATE"): +                        rrulelines.append(line) +                    elif name == "TZOFFSETFROM": +                        if parms: +                            raise ValueError( +                                "unsupported %s parm: %s " % (name, parms[0])) +                        tzoffsetfrom = self._parse_offset(value) +                    elif name == "TZOFFSETTO": +                        if parms: +                            raise ValueError( +                                "unsupported TZOFFSETTO parm: "+parms[0]) +                        tzoffsetto = self._parse_offset(value) +                    elif name == "TZNAME": +                        if parms: +                            raise ValueError( +                                "unsupported TZNAME parm: "+parms[0]) +                        tzname = value +                    elif name == "COMMENT": +                        pass +                    else: +                        raise ValueError("unsupported property: "+name) +                else: +                    if name == "TZID": +                        if parms: +                            raise ValueError( +                                "unsupported TZID parm: "+parms[0]) +                        tzid = value +                    elif name in ("TZURL", "LAST-MODIFIED", "COMMENT"): +                        pass +                    else: +                        raise ValueError("unsupported property: "+name) +            elif name == "BEGIN" and value == "VTIMEZONE": +                tzid = None +                comps = [] +                invtz = True + +    def __repr__(self): +        return "%s(%s)" % (self.__class__.__name__, repr(self._s)) + + +if sys.platform != "win32": +    TZFILES = ["/etc/localtime", "localtime"] +    TZPATHS = ["/usr/share/zoneinfo", +               "/usr/lib/zoneinfo", +               "/usr/share/lib/zoneinfo", +               "/etc/zoneinfo"] +else: +    TZFILES = [] +    TZPATHS = [] + + +def gettz(name=None): +    tz = None +    if not name: +        try: +            name = os.environ["TZ"] +        except KeyError: +            pass +    if name is None or name == ":": +        for filepath in TZFILES: +            if not os.path.isabs(filepath): +                filename = filepath +                for path in TZPATHS: +                    filepath = os.path.join(path, filename) +                    if os.path.isfile(filepath): +                        break +                else: +                    continue +            if os.path.isfile(filepath): +                try: +                    tz = tzfile(filepath) +                    break +                except (IOError, OSError, ValueError): +                    pass +        else: +            tz = tzlocal() +    else: +        if name.startswith(":"): +            name = name[:-1] +        if os.path.isabs(name): +            if os.path.isfile(name): +                tz = tzfile(name) +            else: +                tz = None +        else: +            for path in TZPATHS: +                filepath = os.path.join(path, name) +                if not os.path.isfile(filepath): +                    filepath = filepath.replace(' ', '_') +                    if not os.path.isfile(filepath): +                        continue +                try: +                    tz = tzfile(filepath) +                    break +                except (IOError, OSError, ValueError): +                    pass +            else: +                tz = None +                if tzwin is not None: +                    try: +                        tz = tzwin(name) +                    except WindowsError: +                        tz = None + +                if not tz: +                    from dateutil.zoneinfo import get_zonefile_instance +                    tz = get_zonefile_instance().get(name) + +                if not tz: +                    for c in name: +                        # name must have at least one offset to be a tzstr +                        if c in "0123456789": +                            try: +                                tz = tzstr(name) +                            except ValueError: +                                pass +                            break +                    else: +                        if name in ("GMT", "UTC"): +                            tz = tzutc() +                        elif name in time.tzname: +                            tz = tzlocal() +    return tz + + +def datetime_exists(dt, tz=None): +    """ +    Given a datetime and a time zone, determine whether or not a given datetime +    would fall in a gap. + +    :param dt: +        A :class:`datetime.datetime` (whose time zone will be ignored if ``tz`` +        is provided.) + +    :param tz: +        A :class:`datetime.tzinfo` with support for the ``fold`` attribute. If +        ``None`` or not provided, the datetime's own time zone will be used. + +    :return: +        Returns a boolean value whether or not the "wall time" exists in ``tz``. +    """ +    if tz is None: +        if dt.tzinfo is None: +            raise ValueError('Datetime is naive and no time zone provided.') +        tz = dt.tzinfo + +    dt = dt.replace(tzinfo=None) + +    # This is essentially a test of whether or not the datetime can survive +    # a round trip to UTC. +    dt_rt = dt.replace(tzinfo=tz).astimezone(tzutc()).astimezone(tz) +    dt_rt = dt_rt.replace(tzinfo=None) + +    return dt == dt_rt + + +def datetime_ambiguous(dt, tz=None): +    """ +    Given a datetime and a time zone, determine whether or not a given datetime +    is ambiguous (i.e if there are two times differentiated only by their DST +    status). + +    :param dt: +        A :class:`datetime.datetime` (whose time zone will be ignored if ``tz`` +        is provided.) + +    :param tz: +        A :class:`datetime.tzinfo` with support for the ``fold`` attribute. If +        ``None`` or not provided, the datetime's own time zone will be used. + +    :return: +        Returns a boolean value whether or not the "wall time" is ambiguous in +        ``tz``. + +    .. versionadded:: 2.6.0 +    """ +    if tz is None: +        if dt.tzinfo is None: +            raise ValueError('Datetime is naive and no time zone provided.') + +        tz = dt.tzinfo + +    # If a time zone defines its own "is_ambiguous" function, we'll use that. +    is_ambiguous_fn = getattr(tz, 'is_ambiguous', None) +    if is_ambiguous_fn is not None: +        try: +            return tz.is_ambiguous(dt) +        except: +            pass + +    # If it doesn't come out and tell us it's ambiguous, we'll just check if +    # the fold attribute has any effect on this particular date and time. +    dt = dt.replace(tzinfo=tz) +    wall_0 = enfold(dt, fold=0) +    wall_1 = enfold(dt, fold=1) + +    same_offset = wall_0.utcoffset() == wall_1.utcoffset() +    same_dst = wall_0.dst() == wall_1.dst() + +    return not (same_offset and same_dst) + + +def _datetime_to_timestamp(dt): +    """ +    Convert a :class:`datetime.datetime` object to an epoch timestamp in seconds +    since January 1, 1970, ignoring the time zone. +    """ +    return _total_seconds((dt.replace(tzinfo=None) - EPOCH)) + + +class _ContextWrapper(object): +    """ +    Class for wrapping contexts so that they are passed through in a +    with statement. +    """ +    def __init__(self, context): +        self.context = context + +    def __enter__(self): +        return self.context + +    def __exit__(*args, **kwargs): +        pass + +# vim:ts=4:sw=4:et diff --git a/python/dateutil/tz/win.py b/python/dateutil/tz/win.py new file mode 100644 index 0000000..36a1c26 --- /dev/null +++ b/python/dateutil/tz/win.py @@ -0,0 +1,332 @@ +# This code was originally contributed by Jeffrey Harris. +import datetime +import struct + +from six.moves import winreg +from six import text_type + +try: +    import ctypes +    from ctypes import wintypes +except ValueError: +    # ValueError is raised on non-Windows systems for some horrible reason. +    raise ImportError("Running tzwin on non-Windows system") + +from ._common import tzrangebase + +__all__ = ["tzwin", "tzwinlocal", "tzres"] + +ONEWEEK = datetime.timedelta(7) + +TZKEYNAMENT = r"SOFTWARE\Microsoft\Windows NT\CurrentVersion\Time Zones" +TZKEYNAME9X = r"SOFTWARE\Microsoft\Windows\CurrentVersion\Time Zones" +TZLOCALKEYNAME = r"SYSTEM\CurrentControlSet\Control\TimeZoneInformation" + + +def _settzkeyname(): +    handle = winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) +    try: +        winreg.OpenKey(handle, TZKEYNAMENT).Close() +        TZKEYNAME = TZKEYNAMENT +    except WindowsError: +        TZKEYNAME = TZKEYNAME9X +    handle.Close() +    return TZKEYNAME + + +TZKEYNAME = _settzkeyname() + + +class tzres(object): +    """ +    Class for accessing `tzres.dll`, which contains timezone name related +    resources. + +    .. versionadded:: 2.5.0 +    """ +    p_wchar = ctypes.POINTER(wintypes.WCHAR)        # Pointer to a wide char + +    def __init__(self, tzres_loc='tzres.dll'): +        # Load the user32 DLL so we can load strings from tzres +        user32 = ctypes.WinDLL('user32') + +        # Specify the LoadStringW function +        user32.LoadStringW.argtypes = (wintypes.HINSTANCE, +                                       wintypes.UINT, +                                       wintypes.LPWSTR, +                                       ctypes.c_int) + +        self.LoadStringW = user32.LoadStringW +        self._tzres = ctypes.WinDLL(tzres_loc) +        self.tzres_loc = tzres_loc + +    def load_name(self, offset): +        """ +        Load a timezone name from a DLL offset (integer). + +        >>> from dateutil.tzwin import tzres +        >>> tzr = tzres() +        >>> print(tzr.load_name(112)) +        'Eastern Standard Time' + +        :param offset: +            A positive integer value referring to a string from the tzres dll. + +        ..note: +            Offsets found in the registry are generally of the form +            `@tzres.dll,-114`. The offset in this case if 114, not -114. + +        """ +        resource = self.p_wchar() +        lpBuffer = ctypes.cast(ctypes.byref(resource), wintypes.LPWSTR) +        nchar = self.LoadStringW(self._tzres._handle, offset, lpBuffer, 0) +        return resource[:nchar] + +    def name_from_string(self, tzname_str): +        """ +        Parse strings as returned from the Windows registry into the time zone +        name as defined in the registry. + +        >>> from dateutil.tzwin import tzres +        >>> tzr = tzres() +        >>> print(tzr.name_from_string('@tzres.dll,-251')) +        'Dateline Daylight Time' +        >>> print(tzr.name_from_string('Eastern Standard Time')) +        'Eastern Standard Time' + +        :param tzname_str: +            A timezone name string as returned from a Windows registry key. + +        :return: +            Returns the localized timezone string from tzres.dll if the string +            is of the form `@tzres.dll,-offset`, else returns the input string. +        """ +        if not tzname_str.startswith('@'): +            return tzname_str + +        name_splt = tzname_str.split(',-') +        try: +            offset = int(name_splt[1]) +        except: +            raise ValueError("Malformed timezone string.") + +        return self.load_name(offset) + + +class tzwinbase(tzrangebase): +    """tzinfo class based on win32's timezones available in the registry.""" +    def __init__(self): +        raise NotImplementedError('tzwinbase is an abstract base class') + +    def __eq__(self, other): +        # Compare on all relevant dimensions, including name. +        if not isinstance(other, tzwinbase): +            return NotImplemented + +        return  (self._std_offset == other._std_offset and +                 self._dst_offset == other._dst_offset and +                 self._stddayofweek == other._stddayofweek and +                 self._dstdayofweek == other._dstdayofweek and +                 self._stdweeknumber == other._stdweeknumber and +                 self._dstweeknumber == other._dstweeknumber and +                 self._stdhour == other._stdhour and +                 self._dsthour == other._dsthour and +                 self._stdminute == other._stdminute and +                 self._dstminute == other._dstminute and +                 self._std_abbr == other._std_abbr and +                 self._dst_abbr == other._dst_abbr) + +    @staticmethod +    def list(): +        """Return a list of all time zones known to the system.""" +        with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle: +            with winreg.OpenKey(handle, TZKEYNAME) as tzkey: +                result = [winreg.EnumKey(tzkey, i) +                          for i in range(winreg.QueryInfoKey(tzkey)[0])] +        return result + +    def display(self): +        return self._display + +    def transitions(self, year): +        """ +        For a given year, get the DST on and off transition times, expressed +        always on the standard time side. For zones with no transitions, this +        function returns ``None``. + +        :param year: +            The year whose transitions you would like to query. + +        :return: +            Returns a :class:`tuple` of :class:`datetime.datetime` objects, +            ``(dston, dstoff)`` for zones with an annual DST transition, or +            ``None`` for fixed offset zones. +        """ + +        if not self.hasdst: +            return None + +        dston = picknthweekday(year, self._dstmonth, self._dstdayofweek, +                               self._dsthour, self._dstminute, +                               self._dstweeknumber) + +        dstoff = picknthweekday(year, self._stdmonth, self._stddayofweek, +                                self._stdhour, self._stdminute, +                                self._stdweeknumber) + +        # Ambiguous dates default to the STD side +        dstoff -= self._dst_base_offset + +        return dston, dstoff + +    def _get_hasdst(self): +        return self._dstmonth != 0 + +    @property +    def _dst_base_offset(self): +        return self._dst_base_offset_ + + +class tzwin(tzwinbase): + +    def __init__(self, name): +        self._name = name + +        # multiple contexts only possible in 2.7 and 3.1, we still support 2.6 +        with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle: +            tzkeyname = text_type("{kn}\\{name}").format(kn=TZKEYNAME, name=name) +            with winreg.OpenKey(handle, tzkeyname) as tzkey: +                keydict = valuestodict(tzkey) + +        self._std_abbr = keydict["Std"] +        self._dst_abbr = keydict["Dlt"] + +        self._display = keydict["Display"] + +        # See http://ww_winreg.jsiinc.com/SUBA/tip0300/rh0398.htm +        tup = struct.unpack("=3l16h", keydict["TZI"]) +        stdoffset = -tup[0]-tup[1]          # Bias + StandardBias * -1 +        dstoffset = stdoffset-tup[2]        # + DaylightBias * -1 +        self._std_offset = datetime.timedelta(minutes=stdoffset) +        self._dst_offset = datetime.timedelta(minutes=dstoffset) + +        # for the meaning see the win32 TIME_ZONE_INFORMATION structure docs +        # http://msdn.microsoft.com/en-us/library/windows/desktop/ms725481(v=vs.85).aspx +        (self._stdmonth, +         self._stddayofweek,   # Sunday = 0 +         self._stdweeknumber,  # Last = 5 +         self._stdhour, +         self._stdminute) = tup[4:9] + +        (self._dstmonth, +         self._dstdayofweek,   # Sunday = 0 +         self._dstweeknumber,  # Last = 5 +         self._dsthour, +         self._dstminute) = tup[12:17] + +        self._dst_base_offset_ = self._dst_offset - self._std_offset +        self.hasdst = self._get_hasdst() + +    def __repr__(self): +        return "tzwin(%s)" % repr(self._name) + +    def __reduce__(self): +        return (self.__class__, (self._name,)) + + +class tzwinlocal(tzwinbase): +    def __init__(self): +        with winreg.ConnectRegistry(None, winreg.HKEY_LOCAL_MACHINE) as handle: +            with winreg.OpenKey(handle, TZLOCALKEYNAME) as tzlocalkey: +                keydict = valuestodict(tzlocalkey) + +            self._std_abbr = keydict["StandardName"] +            self._dst_abbr = keydict["DaylightName"] + +            try: +                tzkeyname = text_type('{kn}\\{sn}').format(kn=TZKEYNAME, +                                                          sn=self._std_abbr) +                with winreg.OpenKey(handle, tzkeyname) as tzkey: +                    _keydict = valuestodict(tzkey) +                    self._display = _keydict["Display"] +            except OSError: +                self._display = None + +        stdoffset = -keydict["Bias"]-keydict["StandardBias"] +        dstoffset = stdoffset-keydict["DaylightBias"] + +        self._std_offset = datetime.timedelta(minutes=stdoffset) +        self._dst_offset = datetime.timedelta(minutes=dstoffset) + +        # For reasons unclear, in this particular key, the day of week has been +        # moved to the END of the SYSTEMTIME structure. +        tup = struct.unpack("=8h", keydict["StandardStart"]) + +        (self._stdmonth, +         self._stdweeknumber,  # Last = 5 +         self._stdhour, +         self._stdminute) = tup[1:5] + +        self._stddayofweek = tup[7] + +        tup = struct.unpack("=8h", keydict["DaylightStart"]) + +        (self._dstmonth, +         self._dstweeknumber,  # Last = 5 +         self._dsthour, +         self._dstminute) = tup[1:5] + +        self._dstdayofweek = tup[7] + +        self._dst_base_offset_ = self._dst_offset - self._std_offset +        self.hasdst = self._get_hasdst() + +    def __repr__(self): +        return "tzwinlocal()" + +    def __str__(self): +        # str will return the standard name, not the daylight name. +        return "tzwinlocal(%s)" % repr(self._std_abbr) + +    def __reduce__(self): +        return (self.__class__, ()) + + +def picknthweekday(year, month, dayofweek, hour, minute, whichweek): +    """ dayofweek == 0 means Sunday, whichweek 5 means last instance """ +    first = datetime.datetime(year, month, 1, hour, minute) + +    # This will work if dayofweek is ISO weekday (1-7) or Microsoft-style (0-6), +    # Because 7 % 7 = 0 +    weekdayone = first.replace(day=((dayofweek - first.isoweekday()) % 7) + 1) +    wd = weekdayone + ((whichweek - 1) * ONEWEEK) +    if (wd.month != month): +        wd -= ONEWEEK + +    return wd + + +def valuestodict(key): +    """Convert a registry key's values to a dictionary.""" +    dout = {} +    size = winreg.QueryInfoKey(key)[1] +    tz_res = None + +    for i in range(size): +        key_name, value, dtype = winreg.EnumValue(key, i) +        if dtype == winreg.REG_DWORD or dtype == winreg.REG_DWORD_LITTLE_ENDIAN: +            # If it's a DWORD (32-bit integer), it's stored as unsigned - convert +            # that to a proper signed integer +            if value & (1 << 31): +                value = value - (1 << 32) +        elif dtype == winreg.REG_SZ: +            # If it's a reference to the tzres DLL, load the actual string +            if value.startswith('@tzres'): +                tz_res = tz_res or tzres() +                value = tz_res.name_from_string(value) + +            value = value.rstrip('\x00')    # Remove trailing nulls + +        dout[key_name] = value + +    return dout diff --git a/python/dateutil/tzwin.py b/python/dateutil/tzwin.py new file mode 100644 index 0000000..cebc673 --- /dev/null +++ b/python/dateutil/tzwin.py @@ -0,0 +1,2 @@ +# tzwin has moved to dateutil.tz.win +from .tz.win import * diff --git a/python/dateutil/zoneinfo/__init__.py b/python/dateutil/zoneinfo/__init__.py new file mode 100644 index 0000000..a2ed4f9 --- /dev/null +++ b/python/dateutil/zoneinfo/__init__.py @@ -0,0 +1,183 @@ +# -*- coding: utf-8 -*- +import warnings +import json + +from tarfile import TarFile +from pkgutil import get_data +from io import BytesIO +from contextlib import closing + +from dateutil.tz import tzfile + +__all__ = ["get_zonefile_instance", "gettz", "gettz_db_metadata", "rebuild"] + +ZONEFILENAME = "dateutil-zoneinfo.tar.gz" +METADATA_FN = 'METADATA' + +# python2.6 compatability. Note that TarFile.__exit__ != TarFile.close, but +# it's close enough for python2.6 +tar_open = TarFile.open +if not hasattr(TarFile, '__exit__'): +    def tar_open(*args, **kwargs): +        return closing(TarFile.open(*args, **kwargs)) + + +class tzfile(tzfile): +    def __reduce__(self): +        return (gettz, (self._filename,)) + + +def getzoneinfofile_stream(): +    try: +        return BytesIO(get_data(__name__, ZONEFILENAME)) +    except IOError as e:  # TODO  switch to FileNotFoundError? +        warnings.warn("I/O error({0}): {1}".format(e.errno, e.strerror)) +        return None + + +class ZoneInfoFile(object): +    def __init__(self, zonefile_stream=None): +        if zonefile_stream is not None: +            with tar_open(fileobj=zonefile_stream, mode='r') as tf: +                # dict comprehension does not work on python2.6 +                # TODO: get back to the nicer syntax when we ditch python2.6 +                # self.zones = {zf.name: tzfile(tf.extractfile(zf), +                #               filename = zf.name) +                #              for zf in tf.getmembers() if zf.isfile()} +                self.zones = dict((zf.name, tzfile(tf.extractfile(zf), +                                                   filename=zf.name)) +                                  for zf in tf.getmembers() +                                  if zf.isfile() and zf.name != METADATA_FN) +                # deal with links: They'll point to their parent object. Less +                # waste of memory +                # links = {zl.name: self.zones[zl.linkname] +                #        for zl in tf.getmembers() if zl.islnk() or zl.issym()} +                links = dict((zl.name, self.zones[zl.linkname]) +                             for zl in tf.getmembers() if +                             zl.islnk() or zl.issym()) +                self.zones.update(links) +                try: +                    metadata_json = tf.extractfile(tf.getmember(METADATA_FN)) +                    metadata_str = metadata_json.read().decode('UTF-8') +                    self.metadata = json.loads(metadata_str) +                except KeyError: +                    # no metadata in tar file +                    self.metadata = None +        else: +            self.zones = dict() +            self.metadata = None + +    def get(self, name, default=None): +        """ +        Wrapper for :func:`ZoneInfoFile.zones.get`. This is a convenience method +        for retrieving zones from the zone dictionary. + +        :param name: +            The name of the zone to retrieve. (Generally IANA zone names) + +        :param default: +            The value to return in the event of a missing key. + +        .. versionadded:: 2.6.0 + +        """ +        return self.zones.get(name, default) + + +# The current API has gettz as a module function, although in fact it taps into +# a stateful class. So as a workaround for now, without changing the API, we +# will create a new "global" class instance the first time a user requests a +# timezone. Ugly, but adheres to the api. +# +# TODO: Remove after deprecation period. +_CLASS_ZONE_INSTANCE = list() + + +def get_zonefile_instance(new_instance=False): +    """ +    This is a convenience function which provides a :class:`ZoneInfoFile` +    instance using the data provided by the ``dateutil`` package. By default, it +    caches a single instance of the ZoneInfoFile object and returns that. + +    :param new_instance: +        If ``True``, a new instance of :class:`ZoneInfoFile` is instantiated and +        used as the cached instance for the next call. Otherwise, new instances +        are created only as necessary. + +    :return: +        Returns a :class:`ZoneInfoFile` object. + +    .. versionadded:: 2.6 +    """ +    if new_instance: +        zif = None +    else: +        zif = getattr(get_zonefile_instance, '_cached_instance', None) + +    if zif is None: +        zif = ZoneInfoFile(getzoneinfofile_stream()) + +        get_zonefile_instance._cached_instance = zif + +    return zif + + +def gettz(name): +    """ +    This retrieves a time zone from the local zoneinfo tarball that is packaged +    with dateutil. + +    :param name: +        An IANA-style time zone name, as found in the zoneinfo file. + +    :return: +        Returns a :class:`dateutil.tz.tzfile` time zone object. + +    .. warning:: +        It is generally inadvisable to use this function, and it is only +        provided for API compatibility with earlier versions. This is *not* +        equivalent to ``dateutil.tz.gettz()``, which selects an appropriate +        time zone based on the inputs, favoring system zoneinfo. This is ONLY +        for accessing the dateutil-specific zoneinfo (which may be out of +        date compared to the system zoneinfo). + +    .. deprecated:: 2.6 +        If you need to use a specific zoneinfofile over the system zoneinfo, +        instantiate a :class:`dateutil.zoneinfo.ZoneInfoFile` object and call +        :func:`dateutil.zoneinfo.ZoneInfoFile.get(name)` instead. + +        Use :func:`get_zonefile_instance` to retrieve an instance of the +        dateutil-provided zoneinfo. +    """ +    warnings.warn("zoneinfo.gettz() will be removed in future versions, " +                  "to use the dateutil-provided zoneinfo files, instantiate a " +                  "ZoneInfoFile object and use ZoneInfoFile.zones.get() " +                  "instead. See the documentation for details.", +                  DeprecationWarning) + +    if len(_CLASS_ZONE_INSTANCE) == 0: +        _CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream())) +    return _CLASS_ZONE_INSTANCE[0].zones.get(name) + + +def gettz_db_metadata(): +    """ Get the zonefile metadata + +    See `zonefile_metadata`_ + +    :returns: +        A dictionary with the database metadata + +    .. deprecated:: 2.6 +        See deprecation warning in :func:`zoneinfo.gettz`. To get metadata, +        query the attribute ``zoneinfo.ZoneInfoFile.metadata``. +    """ +    warnings.warn("zoneinfo.gettz_db_metadata() will be removed in future " +                  "versions, to use the dateutil-provided zoneinfo files, " +                  "ZoneInfoFile object and query the 'metadata' attribute " +                  "instead. See the documentation for details.", +                  DeprecationWarning) + +    if len(_CLASS_ZONE_INSTANCE) == 0: +        _CLASS_ZONE_INSTANCE.append(ZoneInfoFile(getzoneinfofile_stream())) +    return _CLASS_ZONE_INSTANCE[0].metadata diff --git a/python/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz b/python/dateutil/zoneinfo/dateutil-zoneinfo.tar.gzBinary files differ new file mode 100644 index 0000000..613c0ff --- /dev/null +++ b/python/dateutil/zoneinfo/dateutil-zoneinfo.tar.gz diff --git a/python/dateutil/zoneinfo/rebuild.py b/python/dateutil/zoneinfo/rebuild.py new file mode 100644 index 0000000..9d53bb8 --- /dev/null +++ b/python/dateutil/zoneinfo/rebuild.py @@ -0,0 +1,52 @@ +import logging +import os +import tempfile +import shutil +import json +from subprocess import check_call + +from dateutil.zoneinfo import tar_open, METADATA_FN, ZONEFILENAME + + +def rebuild(filename, tag=None, format="gz", zonegroups=[], metadata=None): +    """Rebuild the internal timezone info in dateutil/zoneinfo/zoneinfo*tar* + +    filename is the timezone tarball from ftp.iana.org/tz. + +    """ +    tmpdir = tempfile.mkdtemp() +    zonedir = os.path.join(tmpdir, "zoneinfo") +    moduledir = os.path.dirname(__file__) +    try: +        with tar_open(filename) as tf: +            for name in zonegroups: +                tf.extract(name, tmpdir) +            filepaths = [os.path.join(tmpdir, n) for n in zonegroups] +            try: +                check_call(["zic", "-d", zonedir] + filepaths) +            except OSError as e: +                _print_on_nosuchfile(e) +                raise +        # write metadata file +        with open(os.path.join(zonedir, METADATA_FN), 'w') as f: +            json.dump(metadata, f, indent=4, sort_keys=True) +        target = os.path.join(moduledir, ZONEFILENAME) +        with tar_open(target, "w:%s" % format) as tf: +            for entry in os.listdir(zonedir): +                entrypath = os.path.join(zonedir, entry) +                tf.add(entrypath, entry) +    finally: +        shutil.rmtree(tmpdir) + + +def _print_on_nosuchfile(e): +    """Print helpful troubleshooting message + +    e is an exception raised by subprocess.check_call() + +    """ +    if e.errno == 2: +        logging.error( +            "Could not find zic. Perhaps you need to install " +            "libc-bin or some other package that provides it, " +            "or it's not in your PATH?") diff --git a/python/defusedxml/ElementTree.py b/python/defusedxml/ElementTree.py new file mode 100644 index 0000000..41b2ea8 --- /dev/null +++ b/python/defusedxml/ElementTree.py @@ -0,0 +1,112 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.etree.ElementTree facade +""" +from __future__ import print_function, absolute_import + +import sys +from xml.etree.ElementTree import TreeBuilder as _TreeBuilder +from xml.etree.ElementTree import parse as _parse +from xml.etree.ElementTree import tostring + +from .common import PY3 + + +if PY3: +    import importlib +else: +    from xml.etree.ElementTree import XMLParser as _XMLParser +    from xml.etree.ElementTree import iterparse as _iterparse +    from xml.etree.ElementTree import ParseError + + +from .common import (DTDForbidden, EntitiesForbidden, +                     ExternalReferenceForbidden, _generate_etree_functions) + +__origin__ = "xml.etree.ElementTree" + + +def _get_py3_cls(): +    """Python 3.3 hides the pure Python code but defusedxml requires it. + +    The code is based on test.support.import_fresh_module(). +    """ +    pymodname = "xml.etree.ElementTree" +    cmodname = "_elementtree" + +    pymod = sys.modules.pop(pymodname, None) +    cmod = sys.modules.pop(cmodname, None) + +    sys.modules[cmodname] = None +    pure_pymod = importlib.import_module(pymodname) +    if cmod is not None: +        sys.modules[cmodname] = cmod +    else: +        sys.modules.pop(cmodname) +    sys.modules[pymodname] = pymod + +    _XMLParser = pure_pymod.XMLParser +    _iterparse = pure_pymod.iterparse +    ParseError = pure_pymod.ParseError + +    return _XMLParser, _iterparse, ParseError + + +if PY3: +    _XMLParser, _iterparse, ParseError = _get_py3_cls() + + +class DefusedXMLParser(_XMLParser): + +    def __init__(self, html=0, target=None, encoding=None, +                 forbid_dtd=False, forbid_entities=True, +                 forbid_external=True): +        # Python 2.x old style class +        _XMLParser.__init__(self, html, target, encoding) +        self.forbid_dtd = forbid_dtd +        self.forbid_entities = forbid_entities +        self.forbid_external = forbid_external +        if PY3: +            parser = self.parser +        else: +            parser = self._parser +        if self.forbid_dtd: +            parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl +        if self.forbid_entities: +            parser.EntityDeclHandler = self.defused_entity_decl +            parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl +        if self.forbid_external: +            parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler + +    def defused_start_doctype_decl(self, name, sysid, pubid, +                                   has_internal_subset): +        raise DTDForbidden(name, sysid, pubid) + +    def defused_entity_decl(self, name, is_parameter_entity, value, base, +                            sysid, pubid, notation_name): +        raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name) + +    def defused_unparsed_entity_decl(self, name, base, sysid, pubid, +                                     notation_name): +        # expat 1.2 +        raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name) + +    def defused_external_entity_ref_handler(self, context, base, sysid, +                                            pubid): +        raise ExternalReferenceForbidden(context, base, sysid, pubid) + + +# aliases +XMLTreeBuilder = XMLParse = DefusedXMLParser + +parse, iterparse, fromstring = _generate_etree_functions(DefusedXMLParser, +                                                         _TreeBuilder, _parse, +                                                         _iterparse) +XML = fromstring + + +__all__ = ['XML', 'XMLParse', 'XMLTreeBuilder', 'fromstring', 'iterparse', +           'parse', 'tostring'] diff --git a/python/defusedxml/__init__.py b/python/defusedxml/__init__.py new file mode 100644 index 0000000..590a5a9 --- /dev/null +++ b/python/defusedxml/__init__.py @@ -0,0 +1,45 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defuse XML bomb denial of service vulnerabilities +""" +from __future__ import print_function, absolute_import + +from .common import (DefusedXmlException, DTDForbidden, EntitiesForbidden, +                     ExternalReferenceForbidden, NotSupportedError, +                     _apply_defusing) + + +def defuse_stdlib(): +    """Monkey patch and defuse all stdlib packages + +    :warning: The monkey patch is an EXPERIMETNAL feature. +    """ +    defused = {} + +    from . import cElementTree +    from . import ElementTree +    from . import minidom +    from . import pulldom +    from . import sax +    from . import expatbuilder +    from . import expatreader +    from . import xmlrpc + +    xmlrpc.monkey_patch() +    defused[xmlrpc] = None + +    for defused_mod in [cElementTree, ElementTree, minidom, pulldom, sax, +                        expatbuilder, expatreader]: +        stdlib_mod = _apply_defusing(defused_mod) +        defused[defused_mod] = stdlib_mod + +    return defused + + +__version__ = "0.5.0" + +__all__ = ['DefusedXmlException', 'DTDForbidden', 'EntitiesForbidden', +           'ExternalReferenceForbidden', 'NotSupportedError'] diff --git a/python/defusedxml/cElementTree.py b/python/defusedxml/cElementTree.py new file mode 100644 index 0000000..cc13689 --- /dev/null +++ b/python/defusedxml/cElementTree.py @@ -0,0 +1,30 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.etree.cElementTree +""" +from __future__ import absolute_import + +from xml.etree.cElementTree import TreeBuilder as _TreeBuilder +from xml.etree.cElementTree import parse as _parse +from xml.etree.cElementTree import tostring +# iterparse from ElementTree! +from xml.etree.ElementTree import iterparse as _iterparse + +from .ElementTree import DefusedXMLParser +from .common import _generate_etree_functions + +__origin__ = "xml.etree.cElementTree" + + +XMLTreeBuilder = XMLParse = DefusedXMLParser + +parse, iterparse, fromstring = _generate_etree_functions(DefusedXMLParser, +                                                         _TreeBuilder, _parse, +                                                         _iterparse) +XML = fromstring + +__all__ = ['XML', 'XMLParse', 'XMLTreeBuilder', 'fromstring', 'iterparse', +           'parse', 'tostring'] diff --git a/python/defusedxml/common.py b/python/defusedxml/common.py new file mode 100644 index 0000000..668b609 --- /dev/null +++ b/python/defusedxml/common.py @@ -0,0 +1,120 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Common constants, exceptions and helpe functions +""" +import sys + +PY3 = sys.version_info[0] == 3 + + +class DefusedXmlException(ValueError): +    """Base exception +    """ + +    def __repr__(self): +        return str(self) + + +class DTDForbidden(DefusedXmlException): +    """Document type definition is forbidden +    """ + +    def __init__(self, name, sysid, pubid): +        super(DTDForbidden, self).__init__() +        self.name = name +        self.sysid = sysid +        self.pubid = pubid + +    def __str__(self): +        tpl = "DTDForbidden(name='{}', system_id={!r}, public_id={!r})" +        return tpl.format(self.name, self.sysid, self.pubid) + + +class EntitiesForbidden(DefusedXmlException): +    """Entity definition is forbidden +    """ + +    def __init__(self, name, value, base, sysid, pubid, notation_name): +        super(EntitiesForbidden, self).__init__() +        self.name = name +        self.value = value +        self.base = base +        self.sysid = sysid +        self.pubid = pubid +        self.notation_name = notation_name + +    def __str__(self): +        tpl = "EntitiesForbidden(name='{}', system_id={!r}, public_id={!r})" +        return tpl.format(self.name, self.sysid, self.pubid) + + +class ExternalReferenceForbidden(DefusedXmlException): +    """Resolving an external reference is forbidden +    """ + +    def __init__(self, context, base, sysid, pubid): +        super(ExternalReferenceForbidden, self).__init__() +        self.context = context +        self.base = base +        self.sysid = sysid +        self.pubid = pubid + +    def __str__(self): +        tpl = "ExternalReferenceForbidden(system_id='{}', public_id={})" +        return tpl.format(self.sysid, self.pubid) + + +class NotSupportedError(DefusedXmlException): +    """The operation is not supported +    """ + + +def _apply_defusing(defused_mod): +    assert defused_mod is sys.modules[defused_mod.__name__] +    stdlib_name = defused_mod.__origin__ +    __import__(stdlib_name, {}, {}, ["*"]) +    stdlib_mod = sys.modules[stdlib_name] +    stdlib_names = set(dir(stdlib_mod)) +    for name, obj in vars(defused_mod).items(): +        if name.startswith("_") or name not in stdlib_names: +            continue +        setattr(stdlib_mod, name, obj) +    return stdlib_mod + + +def _generate_etree_functions(DefusedXMLParser, _TreeBuilder, +                              _parse, _iterparse): +    """Factory for functions needed by etree, dependent on whether +    cElementTree or ElementTree is used.""" + +    def parse(source, parser=None, forbid_dtd=False, forbid_entities=True, +              forbid_external=True): +        if parser is None: +            parser = DefusedXMLParser(target=_TreeBuilder(), +                                      forbid_dtd=forbid_dtd, +                                      forbid_entities=forbid_entities, +                                      forbid_external=forbid_external) +        return _parse(source, parser) + +    def iterparse(source, events=None, parser=None, forbid_dtd=False, +                  forbid_entities=True, forbid_external=True): +        if parser is None: +            parser = DefusedXMLParser(target=_TreeBuilder(), +                                      forbid_dtd=forbid_dtd, +                                      forbid_entities=forbid_entities, +                                      forbid_external=forbid_external) +        return _iterparse(source, events, parser) + +    def fromstring(text, forbid_dtd=False, forbid_entities=True, +                   forbid_external=True): +        parser = DefusedXMLParser(target=_TreeBuilder(), +                                  forbid_dtd=forbid_dtd, +                                  forbid_entities=forbid_entities, +                                  forbid_external=forbid_external) +        parser.feed(text) +        return parser.close() + +    return parse, iterparse, fromstring diff --git a/python/defusedxml/expatbuilder.py b/python/defusedxml/expatbuilder.py new file mode 100644 index 0000000..0eb6b91 --- /dev/null +++ b/python/defusedxml/expatbuilder.py @@ -0,0 +1,110 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.dom.expatbuilder +""" +from __future__ import print_function, absolute_import + +from xml.dom.expatbuilder import ExpatBuilder as _ExpatBuilder +from xml.dom.expatbuilder import Namespaces as _Namespaces + +from .common import (DTDForbidden, EntitiesForbidden, +                     ExternalReferenceForbidden) + +__origin__ = "xml.dom.expatbuilder" + + +class DefusedExpatBuilder(_ExpatBuilder): +    """Defused document builder""" + +    def __init__(self, options=None, forbid_dtd=False, forbid_entities=True, +                 forbid_external=True): +        _ExpatBuilder.__init__(self, options) +        self.forbid_dtd = forbid_dtd +        self.forbid_entities = forbid_entities +        self.forbid_external = forbid_external + +    def defused_start_doctype_decl(self, name, sysid, pubid, +                                   has_internal_subset): +        raise DTDForbidden(name, sysid, pubid) + +    def defused_entity_decl(self, name, is_parameter_entity, value, base, +                            sysid, pubid, notation_name): +        raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name) + +    def defused_unparsed_entity_decl(self, name, base, sysid, pubid, +                                     notation_name): +        # expat 1.2 +        raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name) + +    def defused_external_entity_ref_handler(self, context, base, sysid, +                                            pubid): +        raise ExternalReferenceForbidden(context, base, sysid, pubid) + +    def install(self, parser): +        _ExpatBuilder.install(self, parser) + +        if self.forbid_dtd: +            parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl +        if self.forbid_entities: +            # if self._options.entities: +            parser.EntityDeclHandler = self.defused_entity_decl +            parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl +        if self.forbid_external: +            parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler + + +class DefusedExpatBuilderNS(_Namespaces, DefusedExpatBuilder): +    """Defused document builder that supports namespaces.""" + +    def install(self, parser): +        DefusedExpatBuilder.install(self, parser) +        if self._options.namespace_declarations: +            parser.StartNamespaceDeclHandler = ( +                self.start_namespace_decl_handler) + +    def reset(self): +        DefusedExpatBuilder.reset(self) +        self._initNamespaces() + + +def parse(file, namespaces=True, forbid_dtd=False, forbid_entities=True, +          forbid_external=True): +    """Parse a document, returning the resulting Document node. + +    'file' may be either a file name or an open file object. +    """ +    if namespaces: +        build_builder = DefusedExpatBuilderNS +    else: +        build_builder = DefusedExpatBuilder +    builder = build_builder(forbid_dtd=forbid_dtd, +                            forbid_entities=forbid_entities, +                            forbid_external=forbid_external) + +    if isinstance(file, str): +        fp = open(file, 'rb') +        try: +            result = builder.parseFile(fp) +        finally: +            fp.close() +    else: +        result = builder.parseFile(file) +    return result + + +def parseString(string, namespaces=True, forbid_dtd=False, +                forbid_entities=True, forbid_external=True): +    """Parse a document from a string, returning the resulting +    Document node. +    """ +    if namespaces: +        build_builder = DefusedExpatBuilderNS +    else: +        build_builder = DefusedExpatBuilder +    builder = build_builder(forbid_dtd=forbid_dtd, +                            forbid_entities=forbid_entities, +                            forbid_external=forbid_external) +    return builder.parseString(string) diff --git a/python/defusedxml/expatreader.py b/python/defusedxml/expatreader.py new file mode 100644 index 0000000..ef6bc39 --- /dev/null +++ b/python/defusedxml/expatreader.py @@ -0,0 +1,59 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.sax.expatreader +""" +from __future__ import print_function, absolute_import + +from xml.sax.expatreader import ExpatParser as _ExpatParser + +from .common import (DTDForbidden, EntitiesForbidden, +                     ExternalReferenceForbidden) + +__origin__ = "xml.sax.expatreader" + + +class DefusedExpatParser(_ExpatParser): +    """Defused SAX driver for the pyexpat C module.""" + +    def __init__(self, namespaceHandling=0, bufsize=2 ** 16 - 20, +                 forbid_dtd=False, forbid_entities=True, +                 forbid_external=True): +        _ExpatParser.__init__(self, namespaceHandling, bufsize) +        self.forbid_dtd = forbid_dtd +        self.forbid_entities = forbid_entities +        self.forbid_external = forbid_external + +    def defused_start_doctype_decl(self, name, sysid, pubid, +                                   has_internal_subset): +        raise DTDForbidden(name, sysid, pubid) + +    def defused_entity_decl(self, name, is_parameter_entity, value, base, +                            sysid, pubid, notation_name): +        raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name) + +    def defused_unparsed_entity_decl(self, name, base, sysid, pubid, +                                     notation_name): +        # expat 1.2 +        raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name) + +    def defused_external_entity_ref_handler(self, context, base, sysid, +                                            pubid): +        raise ExternalReferenceForbidden(context, base, sysid, pubid) + +    def reset(self): +        _ExpatParser.reset(self) +        parser = self._parser +        if self.forbid_dtd: +            parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl +        if self.forbid_entities: +            parser.EntityDeclHandler = self.defused_entity_decl +            parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl +        if self.forbid_external: +            parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler + + +def create_parser(*args, **kwargs): +    return DefusedExpatParser(*args, **kwargs) diff --git a/python/defusedxml/lxml.py b/python/defusedxml/lxml.py new file mode 100644 index 0000000..7f3ee0b --- /dev/null +++ b/python/defusedxml/lxml.py @@ -0,0 +1,153 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Example code for lxml.etree protection + +The code has NO protection against decompression bombs. +""" +from __future__ import print_function, absolute_import + +import threading +from lxml import etree as _etree + +from .common import DTDForbidden, EntitiesForbidden, NotSupportedError + +LXML3 = _etree.LXML_VERSION[0] >= 3 + +__origin__ = "lxml.etree" + +tostring = _etree.tostring + + +class RestrictedElement(_etree.ElementBase): +    """A restricted Element class that filters out instances of some classes +    """ +    __slots__ = () +    # blacklist = (etree._Entity, etree._ProcessingInstruction, etree._Comment) +    blacklist = _etree._Entity + +    def _filter(self, iterator): +        blacklist = self.blacklist +        for child in iterator: +            if isinstance(child, blacklist): +                continue +            yield child + +    def __iter__(self): +        iterator = super(RestrictedElement, self).__iter__() +        return self._filter(iterator) + +    def iterchildren(self, tag=None, reversed=False): +        iterator = super(RestrictedElement, self).iterchildren( +            tag=tag, reversed=reversed) +        return self._filter(iterator) + +    def iter(self, tag=None, *tags): +        iterator = super(RestrictedElement, self).iter(tag=tag, *tags) +        return self._filter(iterator) + +    def iterdescendants(self, tag=None, *tags): +        iterator = super(RestrictedElement, +                         self).iterdescendants(tag=tag, *tags) +        return self._filter(iterator) + +    def itersiblings(self, tag=None, preceding=False): +        iterator = super(RestrictedElement, self).itersiblings( +            tag=tag, preceding=preceding) +        return self._filter(iterator) + +    def getchildren(self): +        iterator = super(RestrictedElement, self).__iter__() +        return list(self._filter(iterator)) + +    def getiterator(self, tag=None): +        iterator = super(RestrictedElement, self).getiterator(tag) +        return self._filter(iterator) + + +class GlobalParserTLS(threading.local): +    """Thread local context for custom parser instances +    """ +    parser_config = { +        'resolve_entities': False, +        # 'remove_comments': True, +        # 'remove_pis': True, +    } + +    element_class = RestrictedElement + +    def createDefaultParser(self): +        parser = _etree.XMLParser(**self.parser_config) +        element_class = self.element_class +        if self.element_class is not None: +            lookup = _etree.ElementDefaultClassLookup(element=element_class) +            parser.set_element_class_lookup(lookup) +        return parser + +    def setDefaultParser(self, parser): +        self._default_parser = parser + +    def getDefaultParser(self): +        parser = getattr(self, "_default_parser", None) +        if parser is None: +            parser = self.createDefaultParser() +            self.setDefaultParser(parser) +        return parser + + +_parser_tls = GlobalParserTLS() +getDefaultParser = _parser_tls.getDefaultParser + + +def check_docinfo(elementtree, forbid_dtd=False, forbid_entities=True): +    """Check docinfo of an element tree for DTD and entity declarations + +    The check for entity declarations needs lxml 3 or newer. lxml 2.x does +    not support dtd.iterentities(). +    """ +    docinfo = elementtree.docinfo +    if docinfo.doctype: +        if forbid_dtd: +            raise DTDForbidden(docinfo.doctype, +                               docinfo.system_url, +                               docinfo.public_id) +        if forbid_entities and not LXML3: +            # lxml < 3 has no iterentities() +            raise NotSupportedError("Unable to check for entity declarations " +                                    "in lxml 2.x") + +    if forbid_entities: +        for dtd in docinfo.internalDTD, docinfo.externalDTD: +            if dtd is None: +                continue +            for entity in dtd.iterentities(): +                raise EntitiesForbidden(entity.name, entity.content, None, +                                        None, None, None) + + +def parse(source, parser=None, base_url=None, forbid_dtd=False, +          forbid_entities=True): +    if parser is None: +        parser = getDefaultParser() +    elementtree = _etree.parse(source, parser, base_url=base_url) +    check_docinfo(elementtree, forbid_dtd, forbid_entities) +    return elementtree + + +def fromstring(text, parser=None, base_url=None, forbid_dtd=False, +               forbid_entities=True): +    if parser is None: +        parser = getDefaultParser() +    rootelement = _etree.fromstring(text, parser, base_url=base_url) +    elementtree = rootelement.getroottree() +    check_docinfo(elementtree, forbid_dtd, forbid_entities) +    return rootelement + + +XML = fromstring + + +def iterparse(*args, **kwargs): +    raise NotSupportedError("defused lxml.etree.iterparse not available") diff --git a/python/defusedxml/minidom.py b/python/defusedxml/minidom.py new file mode 100644 index 0000000..0fd8684 --- /dev/null +++ b/python/defusedxml/minidom.py @@ -0,0 +1,42 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.dom.minidom +""" +from __future__ import print_function, absolute_import + +from xml.dom.minidom import _do_pulldom_parse +from . import expatbuilder as _expatbuilder +from . import pulldom as _pulldom + +__origin__ = "xml.dom.minidom" + + +def parse(file, parser=None, bufsize=None, forbid_dtd=False, +          forbid_entities=True, forbid_external=True): +    """Parse a file into a DOM by filename or file object.""" +    if parser is None and not bufsize: +        return _expatbuilder.parse(file, forbid_dtd=forbid_dtd, +                                   forbid_entities=forbid_entities, +                                   forbid_external=forbid_external) +    else: +        return _do_pulldom_parse(_pulldom.parse, (file,), +                                 {'parser': parser, 'bufsize': bufsize, +                                  'forbid_dtd': forbid_dtd, 'forbid_entities': forbid_entities, +                                  'forbid_external': forbid_external}) + + +def parseString(string, parser=None, forbid_dtd=False, +                forbid_entities=True, forbid_external=True): +    """Parse a file into a DOM from a string.""" +    if parser is None: +        return _expatbuilder.parseString(string, forbid_dtd=forbid_dtd, +                                         forbid_entities=forbid_entities, +                                         forbid_external=forbid_external) +    else: +        return _do_pulldom_parse(_pulldom.parseString, (string,), +                                 {'parser': parser, 'forbid_dtd': forbid_dtd, +                                  'forbid_entities': forbid_entities, +                                  'forbid_external': forbid_external}) diff --git a/python/defusedxml/pulldom.py b/python/defusedxml/pulldom.py new file mode 100644 index 0000000..fc9e466 --- /dev/null +++ b/python/defusedxml/pulldom.py @@ -0,0 +1,34 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.dom.pulldom +""" +from __future__ import print_function, absolute_import + +from xml.dom.pulldom import parse as _parse +from xml.dom.pulldom import parseString as _parseString +from .sax import make_parser + +__origin__ = "xml.dom.pulldom" + + +def parse(stream_or_string, parser=None, bufsize=None, forbid_dtd=False, +          forbid_entities=True, forbid_external=True): +    if parser is None: +        parser = make_parser() +        parser.forbid_dtd = forbid_dtd +        parser.forbid_entities = forbid_entities +        parser.forbid_external = forbid_external +    return _parse(stream_or_string, parser, bufsize) + + +def parseString(string, parser=None, forbid_dtd=False, +                forbid_entities=True, forbid_external=True): +    if parser is None: +        parser = make_parser() +        parser.forbid_dtd = forbid_dtd +        parser.forbid_entities = forbid_entities +        parser.forbid_external = forbid_external +    return _parseString(string, parser) diff --git a/python/defusedxml/sax.py b/python/defusedxml/sax.py new file mode 100644 index 0000000..534d0ca --- /dev/null +++ b/python/defusedxml/sax.py @@ -0,0 +1,49 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xml.sax +""" +from __future__ import print_function, absolute_import + +from xml.sax import InputSource as _InputSource +from xml.sax import ErrorHandler as _ErrorHandler + +from . import expatreader + +__origin__ = "xml.sax" + + +def parse(source, handler, errorHandler=_ErrorHandler(), forbid_dtd=False, +          forbid_entities=True, forbid_external=True): +    parser = make_parser() +    parser.setContentHandler(handler) +    parser.setErrorHandler(errorHandler) +    parser.forbid_dtd = forbid_dtd +    parser.forbid_entities = forbid_entities +    parser.forbid_external = forbid_external +    parser.parse(source) + + +def parseString(string, handler, errorHandler=_ErrorHandler(), +                forbid_dtd=False, forbid_entities=True, +                forbid_external=True): +    from io import BytesIO + +    if errorHandler is None: +        errorHandler = _ErrorHandler() +    parser = make_parser() +    parser.setContentHandler(handler) +    parser.setErrorHandler(errorHandler) +    parser.forbid_dtd = forbid_dtd +    parser.forbid_entities = forbid_entities +    parser.forbid_external = forbid_external + +    inpsrc = _InputSource() +    inpsrc.setByteStream(BytesIO(string)) +    parser.parse(inpsrc) + + +def make_parser(parser_list=[]): +    return expatreader.create_parser() diff --git a/python/defusedxml/xmlrpc.py b/python/defusedxml/xmlrpc.py new file mode 100644 index 0000000..2a456e6 --- /dev/null +++ b/python/defusedxml/xmlrpc.py @@ -0,0 +1,157 @@ +# defusedxml +# +# Copyright (c) 2013 by Christian Heimes <christian@python.org> +# Licensed to PSF under a Contributor Agreement. +# See http://www.python.org/psf/license for licensing details. +"""Defused xmlrpclib + +Also defuses gzip bomb +""" +from __future__ import print_function, absolute_import + +import io + +from .common import ( +    DTDForbidden, EntitiesForbidden, ExternalReferenceForbidden, PY3) + +if PY3: +    __origin__ = "xmlrpc.client" +    from xmlrpc.client import ExpatParser +    from xmlrpc import client as xmlrpc_client +    from xmlrpc import server as xmlrpc_server +    from xmlrpc.client import gzip_decode as _orig_gzip_decode +    from xmlrpc.client import GzipDecodedResponse as _OrigGzipDecodedResponse +else: +    __origin__ = "xmlrpclib" +    from xmlrpclib import ExpatParser +    import xmlrpclib as xmlrpc_client +    xmlrpc_server = None +    from xmlrpclib import gzip_decode as _orig_gzip_decode +    from xmlrpclib import GzipDecodedResponse as _OrigGzipDecodedResponse + +try: +    import gzip +except ImportError: +    gzip = None + + +# Limit maximum request size to prevent resource exhaustion DoS +# Also used to limit maximum amount of gzip decoded data in order to prevent +# decompression bombs +# A value of -1 or smaller disables the limit +MAX_DATA = 30 * 1024 * 1024  # 30 MB + + +def defused_gzip_decode(data, limit=None): +    """gzip encoded data -> unencoded data + +    Decode data using the gzip content encoding as described in RFC 1952 +    """ +    if not gzip: +        raise NotImplementedError +    if limit is None: +        limit = MAX_DATA +    f = io.BytesIO(data) +    gzf = gzip.GzipFile(mode="rb", fileobj=f) +    try: +        if limit < 0:  # no limit +            decoded = gzf.read() +        else: +            decoded = gzf.read(limit + 1) +    except IOError: +        raise ValueError("invalid data") +    f.close() +    gzf.close() +    if limit >= 0 and len(decoded) > limit: +        raise ValueError("max gzipped payload length exceeded") +    return decoded + + +class DefusedGzipDecodedResponse(gzip.GzipFile if gzip else object): +    """a file-like object to decode a response encoded with the gzip +    method, as described in RFC 1952. +    """ + +    def __init__(self, response, limit=None): +        # response doesn't support tell() and read(), required by +        # GzipFile +        if not gzip: +            raise NotImplementedError +        self.limit = limit = limit if limit is not None else MAX_DATA +        if limit < 0:  # no limit +            data = response.read() +            self.readlength = None +        else: +            data = response.read(limit + 1) +            self.readlength = 0 +        if limit >= 0 and len(data) > limit: +            raise ValueError("max payload length exceeded") +        self.stringio = io.BytesIO(data) +        gzip.GzipFile.__init__(self, mode="rb", fileobj=self.stringio) + +    def read(self, n): +        if self.limit >= 0: +            left = self.limit - self.readlength +            n = min(n, left + 1) +            data = gzip.GzipFile.read(self, n) +            self.readlength += len(data) +            if self.readlength > self.limit: +                raise ValueError("max payload length exceeded") +            return data +        else: +            return gzip.GzipFile.read(self, n) + +    def close(self): +        gzip.GzipFile.close(self) +        self.stringio.close() + + +class DefusedExpatParser(ExpatParser): + +    def __init__(self, target, forbid_dtd=False, forbid_entities=True, +                 forbid_external=True): +        ExpatParser.__init__(self, target) +        self.forbid_dtd = forbid_dtd +        self.forbid_entities = forbid_entities +        self.forbid_external = forbid_external +        parser = self._parser +        if self.forbid_dtd: +            parser.StartDoctypeDeclHandler = self.defused_start_doctype_decl +        if self.forbid_entities: +            parser.EntityDeclHandler = self.defused_entity_decl +            parser.UnparsedEntityDeclHandler = self.defused_unparsed_entity_decl +        if self.forbid_external: +            parser.ExternalEntityRefHandler = self.defused_external_entity_ref_handler + +    def defused_start_doctype_decl(self, name, sysid, pubid, +                                   has_internal_subset): +        raise DTDForbidden(name, sysid, pubid) + +    def defused_entity_decl(self, name, is_parameter_entity, value, base, +                            sysid, pubid, notation_name): +        raise EntitiesForbidden(name, value, base, sysid, pubid, notation_name) + +    def defused_unparsed_entity_decl(self, name, base, sysid, pubid, +                                     notation_name): +        # expat 1.2 +        raise EntitiesForbidden(name, None, base, sysid, pubid, notation_name) + +    def defused_external_entity_ref_handler(self, context, base, sysid, +                                            pubid): +        raise ExternalReferenceForbidden(context, base, sysid, pubid) + + +def monkey_patch(): +    xmlrpc_client.FastParser = DefusedExpatParser +    xmlrpc_client.GzipDecodedResponse = DefusedGzipDecodedResponse +    xmlrpc_client.gzip_decode = defused_gzip_decode +    if xmlrpc_server: +        xmlrpc_server.gzip_decode = defused_gzip_decode + + +def unmonkey_patch(): +    xmlrpc_client.FastParser = None +    xmlrpc_client.GzipDecodedResponse = _OrigGzipDecodedResponse +    xmlrpc_client.gzip_decode = _orig_gzip_decode +    if xmlrpc_server: +        xmlrpc_server.gzip_decode = _orig_gzip_decode diff --git a/python/six.py b/python/six.py new file mode 100644 index 0000000..6bf4fd3 --- /dev/null +++ b/python/six.py @@ -0,0 +1,891 @@ +# Copyright (c) 2010-2017 Benjamin Peterson +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Utilities for writing code that runs on Python 2 and 3""" + +from __future__ import absolute_import + +import functools +import itertools +import operator +import sys +import types + +__author__ = "Benjamin Peterson <benjamin@python.org>" +__version__ = "1.11.0" + + +# Useful for very coarse version differentiation. +PY2 = sys.version_info[0] == 2 +PY3 = sys.version_info[0] == 3 +PY34 = sys.version_info[0:2] >= (3, 4) + +if PY3: +    string_types = str, +    integer_types = int, +    class_types = type, +    text_type = str +    binary_type = bytes + +    MAXSIZE = sys.maxsize +else: +    string_types = basestring, +    integer_types = (int, long) +    class_types = (type, types.ClassType) +    text_type = unicode +    binary_type = str + +    if sys.platform.startswith("java"): +        # Jython always uses 32 bits. +        MAXSIZE = int((1 << 31) - 1) +    else: +        # It's possible to have sizeof(long) != sizeof(Py_ssize_t). +        class X(object): + +            def __len__(self): +                return 1 << 31 +        try: +            len(X()) +        except OverflowError: +            # 32-bit +            MAXSIZE = int((1 << 31) - 1) +        else: +            # 64-bit +            MAXSIZE = int((1 << 63) - 1) +        del X + + +def _add_doc(func, doc): +    """Add documentation to a function.""" +    func.__doc__ = doc + + +def _import_module(name): +    """Import module, returning the module after the last dot.""" +    __import__(name) +    return sys.modules[name] + + +class _LazyDescr(object): + +    def __init__(self, name): +        self.name = name + +    def __get__(self, obj, tp): +        result = self._resolve() +        setattr(obj, self.name, result)  # Invokes __set__. +        try: +            # This is a bit ugly, but it avoids running this again by +            # removing this descriptor. +            delattr(obj.__class__, self.name) +        except AttributeError: +            pass +        return result + + +class MovedModule(_LazyDescr): + +    def __init__(self, name, old, new=None): +        super(MovedModule, self).__init__(name) +        if PY3: +            if new is None: +                new = name +            self.mod = new +        else: +            self.mod = old + +    def _resolve(self): +        return _import_module(self.mod) + +    def __getattr__(self, attr): +        _module = self._resolve() +        value = getattr(_module, attr) +        setattr(self, attr, value) +        return value + + +class _LazyModule(types.ModuleType): + +    def __init__(self, name): +        super(_LazyModule, self).__init__(name) +        self.__doc__ = self.__class__.__doc__ + +    def __dir__(self): +        attrs = ["__doc__", "__name__"] +        attrs += [attr.name for attr in self._moved_attributes] +        return attrs + +    # Subclasses should override this +    _moved_attributes = [] + + +class MovedAttribute(_LazyDescr): + +    def __init__(self, name, old_mod, new_mod, old_attr=None, new_attr=None): +        super(MovedAttribute, self).__init__(name) +        if PY3: +            if new_mod is None: +                new_mod = name +            self.mod = new_mod +            if new_attr is None: +                if old_attr is None: +                    new_attr = name +                else: +                    new_attr = old_attr +            self.attr = new_attr +        else: +            self.mod = old_mod +            if old_attr is None: +                old_attr = name +            self.attr = old_attr + +    def _resolve(self): +        module = _import_module(self.mod) +        return getattr(module, self.attr) + + +class _SixMetaPathImporter(object): + +    """ +    A meta path importer to import six.moves and its submodules. + +    This class implements a PEP302 finder and loader. It should be compatible +    with Python 2.5 and all existing versions of Python3 +    """ + +    def __init__(self, six_module_name): +        self.name = six_module_name +        self.known_modules = {} + +    def _add_module(self, mod, *fullnames): +        for fullname in fullnames: +            self.known_modules[self.name + "." + fullname] = mod + +    def _get_module(self, fullname): +        return self.known_modules[self.name + "." + fullname] + +    def find_module(self, fullname, path=None): +        if fullname in self.known_modules: +            return self +        return None + +    def __get_module(self, fullname): +        try: +            return self.known_modules[fullname] +        except KeyError: +            raise ImportError("This loader does not know module " + fullname) + +    def load_module(self, fullname): +        try: +            # in case of a reload +            return sys.modules[fullname] +        except KeyError: +            pass +        mod = self.__get_module(fullname) +        if isinstance(mod, MovedModule): +            mod = mod._resolve() +        else: +            mod.__loader__ = self +        sys.modules[fullname] = mod +        return mod + +    def is_package(self, fullname): +        """ +        Return true, if the named module is a package. + +        We need this method to get correct spec objects with +        Python 3.4 (see PEP451) +        """ +        return hasattr(self.__get_module(fullname), "__path__") + +    def get_code(self, fullname): +        """Return None + +        Required, if is_package is implemented""" +        self.__get_module(fullname)  # eventually raises ImportError +        return None +    get_source = get_code  # same as get_code + +_importer = _SixMetaPathImporter(__name__) + + +class _MovedItems(_LazyModule): + +    """Lazy loading of moved objects""" +    __path__ = []  # mark as package + + +_moved_attributes = [ +    MovedAttribute("cStringIO", "cStringIO", "io", "StringIO"), +    MovedAttribute("filter", "itertools", "builtins", "ifilter", "filter"), +    MovedAttribute("filterfalse", "itertools", "itertools", "ifilterfalse", "filterfalse"), +    MovedAttribute("input", "__builtin__", "builtins", "raw_input", "input"), +    MovedAttribute("intern", "__builtin__", "sys"), +    MovedAttribute("map", "itertools", "builtins", "imap", "map"), +    MovedAttribute("getcwd", "os", "os", "getcwdu", "getcwd"), +    MovedAttribute("getcwdb", "os", "os", "getcwd", "getcwdb"), +    MovedAttribute("getoutput", "commands", "subprocess"), +    MovedAttribute("range", "__builtin__", "builtins", "xrange", "range"), +    MovedAttribute("reload_module", "__builtin__", "importlib" if PY34 else "imp", "reload"), +    MovedAttribute("reduce", "__builtin__", "functools"), +    MovedAttribute("shlex_quote", "pipes", "shlex", "quote"), +    MovedAttribute("StringIO", "StringIO", "io"), +    MovedAttribute("UserDict", "UserDict", "collections"), +    MovedAttribute("UserList", "UserList", "collections"), +    MovedAttribute("UserString", "UserString", "collections"), +    MovedAttribute("xrange", "__builtin__", "builtins", "xrange", "range"), +    MovedAttribute("zip", "itertools", "builtins", "izip", "zip"), +    MovedAttribute("zip_longest", "itertools", "itertools", "izip_longest", "zip_longest"), +    MovedModule("builtins", "__builtin__"), +    MovedModule("configparser", "ConfigParser"), +    MovedModule("copyreg", "copy_reg"), +    MovedModule("dbm_gnu", "gdbm", "dbm.gnu"), +    MovedModule("_dummy_thread", "dummy_thread", "_dummy_thread"), +    MovedModule("http_cookiejar", "cookielib", "http.cookiejar"), +    MovedModule("http_cookies", "Cookie", "http.cookies"), +    MovedModule("html_entities", "htmlentitydefs", "html.entities"), +    MovedModule("html_parser", "HTMLParser", "html.parser"), +    MovedModule("http_client", "httplib", "http.client"), +    MovedModule("email_mime_base", "email.MIMEBase", "email.mime.base"), +    MovedModule("email_mime_image", "email.MIMEImage", "email.mime.image"), +    MovedModule("email_mime_multipart", "email.MIMEMultipart", "email.mime.multipart"), +    MovedModule("email_mime_nonmultipart", "email.MIMENonMultipart", "email.mime.nonmultipart"), +    MovedModule("email_mime_text", "email.MIMEText", "email.mime.text"), +    MovedModule("BaseHTTPServer", "BaseHTTPServer", "http.server"), +    MovedModule("CGIHTTPServer", "CGIHTTPServer", "http.server"), +    MovedModule("SimpleHTTPServer", "SimpleHTTPServer", "http.server"), +    MovedModule("cPickle", "cPickle", "pickle"), +    MovedModule("queue", "Queue"), +    MovedModule("reprlib", "repr"), +    MovedModule("socketserver", "SocketServer"), +    MovedModule("_thread", "thread", "_thread"), +    MovedModule("tkinter", "Tkinter"), +    MovedModule("tkinter_dialog", "Dialog", "tkinter.dialog"), +    MovedModule("tkinter_filedialog", "FileDialog", "tkinter.filedialog"), +    MovedModule("tkinter_scrolledtext", "ScrolledText", "tkinter.scrolledtext"), +    MovedModule("tkinter_simpledialog", "SimpleDialog", "tkinter.simpledialog"), +    MovedModule("tkinter_tix", "Tix", "tkinter.tix"), +    MovedModule("tkinter_ttk", "ttk", "tkinter.ttk"), +    MovedModule("tkinter_constants", "Tkconstants", "tkinter.constants"), +    MovedModule("tkinter_dnd", "Tkdnd", "tkinter.dnd"), +    MovedModule("tkinter_colorchooser", "tkColorChooser", +                "tkinter.colorchooser"), +    MovedModule("tkinter_commondialog", "tkCommonDialog", +                "tkinter.commondialog"), +    MovedModule("tkinter_tkfiledialog", "tkFileDialog", "tkinter.filedialog"), +    MovedModule("tkinter_font", "tkFont", "tkinter.font"), +    MovedModule("tkinter_messagebox", "tkMessageBox", "tkinter.messagebox"), +    MovedModule("tkinter_tksimpledialog", "tkSimpleDialog", +                "tkinter.simpledialog"), +    MovedModule("urllib_parse", __name__ + ".moves.urllib_parse", "urllib.parse"), +    MovedModule("urllib_error", __name__ + ".moves.urllib_error", "urllib.error"), +    MovedModule("urllib", __name__ + ".moves.urllib", __name__ + ".moves.urllib"), +    MovedModule("urllib_robotparser", "robotparser", "urllib.robotparser"), +    MovedModule("xmlrpc_client", "xmlrpclib", "xmlrpc.client"), +    MovedModule("xmlrpc_server", "SimpleXMLRPCServer", "xmlrpc.server"), +] +# Add windows specific modules. +if sys.platform == "win32": +    _moved_attributes += [ +        MovedModule("winreg", "_winreg"), +    ] + +for attr in _moved_attributes: +    setattr(_MovedItems, attr.name, attr) +    if isinstance(attr, MovedModule): +        _importer._add_module(attr, "moves." + attr.name) +del attr + +_MovedItems._moved_attributes = _moved_attributes + +moves = _MovedItems(__name__ + ".moves") +_importer._add_module(moves, "moves") + + +class Module_six_moves_urllib_parse(_LazyModule): + +    """Lazy loading of moved objects in six.moves.urllib_parse""" + + +_urllib_parse_moved_attributes = [ +    MovedAttribute("ParseResult", "urlparse", "urllib.parse"), +    MovedAttribute("SplitResult", "urlparse", "urllib.parse"), +    MovedAttribute("parse_qs", "urlparse", "urllib.parse"), +    MovedAttribute("parse_qsl", "urlparse", "urllib.parse"), +    MovedAttribute("urldefrag", "urlparse", "urllib.parse"), +    MovedAttribute("urljoin", "urlparse", "urllib.parse"), +    MovedAttribute("urlparse", "urlparse", "urllib.parse"), +    MovedAttribute("urlsplit", "urlparse", "urllib.parse"), +    MovedAttribute("urlunparse", "urlparse", "urllib.parse"), +    MovedAttribute("urlunsplit", "urlparse", "urllib.parse"), +    MovedAttribute("quote", "urllib", "urllib.parse"), +    MovedAttribute("quote_plus", "urllib", "urllib.parse"), +    MovedAttribute("unquote", "urllib", "urllib.parse"), +    MovedAttribute("unquote_plus", "urllib", "urllib.parse"), +    MovedAttribute("unquote_to_bytes", "urllib", "urllib.parse", "unquote", "unquote_to_bytes"), +    MovedAttribute("urlencode", "urllib", "urllib.parse"), +    MovedAttribute("splitquery", "urllib", "urllib.parse"), +    MovedAttribute("splittag", "urllib", "urllib.parse"), +    MovedAttribute("splituser", "urllib", "urllib.parse"), +    MovedAttribute("splitvalue", "urllib", "urllib.parse"), +    MovedAttribute("uses_fragment", "urlparse", "urllib.parse"), +    MovedAttribute("uses_netloc", "urlparse", "urllib.parse"), +    MovedAttribute("uses_params", "urlparse", "urllib.parse"), +    MovedAttribute("uses_query", "urlparse", "urllib.parse"), +    MovedAttribute("uses_relative", "urlparse", "urllib.parse"), +] +for attr in _urllib_parse_moved_attributes: +    setattr(Module_six_moves_urllib_parse, attr.name, attr) +del attr + +Module_six_moves_urllib_parse._moved_attributes = _urllib_parse_moved_attributes + +_importer._add_module(Module_six_moves_urllib_parse(__name__ + ".moves.urllib_parse"), +                      "moves.urllib_parse", "moves.urllib.parse") + + +class Module_six_moves_urllib_error(_LazyModule): + +    """Lazy loading of moved objects in six.moves.urllib_error""" + + +_urllib_error_moved_attributes = [ +    MovedAttribute("URLError", "urllib2", "urllib.error"), +    MovedAttribute("HTTPError", "urllib2", "urllib.error"), +    MovedAttribute("ContentTooShortError", "urllib", "urllib.error"), +] +for attr in _urllib_error_moved_attributes: +    setattr(Module_six_moves_urllib_error, attr.name, attr) +del attr + +Module_six_moves_urllib_error._moved_attributes = _urllib_error_moved_attributes + +_importer._add_module(Module_six_moves_urllib_error(__name__ + ".moves.urllib.error"), +                      "moves.urllib_error", "moves.urllib.error") + + +class Module_six_moves_urllib_request(_LazyModule): + +    """Lazy loading of moved objects in six.moves.urllib_request""" + + +_urllib_request_moved_attributes = [ +    MovedAttribute("urlopen", "urllib2", "urllib.request"), +    MovedAttribute("install_opener", "urllib2", "urllib.request"), +    MovedAttribute("build_opener", "urllib2", "urllib.request"), +    MovedAttribute("pathname2url", "urllib", "urllib.request"), +    MovedAttribute("url2pathname", "urllib", "urllib.request"), +    MovedAttribute("getproxies", "urllib", "urllib.request"), +    MovedAttribute("Request", "urllib2", "urllib.request"), +    MovedAttribute("OpenerDirector", "urllib2", "urllib.request"), +    MovedAttribute("HTTPDefaultErrorHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPRedirectHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPCookieProcessor", "urllib2", "urllib.request"), +    MovedAttribute("ProxyHandler", "urllib2", "urllib.request"), +    MovedAttribute("BaseHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPPasswordMgr", "urllib2", "urllib.request"), +    MovedAttribute("HTTPPasswordMgrWithDefaultRealm", "urllib2", "urllib.request"), +    MovedAttribute("AbstractBasicAuthHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPBasicAuthHandler", "urllib2", "urllib.request"), +    MovedAttribute("ProxyBasicAuthHandler", "urllib2", "urllib.request"), +    MovedAttribute("AbstractDigestAuthHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPDigestAuthHandler", "urllib2", "urllib.request"), +    MovedAttribute("ProxyDigestAuthHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPSHandler", "urllib2", "urllib.request"), +    MovedAttribute("FileHandler", "urllib2", "urllib.request"), +    MovedAttribute("FTPHandler", "urllib2", "urllib.request"), +    MovedAttribute("CacheFTPHandler", "urllib2", "urllib.request"), +    MovedAttribute("UnknownHandler", "urllib2", "urllib.request"), +    MovedAttribute("HTTPErrorProcessor", "urllib2", "urllib.request"), +    MovedAttribute("urlretrieve", "urllib", "urllib.request"), +    MovedAttribute("urlcleanup", "urllib", "urllib.request"), +    MovedAttribute("URLopener", "urllib", "urllib.request"), +    MovedAttribute("FancyURLopener", "urllib", "urllib.request"), +    MovedAttribute("proxy_bypass", "urllib", "urllib.request"), +    MovedAttribute("parse_http_list", "urllib2", "urllib.request"), +    MovedAttribute("parse_keqv_list", "urllib2", "urllib.request"), +] +for attr in _urllib_request_moved_attributes: +    setattr(Module_six_moves_urllib_request, attr.name, attr) +del attr + +Module_six_moves_urllib_request._moved_attributes = _urllib_request_moved_attributes + +_importer._add_module(Module_six_moves_urllib_request(__name__ + ".moves.urllib.request"), +                      "moves.urllib_request", "moves.urllib.request") + + +class Module_six_moves_urllib_response(_LazyModule): + +    """Lazy loading of moved objects in six.moves.urllib_response""" + + +_urllib_response_moved_attributes = [ +    MovedAttribute("addbase", "urllib", "urllib.response"), +    MovedAttribute("addclosehook", "urllib", "urllib.response"), +    MovedAttribute("addinfo", "urllib", "urllib.response"), +    MovedAttribute("addinfourl", "urllib", "urllib.response"), +] +for attr in _urllib_response_moved_attributes: +    setattr(Module_six_moves_urllib_response, attr.name, attr) +del attr + +Module_six_moves_urllib_response._moved_attributes = _urllib_response_moved_attributes + +_importer._add_module(Module_six_moves_urllib_response(__name__ + ".moves.urllib.response"), +                      "moves.urllib_response", "moves.urllib.response") + + +class Module_six_moves_urllib_robotparser(_LazyModule): + +    """Lazy loading of moved objects in six.moves.urllib_robotparser""" + + +_urllib_robotparser_moved_attributes = [ +    MovedAttribute("RobotFileParser", "robotparser", "urllib.robotparser"), +] +for attr in _urllib_robotparser_moved_attributes: +    setattr(Module_six_moves_urllib_robotparser, attr.name, attr) +del attr + +Module_six_moves_urllib_robotparser._moved_attributes = _urllib_robotparser_moved_attributes + +_importer._add_module(Module_six_moves_urllib_robotparser(__name__ + ".moves.urllib.robotparser"), +                      "moves.urllib_robotparser", "moves.urllib.robotparser") + + +class Module_six_moves_urllib(types.ModuleType): + +    """Create a six.moves.urllib namespace that resembles the Python 3 namespace""" +    __path__ = []  # mark as package +    parse = _importer._get_module("moves.urllib_parse") +    error = _importer._get_module("moves.urllib_error") +    request = _importer._get_module("moves.urllib_request") +    response = _importer._get_module("moves.urllib_response") +    robotparser = _importer._get_module("moves.urllib_robotparser") + +    def __dir__(self): +        return ['parse', 'error', 'request', 'response', 'robotparser'] + +_importer._add_module(Module_six_moves_urllib(__name__ + ".moves.urllib"), +                      "moves.urllib") + + +def add_move(move): +    """Add an item to six.moves.""" +    setattr(_MovedItems, move.name, move) + + +def remove_move(name): +    """Remove item from six.moves.""" +    try: +        delattr(_MovedItems, name) +    except AttributeError: +        try: +            del moves.__dict__[name] +        except KeyError: +            raise AttributeError("no such move, %r" % (name,)) + + +if PY3: +    _meth_func = "__func__" +    _meth_self = "__self__" + +    _func_closure = "__closure__" +    _func_code = "__code__" +    _func_defaults = "__defaults__" +    _func_globals = "__globals__" +else: +    _meth_func = "im_func" +    _meth_self = "im_self" + +    _func_closure = "func_closure" +    _func_code = "func_code" +    _func_defaults = "func_defaults" +    _func_globals = "func_globals" + + +try: +    advance_iterator = next +except NameError: +    def advance_iterator(it): +        return it.next() +next = advance_iterator + + +try: +    callable = callable +except NameError: +    def callable(obj): +        return any("__call__" in klass.__dict__ for klass in type(obj).__mro__) + + +if PY3: +    def get_unbound_function(unbound): +        return unbound + +    create_bound_method = types.MethodType + +    def create_unbound_method(func, cls): +        return func + +    Iterator = object +else: +    def get_unbound_function(unbound): +        return unbound.im_func + +    def create_bound_method(func, obj): +        return types.MethodType(func, obj, obj.__class__) + +    def create_unbound_method(func, cls): +        return types.MethodType(func, None, cls) + +    class Iterator(object): + +        def next(self): +            return type(self).__next__(self) + +    callable = callable +_add_doc(get_unbound_function, +         """Get the function out of a possibly unbound function""") + + +get_method_function = operator.attrgetter(_meth_func) +get_method_self = operator.attrgetter(_meth_self) +get_function_closure = operator.attrgetter(_func_closure) +get_function_code = operator.attrgetter(_func_code) +get_function_defaults = operator.attrgetter(_func_defaults) +get_function_globals = operator.attrgetter(_func_globals) + + +if PY3: +    def iterkeys(d, **kw): +        return iter(d.keys(**kw)) + +    def itervalues(d, **kw): +        return iter(d.values(**kw)) + +    def iteritems(d, **kw): +        return iter(d.items(**kw)) + +    def iterlists(d, **kw): +        return iter(d.lists(**kw)) + +    viewkeys = operator.methodcaller("keys") + +    viewvalues = operator.methodcaller("values") + +    viewitems = operator.methodcaller("items") +else: +    def iterkeys(d, **kw): +        return d.iterkeys(**kw) + +    def itervalues(d, **kw): +        return d.itervalues(**kw) + +    def iteritems(d, **kw): +        return d.iteritems(**kw) + +    def iterlists(d, **kw): +        return d.iterlists(**kw) + +    viewkeys = operator.methodcaller("viewkeys") + +    viewvalues = operator.methodcaller("viewvalues") + +    viewitems = operator.methodcaller("viewitems") + +_add_doc(iterkeys, "Return an iterator over the keys of a dictionary.") +_add_doc(itervalues, "Return an iterator over the values of a dictionary.") +_add_doc(iteritems, +         "Return an iterator over the (key, value) pairs of a dictionary.") +_add_doc(iterlists, +         "Return an iterator over the (key, [values]) pairs of a dictionary.") + + +if PY3: +    def b(s): +        return s.encode("latin-1") + +    def u(s): +        return s +    unichr = chr +    import struct +    int2byte = struct.Struct(">B").pack +    del struct +    byte2int = operator.itemgetter(0) +    indexbytes = operator.getitem +    iterbytes = iter +    import io +    StringIO = io.StringIO +    BytesIO = io.BytesIO +    _assertCountEqual = "assertCountEqual" +    if sys.version_info[1] <= 1: +        _assertRaisesRegex = "assertRaisesRegexp" +        _assertRegex = "assertRegexpMatches" +    else: +        _assertRaisesRegex = "assertRaisesRegex" +        _assertRegex = "assertRegex" +else: +    def b(s): +        return s +    # Workaround for standalone backslash + +    def u(s): +        return unicode(s.replace(r'\\', r'\\\\'), "unicode_escape") +    unichr = unichr +    int2byte = chr + +    def byte2int(bs): +        return ord(bs[0]) + +    def indexbytes(buf, i): +        return ord(buf[i]) +    iterbytes = functools.partial(itertools.imap, ord) +    import StringIO +    StringIO = BytesIO = StringIO.StringIO +    _assertCountEqual = "assertItemsEqual" +    _assertRaisesRegex = "assertRaisesRegexp" +    _assertRegex = "assertRegexpMatches" +_add_doc(b, """Byte literal""") +_add_doc(u, """Text literal""") + + +def assertCountEqual(self, *args, **kwargs): +    return getattr(self, _assertCountEqual)(*args, **kwargs) + + +def assertRaisesRegex(self, *args, **kwargs): +    return getattr(self, _assertRaisesRegex)(*args, **kwargs) + + +def assertRegex(self, *args, **kwargs): +    return getattr(self, _assertRegex)(*args, **kwargs) + + +if PY3: +    exec_ = getattr(moves.builtins, "exec") + +    def reraise(tp, value, tb=None): +        try: +            if value is None: +                value = tp() +            if value.__traceback__ is not tb: +                raise value.with_traceback(tb) +            raise value +        finally: +            value = None +            tb = None + +else: +    def exec_(_code_, _globs_=None, _locs_=None): +        """Execute code in a namespace.""" +        if _globs_ is None: +            frame = sys._getframe(1) +            _globs_ = frame.f_globals +            if _locs_ is None: +                _locs_ = frame.f_locals +            del frame +        elif _locs_ is None: +            _locs_ = _globs_ +        exec("""exec _code_ in _globs_, _locs_""") + +    exec_("""def reraise(tp, value, tb=None): +    try: +        raise tp, value, tb +    finally: +        tb = None +""") + + +if sys.version_info[:2] == (3, 2): +    exec_("""def raise_from(value, from_value): +    try: +        if from_value is None: +            raise value +        raise value from from_value +    finally: +        value = None +""") +elif sys.version_info[:2] > (3, 2): +    exec_("""def raise_from(value, from_value): +    try: +        raise value from from_value +    finally: +        value = None +""") +else: +    def raise_from(value, from_value): +        raise value + + +print_ = getattr(moves.builtins, "print", None) +if print_ is None: +    def print_(*args, **kwargs): +        """The new-style print function for Python 2.4 and 2.5.""" +        fp = kwargs.pop("file", sys.stdout) +        if fp is None: +            return + +        def write(data): +            if not isinstance(data, basestring): +                data = str(data) +            # If the file has an encoding, encode unicode with it. +            if (isinstance(fp, file) and +                    isinstance(data, unicode) and +                    fp.encoding is not None): +                errors = getattr(fp, "errors", None) +                if errors is None: +                    errors = "strict" +                data = data.encode(fp.encoding, errors) +            fp.write(data) +        want_unicode = False +        sep = kwargs.pop("sep", None) +        if sep is not None: +            if isinstance(sep, unicode): +                want_unicode = True +            elif not isinstance(sep, str): +                raise TypeError("sep must be None or a string") +        end = kwargs.pop("end", None) +        if end is not None: +            if isinstance(end, unicode): +                want_unicode = True +            elif not isinstance(end, str): +                raise TypeError("end must be None or a string") +        if kwargs: +            raise TypeError("invalid keyword arguments to print()") +        if not want_unicode: +            for arg in args: +                if isinstance(arg, unicode): +                    want_unicode = True +                    break +        if want_unicode: +            newline = unicode("\n") +            space = unicode(" ") +        else: +            newline = "\n" +            space = " " +        if sep is None: +            sep = space +        if end is None: +            end = newline +        for i, arg in enumerate(args): +            if i: +                write(sep) +            write(arg) +        write(end) +if sys.version_info[:2] < (3, 3): +    _print = print_ + +    def print_(*args, **kwargs): +        fp = kwargs.get("file", sys.stdout) +        flush = kwargs.pop("flush", False) +        _print(*args, **kwargs) +        if flush and fp is not None: +            fp.flush() + +_add_doc(reraise, """Reraise an exception.""") + +if sys.version_info[0:2] < (3, 4): +    def wraps(wrapped, assigned=functools.WRAPPER_ASSIGNMENTS, +              updated=functools.WRAPPER_UPDATES): +        def wrapper(f): +            f = functools.wraps(wrapped, assigned, updated)(f) +            f.__wrapped__ = wrapped +            return f +        return wrapper +else: +    wraps = functools.wraps + + +def with_metaclass(meta, *bases): +    """Create a base class with a metaclass.""" +    # This requires a bit of explanation: the basic idea is to make a dummy +    # metaclass for one level of class instantiation that replaces itself with +    # the actual metaclass. +    class metaclass(type): + +        def __new__(cls, name, this_bases, d): +            return meta(name, bases, d) + +        @classmethod +        def __prepare__(cls, name, this_bases): +            return meta.__prepare__(name, bases) +    return type.__new__(metaclass, 'temporary_class', (), {}) + + +def add_metaclass(metaclass): +    """Class decorator for creating a class with a metaclass.""" +    def wrapper(cls): +        orig_vars = cls.__dict__.copy() +        slots = orig_vars.get('__slots__') +        if slots is not None: +            if isinstance(slots, str): +                slots = [slots] +            for slots_var in slots: +                orig_vars.pop(slots_var) +        orig_vars.pop('__dict__', None) +        orig_vars.pop('__weakref__', None) +        return metaclass(cls.__name__, cls.__bases__, orig_vars) +    return wrapper + + +def python_2_unicode_compatible(klass): +    """ +    A decorator that defines __unicode__ and __str__ methods under Python 2. +    Under Python 3 it does nothing. + +    To support Python 2 and 3 with a single code base, define a __str__ method +    returning text and apply this decorator to the class. +    """ +    if PY2: +        if '__str__' not in klass.__dict__: +            raise ValueError("@python_2_unicode_compatible cannot be applied " +                             "to %s because it doesn't define __str__()." % +                             klass.__name__) +        klass.__unicode__ = klass.__str__ +        klass.__str__ = lambda self: self.__unicode__().encode('utf-8') +    return klass + + +# Complete the moves implementation. +# This code is at the end of this module to speed up module loading. +# Turn this module into a package. +__path__ = []  # required for PEP 302 and PEP 451 +__package__ = __name__  # see PEP 366 @ReservedAssignment +if globals().get("__spec__") is not None: +    __spec__.submodule_search_locations = []  # PEP 451 @UndefinedVariable +# Remove other six meta path importers, since they cause problems. This can +# happen if six is removed from sys.modules and then reloaded. (Setuptools does +# this for some reason.) +if sys.meta_path: +    for i, importer in enumerate(sys.meta_path): +        # Here's some real nastiness: Another "instance" of the six module might +        # be floating around. Therefore, we can't use isinstance() to check for +        # the six meta path importer, since the other six instance will have +        # inserted an importer with different class. +        if (type(importer).__name__ == "_SixMetaPathImporter" and +                importer.name == __name__): +            del sys.meta_path[i] +            break +    del i, importer +# Finally, add the importer to the meta path import hook. +sys.meta_path.append(_importer) diff --git a/youtube/channel.py b/youtube/channel.py index 9577525..c83d7d1 100644 --- a/youtube/channel.py +++ b/youtube/channel.py @@ -248,6 +248,7 @@ def channel_videos_html(polymer_json, current_page=1, current_sort=3, number_of_      return yt_channel_items_template.substitute(          header              = common.get_header(),          channel_title       = microformat['title'], +        channel_id          = channel_id,          channel_tabs        = channel_tabs_html(channel_id, 'Videos'),          sort_buttons        = channel_sort_buttons_html(channel_id, 'videos', current_sort),          avatar              = '/' + microformat['thumbnail']['thumbnails'][0]['url'], @@ -269,6 +270,7 @@ def channel_playlists_html(polymer_json, current_sort=3):      return yt_channel_items_template.substitute(          header              = common.get_header(),          channel_title       = microformat['title'], +        channel_id          = channel_id,          channel_tabs        = channel_tabs_html(channel_id, 'Playlists'),          sort_buttons        = channel_sort_buttons_html(channel_id, 'playlists', current_sort),          avatar              = '/' + microformat['thumbnail']['thumbnails'][0]['url'], @@ -333,6 +335,7 @@ def channel_about_page(polymer_json):          description         = description,          links               = channel_links,          stats               = stats, +        channel_id          = channel_metadata['channelId'],          channel_tabs        = channel_tabs_html(channel_metadata['channelId'], 'About'),      ) @@ -353,6 +356,7 @@ def channel_search_page(polymer_json, query, current_page=1, number_of_videos =      return yt_channel_items_template.substitute(          header              = common.get_header(),          channel_title       = html.escape(microformat['title']), +        channel_id          = channel_id,          channel_tabs        = channel_tabs_html(channel_id, '', query),          avatar              = '/' + microformat['thumbnail']['thumbnails'][0]['url'],          page_title          = html.escape(query + ' - Channel search'), diff --git a/youtube/subscriptions.py b/youtube/subscriptions.py index 82916dd..ff7d0df 100644 --- a/youtube/subscriptions.py +++ b/youtube/subscriptions.py @@ -5,6 +5,10 @@ import sqlite3  import os  import secrets  import datetime +import itertools +import time +import urllib +import socks, sockshandler  # so as to not completely break on people who have updated but don't know of new dependency  try: @@ -51,11 +55,16 @@ def open_database():      return connection -def _subscribe(channel_id, channel_name): +def _subscribe(channels): +    ''' channels is a list of (channel_id, channel_name) ''' + +    # set time_last_checked to 0 on all channels being subscribed to +    channels = ( (channel_id, channel_name, 0) for channel_id, channel_name in channels) +      connection = open_database()      try:          cursor = connection.cursor() -        cursor.execute("INSERT INTO subscribed_channels (channel_id, name) VALUES (?, ?)", (channel_id, channel_name)) +        cursor.executemany("INSERT INTO subscribed_channels (channel_id, channel_name, time_last_checked) VALUES (?, ?, ?)", channels)          connection.commit()      except:          connection.rollback() @@ -63,11 +72,12 @@ def _subscribe(channel_id, channel_name):      finally:          connection.close() -def _unsubscribe(channel_id): +def _unsubscribe(channel_ids): +    ''' channel_ids is a list of channel_ids '''      connection = open_database()      try:          cursor = connection.cursor() -        cursor.execute("DELETE FROM subscribed_channels WHERE channel_id=?", (channel_id, )) +        cursor.executemany("DELETE FROM subscribed_channels WHERE channel_id=?", ((channel_id, ) for channel_id in channel_ids))          connection.commit()      except:          connection.rollback() @@ -125,12 +135,14 @@ def youtube_timestamp_to_posix(dumb_timestamp):  weekdays = ('Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun')  months = ('Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec') -def _get_upstream_videos(channel_id, channel_name, time_last_checked): +def _get_upstream_videos(channel_id, time_last_checked):      feed_url = "https://www.youtube.com/feeds/videos.xml?channel_id=" + channel_id      headers = {}      # randomly change time_last_checked up to one day earlier to make tracking harder      time_last_checked = time_last_checked - secrets.randbelow(24*3600) +    if time_last_checked < 0:   # happens when time_last_checked is initialized to 0 when checking for first time +        time_last_checked = 0      # If-Modified-Since header: https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/If-Modified-Since      struct_time = time.gmtime(time_last_checked) @@ -142,7 +154,7 @@ def _get_upstream_videos(channel_id, channel_name, time_last_checked):      headers['User-Agent'] = 'Python-urllib'     # Don't leak python version      headers['Accept-Encoding'] = 'gzip, br' -    req = urllib.request.Request(url, headers=headers) +    req = urllib.request.Request(feed_url, headers=headers)      if settings.route_tor:          opener = urllib.request.build_opener(sockshandler.SocksiPyHandler(socks.PROXY_TYPE_SOCKS5, "127.0.0.1", 9150))      else: @@ -165,13 +177,10 @@ def _get_upstream_videos(channel_id, channel_name, time_last_checked):      for entry in feed.entries:          video_id = entry.id_[9:]     # example of id_: yt:video:q6EoRBvdVPQ -        # standard names used in this program for purposes of html templating          atom_videos[video_id] = {              'title': entry.title.value, -            'author': entry.authors[0].name,              #'description': '',              # Not supported by atoma              #'duration': '',                 # Youtube's atom feeds don't provide it.. very frustrating -            'published':    entry.published.strftime('%m/%d/%Y'),              'time_published':   int(entry.published.timestamp()),          } @@ -182,12 +191,13 @@ def _get_upstream_videos(channel_id, channel_name, time_last_checked):      # Now check channel page to retrieve missing information for videos      json_channel_videos = channel.get_grid_items(channel.get_channel_tab(channel_id)[1]['response'])      for json_video in json_channel_videos: -        info = renderer_info(json_video) +        info = common.renderer_info(json_video['gridVideoRenderer']) +        if 'description' not in info: +            info['description'] = ''          if info['id'] in atom_videos:              info.update(atom_videos[info['id']])          else: -            info['author'] = channel_name -            info['time published'] = youtube_timestamp_to_posix(info['published']) +            info['time_published'] = youtube_timestamp_to_posix(info['published'])          videos.append(info)      return videos @@ -195,7 +205,7 @@ def get_subscriptions_page(env, start_response):      items_html = '''<nav class="item-grid">\n'''      for item in _get_videos(30, 0): -        items_html += common.video_item_html(info, common.small_video_item_template) +        items_html += common.video_item_html(item, common.small_video_item_template)      items_html += '''\n</nav>'''      start_response('200 OK', [('Content-type','text/html'),]) @@ -205,3 +215,38 @@ def get_subscriptions_page(env, start_response):          page_buttons = '',      ).encode('utf-8') +def post_subscriptions_page(env, start_response): +    params = env['parameters'] +    action = params['action'][0] +    if action == 'subscribe': +        if len(params['channel_id']) != len(params['channel_name']): +            start_response('400 Bad Request', ()) +            return b'400 Bad Request, length of channel_id != length of channel_name' +        _subscribe(zip(params['channel_id'], params['channel_name'])) + +    elif action == 'unsubscribe': +        _unsubscribe(params['channel_id']) + +    elif action == 'refresh': +        connection = open_database() +        try: +            cursor = connection.cursor() +            for uploader_id, channel_id, time_last_checked in cursor.execute('''SELECT id, channel_id, time_last_checked FROM subscribed_channels'''): +                db_videos = ( (uploader_id, info['id'], info['title'], info['duration'], info['time_published'], info['description']) for info in _get_upstream_videos(channel_id, time_last_checked) ) +                cursor.executemany('''INSERT INTO videos (uploader_id, video_id, title, duration, time_published, description) VALUES (?, ?, ?, ?, ?, ?)''', db_videos) + +            cursor.execute('''UPDATE subscribed_channels SET time_last_checked = ?''', ( int(time.time()), ) ) +            connection.commit() +        except: +            connection.rollback() +            raise +        finally: +            connection.close() + +        start_response('303 See Other', [('Location', common.URL_ORIGIN + '/subscriptions'),] ) +        return b'' +    else: +        start_response('400 Bad Request', ()) +        return b'400 Bad Request' +    start_response('204 No Content', ()) +    return b'' diff --git a/youtube/youtube.py b/youtube/youtube.py index ad73a6e..288f68b 100644 --- a/youtube/youtube.py +++ b/youtube/youtube.py @@ -35,6 +35,8 @@ post_handlers = {      'comments':         post_comment.post_comment,      'post_comment':     post_comment.post_comment,      'delete_comment':   post_comment.delete_comment, + +    'subscriptions':    subscriptions.post_subscriptions_page,  }  def youtube(env, start_response): diff --git a/yt_channel_about_template.html b/yt_channel_about_template.html index 221b838..6ed7a03 100644 --- a/yt_channel_about_template.html +++ b/yt_channel_about_template.html @@ -18,12 +18,16 @@                      height:200px;                      width:200px;                  } -                main .title{ +                .metadata{                      grid-row:1; -                    grid-column:2;                 +                    grid-column:2; +                    margin-left: 10px; +                    display:grid; +                    align-content: start; +                    grid-row-gap:10px;                  } +                  main .channel-tabs{ -                    grid-row:2;                      grid-column: 1 / span 2;                      display:grid; @@ -34,7 +38,6 @@                      padding: 3px;                  }                  main .channel-info{ -                    grid-row: 3;                      grid-column: 1 / span 3;                  }                  .tab{ @@ -51,7 +54,15 @@  $header          <main>                     <img class="avatar" src="$avatar"> -            <h2 class="title">$channel_title</h2> +            <div class="metadata"> +                <h2 class="title">$channel_title</h2> +                <form method="POST" action="/youtube.com/subscriptions" class="subscribe"> +                    <input type="submit" value="Subscribe"> +                    <input type="hidden" name="channel_id" value="$channel_id"> +                    <input type="hidden" name="channel_name" value="$channel_title"> +                    <input type="hidden" name="action" value="subscribe"> +                </form> +            </div>              <nav class="channel-tabs">  $channel_tabs              </nav> diff --git a/yt_channel_items_template.html b/yt_channel_items_template.html index 1a8551d..93c4b0a 100644 --- a/yt_channel_items_template.html +++ b/yt_channel_items_template.html @@ -18,12 +18,15 @@                      height:200px;                      width:200px;                  } -                main .title{ +                .metadata{                      grid-row:1; -                    grid-column:2;                 +                    grid-column:2; +                    margin-left: 10px; +                    display:grid; +                    align-content: start; +                    grid-row-gap:10px;                  }                  main .channel-tabs{ -                    grid-row:2;                      grid-column: 1 / span 2;                      display:grid; @@ -48,7 +51,6 @@                          font-weight:bold;                      }                  .item-grid{ -                    grid-row:4;                      grid-column: 1 / span 2;                  }                  .item-list{ @@ -68,7 +70,15 @@  $header          <main>                     <img class="avatar" src="$avatar"> -            <h2 class="title">$channel_title</h2> +            <div class="metadata"> +                <h2 class="title">$channel_title</h2> +                <form method="POST" action="/youtube.com/subscriptions" class="subscribe"> +                    <input type="submit" value="Subscribe"> +                    <input type="hidden" name="channel_id" value="$channel_id"> +                    <input type="hidden" name="channel_name" value="$channel_title"> +                    <input type="hidden" name="action" value="subscribe"> +                </form> +            </div>              <nav class="channel-tabs">  $channel_tabs              </nav> diff --git a/yt_subscriptions_template.html b/yt_subscriptions_template.html index 8477d25..6395b6c 100644 --- a/yt_subscriptions_template.html +++ b/yt_subscriptions_template.html @@ -15,6 +15,10 @@      <body>  $header          <main> +            <form method="POST"> +                <input type="submit" value="refresh"> +                <input type="hidden" name="action" value="refresh"> +            </form>  $items              <nav class="page-button-row">  $page_buttons | 
