Skip to content

Commit

Permalink
XOAUTH2 support for Outlook SMTP (#15064)
Browse files Browse the repository at this point in the history
(cherry picked from commit 56fd2c9)

Co-authored-by: themylogin <[email protected]>
  • Loading branch information
bugclerk and themylogin authored Dec 2, 2024
1 parent 254bcae commit bc57695
Show file tree
Hide file tree
Showing 4 changed files with 119 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""Mail OAuth provider
Revision ID: bda3a0ff206e
Revises: bb352e66987f
Create Date: 2024-12-02 13:45:00.262906+00:00
"""
import json

from alembic import op
import sqlalchemy as sa

from middlewared.plugins.pwenc import encrypt, decrypt


# revision identifiers, used by Alembic.
revision = 'bda3a0ff206e'
down_revision = 'bb352e66987f'
branch_labels = None
depends_on = None


def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
conn = op.get_bind()
for id, em_oauth in conn.execute("SELECT id, em_oauth FROM system_email").fetchall():
if em_oauth := decrypt(em_oauth):
em_oauth = json.loads(em_oauth)
if em_oauth:
em_oauth["provider"] = "gmail"
conn.execute(
"UPDATE system_email SET em_oauth = ? WHERE id = ?",
(encrypt(json.dumps(em_oauth)), id)
)

# ### end Alembic commands ###


def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
pass
# ### end Alembic commands ###
19 changes: 9 additions & 10 deletions src/middlewared/middlewared/plugins/mail.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ class DenyNetworkActivity(Exception):
pass


class QueueItem(object):

class QueueItem:
def __init__(self, message):
self.attempts = 0
self.message = message


class MailQueue(object):

class MailQueue:
MAX_ATTEMPTS = 3
MAX_QUEUE_LIMIT = 20

Expand Down Expand Up @@ -72,10 +70,7 @@ class MailModel(sa.Model):


class MailService(ConfigService):

mail_queue = MailQueue()
oauth_access_token = None
oauth_access_token_expires_at = None

class Config:
datastore = 'system.email'
Expand All @@ -95,6 +90,7 @@ class Config:
Password('pass', null=True, required=True),
Dict(
'oauth',
Str('provider'),
Str('client_id'),
Str('client_secret'),
Password('refresh_token'),
Expand All @@ -118,6 +114,7 @@ async def mail_extend(self, cfg):
(
'replace', Dict(
'oauth',
Str('provider'),
Str('client_id', required=True),
Str('client_secret', required=True),
Password('refresh_token', required=True),
Expand Down Expand Up @@ -369,7 +366,7 @@ def read_json():
msg[key] = val

try:
if config['oauth']:
if config['oauth'] and config['oauth']['provider'] == 'gmail':
self.middleware.call_sync('mail.gmail_send', msg, config)
else:
server = self._get_smtp_server(config, message['timeout'], local_hostname=local_hostname)
Expand Down Expand Up @@ -428,7 +425,9 @@ def _get_smtp_server(self, config, timeout=300, local_hostname=None):
local_hostname=local_hostname)
if config['security'] == 'TLS':
server.starttls()
if config['smtp']:
if config['oauth'] and config['oauth']['provider'] == 'outlook':
self.middleware.call_sync('mail.outlook_xoauth2', server, config)
elif config['smtp']:
server.login(config['user'], config['pass'])

return server
Expand All @@ -440,7 +439,7 @@ def send_mail_queue(self):
for queue in list(mq.queue):
try:
config = self.middleware.call_sync('mail.config')
if config['oauth']:
if config['oauth'] and config['oauth']['provider'] == 'gmail':
self.middleware.call_sync('mail.gmail_send', queue.message, config)
else:
server = self._get_smtp_server(config)
Expand Down
2 changes: 1 addition & 1 deletion src/middlewared/middlewared/plugins/mail_/gmail.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def gmail_initialize(self):

@private
def gmail_build_service(self, config):
if config["oauth"]:
if config["oauth"] and config["oauth"]["provider"] == "gmail":
return GmailService(config)

return None
Expand Down
67 changes: 67 additions & 0 deletions src/middlewared/middlewared/plugins/mail_/outlook.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import base64
from dataclasses import dataclass
from smtplib import SMTP
import time

import requests

from middlewared.service import CallError, private, Service


@dataclass
class OutlookToken:
token: str
expires_at: float


class MailService(Service):
outlook_tokens: dict[str, OutlookToken] = {}

@private
def outlook_xoauth2(self, server: SMTP, config: dict):
server.ehlo()

if token := self._get_outlook_token(config["fromemail"], config["oauth"]["refresh_token"]):
code, response = self._do_xoauth2(server, config["fromemail"], token)
if 200 <= code <= 299:
return

self.logger.warning("Outlook XOAUTH2 failed: %r %r. Refreshing access token", code, response)

self.logger.debug("Requesting Outlook access token")
r = requests.post(
"https://login.microsoftonline.com/common/oauth2/v2.0/token",
data={
"grant_type": "refresh_token",
"client_id": config["oauth"]["client_id"],
"client_secret": config["oauth"]["client_secret"],
"refresh_token": config["oauth"]["refresh_token"],
"scope": "https://outlook.office.com/SMTP.Send openid offline_access",
}
)
r.raise_for_status()
response = r.json()

token = response["access_token"]
self._set_outlook_token(config["fromemail"], config["oauth"]["refresh_token"], token, response["expires_in"])

code, response = self._do_xoauth2(server, config["fromemail"], token)
if 200 <= code <= 299:
return

raise CallError("Outlook XOAUTH2 failed: %r %r" % (code, response))

def _get_outlook_token(self, email: str, refresh_token: str) -> str | None:
for key, token in list(self.outlook_tokens.items()):
if token.expires_at < time.monotonic() - 5:
self.outlook_tokens.pop(key)

if token := self.outlook_tokens.get(email + refresh_token):
return token.token

def _set_outlook_token(self, email: str, refresh_token: str, token: str, expires_in: int):
self.outlook_tokens[email + refresh_token] = OutlookToken(token, time.monotonic() + expires_in)

def _do_xoauth2(self, server: SMTP, email: str, access_token: str):
auth_string = f"user={email}\1auth=Bearer {access_token}\1\1"
return server.docmd("AUTH XOAUTH2", base64.b64encode(auth_string.encode()).decode())

0 comments on commit bc57695

Please sign in to comment.