summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJannis M. Hoffmann <jannis@fehcom.de>2024-12-09 17:58:17 +0100
committerJannis M. Hoffmann <jannis@fehcom.de>2024-12-09 17:58:17 +0100
commit55688b969a645fbc6d94c76f51da3be976c1d098 (patch)
treeb0b72173c041fa1253fba2e3a0b55ce65d0c8006
parent44b719671fe73b7378789968ecc8d48d7f9c00ca (diff)
make QMailAuthuser a context manager
-rw-r--r--src/jwebmail/__init__.py10
-rw-r--r--src/jwebmail/model/read_mails.py25
-rw-r--r--src/jwebmail/read_mails.py22
-rw-r--r--src/jwebmail/webmail.py143
4 files changed, 100 insertions, 100 deletions
diff --git a/src/jwebmail/__init__.py b/src/jwebmail/__init__.py
index a1bd00a..b4c01a1 100644
--- a/src/jwebmail/__init__.py
+++ b/src/jwebmail/__init__.py
@@ -4,7 +4,7 @@ from os import environ
from shutil import which
from babel import parse_locale
-from flask import Flask, abort, g, redirect, request_finished, url_for
+from flask import Flask, abort, g, redirect, url_for
from flask_babel import Babel, get_locale
from flask_login import LoginManager, login_required
from flask_wtf.csrf import CSRFProtect
@@ -36,7 +36,7 @@ else:
toml_read_file = dict(load=toml_load, text=True)
-__version__ = "2.8.2.dev0"
+__version__ = "2.8.2.dev1"
csrf = CSRFProtect()
@@ -122,12 +122,6 @@ def create_app():
except ValueError:
abort(404)
- @request_finished.connect_via(app)
- def close_qma(_app, **_):
- if "read_mails" in g:
- g.read_mails.close()
- g.read_mails = None
-
return app
diff --git a/src/jwebmail/model/read_mails.py b/src/jwebmail/model/read_mails.py
index fc89c8e..534f2f7 100644
--- a/src/jwebmail/model/read_mails.py
+++ b/src/jwebmail/model/read_mails.py
@@ -12,11 +12,7 @@ class QMAuthError(Exception):
class QMailAuthuser:
- def __init__(
- self, username, password, prog, mailbox_path, virtual_user, authenticator
- ):
- self._username = username
- self._password = password
+ def __init__(self, prog, mailbox_path, virtual_user, authenticator):
self._prog = prog
self._mailbox_path = mailbox_path
self._virtual_user = virtual_user
@@ -179,7 +175,7 @@ class QMailAuthuser:
else:
assert False
- def open(self):
+ def open(self, username, password):
(rp, wp) = os.pipe()
(sp, sc) = socketpair()
cmdline = [self._authenticator, self._prog]
@@ -196,7 +192,7 @@ class QMailAuthuser:
assert False
sc.close()
os.close(rp)
- os.write(wp, f"{self._username}\0{self._password}\0\0".encode())
+ os.write(wp, f"{username}\0{password}\0\0".encode())
os.close(wp)
self._pid = pid
@@ -214,11 +210,12 @@ class QMailAuthuser:
else:
raise
- user = self._username[: self._username.index("@")]
+ user = username[: username.index("@")]
self._connection.Init(
unix_user=self._virtual_user,
mailbox_path=os.path.join(self._mailbox_path, user),
)
+
return self
def close(self):
@@ -229,3 +226,15 @@ class QMailAuthuser:
rc = os.waitstatus_to_exitcode(status)
if rc != 0:
raise QMAuthError(rc)
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, ex_type, ex_val, ex_tb):
+ if ex_val is None:
+ self.close()
+ elif issubclass(ex_type, BrokenPipeError):
+ (pid, _status) = os.waitpid(self._pid, 0)
+ assert pid == self._pid
+
+ return False
diff --git a/src/jwebmail/read_mails.py b/src/jwebmail/read_mails.py
index 5aed8d2..8e1a23d 100644
--- a/src/jwebmail/read_mails.py
+++ b/src/jwebmail/read_mails.py
@@ -2,7 +2,7 @@ import pwd
from contextlib import closing
from os.path import join as path_join
-from flask import current_app, g
+from flask import current_app
from flask_login import UserMixin, current_user, login_user
from .model.read_mails import QMailAuthuser, QMAuthError
@@ -150,14 +150,12 @@ def _select_timeout_session():
raise ValueError(f"unknown session_type {session_type!r}")
-def _build_qma(username, password):
+def _build_qma(domain):
authenticator = current_app.config["JWEBMAIL"]["READ_MAILS"]["AUTHENTICATOR"]
backend = current_app.config["JWEBMAIL"]["READ_MAILS"]["BACKEND"]
virt_users = current_app.config["JWEBMAIL"]["READ_MAILS"].get("VIRTUAL_USERS")
if virt_users:
- _, domain = username.split("@")
-
with open(virt_users, encoding="ASCII") as file:
for virt_dom in file:
dom, unix_user = virt_dom.rstrip().split(":")
@@ -171,14 +169,13 @@ def _build_qma(username, password):
mailbox_user = current_app.config["JWEBMAIL"]["READ_MAILS"]["MAILBOX_USER"]
mailbox = current_app.config["JWEBMAIL"]["READ_MAILS"]["MAILBOX"]
- return QMailAuthuser(
- username, password, backend, mailbox, mailbox_user, authenticator
- )
+ return QMailAuthuser(backend, mailbox, mailbox_user, authenticator)
def login(username, password):
try:
- _build_qma(username, password).open()
+ _, domain = username.split("@")
+ _build_qma(domain).open(username, password)
except QMAuthError as err:
if err.rc == 1:
return False
@@ -203,9 +200,6 @@ def load_user(username: str) -> JWebmailUser:
def get_read_mails_logged_in():
- if "read_mails" in g:
- return g.read_mails
-
- qma = _build_qma(current_user.get_id(), current_user.password).open()
- g.read_mails = qma
- return qma
+ username = current_user.get_id()
+ _, domain = username.split("@")
+ return _build_qma(domain).open(username, current_user.password)
diff --git a/src/jwebmail/webmail.py b/src/jwebmail/webmail.py
index d183b30..d260e8a 100644
--- a/src/jwebmail/webmail.py
+++ b/src/jwebmail/webmail.py
@@ -86,31 +86,32 @@ def about():
def displayheaders(folder=""):
- folders = get_read_mails_logged_in().folders()
-
- if folder and folder not in folders:
- return render_template("error", error="no_folder", links=folders), 404
-
- page_bound = request.args.get("page_bound")
- page_after = bool(request.args.get("page_after", type=int, default=True))
- per_page = request.args.get("per_page", type=int, default=25)
- sort = request.args.get("sort", "!date")
- search = request.args.get("search")
-
- s = sort[1:] if sort[0] == "!" else sort
- if s not in ["date", "size", "sender"]:
- abort(400)
-
- count = get_read_mails_logged_in().count(folder)
-
- headers, first, last = get_read_mails_logged_in().list_search(
- folder=folder,
- bound=page_bound,
- after=page_after,
- limit=per_page,
- sort=sort,
- search=search,
- )
+ with get_read_mails_logged_in() as read_mails:
+ folders = read_mails.folders()
+
+ if folder and folder not in folders:
+ return render_template("error", error="no_folder", links=folders), 404
+
+ page_bound = request.args.get("page_bound")
+ page_after = bool(request.args.get("page_after", type=int, default=True))
+ per_page = request.args.get("per_page", type=int, default=25)
+ sort = request.args.get("sort", "!date")
+ search = request.args.get("search")
+
+ s = sort[1:] if sort[0] == "!" else sort
+ if s not in ["date", "size", "sender"]:
+ abort(400)
+
+ count = read_mails.count(folder)
+
+ headers, first, last = read_mails.list_search(
+ folder=folder,
+ bound=page_bound,
+ after=page_after,
+ limit=per_page,
+ sort=sort,
+ search=search,
+ )
if headers:
match s:
@@ -160,41 +161,41 @@ def displayheaders(folder=""):
def readmail(msgid, folder=""):
format = request.args.get("format", "html").lower()
- if format == "html":
- try:
- mail = get_read_mails_logged_in().show(folder, msgid)
- except QMAuthError:
- return render_template("not_found.html"), 404
+ with get_read_mails_logged_in() as read_mails:
+ if format == "html":
+ try:
+ mail = read_mails.show(folder, msgid)
+ except QMAuthError:
+ return render_template("not_found.html"), 404
- return render_template("readmail.html", msg=mail, folder=folder)
+ return render_template("readmail.html", msg=mail, folder=folder)
- elif format == "raw":
- path = request.args.get("path", "")
+ elif format == "raw":
+ path = request.args.get("path", "")
- content = get_read_mails_logged_in().raw(folder, msgid, path)
+ content = read_mails.raw(folder, msgid, path)
- headers = []
+ headers = []
- cd = content["head"].get("content_disposition")
- if cd and cd.lower() == "attachment":
- headers.append(
- (
- "Content-Disposition",
- f"attachment; filename={content['head']['filename']}",
+ cd = content["head"].get("content_disposition")
+ if cd and cd.lower() == "attachment":
+ headers.append(
+ (
+ "Content-Disposition",
+ f"attachment; filename={content['head']['filename']}",
+ )
)
- )
- ct = to_mime_type(content["head"])
- if ct.startswith("text/"):
- ct += "; charset=UTF-8"
- headers.append(("Content-Type", ct))
+ ct = to_mime_type(content["head"])
+ if ct.startswith("text/"):
+ ct += "; charset=UTF-8"
+ headers.append(("Content-Type", ct))
- return content["body"], headers
+ return content["body"], headers
- elif format == "json":
- mail = get_read_mails_logged_in().show(folder, msgid)
- return mail
- else:
- abort(404)
+ elif format == "json":
+ return read_mails.show(folder, msgid)
+ else:
+ abort(404)
def writemail():
@@ -208,16 +209,17 @@ def _take_common_req_args(mapping):
def move(folder=""):
- folders = get_read_mails_logged_in().folders()
+ with get_read_mails_logged_in() as read_mails:
+ folders = read_mails.folders()
- mm = request.form.getlist("mail")
- to_folder = request.form["select-folder"]
+ mm = request.form.getlist("mail")
+ to_folder = request.form["select-folder"]
- if folder not in folders or to_folder not in folders:
- raise ValueError("folder not valid")
+ if folder not in folders or to_folder not in folders:
+ raise ValueError("folder not valid")
- for m in mm:
- get_read_mails_logged_in().move(m, folder, to_folder)
+ for m in mm:
+ read_mails.move(m, folder, to_folder)
flash(gettext("succ_move"))
args = _take_common_req_args(request.form)
@@ -225,21 +227,22 @@ def move(folder=""):
def remove(folder=""):
- folders = get_read_mails_logged_in().add_folder("Trash")
+ with get_read_mails_logged_in() as read_mails:
+ folders = read_mails.add_folder("Trash")
- mm = request.form.getlist("mail")
+ mm = request.form.getlist("mail")
- folders = get_read_mails_logged_in().folders()
+ folders = read_mails.folders()
- if folder not in folders:
- raise ValueError("folder not valid")
+ if folder not in folders:
+ raise ValueError("folder not valid")
- if folder == "Trash":
- for m in mm:
- get_read_mails_logged_in().remove(folder, m)
- else:
- for m in mm:
- get_read_mails_logged_in().move(m, folder, "Trash")
+ if folder == "Trash":
+ for m in mm:
+ read_mails.remove(folder, m)
+ else:
+ for m in mm:
+ read_mails.move(m, folder, "Trash")
flash(gettext("succ_remove"))
args = _take_common_req_args(request.form)