Skip to content

Commit

Permalink
fix bug #45; implement RFE #43 #44 #46 #47
Browse files Browse the repository at this point in the history
  • Loading branch information
ryran committed Nov 21, 2016
1 parent 5e9cb5a commit e0c9072
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 39 deletions.
52 changes: 29 additions & 23 deletions rhsda.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,12 @@ def _reduce_method(m):
numThreadsDefault = multiprocessing.cpu_count() * 2


def jprint(jsoninput, printOutput=True):
def jprint(jsoninput):
"""Pretty-print jsoninput."""
j = json.dumps(jsoninput, sort_keys=True, indent=2) + "\n"
if printOutput:
print(j)
else:
return j
return json.dumps(jsoninput, sort_keys=True, indent=2) + "\n"


def extract_cves_from_input(obj):
def extract_cves_from_input(obj, descriptiveNoun=None):
"""Use case-insensitive regex to extract CVE ids from input object.
*obj* can be a list, a file, or a string.
Expand All @@ -146,9 +142,9 @@ def extract_cves_from_input(obj):
# Array to store found CVEs
found = []
if obj == sys.stdin:
noun = "stdin"
else:
noun = "input"
descriptiveNoun = "stdin"
elif not descriptiveNoun:
descriptiveNoun = "input"
if isinstance(obj, str):
obj = obj.splitlines()
for line in obj:
Expand All @@ -159,10 +155,15 @@ def extract_cves_from_input(obj):
# Converting to a set removes duplicates
found = list(set(found))
uniqueCount = len(found)
logger.log(25, "Found {0} CVEs in {1}; {2} duplicates removed".format(uniqueCount, noun, matchCount-uniqueCount))
if matchCount-uniqueCount:
dupes = "; {0} duplicates removed".format(matchCount-uniqueCount)
else:
dupes = ""
logger.log(25, "Found {0} CVEs on {1}{2}".format(uniqueCount, descriptiveNoun, dupes))
return [x.upper() for x in found]
else:
logger.log(25, "No CVEs (matching regex: '{0}') found in {1}".format(cve_regex_string, noun))
logger.warning("No CVEs (matching regex: '{0}') found on {1}".format(cve_regex_string, descriptiveNoun))
return []


class ApiClient:
Expand Down Expand Up @@ -198,7 +199,7 @@ def __get(self, url, params={}):
if params[k]:
u += "&{0}={1}".format(k, params[k])
u = u.replace("&", "?", 1)
logger.info("Getting '{0}{1}' ...".format(url, u))
logger.info("Getting {0}{1}".format(url, u))
try:
r = requests.get(url, params=params)
except requests.exceptions.ConnectionError as e:
Expand All @@ -207,7 +208,10 @@ def __get(self, url, params={}):
except requests.exceptions.RequestException as e:
logger.error(e)
raise
logger.debug("Return status: '{0}'; Content-Type: '{1}'".format(r.status_code, r.headers['Content-Type']))
baseurl = r.url.split("/")[-1]
if not baseurl:
baseurl = r.url.split("/")[-2]
logger.debug("Return '.../{0}': Status {1}, Content-Type {2}".format(baseurl, r.status_code, r.headers['Content-Type'].split(";")[0]))
r.raise_for_status()
if 'application/xml' in r.headers['Content-Type']:
return r.content
Expand Down Expand Up @@ -581,6 +585,7 @@ def _set_cve_plaintext_width(self, wrapWidth):
self.wrapper = textwrap.TextWrapper(width=wrapWidth, initial_indent=" ", subsequent_indent=" ", replace_whitespace=False)
else:
self.wrapper = 0
logger.debug("Set wrapWidth to '{0}'".format(wrapWidth))

def mget_cves(self, cves, numThreads=0, onlyCount=False, outFormat='plaintext',
urls=False, fields='ALL', wrapWidth=70, product=None, timeout=300):
Expand Down Expand Up @@ -644,9 +649,7 @@ def mget_cves(self, cves, numThreads=0, onlyCount=False, outFormat='plaintext',
raise ValueError("Invalid outFormat ('{0}') requested; should be one of: 'plaintext', 'json', 'jsonpretty'".format(outFormat))
if isinstance(cves, str) or isinstance(cves, file):
cves = extract_cves_from_input(cves)
elif isinstance(cves, list):
cves = [x.upper() for x in cves]
else:
elif not isinstance(cves, list):
raise ValueError("Invalid 'cves=' argument input; must be list, string, or file obj")
# Configure threads
if not numThreads:
Expand All @@ -673,9 +676,9 @@ def mget_cves(self, cves, numThreads=0, onlyCount=False, outFormat='plaintext',
# Need to specify timeout; see: http://stackoverflow.com/a/35134329
results = p.get(timeout=timeout)
except KeyboardInterrupt:
logger.error("\nReceived KeyboardInterrupt; terminating worker threads")
logger.error("Received KeyboardInterrupt; terminating worker threads")
pool.terminate()
return
raise
else:
pool.close()
pool.join()
Expand Down Expand Up @@ -703,7 +706,7 @@ def mget_cves(self, cves, numThreads=0, onlyCount=False, outFormat='plaintext',
elif outFormat == 'json':
return cveOutput
elif outFormat == 'jsonpretty':
return jprint(cveOutput, False)
return jprint(cveOutput)

def cve_search_query(self, params, outFormat='list'):
"""Perform a CVE search query.
Expand All @@ -721,7 +724,7 @@ def cve_search_query(self, params, outFormat='list'):
if outFormat == 'json':
return result
if outFormat == 'jsonpretty':
return jprint(result, False)
return jprint(result)
cves = []
for i in result:
cves.append(i['CVE'])
Expand All @@ -739,7 +742,7 @@ def _err_print_support_urls(self, msg=None):

