#!/usr/bin/env python3

"""qmauth.py

Extract delivers information about emails from a maildir.
Runs with elevated privileges.

This program is started by qmail-authuser with elevated privileges after
a successful login.

The run method is provided by a command line argument.
Additional data is read from STDIN as protobuf.
Output is delivered via STDOUT as protobuf and log information via STDERR.

Exit codes::

      1  reserved
      2  reserved
      3  operational error       (error message in output)
      4  user error              (no output)
      5  issue switching to user (no output)
    110  reserved
    111  reserved
"""

import email.parser
import email.policy
import logging
import re
from argparse import ArgumentParser
from base64 import b64encode
from datetime import datetime
from email.message import EmailMessage
from itertools import islice
from mailbox import Maildir, MaildirMessage
from os import environ, getpid, mkdir, path, setuid
from pathlib import Path
from pwd import getpwnam
from sys import exit as sysexit
from sys import stdin, stdout

import jwebmail.model.jwebmail_pb2 as jwebmail


class MyMaildir(Maildir):
    def __init__(self, dirname, *args, **kwargs):
        self.__path = Path(dirname)
        self.set_msgtype("MaildirMessage")

        super().__init__(dirname, *args, **kwargs)

    def set_msgtype(self, typ):
        if typ == "MaildirMessage":
            self._factory = None
        elif typ == "EmailMessage":
            self._factory = lambda mail_file: email.parser.BytesParser(
                policy=email.policy.default
            ).parse(mail_file)
        elif typ == "Maildir(EmailMessageHeader)":

            def factory(mail_file):
                """
                Copy from implementation.
                """
                name = Path(mail_file._file.name).name
                msg = MaildirMessage(
                    email.parser.BytesHeaderParser(policy=email.policy.default).parse(
                        mail_file
                    )
                )
                if self.colon in name:
                    msg.set_info(name.split(self.colon)[-1])
                return msg

            self._factory = factory

        elif typ == "None":
            self._factory = lambda _file: MaildirMessage(None)
        else:
            raise ValueError(typ)

    def get_filename(self, mid):
        return self.__path / self._lookup(mid)

    def get_folder(self, folder):
        """
        This is an internal copy of the implementation of Maildir.

        This is because Maildir was not designed with inheritance in mind
        and this method does hence not return an instance of itself but
        rather only and instance of Maildir.

        This override corrects this but can not anticipate additional state
        introduced in more derived classes.
        """
        return type(self)(
            path.join(self._path, "." + folder),
            create=False,
        )

    def _refresh(self):
        """
        This override of internal method _refresh that strips out 'hidden' files.
        """
        super()._refresh()
        rm = [r for r in self._toc if r.startswith(".")]
        for r in rm:
            del self._toc[r]


class QMAuthError(Exception):
    def __init__(self, msg, **args):
        self.msg = msg
        self.info = args


def _adr(addrs):
    if addrs is None:
        return []
    return [
        jwebmail.MailHeader.MailAddr(
            address=addr.addr_spec,
            name=addr.display_name,
        )
        for addr in addrs.addresses
    ]


def _get_rcv_time(mid):
    idx = mid.find(".")
    assert idx > 0
    return float(mid[:idx])


def startup(maildir, su, user, mode):
    del environ["PATH"]

    netfehcom_uid = getpwnam(su).pw_uid
    if not netfehcom_uid:
        logging.error("user must not be root")
        sysexit(5)
    try:
        setuid(netfehcom_uid)
    except OSError:
        logging.exception("error setting uid")
        sysexit(5)

    return MyMaildir(maildir / user, create=False)


def _sort_by_sender(midmsg):
    _, msg = midmsg

    if len(addrs := msg["from"].addresses) == 1:
        return addrs[0].addr_spec
    else:
        return msg["sender"].address.addr_spec


def _sort_mails(f, sort):
    reverse = False
    if sort.startswith("!"):
        reverse = True
        sort = sort[1:]

    def by_rec_date(midmsg):
        return float(re.match(r"\d+\.\d+", midmsg[0], re.ASCII)[0])

    if sort == "date":
        keyfn = by_rec_date
    elif sort == "sender":
        keyfn = _sort_by_sender
    elif sort == "subject":
        keyfn = lambda midmsg: midmsg[1]["subject"]
    elif sort == "size":
        keyfn = lambda midmsg: path.getsize(f.get_filename(midmsg[0]))
    elif sort == "":
        keyfn = by_rec_date
    else:
        logging.warning("unknown sort-verb %r", sort)
        reverse = False
        keyfn = by_rec_date

    return keyfn, reverse


def _get_mime_head_info(msg):
    mh = jwebmail.MIMEHeader(
        maintype=msg.get_content_maintype(),
        subtype=msg.get_content_subtype(),
    )
    if (cd := msg.get_content_disposition()) == "inline":
        mh.contentdispo = (
            jwebmail.MIMEHeader.ContentDisposition.CONTENT_DISPOSITION_INLINE
        )
    elif cd == "attachment":
        mh.contentdispo = (
            jwebmail.MIMEHeader.ContentDisposition.CONTENT_DISPOSITION_ATTACHMENT
        )
    elif cd is None:
        mh.contentdispo = (
            jwebmail.MIMEHeader.ContentDisposition.CONTENT_DISPOSITION_NONE
        )
    else:
        assert False

    if fn := msg.get_filename():
        mh.file_name = fn

    return mh


