一個用Python編寫的股票數據(滬深)爬蟲和選股策略測試框架

一個股票數據(滬深)爬蟲和選股策略測試框架,數據基於雅虎YQL和新浪財經。

  • 根據選定的日期範圍抓取所有滬深兩市股票的行情數據。
  • 根據指定的選股策略和指定的日期進行選股測試。
  • 計算選股測試實際結果(包括與滬深300指數比較)。
  • 保存數據到JSON文件、CSV文件。
  • 支持使用表達式定義選股策略。
  • 支持多線程處理。

代碼

項目結構

main.py

from stockholm import Stockholm
import option
import os

def checkFoldPermission(path):
    if(path == 'USER_HOME/tmp/stockholm_export'):
        path = os.path.expanduser('~') + '/tmp/stockholm_export'
    try:
        if not os.path.exists(path):
            os.makedirs(path)
        else:
            txt = open(path + os.sep + "test.txt","w")
            txt.write("test")
            txt.close()
            os.remove(path + os.sep + "test.txt")

    except Exception as e:
        print(e)
        return False
    return True

def main():
    args = option.parser.parse_args()
    if not checkFoldPermission(args.store_path):
        print('\nPermission denied: %s' % args.store_path)
        print('Please make sure you have the permission to save the data!\n')
    else:
        print('Stockholm is starting...\n')
        stockh = Stockholm(args)
        stockh.run()
        print('Stockholm is done...\n')

if __name__ == '__main__':
    main()

option.py

import argparse 
import datetime

def get_date_str(offset):
    if(offset is None):
        offset = 0
    date_str = (datetime.datetime.today() + datetime.timedelta(days=offset)).strftime("%Y-%m-%d")
    return date_str

_default = dict(
    reload_data = 'Y',
    gen_portfolio = 'N',
    output_type = 'json',
    charset = 'utf-8',
    test_date_range = 60,
    start_date = get_date_str(-90),
    end_date = get_date_str(None),
    target_date = get_date_str(None),
    store_path = 'USER_HOME/tmp/stockholm_export',
    thread = 10,
    testfile_path = './portfolio_test.txt',
    db_name = 'stockholm',
    methods = ''
    )

parser = argparse.ArgumentParser(description='A stock crawler and portfolio testing framework.') 

parser.add_argument('--reload', type=str, default=_default['reload_data'], dest='reload_data', help='Reload the stock data or not (Y/N), Default: %s' % _default['reload_data'])

parser.add_argument('--portfolio', type=str, default=_default['gen_portfolio'], dest='gen_portfolio', help='Generate the portfolio or not (Y/N), Default: %s' % _default['gen_portfolio'])

parser.add_argument('--output', type=str, default=_default['output_type'], dest='output_type', help='Data output type (json/csv/all), Default: %s' % _default['output_type'])

parser.add_argument('--charset', type=str, default=_default['charset'], dest='charset', help='Data output charset (utf-8/gbk), Default: %s' % _default['charset'])

parser.add_argument('--testrange', type=int, default=_default['test_date_range'], dest='test_date_range', help='Test date range(days): %s' % _default['test_date_range'])

parser.add_argument('--startdate', type=str, default=_default['start_date'], dest='start_date', help='Data loading start date, Default: %s' % _default['start_date'])

parser.add_argument('--enddate', type=str, default=_default['end_date'], dest='end_date', help='Data loading end date, Default: %s' % _default['end_date'])

parser.add_argument('--targetdate', type=str, default=_default['target_date'], dest='target_date', help='Portfolio generating target date, Default: %s' % _default['target_date'])

parser.add_argument('--storepath', type=str, default=_default['store_path'], dest='store_path', help='Data file store path, Default: %s' % _default['store_path'])

parser.add_argument('--thread', type=int, default=_default['thread'], dest='thread', help='Thread number, Default: %s' % _default['thread'])

parser.add_argument('--testfile', type=str, default=_default['testfile_path'], dest='testfile_path', help='Portfolio test file path, Default: %s' % _default['testfile_path'])

parser.add_argument('--dbname', type=str, default=_default['db_name'], dest='db_name', help='MongoDB DB name, Default: %s' % _default['db_name'])

parser.add_argument('--methods', type=str, default=_default['methods'], dest='methods', help='Target methods for back testing, Default: %s' % _default['methods'])

def main():
    args = parser.parse_args()
    print(args)

if __name__ == '__main__':
    main()

stockholm.py

#coding:utf-8
import requests
import json
import datetime
import timeit
import time
import io
import os
import csv
import re
from pymongo import MongoClient
from multiprocessing.dummy import Pool as ThreadPool
from functools import partial

