Spaces:
Sleeping
Sleeping
Fangzhi Xu
commited on
Commit
·
006d68a
1
Parent(s):
a21ffea
Config
Browse files- app.py +69 -22
- requirements.txt +2 -1
app.py
CHANGED
|
@@ -5,6 +5,8 @@ import pandas as pd
|
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import matplotlib
|
| 7 |
import os
|
|
|
|
|
|
|
| 8 |
matplotlib.use('Agg')
|
| 9 |
|
| 10 |
class TradeArenaEnv_Deterministic:
|
|
@@ -238,35 +240,80 @@ def create_news_display(obs):
|
|
| 238 |
# ax.grid(True, alpha=0.3)
|
| 239 |
# return fig
|
| 240 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 241 |
def create_price_chart():
|
| 242 |
-
"""Create
|
| 243 |
if len(history) <= 1:
|
| 244 |
-
|
| 245 |
-
|
| 246 |
-
|
| 247 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 248 |
return fig
|
| 249 |
|
| 250 |
df = pd.DataFrame(history)
|
| 251 |
-
|
| 252 |
-
fig, axs = plt.subplots(num_stocks, 1, figsize=(10, 4*num_stocks), sharex=True)
|
| 253 |
-
|
| 254 |
-
# 如果只有一个股票,axs不是数组,需要处理
|
| 255 |
-
if num_stocks == 1:
|
| 256 |
-
axs = [axs]
|
| 257 |
-
|
| 258 |
colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
|
| 259 |
-
|
| 260 |
for i, stock in enumerate(env.stocks):
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 270 |
return fig
|
| 271 |
|
| 272 |
|
|
|
|
| 5 |
import matplotlib.pyplot as plt
|
| 6 |
import matplotlib
|
| 7 |
import os
|
| 8 |
+
import plotly.graph_objects as go
|
| 9 |
+
import pandas as pd
|
| 10 |
matplotlib.use('Agg')
|
| 11 |
|
| 12 |
class TradeArenaEnv_Deterministic:
|
|
|
|
| 240 |
# ax.grid(True, alpha=0.3)
|
| 241 |
# return fig
|
| 242 |
|
| 243 |
+
# def create_price_chart():
|
| 244 |
+
# """Create individual price chart for each stock"""
|
| 245 |
+
# if len(history) <= 1:
|
| 246 |
+
# fig, axs = plt.subplots(1, 1, figsize=(10, 6))
|
| 247 |
+
# axs.text(0.5, 0.5, 'Trade to see price history',
|
| 248 |
+
# ha='center', va='center', fontsize=14, color='gray')
|
| 249 |
+
# axs.axis('off')
|
| 250 |
+
# return fig
|
| 251 |
+
|
| 252 |
+
# df = pd.DataFrame(history)
|
| 253 |
+
# num_stocks = len(env.stocks)
|
| 254 |
+
# fig, axs = plt.subplots(num_stocks, 1, figsize=(10, 4*num_stocks), sharex=True)
|
| 255 |
+
|
| 256 |
+
# # 如果只有一个股票,axs不是数组,需要处理
|
| 257 |
+
# if num_stocks == 1:
|
| 258 |
+
# axs = [axs]
|
| 259 |
+
|
| 260 |
+
# colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
|
| 261 |
+
|
| 262 |
+
# for i, stock in enumerate(env.stocks):
|
| 263 |
+
# ax = axs[i]
|
| 264 |
+
# ax.plot(df['day'], df[stock], marker='o', linewidth=2, color=colors[i % len(colors)], label=stock)
|
| 265 |
+
# ax.set_ylabel(f'{stock} ($)')
|
| 266 |
+
# ax.set_title(f'{stock} Price History')
|
| 267 |
+
# ax.legend(loc='best', framealpha=0.8)
|
| 268 |
+
# ax.grid(True, alpha=0.3)
|
| 269 |
+
|
| 270 |
+
# axs[-1].set_xlabel('Day')
|
| 271 |
+
# plt.tight_layout()
|
| 272 |
+
# return fig
|
| 273 |
+
|
| 274 |
+
|
| 275 |
def create_price_chart():
|
| 276 |
+
"""Create stock price chart using Plotly"""
|
| 277 |
if len(history) <= 1:
|
| 278 |
+
# 没有交易历史时,返回空白图
|
| 279 |
+
fig = go.Figure()
|
| 280 |
+
fig.add_annotation(
|
| 281 |
+
text="Trade to see price history",
|
| 282 |
+
xref="paper", yref="paper",
|
| 283 |
+
showarrow=False,
|
| 284 |
+
font=dict(size=16, color="gray")
|
| 285 |
+
)
|
| 286 |
+
fig.update_layout(
|
| 287 |
+
xaxis=dict(visible=False),
|
| 288 |
+
yaxis=dict(visible=False),
|
| 289 |
+
template="plotly_white",
|
| 290 |
+
height=400
|
| 291 |
+
)
|
| 292 |
return fig
|
| 293 |
|
| 294 |
df = pd.DataFrame(history)
|
| 295 |
+
fig = go.Figure()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
|
| 297 |
+
|
| 298 |
for i, stock in enumerate(env.stocks):
|
| 299 |
+
fig.add_trace(go.Scatter(
|
| 300 |
+
x=df['day'],
|
| 301 |
+
y=df[stock],
|
| 302 |
+
mode='lines+markers',
|
| 303 |
+
name=stock,
|
| 304 |
+
line=dict(color=colors[i % len(colors)], width=2),
|
| 305 |
+
marker=dict(size=6)
|
| 306 |
+
))
|
| 307 |
+
|
| 308 |
+
fig.update_layout(
|
| 309 |
+
title="Stock Price History",
|
| 310 |
+
xaxis_title="Day",
|
| 311 |
+
yaxis_title="Price ($)",
|
| 312 |
+
template="plotly_white",
|
| 313 |
+
legend=dict(title="Stocks", orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
|
| 314 |
+
height=400 + 50 * len(env.stocks)
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
return fig
|
| 318 |
|
| 319 |
|
requirements.txt
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
gradio
|
| 2 |
numpy
|
| 3 |
pandas
|
| 4 |
-
matplotlib
|
|
|
|
|
|
| 1 |
gradio
|
| 2 |
numpy
|
| 3 |
pandas
|
| 4 |
+
matplotlib
|
| 5 |
+
plotly
|