def _get_head_info(msg):
    mh = jwebmail.MailHeader(
        send_date=msg["date"].datetime.isoformat(),
        written_from=_adr(msg["from"]),
        reply_to=_adr(msg["reply-to"]),
        send_to=_adr(msg["to"]),
        cc=_adr(msg["cc"]),
        bcc=_adr(msg["bcc"]),
        subject=msg["subject"],
        comments=msg["comments"],
        keywords=msg["keywords"],
        mime=_get_mime_head_info(msg),
    )

    if s := _adr(msg["sender"]):
        mh.sender = s[0]

    return mh


def list_mails(f, req):
    r = jwebmail.ListReq()
    r.ParseFromString(req)

    assert 0 <= r.start <= r.end

    if r.folder:
        f = f.get_folder(r.folder)

    f.set_msgtype("Maildir(EmailMessageHeader)")

    if r.start == r.end:
        return []

    kfn, reverse = _sort_mails(f, r.sort)
    msgs = list(f.items())
    msgs.sort(key=kfn, reverse=reverse)
    msgs = msgs[r.start : min(len(msgs), r.end)]

    items = [
        jwebmail.ListMailHeader(
            mid=mid,
            byte_size=path.getsize(f.get_filename(mid)),
            unread="S" not in msg.get_flags(),
            rec_date=datetime.fromtimestamp(_get_rcv_time(mid)).isoformat(),
            header=_get_head_info(msg),
        )
        for mid, msg in msgs
    ]
    return jwebmail.ListResp(mail_heads=items).SerializeToString()


def count_mails(f, req):
    r = jwebmail.StatsReq()
    r.ParseFromString(req)
    if r.folder:
        f = f.get_folder(r.folder)

    f.set_msgtype("None")

    resp = jwebmail.StatsResp(
        mail_count=len(f),
        unread_count=len([1 for m in f if "S" in m.get_flags()]),
        byte_size=sum(path.getsize(f.get_filename(mid)) for mid in f.keys()),
    )
    return resp.SerializeToString()


def _get_body(mail):
    if not mail.is_multipart():
        if mail.get_content_maintype() == "text":
            return jwebmail.MailBody(discrete=mail.get_content())
        else:
            ret = mail.get_content()
            if ret.isascii():
                return jwebmail.MailBody(discrete=ret.decode(encoding="ascii"))
            elif len(ret) <= 128 * 1024:
                return jwebmail.MailBody(
                    discrete=b64encode(ret).decode(encoding="ascii")
                )
            else:
                raise QMAuthError(
                    "non attachment part too large (>512kB)", size=len(ret)
                )

    if (mctype := mail.get_content_maintype()) == "message":
        msg = mail.get_content()
        return jwebmail.MailBody(
            mail=jwebmail.Mail(head=_get_head_info(msg), body=_get_body(msg))
        )
    elif mctype == "multipart":
        ret = jwebmail.MailBody.Multipart(
            preamble=mail.preamble,
            epilogue=mail.epilogue,
        )
        for part in mail.iter_parts():
            head = _get_mime_head_info(part)
            if (
                head.contentdispo
                != jwebmail.MIMEHeader.ContentDisposition.CONTENT_DISPOSITION_ATTACHMENT
            ):
                body = _get_body(part)
            else:
                body = None
            ret.parts.append(
                jwebmail.MIMEPart(
                    mime_header=head,
                    body=body,
                )
            )
        return jwebmail.MailBody(multipart=ret)
    else:
        raise ValueError(f"unknown major content-type {mctype!r}")


def read_mail(f, req):
    r = jwebmail.ShowReq()
    r.ParseFromString(req)

    if r.folder:
        f = f.get_folder(r.folder)

    f.set_msgtype("EmailMessage")

    msg = f.get(r.mid, None)
    if not msg:
        raise QMAuthError("no such message", mid=r.mid)

    f.set_msgtype("MaildirMessage")

    f[r.mid].add_flag("S")

    res = jwebmail.Mail(
        head=_get_head_info(msg),
        body=_get_body(msg),
    )
    return jwebmail.ShowResp(mail=res).SerializeToString()


def _descent(xx):
    head = _get_mime_head_info(xx)
    if (mctype := head.maintype) == "message":
        body = xx.get_content()
    elif mctype == "multipart":
        body = xx.iter_parts()
    else:
        body = xx.get_content()
    return head, body


