仓位系统

一个简单的字典作为持仓列表,key为标的代码,value为持有股数

class Portfolio:
    def __init__(self):
        self.position = collections.defaultdict(int)

    def update_position(self, code='000001.SZ', hand=100):
        self.position[code] += hand

    def get_position(self, code='000001.SZ'):
        return self.position[code]

仓位信息

单纯仓位系统不能直接开放给策略,策略应该拿到的是一个实际持仓的只读副本。

class PortfolioInfo:
    def __init__(self, portfolio: Portfolio, start_date: str, end_date: str, date: str, avail_cash: float):
        self.position = {}
        for k, v in portfolio.position.items():
            self.position[k] = {'hand': v, 'value': Util.get_price(k).close(date) * v * 100}
        self.avail_cash = avail_cash
        self.order = []
        self.date = date
        self.start_date = start_date
        self.end_date = end_date

    def total_value(self) -> float:
        res = 0
        for k, v in self.position.items():
            res += v['value']
        return res

    def rate(self) -> float:
        return self.total_value() / (self.avail_cash + self.total_value())

    def order_by_hand(self, code='000001.SZ', hand=100.):
        self.order.append({"code": code, "hand": hand})

    def order_by_cash(self, code='000001.SZ', cash=10000.):
        self.order.append({"code": code, "cash": cash})

此外,这个信息结构里面还要带上订单系统,供策略调用,策略调用结束后,此结构记录了策略下单的信息,带回模拟器进行模拟购买。

模拟器

模拟器里面主要负责维护一份持仓列表,并提供给策略去做买入卖出操作,操作结束后对策略的下单信息分别进行处理。对于股数/现金买入,计算佣金情况下,二分查找能够在范围内买到的最大数量,对于卖出,计算在佣金印花税消耗下,二分查找最大能卖出的数量处理,处理数量为整百股。另外模拟器在每个交易日的模拟结束后收集仓位信息,最终输出评估表。

