實用:python中的ORM對象關係映射

class Field:
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False):
        self.name = name
        self.fieldname = fieldname
        self.pk = pk
        self.unique = unique
        self.default = default
        self.nullable = nullable
        self.index = index

    def validate(self, value):
        raise NotImplementedError

    def __get__(self, instance, owner):
        if isinstance is None:
            return self
        return instance.__dict__[self.name]

    def __set__(self, instance, value):
        self.validate(value)
        instance.__dict__[self.name] = value

    def __str__(self):
        return '<{} {}>'.format(self.__class__.__name__,self.name)

    __repr__ = __str__

class IntField(Field):
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False,auto_increment=False):
        self.auto_increment = auto_increment
        super().__init__(name, fieldname, pk, unique, default, nullable, index)

    def validate(self, value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name, value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name, value))
        else:
            if not isinstance(value, int):
                raise TypeError('{} = {} error!!'.format(self.name, value))

class StringField(Field):
    def __init__(self,length=32,name=None, fieldname=None,pk=False,unique=False,default=False,nullable=True,index=False,auto_increment=False):
        self.length = length
        super().__init__(name,fieldname,pk,unique,default,nullable,index)

    def validate(self,value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name,value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name,value))
        else:
            if not isinstance(value,str):
                raise TypeError('{} = {} error!!'.format(self.name,value))
            if len(value) > self.length:
                raise ValueError('{} is too long,value={}'.format(self.name,value))

class Student:
    id = IntField('id')
    name = StringField(24,'name')

    def __init__(self,id,name):
        self.id = id
        self.name = name

s = Student(1,name='tom')
print(s.id,s.name)

運行結果:

1 tom

改進:

import pymysql

class Field:
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False):
        self.name = name
        self.fieldname = fieldname
        self.pk = pk
        self.unique = unique
        self.default = default
        self.nullable = nullable
        self.index = index

    def validate(self, value):
        raise NotImplementedError

    def __get__(self, instance, owner):
        if isinstance is None:
            return self
        return instance.__class__.__name__,instance.__dict__[self.name],self.__dict__

    def __set__(self, instance, value):
        self.validate(value)
        instance.__dict__[self.name] = value

    def __str__(self):
        return '<{} {}>'.format(self.__class__.__name__,self.name)

    __repr__ = __str__

class IntField(Field):
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False,auto_increment=False):
        self.auto_increment = auto_increment
        super().__init__(name, fieldname, pk, unique, default, nullable, index)

    def validate(self, value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name, value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name, value))
        else:
            if not isinstance(value, int):
                raise TypeError('{} = {} error!!'.format(self.name, value))

class StringField(Field):
    def __init__(self,length=32,name=None, fieldname=None,pk=False,unique=False,default=False,nullable=True,index=False,auto_increment=False):
        self.length = length
        super().__init__(name,fieldname,pk,unique,default,nullable,index)

    def validate(self,value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name,value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name,value))
        else:
            if not isinstance(value,str):
                raise TypeError('{} = {} error!!'.format(self.name,value))
            if len(value) > self.length:
                raise ValueError('{} is too long,value={}'.format(self.name,value))

class Student:
    id = IntField('id')
    name = StringField(24,'name')
    age = IntField('age')

    def __init__(self,id,name,age=18):
        self.id = id
        self.name = name
        self.age = age

    def save(self,conn:pymysql.connections.Connection):
        #連接數據庫
        #cursor execute(sql)
        sql = 'insert into t(id,name,age) values(%s,%s,%s);'
        try:
            with conn as cursor:
                with cursor:
                    line = cursor.execute(sql,(self.id,self.name,self.age))
                    conn.commit()
        except:
            conn.rollback()

#連接數據庫
conn = pymysql.connect('172.20.10.11','root','123456','test')
s = Student(1,'haha',27)
s.save(conn)

#返回表名+字段值+字段定義
print(s.id)
print(s.name)
print(s.age)

運行結果:

('Student', 1, {'auto_increment': False, 'unique': False, 'default': False, 'fieldname': None, 'index': False, 'name': 'id', 'pk': False, 'nullable': True})
('Student', 'haha', {'unique': False, 'default': False, 'fieldname': None, 'index': False, 'name': 'name', 'pk': False, 'length': 24, 'nullable': True})
('Student', 27, {'auto_increment': False, 'unique': False, 'default': False, 'fieldname': None, 'index': False, 'name': 'age', 'pk': False, 'nullable': True})

