TensorFlow結合OpenCV實現手寫數字識別

利用OpenCV的繪圖功能與TensorFlow的模型來識別手寫數字。

完整代碼GitHub上:跳轉按鈕.

1.加載數據

data = pd.read_csv('train.csv')

x = data.loc[:,data.columns != 'label'].values.astype(np.float32)
y = data['label'].values
x = x / 255.0

data = tf.data.Dataset.from_tensor_slices((x,y))
data_loader = data.repeat().shuffle(5000).batch(128).prefetch(1)

2.構造模型

class network(tf.keras.Model):
    def __init__(self):
        super(network, self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(32,kernel_size=5,activation=tf.nn.relu)
        self.maxpool1 =tf.keras.layers.MaxPool2D(2,strides=2)

        self.conv2 = tf.keras.layers.Conv2D(64,kernel_size=3,activation=tf.nn.relu)
        self.maxpool2 = tf.keras.layers.MaxPool2D(2,strides=2)

        self.flatten = tf.keras.layers.Flatten()
        self.fc1 = tf.keras.layers.Dense(1024)
        self.dropout = tf.keras.layers.Dropout(rate=0.5)
        self.out = tf.keras.layers.Dense(10)

    def call(self,x,is_training=False):
        x = tf.reshape(x,[-1,28,28,1])
        x = self.conv1(x)
        x = self.maxpool1(x)
        x = self.conv2(x)
        x = self.maxpool2(x)
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.dropout(x)
        x = self.out(x)
        if not is_training:
            x = tf.nn.softmax(x)
        return x
conv = network()
  1. 定義損失函數和精度函數
def cross_entropy_loss(x,y):
    y = tf.cast(y,tf.int64)
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=y,logits=x)
    return tf.reduce_mean(loss)

def accuracy(y_pred,y_true):
    correct_pred = tf.equal(tf.argmax(y_pred,1),tf.cast(y_true,tf.int64))
    return tf.reduce_mean(tf.cast(correct_pred,tf.float32))
  1. 構造優化函數以及訓練代碼
optimizer = tf.optimizers.SGD(0.01)

def run_optimizer(x,y):
    with tf.GradientTape() as g:
        pred = conv(x,is_training=True)
        loss = cross_entropy_loss(pred,y)
    training_variable = conv.trainable_variables
    gradient = g.gradient(loss,training_variable)
    optimizer.apply_gradients(zip(gradient,training_variable))

for i,(x_batch,y_batch) in enumerate(data_loader.take(1000),1):
    run_optimizer(x_batch,y_batch)
    if i%50==0:
        pred = conv(x_batch)
        acc = accuracy(pred,y_batch)
        print("%f"%(acc))
  1. 保存模型
conv.save_weights('mannul_checkpoint')
  1. 加載模型
conv.load_weights('mannul_checkpoint')
  1. 鼠標繪圖

drawing = False  # 是否開始畫圖
mode = True  # True:畫矩形,False:畫圓
start = (-1, -1)
def mouse_event(event, x, y, flags, param):
    global start, drawing, mode
    # 左鍵按下:開始畫圖
    if event == cv2.EVENT_LBUTTONDOWN:
        drawing = True
        start = (x, y)
    # 鼠標移動,畫圖
    elif event == cv2.EVENT_MOUSEMOVE:
        if drawing:
            cv2.circle(img, (x, y), 8, (random.randint(185,255)), -1)
    # 左鍵釋放:結束畫圖
    elif event == cv2.EVENT_LBUTTONUP:
        drawing = False
  1. 將繪製的圖片送入TensorFlow模型中進行識別,並在圖像中顯示類別

img = np.zeros((512, 512, 1), np.uint8)
img[:,:] = 100
cv2.namedWindow('image')
cv2.setMouseCallback('image', mouse_event)

while(True):
    temp = img.copy()
    a = temp[:400,:400]
    cv2.imshow('a',cv2.resize(a,(28,28)))
    temp = temp.astype(np.float32)
    temp = cv2.resize(temp[:400,:400],(28,28))
    temp = np.reshape(temp,[1,784])

    pred = conv(temp, is_training=False)
    b = np.argmax(pred.numpy(),axis=1)

    cv2.imshow('image', img)
    # 按下m切換模式

    if cv2.waitKey(1) == ord('a'):
        img = np.zeros((512, 512, 1), np.uint8)

        cv2.putText(img, str(b), (250, 450), cv2.FONT_HERSHEY_COMPLEX, 2.0, (100, 200, 200), 5)
    elif cv2.waitKey(1) == 27:
        break

效果圖

發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章