class Simulator:
    def __init__(self, strategy, init_cash=100000., start_date='20200101', end_date='20210101', duty=0.001,
                 brokerage=0.00025):
        self.init_cash = init_cash
        self.cash = init_cash
        self.start_date = start_date
        self.end_date = end_date
        self.duty = duty
        self.brokerage = brokerage
        self.strategy = strategy
        self.portfolio = Portfolio()
        df: pd.DataFrame = pro.trade_cal(start_date=start_date, end_date=end_date)
        df = df.loc[df['is_open'] == 1]
        self.trade_dates = list(df['cal_date'])
        self.date = 0
        self.statistics = pd.DataFrame(columns=['date', 'cash', 'stock', 'total', 'rate', 'base'])

    def _produce_order_hand(self, o: dict, code: str, hand: float, price: float, buy_limit: 0.0):
        if hand == 0:
            return
        price *= 100
        if hand > 0:
            l, r = 0, hand
            while l < r:
                m = (l + r + 1) // 2
                if price * m + max(price * m * self.brokerage, 5) <= buy_limit:
                    l = m
                else:
                    r = m - 1
            hand = l
            cost = price * hand + max(price * hand * self.brokerage, 5)
            logging.info('%s交易成功,买入%d手,花去%f' % (str(o), int(hand), cost))
            self.portfolio.update_position(code, hand)
            self.cash -= cost
        else:
            hand = -hand
            hand = int(min(hand, self.portfolio.get_position(code)))
            g = price * hand * (1 - self.duty)
            logging.info('%s交易成功,卖出%d手,得到%f' % (str(o), int(hand), g))
            self.portfolio.update_position(code, -hand)
            self.cash += g

    def _produce_order(self, info: PortfolioInfo):
        for o in info.order:
            price = Util.get_price(o['code'])
            p = price.open(info.date)
            if 'hand' in o:
                self._produce_order_hand(o, o['code'], o['hand'], p, self.cash)
            elif 'cash' in o:
                self._produce_order_hand(o, o['code'], o['cash'] // (p * 100), p, o['cash'])
            else:
                logging.info('%s非法交易' % (str(o)))

    def _collect(self):
        info = PortfolioInfo(self.portfolio, self.start_date, self.end_date, self.trade_dates[self.date], self.cash)
        data = {
            'date': self.trade_dates[self.date],
            'cash': self.cash,
            'stock': info.total_value(),
            'total': self.cash + info.total_value(),
            'rate': (self.cash + info.total_value()) / self.init_cash - 1,
            'base': Util.get_base().get_base(self.start_date, self.trade_dates[self.date])
        }
        self.statistics = self.statistics.append(data, ignore_index=True)
        logging.info(data)

    def _open(self, info: PortfolioInfo):
        self.date += 1
        if self.date >= len(self.trade_dates):
            return
        self._produce_order(info)
        self._collect()

    def _close(self):
        info = PortfolioInfo(self.portfolio, self.start_date, self.end_date, self.trade_dates[self.date], self.cash)
        self.strategy.on_close(info)
        self._open(info)

    def run(self):
        while self.date < len(self.trade_dates):
            self._close()
        f = plt.figure()
        #ax1 = f.add_subplot()
        #ax1.plot(self.statistics.index, self.statistics['cash'], label='cash')
        #ax1.plot(self.statistics.index, self.statistics['stock'], label='stock')
        #ax1.plot(self.statistics.index, self.statistics['total'], label='total')
        #ax1.set_title('statistics')
        #plt.legend()
        #ax2 = ax1.twinx()

        def to_percent(temp, position):
            return '%.2f' % (100 * temp) + '%'

        #ax2.yaxis.set_major_formatter(FuncFormatter(to_percent))
        plt.plot(self.statistics.index, self.statistics['rate'], label='rate')
        plt.plot(self.statistics.index, self.statistics['base'], label='base')
        plt.legend()
        plt.show()

工具类

基准类用于根据沪深300指数来做基准评估,价格类用于获取历史价格。

class Base:
    def __init__(self, code='ASHS'):
        file_name = code + ".pkl"
        if os.path.exists(file_name):
            self.df = pd.read_pickle(file_name)
        else:
            self.df: pd.DataFrame = pro.us_daily(ts_code=code, start_date='20190101', end_date='20210201')
            self.df = self.df.sort_values('trade_date')
            self.df.to_pickle(file_name)

    def get_base(self, start_date: str, end_date: str):
        y = list(self.df.loc[self.df['trade_date'] <= end_date]['close'])[-1]
        x = list(self.df.loc[self.df['trade_date'] <= start_date]['close'])[-1]
        print(start_date, end_date, x, y)
        return y/x-1


class Price:
    def __init__(self, code='000001.SZ'):
        self.code = code
        self.df: pd.DataFrame = pro.daily(ts_code=code, start_date='20190101', end_date='20210201')

    def close(self, item):
        return list(self.df.loc[self.df['trade_date'] == item]['close'])[0]

    def open(self, item):
        return list(self.df.loc[self.df['trade_date'] == item]['open'])[0]

    def get(self, end_date: str):
        return self.df.loc[self.df['trade_date'] < end_date]

简单策略

这里随便搞了一个(没啥用)的策略,用于测试框架效果。

class MACDStrategy:
    def __init__(self):
        self.code = '000001.SZ'

    def on_close(self, info: PortfolioInfo):
        p = Util.get_price(self.code)
        df = p.get(info.date)
        df = df.sort_values('trade_date')
        df['ema12'] = df['close'].ewm(span=12).mean()
        df['ema26'] = df['close'].ewm(span=26).mean()
        df['diff'] = df['ema12'] - df['ema26']
        df['dea'] = df['diff'].ewm(span=9).mean()
        df['histogram'] = df['diff'] - df['dea']
        diff_list = list(df['diff'])
        if diff_list[-1] > 0 > diff_list[-2]:
            info.order_by_cash(self.code, info.avail_cash/2)
        elif diff_list[-1] < 0 < diff_list[-2]:
            info.order_by_cash(self.code, -info.total_value()/2)

执行效果

2021-02-14 15:51:04,669 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201215', 'cash': 314244.5519999999, 'stock': 818937.9999999999, 'total': 1133182.5519999997, 'rate': 0.13318255199999962, 'base': 0.2669077757685354}
2021-02-14 15:51:04,679 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201216', 'cash': 314244.5519999999, 'stock': 830737.0000000001, 'total': 1144981.5520000001, 'rate': 0.14498155200000018, 'base': 0.26039783001808336}
2021-02-14 15:51:04,689 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201217', 'cash': 314244.5519999999, 'stock': 828115.0, 'total': 1142359.552, 'rate': 0.14235955199999983, 'base': 0.2730560578661847}
2021-02-14 15:51:04,699 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201218', 'cash': 314244.5519999999, 'stock': 802332.0, 'total': 1116576.552, 'rate': 0.11657655199999994, 'base': 0.2600361663652804}
2021-02-14 15:51:04,709 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201221', 'cash': 314244.5519999999, 'stock': 798836.0, 'total': 1113080.552, 'rate': 0.113080552, 'base': 0.27377938517179023}
2021-02-14 15:51:04,719 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201222', 'cash': 314244.5519999999, 'stock': 782229.9999999999, 'total': 1096474.5519999997, 'rate': 0.09647455199999966, 'base': 0.2484629294755878}
2021-02-14 15:51:04,727 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201223', 'cash': 314244.5519999999, 'stock': 786600.0, 'total': 1100844.552, 'rate': 0.10084455199999987, 'base': 0.26075949367088613}
2021-02-14 15:51:04,735 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201224', 'cash': 314244.5519999999, 'stock': 797962.0000000001, 'total': 1112206.5520000001, 'rate': 0.11220655200000018, 'base': 0.24448462929475578}
2021-02-14 15:51:04,740 - /Users/lizile/PycharmProjects/quant/main.py[line:139] - INFO: {'code': '000001.SZ', 'cash': -398981.00000000006}交易成功,卖出220手,得到398900.700000
2021-02-14 15:51:04,744 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201225', 'cash': 713145.2519999999, 'stock': 391468.0, 'total': 1104613.2519999999, 'rate': 0.10461325199999982, 'base': 0.24448462929475578}
2021-02-14 15:51:04,756 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201228', 'cash': 713145.2519999999, 'stock': 409045.0, 'total': 1122190.2519999999, 'rate': 0.12219025199999978, 'base': 0.2484629294755878}
2021-02-14 15:51:04,769 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201229', 'cash': 713145.2519999999, 'stock': 415989.00000000006, 'total': 1129134.2519999999, 'rate': 0.12913425199999984, 'base': 0.24520795660036176}
2021-02-14 15:51:04,784 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201230', 'cash': 713145.2519999999, 'stock': 416639.99999999994, 'total': 1129785.2519999999, 'rate': 0.1297852519999998, 'base': 0.26039783001808336}
2021-02-14 15:51:04,793 - /Users/lizile/PycharmProjects/quant/main.py[line:132] - INFO: {'code': '000001.SZ', 'cash': 356572.62599999993}交易成功,买入187手,花去355388.825000
2021-02-14 15:51:04,800 - /Users/lizile/PycharmProjects/quant/main.py[line:165] - INFO: {'date': '20201231', 'cash': 357756.42699999985, 'stock': 781336.0, 'total': 1139092.427, 'rate': 0.13909242700000002, 'base': 0.27667269439421327}

完整代码

import logging
from functools import lru_cache
import os
import tushare as ts
import pandas as pd
from matplotlib import pyplot as plt
import collections

from matplotlib.ticker import FuncFormatter

logging.basicConfig(format='%(asctime)s - %(pathname)s[line:%(lineno)d] - %(levelname)s: %(message)s',
                    level=logging.INFO)

token = '???'
ts.set_token(token)
pro = ts.pro_api()


class Base:
    def __init__(self, code='ASHS'):
        file_name = code + ".pkl"
        if os.path.exists(file_name):
            self.df = pd.read_pickle(file_name)
        else:
            self.df: pd.DataFrame = pro.us_daily(ts_code=code, start_date='20190101', end_date='20210201')
            self.df = self.df.sort_values('trade_date')
            self.df.to_pickle(file_name)

    def get_base(self, start_date: str, end_date: str):
        y = list(self.df.loc[self.df['trade_date'] <= end_date]['close'])[-1]
        x = list(self.df.loc[self.df['trade_date'] <= start_date]['close'])[-1]
        #print(start_date, end_date, x, y)
        return y/x-1


class Price:
    def __init__(self, code='000001.SZ'):
        self.code = code
        self.df: pd.DataFrame = pro.daily(ts_code=code, start_date='20190101', end_date='20210201')

    def close(self, item):
        return list(self.df.loc[self.df['trade_date'] == item]['close'])[0]

    def open(self, item):
        return list(self.df.loc[self.df['trade_date'] == item]['open'])[0]

    def get(self, end_date: str):
        return self.df.loc[self.df['trade_date'] < end_date]


class Util:
    @staticmethod
    @lru_cache(maxsize=None)
    def get_price(code='000001.SZ') -> Price:
        return Price(code)

    @staticmethod
    @lru_cache(maxsize=None)
    def get_base(code='ASHS') -> Base:
        return Base(code)


class Portfolio:
    def __init__(self):
        self.position = collections.defaultdict(int)

    def update_position(self, code='000001.SZ', hand=100):
        self.position[code] += hand

    def get_position(self, code='000001.SZ'):
        return self.position[code]


class PortfolioInfo:
    def __init__(self, portfolio: Portfolio, start_date: str, end_date: str, date: str, avail_cash: float):
        self.position = {}
        for k, v in portfolio.position.items():
            self.position[k] = {'hand': v, 'value': Util.get_price(k).close(date) * v * 100}
        self.avail_cash = avail_cash
        self.order = []
        self.date = date
        self.start_date = start_date
        self.end_date = end_date

    def total_value(self) -> float:
        res = 0
        for k, v in self.position.items():
            res += v['value']
        return res

    def rate(self) -> float:
        return self.total_value() / (self.avail_cash + self.total_value())

    def order_by_hand(self, code='000001.SZ', hand=100.):
        self.order.append({"code": code, "hand": hand})

    def order_by_cash(self, code='000001.SZ', cash=10000.):
        self.order.append({"code": code, "cash": cash})


class Simulator:
    def __init__(self, strategy, init_cash=100000., start_date='20200101', end_date='20210101', duty=0.001,
                 brokerage=0.00025):
        self.init_cash = init_cash
        self.cash = init_cash
        self.start_date = start_date
        self.end_date = end_date
        self.duty = duty
        self.brokerage = brokerage
        self.strategy = strategy
        self.portfolio = Portfolio()
        df: pd.DataFrame = pro.trade_cal(start_date=start_date, end_date=end_date)
        df = df.loc[df['is_open'] == 1]
        self.trade_dates = list(df['cal_date'])
        self.date = 0
        self.statistics = pd.DataFrame(columns=['date', 'cash', 'stock', 'total', 'rate', 'base'])

    def _produce_order_hand(self, o: dict, code: str, hand: float, price: float, buy_limit: 0.0):
        if hand == 0:
            return
        price *= 100
        if hand > 0:
            l, r = 0, hand
            while l < r:
                m = (l + r + 1) // 2
                if price * m + max(price * m * self.brokerage, 5) <= buy_limit:
                    l = m
                else:
                    r = m - 1
            hand = l
            cost = price * hand + max(price * hand * self.brokerage, 5)
            logging.info('%s交易成功,买入%d手,花去%f' % (str(o), int(hand), cost))
            self.portfolio.update_position(code, hand)
            self.cash -= cost
        else:
            hand = -hand
            hand = int(min(hand, self.portfolio.get_position(code)))
            g = price * hand * (1 - self.duty)
            logging.info('%s交易成功,卖出%d手,得到%f' % (str(o), int(hand), g))
            self.portfolio.update_position(code, -hand)
            self.cash += g

    def _produce_order(self, info: PortfolioInfo):
        for o in info.order:
            price = Util.get_price(o['code'])
            p = price.open(info.date)
            if 'hand' in o:
                self._produce_order_hand(o, o['code'], o['hand'], p, self.cash)
            elif 'cash' in o:
                self._produce_order_hand(o, o['code'], o['cash'] // (p * 100), p, o['cash'])
            else:
                logging.info('%s非法交易' % (str(o)))

    def _collect(self):
        info = PortfolioInfo(self.portfolio, self.start_date, self.end_date, self.trade_dates[self.date], self.cash)
        data = {
            'date': self.trade_dates[self.date],
            'cash': self.cash,
            'stock': info.total_value(),
            'total': self.cash + info.total_value(),
            'rate': (self.cash + info.total_value()) / self.init_cash - 1,
            'base': Util.get_base().get_base(self.start_date, self.trade_dates[self.date])
        }
        self.statistics = self.statistics.append(data, ignore_index=True)
        logging.info(data)

    def _open(self, info: PortfolioInfo):
        self.date += 1
        if self.date >= len(self.trade_dates):
            return
        self._produce_order(info)
        self._collect()

    def _close(self):
        info = PortfolioInfo(self.portfolio, self.start_date, self.end_date, self.trade_dates[self.date], self.cash)
        self.strategy.on_close(info)
        self._open(info)

    def run(self):
        while self.date < len(self.trade_dates):
            self._close()
        f = plt.figure()
        #ax1 = f.add_subplot()
        #ax1.plot(self.statistics.index, self.statistics['cash'], label='cash')
        #ax1.plot(self.statistics.index, self.statistics['stock'], label='stock')
        #ax1.plot(self.statistics.index, self.statistics['total'], label='total')
        #ax1.set_title('statistics')
        #plt.legend()
        #ax2 = ax1.twinx()

        def to_percent(temp, position):
            return '%.2f' % (100 * temp) + '%'

        #ax2.yaxis.set_major_formatter(FuncFormatter(to_percent))
        plt.plot(self.statistics.index, self.statistics['rate'], label='rate')
        plt.plot(self.statistics.index, self.statistics['base'], label='base')
        plt.legend()
        plt.show()


class MACDStrategy:
    def __init__(self):
        self.code = '000001.SZ'

    def on_close(self, info: PortfolioInfo):
        p = Util.get_price(self.code)
        df = p.get(info.date)
        df = df.sort_values('trade_date')
        df['ema12'] = df['close'].ewm(span=12).mean()
        df['ema26'] = df['close'].ewm(span=26).mean()
        df['diff'] = df['ema12'] - df['ema26']
        df['dea'] = df['diff'].ewm(span=9).mean()
        df['histogram'] = df['diff'] - df['dea']
        diff_list = list(df['diff'])
        if diff_list[-1] > 0 > diff_list[-2]:
            info.order_by_cash(self.code, info.avail_cash/2)
        elif diff_list[-1] < 0 < diff_list[-2]:
            info.order_by_cash(self.code, -info.total_value()/2)


if __name__ == '__main__':
    strategy = MACDStrategy()
    s = Simulator(strategy, 1000000)
    s.run()