File indexing completed on 2024-04-06 12:01:53
0001 """
0002
0003 connection class translates either a connection string for sqlite, oracle of frontier into a connection object.
0004 Also sets up ORM with SQLAlchemy.
0005
0006 connection class can also take a pre-constructed engine - useful for web services.
0007
0008 """
0009
0010 import sqlalchemy
0011 from sqlalchemy import create_engine, text, or_
0012 from sqlalchemy.orm import sessionmaker
0013 from sqlalchemy.pool import NullPool
0014 import datetime
0015 from .data_sources import json_data_node
0016 from copy import deepcopy
0017 from . import models
0018 import traceback
0019 import os
0020 import netrc
0021 import sys
0022
0023 class connection(object):
0024 engine = None
0025 connection = None
0026 session = None
0027 connection_data = None
0028 netrc_authenticators = None
0029 secrets = None
0030 """
0031
0032 Given a connection string, parses the connection string and connects.
0033
0034 """
0035 def __init__(self, connection_data, mode=None, map_blobs=False, secrets=None, pooling=False):
0036
0037 self._pooling = pooling
0038
0039
0040
0041
0042 self.range = models.Range
0043 self.radius = models.Radius
0044 self.regexp = models.RegExp
0045 self.regexp.connection_object = self
0046
0047 if type(connection_data) in [str, str]:
0048
0049 self.connection_data = new_connection_dictionary(connection_data, secrets=secrets, mode=mode)
0050 self.schema = self.connection_data.get("schema") if self.connection_data.get("schema") != None else ""
0051
0052 self.range.database_type = self.connection_data["host"]
0053 self.radius.database_type = self.connection_data["host"]
0054 self.regexp.database_type = self.connection_data["host"]
0055 else:
0056 self.connection_data = connection_data
0057
0058
0059 engine_string = str(connection_data)
0060 db_type = None
0061 if "oracle" in engine_string:
0062 db_type = "oracle"
0063 elif "frontier" in engine_string:
0064 db_type = "frontier"
0065 elif "sqlite" in engine_string:
0066 db_type = "sqlite"
0067
0068 self.range.database_type = db_type
0069 self.radius.database_type = db_type
0070 self.regexp.database_type = db_type
0071
0072 from . import models as ms
0073 self.models = ms.generate(map_blobs)
0074
0075
0076 def setup(self):
0077 """
0078 Setup engine with given credentials from netrc file, and make a session maker.
0079 """
0080
0081 if type(self.connection_data) == dict:
0082 self.engine = engine_from_dictionary(self.connection_data, pooling=self._pooling)
0083 else:
0084
0085
0086 self.engine = self.connection_data
0087
0088 self.sessionmaker = sessionmaker(bind=self.engine)
0089 self.session = self.sessionmaker()
0090 self.factory = factory(self)
0091
0092
0093 tmp_models_dict = {}
0094 for key in self.models:
0095 if self.models[key].__class__ == sqlalchemy.ext.declarative.api.DeclarativeMeta\
0096 and str(self.models[key].__name__) != "Base":
0097
0098 if type(self.connection_data) == dict:
0099
0100
0101
0102
0103 self.models[key].__table__.schema = self.connection_data["schema"]
0104 self.models[key].secrets = self.connection_data["secrets"]
0105
0106 self.models[key].session = self.session
0107
0108
0109 self.models[key].connection = self
0110 tmp_models_dict[key.lower()] = self.models[key]
0111 tmp_models_dict[key.lower()].empty = False
0112
0113 self.models = tmp_models_dict
0114
0115 return self
0116
0117 @staticmethod
0118 def _get_CMS_frontier_connection_string(database):
0119 try:
0120 import subprocess
0121 return subprocess.Popen(['cmsGetFnConnect', 'frontier://%s' % database], stdout = subprocess.PIPE).communicate()[0].strip()
0122 except:
0123 raise Exception("Frontier connections can only be constructed when inside a CMSSW environment.")
0124
0125 @staticmethod
0126 def _cms_frontier_string(database, schema="cms_conditions"):
0127 """
0128 Get database string for frontier.
0129 """
0130 import urllib.request, urllib.parse, urllib.error
0131 return 'oracle+frontier://@%s/%s' % (urllib.parse.quote_plus(connection._get_CMS_frontier_connection_string(database)), schema)
0132
0133 @staticmethod
0134 def _cms_oracle_string(user, pwd, db_name):
0135 """
0136 Get database string for oracle.
0137 """
0138 return 'oracle://%s:%s@%s' % (user, pwd, db_name)
0139
0140 @staticmethod
0141 def build_oracle_url(user, pwd, db_name):
0142 """
0143 Build the connection url, and get credentials from self.secrets dictionary.
0144 """
0145
0146 database_url = connection._cms_oracle_string(user, pwd, db_name)
0147
0148 try:
0149 url = sqlalchemy.engine.url.make_url(database_url)
0150 if url.password is None:
0151 url.password = pwd
0152 except sqlalchemy.exc.ArgumentError:
0153 url = sqlalchemy.engine.url.make_url('sqlite:///%s' % db_name)
0154 return url
0155
0156 @staticmethod
0157 def build_frontier_url(db_name, schema):
0158 database_url = connection._cms_frontier_string(db_name, schema)
0159
0160 try:
0161 url = sqlalchemy.engine.url.make_url(database_url)
0162 except sqlalchemy.exc.ArgumentError:
0163 """
0164 Is this needed for a use case?
0165 """
0166 url = sqlalchemy.engine.url.make_url('sqlite:///%s' % db_name)
0167 return url
0168
0169
0170
0171 def tear_down(self):
0172 try:
0173 self.session.commit()
0174 self.close_session()
0175 except:
0176 return "Couldn't tear down connection on engine %s." % str(self.engine)
0177
0178 def close_session(self):
0179 self.session.close()
0180 return True
0181
0182 def hard_close(self):
0183 self.engine.dispose()
0184 return True
0185
0186
0187 def model(self, model_name):
0188 if model_name.__class__ == sqlalchemy.ext.declarative.api.DeclarativeMeta:
0189 model_name = model_name.__name__
0190 model_name = model_name.replace("_", "")
0191 return self.models[model_name]
0192
0193
0194
0195 def object(self, model, pk_to_value):
0196 if self.session == None:
0197 return None
0198 model_data = self.session.query(model)
0199 for pk in pk_to_value:
0200 model_data = model_data.filter(model.__dict__[pk] == pk_to_value[pk])
0201 return model_data.first()
0202
0203 def global_tag(self, **pkargs):
0204 return self.factory.object("globaltag", **pkargs)
0205
0206 def global_tag_map(self, **pkargs):
0207 return self.factory.object("globaltagmap", **pkargs)
0208
0209 """def global_tag_map_request(self, **pkargs):
0210 return self.factory.object("globaltagmaprequest", **pkargs)"""
0211
0212 def tag(self, **pkargs):
0213 return self.factory.object("tag", **pkargs)
0214
0215 def tag_authorization(self, **pkargs):
0216 return self.factory.object("tagauthorization", **pkargs)
0217
0218 def iov(self, **pkargs):
0219 return self.factory.object("iov", **pkargs)
0220
0221 def payload(self, **pkargs):
0222 return self.factory.object("payload", **pkargs)
0223
0224 """def record(self, **pkargs):
0225 return self.factory.object("record", **pkargs)"""
0226
0227
0228 def _oracle_match_format(self, string):
0229 return "%%%s%%" % string
0230
0231
0232 def search_everything(self, string, amount=10):
0233 string = self._oracle_match_format(string)
0234
0235 gt = self.model("globaltag")
0236 global_tags = self.session.query(gt).filter(or_(
0237 gt.name.ilike(string),
0238 gt.description.ilike(string),
0239 gt.release.ilike(string)
0240 )).limit(amount)
0241 tag = self.model("tag")
0242 tags = self.session.query(tag).filter(or_(
0243 tag.name.ilike(string),
0244 tag.object_type.ilike(string),
0245 tag.description.ilike(string))
0246 ).limit(amount)
0247 iov = self.model("iov")
0248 iovs = self.session.query(iov).filter(or_(
0249 iov.tag_name.ilike(string),
0250 iov.since.ilike(string),
0251 iov.payload_hash.ilike(string),
0252 iov.insertion_time.ilike(string)
0253 )).limit(amount)
0254 payload = self.model("payload")
0255 payloads = self.session.query(payload).filter(or_(
0256 payload.hash.ilike(string),
0257 payload.object_type.ilike(string),
0258 payload.insertion_time.ilike(string)
0259 )).limit(amount)
0260
0261 return json_data_node.make({
0262 "global_tags" : global_tags.all(),
0263 "tags" : tags.all(),
0264 "iovs" : iovs.all(),
0265 "payloads" : payloads.all()
0266 })
0267
0268 def write(self, object):
0269 new_object = models.session_independent_object(object, schema=self.schema)
0270 self.session.add(new_object)
0271 return new_object
0272
0273 def commit(self):
0274 try:
0275 self.session.commit()
0276 except:
0277 traceback.print_exc()
0278 self.session.rollback()
0279
0280 def write_and_commit(self, object):
0281 if type(object) == list:
0282 for item in object:
0283 self.write_and_commit(item)
0284 else:
0285
0286 self.write(object)
0287 self.commit()
0288
0289 def rollback(self):
0290 try:
0291 self.session.rollback()
0292 except:
0293 traceback.print_exc()
0294 print("Session couldn't be rolled back.")
0295
0296 class factory():
0297 """
0298 Contains methods for creating objects.
0299 """
0300 def __init__(self, connection):
0301 self.connection = connection
0302
0303
0304
0305
0306 def object(self, class_name, **pkargs):
0307 from .data_sources import json_list
0308 from .models import apply_filters
0309
0310 model = self.connection.model(class_name)
0311
0312 if self.connection.session == None:
0313 return None
0314
0315
0316 model_data = self.connection.session.query(model)
0317 if len(list(pkargs.items())) != 0:
0318
0319 model_data = apply_filters(model_data, model, **pkargs)
0320 amount = pkargs["amount"] if "amount" in list(pkargs.keys()) else None
0321 model_data = model_data.limit(amount)
0322 if model_data.count() > 1:
0323
0324 return json_list(model_data.all())
0325 elif model_data.count() == 1:
0326
0327 return model_data.first()
0328 else:
0329
0330 return None
0331 else:
0332
0333 new_object = model()
0334 new_object.empty = True
0335 return new_object
0336
0337 def _get_netrc_data(netrc_file, key):
0338 """
0339 Returns a dictionary {login : ..., account : ..., password : ...}
0340 """
0341 try:
0342 headers = ["login", "account", "password"]
0343 authenticator_tuple = netrc.netrc(netrc_file).authenticators(key)
0344 if authenticator_tuple == None:
0345 raise Exception("netrc file must contain key '%s'." % key)
0346 except:
0347 raise Exception("Couldn't get credentials from netrc file.")
0348 return dict(list(zip(headers, authenticator_tuple)))
0349
0350 def new_connection_dictionary(connection_data, secrets=None, mode="r"):
0351 """
0352 Function used to construct connection data dictionaries - internal to framework.
0353 """
0354 frontier_str_length = len("frontier://")
0355 sqlite_str_length = len("sqlite://")
0356
0357 oracle_str_length = len("oracle://")
0358
0359 if type(connection_data) in [str, str] and connection_data[0:frontier_str_length] == "frontier://":
0360 """
0361 frontier://database_name/schema
0362 """
0363 db_name = connection_data[frontier_str_length:].split("/")[0]
0364 schema = connection_data[frontier_str_length:].split("/")[1]
0365 connection_data = {}
0366 connection_data["database_name"] = db_name
0367 connection_data["schema"] = schema
0368 connection_data["host"] = "frontier"
0369 connection_data["secrets"] = None
0370 elif type(connection_data) in [str, str] and connection_data[0:sqlite_str_length] == "sqlite://":
0371 """
0372 sqlite://database_file_name
0373 """
0374
0375 db_name = connection_data[sqlite_str_length:]
0376 schema = ""
0377 connection_data = {}
0378 connection_data["database_name"] = os.path.abspath(db_name)
0379 connection_data["schema"] = schema
0380 connection_data["host"] = "sqlite"
0381 connection_data["secrets"] = None
0382 elif type(connection_data) in [str, str] and connection_data[0:oracle_str_length] == "oracle://":
0383 """
0384 oracle://account:password@database_name
0385 or
0386 oracle://database_name/schema (requires a separate method of authentication - either dictionary or netrc)
0387 """
0388 new_connection_string = connection_data[oracle_str_length:]
0389
0390 if ":" in new_connection_string:
0391
0392 database_name = new_connection_string[new_connection_string.index("@")+1:]
0393 schema_name = new_connection_string[0:new_connection_string.index(":")]
0394
0395 username = new_connection_string[0:new_connection_string.index(":")]
0396 password = new_connection_string[new_connection_string.index(":")+1:new_connection_string.index("@")]
0397 else:
0398 mode_to_netrc_key_suffix = {"r" : "read", "w" : "write"}
0399 database_name = new_connection_string[0:new_connection_string.index("/")]
0400 schema_name = new_connection_string[new_connection_string.index("/")+1:]
0401 if secrets == None:
0402 username = str(input("Enter the username you want to connect to the schema '%s' with: " % (schema_name)))
0403 password = str(input("Enter the password for the user '%s' in database '%s': " % (username, database_name)))
0404 else:
0405 if type(secrets) == str:
0406 netrc_key = "%s/%s/%s" % (database_name, schema_name, mode_to_netrc_key_suffix[mode])
0407 netrc_data = _get_netrc_data(secrets, key=netrc_key)
0408
0409
0410 username = netrc_data["login"]
0411 password = netrc_data["password"]
0412 elif type(secrets) == dict:
0413 username = secrets["user"]
0414 password = secrets["password"]
0415 else:
0416 raise Exception("Invalid type given for secrets. Either an str or a dict must be given.")
0417
0418
0419
0420 connection_data = {}
0421 connection_data["database_name"] = database_name
0422 connection_data["schema"] = schema_name
0423 connection_data["password"] = password
0424 connection_data["host"] = "oracle"
0425 connection_data["secrets"] = {"login" : username, "password" : password}
0426
0427 return connection_data
0428
0429 def engine_from_dictionary(dictionary, pooling=True):
0430 if dictionary["host"] != "sqlite":
0431 if dictionary["host"] != "frontier":
0432
0433
0434 user = dictionary["secrets"]["login"]
0435 pwd = dictionary["secrets"]["password"]
0436
0437 if pooling:
0438 return create_engine(connection.build_oracle_url(user, pwd, dictionary["database_name"]), label_length=6)
0439 else:
0440 return create_engine(connection.build_oracle_url(user, pwd, dictionary["database_name"]), label_length=6, poolclass=NullPool)
0441 else:
0442
0443
0444 if pooling:
0445 return create_engine(connection.build_frontier_url(dictionary["database_name"], dictionary["schema"]), label_length=6)
0446 else:
0447 return create_engine(connection.build_frontier_url(dictionary["database_name"], dictionary["schema"]), label_length=6, poolclass=NullPool)
0448 else:
0449
0450 return create_engine("sqlite:///%s" % dictionary["database_name"])
0451
0452
0453 def connect(connection_data, mode="r", map_blobs=False, secrets=None, pooling=True):
0454 """
0455 Utility method for user - set up a connection object.
0456 """
0457 con = connection(connection_data=connection_data, mode=mode, map_blobs=map_blobs, secrets=secrets, pooling=pooling)
0458 con = con.setup()
0459 return con