有问题欢迎与我交流。
评论留言或者联系我的邮箱:jiaohaibin@ruc.edu.cn
数据由JQData本地量化金融数据支持
实验2:
使?历史前5个时刻的 open close high low volume money
预测当前时刻的收盘价,
即 [None, 5, 6] => [None, 1] # None是 batch_size
这一篇继续对 实验2的模型 进行拓展,增加Attention机制
先写点Attention的简单介绍
attention本质:
其实就是一个加权求和。
attention处理的问题,往往面临的是这样一个场景:
你有k个d维的特征向量hi(i=1,2,...,k)。现在你想整合这k个特征向量的信息,变成一个向量h?(一般也是d维)。
solution:
1.一个最简单粗暴的办法就是这k个向量以element-wise取平均,得到新的向量,作为h?,显然不够合理。
2.较为合理的办法就是,加权平均,即(αi为权重): 而attention所做的事情就是如何将αi(权重)合理的算出来。
神经科学和计算神经科学中的neural processes已经广泛研究了注意力机制。视觉注意力机制是一个特别值得研究的方向:许多动物专注于视觉输入的特定部分,去计算适当的反映。这个原理对神经计算有很大的影响,因为我们需要选择最相关的信息,而不是使用所有可用的信息,所有可用信息中有很大一部分与计算神经元反映无关。一个类似于视觉专注于输入的特定部分,也就是注意力机制已经用于深度学习、语音识别、翻译、推理以及视觉识别。
模型架构
实验结果:结果看误差:MSE Test loss/误差: 0.0005358342003804944
import pandas as pdimport time, datetimedf_data_5minute=pd.read_csv('黄金主力5分钟数据.csv')
df_data_5minute.head()
.dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
Unnamed: 0 | open | close | high | low | volume | money | |
---|---|---|---|---|---|---|---|
0 | 2016-01-04 09:05:00 | 226.70 | 226.65 | 226.85 | 226.45 | 5890.0 | 1.335146e+09 |
1 | 2016-01-04 09:10:00 | 226.75 | 226.50 | 226.75 | 226.40 | 2562.0 | 5.804133e+08 |
2 | 2016-01-04 09:15:00 | 226.45 | 226.45 | 226.60 | 226.40 | 1638.0 | 3.709666e+08 |
3 | 2016-01-04 09:20:00 | 226.45 | 226.25 | 226.50 | 226.20 | 3162.0 | 7.157891e+08 |
4 | 2016-01-04 09:25:00 | 226.25 | 226.25 | 226.30 | 226.20 | 1684.0 | 3.809907e+08 |
df_data_5minute
.dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
Unnamed: 0 | open | close | high | low | volume | money | |
---|---|---|---|---|---|---|---|
0 | 2016-01-04 09:05:00 | 226.70 | 226.65 | 226.85 | 226.45 | 5890.0 | 1.335146e+09 |
1 | 2016-01-04 09:10:00 | 226.75 | 226.50 | 226.75 | 226.40 | 2562.0 | 5.804133e+08 |
2 | 2016-01-04 09:15:00 | 226.45 | 226.45 | 226.60 | 226.40 | 1638.0 | 3.709666e+08 |
3 | 2016-01-04 09:20:00 | 226.45 | 226.25 | 226.50 | 226.20 | 3162.0 | 7.157891e+08 |
4 | 2016-01-04 09:25:00 | 226.25 | 226.25 | 226.30 | 226.20 | 1684.0 | 3.809907e+08 |
5 | 2016-01-04 09:30:00 | 226.25 | 226.30 | 226.35 | 226.20 | 922.0 | 2.086313e+08 |
6 | 2016-01-04 09:35:00 | 226.30 | 226.35 | 226.40 | 226.20 | 2476.0 | 5.603541e+08 |
7 | 2016-01-04 09:40:00 | 226.30 | 226.45 | 226.45 | 226.25 | 2516.0 | 5.695246e+08 |
8 | 2016-01-04 09:45:00 | 226.45 | 226.35 | 226.45 | 226.30 | 1344.0 | 3.042327e+08 |
9 | 2016-01-04 09:50:00 | 226.30 | 226.30 | 226.35 | 226.20 | 1414.0 | 3.199363e+08 |
10 | 2016-01-04 09:55:00 | 226.35 | 226.45 | 226.50 | 226.30 | 1610.0 | 3.645328e+08 |
11 | 2016-01-04 10:00:00 | 226.45 | 226.40 | 226.50 | 226.40 | 972.0 | 2.200957e+08 |
12 | 2016-01-04 10:05:00 | 226.40 | 226.50 | 226.55 | 226.35 | 2004.0 | 4.538166e+08 |
13 | 2016-01-04 10:10:00 | 226.50 | 226.45 | 226.55 | 226.40 | 780.0 | 1.766423e+08 |
14 | 2016-01-04 10:15:00 | 226.45 | 226.45 | 226.50 | 226.40 | 1530.0 | 3.464690e+08 |
15 | 2016-01-04 10:35:00 | 226.55 | 226.45 | 226.65 | 226.45 | 2564.0 | 5.807784e+08 |
16 | 2016-01-04 10:40:00 | 226.45 | 226.50 | 226.55 | 226.45 | 900.0 | 2.038475e+08 |
17 | 2016-01-04 10:45:00 | 226.55 | 226.70 | 226.80 | 226.50 | 3008.0 | 6.817039e+08 |
18 | 2016-01-04 10:50:00 | 226.70 | 226.65 | 226.85 | 226.60 | 2510.0 | 5.691306e+08 |
19 | 2016-01-04 10:55:00 | 226.65 | 226.60 | 226.65 | 226.60 | 930.0 | 2.107595e+08 |
20 | 2016-01-04 11:00:00 | 226.65 | 226.75 | 226.75 | 226.60 | 1184.0 | 2.683818e+08 |
21 | 2016-01-04 11:05:00 | 226.75 | 226.65 | 226.75 | 226.60 | 1044.0 | 2.366603e+08 |
22 | 2016-01-04 11:10:00 | 226.65 | 226.60 | 226.70 | 226.60 | 342.0 | 7.751130e+07 |
23 | 2016-01-04 11:15:00 | 226.60 | 226.60 | 226.65 | 226.55 | 640.0 | 1.450196e+08 |
24 | 2016-01-04 11:20:00 | 226.60 | 226.65 | 226.70 | 226.60 | 502.0 | 1.137778e+08 |
25 | 2016-01-04 11:25:00 | 226.65 | 226.95 | 226.95 | 226.65 | 3222.0 | 7.308042e+08 |
26 | 2016-01-04 11:30:00 | 226.90 | 226.90 | 226.95 | 226.80 | 1472.0 | 3.339398e+08 |
27 | 2016-01-04 13:35:00 | 227.10 | 227.25 | 227.25 | 227.00 | 4894.0 | 1.111496e+09 |
28 | 2016-01-04 13:40:00 | 227.25 | 227.55 | 227.60 | 227.20 | 5338.0 | 1.214103e+09 |
29 | 2016-01-04 13:45:00 | 227.60 | 227.75 | 228.00 | 227.50 | 8612.0 | 1.961599e+09 |
... | ... | ... | ... | ... | ... | ... | ... |
53280 | 2017-12-29 10:35:00 | 278.05 | 277.95 | 278.05 | 277.90 | 448.0 | 1.245318e+08 |
53281 | 2017-12-29 10:40:00 | 277.90 | 277.95 | 278.00 | 277.90 | 506.0 | 1.406423e+08 |
53282 | 2017-12-29 10:45:00 | 277.95 | 277.95 | 278.00 | 277.95 | 180.0 | 5.003790e+07 |
53283 | 2017-12-29 10:50:00 | 277.95 | 278.00 | 278.05 | 277.95 | 936.0 | 2.602273e+08 |
53284 | 2017-12-29 10:55:00 | 278.05 | 277.90 | 278.05 | 277.90 | 942.0 | 2.618281e+08 |
53285 | 2017-12-29 11:00:00 | 277.85 | 277.90 | 277.95 | 277.85 | 518.0 | 1.439454e+08 |
53286 | 2017-12-29 11:05:00 | 277.95 | 277.95 | 277.95 | 277.90 | 614.0 | 1.706443e+08 |
53287 | 2017-12-29 11:10:00 | 277.90 | 277.90 | 277.95 | 277.85 | 1046.0 | 2.906776e+08 |
53288 | 2017-12-29 11:15:00 | 277.95 | 277.90 | 277.95 | 277.90 | 206.0 | 5.725350e+07 |
53289 | 2017-12-29 11:20:00 | 277.90 | 277.90 | 277.95 | 277.85 | 740.0 | 2.056435e+08 |
53290 | 2017-12-29 11:25:00 | 277.90 | 277.85 | 277.90 | 277.85 | 200.0 | 5.557570e+07 |
53291 | 2017-12-29 11:30:00 | 277.90 | 277.90 | 277.95 | 277.85 | 756.0 | 2.100840e+08 |
53292 | 2017-12-29 13:35:00 | 277.90 | 278.00 | 278.00 | 277.90 | 490.0 | 1.362097e+08 |
53293 | 2017-12-29 13:40:00 | 278.00 | 278.05 | 278.15 | 278.00 | 768.0 | 2.135675e+08 |
53294 | 2017-12-29 13:45:00 | 278.10 | 278.15 | 278.15 | 278.05 | 252.0 | 7.008070e+07 |
53295 | 2017-12-29 13:50:00 | 278.10 | 278.05 | 278.10 | 278.00 | 800.0 | 2.224430e+08 |
53296 | 2017-12-29 13:55:00 | 278.00 | 278.00 | 278.05 | 277.95 | 184.0 | 5.115390e+07 |
53297 | 2017-12-29 14:00:00 | 278.00 | 277.95 | 278.00 | 277.90 | 474.0 | 1.317464e+08 |
53298 | 2017-12-29 14:05:00 | 277.95 | 277.95 | 277.95 | 277.90 | 334.0 | 9.282880e+07 |
53299 | 2017-12-29 14:10:00 | 277.95 | 277.90 | 277.95 | 277.90 | 332.0 | 9.226560e+07 |
53300 | 2017-12-29 14:15:00 | 277.90 | 277.95 | 277.95 | 277.90 | 672.0 | 1.867720e+08 |
53301 | 2017-12-29 14:20:00 | 277.90 | 277.85 | 277.95 | 277.85 | 994.0 | 2.762458e+08 |
53302 | 2017-12-29 14:25:00 | 277.90 | 277.90 | 277.95 | 277.85 | 352.0 | 9.781830e+07 |
53303 | 2017-12-29 14:30:00 | 277.90 | 277.80 | 277.95 | 277.80 | 784.0 | 2.178426e+08 |
53304 | 2017-12-29 14:35:00 | 277.85 | 277.80 | 277.85 | 277.75 | 920.0 | 2.555711e+08 |
53305 | 2017-12-29 14:40:00 | 277.80 | 277.80 | 277.85 | 277.75 | 606.0 | 1.683349e+08 |
53306 | 2017-12-29 14:45:00 | 277.80 | 277.85 | 277.85 | 277.80 | 560.0 | 1.555840e+08 |
53307 | 2017-12-29 14:50:00 | 277.85 | 277.85 | 277.90 | 277.80 | 802.0 | 2.228271e+08 |
53308 | 2017-12-29 14:55:00 | 277.85 | 277.75 | 277.90 | 277.75 | 1236.0 | 3.433855e+08 |
53309 | 2017-12-29 15:00:00 | 277.80 | 277.80 | 277.90 | 277.70 | 1790.0 | 4.972797e+08 |
53310 rows × 7 columns
df_data_5minute.drop('Unnamed: 0', axis=1, inplace=True)df_data_5minute
.dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
open | close | high | low | volume | money | |
---|---|---|---|---|---|---|
0 | 226.70 | 226.65 | 226.85 | 226.45 | 5890.0 | 1.335146e+09 |
1 | 226.75 | 226.50 | 226.75 | 226.40 | 2562.0 | 5.804133e+08 |
2 | 226.45 | 226.45 | 226.60 | 226.40 | 1638.0 | 3.709666e+08 |
3 | 226.45 | 226.25 | 226.50 | 226.20 | 3162.0 | 7.157891e+08 |
4 | 226.25 | 226.25 | 226.30 | 226.20 | 1684.0 | 3.809907e+08 |
5 | 226.25 | 226.30 | 226.35 | 226.20 | 922.0 | 2.086313e+08 |
6 | 226.30 | 226.35 | 226.40 | 226.20 | 2476.0 | 5.603541e+08 |
7 | 226.30 | 226.45 | 226.45 | 226.25 | 2516.0 | 5.695246e+08 |
8 | 226.45 | 226.35 | 226.45 | 226.30 | 1344.0 | 3.042327e+08 |
9 | 226.30 | 226.30 | 226.35 | 226.20 | 1414.0 | 3.199363e+08 |
10 | 226.35 | 226.45 | 226.50 | 226.30 | 1610.0 | 3.645328e+08 |
11 | 226.45 | 226.40 | 226.50 | 226.40 | 972.0 | 2.200957e+08 |
12 | 226.40 | 226.50 | 226.55 | 226.35 | 2004.0 | 4.538166e+08 |
13 | 226.50 | 226.45 | 226.55 | 226.40 | 780.0 | 1.766423e+08 |
14 | 226.45 | 226.45 | 226.50 | 226.40 | 1530.0 | 3.464690e+08 |
15 | 226.55 | 226.45 | 226.65 | 226.45 | 2564.0 | 5.807784e+08 |
16 | 226.45 | 226.50 | 226.55 | 226.45 | 900.0 | 2.038475e+08 |
17 | 226.55 | 226.70 | 226.80 | 226.50 | 3008.0 | 6.817039e+08 |
18 | 226.70 | 226.65 | 226.85 | 226.60 | 2510.0 | 5.691306e+08 |
19 | 226.65 | 226.60 | 226.65 | 226.60 | 930.0 | 2.107595e+08 |
20 | 226.65 | 226.75 | 226.75 | 226.60 | 1184.0 | 2.683818e+08 |
21 | 226.75 | 226.65 | 226.75 | 226.60 | 1044.0 | 2.366603e+08 |
22 | 226.65 | 226.60 | 226.70 | 226.60 | 342.0 | 7.751130e+07 |
23 | 226.60 | 226.60 | 226.65 | 226.55 | 640.0 | 1.450196e+08 |
24 | 226.60 | 226.65 | 226.70 | 226.60 | 502.0 | 1.137778e+08 |
25 | 226.65 | 226.95 | 226.95 | 226.65 | 3222.0 | 7.308042e+08 |
26 | 226.90 | 226.90 | 226.95 | 226.80 | 1472.0 | 3.339398e+08 |
27 | 227.10 | 227.25 | 227.25 | 227.00 | 4894.0 | 1.111496e+09 |
28 | 227.25 | 227.55 | 227.60 | 227.20 | 5338.0 | 1.214103e+09 |
29 | 227.60 | 227.75 | 228.00 | 227.50 | 8612.0 | 1.961599e+09 |
... | ... | ... | ... | ... | ... | ... |
53280 | 278.05 | 277.95 | 278.05 | 277.90 | 448.0 | 1.245318e+08 |
53281 | 277.90 | 277.95 | 278.00 | 277.90 | 506.0 | 1.406423e+08 |
53282 | 277.95 | 277.95 | 278.00 | 277.95 | 180.0 | 5.003790e+07 |
53283 | 277.95 | 278.00 | 278.05 | 277.95 | 936.0 | 2.602273e+08 |
53284 | 278.05 | 277.90 | 278.05 | 277.90 | 942.0 | 2.618281e+08 |
53285 | 277.85 | 277.90 | 277.95 | 277.85 | 518.0 | 1.439454e+08 |
53286 | 277.95 | 277.95 | 277.95 | 277.90 | 614.0 | 1.706443e+08 |
53287 | 277.90 | 277.90 | 277.95 | 277.85 | 1046.0 | 2.906776e+08 |
53288 | 277.95 | 277.90 | 277.95 | 277.90 | 206.0 | 5.725350e+07 |
53289 | 277.90 | 277.90 | 277.95 | 277.85 | 740.0 | 2.056435e+08 |
53290 | 277.90 | 277.85 | 277.90 | 277.85 | 200.0 | 5.557570e+07 |
53291 | 277.90 | 277.90 | 277.95 | 277.85 | 756.0 | 2.100840e+08 |
53292 | 277.90 | 278.00 | 278.00 | 277.90 | 490.0 | 1.362097e+08 |
53293 | 278.00 | 278.05 | 278.15 | 278.00 | 768.0 | 2.135675e+08 |
53294 | 278.10 | 278.15 | 278.15 | 278.05 | 252.0 | 7.008070e+07 |
53295 | 278.10 | 278.05 | 278.10 | 278.00 | 800.0 | 2.224430e+08 |
53296 | 278.00 | 278.00 | 278.05 | 277.95 | 184.0 | 5.115390e+07 |
53297 | 278.00 | 277.95 | 278.00 | 277.90 | 474.0 | 1.317464e+08 |
53298 | 277.95 | 277.95 | 277.95 | 277.90 | 334.0 | 9.282880e+07 |
53299 | 277.95 | 277.90 | 277.95 | 277.90 | 332.0 | 9.226560e+07 |
53300 | 277.90 | 277.95 | 277.95 | 277.90 | 672.0 | 1.867720e+08 |
53301 | 277.90 | 277.85 | 277.95 | 277.85 | 994.0 | 2.762458e+08 |
53302 | 277.90 | 277.90 | 277.95 | 277.85 | 352.0 | 9.781830e+07 |
53303 | 277.90 | 277.80 | 277.95 | 277.80 | 784.0 | 2.178426e+08 |
53304 | 277.85 | 277.80 | 277.85 | 277.75 | 920.0 | 2.555711e+08 |
53305 | 277.80 | 277.80 | 277.85 | 277.75 | 606.0 | 1.683349e+08 |
53306 | 277.80 | 277.85 | 277.85 | 277.80 | 560.0 | 1.555840e+08 |
53307 | 277.85 | 277.85 | 277.90 | 277.80 | 802.0 | 2.228271e+08 |
53308 | 277.85 | 277.75 | 277.90 | 277.75 | 1236.0 | 3.433855e+08 |
53309 | 277.80 | 277.80 | 277.90 | 277.70 | 1790.0 | 4.972797e+08 |
53310 rows × 6 columns
df=df_data_5minuteclose = df['close']df.drop(labels=['close'], axis=1,inplace = True)df.insert(0, 'close', close)df
.dataframe tbody tr th:only-of-type { vertical-align: middle; } .dataframe tbody tr th { vertical-align: top; } .dataframe thead th { text-align: right; }
close | open | high | low | volume | money | |
---|---|---|---|---|---|---|
0 | 226.65 | 226.70 | 226.85 | 226.45 | 5890.0 | 1.335146e+09 |
1 | 226.50 | 226.75 | 226.75 | 226.40 | 2562.0 | 5.804133e+08 |
2 | 226.45 | 226.45 | 226.60 | 226.40 | 1638.0 | 3.709666e+08 |
3 | 226.25 | 226.45 | 226.50 | 226.20 | 3162.0 | 7.157891e+08 |
4 | 226.25 | 226.25 | 226.30 | 226.20 | 1684.0 | 3.809907e+08 |
5 | 226.30 | 226.25 | 226.35 | 226.20 | 922.0 | 2.086313e+08 |
6 | 226.35 | 226.30 | 226.40 | 226.20 | 2476.0 | 5.603541e+08 |
7 | 226.45 | 226.30 | 226.45 | 226.25 | 2516.0 | 5.695246e+08 |
8 | 226.35 | 226.45 | 226.45 | 226.30 | 1344.0 | 3.042327e+08 |
9 | 226.30 | 226.30 | 226.35 | 226.20 | 1414.0 | 3.199363e+08 |
10 | 226.45 | 226.35 | 226.50 | 226.30 | 1610.0 | 3.645328e+08 |
11 | 226.40 | 226.45 | 226.50 | 226.40 | 972.0 | 2.200957e+08 |
12 | 226.50 | 226.40 | 226.55 | 226.35 | 2004.0 | 4.538166e+08 |
13 | 226.45 | 226.50 | 226.55 | 226.40 | 780.0 | 1.766423e+08 |
14 | 226.45 | 226.45 | 226.50 | 226.40 | 1530.0 | 3.464690e+08 |
15 | 226.45 | 226.55 | 226.65 | 226.45 | 2564.0 | 5.807784e+08 |
16 | 226.50 | 226.45 | 226.55 | 226.45 | 900.0 | 2.038475e+08 |
17 | 226.70 | 226.55 | 226.80 | 226.50 | 3008.0 | 6.817039e+08 |
18 | 226.65 | 226.70 | 226.85 | 226.60 | 2510.0 | 5.691306e+08 |
19 | 226.60 | 226.65 | 226.65 | 226.60 | 930.0 | 2.107595e+08 |
20 | 226.75 | 226.65 | 226.75 | 226.60 | 1184.0 | 2.683818e+08 |
21 | 226.65 | 226.75 | 226.75 | 226.60 | 1044.0 | 2.366603e+08 |
22 | 226.60 | 226.65 | 226.70 | 226.60 | 342.0 | 7.751130e+07 |
23 | 226.60 | 226.60 | 226.65 | 226.55 | 640.0 | 1.450196e+08 |
24 | 226.65 | 226.60 | 226.70 | 226.60 | 502.0 | 1.137778e+08 |
25 | 226.95 | 226.65 | 226.95 | 226.65 | 3222.0 | 7.308042e+08 |
26 | 226.90 | 226.90 | 226.95 | 226.80 | 1472.0 | 3.339398e+08 |
27 | 227.25 | 227.10 | 227.25 | 227.00 | 4894.0 | 1.111496e+09 |
28 | 227.55 | 227.25 | 227.60 | 227.20 | 5338.0 | 1.214103e+09 |
29 | 227.75 | 227.60 | 228.00 | 227.50 | 8612.0 | 1.961599e+09 |
... | ... | ... | ... | ... | ... | ... |
53280 | 277.95 | 278.05 | 278.05 | 277.90 | 448.0 | 1.245318e+08 |
53281 | 277.95 | 277.90 | 278.00 | 277.90 | 506.0 | 1.406423e+08 |
53282 | 277.95 | 277.95 | 278.00 | 277.95 | 180.0 | 5.003790e+07 |
53283 | 278.00 | 277.95 | 278.05 | 277.95 | 936.0 | 2.602273e+08 |
53284 | 277.90 | 278.05 | 278.05 | 277.90 | 942.0 | 2.618281e+08 |
53285 | 277.90 | 277.85 | 277.95 | 277.85 | 518.0 | 1.439454e+08 |
53286 | 277.95 | 277.95 | 277.95 | 277.90 | 614.0 | 1.706443e+08 |
53287 | 277.90 | 277.90 | 277.95 | 277.85 | 1046.0 | 2.906776e+08 |
53288 | 277.90 | 277.95 | 277.95 | 277.90 | 206.0 | 5.725350e+07 |
53289 | 277.90 | 277.90 | 277.95 | 277.85 | 740.0 | 2.056435e+08 |
53290 | 277.85 | 277.90 | 277.90 | 277.85 | 200.0 | 5.557570e+07 |
53291 | 277.90 | 277.90 | 277.95 | 277.85 | 756.0 | 2.100840e+08 |
53292 | 278.00 | 277.90 | 278.00 | 277.90 | 490.0 | 1.362097e+08 |
53293 | 278.05 | 278.00 | 278.15 | 278.00 | 768.0 | 2.135675e+08 |
53294 | 278.15 | 278.10 | 278.15 | 278.05 | 252.0 | 7.008070e+07 |
53295 | 278.05 | 278.10 | 278.10 | 278.00 | 800.0 | 2.224430e+08 |
53296 | 278.00 | 278.00 | 278.05 | 277.95 | 184.0 | 5.115390e+07 |
53297 | 277.95 | 278.00 | 278.00 | 277.90 | 474.0 | 1.317464e+08 |
53298 | 277.95 | 277.95 | 277.95 | 277.90 | 334.0 | 9.282880e+07 |
53299 | 277.90 | 277.95 | 277.95 | 277.90 | 332.0 | 9.226560e+07 |
53300 | 277.95 | 277.90 | 277.95 | 277.90 | 672.0 | 1.867720e+08 |
53301 | 277.85 | 277.90 | 277.95 | 277.85 | 994.0 | 2.762458e+08 |
53302 | 277.90 | 277.90 | 277.95 | 277.85 | 352.0 | 9.781830e+07 |
53303 | 277.80 | 277.90 | 277.95 | 277.80 | 784.0 | 2.178426e+08 |
53304 | 277.80 | 277.85 | 277.85 | 277.75 | 920.0 | 2.555711e+08 |
53305 | 277.80 | 277.80 | 277.85 | 277.75 | 606.0 | 1.683349e+08 |
53306 | 277.85 | 277.80 | 277.85 | 277.80 | 560.0 | 1.555840e+08 |
53307 | 277.85 | 277.85 | 277.90 | 277.80 | 802.0 | 2.228271e+08 |
53308 | 277.75 | 277.85 | 277.90 | 277.75 | 1236.0 | 3.433855e+08 |
53309 | 277.80 | 277.80 | 277.90 | 277.70 | 1790.0 | 4.972797e+08 |
53310 rows × 6 columns
data_train =df.iloc[:int(df.shape[0] * 0.7), :]data_test = df.iloc[int(df.shape[0] * 0.7):, :]print(data_train.shape, data_test.shape)
(37317, 6) (15993, 6)
# -*- coding: utf-8 -*-import pandas as pdimport numpy as npimport tensorflow as tfimport matplotlib.pyplot as plt%matplotlib inlinefrom sklearn.preprocessing import MinMaxScalerimport timescaler = MinMaxScaler(feature_range=(-1, 1))scaler.fit(data_train)
/Users/jiaohaibin/anaconda3/lib/python3.6/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`. from ._conv import register_converters as _register_converters
MinMaxScaler(copy=True, feature_range=(-1, 1))
data_train = scaler.transform(data_train)data_test = scaler.transform(data_test)
data_train
array([[-0.98877193, -0.98736842, -0.98459384, -0.99297259, -0.82504604, -0.85978547], [-0.99298246, -0.98596491, -0.98739496, -0.99437807, -0.92389948, -0.93904608], [-0.99438596, -0.99438596, -0.99159664, -0.99437807, -0.95134557, -0.96104178], ..., [ 0.61263158, 0.61824561, 0.61484594, 0.61349262, -0.90916652, -0.90885626], [ 0.61684211, 0.61403509, 0.61204482, 0.61630358, -0.94754352, -0.94737162], [ 0.6154386 , 0.6154386 , 0.61064426, 0.61349262, -0.94445435, -0.9442865 ]])
from keras.layers import Input, Dense, LSTMfrom keras.models import Modelfrom keras.layers import *from keras.models import *from keras.optimizers import Adamoutput_dim = 1batch_size = 256 #每轮训练模型时,样本的数量epochs = 60 #训练60轮次seq_len = 5hidden_size = 128TIME_STEPS = 5INPUT_DIM = 6lstm_units = 64X_train = np.array([data_train[i : i + seq_len, :] for i in range(data_train.shape[0] - seq_len)])y_train = np.array([data_train[i + seq_len, 0] for i in range(data_train.shape[0]- seq_len)])X_test = np.array([data_test[i : i + seq_len, :] for i in range(data_test.shape[0]- seq_len)])y_test = np.array([data_test[i + seq_len, 0] for i in range(data_test.shape[0] - seq_len)])print(X_train.shape, y_train.shape, X_test.shape, y_test.shape)
(37312, 5, 6) (37312,) (15988, 5, 6) (15988,)
Using TensorFlow backend.
inputs = Input(shape=(TIME_STEPS, INPUT_DIM))#drop1 = Dropout(0.3)(inputs)x = Conv1D(filters = 64, kernel_size = 1, activation = 'relu')(inputs) #, padding = 'same'#x = Conv1D(filters=128, kernel_size=5, activation='relu')(output1)#embedded_sequencesx = MaxPooling1D(pool_size = 5)(x)x = Dropout(0.2)(x)print(x.shape)
WARNING:tensorflow:From /Users/jiaohaibin/anaconda3/lib/python3.6/site-packages/tensorflow/python/util/deprecation.py:497: calling conv1d (from tensorflow.python.ops.nn_ops) with data_format=NHWC is deprecated and will be removed in a future version. Instructions for updating: `NHWC` for data_format is deprecated, use `NWC` instead (?, 1, 64)
lstm_out = Bidirectional(LSTM(lstm_units, activation='relu'), name='bilstm')(x)#lstm_out = LSTM(lstm_units,activation='relu')(x)print(lstm_out.shape)
(?, 128)
from keras import backend as Kfrom keras.engine.topology import Layerimport numpy as npfrom keras import initializers# Attention GRU network 未用 class AttLayer(Layer):def __init__(self, **kwargs):self.init = initializers.get('normal')#self.input_spec = [InputSpec(ndim=3)]super(AttLayer, self).__init__(**kwargs)def build(self, input_shape):assert len(input_shape)==128#self.W = self.init((input_shape[-1],1))self.W = self.init((input_shape[-1],))#self.input_spec = [InputSpec(shape=input_shape)]self.trainable_weights = [self.W]super(AttLayer, self).build(input_shape) # be sure you call this somewhere!def call(self, x, mask=None):eij = K.tanh(K.dot(x, self.W))ai = K.exp(eij)weights = ai/K.sum(ai, axis=1).dimshuffle(0,'x')weighted_input = x*weights.dimshuffle(0,1,'x')return weighted_input.sum(axis=1)def get_output_shape_for(self, input_shape):return (input_shape[0], input_shape[-1])
'''l_att = AttLayer()(lstm_out)output = Dense(1, activation='sigmoid')(l_att)print(output.shape)'''
"\nl_att = AttLayer()(lstm_out)\noutput = Dense(1, activation='sigmoid')(l_att)\nprint(output.shape)"
from keras.layers import Input, Dense, mergefrom keras import layers# ATTENTION PART STARTS HEREattention_probs = Dense(128, activation='sigmoid', name='attention_vec')(lstm_out)#attention_mul=layers.merge([stm_out,attention_probs], output_shape],mode='concat',concat_axis=1))attention_mul =Multiply()([lstm_out, attention_probs])#attention_mul = merge([lstm_out, attention_probs],output_shape=32, name='attention_mul', mode='mul')
output = Dense(1, activation='sigmoid')(attention_mul)#output = Dense(10, activation='sigmoid')(drop2)model = Model(inputs=inputs, outputs=output)print(model.summary())
__________________________________________________________________________________________________ Layer (type) Output Shape Param # Connected to ================================================================================================== input_1 (InputLayer) (None, 5, 6) 0 __________________________________________________________________________________________________ conv1d_1 (Conv1D) (None, 5, 64) 448 input_1[0][0] __________________________________________________________________________________________________ max_pooling1d_1 (MaxPooling1D) (None, 1, 64) 0 conv1d_1[0][0] __________________________________________________________________________________________________ dropout_1 (Dropout) (None, 1, 64) 0 max_pooling1d_1[0][0] __________________________________________________________________________________________________ bilstm (Bidirectional) (None, 128) 66048 dropout_1[0][0] __________________________________________________________________________________________________ attention_vec (Dense) (None, 128) 16512 bilstm[0][0] __________________________________________________________________________________________________ multiply_1 (Multiply) (None, 128) 0 bilstm[0][0] attention_vec[0][0] __________________________________________________________________________________________________ dense_1 (Dense) (None, 1) 129 multiply_1[0][0] ================================================================================================== Total params: 83,137 Trainable params: 83,137 Non-trainable params: 0 __________________________________________________________________________________________________ None
model.compile(loss='mean_squared_error', optimizer='adam')model.fit(X_train, y_train, epochs=epochs, batch_size=batch_size, shuffle=False)y_pred = model.predict(X_test)print('MSE Train loss:', model.evaluate(X_train, y_train, batch_size=batch_size))print('MSE Test loss:', model.evaluate(X_test, y_test, batch_size=batch_size))plt.plot(y_test, label='test')plt.plot(y_pred, label='pred')plt.legend()plt.show()
Epoch 1/60 37312/37312 [==============================] - 3s 92us/step - loss: 0.1865 Epoch 2/60 37312/37312 [==============================] - 2s 46us/step - loss: 0.0514 Epoch 3/60 37312/37312 [==============================] - 2s 42us/step - loss: 0.0442 Epoch 4/60 37312/37312 [==============================] - 2s 44us/step - loss: 0.0439 Epoch 5/60 37312/37312 [==============================] - 1s 39us/step - loss: 0.0436 Epoch 6/60 37312/37312 [==============================] - 1s 35us/step - loss: 0.0432 Epoch 7/60 37312/37312 [==============================] - 1s 35us/step - loss: 0.0429 Epoch 8/60 37312/37312 [==============================] - 1s 39us/step - loss: 0.0426 Epoch 9/60 37312/37312 [==============================] - 1s 37us/step - loss: 0.0424 Epoch 10/60 37312/37312 [==============================] - 1s 34us/step - loss: 0.0422 Epoch 11/60 37312/37312 [==============================] - 1s 36us/step - loss: 0.0420 Epoch 12/60 37312/37312 [==============================] - 1s 40us/step - loss: 0.0419 Epoch 13/60 37312/37312 [==============================] - 1s 40us/step - loss: 0.0418 Epoch 14/60 37312/37312 [==============================] - 1s 38us/step - loss: 0.0417 Epoch 15/60 37312/37312 [==============================] - 2s 42us/step - loss: 0.0417 Epoch 16/60 37312/37312 [==============================] - 2s 45us/step - loss: 0.0416 Epoch 17/60 37312/37312 [==============================] - 1s 39us/step - loss: 0.0416 Epoch 18/60 37312/37312 [==============================] - 2s 42us/step - loss: 0.0416 Epoch 19/60 37312/37312 [==============================] - 2s 44us/step - loss: 0.0416 Epoch 20/60 37312/37312 [==============================] - 1s 35us/step - loss: 0.0416 Epoch 21/60 37312/37312 [==============================] - 1s 40us/step - loss: 0.0416 Epoch 22/60 37312/37312 [==============================] - 2s 41us/step - loss: 0.0416 Epoch 23/60 37312/37312 [==============================] - 1s 37us/step - loss: 0.0415 Epoch 24/60 37312/37312 [==============================] - 1s 35us/step - loss: 0.0416 Epoch 25/60 37312/37312 [==============================] - 1s 40us/step - loss: 0.0415 Epoch 26/60 37312/37312 [==============================] - 2s 41us/step - loss: 0.0415 Epoch 27/60 37312/37312 [==============================] - 2s 46us/step - loss: 0.0415 Epoch 28/60 37312/37312 [==============================] - 2s 47us/step - loss: 0.0415 Epoch 29/60 37312/37312 [==============================] - 2s 43us/step - loss: 0.0414 Epoch 30/60 37312/37312 [==============================] - 1s 39us/step - loss: 0.0414 Epoch 31/60 37312/37312 [==============================] - 2s 41us/step - loss: 0.0414 Epoch 32/60 37312/37312 [==============================] - 2s 42us/step - loss: 0.0414 Epoch 33/60 37312/37312 [==============================] - 1s 37us/step - loss: 0.0414 Epoch 34/60 37312/37312 [==============================] - 2s 44us/step - loss: 0.0414 Epoch 35/60 37312/37312 [==============================] - 2s 49us/step - loss: 0.0413 Epoch 36/60 37312/37312 [==============================] - 1s 40us/step - loss: 0.0413 Epoch 37/60 37312/37312 [==============================] - 1s 35us/step - loss: 0.0413 Epoch 38/60 37312/37312 [==============================] - 2s 48us/step - loss: 0.0413 Epoch 39/60 37312/37312 [==============================] - 1s 40us/step - loss: 0.0412 Epoch 40/60 37312/37312 [==============================] - 1s 38us/step - loss: 0.0413 Epoch 41/60 37312/37312 [==============================] - 2s 42us/step - loss: 0.0412 Epoch 42/60 37312/37312 [==============================] - 2s 41us/step - loss: 0.0412 Epoch 43/60 37312/37312 [==============================] - 1s 36us/step - loss: 0.0412 Epoch 44/60 37312/37312 [==============================] - 1s 40us/step - loss: 0.0412 Epoch 45/60 37312/37312 [==============================] - 2s 43us/step - loss: 0.0412 Epoch 46/60 37312/37312 [==============================] - 1s 37us/step - loss: 0.0412 Epoch 47/60 37312/37312 [==============================] - 1s 38us/step - loss: 0.0412 Epoch 48/60 37312/37312 [==============================] - 2s 43us/step - loss: 0.0412 Epoch 49/60 37312/37312 [==============================] - 1s 39us/step - loss: 0.0411 Epoch 50/60 37312/37312 [==============================] - 1s 37us/step - loss: 0.0411 Epoch 51/60 37312/37312 [==============================] - 2s 42us/step - loss: 0.0411 Epoch 52/60 37312/37312 [==============================] - 2s 43us/step - loss: 0.0411 Epoch 53/60 37312/37312 [==============================] - 1s 38us/step - loss: 0.0411 Epoch 54/60 37312/37312 [==============================] - 2s 47us/step - loss: 0.0410 Epoch 55/60 37312/37312 [==============================] - 2s 50us/step - loss: 0.0410 Epoch 56/60 37312/37312 [==============================] - 2s 43us/step - loss: 0.0410 Epoch 57/60 37312/37312 [==============================] - 2s 48us/step - loss: 0.0410 Epoch 58/60 37312/37312 [==============================] - 2s 50us/step - loss: 0.0410 Epoch 59/60 37312/37312 [==============================] - 2s 40us/step - loss: 0.0410 Epoch 60/60 37312/37312 [==============================] - 2s 45us/step - loss: 0.0410 37312/37312 [==============================] - 1s 20us/step MSE Train loss: 0.04148709757006819 15988/15988 [==============================] - 0s 15us/step MSE Test loss: 0.0005358342003804944
本社区仅针对特定人员开放
查看需注册登录并通过风险意识测评
5秒后跳转登录页面...
移动端课程