def _iavm_query(self, url):
"""Get IAVA json from IAVM Mapper App."""
logger.info("Getting '{0}' ...".format(url))
logger.info("Getting {0}".format(url))
try:
r = requests.get(url, auth=())
except requests.exceptions.ConnectionError as e:
Expand All @@ -752,7 +755,10 @@ def _iavm_query(self, url):
self._err_print_support_urls(e)
raise
r.raise_for_status()
logger.debug("Return status: '{0}'; Content-Type: '{1}'".format(r.status_code, r.headers['Content-Type']))
baseurl = r.url.split("/")[-1]
if not baseurl:
baseurl = r.url.split("/")[-2]
logger.debug("Return '.../{0}': Status {1}, Content-Type {2}".format(baseurl, r.status_code, r.headers['Content-Type'].split(";")[0]))
if 'application/json' in r.headers['Content-Type']:
result = r.json()
elif '<title>Login - Red Hat Customer Portal</title>' in r.content:
Expand Down
31 changes: 15 additions & 16 deletions rhsecapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@
# Globals
prog = 'rhsecapi'
vers = {}
vers['version'] = '1.0.0_rc3'
vers['date'] = '2016/18/10'
vers['version'] = '1.0.0_rc4'
vers['date'] = '2016/11/20'


# Logging
Expand Down Expand Up @@ -154,10 +154,10 @@ def parse_args():
g_listByAttr = p.add_argument_group(
'FIND CVES BY ATTRIBUTE')
g_listByAttr.add_argument(
'--q-before', metavar="YEAR-MM-DD",
'--q-before', metavar="YYYY-MM-DD",
help="Narrow down results to before a certain time period")
g_listByAttr.add_argument(
'--q-after', metavar="YEAR-MM-DD",
'--q-after', metavar="YYYY-MM-DD",
help="Narrow down results to after a certain time period")
g_listByAttr.add_argument(
'--q-bug', metavar="BZID",
Expand Down Expand Up @@ -202,8 +202,8 @@ def parse_args():
g_getCve = p.add_argument_group(
'QUERY SPECIFIC CVES')
g_getCve.add_argument(
'cves', metavar="CVE", nargs='*',
help="Retrieve a CVE or space-separated list of CVEs (e.g.: 'CVE-2016-5387')")
'cves', metavar="CVE-YYYY-NNNN", nargs='*',
help="Retrieve a CVE or list of CVEs (e.g.: 'CVE-2016-5387'); note that case-insensitive regex-matching is done -- extra characters & duplicate CVEs will be discarded")
g_getCve.add_argument(
'-s', '--extract-search', action='store_true',
help="Extract CVEs them from search query (as initiated by at least one of the --q-xxx options)")
Expand Down Expand Up @@ -314,13 +314,13 @@ def parse_args():
if o.q_iava and o.doSearch:
logger.error("The --q-iava option is incompatible with other --q-xxx options; it can only be used alone")
sys.exit(1)
if o.cves:
o.cves = rhsda.extract_cves_from_input(o.cves, "cmdline")
if not o.cves:
o.showUsage = True
if o.extract_stdin and not sys.stdin.isatty():
found = rhsda.extract_cves_from_input(sys.stdin)
o.cves.extend(found)
# If only one CVE (common use-case), let's validate its format
if len(o.cves) == 1 and not rhsda.cve_regex.match(o.cves[0]):
logger.error("Invalid CVE format '{0}'; expected: 'CVE-YYYY-XXXX'".format(o.cves[0]))
o.showUsage = True
# If no search (--q-xxx) and no CVEs mentioned
if not o.showUsage and not (o.doSearch or o.cves or o.q_iava):
logger.error("Must specify a search to perform (one of the --q-xxx opts) or CVEs to retrieve")
Expand Down Expand Up @@ -352,7 +352,7 @@ def main(opts):
opts.cves.append(cve['CVE'])
elif not opts.count:
if opts.json:
searchOutput.append(rhsda.jprint(result, False))
searchOutput.append(rhsda.jprint(result))
else:
for cve in result:
searchOutput.append(cve['CVE'] + "\n")
Expand All @@ -367,15 +367,15 @@ def main(opts):
opts.cves.extend(result['IAVM']['CVEs']['CVENumber'])
elif not opts.count:
if opts.json:
iavaOutput.append(rhsda.jprint(result, False))
iavaOutput.append(rhsda.jprint(result))
else:
for cve in result['IAVM']['CVEs']['CVENumber']:
iavaOutput.append(cve + "\n")
if not opts.pastebin:
print(file=sys.stderr)
print("".join(iavaOutput), end="")
if opts.dryrun and opts.cves:
logger.notice("Skipping CVE retrieval due to --dryrun; would have retrieved: {0}".format(len(opts.cves)))
logger.log(25, "Skipping CVE retrieval due to --dryrun; would have retrieved: {0}".format(len(opts.cves)))
cveOutput = " ".join(opts.cves) + "\n"
elif opts.cves:
if searchOutput or iavaOutput:
Expand All @@ -388,8 +388,6 @@ def main(opts):
fields=opts.fields,
wrapWidth=opts.wrapWidth,
product=opts.product)
if not opts.count:
print(file=sys.stderr)
if opts.count:
return
if opts.pastebin:
Expand All @@ -408,6 +406,7 @@ def main(opts):
else:
print(response)
elif opts.cves:
print(file=sys.stderr)
print(cveOutput, end="")


Expand All @@ -416,5 +415,5 @@ def main(opts):
opts = parse_args()
main(opts)
except KeyboardInterrupt:
print("\n{0}: Received KeyboardInterrupt; exiting".format(prog))
logger.error("Received KeyboardInterrupt; exiting")
sys.exit()

0 comments on commit e0c9072

Please sign in to comment.