查看數據庫:
在這裏插入圖片描述
改進:

import pymysql

class Field:
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False):
        self.name = name
        self.fieldname = fieldname
        self.pk = pk
        self.unique = unique
        self.default = default
        self.nullable = nullable
        self.index = index

    def validate(self, value):
        raise NotImplementedError

    def __get__(self, instance, owner):
        if isinstance is None:
            return self
        return instance.__dict__[self.name]

    def __set__(self, instance, value):
        self.validate(value)
        instance.__dict__[self.name] = value

    def __str__(self):
        return '<{} {}>'.format(self.__class__.__name__,self.name)

    __repr__ = __str__

class IntField(Field):
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False,auto_increment=False):
        self.auto_increment = auto_increment
        super().__init__(name, fieldname, pk, unique, default, nullable, index)

    def validate(self, value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name, value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name, value))
        else:
            if not isinstance(value, int):
                raise TypeError('{} = {} error!!'.format(self.name, value))

class StringField(Field):
    def __init__(self,length=32,name=None, fieldname=None,pk=False,unique=False,default=False,nullable=True,index=False,auto_increment=False):
        self.length = length
        super().__init__(name,fieldname,pk,unique,default,nullable,index)

    def validate(self,value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name,value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name,value))
        else:
            if not isinstance(value,str):
                raise TypeError('{} = {} error!!'.format(self.name,value))
            if len(value) > self.length:
                raise ValueError('{} is too long,value={}'.format(self.name,value))

class Session:
    def __init__(self,conn:pymysql.connections.Connection):
        self.conn = conn

    def __enter__(self):
        return self.conn.cursor()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type:
            self.conn.rollback()
        else:
            self.conn.commit()

class Student:
    id = IntField('id')
    name = StringField(24,'name')
    age = IntField('age')

    def __init__(self,id,name=None,age=18):
        self.id = id
        self.name = name
        self.age = age

    def save(self,session:Session):
        #連接數據庫
        #cursor execute(sql)
        sql = 'insert into t(id,name,age) values(%s,%s,%s);'

        with session as cursor:
            with cursor:
                cursor.execute(sql,(self.id,self.name,self.age))

#連接數據庫
conn = pymysql.connect('172.20.10.11','root','123456','test')
session = Session(conn)
s = Student(1,'lala',27)
s.save(session)

print(s.id,s.name,s.age)

運行結果:

1 lala 27

查看數據庫結果:
在這裏插入圖片描述

改進2:(增加元類的使用)

import pymysql
import logging
logging.basicConfig(level=logging.INFO)


class Field:
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False):
        self.name = name
        self.fieldname = fieldname
        self.pk = pk
        self.unique = unique
        self.default = default
        self.nullable = nullable
        self.index = index

    def validate(self, value):
        raise NotImplementedError

    def __get__(self, instance, owner):
        if isinstance is None:
            return self
        return instance.__dict__[self.name]

    def __set__(self, instance, value):
        self.validate(value)
        instance.__dict__[self.name] = value

    def __str__(self):
        return '<{} {}>'.format(self.__class__.__name__,self.name)

    __repr__ = __str__

class IntField(Field):
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False,auto_increment=False):
        self.auto_increment = auto_increment
        super().__init__(name, fieldname, pk, unique, default, nullable, index)

    def validate(self, value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name, value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name, value))
        else:
            if not isinstance(value, int):
                raise TypeError('{} = {} error!!'.format(self.name, value))

class StringField(Field):
    def __init__(self,length=32,name=None, fieldname=None,pk=False,unique=False,default=False,nullable=True,index=False,auto_increment=False):
        self.length = length
        super().__init__(name,fieldname,pk,unique,default,nullable,index)

    def validate(self,value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name,value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name,value))
        else:
            if not isinstance(value,str):
                raise TypeError('{} = {} error!!'.format(self.name,value))
            if len(value) > self.length:
                raise ValueError('{} is too long,value={}'.format(self.name,value))

