Source code for tract_querier.query_processor

import ast
from os import path
from copy import deepcopy
from operator import lt, gt
from itertools import takewhile
import fnmatch

from .code_util import DocStringInheritor

__all__ = [
    'keywords', 'EvaluateQueries', 'eval_queries',
    'queries_syntax_check', 'queries_preprocess',
    'TractQuerierSyntaxError', 'TractQuerierLabelNotFound'
]

keywords = [
    'and',
    'or',
    'not in',
    'not',
    'only',
    'endpoints_in',
    'both_endpoints_in',
    'anterior_of',
    'posterior_of',
    'medial_of',
    'lateral_of',
    'inferior_of',
    'superior_of',
]


class FiberQueryInfo(object):

    r"""
    Information about a processed query

    Attribute
    ---------
        tracts : set
            set of tract indices resulting from the query
        labels : set
            set of labels resulting by the query
        tracts_endpoints : (set, set)
            sets of labels of where the tract endpoints are
    """

    def __init__(self, tracts=None, labels=None, tracts_endpoints=None):
        if tracts is None:
            tracts = set()
        if labels is None:
            labels = set()
        if tracts_endpoints is None:
            tracts_endpoints = (set(), set())
        self.tracts = tracts
        self.labels = labels
        self.tracts_endpoints = tracts_endpoints

    def __getattribute__(self, name):
        if name in (
            'update', 'intersection_update', 'union', 'intersection',
            'difference', 'difference_update'
        ):
            return self.set_operation(name)
        else:
            return object.__getattribute__(self, name)

    def copy(self):
        return FiberQueryInfo(
            self.tracts.copy(), self.labels.copy(),
            (self.tracts_endpoints[0].copy(), self.tracts_endpoints[1].copy()),
        )

    def set_operation(self, name):
        def operation(tract_query_info):
            tracts_op = getattr(self.tracts, name)
            if name == 'intersection':
                name_labels = 'union'
            elif name == 'intersection_update':
                name_labels = 'update'
            else:
                name_labels = name
            labels_op = getattr(self.labels, name_labels)

            new_tracts = tracts_op(tract_query_info.tracts)
            new_labels = labels_op(tract_query_info.labels)

            new_tracts_endpoints = (
                getattr(self.tracts_endpoints[0], name)(
                    tract_query_info.tracts_endpoints[0]
                ),
                getattr(self.tracts_endpoints[1], name)(
                    tract_query_info.tracts_endpoints[1]
                )
            )

            if name.endswith('update'):
                return self
            else:
                return FiberQueryInfo(
                    new_tracts, new_labels,
                    new_tracts_endpoints,
                )

        return operation


