Source code for abilian.core.extensions

"""Create all standard extensions."""

# Note: Because of issues with circular dependencies, Abilian-specific
# extensions are created later.
import sqlite3
from typing import Any

import flask_mail
import sqlalchemy as sa
import sqlalchemy.event
import sqlalchemy.orm
from flask import current_app
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.sql.schema import MetaData

from abilian.core.logging import patch_logger
from abilian.core.sqlalchemy import SQLAlchemy

from . import upstream_info
from .csrf import abilian_csrf
from .csrf import wtf_csrf as csrf
from .jinjaext import DeferredJS
from .login import login_manager
from .redis import Redis

__all__ = (
    "get_extension",
    "redis",
    "db",
    "mail",
    "login_manager",
    "csrf",
    "upstream_info",
)


# patch flask.ext.mail.Message.send to always set enveloppe_from default mail
# sender
# FIXME: we'ld rather subclass Message and update all imports
def _message_send(self: Any, connection: flask_mail.Connection) -> None:
    """Send a single message instance.

    If TESTING is True the message will not actually be sent.

    :param self: a Message instance.
    """
    sender = current_app.config["MAIL_SENDER"]
    if not self.extra_headers:
        self.extra_headers = {}
    self.extra_headers["Sender"] = sender
    connection.send(self, sender)


patch_logger.info(flask_mail.Message.send)
flask_mail.Message.send = _message_send

mail = flask_mail.Mail()

db = SQLAlchemy()

redis = Redis()

deferred_js = DeferredJS()


@sa.event.listens_for(db.metadata, "before_create")
@sa.event.listens_for(db.metadata, "before_drop")
def _filter_metadata_for_connection(
    target: MetaData, connection: Connection, **kw: Any
) -> None:
    """Listener to control what indexes get created.

    Useful for skipping postgres-specific indexes on a sqlite for example.

    It's looking for info entry `engines` on an index
    (`Index(info=dict(engines=['postgresql']))`), an iterable of engine names.
    """
    engine = connection.engine.name
    default_engines = (engine,)
    tables = target if isinstance(target, sa.Table) else kw.get("tables", [])
    for table in tables:
        indexes = list(table.indexes)
        for idx in indexes:
            if engine not in idx.info.get("engines", default_engines):
                table.indexes.remove(idx)


#
# CSRF
#
[docs]def get_extension(name: str): """Get the named extension from the current app, returning None if not found.""" return current_app.extensions.get(name)
def _install_get_display_value(cls: Any) -> None: _MARK = object() def display_value(self, field_name, value=_MARK): """Return display value for fields having 'choices' mapping (stored value. -> human readable value). For other fields it will simply return field value. `display_value` should be used instead of directly getting field value. If `value` is provided, it is "translated" to a human-readable value. This is useful for obtaining a human readable label from a raw value. """ val = getattr(self, field_name, "ERROR") if value is _MARK else value mapper = sa.orm.object_mapper(self) try: field = getattr(mapper.c, field_name) except AttributeError: pass else: if "choices" in field.info: def get(v): return field.info["choices"].get(v, v) if isinstance(val, list): val = [get(v) for v in val] else: val = get(val) return val if not hasattr(cls, "display_value"): cls.display_value = display_value sa.event.listen(db.Model, "class_instrument", _install_get_display_value) # # Make Sqlite a bit more well-behaved. # @sa.event.listens_for(Engine, "connect") def _set_sqlite_pragma(dbapi_connection: Any, connection_record: Any) -> None: if isinstance(dbapi_connection, sqlite3.Connection): # pragma: no cover cursor = dbapi_connection.cursor() cursor.execute("PRAGMA foreign_keys=ON;") cursor.close()