Back to home page

Project CMSSW displayed by LXR

 
 

    


File indexing completed on 2024-04-06 12:01:53

0001 """
0002 
0003 connection class translates either a connection string for sqlite, oracle of frontier into a connection object.
0004 Also sets up ORM with SQLAlchemy.
0005 
0006 connection class can also take a pre-constructed engine - useful for web services.
0007 
0008 """
0009 
0010 import sqlalchemy
0011 from sqlalchemy import create_engine, text, or_
0012 from sqlalchemy.orm import sessionmaker
0013 from sqlalchemy.pool import NullPool
0014 import datetime
0015 from .data_sources import json_data_node
0016 from copy import deepcopy
0017 from . import models
0018 import traceback
0019 import os
0020 import netrc
0021 import sys
0022 
0023 class connection(object):
0024     engine = None
0025     connection = None
0026     session = None
0027     connection_data = None
0028     netrc_authenticators = None
0029     secrets = None
0030     """
0031 
0032     Given a connection string, parses the connection string and connects.
0033 
0034     """
0035     def __init__(self, connection_data, mode=None, map_blobs=False, secrets=None, pooling=False):
0036 
0037         self._pooling = pooling
0038 
0039         # add querying utility properties
0040         # these must belong to the connection since the way in which their values are handled
0041         # depends on the database being connected to.
0042         self.range = models.Range
0043         self.radius = models.Radius
0044         self.regexp = models.RegExp
0045         self.regexp.connection_object = self
0046 
0047         if type(connection_data) in [str, str]:
0048             # if we've been given a connection string, process it
0049             self.connection_data = new_connection_dictionary(connection_data, secrets=secrets, mode=mode)
0050             self.schema = self.connection_data.get("schema") if self.connection_data.get("schema") != None else ""
0051 
0052             self.range.database_type = self.connection_data["host"]
0053             self.radius.database_type = self.connection_data["host"]
0054             self.regexp.database_type = self.connection_data["host"]
0055         else:
0056             self.connection_data = connection_data
0057             # assume we have an engine
0058             # we need to take the string representation so we know which type of db we're aiming at
0059             engine_string = str(connection_data)
0060             db_type = None
0061             if "oracle" in engine_string:
0062                 db_type = "oracle"
0063             elif "frontier" in engine_string:
0064                 db_type = "frontier"
0065             elif "sqlite" in engine_string:
0066                 db_type = "sqlite"
0067 
0068             self.range.database_type = db_type
0069             self.radius.database_type = db_type
0070             self.regexp.database_type = db_type
0071 
0072         from . import models as ms
0073         self.models = ms.generate(map_blobs)
0074         #self.base = self.models["Base"]
0075 
0076     def setup(self):
0077         """
0078         Setup engine with given credentials from netrc file, and make a session maker.
0079         """
0080 
0081         if type(self.connection_data) == dict:
0082             self.engine = engine_from_dictionary(self.connection_data, pooling=self._pooling)
0083         else:
0084             # we've been given an engine by the user
0085             # use it as the engine
0086             self.engine = self.connection_data
0087 
0088         self.sessionmaker = sessionmaker(bind=self.engine)
0089         self.session = self.sessionmaker()
0090         self.factory = factory(self)
0091 
0092         # assign correct schema for database name to each model
0093         tmp_models_dict = {}
0094         for key in self.models:
0095             if self.models[key].__class__ == sqlalchemy.ext.declarative.api.DeclarativeMeta\
0096                and str(self.models[key].__name__) != "Base":
0097 
0098                 if type(self.connection_data) == dict:
0099                     # we can only extract the secrets and schema individuall
0100                     # if we were given a dictionary...  if we were given an engine
0101                     # we can't do this without parsing the connection string from the engine
0102                     # - a wide range of which it will be difficult to support!
0103                     self.models[key].__table__.schema = self.connection_data["schema"]
0104                     self.models[key].secrets = self.connection_data["secrets"]
0105 
0106                 self.models[key].session = self.session
0107                 # isn't used anywhere - comment it out for now
0108                 #self.models[key].authentication = self.netrc_authenticators
0109                 self.models[key].connection = self
0110                 tmp_models_dict[key.lower()] = self.models[key]
0111                 tmp_models_dict[key.lower()].empty = False
0112 
0113         self.models = tmp_models_dict
0114 
0115         return self
0116 
0117     @staticmethod
0118     def _get_CMS_frontier_connection_string(database):
0119         try:
0120             import subprocess
0121             return subprocess.Popen(['cmsGetFnConnect', 'frontier://%s' % database], stdout = subprocess.PIPE).communicate()[0].strip()
0122         except:
0123             raise Exception("Frontier connections can only be constructed when inside a CMSSW environment.")
0124 
0125     @staticmethod
0126     def _cms_frontier_string(database, schema="cms_conditions"):
0127         """
0128         Get database string for frontier.
0129         """
0130         import urllib.request, urllib.parse, urllib.error
0131         return 'oracle+frontier://@%s/%s' % (urllib.parse.quote_plus(connection._get_CMS_frontier_connection_string(database)), schema)
0132 
0133     @staticmethod
0134     def _cms_oracle_string(user, pwd, db_name):
0135         """
0136         Get database string for oracle.
0137         """
0138         return 'oracle://%s:%s@%s' % (user, pwd, db_name)
0139 
0140     @staticmethod
0141     def build_oracle_url(user, pwd, db_name):
0142         """
0143         Build the connection url, and get credentials from self.secrets dictionary.
0144         """
0145 
0146         database_url = connection._cms_oracle_string(user, pwd, db_name)
0147 
0148         try:
0149             url = sqlalchemy.engine.url.make_url(database_url)
0150             if url.password is None:
0151                 url.password = pwd
0152         except sqlalchemy.exc.ArgumentError:
0153             url = sqlalchemy.engine.url.make_url('sqlite:///%s' % db_name)
0154         return url
0155 
0156     @staticmethod
0157     def build_frontier_url(db_name, schema):
0158         database_url = connection._cms_frontier_string(db_name, schema)
0159 
0160         try:
0161             url = sqlalchemy.engine.url.make_url(database_url)
0162         except sqlalchemy.exc.ArgumentError:
0163             """
0164             Is this needed for a use case?
0165             """
0166             url = sqlalchemy.engine.url.make_url('sqlite:///%s' % db_name)
0167         return url
0168 
0169     # currently just commits and closes the current session (ends transaction, closes connection)
0170     # may do other things later
0171     def tear_down(self):
0172         try:
0173             self.session.commit()
0174             self.close_session()
0175         except:
0176             return "Couldn't tear down connection on engine %s." % str(self.engine)
0177 
0178     def close_session(self):
0179         self.session.close()
0180         return True
0181 
0182     def hard_close(self):
0183         self.engine.dispose()
0184         return True
0185 
0186     # get model based on given model name
0187     def model(self, model_name):
0188         if model_name.__class__ == sqlalchemy.ext.declarative.api.DeclarativeMeta:
0189             model_name = model_name.__name__
0190         model_name = model_name.replace("_", "")
0191         return self.models[model_name]
0192 
0193     # model should be the class the developer wants to be instantiated
0194     # pk_to_value maps primary keys to values
0195     def object(self, model, pk_to_value):
0196         if self.session == None:
0197             return None
0198         model_data = self.session.query(model)
0199         for pk in pk_to_value:
0200             model_data = model_data.filter(model.__dict__[pk] == pk_to_value[pk])
0201         return model_data.first()
0202 
0203     def global_tag(self, **pkargs):
0204         return self.factory.object("globaltag", **pkargs)
0205 
0206     def global_tag_map(self, **pkargs):
0207         return self.factory.object("globaltagmap", **pkargs)
0208 
0209     """def global_tag_map_request(self, **pkargs):
0210         return self.factory.object("globaltagmaprequest", **pkargs)"""
0211 
0212     def tag(self, **pkargs):
0213         return self.factory.object("tag", **pkargs)
0214 
0215     def tag_authorization(self, **pkargs):
0216         return self.factory.object("tagauthorization", **pkargs)
0217 
0218     def iov(self, **pkargs):
0219         return self.factory.object("iov", **pkargs)
0220 
0221     def payload(self, **pkargs):
0222         return self.factory.object("payload", **pkargs)
0223 
0224     """def record(self, **pkargs):
0225         return self.factory.object("record", **pkargs)"""
0226 
0227     # adds %% at the beginning and end so LIKE in SQL searches all of the string
0228     def _oracle_match_format(self, string):
0229         return "%%%s%%" % string
0230 
0231     # returns dictionary mapping object type to a list of all objects found in the search
0232     def search_everything(self, string, amount=10):
0233         string = self._oracle_match_format(string)
0234 
0235         gt = self.model("globaltag")
0236         global_tags = self.session.query(gt).filter(or_(
0237                                                         gt.name.ilike(string),
0238                                                         gt.description.ilike(string),
0239                                                         gt.release.ilike(string)
0240                                                     )).limit(amount)
0241         tag = self.model("tag")
0242         tags = self.session.query(tag).filter(or_(
0243                                                     tag.name.ilike(string),
0244                                                     tag.object_type.ilike(string),
0245                                                     tag.description.ilike(string))
0246                                                 ).limit(amount)
0247         iov = self.model("iov")
0248         iovs = self.session.query(iov).filter(or_(
0249                                                     iov.tag_name.ilike(string),
0250                                                     iov.since.ilike(string),
0251                                                     iov.payload_hash.ilike(string),
0252                                                     iov.insertion_time.ilike(string)
0253                                                 )).limit(amount)
0254         payload = self.model("payload")
0255         payloads = self.session.query(payload).filter(or_(
0256                                                             payload.hash.ilike(string),
0257                                                             payload.object_type.ilike(string),
0258                                                             payload.insertion_time.ilike(string)
0259                                                         )).limit(amount)
0260 
0261         return json_data_node.make({
0262             "global_tags" : global_tags.all(),
0263             "tags" : tags.all(),
0264             "iovs" : iovs.all(),
0265             "payloads" : payloads.all()
0266         })
0267 
0268     def write(self, object):
0269         new_object = models.session_independent_object(object, schema=self.schema)
0270         self.session.add(new_object)
0271         return new_object
0272 
0273     def commit(self):
0274         try:
0275             self.session.commit()
0276         except:
0277             traceback.print_exc()
0278             self.session.rollback()
0279 
0280     def write_and_commit(self, object):
0281         if type(object) == list:
0282             for item in object:
0283                 self.write_and_commit(item)
0284         else:
0285             # should be changed to deal with errors - add them to exception handling if they appear
0286             self.write(object)
0287             self.commit()
0288 
0289     def rollback(self):
0290         try:
0291             self.session.rollback()
0292         except:
0293             traceback.print_exc()
0294             print("Session couldn't be rolled back.")
0295 
0296 class factory():
0297     """
0298     Contains methods for creating objects.
0299     """
0300     def __init__(self, connection):
0301         self.connection = connection
0302 
0303     # class_name is the class name of the model to be used
0304     # pkargs is a dictionary of keyword arguments used as primary key values
0305     # this dictionary will be used to populate the object of type name class_name
0306     def object(self, class_name, **pkargs):
0307         from .data_sources import json_list
0308         from .models import apply_filters
0309         # get the class that self.connection holds from the class name
0310         model = self.connection.model(class_name)
0311 
0312         if self.connection.session == None:
0313             return None
0314 
0315         # query for the ORM object, and return the appropriate object (None, CondDBFW object, or json_list)
0316         model_data = self.connection.session.query(model)
0317         if len(list(pkargs.items())) != 0:
0318             # apply the filters defined in **kwargs
0319             model_data = apply_filters(model_data, model, **pkargs)
0320             amount = pkargs["amount"] if "amount" in list(pkargs.keys()) else None
0321             model_data = model_data.limit(amount)
0322             if model_data.count() > 1:
0323                 # if we have multiple objects, return a json_list
0324                 return json_list(model_data.all())
0325             elif model_data.count() == 1:
0326                 # if we have a single object, return that object
0327                 return model_data.first()
0328             else:
0329                 # if we have no objects returned, return None
0330                 return None
0331         else:
0332             # no column arguments were given, so return an empty object
0333             new_object = model()
0334             new_object.empty = True
0335             return new_object
0336 
0337 def _get_netrc_data(netrc_file, key):
0338     """
0339     Returns a dictionary {login : ..., account : ..., password : ...}
0340     """
0341     try:
0342         headers = ["login", "account", "password"]
0343         authenticator_tuple = netrc.netrc(netrc_file).authenticators(key)
0344         if authenticator_tuple == None:
0345             raise Exception("netrc file must contain key '%s'." % key)
0346     except:
0347         raise Exception("Couldn't get credentials from netrc file.")
0348     return dict(list(zip(headers, authenticator_tuple)))
0349 
0350 def new_connection_dictionary(connection_data, secrets=None, mode="r"):
0351     """
0352     Function used to construct connection data dictionaries - internal to framework.
0353     """
0354     frontier_str_length = len("frontier://")
0355     sqlite_str_length = len("sqlite://")
0356     #sqlite_file_str_length = len("sqlite_file://")
0357     oracle_str_length = len("oracle://")
0358 
0359     if type(connection_data) in [str, str] and connection_data[0:frontier_str_length] == "frontier://":
0360         """
0361         frontier://database_name/schema
0362         """
0363         db_name = connection_data[frontier_str_length:].split("/")[0]
0364         schema = connection_data[frontier_str_length:].split("/")[1]
0365         connection_data = {}
0366         connection_data["database_name"] = db_name
0367         connection_data["schema"] = schema
0368         connection_data["host"] = "frontier"
0369         connection_data["secrets"] = None
0370     elif type(connection_data) in [str, str] and connection_data[0:sqlite_str_length] == "sqlite://":
0371         """
0372         sqlite://database_file_name
0373         """
0374         # for now, just support "sqlite://" format for sqlite connection strings
0375         db_name = connection_data[sqlite_str_length:]
0376         schema = ""
0377         connection_data = {}
0378         connection_data["database_name"] = os.path.abspath(db_name)
0379         connection_data["schema"] = schema
0380         connection_data["host"] = "sqlite"
0381         connection_data["secrets"] = None
0382     elif type(connection_data) in [str, str] and connection_data[0:oracle_str_length] == "oracle://":
0383         """
0384         oracle://account:password@database_name
0385         or
0386         oracle://database_name/schema (requires a separate method of authentication - either dictionary or netrc)
0387         """
0388         new_connection_string = connection_data[oracle_str_length:]
0389 
0390         if ":" in new_connection_string:
0391             # the user has given a password - usually in the case of the db upload service
0392             database_name = new_connection_string[new_connection_string.index("@")+1:]
0393             schema_name = new_connection_string[0:new_connection_string.index(":")]
0394             # set username based on connection string
0395             username = new_connection_string[0:new_connection_string.index(":")]
0396             password = new_connection_string[new_connection_string.index(":")+1:new_connection_string.index("@")]
0397         else:
0398             mode_to_netrc_key_suffix = {"r" : "read", "w" : "write"}
0399             database_name = new_connection_string[0:new_connection_string.index("/")]
0400             schema_name = new_connection_string[new_connection_string.index("/")+1:]
0401             if secrets == None:
0402                 username = str(input("Enter the username you want to connect to the schema '%s' with: " % (schema_name)))
0403                 password = str(input("Enter the password for the user '%s' in database '%s': " % (username, database_name)))
0404             else:
0405                 if type(secrets) == str:
0406                     netrc_key = "%s/%s/%s" % (database_name, schema_name, mode_to_netrc_key_suffix[mode])
0407                     netrc_data = _get_netrc_data(secrets, key=netrc_key)
0408                     # take the username from the netrc entry corresponding to the mode the database is opened in
0409                     # eg, if the user has given mode="read", the database_name/schema_name/read entry will be taken
0410                     username = netrc_data["login"]
0411                     password = netrc_data["password"]
0412                 elif type(secrets) == dict:
0413                     username = secrets["user"]
0414                     password = secrets["password"]
0415                 else:
0416                     raise Exception("Invalid type given for secrets.  Either an str or a dict must be given.")
0417 
0418         #print("Connected to database %s, schema %s, with username %s." % (database_name, schema_name, username))
0419 
0420         connection_data = {}
0421         connection_data["database_name"] = database_name
0422         connection_data["schema"] = schema_name
0423         connection_data["password"] = password
0424         connection_data["host"] = "oracle"
0425         connection_data["secrets"] = {"login" : username, "password" : password}
0426 
0427     return connection_data
0428 
0429 def engine_from_dictionary(dictionary, pooling=True):
0430     if dictionary["host"] != "sqlite":
0431         if dictionary["host"] != "frontier":
0432             # probably oracle
0433             # if not frontier, we have to authenticate
0434             user = dictionary["secrets"]["login"]
0435             pwd = dictionary["secrets"]["password"]
0436             # set max label length for oracle
0437             if pooling:
0438                 return create_engine(connection.build_oracle_url(user, pwd, dictionary["database_name"]), label_length=6)
0439             else:
0440                 return create_engine(connection.build_oracle_url(user, pwd, dictionary["database_name"]), label_length=6, poolclass=NullPool)
0441         else:
0442             # if frontier, no need to authenticate
0443             # set max label length for frontier
0444             if pooling:
0445                 return create_engine(connection.build_frontier_url(dictionary["database_name"], dictionary["schema"]), label_length=6)
0446             else:
0447                 return create_engine(connection.build_frontier_url(dictionary["database_name"], dictionary["schema"]), label_length=6, poolclass=NullPool)
0448     else:
0449         # if host is sqlite, making the url is easy - no authentication
0450         return create_engine("sqlite:///%s" % dictionary["database_name"])
0451 
0452 
0453 def connect(connection_data, mode="r", map_blobs=False, secrets=None, pooling=True):
0454     """
0455     Utility method for user - set up a connection object.
0456     """
0457     con = connection(connection_data=connection_data, mode=mode, map_blobs=map_blobs, secrets=secrets, pooling=pooling)
0458     con = con.setup()
0459     return con