def raw_mail(f, req):
    r = jwebmail.RawReq()
    r.ParseFromString(req)

    if r.folder:
        f = f.get_folder(r.folder)

    f.set_msgtype("EmailMessage")

    msg = f.get(r.mid, None)
    if not msg:
        raise QMAuthError("no such message", mid=r.mid)

    pth = [int(seg) for seg in r.path.split(".")] if r.path else []
    h = jwebmail.MIMEHeader(maintype="message", subtype="rfc822")
    b = msg

    for n in pth:
        mctype = h.maintype

        if mctype == "multipart":
            try:
                res = next(islice(b, n, None))
            except StopIteration:
                raise QMAuthError("out of bounds path for mail", path=pth)
            (h, b) = _descent(res)
        elif mctype == "message":
            assert n == 0
            (h, b) = _descent(b)
        else:
            raise QMAuthError(
                f"can not descent into non multipart content type {mctype}"
            )

    if hasattr(b, "__next__"):
        raise QMAuthError("can not stop at multipart section", path=pth)
    elif isinstance(b, str):
        b = b.encode()
    elif isinstance(b, EmailMessage):
        b = b.as_bytes()

    return jwebmail.RawResp(header=h, body=b).SerializeToString()


def _matches(m, pattern):
    if m.is_multipart():
        return any(
            1
            for part in m.body.parts
            if re.search(pattern, part.decoded()) or re.search(pattern, part.subject)
        )
    return re.search(pattern, m.body.decoded()) or re.search(pattern, m.subject)


def search_mails(f, req):
    r = jwebmail.SearchReq()
    r.ParseFromString(req)

    if r.folder:
        f = f.get_folder(r.folder)

    f.set_msgtype("EmailMessage")

    res = [
        jwebmail.ListMailHeader(
            header=_get_head_info(msg),
        )
        for msg in f.values()
        if _matches(msg, r.pattern)
    ]
    return jwebmail.SearchResp(found=res).SerializeToString()


def folders(f, req):
    r = jwebmail.FoldersReq()
    r.ParseFromString(req)
    return jwebmail.FoldersResp(folders=f.list_folders()).SerializeToString()


def move_mail(f, req):
    r = jwebmail.MoveReq()
    r.ParseFromString(req)

    if r.from_f:
        f = f.get_folder(r.from_f)

    fname = Path(f.get_filename(r.mid))

    assert r.to_f in f.list_folders() or r.to_f == ""

    sep = -2 if not r.from_f else -3

    if r.to_f:
        res = fname.parts[:sep] + ("." + r.to_f,) + fname.parts[-2:]
    else:
        res = fname.parts[:sep] + fname.parts[-2:]

    fname.rename(Path(*res))

    return jwebmail.MoveResp().SerializeToString()


def remove_mail(f, req):
    r = jwebmail.RemoveReq()
    r.ParseFromString(req)

    if r.folder:
        f = f.get_folder(r.folder)

    f.set_msgtype("MaildirMessage")

    f[r.mid].add_flag("T")

    return jwebmail.RemoveResp().SerializeToString()


def add_folder(f, req):
    r = jwebmail.AddFolderReq()
    r.ParseFromString(req)

    name = path.join(f._path, "." + r.name.translate(str.maketrans("/", ".")))

    if path.isdir(name):
        return jwebmail.AddFolderResp(status=1).SerializeToString()

    mkdir(name)
    mkdir(path.join(name, "cur"))
    mkdir(path.join(name, "new"))
    mkdir(path.join(name, "tmp"))

    return jwebmail.AddFolderResp(status=0).SerializeToString()


def method_to_run(value):
    if value == "list":
        return list_mails
    elif value == "count":
        return count_mails
    elif value == "read":
        return read_mail
    elif value == "raw":
        return raw_mail
    elif value == "folders":
        return folders
    elif value == "move":
        return move_mail
    elif value == "remove":
        return remove_mail
    elif value == "search":
        return search_mails
    elif value == "add_folder":
        return add_folder
    else:
        raise ValueError(value)


def parse_arguments():
    ap = ArgumentParser(allow_abbrev=False)
    ap.add_argument("maildir_path", type=Path)
    ap.add_argument("os_user")
    ap.add_argument("mail_user")

    ap.add_argument(
        "method",
        choices=[
            "list",
            "count",
            "read",
            "raw",
            "folders",
            "move",
            "remove",
            "add_folder",
        ],
    )

    return vars(ap.parse_args())


def main():
    try:
        logging.basicConfig(
            level="INFO",
            format="%(levelname)s:" + str(getpid()) + ":%(message)s",
        )
        args = parse_arguments()
        logging.debug("started with %s", args)
        s = startup(
            args["maildir_path"],
            args["os_user"],
            args["mail_user"],
            args["method"],
        )
        logging.debug("setuid successful")
        stdout.write("OPEN\n")
        stdout.flush()
        val = stdin.buffer.read()
        run = method_to_run(args["method"])
        reply = run(s, val)
        logging.debug("pb method(%s) size(%d)", args["method"], len(reply))
        stdout.buffer.write(reply)
    # except QMAuthError as qerr:
    #    errmsg = dict(error=qerr.msg, **qerr.info)
    #    sysexit(3)
    except Exception:
        logging.exception("qmauth.py error")
        sysexit(4)


if __name__ == "__main__":
    main()