class Stockholm(object):

    def __init__(self, args):
        ## flag of if need to reload all stock data
        self.reload_data = args.reload_data
        ## flag of if need to generate portfolio
        self.gen_portfolio = args.gen_portfolio
        ## type of output file json/csv or both
        self.output_type = args.output_type
        ## charset of output file utf-8/gbk
        self.charset = args.charset
        ## portfolio testing date range(# of days)
        self.test_date_range = args.test_date_range
        ## stock data loading start date(e.g. 2014-09-14)
        self.start_date = args.start_date
        ## stock data loading end date
        self.end_date = args.end_date
        ## portfolio generating target date
        self.target_date = args.target_date
        ## thread number
        self.thread = args.thread
        ## data file store path
        if(args.store_path == 'USER_HOME/tmp/stockholm_export'):
            self.export_folder = os.path.expanduser('~') + '/tmp/stockholm_export'
        else:
            self.export_folder = args.store_path
        ## portfolio testing file path
        self.testfile_path = args.testfile_path
        ## methods for back testing
        self.methods = args.methods

        ## for getting quote symbols
        self.all_quotes_url = 'http://money.finance.sina.com.cn/d/api/openapi_proxy.php'
        ## for loading quote data
        self.yql_url = 'http://query.yahooapis.com/v1/public/yql'
        ## export file name
        self.export_file_name = 'stockholm_export'

        self.index_array = ['000001.SS', '399001.SZ', '000300.SS']
        self.sh000001 = {'Symbol': '000001.SS', 'Name': '上證指數'}
        self.sz399001 = {'Symbol': '399001.SZ', 'Name': '深證成指'}
        self.sh000300 = {'Symbol': '000300.SS', 'Name': '滬深300'}
        ## self.sz399005 = {'Symbol': '399005.SZ', 'Name': '中小板指'}
        ## self.sz399006 = {'Symbol': '399006.SZ', 'Name': '創業板指'}

        ## mongodb info
        self.mongo_url = 'localhost'
        self.mongo_port = 27017
        self.database_name = args.db_name
        self.collection_name = 'testing_method'

    def get_columns(self, quote):
        columns = []
        if(quote is not None):
            for key in quote.keys():
                if(key == 'Data'):
                    for data_key in quote['Data'][-1]:
                        columns.append("data." + data_key)
                else:
                    columns.append(key)
            columns.sort()
        return columns

    def get_profit_rate(self, price1, price2):
        if(price1 == 0):
            return None
        else:
            return round((price2-price1)/price1, 5)

    def get_MA(self, number_array):
        total = 0
        n = 0
        for num in number_array:
            if num is not None and num != 0:
                n += 1
                total += num
        return round(total/n, 3)

    def convert_value_check(self, exp):
        val = exp.replace('day', 'quote[\'Data\']').replace('(0)', '(-0)')
        val = re.sub(r'\(((-)?\d+)\)', r'[target_idx\g<1>]', val)
        val = re.sub(r'\.\{((-)?\w+)\}', r"['\g<1>']", val)
        return val

    def convert_null_check(self, exp):
        p = re.compile('\((-)?\d+...\w+\}')
        iterator = p.finditer(exp.replace('(0)', '(-0)'))
        array = []
        for match in iterator:
            v = 'quote[\'Data\']' + match.group()
            v = re.sub(r'\(((-)?\d+)\)', r'[target_idx\g<1>]', v)
            v = re.sub(r'\.\{((-)?\w+)\}', r"['\g<1>']", v)
            v += ' is not None'
            array.append(v)
        val = ' and '.join(array)
        return val

    class KDJ():
        def _avg(self, array):
            length = len(array)
            return sum(array)/length

        def _getMA(self, values, window):
            array = []
            x = window
            while x <= len(values):
                curmb = 50
                if(x-window == 0):
                    curmb = self._avg(values[x-window:x])
                else:
                    curmb = (array[-1]*2+values[x-1])/3
                array.append(round(curmb,3))
                x += 1
            return array

        def _getRSV(self, arrays):
            rsv = []
            x = 9
            while x <= len(arrays):
                high = max(map(lambda x: x['High'], arrays[x-9:x]))
                low = min(map(lambda x: x['Low'], arrays[x-9:x]))
                close = arrays[x-1]['Close']
                rsv.append((close-low)/(high-low)*100)
                t = arrays[x-1]['Date']
                x += 1
            return rsv

        def getKDJ(self, quote_data):
            if(len(quote_data) > 12):
                rsv = self._getRSV(quote_data)
                k = self._getMA(rsv,3)
                d = self._getMA(k,3)
                j = list(map(lambda x: round(3*x[0]-2*x[1],3), zip(k[2:], d)))

                for idx, data in enumerate(quote_data[0:12]):
                    data['KDJ_K'] = None
                    data['KDJ_D'] = None
                    data['KDJ_J'] = None
                for idx, data in enumerate(quote_data[12:]):
                    data['KDJ_K'] = k[2:][idx]
                    data['KDJ_D'] = d[idx]
                    if(j[idx] > 100):
                        data['KDJ_J'] = 100
                    elif(j[idx] < 0):
                        data['KDJ_J'] = 0
                    else:
                        data['KDJ_J'] = j[idx]

            return quote_data

    def load_all_quote_symbol(self):
        print("load_all_quote_symbol start..." + "\n")

        start = timeit.default_timer()

        all_quotes = []

        all_quotes.append(self.sh000001)
        all_quotes.append(self.sz399001)
        all_quotes.append(self.sh000300)
        ## all_quotes.append(self.sz399005)
        ## all_quotes.append(self.sz399006)

        try:
            count = 1
            while (count < 100):
                para_val = '[["hq","hs_a","",0,' + str(count) + ',500]]'
                r_params = {'__s': para_val}
                r = requests.get(self.all_quotes_url, params=r_params)
                if(len(r.json()[0]['items']) == 0):
                    break
                for item in r.json()[0]['items']:
                    quote = {}
                    code = item[0]
                    name = item[2]
                    ## convert quote code
                    if(code.find('sh') > -1):
                        code = code[2:] + '.SS'
                    elif(code.find('sz') > -1):
                        code = code[2:] + '.SZ'
                    ## convert quote code end
                    quote['Symbol'] = code
                    quote['Name'] = name
                    all_quotes.append(quote)
                count += 1
        except Exception as e:
            print("Error: Failed to load all stock symbol..." + "\n")
            print(e)

        print("load_all_quote_symbol end... time cost: " + str(round(timeit.default_timer() - start)) + "s" + "\n")
        return all_quotes

    def load_quote_info(self, quote, is_retry):
        print("load_quote_info start..." + "\n")

        start = timeit.default_timer()

        if(quote is not None and quote['Symbol'] is not None):
            yquery = 'select * from yahoo.finance.quotes where symbol = "' + quote['Symbol'].lower() + '"'
            r_params = {'q': yquery, 'format': 'json', 'env': 'http://datatables.org/alltables.env'}
            r = requests.get(self.yql_url, params=r_params)
            ## print(r.url)
            ## print(r.text)
            rjson = r.json()
            try:
                quote_info = rjson['query']['results']['quote']
                quote['LastTradeDate'] = quote_info['LastTradeDate']
                quote['LastTradePrice'] = quote_info['LastTradePriceOnly']
                quote['PreviousClose'] = quote_info['PreviousClose']
                quote['Open'] = quote_info['Open']
                quote['DaysLow'] = quote_info['DaysLow']
                quote['DaysHigh'] = quote_info['DaysHigh']
                quote['Change'] = quote_info['Change']
                quote['ChangeinPercent'] = quote_info['ChangeinPercent']
                quote['Volume'] = quote_info['Volume']
                quote['MarketCap'] = quote_info['MarketCapitalization']
                quote['StockExchange'] = quote_info['StockExchange']

            except Exception as e:
                print("Error: Failed to load stock info... " + quote['Symbol'] + "/" + quote['Name'] + "\n")
                print(e + "\n")
                if(not is_retry):
                    time.sleep(1)
                    load_quote_info(quote, True) ## retry once for network issue

        ## print(quote)
        print("load_quote_info end... time cost: " + str(round(timeit.default_timer() - start)) + "s" + "\n")
        return quote

    def load_all_quote_info(self, all_quotes):
        print("load_all_quote_info start...")

        start = timeit.default_timer()
        for idx, quote in enumerate(all_quotes):
            print("#" + str(idx + 1))
            load_quote_info(quote, False)

        print("load_all_quote_info end... time cost: " + str(round(timeit.default_timer() - start)) + "s")
        return all_quotes

    def load_quote_data(self, quote, start_date, end_date, is_retry, counter):
        ## print("load_quote_data start..." + "\n")

        start = timeit.default_timer()

        if(quote is not None and quote['Symbol'] is not None):        
            yquery = 'select * from yahoo.finance.historicaldata where symbol = "' + quote['Symbol'].upper() + '" and startDate = "' + start_date + '" and endDate = "' + end_date + '"'
            r_params = {'q': yquery, 'format': 'json', 'env': 'http://datatables.org/alltables.env'}
            try:
                r = requests.get(self.yql_url, params=r_params)
                ## print(r.url)
                ## print(r.text)
                rjson = r.json()
                quote_data = rjson['query']['results']['quote']
                quote_data.reverse()
                quote['Data'] = quote_data
                if(not is_retry):
                    counter.append(1)          

            except:
                print("Error: Failed to load stock data... " + quote['Symbol'] + "/" + quote['Name'] + "\n")
                if(not is_retry):
                    time.sleep(2)
                    self.load_quote_data(quote, start_date, end_date, True, counter) ## retry once for network issue

            print("load_quote_data " + quote['Symbol'] + "/" + quote['Name'] + " end..." + "\n")
            ## print("time cost: " + str(round(timeit.default_timer() - start)) + "s." + "\n")
            ## print("total count: " + str(len(counter)) + "\n")
        return quote

    def load_all_quote_data(self, all_quotes, start_date, end_date):
        print("load_all_quote_data start..." + "\n")

        start = timeit.default_timer()

        counter = []
        mapfunc = partial(self.load_quote_data, start_date=start_date, end_date=end_date, is_retry=False, counter=counter)
        pool = ThreadPool(self.thread)
        pool.map(mapfunc, all_quotes) ## multi-threads executing
        pool.close() 
        pool.join()

        print("load_all_quote_data end... time cost: " + str(round(timeit.default_timer() - start)) + "s" + "\n")
        return all_quotes

    def data_process(self, all_quotes):
        print("data_process start..." + "\n")

        kdj = self.KDJ()
        start = timeit.default_timer()

        for quote in all_quotes:

            if(quote['Symbol'].startswith('300')):
                quote['Type'] = '創業板'
            elif(quote['Symbol'].startswith('002')):
                quote['Type'] = '中小板'
            else:
                quote['Type'] = '主板'

            if('Data' in quote):
                try:
                    temp_data = []
                    for quote_data in quote['Data']:
                        if(quote_data['Volume'] != '000' or quote_data['Symbol'] in self.index_array):
                            d = {}
                            d['Open'] = float(quote_data['Open'])
                            ## d['Adj_Close'] = float(quote_data['Adj_Close'])
                            d['Close'] = float(quote_data['Close'])
                            d['High'] = float(quote_data['High'])
                            d['Low'] = float(quote_data['Low'])
                            d['Volume'] = int(quote_data['Volume'])
                            d['Date'] = quote_data['Date']
                            temp_data.append(d)
                    quote['Data'] = temp_data
                except KeyError as e:
                    print("Data Process: Key Error")
                    print(e)
                    print(quote)

        ## calculate Change / 5 10 20 30 Day MA
        for quote in all_quotes:
            if('Data' in quote):
                try:
                    for i, quote_data in enumerate(quote['Data']):
                        if(i > 0):
                            quote_data['Change'] = self.get_profit_rate(quote['Data'][i-1]['Close'], quote_data['Close'])
                            quote_data['Vol_Change'] = self.get_profit_rate(quote['Data'][i-1]['Volume'], quote_data['Volume'])                        
                        else:
                            quote_data['Change'] = None
                            quote_data['Vol_Change'] = None

                    last_5_array = []
                    last_10_array = []
                    last_20_array = []
                    last_30_array = []
                    for i, quote_data in enumerate(quote['Data']):
                        last_5_array.append(quote_data['Close'])
                        last_10_array.append(quote_data['Close'])
                        last_20_array.append(quote_data['Close'])
                        last_30_array.append(quote_data['Close'])
                        quote_data['MA_5'] = None
                        quote_data['MA_10'] = None
                        quote_data['MA_20'] = None
                        quote_data['MA_30'] = None

                        if(i < 4):
                            continue
                        if(len(last_5_array) == 5):
                            last_5_array.pop(0)
                        quote_data['MA_5'] = self.get_MA(last_5_array)

                        if(i < 9):
                            continue
                        if(len(last_10_array) == 10):
                            last_10_array.pop(0)
                        quote_data['MA_10'] = self.get_MA(last_10_array)

                        if(i < 19):
                            continue
                        if(len(last_20_array) == 20):
                            last_20_array.pop(0)
                        quote_data['MA_20'] = self.get_MA(last_20_array)

                        if(i < 29):
                            continue
                        if(len(last_30_array) == 30):
                            last_30_array.pop(0)
                        quote_data['MA_30'] = self.get_MA(last_30_array)


                except KeyError as e:
                    print("Key Error")
                    print(e)
                    print(quote)

        ## calculate KDJ
        for quote in all_quotes:
            if('Data' in quote):
                try:
                    kdj.getKDJ(quote['Data'])
                except KeyError as e:
                    print("Key Error")
                    print(e)
                    print(quote)

        print("data_process end... time cost: " + str(round(timeit.default_timer() - start)) + "s" + "\n")

    def data_export(self, all_quotes, export_type_array, file_name):

        start = timeit.default_timer()
        directory = self.export_folder
        if(file_name is None):
            file_name = self.export_file_name
        if not os.path.exists(directory):
            os.makedirs(directory)

        if(all_quotes is None or len(all_quotes) == 0):
            print("no data to export...\n")

        if('json' in export_type_array):
            print("start export to JSON file...\n")
            f = io.open(directory + '/' + file_name + '.json', 'w', encoding=self.charset)
            json.dump(all_quotes, f, ensure_ascii=False)

        if('csv' in export_type_array):
            print("start export to CSV file...\n")
            columns = []
            if(all_quotes is not None and len(all_quotes) > 0):
                columns = self.get_columns(all_quotes[0])
            writer = csv.writer(open(directory + '/' + file_name + '.csv', 'w', encoding=self.charset))
            writer.writerow(columns)

            for quote in all_quotes:
                if('Data' in quote):
                    for quote_data in quote['Data']:
                        try:
                            line = []
                            for column in columns:
                                if(column.find('data.') > -1):
                                    if(column[5:] in quote_data):
                                        line.append(quote_data[column[5:]])
                                else:
                                    line.append(quote[column])
                            writer.writerow(line)
                        except Exception as e:
                            print(e)
                            print("write csv error: " + quote)

        if('mongo' in export_type_array):
            print("start export to MongoDB...\n")

        print("export is complete... time cost: " + str(round(timeit.default_timer() - start)) + "s" + "\n")

    def file_data_load(self):
        print("file_data_load start..." + "\n")

        start = timeit.default_timer()
        directory = self.export_folder
        file_name = self.export_file_name

        all_quotes_data = []
        f = io.open(directory + '/' + file_name + '.json', 'r', encoding='utf-8')
        json_str = f.readline()
        all_quotes_data = json.loads(json_str)

        print("file_data_load end... time cost: " + str(round(timeit.default_timer() - start)) + "s" + "\n")
        return all_quotes_data

    def check_date(self, all_quotes, date):

        is_date_valid = False
        for quote in all_quotes:
            if(quote['Symbol'] in self.index_array):
                for quote_data in quote['Data']:
                    if(quote_data['Date'] == date):
                        is_date_valid = True
        if not is_date_valid:
            print(date + " is not valid...\n")
        return is_date_valid

    def quote_pick(self, all_quotes, target_date, methods):
        print("quote_pick start..." + "\n")

        start = timeit.default_timer()

        results = []
        data_issue_count = 0

        for quote in all_quotes:
            try:
                if(quote['Symbol'] in self.index_array):
                    results.append(quote)
                    continue

                target_idx = None
                for idx, quote_data in enumerate(quote['Data']):
                    if(quote_data['Date'] == target_date):
                        target_idx = idx
                if(target_idx is None):
                    ## print(quote['Name'] + " data is not available at this date..." + "\n")
                    data_issue_count+=1
                    continue

                ## pick logic ##
                valid = False
                for method in methods:
                    ## print(method['name'])
                    ## null_check = eval(method['null_check'])
                    try:
                        value_check = eval(method['value_check'])
                        if(value_check):
                            quote['Method'] = method['name']
                            results.append(quote)
                            valid = True
                            break
                    except:
                        valid = False
                if(valid):
                    continue

                ## pick logic end ##

            except KeyError as e:
                ## print("KeyError: " + quote['Name'] + " data is not available..." + "\n")
                data_issue_count+=1

        print("quote_pick end... time cost: " + str(round(timeit.default_timer() - start)) + "s" + "\n")
        print(str(data_issue_count) + " quotes of data is not available...\n")
        return results

    def profit_test(self, selected_quotes, target_date):
        print("profit_test start..." + "\n")

        start = timeit.default_timer()

        results = []
        INDEX = None
        INDEX_idx = 0

        for quote in selected_quotes:
            if(quote['Symbol'] == self.sh000300['Symbol']):
                INDEX = quote
                for idx, quote_data in enumerate(quote['Data']):
                    if(quote_data['Date'] == target_date):
                        INDEX_idx = idx
                break

        for quote in selected_quotes:
            target_idx = None

            if(quote['Symbol'] in self.index_array):
                continue

            for idx, quote_data in enumerate(quote['Data']):
                if(quote_data['Date'] == target_date):
                    target_idx = idx
            if(target_idx is None):
                print(quote['Name'] + " data is not available for testing..." + "\n")
                continue

            test = {}
            test['Name'] = quote['Name']
            test['Symbol'] = quote['Symbol']
            test['Method'] = quote['Method']
            test['Type'] = quote['Type']
            if('KDJ_K' in quote['Data'][target_idx]):
                test['KDJ_K'] = quote['Data'][target_idx]['KDJ_K']
                test['KDJ_D'] = quote['Data'][target_idx]['KDJ_D']
                test['KDJ_J'] = quote['Data'][target_idx]['KDJ_J']
            test['Close'] = quote['Data'][target_idx]['Close']
            test['Change'] = quote['Data'][target_idx]['Change']
            test['Vol_Change'] = quote['Data'][target_idx]['Vol_Change']
            test['MA_5'] = quote['Data'][target_idx]['MA_5']
            test['MA_10'] = quote['Data'][target_idx]['MA_10']
            test['MA_20'] = quote['Data'][target_idx]['MA_20']
            test['MA_30'] = quote['Data'][target_idx]['MA_30']
            test['Data'] = [{}]

            for i in range(1,11):
                if(target_idx+i >= len(quote['Data'])):
                    print(quote['Name'] + " data is not available for " + str(i) + " day testing..." + "\n")
                    break

                day2day_profit = self.get_profit_rate(quote['Data'][target_idx]['Close'], quote['Data'][target_idx+i]['Close'])
                test['Data'][0]['Day_' + str(i) + '_Profit'] = day2day_profit
                if(INDEX_idx+i < len(INDEX['Data'])):
                    day2day_INDEX_change = self.get_profit_rate(INDEX['Data'][INDEX_idx]['Close'], INDEX['Data'][INDEX_idx+i]['Close'])
                    test['Data'][0]['Day_' + str(i) + '_INDEX_Change'] = day2day_INDEX_change
                    test['Data'][0]['Day_' + str(i) + '_Differ'] = day2day_profit-day2day_INDEX_change

            results.append(test)

        print("profit_test end... time cost: " + str(round(timeit.default_timer() - start)) + "s" + "\n")
        return results

    def data_load(self, start_date, end_date, output_types):
        all_quotes = self.load_all_quote_symbol()
        print("total " + str(len(all_quotes)) + " quotes are loaded..." + "\n")
        all_quotes = all_quotes
        ## self.load_all_quote_info(all_quotes)
        self.load_all_quote_data(all_quotes, start_date, end_date)
        self.data_process(all_quotes)

        self.data_export(all_quotes, output_types, None)

    def data_test(self, target_date, test_range, output_types):
        ## loading test methods
        methods = []
        path = self.testfile_path

        ## from mongodb
        if(path == 'mongodb'):
            print("Load testing methods from Mongodb...\n")
            client = MongoClient(self.mongo_url, self.mongo_port)
            db = client[self.database_name]
            col = db[self.collection_name]
            q = None
            if(len(self.methods) > 0):
                applied_methods = list(map(int, self.methods.split(',')))
                q = {"method_id": {"$in": applied_methods}}
            for doc in col.find(q, ['name','desc','method']):
                print(doc)
                m = {'name': doc['name'], 'value_check': self.convert_value_check(doc['method'])}
                methods.append(m)

        ## from test file
        else:
            if not os.path.exists(path):
                print("Portfolio test file is not existed, testing is aborted...\n")
                return
            f = io.open(path, 'r', encoding='utf-8')
            for line in f:
                if(line.startswith('##') or len(line.strip()) == 0):
                    continue
                line = line.strip().strip('\n')
                name = line[line.find('[')+1:line.find(']:')]
                value = line[line.find(']:')+2:]
                m = {'name': name, 'value_check': self.convert_value_check(value)}
                methods.append(m)

        if(len(methods) == 0):
            print("No method is loaded, testing is aborted...\n")
            return

        ## portfolio testing 
        all_quotes = self.file_data_load()
        target_date_time = datetime.datetime.strptime(target_date, "%Y-%m-%d")
        for i in range(test_range):
            date = (target_date_time - datetime.timedelta(days=i)).strftime("%Y-%m-%d")
            is_date_valid = self.check_date(all_quotes, date)
            if is_date_valid:
                selected_quotes = self.quote_pick(all_quotes, date, methods)
                res = self.profit_test(selected_quotes, date)
                self.data_export(res, output_types, 'result_' + date)

    def run(self):
        ## output types
        output_types = []
        if(self.output_type == "json"):
            output_types.append("json")
        elif(self.output_type == "csv"):
            output_types.append("csv")
        elif(self.output_type == "all"):
            output_types = ["json", "csv"]

        ## loading stock data
        if(self.reload_data == 'Y'):
            print("Start loading stock data...\n")
            self.data_load(self.start_date, self.end_date, output_types)

        ## test & generate portfolio
        if(self.gen_portfolio == 'Y'):
            print("Start portfolio testing...\n")
            self.data_test(self.target_date, self.test_date_range, output_types)


