tf2.0 load_data超時

load_data()超時解決

在做minist數據加載時,tf2.0雖然提供了方法,但是特麼的沒法用(由於連接方在一段時間後沒有正確答覆或連接的主機沒有反應,連接嘗試失敗。),翻牆也沒用,數據下載不了,因爲google把庫都刪了(https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz),但是沒改代碼,當然可能tf2.1改了。。。
在這裏插入圖片描述
下面交大家如何處理這種問題:
1.自己重寫這個load函數如下:


```python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers

import numpy as np

from tensorflow.python.keras.utils.data_utils import get_file


def load_data(path='D:\code\Python\Learn\mnist.npz'):
    origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
    path = get_file(
        path,
        origin=origin_folder + 'mnist.npz',
        file_hash=
        '8a61469f7ea1b51cbae51d4f78837e45'
    )
    with np.load(path) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']
        return (x_train, y_train), (x_test, y_test)
這裏實際就是把原來實現的函數改了一下,改了路徑和md5值。
路徑設爲自己存放minist的地方,md5自己生成一個對應的md5:![在這裏插入圖片描述](https://img-blog.csdnimg.cn/20200117175708190.png)
看源碼,他這裏的hash算法就是md5,註釋有。
改了以後 就可以用了。

下載minist文件 百度一堆11m大小。

完整代碼如下:

```python
#!/usr/bin/env python
# -*- encoding: utf-8 -*-

# @Time : 2020/1/17 15:39 

# @Author : shitao.li
# @Email : [email protected]
# @File : mnist_test.py 

# @Software: PyCharm


from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers

import numpy as np

from tensorflow.python.keras.utils.data_utils import get_file


def load_data(path='D:\code\Python\Learn\mnist.npz'):
    origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
    path = get_file(
        path,
        origin=origin_folder + 'mnist.npz',
        file_hash=
        '8a61469f7ea1b51cbae51d4f78837e45'
    )
    with np.load(path) as f:
        x_train, y_train = f['x_train'], f['y_train']
        x_test, y_test = f['x_test'], f['y_test']
        return (x_train, y_train), (x_test, y_test)


(xs, ys), _ = load_data()
print('dataSet: ', xs.shape, ys.shape)

xs = tf.convert_to_tensor(xs, dtype=tf.float32) / 255
db = tf.data.Dataset.from_tensor_slices((xs, ys))

for step, (x, y) in enumerate(db):
    print(step, x.shape, y, y.shape)
發佈了71 篇原創文章 · 獲贊 22 · 訪問量 20萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章