【实战笔记】一个通用方法团灭6道股票问题

强烈建议阅读原文:一个通用方法团灭6道股票问题。受益于这篇文章的启发,团灭了6道股票问题,下面将我自己的理解整理如下。

解动态规划类问题有两个套路:一是先定义清楚状态 ,二是定义出状态转移方程。

这6道问题,我们可以抽象出这些状态:第i天的股票价格、最多的交易次数k、每天可能买入或卖出也可能什么都不做的行为j。将行为j进行简化,把它定义成当天持有股票、不持有股票两种状态。我们可以得到以下状态定义:

1
2
3
4
5
6
7
8
# 状态定义:dp[i][k][1 or 0]
i表示第i天,0 <= i <= n-1
k表示第k天,1 <= k <= K,假设在buy时k才加1(你也可以定义在sell时才加1
1表示当天持有者股票,0表示当天不持有

dp[i][k][1 or 0]表示第i天交易k次持有或不持有股票的最大利益

最终要求的其实就是dp[i][k][0](最后一天抛售完股票利益最大化)

进一步递推状态方程,可以如下描述:

1
2
3
4
5
6
7
8
9
# 转移方程:
dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1]+prices[i])
max(昨天就没持有, 昨天持有今天卖出)
dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0]-prices[i])
max(昨天就持有, 昨天没有今天买入)

在实际编程中,要注意dp[i][k][1 or 0]中的i不能为-1,k不能为0,此为边界条件。按照常识不难得出以下结论:
dp[0][k][0] = 0 # 第0天不持有股票,无盈亏,最大利益为0
dp[0][k][1] = 0 - prices[0] # 第0天买入股票,最大利益为负的prices[0]

下面按照上述的递推方程来团灭6道股票问题,代码中含有详尽的日志说明。

股票问题1

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Solution(object):
def maxProfit(self, prices):
"""
:type prices: List[int]
:rtype: int
"""
# 状态定义:dp[i][k][1 or 0]
# i表示第i天,0 <= i <= n-1
# k表示第k天,1 <= k <= K,假设在buy时k才加1
# 1表示当天持有者股票,0表示当天不持有
# 转移方程:
# dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1]+prices[i])
# max(昨天就没持有, 昨天持有今天卖出)
# dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0]-prices[i])
# max(昨天就持有, 昨天没有今天买入)
# 边界:i为-1,k为0
# dp[-1][k][0] = dp[i][0][0] = 0
# dp[-1][k][1] = dp[i][0][1] = -infinite
#
# k = 1(dp[i-1][0][0]为0,k都为1可舍去)
# dp[i][1][0] = max(dp[i-1][1][0], dp[i-1][1][1]+prices[i])
# dp[i][1][1] = max(dp[i-1][1][1], dp[i-1][0][0]-prices[i])
#
# dp[i][0] = max(dp[i-1][0], dp[i-1][1]+prices[i])
# dp[i][1] = max(dp[i-1][1], -prices[i])
if len(prices) < 2:
return 0
dp = [[0] * 2 for _ in range(len(prices))]
dp[0][0] = 0
dp[0][1] = 0 - prices[0]
for i in range(1, len(prices)):
dp[i][0] = max(dp[i-1][0], dp[i-1][1]+prices[i])
dp[i][1] = max(dp[i-1][1], -prices[i])
return dp[-1][0]

股票问题2

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
class Solution(object):
def maxProfit(self, prices):
"""
:type prices: List[int]
:rtype: int
"""
# 状态定义:dp[i][k][1 or 0]
# i表示第i天,0 <= i <= n-1
# k表示第k天,1 <= k <= K,假设在buy时k才加1
# 1表示当天持有者股票,0表示当天不持有
# 转移方程:
# dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1]+prices[i])
# max(昨天就没持有, 昨天持有今天卖出)
# dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0]-prices[i])
# max(昨天就持有, 昨天没有今天买入)
#
# k这个状态可摘去,目标是求最大化利益,对k值取多少无要求
# dp[i][0] = max(dp[i-1][0], dp[i-1][1]+prices[i])
# max(昨天就没持有, 昨天持有今天卖出)
# dp[i][1] = max(dp[i-1][1], dp[i-1][0]-prices[i])
# max(昨天就持有, 昨天没有今天买入)
if len(prices) < 2:
return 0
dp = [[0] * 2 for _ in range(len(prices))]
dp[0][0] = 0
dp[0][1] = 0 - prices[0]
for i in range(1, len(prices)):
dp[i][0] = max(dp[i-1][0], dp[i-1][1]+prices[i])
dp[i][1] = max(dp[i-1][1], dp[i-1][0]-prices[i])
return dp[-1][0]

股票问题3

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
class Solution(object):
def maxProfit(self, prices):
"""
:type prices: List[int]
:rtype: int
"""
# 状态定义:dp[i][k][1 or 0]
# i表示第i天,0 <= i <= n-1
# k表示第k天,1 <= k <= K,假设在buy时k才加1
# 1表示当天持有者股票,0表示当天不持有
# 转移方程:
# dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1]+prices[i])
# max(昨天就没持有, 昨天持有今天卖出)
# dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0]-prices[i])
# max(昨天就持有, 昨天没有今天买入)
# 边界:i为-1,k为0
# dp[-1][k][0] = dp[i][0][0] = 0
# dp[-1][k][1] = dp[i][0][1] = -infinite
#
# k = 2
if len(prices) < 2:
return 0
max_k = 2 + 1
# 这样定义多维数组有坑啊,[0] * 2为一个list,后面的乘法都是乘的引用数...
# dp = [[[0] * 2] * max_k] * len(prices)
dp = [[[0] * 2 for _ in range(max_k)] for _ in range(len(prices))]
print(dp)
for k in range(1, max_k):
dp[0][k][0] = 0
dp[0][k][1] = 0 - prices[0]
for i in range(1, len(prices)):
dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1]+prices[i])
dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0]-prices[i])
return dp[-1][-1][0]

