load_data()超時解決
在做minist數據加載時,tf2.0雖然提供了方法,但是特麼的沒法用(由於連接方在一段時間後沒有正確答覆或連接的主機沒有反應,連接嘗試失敗。),翻牆也沒用,數據下載不了,因爲google把庫都刪了(https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz),但是沒改代碼,當然可能tf2.1改了。。。
下面交大家如何處理這種問題:
1.自己重寫這個load函數如下:
```python
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers
import numpy as np
from tensorflow.python.keras.utils.data_utils import get_file
def load_data(path='D:\code\Python\Learn\mnist.npz'):
origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
origin=origin_folder + 'mnist.npz',
file_hash=
'8a61469f7ea1b51cbae51d4f78837e45'
)
with np.load(path) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
return (x_train, y_train), (x_test, y_test)
這裏實際就是把原來實現的函數改了一下,改了路徑和md5值。
路徑設爲自己存放minist的地方,md5自己生成一個對應的md5:![在這裏插入圖片描述](https://img-blog.csdnimg.cn/20200117175708190.png)
看源碼,他這裏的hash算法就是md5,註釋有。
改了以後 就可以用了。
下載minist文件 百度一堆11m大小。
完整代碼如下:
```python
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
# @Time : 2020/1/17 15:39
# @Author : shitao.li
# @Email : [email protected]
# @File : mnist_test.py
# @Software: PyCharm
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers
import numpy as np
from tensorflow.python.keras.utils.data_utils import get_file
def load_data(path='D:\code\Python\Learn\mnist.npz'):
origin_folder = 'https://storage.googleapis.com/tensorflow/tf-keras-datasets/'
path = get_file(
path,
origin=origin_folder + 'mnist.npz',
file_hash=
'8a61469f7ea1b51cbae51d4f78837e45'
)
with np.load(path) as f:
x_train, y_train = f['x_train'], f['y_train']
x_test, y_test = f['x_test'], f['y_test']
return (x_train, y_train), (x_test, y_test)
(xs, ys), _ = load_data()
print('dataSet: ', xs.shape, ys.shape)
xs = tf.convert_to_tensor(xs, dtype=tf.float32) / 255
db = tf.data.Dataset.from_tensor_slices((xs, ys))
for step, (x, y) in enumerate(db):
print(step, x.shape, y, y.shape)