class Session:
    def __init__(self,conn:pymysql.connections.Connection):
        self.conn = conn

    def __enter__(self):
        return self.conn.cursor()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type:
            self.conn.rollback()
        else:
            self.conn.commit()

class ModelMeta(type):
    def __new__(cls,name:str,bases,attrs:dict):
        #解決表名的問題
        if attrs.get('__tablename__',None) is None:
            attrs['__tablename__'] = name.lower()

        for k,v in attrs.items():
            print('{}:  -----> {}'.format(k,v))
            if isinstance(v,Field):
                logging.info('{}:{}'.format(k,v))

        return super().__new__(cls,name,bases,attrs)

class Student(metaclass=ModelMeta):
    #__tablename__ = 'student_2019'沒有這行將會使用類名小寫作爲表名,由ModelMeta類的__new__方法中的attrs['__tablename__'] = name.lower()可保證
    id = IntField('id')
    name = StringField(24,'name')
    age = IntField('age')

    def __init__(self,id,name=None,age=18):
        self.id = id
        self.name = name
        self.age = age

    def save(self,session:Session):
        #連接數據庫
        #cursor execute(sql)
        sql = 'insert into t(id,name,age) values(%s,%s,%s);'

        with session as cursor:
            with cursor:
                cursor.execute(sql,(self.id,self.name,self.age))

#連接數據庫
conn = pymysql.connect('172.20.10.11','root','123456','test')
session = Session(conn)
s = Student(1,'lala',27)
s.save(session)

print(s.id,s.name,s.age)

運行結果:

__tablename__:  -----> student
__module__:  -----> __main__
__qualname__:  -----> Student
name:  -----> <StringField name>
age:  -----> <IntField age>
__init__:  -----> <function Student.__init__ at 0x7f751299e048>
save:  -----> <function Student.save at 0x7f751299e0d0>
id:  -----> <IntField id>
INFO:root:name:<StringField name>
INFO:root:age:<IntField age>
INFO:root:id:<IntField id>
1 lala 27

改進3:

import pymysql

class Field:
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False):
        self.name = name
        self.fieldname = fieldname
        self.pk = pk
        self.unique = unique
        self.default = default
        self.nullable = nullable
        self.index = index

    def validate(self, value):
        raise NotImplementedError

    def __get__(self, instance, owner):
        if isinstance is None:
            return self
        return instance.__dict__[self.name]

    def __set__(self, instance, value):
        self.validate(value)
        instance.__dict__[self.name] = value

    def __str__(self):
        return '<{} {}>'.format(self.__class__.__name__,self.name)

    __repr__ = __str__

class IntField(Field):
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False,auto_increment=False):
        self.auto_increment = auto_increment
        super().__init__(name, fieldname, pk, unique, default, nullable, index)

    def validate(self, value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name, value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name, value))
        else:
            if not isinstance(value, int):
                raise TypeError('{} = {} error!!'.format(self.name, value))

class StringField(Field):
    def __init__(self,length=32,name=None, fieldname=None,pk=False,unique=False,default=False,nullable=True,index=False,auto_increment=False):
        self.length = length
        super().__init__(name,fieldname,pk,unique,default,nullable,index)

    def validate(self,value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name,value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name,value))
        else:
            if not isinstance(value,str):
                raise TypeError('{} = {} error!!'.format(self.name,value))
            if len(value) > self.length:
                raise ValueError('{} is too long,value={}'.format(self.name,value))

class Session:
    def __init__(self,conn:pymysql.connections.Connection):
        self.conn = conn

    def __enter__(self):
        return self.conn.cursor()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type:
            self.conn.rollback()
        else:
            self.conn.commit()

class ModelMeta(type):
    def __new__(cls,name:str,bases,attrs:dict):
        #解決表名的問題
        if attrs.get('__tablename__',None) is None:
            attrs['__tablename__'] = name.lower()
        mapping = {}
        for k,v in attrs.items():
            if isinstance(v,Field):
                mapping[k] = v
        attrs['mapping'] = mapping

        return super().__new__(cls,name,bases,attrs)