mongo_scripts.txt

use stockholm

db.counters.insert(
   {
      _id: "method_id",
      seq: 0
   }
)

function getNextSequence(name) {
   var ret = db.counters.findAndModify(
          {
            query: { _id: name },
            update: { $inc: { seq: 1 } },
            new: true
          }
   );

   return ret.seq;
}

db.testing_method.insert({"method_id": getNextSequence("method_id"), "name":"測試方法1", "desc":"這是一個測試方法。", "user_name":"Stockholm", "user_id":"[email protected]", "creation_date": new Date(), "modification_date": new Date(), "method":"day(-2).{KDJ_J}<20 and day(-1).{KDJ_J}<20 and day(0).{KDJ_J}-day(-1).{KDJ_J}>=40 and day(0).{Vol_Change}>=1 and day(0).{MA_10}*1.05>day(0).{Close}"})

db.testing_method.insert({"method_id": getNextSequence("method_id"), "name":"測試方法2", "desc":"這是一個測試方法。", "user_name":"Stockholm", "user_id":"[email protected]", "creation_date": new Date(), "modification_date": new Date(), "method":"day(-2).{KDJ_J}-day(-1).{KDJ_J}>20 and day(0).{KDJ_J}-day(-1).{KDJ_J}>20 and day(-1).{KDJ_J}<50 and day(0).{Vol_Change}<=1"})

