Skip to content

Commit

Permalink
0.9.59 update
Browse files Browse the repository at this point in the history
  • Loading branch information
zengbin93 committed Sep 2, 2024
1 parent 907d53c commit dd65528
Show file tree
Hide file tree
Showing 2 changed files with 164 additions and 7 deletions.
168 changes: 162 additions & 6 deletions czsc/utils/ta.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
describe: 常用技术分析指标
"""
import numpy as np
import pandas as pd


def SMA(close: np.array, timeperiod=5):
Expand All @@ -22,9 +23,9 @@ def SMA(close: np.array, timeperiod=5):
res = []
for i in range(len(close)):
if i < timeperiod:
seq = close[0: i + 1]
seq = close[0 : i + 1]
else:
seq = close[i - timeperiod + 1: i + 1]
seq = close[i - timeperiod + 1 : i + 1]
res.append(seq.mean())
return np.array(res, dtype=np.double).round(4)

Expand Down Expand Up @@ -85,11 +86,11 @@ def KDJ(close: np.array, high: np.array, low: np.array):
lv = []
for i in range(len(close)):
if i < n:
h_ = high[0: i + 1]
l_ = low[0: i + 1]
h_ = high[0 : i + 1]
l_ = low[0 : i + 1]
else:
h_ = high[i - n + 1: i + 1]
l_ = low[i - n + 1: i + 1]
h_ = high[i - n + 1 : i + 1]
l_ = low[i - n + 1 : i + 1]
hv.append(max(h_))
lv.append(min(l_))

Expand Down Expand Up @@ -143,3 +144,158 @@ def RSQ(close: [np.array, list]) -> float:
rsq = 1 - ss_err / ss_tot

return round(rsq, 4)


def plus_di(high, low, close, timeperiod=14):
"""
Calculate Plus Directional Indicator (PLUS_DI) manually.
Parameters:
high (pd.Series): High price series.
low (pd.Series): Low price series.
close (pd.Series): Closing price series.
timeperiod (int): Number of periods to consider for the calculation.
Returns:
pd.Series: Plus Directional Indicator values.
"""
# Calculate the +DM (Directional Movement)
dm_plus = high - high.shift(1)
dm_plus[dm_plus < 0] = 0 # Only positive differences are considered

# Calculate the True Range (TR)
tr = pd.concat([high - low, (high - close.shift(1)).abs(), (low - close.shift(1)).abs()], axis=1).max(axis=1)

# Smooth the +DM and TR with Wilder's smoothing method
smooth_dm_plus = dm_plus.rolling(window=timeperiod).sum()
smooth_tr = tr.rolling(window=timeperiod).sum()

# Avoid division by zero
smooth_tr[smooth_tr == 0] = np.nan

# Calculate the Directional Indicator
plus_di_ = 100 * (smooth_dm_plus / smooth_tr)

return plus_di_


def minus_di(high, low, close, timeperiod=14):
"""
Calculate Minus Directional Indicator (MINUS_DI) manually.
Parameters:
high (pd.Series): High price series.
low (pd.Series): Low price series.
close (pd.Series): Closing price series.
timeperiod (int): Number of periods to consider for the calculation.
Returns:
pd.Series: Minus Directional Indicator values.
"""
# Calculate the -DM (Directional Movement)
dm_minus = (low.shift(1) - low).where((low.shift(1) - low) > (high - low.shift(1)), 0)

# Smooth the -DM with Wilder's smoothing method
smooth_dm_minus = dm_minus.rolling(window=timeperiod).sum()

# Calculate the True Range (TR)
tr = pd.concat([high - low, (high - close.shift(1)).abs(), (low - close.shift(1)).abs()], axis=1).max(axis=1)

# Smooth the TR with Wilder's smoothing method
smooth_tr = tr.rolling(window=timeperiod).sum()

# Avoid division by zero
smooth_tr[smooth_tr == 0] = pd.NA

# Calculate the Directional Indicator
minus_di_ = 100 * (smooth_dm_minus / smooth_tr.fillna(method="ffill"))

return minus_di_


def atr(high, low, close, timeperiod=14):
"""
Calculate Average True Range (ATR).
Parameters:
high (pd.Series): High price series.
low (pd.Series): Low price series.
close (pd.Series): Closing price series.
timeperiod (int): Number of periods to consider for the calculation.
Returns:
pd.Series: Average True Range values.
"""
# Calculate True Range (TR)
tr1 = high - low
tr2 = (high - close.shift()).abs()
tr3 = (close.shift() - low).abs()
tr = pd.concat([tr1, tr2, tr3], axis=1).max(axis=1)

# Calculate ATR
atr_ = tr.rolling(window=timeperiod).mean()

return atr_


def MFI(high, low, close, volume, timeperiod=14):
"""
Calculate Money Flow Index (MFI).
Parameters:
high (np.array): Array of high prices.
low (np.array): Array of low prices.
close (np.array): Array of closing prices.
volume (np.array): Array of trading volumes.
timeperiod (int): Number of periods to consider for the calculation.
Returns:
np.array: Array of Money Flow Index values.
"""
# Calculate Typical Price
typical_price = (high + low + close) / 3

# Calculate Raw Money Flow
raw_money_flow = typical_price * volume

# Calculate Positive and Negative Money Flow
positive_money_flow = np.where(typical_price > typical_price.shift(1), raw_money_flow, 0)
negative_money_flow = np.where(typical_price < typical_price.shift(1), raw_money_flow, 0)

# Calculate Money Ratio
money_ratio = (
positive_money_flow.rolling(window=timeperiod).sum() / negative_money_flow.rolling(window=timeperiod).sum()
)

# Calculate Money Flow Index
mfi = 100 - (100 / (1 + money_ratio))

return mfi


def CCI(high, low, close, timeperiod=14):
"""
Calculate Commodity Channel Index (CCI).
Parameters:
high (np.array): Array of high prices.
low (np.array): Array of low prices.
close (np.array): Array of closing prices.
timeperiod (int): Number of periods to consider for the calculation.
Returns:
np.array: Array of Commodity Channel Index values.
"""
# Typical Price
typical_price = (high + low + close) / 3

# Mean Deviation
mean_typical_price = np.mean(typical_price, axis=0)
mean_deviation = np.mean(np.abs(typical_price - mean_typical_price), axis=0)

# Constant
constant = 1 / (0.015 * timeperiod)

# CCI Calculation
cci = (typical_price - mean_typical_price) / (constant * mean_deviation)
return cci
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -27,4 +27,5 @@ optuna
cryptography
pytz
flask
scipy
scipy
requests_toolbelt

0 comments on commit dd65528

Please sign in to comment.