class Model(metaclass=ModelMeta):
    def save(self,session:Session):
        #連接數據庫
        names = []
        values = []
        for k,v in type(self).__dict__.items():
            if isinstance(v,Field):
                names.append(k)
                values.append(self.__dict__[k])
        # sql = 'insert into t(id,name,age) values(%s,%s,%s);'
        sql = 'insert into {0}({1}) values({2});'.format(self.__tablename__,','.join(names),','.join(['%s']*len(names)))

        with session as cursor:
            with cursor:
                print(names,values)
                cursor.execute(sql,values)


class Student(Model):
    __tablename__ = 't'#沒有這行將會使用類名小寫作爲表名,由ModelMeta類的__new__方法中的attrs['__tablename__'] = name.lower()可保證
    id = IntField('id')
    name = StringField(24,'name')
    age = IntField('age')

    def __init__(self,id,name=None,age=18):
        self.id = id
        self.name = name
        self.age = age

#連接數據庫
conn = pymysql.connect('172.20.10.11','root','123456','test')
session = Session(conn)
s = Student(1,'yzx2',27)
s.save(session)

運行結果:

['age', 'name', 'id'] [27, 'yzx2', 1]

查看數據庫變化:
在這裏插入圖片描述
改進4:

import pymysql


class Field:
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False):
        self.name = name
        self.fieldname = fieldname
        self.pk = pk
        self.unique = unique
        self.default = default
        self.nullable = nullable
        self.index = index

    def validate(self, value):
        raise NotImplementedError

    def __get__(self, instance, owner):
        if isinstance is None:
            return self
        return instance.__dict__[self.name]

    def __set__(self, instance, value):
        self.validate(value)
        instance.__dict__[self.name] = value

    def __str__(self):
        return '<{} {}>'.format(self.__class__.__name__, self.name)

    __repr__ = __str__


class IntField(Field):
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False,
                 auto_increment=False):
        self.auto_increment = auto_increment
        super().__init__(name, fieldname, pk, unique, default, nullable, index)

    def validate(self, value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name, value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name, value))
        else:
            if not isinstance(value, int):
                raise TypeError('{} = {} error!!'.format(self.name, value))


class StringField(Field):
    def __init__(self, length=32, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True,
                 index=False, auto_increment=False):
        self.length = length
        super().__init__(name, fieldname, pk, unique, default, nullable, index)

    def validate(self, value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name, value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name, value))
        else:
            if not isinstance(value, str):
                raise TypeError('{} = {} error!!'.format(self.name, value))
            if len(value) > self.length:
                raise ValueError('{} is too long,value={}'.format(self.name, value))


class Session:
    def __init__(self, conn: pymysql.connections.Connection):
        self.conn = conn

    def __enter__(self):
        return self.conn.cursor()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type:
            self.conn.rollback()
        else:
            self.conn.commit()


class ModelMeta(type):
    def __new__(cls, name: str, bases, attrs: dict):
        # 解決表名的問題
        if attrs.get('__tablename__', None) is None:
            attrs['__tablename__'] = name.lower()
        mapping = {}
        primarykey = []
        for k, v in attrs.items():
            if isinstance(v, Field):
                mapping[k] = v
                if v.name is None:
                    v.name = k
                if v.fieldname is None:
                    v.fieldname = v.name
                if v.pk:
                    primarykey.append(v)
        attrs['mapping'] = mapping
        attrs['primarykey'] = primarykey

        return super().__new__(cls, name, bases, attrs)


class Model(metaclass=ModelMeta):
    def save(self, session: Session):
        # 連接數據庫
        names = []
        values = []
        for k, v in type(self).__dict__.items():
            if isinstance(v, Field):
                names.append(k)
                values.append(self.__dict__[k])
        # sql = 'insert into t(id,name,age) values(%s,%s,%s);'
        sql = 'insert into {0}({1}) values({2});'.format(self.__tablename__, ','.join(names),
                                                         ','.join(['%s'] * len(names)))

        with session as cursor:
            with cursor:
                print(names, values)
                cursor.execute(sql, values)


class Student(Model):
    __tablename__ = 't'  # 沒有這行將會使用類名小寫作爲表名,由ModelMeta類的__new__方法中的attrs['__tablename__'] = name.lower()可保證
    id = IntField()
    name = StringField()
    age = IntField()

    def __init__(self, id, name=None, age=18):
        self.id = id
        self.name = name
        self.age = age