portfolio_test.txt

## Portfolio selection methodology sample file

[測試方法1]:day(-2).{KDJ_J}<20 and day(-1).{KDJ_J}<20 and day(0).{KDJ_J}-day(-1).{KDJ_J}>=40 and day(0).{Vol_Change}>=1 and day(0).{MA_10}*1.05>day(0).{Close}

[測試方法2]:day(-2).{KDJ_J}-day(-1).{KDJ_J}>20 and day(0).{KDJ_J}-day(-1).{KDJ_J}>20 and day(-1).{KDJ_J}<50 and day(0).{Vol_Change}<=1

##[測試方法3]:50<day(-1).{KDJ_J}<80 and day(-2).{KDJ_J}<day(-1).{KDJ_J} and day(0).{KDJ_J}<day(-1).{KDJ_J}

運行時參數

--storepath c://test --output csv   --startdate 2015-09-01 --enddate 2015-12-07 --charset utf-8 --testfile ./portfolio_test.txt --reload Y --portfolio Y --thread 10

運行示例

能幹什麼

如果你想基於滬深股市行情數據進行一些工作,它可以幫助你導出指定時間範圍內所有滬深A股的行情數據和一些技術指標,包括代碼、名稱、開盤、收盤、最高、最低、成交量、均線、KDJ等。

