-->

whaust

2023年11月2日 星期四

採用「LSTM模型預測股價」作為模型的說明範例

 採用「LSTM模型預測股價」作為模型的說明範例 


當您想要使用 LSTM 模型來預測股價時,您可以使用 Python 和相關的套件,如 TensorFlow 或 PyTorch 來建立和訓練模型。以下是一個簡單的 Python 範例,演示如何使用 LSTM 預測股價。


首先,請確保您已安裝所需的套件,例如 TensorFlow 或 PyTorch。以下示例將使用 TensorFlow:






import numpy as np

import pandas as pd

import matplotlib.pyplot as plt

from sklearn.preprocessing import MinMaxScaler

from tensorflow.keras.models import Sequential

from tensorflow.keras.layers import LSTM, Dense


# 讀取股價數據

data = pd.read_csv('股價數據.csv')  # 請將 '股價數據.csv' 替換為實際數據文件


# 提取收盤價作為特徵

prices = data['Close'].values.astype(float)


# 對股價數據進行歸一化

scaler = MinMaxScaler()

prices = scaler.fit_transform(prices.reshape(-1, 1))


# 將數據分為訓練集和測試集

train_size = int(len(prices) * 0.8)

train_data = prices[:train_size]

test_data = prices[train_size:]


# 創建訓練和測試數據集

def create_dataset(dataset, look_back=1):

    X, Y = [], []

    for i in range(len(dataset) - look_back - 1):

        X.append(dataset[i:(i + look_back), 0])

        Y.append(dataset[i + look_back, 0])

    return np.array(X), np.array(Y)


look_back = 10  # 定義時間窗口大小

trainX, trainY = create_dataset(train_data, look_back)

testX, testY = create_dataset(test_data, look_back)


# 創建並訓練 LSTM 模型

model = Sequential()

model.add(LSTM(50, input_shape=(look_back, 1)))

model.add(Dense(1))

model.compile(loss='mean_squared_error', optimizer='adam')

model.fit(trainX, trainY, epochs=100, batch_size=1, verbose=2)


# 使用模型進行預測

train_predict = model.predict(trainX)

test_predict = model.predict(testX)


# 還原歸一化的數據

train_predict = scaler.inverse_transform(train_predict)

test_predict = scaler.inverse_transform(test_predict)


# 繪製訓練集和測試集的預測結果

plt.plot(prices, label='Actual Prices')

plt.plot(range(look_back, len(train_predict) + look_back), train_predict, label='Train Predictions')

plt.plot(range(len(train_predict) + (look_back * 2) + 1, len(prices)), test_predict, label='Test Predictions')

plt.legend()

plt.show()



這只是一個簡單的 LSTM 模型預測股價的示例。實際上,股價預測是一個相對複雜的問題,需要更多的特徵工程、調參和數據處理。該示例僅供參考,您可以根據您的需求進行進一步的優化和改進。

2 則留言:

  1. 請問'股價數據.csv'的欄位格式?謝謝

    回覆刪除
  2. 同學,欄位可以自己定義啊,看你要分析甚麼資料

    回覆刪除

Popular