# 連接數據庫
conn = pymysql.connect('172.20.10.11', 'root', '123456', 'test')
session = Session(conn)
s = Student(1, 'yzx3', 27)
s.save(session)

運行結果:

['id', 'age', 'name'] [1, 27, 'yzx3']

查看數據庫變化:
在這裏插入圖片描述
改進5:(增加引擎)

import pymysql


class Field:
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False):
        self.name = name
        self.fieldname = fieldname
        self.pk = pk
        self.unique = unique
        self.default = default
        self.nullable = nullable
        self.index = index

    def validate(self, value):
        raise NotImplementedError

    def __get__(self, instance, owner):
        if isinstance is None:
            return self
        return instance.__dict__[self.name]

    def __set__(self, instance, value):
        self.validate(value)
        instance.__dict__[self.name] = value

    def __str__(self):
        return '<{} {}>'.format(self.__class__.__name__, self.name)

    __repr__ = __str__


class IntField(Field):
    def __init__(self, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True, index=False,
                 auto_increment=False):
        self.auto_increment = auto_increment
        super().__init__(name, fieldname, pk, unique, default, nullable, index)

    def validate(self, value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name, value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name, value))
        else:
            if not isinstance(value, int):
                raise TypeError('{} = {} error!!'.format(self.name, value))


class StringField(Field):
    def __init__(self, length=32, name=None, fieldname=None, pk=False, unique=False, default=False, nullable=True,
                 index=False, auto_increment=False):
        self.length = length
        super().__init__(name, fieldname, pk, unique, default, nullable, index)

    def validate(self, value):
        if value is None:
            if self.pk:
                raise TypeError('{} = {} error!!'.format(self.name, value))
            if not self.nullable:
                raise TypeError('{} = {} error!!'.format(self.name, value))
        else:
            if not isinstance(value, str):
                raise TypeError('{} = {} error!!'.format(self.name, value))
            if len(value) > self.length:
                raise ValueError('{} is too long,value={}'.format(self.name, value))


class Session:
    def __init__(self, conn: pymysql.connections.Connection):
        self.conn = conn

    def __enter__(self):
        return self.conn.cursor()

    def __exit__(self, exc_type, exc_val, exc_tb):
        if exc_type:
            self.conn.rollback()
        else:
            self.conn.commit()


class ModelMeta(type):
    def __new__(cls, name: str, bases, attrs: dict):
        # 解決表名的問題
        if attrs.get('__tablename__', None) is None:
            attrs['__tablename__'] = name.lower()
        mapping = {}
        primarykey = []
        for k, v in attrs.items():
            if isinstance(v, Field):
                mapping[k] = v
                if v.name is None:
                    v.name = k
                if v.fieldname is None:
                    v.fieldname = v.name
                if v.pk:
                    primarykey.append(v)
        attrs['mapping'] = mapping
        attrs['primarykey'] = primarykey

        return super().__new__(cls, name, bases, attrs)


class Model(metaclass=ModelMeta): pass


class Student(Model):
    __tablename__ = 't'  # 沒有這行將會使用類名小寫作爲表名,由ModelMeta類的__new__方法中的attrs['__tablename__'] = name.lower()可保證
    id = IntField()
    name = StringField()
    age = IntField()

    def __init__(self, id, name=None, age=18):
        self.id = id
        self.name = name
        self.age = age

class Engine:
    def __init__(self,*args,**kwargs):
        self.conn = pymysql.connect(*args,**kwargs)

    def save(self, instance:Student):
        # 連接數據庫
        names = []
        values = []
        for k, v in instance.mapping.items():
            if isinstance(v, Field):
                names.append(k)
                values.append(instance.__dict__[k])
        # sql = 'insert into t(id,name,age) values(%s,%s,%s);'
        sql = 'insert into {0}({1}) values({2});'.format(instance.__tablename__, ','.join(names),','.join(['%s'] * len(names)))

        with Session(self.conn) as cursor:
            with cursor:
                print(names, values)
                cursor.execute(sql, values)

# 連接數據庫
engine = Engine('172.20.10.11', 'root', '123456', 'test')
s = Student(1, 'yzx4', 27)
engine.save(s)

運行結果:

['id', 'name', 'age'] [1, 'yzx4', 27]

查看數據庫變化:
在這裏插入圖片描述

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章