-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathStockDataMod.py
107 lines (90 loc) · 4.3 KB
/
StockDataMod.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
#! /usr/bin/env python
#-*- encoding: utf-8 -*-
import numpy as np
import pandas as pd
import pandas_datareader.data as web
import datetime
import csv,os
import codecs
import talib
import requests
import re
#获取股票数据接口
def GetStockDatApi(stockName=None,stockTimeS=None,stockTimeE=None):
path = os.getcwd()
str_stockTimeS = stockTimeS.strftime('%Y-%m-%d')
str_stockTimeE = stockTimeE.strftime('%Y-%m-%d')
newname = stockName+'+'+str_stockTimeS+'+'+str_stockTimeE+'.csv'
newpath = os.path.join(path,newname)
#path=os.path.abspath('.')#获取当前脚本所在的路径 C:\Program Files\Notepad++
print(u"当前:%s" % os.getcwd())#当前工作目录
os.chdir(path)
print(u"修改为:%s" % os.getcwd())#修改后工作目录
for filename in os.listdir(path):#遍历路径下所有文件
#print(os.path.join(path,filename))
if stockName in filename:
if filename.count('+') == 2:#存在CSV文件
str_dfLoadTimeS = filename.split('+')[1]
str_dfLoadTimeE = filename.split('+')[2].split('.')[0]
dtm_dfLoadTimeS = datetime.datetime.strptime(str_dfLoadTimeS,'%Y-%m-%d')
dtm_dfLoadTimeE = datetime.datetime.strptime(str_dfLoadTimeE,'%Y-%m-%d')
if((dtm_dfLoadTimeS - stockTimeS).days <= 0)and((dtm_dfLoadTimeE - stockTimeE).days >= 0):#起止日期在文件内则读取CSV文件获取数据
print("123",(dtm_dfLoadTimeS - stockTimeS).days)
print("345",(dtm_dfLoadTimeE - stockTimeE).days)
stockDat = pd.read_csv(os.path.join(path,filename),parse_dates=True,index_col=0,encoding='gb2312')
print(stockDat.head(),stockDat.tail())
stockDat = stockDat.loc[stockTimeS:stockTimeE]
print(stockDat.head(),stockDat.tail())
else:#起止日期不相同则重新下载
stockDat = web.DataReader(stockName, "yahoo", stockTimeS, stockTimeE)
os.rename(filename, newname)
stockDat.to_csv(newpath,columns=stockDat.columns,index=True)
return stockDat
else:
break
stockDat = web.DataReader(stockName, "yahoo", stockTimeS, stockTimeE)
stockDat.to_csv(newpath,columns=stockDat.columns,index=True)
return stockDat
#处理股票数据接口
def GetStockDatPro(stockName=None,stockTimeS=None,stockTimeE=None):
if stockName.startswith('6'):
stockName = stockName + '.SS'
else:
stockName = stockName + '.SZ'
stockPro = GetStockDatApi(stockName, stockTimeS, stockTimeE)
# 处理移动平均线
stockPro['Ma20'] = stockPro.Close.rolling(window=20).mean()
stockPro['Ma60'] = stockPro.Close.rolling(window=60).mean()
stockPro['Ma120'] = stockPro.Close.rolling(window=120).mean()
# 处理MACD
stockPro['macd_dif'], stockPro['macd_dea'], stockPro['macd_bar'] = talib.MACD(stockPro['Close'].values, fastperiod=12, slowperiod=26, signalperiod=9)
# 处理KDJ
xd = 9 - 1
date = stockPro.index.to_series()
RSV = pd.Series(np.zeros(len(date) - xd), index=date.index[xd:])
Kvalue = pd.Series(0.0, index=RSV.index)
Dvalue = pd.Series(0.0, index=RSV.index)
Kvalue[0], Dvalue[0] = 50, 50
for day_ind in range(xd, len(stockPro.index)):
RSV[date[day_ind]] = (stockPro.Close[day_ind] - stockPro.Low[day_ind - xd:day_ind + 1].min()) / (
stockPro.High[day_ind - xd:day_ind + 1].max() - stockPro.Low[
day_ind - xd:day_ind + 1].min()) * 100
if day_ind > xd:
index = day_ind - xd
Kvalue[index] = 2.0 / 3 * Kvalue[index - 1] + RSV[date[day_ind]] / 3
Dvalue[index] = 2.0 / 3 * Dvalue[index - 1] + Kvalue[index] / 3
stockPro['RSV'] = RSV
stockPro['K'] = Kvalue
stockPro['D'] = Dvalue
stockPro['J'] = 3 * Kvalue - 2 * Dvalue
return stockPro
def format_code_name(code):
"""stockName"""
if code.startswith('6'):
code = 'sh' + code
else:
code = 'sz' + code
url = "https://hq.sinajs.cn/list={}".format(code)
content = requests.get(url).text
stockName = re.findall('[\u4e00-\u9fa5]+', content)[0]
return stockName