股票问题4

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
class Solution(object):
def maxProfit(self, k, prices):
"""
:type prices: List[int]
:rtype: int
"""
# 状态定义:dp[i][k][1 or 0]
# i表示第i天,0 <= i <= n-1
# k表示第k天,1 <= k <= K,假设在buy时k才加1
# 1表示当天持有者股票,0表示当天不持有
# 转移方程:
# dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1]+prices[i])
# max(昨天就没持有, 昨天持有今天卖出)
# dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0]-prices[i])
# max(昨天就持有, 昨天没有今天买入)
# 边界:i为-1,k为0
# dp[-1][k][0] = dp[i][0][0] = 0
# dp[-1][k][1] = dp[i][0][1] = -infinite
#
# k = 2
if len(prices) < 2:
return 0
# 等同于k不限次数的情况
if k > len(prices) / 2:
dp = [[0] * 2 for _ in range(len(prices))]
dp[0][0] = 0
dp[0][1] = 0 - prices[0]
for i in range(1, len(prices)):
dp[i][0] = max(dp[i-1][0], dp[i-1][1]+prices[i])
dp[i][1] = max(dp[i-1][1], dp[i-1][0]-prices[i])
return dp[-1][0]
max_k = k + 1
# 这样定义多维数组有坑啊,[0] * 2为一个list,后面的乘法都是乘的引用数...
# dp = [[[0] * 2] * max_k] * len(prices)
dp = [[[0] * 2 for _ in range(max_k)] for _ in range(len(prices))]
for _k in range(1, max_k):
dp[0][_k][0] = 0
dp[0][_k][1] = 0 - prices[0]
for i in range(1, len(prices)):
dp[i][_k][0] = max(dp[i-1][_k][0], dp[i-1][_k][1]+prices[i])
dp[i][_k][1] = max(dp[i-1][_k][1], dp[i-1][_k-1][0]-prices[i])
return dp[-1][-1][0]

股票问题含手续费

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution(object):
def maxProfit(self, prices, fee):
"""
:type prices: List[int]
:rtype: int
"""
# 状态定义:dp[i][k][1 or 0]
# i表示第i天,0 <= i <= n-1
# k表示第k天,1 <= k <= K,假设在buy时k才加1
# 1表示当天持有者股票,0表示当天不持有
# 转移方程:
# dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1]+prices[i])
# max(昨天就没持有, 昨天持有今天卖出)
# dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0]-prices[i])
# max(昨天就持有, 昨天没有今天买入)
# 边界:i为-1,k为0
# dp[-1][k][0] = dp[i][0][0] = 0
# dp[-1][k][1] = dp[i][0][1] = -infinite
#
# k这个状态可摘去,目标是求最大化利益,对k值取多少无要求
# dp[i][0] = max(dp[i-1][0], dp[i-1][1]+prices[i])
# max(昨天就没持有, 昨天持有今天卖出)
# dp[i][1] = max(dp[i-1][1], dp[i-1][0]-prices[i])
# max(昨天就持有, 昨天没有今天买入)
#
# 新增的限定条件为卖出股票时需要给手续费
# dp[i][k][0] = max(dp[i-1][0], dp[i-1][1]+prices[i]-fee)
if len(prices) < 2:
return 0
dp = [[0] * 2 for _ in range(len(prices))]
dp[0][0] = 0
dp[0][1] = 0 - prices[0]
for i in range(1, len(prices)):
dp[i][0] = max(dp[i-1][0], dp[i-1][1]+prices[i]-fee)
dp[i][1] = max(dp[i-1][1], dp[i-1][0]-prices[i])
return dp[-1][0]

股票问题含冻结时间

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
class Solution(object):
def maxProfit(self, prices):
"""
:type prices: List[int]
:rtype: int
"""
# 状态定义:dp[i][k][1 or 0]
# i表示第i天,0 <= i <= n-1
# k表示第k天,1 <= k <= K,假设在buy时k才加1
# 1表示当天持有者股票,0表示当天不持有
# 转移方程:
# dp[i][k][0] = max(dp[i-1][k][0], dp[i-1][k][1]+prices[i])
# max(昨天就没持有, 昨天持有今天卖出)
# dp[i][k][1] = max(dp[i-1][k][1], dp[i-1][k-1][0]-prices[i])
# max(昨天就持有, 昨天没有今天买入)
# 边界:i为-1,k为0
# dp[-1][k][0] = dp[i][0][0] = 0
# dp[-1][k][1] = dp[i][0][1] = -infinite
#
# k这个状态可摘去,目标是求最大化利益,对k值取多少无要求
# dp[i][0] = max(dp[i-1][0], dp[i-1][1]+prices[i])
# max(昨天就没持有, 昨天持有今天卖出)
# dp[i][1] = max(dp[i-1][1], dp[i-1][0]-prices[i])
# max(昨天就持有, 昨天没有今天买入)
#
# 新增的限定条件为卖出股票后无法在第二天买入,更新转移方程
# dp[i][k][1] = max(dp[i-1][1], dp[i-2][0]-prices[i])
if len(prices) < 2:
return 0
dp = [[0] * 2 for _ in range(len(prices))]
dp[0][0] = 0
dp[0][1] = 0 - prices[0]
for i in range(1, len(prices)):
dp[i][0] = max(dp[i-1][0], dp[i-1][1]+prices[i])
dp[i][1] = max(dp[i-1][1], dp[i-2][0]-prices[i]) # i-2表示隔天交易
return dp[-1][0]