Source code for rdm.db.context

from collections import defaultdict
import pprint
import copy

import mysql.connector as mysql
import psycopg2 as postgresql
import converters

from datasource import MySQLDataSource, PgSQLDataSource


class DBVendor:
    MySQL = 'mysql'
    PostgreSQL = 'postgresql'


class DBConnection:
    '''
    Database credentials.
    '''

    class Manager:
        '''
        Context Manager.
        '''

        def __init__(self, user, password, host, database, dal_connect_fun):
            self.con = dal_connect_fun(
                user=user,
                password=password,
                host=host,
                database=database
            )

        def __enter__(self):
            return self.con

        def __exit__(self, exc_type, exc_value, traceback):
            self.con.close()

    def __init__(self, user, password, host, database, vendor=DBVendor.MySQL):
        self.user = user
        self.password = password
        self.host = host
        self.database = database
        self.vendor = vendor
        self.check_connection()

        if self.vendor == DBVendor.MySQL:
            self.src = MySQLDataSource(self)
        elif self.vendor == DBVendor.PostgreSQL:
            self.src = PgSQLDataSource(self)
        else:
            raise Exception("Unknown DB vendor: {}".format(vendor))

    def check_connection(self):
        try:
            with self.connect() as _:
                pass
        except Exception, e:
            raise Exception('Problem connecting to the database. Please re-check your credentials.')

    def connection(self):
        return self.connect().con

    def connect(self):
        dal_connect_fun = None
        if self.vendor == DBVendor.MySQL:
            dal_connect_fun = mysql.connect
        elif self.vendor == DBVendor.PostgreSQL:
            dal_connect_fun = postgresql.connect

        if not dal_connect_fun:
            raise Exception('Unsupported or unset database vendor: {}'.format(dal_connect_fun))

        return DBConnection.Manager(self.user, self.password, self.host, self.database, dal_connect_fun)


[docs]class DBContext:
[docs] def __init__(self, connection, target_table=None, target_att=None, find_connections=False, in_memory=True): ''' Initializes a new DBContext object from the given DBConnection. :param connection: a DBConnection instance :param target_table: set a target table for learning :param target_att: set a target table attribute for learning :param find_connections: set to True if you want to detect relationships based on attribute and table names, \ e.g., ``train_id`` is the foreign key refering to ``id`` in table ``train``. :param in_memory: Load the database into main memory (currently required for most approaches and pre-processing) ''' self.src = connection.src self.tables = self.src.tables() self.cols = {} for table in self.tables: self.cols[table] = self.src.table_columns(table) self.all_cols = dict(self.cols) self.col_vals = {} conn_data = self.src.connected( self.tables, self.cols, find_connections=find_connections ) self.connected, self.pkeys, self.fkeys, self.reverse_fkeys = conn_data self.target_table = self.tables[0] if not target_table else target_table self.target_att = None if not target_att else target_att self.orng_tables = None self.in_memory = in_memory if in_memory: self.orng_tables = self.read_into_orange()
def read_into_orange(self): conv = converters.OrangeConverter(self) tables = { self.target_table: conv.target_Orange_table() } other_tbl_names = [table for table in self.tables if table != self.target_table] other_tables = dict(zip(other_tbl_names, conv.other_Orange_tables())) tables.update(other_tables) return tables
[docs] def fetch(self, table, cols): ''' Fetches rows from the db. :param table: table name to select :cols: list of columns to select :return: list of rows :rtype: list ''' return self.src.fetch(table, cols)
[docs] def rows(self, table, cols): ''' Fetches rows from the local cache or from the db if there's no cache. :param table: table name to select :cols: list of columns to select :return: list of rows :rtype: list ''' if self.orng_tables: data = [] for ex in self.orng_tables[table]: data.append([ex[str(col)] for col in cols]) return data else: return self.fetch(table, cols)
[docs] def select_where(self, table, cols, pk_att, pk): ''' SELECT with WHERE clause. :param table: target table :param cols: list of columns to select :param pk_att: attribute for the where clause :param pk: the id that the pk_att should match :return: rows from the given table and cols, with the condition pk_att==pk :rtype: list ''' if self.orng_tables: data = [] for ex in self.orng_tables[table]: if str(ex[str(pk_att)]) == str(pk): data.append([ex[str(col)] for col in cols]) return data else: return self.src.select_where(table, cols, pk_att, pk)
[docs] def fetch_types(self, table, cols): ''' Returns a dictionary of field types for the given table and columns. :param table: target table :param cols: list of columns to select :return: a dictionary of types for each attribute :rtype: dict ''' return self.src.fetch_types(table, cols)
def compute_col_vals(self): for table, cols in self.cols.items(): self.col_vals[table] = {} for col in cols: self.col_vals[table][col] = self.src.column_values(table, col)
[docs] def copy(self): ''' Makes a deepcopy of the DBContext object (e.g., for making folds) :returns: a deep copy of ``self``. :rtype: DBContext ''' return copy.deepcopy(self)
def __repr__(self): return pprint.pformat({ 'target_table': self.target_table, 'target attribute': self.target_att, 'tables': self.tables, 'cols': self.cols, 'connected': self.connected, 'pkeys': self.pkeys, 'fkeys': self.fkeys, 'orng_tables': [(name, len(table)) for name, table in self.orng_tables.items()] if self.orng_tables else 'not in memory' })