diff --git a/antivirus_service/consumer.py b/antivirus_service/consumer.py index dca1ba5..52745f2 100644 --- a/antivirus_service/consumer.py +++ b/antivirus_service/consumer.py @@ -49,9 +49,3 @@ class ScanFileConsumer(ScanConsumer): def __init__(self, settings, handler): super().__init__(settings, handler) self.amqp_queue = self.amqp_config['scan_file']['queue'] - - -class ScanUrlConsumer(ScanConsumer): - def __init__(self, settings, handler): - super().__init__(settings, handler) - self.amqp_queue = self.amqp_config['scan_url']['queue'] \ No newline at end of file diff --git a/antivirus_service/handler.py b/antivirus_service/handler.py index 3544a8e..b702c45 100644 --- a/antivirus_service/handler.py +++ b/antivirus_service/handler.py @@ -11,6 +11,7 @@ class ScanHandler(object): def __init__(self, settings): self.config = settings.config[settings.env] self.clamd = Clamd(self.config['clamd']) + self.download_retry_count = self.config['download_retry_count'] def handle_error_message(self, payload, error_message): ''' @@ -48,7 +49,7 @@ def scan(self, download_uri, access_token): # that the scan request was triggered before the file upload # has been finished last_exception_message = '' - count = 3 + count = self.download_retry_count for i in range(1, count): try: # defining stream and timeout = just specifying the initial socket connect timeout, that's what we want @@ -102,71 +103,3 @@ def handle_message(self, payload): scan_result, signature = self.scan(download_uri, access_token) self.callback(callback_uri, access_token, scan_result, signature) - -class ScanUrlHandler(ScanHandler): - def parse_body(self, body): - payload = json.loads(bytes(body).decode('utf-8')) - assert 'url' in payload - assert 'callback_uri' in payload - return payload - - def callback(self, callback_uri, access_token, blacklisted, full_report): - logging.info('Start callback') - headers = {} - if access_token: - headers['Authorization'] = 'Bearer %s' % access_token - - result = { - 'blacklisted': blacklisted, - 'full_report': full_report - } - logging.info(result) - - response = requests.put(callback_uri, headers=headers, data=json.dumps(result)) - logging.info(response.status_code) - logging.info('------------- END PROCESS SCAN -------------') - - def scan_url(self, url): - logging.info('Start scan') - - api_key = self.config['virustotal']['api_key'] - - params = {'apikey': api_key, 'url': url} - response = requests.post('https://www.virustotal.com/vtapi/v2/url/scan', data=params) - scan_response = response.json() - logging.info(scan_response) - - # we have to wait for the result - headers = { - 'Accept-Encoding': 'gzip, deflate', - 'User-Agent' : 'gzip, Antivirus checker' - } - for i in range(10): - params = {'apikey': api_key, 'resource': scan_response['scan_id']} - response = requests.post('http://www.virustotal.com/vtapi/v2/url/report', data=params, headers=headers) - result_response = response.json() - logging.info(result_response) - - # scan job is processed - if 'positives' in result_response: - break - time.sleep(2**i) - else: - d = 2**(i+1) - 1 - raise Exception('Scan report could not downloaded after {0} seconds'.format(d)) - - return result_response['positives'] != 0, result_response - - def handle_message(self, payload): - ''' - handles antivirus scan url requests - ''' - logging.info('------------- INCOMING MESSAGE -------------') - logging.info(payload) - - url = payload['url'] - callback_uri = payload['callback_uri'] - access_token = payload.get('access_token', None) - - blacklisted, full_report = self.scan_url(url) - self.callback(callback_uri, access_token, blacklisted, full_report) diff --git a/antivirus_service/service.py b/antivirus_service/service.py index 80416ab..226fc46 100644 --- a/antivirus_service/service.py +++ b/antivirus_service/service.py @@ -6,8 +6,8 @@ from environs import Env from antivirus_service.webserver import Webserver -from antivirus_service.handler import ScanFileHandler, ScanUrlHandler -from antivirus_service.consumer import ScanFileConsumer, ScanUrlConsumer +from antivirus_service.handler import ScanFileHandler +from antivirus_service.consumer import ScanFileConsumer class AntivirusSettings(object): @@ -19,6 +19,7 @@ def __init__(self, env, debug): env = Env() with env.prefixed(self.env.upper() + "_"): self.config[self.env] = {} + self.config[self.env]['download_retry_count'] = env.int("DOWNLOAD_RETRY_COUNT", 3) param = "clamd" with env.prefixed(param.upper() + "_"): self.config[self.env][param] = {} @@ -65,35 +66,25 @@ def cli(ctx, env, debug): @click.pass_context def scan_file(ctx): """Run Antivirus Service - listen on message queue""" + consumer = None try: handler = ScanFileHandler(ctx.obj) consumer = ScanFileConsumer(ctx.obj, handler) consumer.run() except KeyboardInterrupt: - consumer.stop() - - -@cli.command() -@click.pass_context -def scan_url(ctx): - """Run Antivirus Service - listen on message queue""" - try: - handler = ScanUrlHandler(ctx.obj) - consumer = ScanUrlConsumer(ctx.obj, handler) - consumer.run() - except KeyboardInterrupt: - consumer.stop() + consumer and consumer.stop() @cli.command() @click.pass_context def webserver(ctx): """Run Antivirus Service - webserver""" + web_server = None try: - webserver = Webserver(ctx.obj) - webserver.run() + web_server = Webserver(ctx.obj) + web_server.run() except KeyboardInterrupt: - webserver.stop() + web_server and web_server.stop() def main():