還有些什麼問題

行情數據目前來源於雅虎YQL,每日數據的更新時間不太穩定(一般在中國時間午夜左右)。

環境

Python 3.4以上

pip install requests
pip install pymongo

使用

python main.py [-h] [--reload {Y,N}] [--portfolio {Y,N}] 
               [--output {json,csv,all}] [--storepath PATH] [--thread NUM] 
               [--startdate yyyy-MM-dd] [--enddate yyyy-MM-dd] 
               [--targetdate yyyy-MM-dd] [--testrange NUM] [--testfile PATH]

可選參數

  -h, --help                  查看幫助並退出
  --reload {Y,N}              是否重新抓取股票數據,默認值:Y
  --portfolio {Y,N}           是否生成選股測試結果,默認值:N
  --output {json,csv,all}     輸出文件格式,默認值:json
  --charset {utf-8,gbk}       輸出文件編碼,默認值:utf-8
  --storepath PATH            輸出文件路徑,默認值:~/tmp/stockholm_export
  --thread NUM                線程數,默認值:10
  --startdate yyyy-MM-dd      抓取數據的開始日期,默認值:當前系統日期-100天(例如2015-01-01  --enddate yyyy-MM-dd        抓取數據的結束日期,默認值:當前系統日期
  --targetdate yyyy-MM-dd     測試選股策略的目標日期,默認值:當前系統日期
  --testrange NUM             測試日期範圍天數,默認值:50
  --testfile PATH             測試文件路徑,默認值:./portfolio_test.txt

可用數據/格式

行情數據:

[
    {"Symbol": "600000.SS", 
     "Name": "浦發銀行""Data": [
                 {"Vol_Change": null, "MA_10": null, "Date": "2015-03-26", "High": 15.58, "Open": 15.15, "Volume": 282340700, "Close": 15.36, "Change": null, "Low": 15.04}, 
                 {"Vol_Change": -0.22726, "MA_10": null, "Date": "2015-03-27", "High": 15.55, "Open": 15.32, "Volume": 218174900, "Close": 15.36, "Change": 0.0, "Low": 15.17}
             ]
    }
]

Date(日期); Open(開盤價); Close(收盤價); High(當日最高); Low(當日最低); Change(價格變化%); Volume(成交量); Vol_Change(成交量較前日變化); MA_5(5日均線); MA_10(10日均線); MA_20(20日均線); MA_30(30日均線); KDJ_K(KDJ指標K); KDJ_D(KDJ指標D); KDJ_J(KDJ指標J);

選股策略測試數據:

[
    {
        "Symbol": "600000.SS", 
        "Name": "浦發銀行", 
        "Close": 14.51, 
        "Change": 0.06456,
        "Vol_Change": 2.39592, 
        "MA_10": 14.171, 
        "KDJ_K": 37.65, 
        "KDJ_D": 33.427, 
        "KDJ_J": 46.096, 
        "Data": [
                    {"Day_5_Differ": 0.01869, "Day_9_Profit": 0.08546, "Day_1_Profit": -0.02826, "Day_1_INDEX_Change": -0.00484, "Day_3_INDEX_Change": 0.01557, "Day_5_INDEX_Change": 0.04747, "Day_3_Differ": 0.02647, "Day_9_INDEX_Change": 0.1003, "Day_5_Profit": 0.06616, "Day_3_Profit": 0.04204, "Day_1_Differ": -0.02342, "Day_9_Differ": -0.014840000000000006}
                ]
    }
]

Close(收盤價); Change(價格變化%); Vol_Change(成交量較前日變化); MA_10(十天均價); KDJ_K(KDJ指標K); KDJ_D(KDJ指標D); KDJ_J(KDJ指標J); Day_1_Profit(後一天利潤率%); Day_1_INDEX_Change(後一天滬深300變化率%); Day_1_Differ(後一天相對利潤率%——即利潤率-滬深300變化率); Day_n_Profit(後n天利潤率%); Day_n_INDEX_Change(後n天滬深300變化率%); Day_n_Differ(後n天相對利潤率%——即利潤率-滬深300變化率);

行情數據抓取範例

獲取從當前日期倒推100天(不是100個交易日)的所有滬深股票行情數據。

執行完成後,數據在當前用戶文件夾下./tmp/stockholm_export/stockholm_export.json

python main.py

如果想導出csv文件

python main.py --output=csv

選股策略測試範例

選股策略範例文件內容如下(包括在源碼中)

選股策略”method 1”是:前前個交易日的KDJ指標的J值小於20+前個交易日的KDJ指標J值小於20+當前交易日的KDJ指標J值比上個交易日大40+當前交易日成交量變化大於100%

## Portfolio selection methodology sample file

[method 1]:day(-2).{KDJ_J}<20 and day(-1).{KDJ_J}<20 and day(0).{KDJ_J}-day(-1).{KDJ_J}>=40 and day(0).{Vol_Change}>=1

以當前系統日期爲目標日期進行倒推60天得選股策略測試。

不重新抓取行情數據並執行測試命令。

執行完畢後,會將測試結果按照每天一個文件的方式保存在./tmp/stockholm_export/。

文件名格式爲result_yyyy-MM-dd.json(例如result_2015-03-24.json)。

python main.py --reload=N --portfolio=Y

通過更改測試文件中的選股策略公式,可以隨意測試指定時間範圍內的選股效果。

發佈了107 篇原創文章 · 獲贊 147 · 訪問量 51萬+
發表評論
所有評論
還沒有人評論,想成為第一個評論的人麼? 請在上方評論欄輸入並且點擊發布.
相關文章