diff --git a/radon/lambdas.py b/radon/lambdas.py index 6759251..b2cf0bd 100644 --- a/radon/lambdas.py +++ b/radon/lambdas.py @@ -5,7 +5,7 @@ import time import boto3 -from . import virus_total as vt +from .virus_total import attachment_scan, uri_scan from . import phish_report from .investigation_message_searcher import InvestigationMessageSearcher from .pull_message import MessagePull @@ -16,6 +16,7 @@ logger.setLevel(logging.DEBUG) sqs = boto3.client('sqs') +s3 = boto3.client('s3') def sqs_event_handler(event, context): @@ -69,27 +70,22 @@ def pull_message(org_id, mailbox, msg_id): # TODO Batch URI dispatch for uri in mp.uris(): - print(uri) dispatch('lookup_url', uri['uri']) for a in mp.attachments(): - print(a) # TODO: let's upload this to s3 - dispatch('lookup_attachment', a['sha256'], a['filename'], a['content']) + s3_path = '' + dispatch('lookup_attachment', a['sha256'], a['filename'], s3_path) def lookup_url(url): - print(url) - return - r = vt.url_report(url) - if r.status == 'pending': - dispatch('lookup_url', url, wait=5) - return + scan = uri_scan(url) + if scan is None: + dispatch('lookup_url', url, wait=30) - # TODO persist scan - -def lookup_attachment(file_hash, file_name, content): - pass +def lookup_attachment(file_hash, file_name, s3_path): + s3_file = s3.get_object() + attachment_scan(file_hash, file_name, s3_file['Body']) # vim:sw=4 sts=4 diff --git a/radon/models.py b/radon/models.py index cd4a018..01e1574 100644 --- a/radon/models.py +++ b/radon/models.py @@ -1,6 +1,6 @@ import os from orator import Model, DatabaseManager -from orator.orm import has_many, belongs_to +from orator.orm import has_many, belongs_to, accessor config = { 'postgres': { @@ -57,8 +57,12 @@ def attachments(self): class UriScan(Model): - pass + @accessor + def resource(self): + return self.get_raw_attribute('uri') class AttachmentScan(Model): - pass + @accessor + def resource(self): + return self.get_raw_attribute('hash_256') diff --git a/radon/scanner.py b/radon/scanner.py new file mode 100644 index 0000000..786def5 --- /dev/null +++ b/radon/scanner.py @@ -0,0 +1,65 @@ +import time +from collections import OrderedDict +from orator import Model +from typing import Callable + +MAX_CACHE_SIZE = 4096 +MALIGN_TTL = 12 * 60 * 60 +BENIGN_TTL = 30 * 60 +PENDING_TTL = 5 * 60 + + +class ExpiringLRUCache: + def __init__(self, max_len: int) -> None: + self.dict: OrderedDict = OrderedDict() + self.max_len: int = max_len + + def get(self, key: str): + item, age_out_time = self.dict.get(key, (None, 0)) + if age_out_time >= time.time(): + del self.dict[key] + return item + + def set(self, key: str, value: Model, age_out_seconds: int): + if key in self.dict: + del self.dict[key] + + self.dict[key] = (value, time.time() + age_out_seconds) + + if len(self.dict) > self.max_len: + self.dict.popitem(last=False) + + def __repr__(self): + return '%s(%r)' % (self.__class__.__name__, self.dict.items()) + + +class Scanner: + def __init__(self, model: Model, scanner: Callable) -> None: + self.cache = ExpiringLRUCache(MAX_CACHE_SIZE) + self.model = model + self.scanner: Callable = scanner + + def cache_model(self, model): + self.cache.set(model.resource, model, MALIGN_TTL if model.malicious else BENIGN_TTL) + + def scan(self, *args) -> Model: + resource = args[0] + + mem = self.cache.get(resource) + if mem: + return mem + + model = self.model.first_or_new(resource=resource) + if model.exists: + self.cache_model(model) + return model + + scan = self.scanner(*args) + if scan: + model.results = scan + model.save() + self.cache_model(model) + else: + model.pending = True + self.cache.set(resource, model, PENDING_TTL) + return model diff --git a/radon/virus_total.py b/radon/virus_total.py index 467a73d..b7dac62 100644 --- a/radon/virus_total.py +++ b/radon/virus_total.py @@ -1,16 +1,20 @@ import logging import os import requests -from collections import namedtuple +import boto3 +from typing import Optional + +from .scanner import Scanner +from .models import UriScan, AttachmentScan logger = logging.getLogger() logger.setLevel(logging.DEBUG) -VT_API_KEY = os.environ.get('VT_API_KEY') + FILE_NOTIFY_URL = os.environ.get('FILE_NOTIFY_URL') -Result = namedtuple('Result', ['status', 'result']) +VT_API_KEY = os.environ.get('VT_API_KEY') or \ + boto3.client('secretsmanager').get_secret_value(SecretId='VT_API_KEY') -# TODO use local -> dynamo caching first def api(endpoint, **opts): logger.debug(f'VT call to {endpoint} {opts}') headers = { 'Accept-Encoding': 'gzip, deflate', 'User-Agent': 'gzip, Agari' } @@ -24,16 +28,28 @@ def api(endpoint, **opts): return response.json() -def url_report(url): +def vt_url_report(url): vt = api('url/report', resource=url, scan=1, allinfo=1) if 'scans' not in vt: - return Result(status='pending', result=vt) - return Result(status='found', result=vt) + return None + return vt -def file_report(file_hash, file_name, file_handle): +def vt_file_report(file_hash, file_name, file_handle): vt = api('file/report', resource=file_hash, allinfo=1) if 'scans' not in vt: vt = api('file/scan', notify_url=FILE_NOTIFY_URL, file_name=file_name, file_handle=file_handle) - return Result(status='pending', result=vt) - return Result(status='found', result=vt) + return None + return vt + + +URI_SCANNER = Scanner(UriScan, vt_url_report) +ATTACHMENT_SCANNER = Scanner(AttachmentScan, vt_file_report) + + +def uri_scan(uri: str) -> Optional[UriScan]: + return URI_SCANNER.scan(uri) + + +def attachment_scan(file_hash, filename, file_handle) -> Optional[AttachmentScan]: + return ATTACHMENT_SCANNER.scan(file_hash, filename, file_handle) diff --git a/requirements-test.txt b/requirements-test.txt index bf6255f..03579f6 100644 --- a/requirements-test.txt +++ b/requirements-test.txt @@ -1,2 +1,3 @@ pytest moto +requests_mock diff --git a/requirements.txt b/requirements.txt index 344c748..45ca3d9 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,19 +4,17 @@ # # pip-compile --output-file requirements.txt requirements.in # ---extra-index-url https://257ee1b1133eba72dbc6a1c6f544ba507eb56a25bcdbcadd:@packagecloud.io/agari/private/pypi/simple - backpack==0.1 # via orator blinker==1.4 # via orator -boto3==1.7.82 -botocore==1.10.82 # via boto3, s3transfer +boto3==1.7.83 +botocore==1.10.83 # via boto3, s3transfer certifi==2018.8.13 # via requests chardet==3.0.4 # via html5-parser, requests cleo==0.6.8 # via orator docutils==0.14 # via botocore elasticsearch==6.3.1 faker==0.7.18 # via orator -futures==3.2.0 # via s3transfer +futures==3.1.1 # via s3transfer html5-parser==0.4.5 idna==2.7 # via requests inflection==0.3.1 # via orator diff --git a/test/test_scanner.py b/test/test_scanner.py new file mode 100644 index 0000000..3439be5 --- /dev/null +++ b/test/test_scanner.py @@ -0,0 +1,57 @@ +import unittest +import itertools +from unittest.mock import MagicMock +from orator.orm import Model + +from radon.scanner import Scanner + +class TestScanner(unittest.TestCase): + + def run_scan(self, in_cache: bool, in_db: bool, in_api: bool): + def lookup(r): + if r == 'lookup': + return {'stuff': 'api'} + else: + return None + + model_instance = MagicMock(spec=Model) + model_instance.resource = 'frogs' + model_instance.exists = False + model_instance.pending = None + model_instance.malicious = None + model_instance.results = {'stuff': 'in_db'} + + DummyModel = MagicMock(spec=Model) + DummyModel.first_or_new.return_value = model_instance + + scanner = Scanner(DummyModel, lookup) + + if in_api: + model_instance.resource = 'lookup' + + if in_db: + model_instance.exists = True + + if in_cache: + model_instance.results = {'stuff': 'cached'} + scanner.cache.set(model_instance.resource, model_instance, 10) + + return scanner.scan(model_instance.resource) + + def test_scanner(self): + for args in itertools.product([True, False], repeat=3): + m = self.run_scan(*args) + in_cache, in_db, in_api = args + if in_cache: + assert not m.pending + assert m.results['stuff'] == 'cached' + else: + if in_db: + assert not m.pending + assert m.results['stuff'] == 'in_db' + else: + if in_api: + assert not m.pending + assert m.results['stuff'] == 'api' + else: + assert m.pending diff --git a/test/test_virus_total.py b/test/test_virus_total.py new file mode 100644 index 0000000..82181e7 --- /dev/null +++ b/test/test_virus_total.py @@ -0,0 +1,25 @@ +import unittest +import requests_mock +from io import StringIO + +from radon.virus_total import vt_url_report, vt_file_report + +class TestVirusTotal(unittest.TestCase): + + def test_uri(self): + with requests_mock.Mocker() as mock: + mock.post('/vtapi/v2/url/report', json={}) + assert vt_url_report('http://agari.com/') is None + mock.post('/vtapi/v2/url/report', json={'scans': [1,2,3], 'positives': 3}) + assert vt_url_report('http://agari.com/')['positives'] == 3 + + def test_file(self): + with requests_mock.Mocker() as mock: + mock.post('/vtapi/v2/file/report', json={}) + mock.post('/vtapi/v2/file/scan', json={}) + fh = StringIO('contents') + assert vt_file_report('abc','name', fh) is None + assert mock.call_count == 2 + mock.post('/vtapi/v2/file/report', json={'scans': [1,2], 'positives': 2}) + assert vt_file_report('abc','name', fh)['positives'] == 2 + assert mock.call_count == 3