[docs]class EvaluateQueries(ast.NodeVisitor): r""" This class implements the parser to process White Matter Query Language modules. By inheriting from :py:mod:`ast.NodeVisitor` it uses a syntax close to the python language. Every node expression visitor has the following signature Parameters ---------- node : ast.Node Returns ------- tracts : set numbers of the tracts that result of this query labels : set numbers of the labels that are traversed by the tracts resulting from this query """ __metaclass__ = DocStringInheritor relative_terms = [ 'anterior_of', 'posterior_of', 'medial_of', 'lateral_of', 'inferior_of', 'superior_of' ] def __init__( self, tractography_spatial_indexing, ): self.tractography_spatial_indexing = tractography_spatial_indexing self.evaluated_queries_info = {} self.queries_to_save = set() self.evaluating_endpoints = False
[docs] def visit_Module(self, node): for line in node.body: self.visit(line)
[docs] def visit_Compare(self, node): if any(not isinstance(op, ast.NotIn) for op in node.ops): raise TractQuerierSyntaxError( "Invalid syntax in query line %d" % node.lineno ) query_info = self.visit(node.left).copy() for value in node.comparators: query_info_ = self.visit(value) query_info.difference_update(query_info_) return query_info
[docs] def visit_BoolOp(self, node): query_info = self.visit(node.values[0]) query_info = query_info.copy() if isinstance(node.op, ast.Or): for value in node.values[1:]: query_info_ = self.visit(value) query_info.update(query_info_) elif isinstance(node.op, ast.And): for value in node.values[1:]: query_info_ = self.visit(value) query_info.intersection_update(query_info_) else: return self.generic_visit(node) return query_info
[docs] def visit_BinOp(self, node): info_left = self.visit(node.left) info_right = self.visit(node.right) if isinstance(node.op, ast.Add): return info_left.union(info_right) if isinstance(node.op, ast.Mult): return info_left.intersection(info_right) if isinstance(node.op, ast.Sub): return ( info_left.difference(info_right) ) else: return self.generic_visit(node)
[docs] def visit_UnaryOp(self, node): query_info = self.visit(node.operand) if isinstance(node.op, ast.Invert): return FiberQueryInfo( set( tract for tract in query_info.tracts if ( self.tractography_spatial_indexing. crossing_tracts_labels[tract]. issubset(query_info.labels) ) ), query_info.labels ) elif isinstance(node.op, ast.UAdd): return query_info elif isinstance(node.op, ast.USub) or isinstance(node.op, ast.Not): all_labels = set( self.tractography_spatial_indexing. crossing_labels_tracts.keys() ) all_labels.difference_update(query_info.labels) all_tracts = set().union(*tuple( ( self.tractography_spatial_indexing. crossing_labels_tracts[label] for label in all_labels ) )) new_info = FiberQueryInfo(all_tracts, all_labels) return new_info else: raise TractQuerierSyntaxError( "Syntax error in query line %d" % node.lineno)
[docs] def visit_Str(self, node): query_info = FiberQueryInfo() for name in fnmatch.filter(self.evaluated_queries_info.keys(), node.s): query_info.update(self.evaluated_queries_info[name]) return query_info
[docs] def visit_Call(self, node): # Single string argument function if ( isinstance(node.func, ast.Name) and len(node.args) == 1 and len(node.args) == 1 and node.starargs is None and node.keywords == [] and node.kwargs is None ): if (node.func.id.lower() == 'only'): query_info = self.visit(node.args[0]) only_tracts = set( tract for tract in query_info.tracts if ( self.tractography_spatial_indexing. crossing_tracts_labels[tract]. issubset(query_info.labels) ) ) only_endpoints = tuple(( set( tract for tract in query_info.tracts_endpoints[i] if ( self.tractography_spatial_indexing. ending_tracts_labels[i][tract] in query_info.labels ) ) for i in (0, 1) )) return FiberQueryInfo( only_tracts, query_info.labels, only_endpoints ) elif (node.func.id.lower() == 'endpoints_in'): query_info = self.visit(node.args[0]) new_tracts = ( query_info.tracts_endpoints[0]. union(query_info.tracts_endpoints[1]) ) return FiberQueryInfo( new_tracts, query_info.labels, query_info.tracts_endpoints ) elif (node.func.id.lower() == 'both_endpoints_in'): query_info = self.visit(node.args[0]) new_tracts = ( query_info.tracts_endpoints[0]. intersection(query_info.tracts_endpoints[1]) ) return FiberQueryInfo( new_tracts, query_info.labels, query_info.tracts_endpoints ) elif ( node.func.id.lower() == 'save' and isinstance(node.args, ast.Str) ): self.queries_to_save.add(node.args[0].s) return elif node.func.id.lower() in self.relative_terms: return self.process_relative_term(node) raise TractQuerierSyntaxError("Invalid query in line %d" % node.lineno)
[docs] def process_relative_term(self, node): r""" Processes the relative terms * anterior_of * posterior_of * superior_of * inferior_of * medial_of * lateral_of Parameters ---------- node : :py:class:`ast.Node` Parsed tree Returns ------- tracts, labels tracts : set Numbers of the tracts that result of this query labels : set Numbers of the labels that are traversed by the tracts resulting from this query """ if len(self.tractography_spatial_indexing.label_bounding_boxes) == 0: return FiberQueryInfo() arg = node.args[0] if isinstance(arg, ast.Name): query_info = self.visit(arg) elif isinstance(arg, ast.Attribute): if arg.attr.lower() in ('left', 'right'): side = arg.attr.lower() query_info = self.visit(arg) else: raise TractQuerierSyntaxError( "Attribute not recognized for relative specification." "Line %d" % node.lineno ) labels = query_info.labels labels_generator = (l for l in labels) try: bounding_box = ( self.tractography_spatial_indexing. label_bounding_boxes[labels_generator.next()] ) for label in labels_generator: bounding_box = bounding_box.union( self.tractography_spatial_indexing. label_bounding_boxes[label] ) except KeyError as e: raise TractQuerierLabelNotFound( "Label %s not found in atlas file" % e ) function_name = node.func.id.lower() name = function_name.replace('_of', '') if ( name in ('anterior', 'inferior') or name == 'medial' and side == 'left' or name == 'lateral' and side == 'right' ): operator = gt else: operator = lt if name == 'medial': if side == 'left': name = 'right' else: name = 'left' elif name == 'lateral': if side == 'left': name = 'left' else: name = 'right' tract_bounding_box_coordinate =\ self.tractography_spatial_indexing.tract_bounding_boxes[name] tract_endpoints_pos =\ self.tractography_spatial_indexing.tract_endpoints_pos bounding_box_coordinate = getattr(bounding_box, name) if name in ('left', 'right'): column = 0 elif name in ('anterior', 'posterior'): column = 1 elif name in ('superior', 'inferior'): column = 2 tracts = set( operator( tract_bounding_box_coordinate, bounding_box_coordinate ).nonzero()[0] ) endpoints = tuple(( set( operator( tract_endpoints_pos[:, i, column], bounding_box_coordinate ).nonzero()[0] ) for i in (0, 1) )) labels = set().union(*tuple(( self.tractography_spatial_indexing.crossing_tracts_labels[tract] for tract in tracts ))) return FiberQueryInfo(tracts, labels, endpoints)
[docs] def visit_Assign(self, node): if len(node.targets) > 1: raise TractQuerierSyntaxError( "Invalid assignment in line %d" % node.lineno) queries_to_evaluate = self.process_assignment(node) for query_name, value_node in queries_to_evaluate.items(): self.queries_to_save.add(query_name) self.evaluated_queries_info[query_name] = self.visit(value_node)
[docs] def visit_AugAssign(self, node): if not isinstance(node.op, ast.BitOr): raise TractQuerierSyntaxError( "Invalid assignment in line %d" % node.lineno) queries_to_evaluate = self.process_assignment(node) for query_name, value_node in queries_to_evaluate.items(): query_info = self.visit(value_node) self.evaluated_queries_info[query_name] = query_info
[docs] def process_assignment(self, node): r""" Processes the assignment operations Parameters ---------- node : :py:class:`ast.Node` Parsed tree Returns ------- queries_to_evaluate: dict A dictionary or pairs '<name of the query>'= <node to evaluate> """ queries_to_evaluate = {} if 'target' in node._fields: target = node.target if 'targets' in node._fields: target = node.targets[0] if isinstance(target, ast.Name): queries_to_evaluate[target.id] = node.value elif ( isinstance(target, ast.Attribute) and target.attr == 'side' ): node_left, node_right = self.rewrite_side_query(node) self.visit(node_left) self.visit(node_right) elif ( isinstance(target, ast.Attribute) and isinstance(target.value, ast.Name) ): queries_to_evaluate[ target.value.id.lower() + '.' + target.attr.lower()] = node.value else: raise TractQuerierSyntaxError( "Invalid assignment in line %d" % node.lineno) return queries_to_evaluate
[docs] def rewrite_side_query(self, node): r""" Processes the side suffixes in a query Parameters ---------- node : :py:class:`ast.Node` Parsed tree Returns ------- node_left, node_right: nodes two AST nodes, one for the query instantiated on the left hemisphere one for the query instantiated on the right hemisphere """ node_left = deepcopy(node) node_right = deepcopy(node) for node_ in ast.walk(node_left): if isinstance(node_, ast.Attribute): if node_.attr == 'side': node_.attr = 'left' elif node_.attr == 'opposite': node_.attr = 'right' for node_ in ast.walk(node_right): if isinstance(node_, ast.Attribute): if node_.attr == 'side': node_.attr = 'right' elif node_.attr == 'opposite': node_.attr = 'left' return node_left, node_right
[docs] def visit_Name(self, node): if node.id in self.evaluated_queries_info: return self.evaluated_queries_info[node.id] else: raise TractQuerierSyntaxError( "Invalid query name in line %d: %s" % (node.lineno, node.id))
[docs] def visit_Attribute(self, node): if not isinstance(node.value, ast.Name): raise TractQuerierSyntaxError( "Invalid query in line %d: %s" % node.lineno) query_name = node.value.id + '.' + node.attr if query_name in self.evaluated_queries_info: return self.evaluated_queries_info[query_name] else: raise TractQuerierSyntaxError( "Invalid query name in line %d: %s" % (node.lineno, query_name) )
[docs] def visit_Num(self, node): if ( node.n in self.tractography_spatial_indexing.crossing_labels_tracts ): tracts = ( self.tractography_spatial_indexing. crossing_labels_tracts[node.n] ) else: tracts = set() endpoints = (set(), set()) for i in (0, 1): elt = self.tractography_spatial_indexing.ending_labels_tracts[i] if node.n in elt: endpoints[i].update(elt[node.n]) labelset = set((node.n,)) tract_info = FiberQueryInfo( tracts, labelset, endpoints ) return tract_info
[docs] def visit_Expr(self, node): if isinstance(node.value, ast.Name): if node.value.id in self.evaluated_queries_info.keys(): self.queries_to_save.add(node.value.id) else: raise TractQuerierSyntaxError( "Query %s not known line: %d" % (node.value.id, node.lineno) ) elif isinstance(node.value, ast.Module): self.visit(node.value) else: raise TractQuerierSyntaxError( "Invalid expression at line: %d" % (node.lineno))
[docs] def generic_visit(self, node): raise TractQuerierSyntaxError( "Invalid Operation %s line: %d" % (type(node), node.lineno))
[docs] def visit_For(self, node): id_to_replace = node.target.id.lower() iter_ = node.iter if isinstance(iter_, ast.Str): list_items = fnmatch.filter( self.evaluated_queries_info.keys(), iter_.s.lower()) elif isinstance(iter_, ast.List): list_items = [] for item in iter_.elts: if isinstance(item, ast.Name): list_items.append(item.id.lower()) else: raise TractQuerierSyntaxError( 'Error in FOR statement in line %d,' ' elements in the list must be query names' % node.lineno ) original_body = ast.Module(body=node.body) for item in list_items: aux_body = deepcopy(original_body) for node_ in ast.walk(aux_body): if ( isinstance(node_, ast.Name) and node_.id.lower() == id_to_replace ): node_.id = item self.visit(aux_body)
[docs]class TractQuerierSyntaxError(ValueError): def __init__(self, value): self.value = value def __str__(self): return repr(self.value)
[docs]class TractQuerierLabelNotFound(ValueError): def __init__(self, value): self.value = value def __str__(self): return repr(self.value)
class RewriteChangeNotInPrescedence(ast.NodeTransformer): def visit_BoolOp(self, node): predicate = lambda value: not ( isinstance(value, ast.Compare) and isinstance(value.ops[0], ast.NotIn) ) values_which_are_not_in_op = [value for value in takewhile( predicate, node.values[1:] )] if (len(values_which_are_not_in_op) == len(node.values) - 1): return node old_CompareNode = node.values[len(values_which_are_not_in_op) + 1] new_CompareNodeLeft = ast.copy_location( ast.BoolOp( op=node.op, values=( [node.values[0]] + values_which_are_not_in_op + [old_CompareNode.left] ) ), node ) new_CompareNode = ast.copy_location( ast.Compare( left=new_CompareNodeLeft, ops=old_CompareNode.ops, comparators=old_CompareNode.comparators ), node ) rest_of_the_values = node.values[len(values_which_are_not_in_op) + 2:] if len(rest_of_the_values) == 0: return self.visit(new_CompareNode) else: return self.visit(ast.copy_location( ast.BoolOp( op=node.op, values=( [new_CompareNode] + rest_of_the_values ) ), node )) class RewritePreprocess(ast.NodeTransformer): def __init__(self, *args, **kwargs): if 'include_folders' in kwargs: self.include_folders = kwargs['include_folders'] kwargs['include_folders'] = None del kwargs['include_folders'] else: self.include_folders = ['.'] super(RewritePreprocess, self).__init__(*args, **kwargs) def visit_Attribute(self, node): return ast.copy_location( ast.Attribute( value=self.visit(node.value), attr=node.attr.lower() ), node ) def visit_Name(self, node): return ast.copy_location( ast.Name(id=node.id.lower()), node ) def visit_Str(self, node): return ast.copy_location( ast.Str(s=node.s.lower()), node ) def visit_Import(self, node): try: module_names = [] for module_name in node.names: file_name = module_name.name found = False for folder in self.include_folders: file_ = path.join(folder, file_name) if path.exists(file_) and path.isfile(file_): module_names.append(file_) found = True break if not found: raise TractQuerierSyntaxError( 'Imported file not found: %s' % file_name ) imported_modules = [ ast.parse(file(module_name).read(), filename=module_name) for module_name in module_names ] except SyntaxError: import sys import traceback exc_type, exc_value, exc_traceback = sys.exc_info() formatted_lines = traceback.format_exc().splitlines() raise TractQuerierSyntaxError( 'syntax error in line %s line %d: \n%s\n%s' % ( module_name, exc_value[1][1], formatted_lines[-3], formatted_lines[-2] ) ) new_node = ast.Module(imported_modules) return ast.copy_location( self.visit(new_node), node )
[docs]def queries_preprocess(query_file, filename='<unknown>', include_folders=[]): try: query_file_module = ast.parse(query_file, filename='<unknown>') except SyntaxError: import sys import traceback exc_type, exc_value, exc_traceback = sys.exc_info() formatted_lines = traceback.format_exc().splitlines() raise TractQuerierSyntaxError( 'syntax error in line %s line %d: \n%s\n%s' % ( filename, exc_value[1][1], formatted_lines[-3], formatted_lines[-2] ) ) rewrite_preprocess = RewritePreprocess(include_folders=include_folders) rewrite_precedence_not_in = RewriteChangeNotInPrescedence() preprocessed_module = rewrite_precedence_not_in.visit( rewrite_preprocess.visit(query_file_module) ) return preprocessed_module.body
[docs]def eval_queries( query_file_body, tractography_spatial_indexing ): eq = EvaluateQueries(tractography_spatial_indexing) if isinstance(query_file_body, list): eq.visit(ast.Module(query_file_body)) else: eq.visit(query_file_body) return dict([ (key, eq.evaluated_queries_info[key].tracts) for key in eq.queries_to_save ])
[docs]def queries_syntax_check(query_file_body): class DummySpatialIndexing: def __init__(self): self.crossing_tracts_labels = {} self.crossing_labels_tracts = {} self.ending_tracts_labels = ({}, {}) self.ending_labels_tracts = ({}, {}) self.label_bounding_boxes = {} self.tract_bounding_boxes = {} eval_queries(query_file_body, DummySpatialIndexing())
def labels_for_tracts(crossing_tracts_labels): crossing_labels_tracts = {} for i, f in crossing_tracts_labels.items(): for l in f: if l in crossing_labels_tracts: crossing_labels_tracts[l].add(i) else: crossing_labels_tracts[l] = set((i,)) return crossing_labels_tracts