import logging
import re
from abc import ABCMeta
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any
from typing import List
from typing import Optional
from typing import Type
from typing import TypeVar
from typing import Union
import cql
from werkzeug.wrappers import Request
from ..constants import PARAM_EXTENSION_PREFIX
from ..constants import SRUDiagnostics
from ..constants import SRUOperation
from ..constants import SRUParam
from ..constants import SRUParamValue
from ..constants import SRUQueryType
from ..constants import SRURecordPacking
from ..constants import SRURecordXmlEscaping
from ..constants import SRURenderBy
from ..constants import SRUVersion
from ..diagnostic import SRUDiagnostic
from ..diagnostic import SRUDiagnosticList
from ..exception import SRUException
from ..queryparser import CQLQueryParser
from ..queryparser import SRUQuery
from ..queryparser import SRUQueryParserRegistry
from .auth import SRUAuthenticationInfo
from .auth import SRUAuthenticationInfoProvider
from .config import SRUServerConfig
T = TypeVar("T")
LOGGER = logging.getLogger("__name__")
QUERY_TYPE_ALLOWED_CHARS = re.compile(r"[a-zA-Z0-9][a-zA-Z0-9_-]*")
# ---------------------------------------------------------------------------
[docs]class SRURequest(metaclass=ABCMeta):
"""Provides information about a SRU request."""
[docs] @abstractmethod
def get_operation(self) -> SRUOperation:
"""Get the ``operation`` parameter of this request. Available
for **explain**, **searchRetrieve** and **scan** requests.
"""
[docs] @abstractmethod
def get_version(self) -> SRUVersion:
"""Get the **version** parameter of this request. Available
for **explain**, **searchRetrieve** and **scan** requests.
"""
[docs] def is_version(self, version: SRUVersion) -> bool:
"""Check if this request is of a specific version.
Args:
version: the version to check
Returns:
bool: ``True`` if this request is in the requested
version, ``False`` otherwise
"""
if version is None:
raise Type("version is None")
return self.get_version() == version
[docs] def is_version_between(self, min: SRUVersion, max: SRUVersion) -> bool:
"""Check if version of this request is at least `min` and
at most `max`.
Args:
min: the minimum version
max: the maximum version
Returns:
bool: ``True`` if this request is in the requested
version, ``False`` otherwise
"""
if min is None:
raise TypeError("min is None")
if max is None:
raise TypeError("max is None")
if min.version_number > max.version_number:
raise ValueError("min > max")
version = self.get_version()
return (
version.version_number >= min.version_number
and version.version_number <= max.version_number
)
[docs] @abstractmethod
def get_record_xml_escaping(self) -> SRURecordXmlEscaping:
"""Get the **recordXmlEscpaing** (SRU 2.0) or **recordPacking**
(SRU 1.1 and SRU 1.2) parameter of this request. Only
available for **explain** and **searchRetrieve** requests.
Returns:
SRURecordXmlEscaping: the record XML escaping method
"""
[docs] @abstractmethod
def get_record_packing(self) -> SRURecordPacking:
"""Get the **recordPacking** (SRU 2.0) parameter of this
request. Only available for **searchRetrieve** requests.
Returns:
SRURecordPacking: the record packing method
"""
[docs] @abstractmethod
def get_query(self) -> Optional[SRUQuery[Any]]:
"""Get the **query** parameter of this request. Only available
for **searchRetrieve** requests.
Returns:
SRUQuery[Any]: an `SRUQuery` instance tailored for the
used queryType or `None` if not a **searchRetrieve**
request
"""
# TODO: required; pythonic?
# def get_query(self, type: Type[T]) -> Optional[SRUQuery[T]]:
[docs] def get_query_type(self) -> Optional[str]:
"""Get the **queryType** parameter of this request. Only
available for **searchRetrieve** requests.
Returns:
str: the queryType of the parsed query or `None` if not a
**searchRetrieve** request
"""
query = self.get_query()
if query is None:
return None
return query.query_type
[docs] def is_query_type(self, query_type: str) -> bool:
"""Check if the request was made with the given queryType.
Only available for **searchRetrieve** requests.
Args:
query_type: the queryType to compare with
Returns:
bool: ``True`` if the queryType matches, ``False``
otherwise
"""
if query_type is None:
return False
return self.get_query_type() == query_type
[docs] @abstractmethod
def get_start_record(self) -> int:
"""Get the **startRecord** parameter of this request. Only
available for **searchRetrieve** requests. If the client did
not provide a value for the request, it is set to ``1``.
Returns:
int: the number of the start record
"""
[docs] @abstractmethod
def get_maximum_records(self) -> int:
"""Get the **maximumRecords** parameter of this request. Only
available for **searchRetrieve** requests. If no value was
supplied with the request, the server will automatically set
a default value.
Returns:
int: the maximum number of records
"""
[docs] @abstractmethod
def get_record_schema_identifier(self) -> Optional[str]:
"""Get the record schema identifier derived from the
**recordSchema** parameter of this request. Only available
for **searchRetrieve** requests. If the request was send with
the short record schema name, it will automatically expanded
to the record schema identifier.
Returns:
str: the record schema identifier or `None` if no
**recordSchema** parameter was supplied for this
request
"""
[docs] @abstractmethod
def get_record_xpath(self) -> Optional[str]:
"""Get the **recordXPath** parameter of this request. Only
available for **searchRetrieve** requests and version 1.1
requests.
Returns:
str: the record XPath or `None` of no value was supplied
for this request
"""
[docs] @abstractmethod
def get_resultSet_TTL(self) -> int:
"""Get the **resultSetTTL** parameter of this request. Only
available for **searchRetrieve** requests.
Returns:
int: the result set TTL or ``-1`` if no value was
supplied for this request
"""
[docs] @abstractmethod
def get_sortKeys(self) -> Optional[str]:
"""Get the **sortKeys** parameter of this request. Only
available for **searchRetrieve** requests and version 1.1 requests.
Returns:
str: the record XPath or `None` of no value was supplied
for this request
"""
# TODO CQLQuery/CQLNode?
[docs] @abstractmethod
def get_scan_clause(self) -> Optional[cql.CQLQuery]:
"""Get the **scanClause** parameter of this request. Only
available for **scan** requests.
Returns:
cql.CQLQuery: the parsed scan clause or `None` if not a
**scan** request
"""
[docs] @abstractmethod
def get_response_position(self) -> int:
"""Get the **responsePosition** parameter of this request.
Only available for **scan** requests. If the client did not
provide a value for the request, it is set to ``1``.
Returns:
int: the response position
"""
[docs] @abstractmethod
def get_maximum_terms(self) -> int:
"""Get the **maximumTerms** parameter of this request.
Available for any type of request.
Returns:
int: the maximum number of terms or ``-1`` if no value
was supplied for this request
"""
[docs] @abstractmethod
def get_stylesheet(self) -> Optional[str]:
"""Get the **stylesheet** parameter of this request.
Available for **explain**, **searchRetrieve** and **scan**
requests.
Returns:
str: the stylesheet or `None` if no value was supplied
for this request
"""
[docs] @abstractmethod
def get_renderBy(self) -> Optional[SRURenderBy]:
"""Get the **renderBy** parameter of this request.
Returns:
SRURenderBy: the renderBy parameter or `None` if no value
was supplied for this request
"""
[docs] @abstractmethod
def get_response_type(self) -> Optional[str]:
"""(SRU 2.0) The request parameter **responseType**, paired
with the Internet media type specified for the response (via
either the httpAccept parameter or http accept header)
determines the schema for the response.
Returns:
str: the value of the responeType request parameter or
`None` if no value was supplied for this request
"""
[docs] @abstractmethod
def get_http_accept(self) -> Optional[str]:
"""(SRU 2.0) The request parameter **httpAccept** may be
supplied to indicate the preferred format of the response.
The value is an Internet media type.
Returns:
str: the value of the httpAccept request parameter or
`None` if no value was supplied for
"""
[docs] @abstractmethod
def get_protocol_schema(self) -> str:
"""Get the protocol schema which was used of this request.
Available for **explain**, **searchRetrieve** and **scan**
requests.
Returns:
str: the protocol scheme
"""
# ---------------------------------------------------------------------------
[docs]@dataclass(frozen=True)
class ParameterInfo:
[docs] class Parameter(str, Enum):
STYLESHEET = "stylesheet"
RENDER_BY = "render_by"
HTTP_ACCEPT = "http_accept"
RESPONSE_TYPE = "response_type"
START_RECORD = "start_record"
MAXIMUM_RECORDS = "maximum_records"
RECORD_XML_ESCAPING = "record_xml_escaping"
RECORD_PACKING = "record_packing"
RECORD_SCHEMA = "record_schema"
RECORD_XPATH = "record_xpath"
RESULT_SET_TTL = "result_set_ttl"
SORT_KEYS = "sort_keys"
SCAN_CLAUSE = "scan_clause"
RESPONSE_POSITION = "response_position"
MAXIMUM_TERMS = "maximum_terms"
# ----------------------------------------------------# ----------------------------------------------------
parameter: Parameter
mandatory: bool
min: SRUVersion
max: SRUVersion
[docs] def name(self, version: SRUVersion) -> Optional[str]:
if self.parameter == ParameterInfo.Parameter.STYLESHEET:
return SRUParam.STYLESHEET
if self.parameter == ParameterInfo.Parameter.RENDER_BY:
return SRUParam.RENDER_BY
if self.parameter == ParameterInfo.Parameter.HTTP_ACCEPT:
return SRUParam.HTTP_ACCEPT
if self.parameter == ParameterInfo.Parameter.RESPONSE_TYPE:
return SRUParam.RESPONSE_TYPE
if self.parameter == ParameterInfo.Parameter.START_RECORD:
return SRUParam.START_RECORD
if self.parameter == ParameterInfo.Parameter.MAXIMUM_RECORDS:
return SRUParam.MAXIMUM_RECORDS
if self.parameter == ParameterInfo.Parameter.RECORD_SCHEMA:
return SRUParam.RECORD_SCHEMA
if self.parameter == ParameterInfo.Parameter.RECORD_XPATH:
return SRUParam.RECORD_XPATH
if self.parameter == ParameterInfo.Parameter.RESULT_SET_TTL:
return SRUParam.RESULT_SET_TTL
if self.parameter == ParameterInfo.Parameter.SORT_KEYS:
return SRUParam.SORT_KEYS
if self.parameter == ParameterInfo.Parameter.SCAN_CLAUSE:
return SRUParam.SCAN_CLAUSE
if self.parameter == ParameterInfo.Parameter.RESPONSE_POSITION:
return SRUParam.RESPONSE_POSITION
if self.parameter == ParameterInfo.Parameter.MAXIMUM_TERMS:
return SRUParam.MAXIMUM_TERMS
if self.parameter == ParameterInfo.Parameter.RECORD_XML_ESCAPING:
"""
'recordPacking' was renamed to 'recordXMLEscaping' in SRU 2.0.
For library API treat 'recordPacking' parameter as 'recordPacking'
for SRU 1.1 and SRU 1.2.
"""
if version == SRUVersion.VERSION_2_0:
return SRUParam.RECORD_XML_ESCAPING
else:
return SRUParam.RECORD_PACKING
if self.parameter == ParameterInfo.Parameter.RECORD_PACKING:
"""
'recordPacking' only exists in SRU 2.0; the old variant is
handled by the case for RECORD_XML_ESCAPING
"""
if version == SRUVersion.VERSION_2_0:
return SRUParam.RECORD_PACKING
else:
return None
raise ValueError(f"unknown ParameterInfo.Parameter? {self.parameter}")
[docs] def is_for_version(self, version: SRUVersion) -> bool:
return (
self.min.version_number <= version.version_number
and self.max.version_number >= version.version_number
)
[docs]class ParameterInfoSets(Enum):
EXPLAIN = [
ParameterInfo(
ParameterInfo.Parameter.STYLESHEET,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_1_2,
),
ParameterInfo(
ParameterInfo.Parameter.RECORD_XML_ESCAPING,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_1_2,
),
]
SCAN = [
ParameterInfo(
ParameterInfo.Parameter.STYLESHEET,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.HTTP_ACCEPT,
False,
SRUVersion.VERSION_2_0,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.SCAN_CLAUSE,
True,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.RESPONSE_POSITION,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.MAXIMUM_TERMS,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_2_0,
),
]
SEARCH_RETRIEVE = [
ParameterInfo(
ParameterInfo.Parameter.STYLESHEET,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_1_2,
),
ParameterInfo(
ParameterInfo.Parameter.HTTP_ACCEPT,
False,
SRUVersion.VERSION_2_0,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.RENDER_BY,
False,
SRUVersion.VERSION_2_0,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.RESPONSE_TYPE,
False,
SRUVersion.VERSION_2_0,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.START_RECORD,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.MAXIMUM_RECORDS,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.RECORD_XML_ESCAPING,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.RECORD_PACKING,
False,
SRUVersion.VERSION_2_0,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.RECORD_SCHEMA,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.RESULT_SET_TTL,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_2_0,
),
ParameterInfo(
ParameterInfo.Parameter.RECORD_XPATH,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_1_2,
),
ParameterInfo(
ParameterInfo.Parameter.SORT_KEYS,
False,
SRUVersion.VERSION_1_1,
SRUVersion.VERSION_2_0,
),
]
[docs] @classmethod
def for_operation(
cls, operation: Optional[SRUOperation]
) -> Optional[List[ParameterInfo]]:
if not operation:
return None
if operation == SRUOperation.EXPLAIN:
return cls.EXPLAIN.value
if operation == SRUOperation.SCAN:
return cls.SCAN.value
if operation == SRUOperation.SEARCH_RETRIEVE:
return cls.SEARCH_RETRIEVE.value
# actually cannot happen
return None
# ---------------------------------------------------------------------------
DEFAULT_START_RECORD = 1
DEFAULT_RESPONSE_POSITION = 1
[docs]class SRURequestImpl(SRUDiagnosticList, SRURequest):
def __init__(
self,
config: SRUServerConfig,
query_parsers: SRUQueryParserRegistry,
request: Request,
authentication_info_provider: Optional[SRUAuthenticationInfoProvider] = None,
):
self.config = config
self.query_parsers = query_parsers
self.authentication_info_provider = authentication_info_provider
self.authentication_info: Optional[SRUAuthenticationInfo] = None
self.request = request
self.diagnostics: List[SRUDiagnostic] = list()
# NOTE: set default to EXPLAIN
# (usually correctly set when parameters validated but operations
# expect some value to be set, not None allowed)
# FIXME: default value version None handling?
self.operation: SRUOperation = SRUOperation.EXPLAIN
self.version: Optional[SRUVersion] = None
self.response_type: Optional[str] = None
self.http_accept: Optional[str] = None
self.record_xml_escaping: Optional[SRURecordXmlEscaping] = None
self.record_packing: Optional[SRURecordPacking] = None
self.renderBy: Optional[SRURenderBy] = None
self.stylesheet: Optional[str] = None
self.query: Optional[SRUQuery[Any]] = None
self.start_record = DEFAULT_START_RECORD
self.maximum_records = -1
self.response_position = DEFAULT_RESPONSE_POSITION
self.maximum_terms = -1
self.record_schema_identifier: Optional[str] = None
self.record_xpath: Optional[str] = None
self.resultSet_TTL = -1
self.sortKeys: Optional[str] = None
self.scan_clause: Optional[cql.CQLQuery] = None
# ----------------------------------------------------
[docs] def get_request(self) -> Request:
return self.request
[docs] def get_operation(self) -> SRUOperation:
return self.operation
[docs] def get_version(self) -> SRUVersion:
if self.version is not None:
return self.version
return self.config.default_version
[docs] def get_authentication(self) -> Optional[SRUAuthenticationInfo]:
return self.authentication_info
[docs] def get_authentication_subject(self) -> Optional[str]:
if not self.authentication_info:
return None
return self.authentication_info.subject
# ----------------------------------------------------
[docs] def get_query(self) -> Optional[SRUQuery[Any]]:
return self.query
[docs] def get_record_xml_escaping(self) -> SRURecordXmlEscaping:
if self.record_xml_escaping is not None:
return self.record_xml_escaping
return self.config.default_record_xml_escaping
[docs] def get_record_packing(self) -> SRURecordPacking:
if self.record_packing is not None:
return self.record_packing
return self.config.default_record_packing
[docs] def get_start_record(self) -> int:
return self.start_record
[docs] def get_maximum_records(self) -> int:
if self.config.allow_override_maximum_records and self.get_extra_request_data(
SRUParam.X_UNLIMITED_RESULTSET
):
return -1
if self.maximum_records == -1:
return self.config.number_of_records
if self.maximum_records > self.config.maximum_records:
return self.config.maximum_records
return self.maximum_records
[docs] def get_record_schema_identifier(self) -> Optional[str]:
return self.record_schema_identifier
[docs] def get_record_xpath(self) -> Optional[str]:
return self.record_xpath
[docs] def get_resultSet_TTL(self) -> int:
return self.resultSet_TTL
[docs] def get_sortKeys(self) -> Optional[str]:
return self.sortKeys
[docs] def get_scan_clause(self) -> Optional[cql.CQLQuery]:
return self.scan_clause
[docs] def get_response_position(self) -> int:
return self.response_position
[docs] def get_maximum_terms(self) -> int:
if self.config.allow_override_maximum_terms and self.get_extra_request_data(
SRUParam.X_UNLIMITED_TERMLIST
):
return -1
if self.maximum_terms == -1:
return self.config.number_of_terms
if self.maximum_records > self.config.maximum_terms:
return self.config.maximum_terms
return self.maximum_terms
[docs] def get_stylesheet(self) -> Optional[str]:
return self.stylesheet
[docs] def get_renderBy(self) -> Optional[SRURenderBy]:
return self.renderBy
[docs] def get_response_type(self) -> Optional[str]:
return self.response_type
# ----------------------------------------------------
# raw/parameter grabby stuff
[docs] def get_version_raw(self) -> Optional[SRUVersion]:
return self.version
[docs] def get_record_xml_escaping_raw(self) -> Optional[str]:
if self.is_version(SRUVersion.VERSION_2_0):
return self.get_parameter(SRUParam.RECORD_XML_ESCAPING, True, False)
else:
return self.get_parameter(SRUParam.RECORD_PACKING, True, False)
[docs] def get_record_packing_raw(self) -> Optional[str]:
if self.is_version(SRUVersion.VERSION_2_0):
return self.get_parameter(SRUParam.RECORD_PACKING, True, False)
else:
return None
[docs] def get_record_schema_identifier_raw(self) -> Optional[str]:
return self.get_parameter(SRUParam.RECORD_SCHEMA, True, False)
[docs] def get_query_raw(self) -> Optional[str]:
return self.get_parameter(SRUParam.QUERY, True, False)
[docs] def get_maximum_records_raw(self) -> int:
return self.maximum_records
[docs] def get_scan_clause_raw(self) -> Optional[str]:
return self.get_parameter(SRUParam.SCAN_CLAUSE, True, False)
[docs] def get_http_accept_raw(self) -> Optional[str]:
return self.get_parameter(SRUParam.HTTP_ACCEPT, True, False)
# FIXME: access request
[docs] def get_indent_response(self) -> int:
if self.config.allow_override_indent_response:
value_str = self.get_extra_request_data(SRUParam.X_INDENT_RESPONSE)
if value_str:
try:
value = int(value_str)
if value > -2 and value < 9:
return value
except Exception:
pass
return self.config.indent_response
[docs] def get_http_accept(self) -> Optional[str]:
if self.http_accept is not None:
return self.http_accept
return self.request.headers.get("ACCEPT")
[docs] def get_protocol_schema(self) -> str:
return "https://" if self.request.is_secure else "http://"
# ----------------------------------------------------
[docs] def add_diagnostic(
self, uri: str, details: Optional[str] = None, message: Optional[str] = None
) -> None:
self.add_diagnostic_obj(SRUDiagnostic(uri, details, message))
[docs] def add_diagnostic_obj(self, diagnostic: SRUDiagnostic):
if self.diagnostics is None:
self.diagnostics = list()
self.diagnostics.append(diagnostic)
# ----------------------------------------------------
def _parse_number_parameter(self, param: str, value: str, min: int) -> int:
result = -1
if value:
try:
result = int(value)
if result < min:
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_PARAMETER_VALUE,
param,
f"Value is less than {min}.",
)
except Exception:
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_PARAMETER_VALUE,
param,
"Invalid number format.",
)
return result
def _parse_scan_query_parameter(
self, param: str, value: str
) -> Optional[cql.CQLQuery]:
# NOTE: this should only be called in `check_parameters_rest`
# when version is not None anymore
sru_query = CQLQueryParser().parse_query(
self.version, {SRUParam.QUERY: value}, self # type: ignore
)
if sru_query is None:
return None
return sru_query.parsed_query
def _parse_and_check_version_parameter(
self, operation: SRUOperation
) -> Optional[SRUVersion]:
version_str = self.get_parameter(SRUParam.VERSION, True, True)
if version_str:
if version_str == SRUVersion.VERSION_1_1:
return SRUVersion.VERSION_1_1
if version_str == SRUVersion.VERSION_1_2:
return SRUVersion.VERSION_1_2
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_VERSION,
SRUVersion.VERSION_1_2,
f"Version '{version_str}' is not supported",
)
return None
# except for "explain" operation, complain if "version" parameter
# was not supplied.
if operation != SRUOperation.EXPLAIN:
self.add_diagnostic(
SRUDiagnostics.MANDATORY_PARAMETER_NOT_SUPPLIED,
str(SRUParam.VERSION),
f"Mandatory parameter '{SRUParam.VERSION!s}' was not supplied.",
)
# this is an explain operation, assume default version
return self.config.default_version
# ----------------------------------------------------
[docs] def check_parameters(self) -> bool:
"""Validate incoming request parameters
Returns:
bool: ``True`` if successful, ``False`` if something
went wrong
"""
if not self.check_parameters_version_operation():
return False
self._check_parameters_rest()
self._check_parameters_auth()
# diagnostics is None -> consider as success
# FIXME: this should be done nicer!
return not self.diagnostics
[docs] def check_parameters_version_operation(self) -> bool:
"""Validate incoming request parameters **version** and
**operation**.
Returns:
bool: ``True`` if successful, ``False`` if something
went wrong
"""
# generally assume, we will also allow processing of SRU 1.1 or 1.2
process_SRU_old = True
# Heuristic to detect SRU version and operation ...
if self.config.max_version >= SRUVersion.VERSION_2_0:
if not self.get_parameter(SRUParam.VERSION, False, False):
# Ok, we're committed to SRU 2.0 now, so don't allow processing
# of SRU 1.1 and 1.2 ...
process_SRU_old = False
LOGGER.debug(
"handling request as SRU 2.0, because no '%s' parameter was found in the request",
SRUParam.VERSION,
)
if self.get_parameter(
SRUParam.QUERY, False, False
) or self.get_parameter(SRUParam.QUERY_TYPE, False, False):
LOGGER.debug(
"found parameter '%s' or '%s' therefore assuming '%s' operation",
SRUParam.QUERY,
SRUParam.QUERY_TYPE,
SRUOperation.SEARCH_RETRIEVE,
)
operation = SRUOperation.SEARCH_RETRIEVE
elif self.get_parameter(SRUParam.SCAN_CLAUSE, False, False):
LOGGER.debug(
"found parameter '%s' therefore assuming '%s' operation",
SRUParam.SCAN_CLAUSE,
SRUOperation.SCAN,
)
operation = SRUOperation.SCAN
else:
LOGGER.debug(
"no special parameter found therefore assuming '%s' operation",
SRUOperation.EXPLAIN,
)
operation = SRUOperation.EXPLAIN
# record version ...
version: Optional[SRUVersion] = SRUVersion.VERSION_2_0
# do pedantic check for 'operation' parameter
operation_str = self.get_parameter(SRUParam.OPERATION, False, False)
if operation_str:
# XXX: if operation is searchRetrive and the 'operation'
# parameter is also searchRetrieve, should the server just
# ignore it?
if (
operation != SRUOperation.SEARCH_RETRIEVE
and operation_str == SRUOperation.SEARCH_RETRIEVE
):
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_PARAMETER,
SRUParam.OPERATION,
message=f"Parameter '{SRUParam.OPERATION}' is not valid for SRU version 2.0",
)
else:
LOGGER.debug(
"handling request as legacy SRU, because found parameter '%s' in request",
SRUParam.VERSION,
)
if process_SRU_old:
# parse mandatory operation parameter
operation_str = self.get_parameter(SRUParam.OPERATION, False, False)
if operation_str:
if not operation_str.isspace():
if operation_str == SRUOperation.EXPLAIN:
operation = SRUOperation.EXPLAIN
elif operation_str == SRUOperation.SCAN:
operation = SRUOperation.SCAN
elif operation_str == SRUOperation.SEARCH_RETRIEVE:
operation = SRUOperation.SEARCH_RETRIEVE
else:
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_OPERATION,
message=f"Operation '{operation_str}' is not supported",
)
else:
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_OPERATION,
message=f"An empty parameter '{SRUParam.OPERATION}' is not supported.",
)
# parse and check version
version = self._parse_and_check_version_parameter(operation)
else:
# absent parameter should be interpreted as "explain"
operation = SRUOperation.EXPLAIN
# parse and check version
version = self._parse_and_check_version_parameter(operation)
# sanity check
if version and operation:
LOGGER.debug(
"min = %s, min? = %s, max = %s, max? = %s, version = %s",
self.config.min_version,
version == self.config.min_version,
self.config.max_version,
version == self.config.max_version,
version,
)
if (
version >= self.config.min_version
and version <= self.config.max_version
):
self.version = version
self.operation = operation
return True
else:
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_VERSION,
self.config.max_version,
message=f"Version '{version}' is not supported by this endpoint.",
)
LOGGER.debug("bailed")
return False
def _check_parameters_rest(self) -> bool:
"""Validate incoming request parameters.
Returns:
bool: ``True`` if successful, ``False`` if something
went wrong
"""
if self.diagnostics:
# this should only happen if repeatedly called
# which is not done usually
return False
# check mandatory/optional parameters for operation
parameters = ParameterInfoSets.for_operation(self.operation)
if not parameters:
self.add_diagnostic(
SRUDiagnostics.GENERAL_SYSTEM_ERROR,
message="internal error (invalid operation)",
)
return False
# keep list of all submitted parameters (except "operation" and
# "version"), so we can later warn if an unsupported parameter
# was sent (= not all parameters were consumed).
parameter_names = self.get_parameter_names()
# check parameters ...
for parameter in parameters:
name = parameter.name(self.version) # type: ignore
if not name:
# this parameter is not supported in the SRU version that
# was used for the request
continue
value = self.get_parameter(name, True, True)
if value is None:
if parameter.mandatory:
self.add_diagnostic(
SRUDiagnostics.MANDATORY_PARAMETER_NOT_SUPPLIED,
name,
message=f"Mandatory parameter '{name}' was not supplied.",
)
continue
# remove supported parameter from list
if name in parameter_names:
parameter_names.remove(name)
# if parameter is not supported in this version, skip it
# and create add an diagnostic.
# NOTE: version is not None
if not parameter.is_for_version(self.version): # type: ignore
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_PARAMETER,
name,
message=f"Version {self.version} does not support parameter '{name}'.",
)
continue
# validate and parse parameters ...
if parameter.parameter == ParameterInfo.Parameter.RECORD_XML_ESCAPING:
if value == SRUParamValue.RECORD_XML_ESCAPING_XML:
self.record_xml_escaping = SRURecordXmlEscaping.XML
elif value == SRUParamValue.RECORD_XML_ESCAPING_STRING:
self.record_xml_escaping = SRURecordXmlEscaping.STRING
else:
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_XML_ESCAPING_VALUE,
message=f"Record XML escaping '{value}' is not supported.",
)
elif parameter.parameter == ParameterInfo.Parameter.RECORD_PACKING:
if value == SRUParamValue.RECORD_PACKING_PACKED:
self.record_packing = SRURecordPacking.PACKED
elif value == SRUParamValue.RECORD_PACKING_UNPACKED:
self.record_packing = SRURecordPacking.UNPACKED
else:
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_PARAMETER_VALUE,
message=f"Record packing '{value}' is not supported.",
)
elif parameter.parameter == ParameterInfo.Parameter.RENDER_BY:
if value == SRUParamValue.RENDER_BY_CLIENT:
self.renderBy = SRURenderBy.CLIENT
elif value == SRUParamValue.RENDER_BY_SERVER:
self.renderBy = SRURenderBy.SERVER
else:
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_PARAMETER_VALUE,
message=f"Value '{value}' for parameter '{name}' is not supported.",
)
elif parameter.parameter == ParameterInfo.Parameter.RECORD_SCHEMA:
# The parameter recordSchema may contain either schema
# identifier or the short name. If available, set to
# appropriate schema identifier in the request object.
schema_info = self.config.find_schema_info(value)
if schema_info:
self.record_schema_identifier = schema_info.identifier
else:
# SRU servers are supposed to raise a non-surrogate
# (fatal) diagnostic in case the record schema is not
# known to the server.
self.add_diagnostic(
SRUDiagnostics.UNKNOWN_SCHEMA_FOR_RETRIEVAL,
value,
message=f"Record schema '{value}' is not supported for retrieval.",
)
elif parameter.parameter == ParameterInfo.Parameter.START_RECORD:
self.start_record = self._parse_number_parameter(name, value, 1)
elif parameter.parameter == ParameterInfo.Parameter.RESPONSE_POSITION:
self.response_position = self._parse_number_parameter(name, value, 0)
elif parameter.parameter == ParameterInfo.Parameter.MAXIMUM_RECORDS:
self.maximum_records = self._parse_number_parameter(name, value, 0)
elif parameter.parameter == ParameterInfo.Parameter.MAXIMUM_TERMS:
self.maximum_terms = self._parse_number_parameter(name, value, 0)
elif parameter.parameter == ParameterInfo.Parameter.RESULT_SET_TTL:
self.resultSet_TTL = self._parse_number_parameter(name, value, 0)
elif parameter.parameter == ParameterInfo.Parameter.SCAN_CLAUSE:
self.scan_clause = self._parse_scan_query_parameter(name, value)
elif parameter.parameter == ParameterInfo.Parameter.RECORD_XPATH:
self.record_xpath = value
elif parameter.parameter == ParameterInfo.Parameter.SORT_KEYS:
self.sortKeys = value
elif parameter.parameter == ParameterInfo.Parameter.STYLESHEET:
self.stylesheet = value
elif parameter.parameter == ParameterInfo.Parameter.RESPONSE_TYPE:
# FIXME: check parameter validity?!
self.response_type = value
elif parameter.parameter == ParameterInfo.Parameter.HTTP_ACCEPT:
# FIXME: check parameter validity?!
self.http_accept = value
# handle query and queryType
if self.operation == SRUOperation.SEARCH_RETRIEVE:
# determine queryType
query_type: Optional[str] = None
if self.version == SRUVersion.VERSION_2_0:
if SRUParam.QUERY_TYPE in parameter_names:
parameter_names.remove(SRUParam.QUERY_TYPE)
value = self.get_parameter(SRUParam.QUERY_TYPE, True, True)
if value is None:
query_type = SRUQueryType.CQL.value
else:
has_bad_chars = QUERY_TYPE_ALLOWED_CHARS.fullmatch(value) is None
if has_bad_chars:
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_PARAMETER_VALUE,
SRUParam.QUERY_TYPE,
message="Value contains illegal characters.",
)
else:
query_type = value
else:
# SRU 1.1 and SRU 1.2 only support CQL
query_type = SRUQueryType.CQL.value
if query_type:
LOGGER.debug("looking for query parser for query type '%s'", query_type)
query_parser = self.query_parsers.find_query_parser(query_type)
if query_parser:
if query_parser.supports_version(self.version):
# gather query parameters
# (as required by QueryParser implementation)
query_parameters = dict()
missing_parameters = list()
for name in query_parser.query_parameter_names:
if name in parameter_names:
parameter_names.remove(name)
value = self.get_parameter(name, True, False)
if value is not None:
query_parameters[name] = value
else:
missing_parameters.append(name)
if not missing_parameters:
LOGGER.debug(
"parsing query with parser for type '%s' and parameters %s",
query_parser.query_type,
query_parameters,
)
# NOTE: version is not None
self.query = query_parser.parse_query(
self.version, query_parameters, self # type: ignore
)
if not self.query:
LOGGER.debug("query parser failed to parse query")
self.add_diagnostic(
SRUDiagnostics.QUERY_SYNTAX_ERROR,
message="Query could not be parsed.",
)
else:
LOGGER.debug(
"parameters %s missing, cannot parse query",
missing_parameters,
)
for name in missing_parameters:
self.add_diagnostic(
SRUDiagnostics.MANDATORY_PARAMETER_NOT_SUPPLIED,
name,
message=(
f"Mandatory parameter '{name}' is missing or empty. "
f"Required to perform query of query type '{query_type}'."
),
)
else:
LOGGER.debug(
"query parser for query type '%s' is not supported by SRU version %s",
query_type,
self.version,
)
self.add_diagnostic(
SRUDiagnostics.CANNOT_PROCESS_QUERY_REASON_UNKNOWN,
message=(
f"Query parser for query type '{query_type}' is not"
f" supported by SRU version '{self.version}'."
),
)
else:
LOGGER.debug("no parser for query type '%s' found", query_type)
self.add_diagnostic(
SRUDiagnostics.CANNOT_PROCESS_QUERY_REASON_UNKNOWN,
message=f"Cannot find query parser for query type '{query_type}'.",
)
else:
LOGGER.debug("cannot determine query type")
self.add_diagnostic(
SRUDiagnostics.CANNOT_PROCESS_QUERY_REASON_UNKNOWN,
message="Cannot determine query type.",
)
# check if any parameters where not consumed and
# add appropriate warnings
if parameter_names:
for name in parameter_names:
# skip extraRequestData (aka extensions)
if not name.startswith(PARAM_EXTENSION_PREFIX):
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_PARAMETER,
name,
message=f"Parameter '{name}' is not supported for this operation.",
)
# diagnostics is None -> consider as success
# FIXME: this should be done nicer!
return not self.diagnostics
def _check_parameters_auth(self) -> None:
# extract authentication information from,
# if an authentication provider is set
if self.authentication_info_provider:
try:
self.authentication_info = (
self.authentication_info_provider.get_AuthenticationInfo(
self.request
)
)
except SRUException as ex:
self.add_diagnostic_obj(ex.get_diagnostic())
# ----------------------------------------------------
[docs] def get_parameter_names(self) -> List[str]:
parameters = list(self.request.args.keys())
parameters = [
p for p in parameters if p not in (SRUParam.OPERATION, SRUParam.VERSION)
]
return parameters
[docs] def get_parameter(
self, name: Union[SRUParam, str], nullify: bool, diagnostic_if_empty: bool
) -> Optional[str]:
value = self.request.args.get(name)
if value is not None:
value = value.strip()
if nullify and not value:
value = None
if diagnostic_if_empty:
self.add_diagnostic(
SRUDiagnostics.UNSUPPORTED_PARAMETER_VALUE,
name,
message=f"An empty parameter '{name}' is not supported.",
)
return value
# ---------------------------------------------------------------------------