Fangzhi Xu commited on
Commit
006d68a
·
1 Parent(s): a21ffea
Files changed (2) hide show
  1. app.py +69 -22
  2. requirements.txt +2 -1
app.py CHANGED
@@ -5,6 +5,8 @@ import pandas as pd
5
  import matplotlib.pyplot as plt
6
  import matplotlib
7
  import os
 
 
8
  matplotlib.use('Agg')
9
 
10
  class TradeArenaEnv_Deterministic:
@@ -238,35 +240,80 @@ def create_news_display(obs):
238
  # ax.grid(True, alpha=0.3)
239
  # return fig
240
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
  def create_price_chart():
242
- """Create individual price chart for each stock"""
243
  if len(history) <= 1:
244
- fig, axs = plt.subplots(1, 1, figsize=(10, 6))
245
- axs.text(0.5, 0.5, 'Trade to see price history',
246
- ha='center', va='center', fontsize=14, color='gray')
247
- axs.axis('off')
 
 
 
 
 
 
 
 
 
 
248
  return fig
249
 
250
  df = pd.DataFrame(history)
251
- num_stocks = len(env.stocks)
252
- fig, axs = plt.subplots(num_stocks, 1, figsize=(10, 4*num_stocks), sharex=True)
253
-
254
- # 如果只有一个股票,axs不是数组,需要处理
255
- if num_stocks == 1:
256
- axs = [axs]
257
-
258
  colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
259
-
260
  for i, stock in enumerate(env.stocks):
261
- ax = axs[i]
262
- ax.plot(df['day'], df[stock], marker='o', linewidth=2, color=colors[i % len(colors)], label=stock)
263
- ax.set_ylabel(f'{stock} ($)')
264
- ax.set_title(f'{stock} Price History')
265
- ax.legend(loc='best', framealpha=0.8)
266
- ax.grid(True, alpha=0.3)
267
-
268
- axs[-1].set_xlabel('Day')
269
- plt.tight_layout()
 
 
 
 
 
 
 
 
 
270
  return fig
271
 
272
 
 
5
  import matplotlib.pyplot as plt
6
  import matplotlib
7
  import os
8
+ import plotly.graph_objects as go
9
+ import pandas as pd
10
  matplotlib.use('Agg')
11
 
12
  class TradeArenaEnv_Deterministic:
 
240
  # ax.grid(True, alpha=0.3)
241
  # return fig
242
 
243
+ # def create_price_chart():
244
+ # """Create individual price chart for each stock"""
245
+ # if len(history) <= 1:
246
+ # fig, axs = plt.subplots(1, 1, figsize=(10, 6))
247
+ # axs.text(0.5, 0.5, 'Trade to see price history',
248
+ # ha='center', va='center', fontsize=14, color='gray')
249
+ # axs.axis('off')
250
+ # return fig
251
+
252
+ # df = pd.DataFrame(history)
253
+ # num_stocks = len(env.stocks)
254
+ # fig, axs = plt.subplots(num_stocks, 1, figsize=(10, 4*num_stocks), sharex=True)
255
+
256
+ # # 如果只有一个股票,axs不是数组,需要处理
257
+ # if num_stocks == 1:
258
+ # axs = [axs]
259
+
260
+ # colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
261
+
262
+ # for i, stock in enumerate(env.stocks):
263
+ # ax = axs[i]
264
+ # ax.plot(df['day'], df[stock], marker='o', linewidth=2, color=colors[i % len(colors)], label=stock)
265
+ # ax.set_ylabel(f'{stock} ($)')
266
+ # ax.set_title(f'{stock} Price History')
267
+ # ax.legend(loc='best', framealpha=0.8)
268
+ # ax.grid(True, alpha=0.3)
269
+
270
+ # axs[-1].set_xlabel('Day')
271
+ # plt.tight_layout()
272
+ # return fig
273
+
274
+
275
  def create_price_chart():
276
+ """Create stock price chart using Plotly"""
277
  if len(history) <= 1:
278
+ # 没有交易历史时,返回空白图
279
+ fig = go.Figure()
280
+ fig.add_annotation(
281
+ text="Trade to see price history",
282
+ xref="paper", yref="paper",
283
+ showarrow=False,
284
+ font=dict(size=16, color="gray")
285
+ )
286
+ fig.update_layout(
287
+ xaxis=dict(visible=False),
288
+ yaxis=dict(visible=False),
289
+ template="plotly_white",
290
+ height=400
291
+ )
292
  return fig
293
 
294
  df = pd.DataFrame(history)
295
+ fig = go.Figure()
 
 
 
 
 
 
296
  colors = ['#3b82f6', '#10b981', '#f59e0b', '#ef4444', '#8b5cf6']
297
+
298
  for i, stock in enumerate(env.stocks):
299
+ fig.add_trace(go.Scatter(
300
+ x=df['day'],
301
+ y=df[stock],
302
+ mode='lines+markers',
303
+ name=stock,
304
+ line=dict(color=colors[i % len(colors)], width=2),
305
+ marker=dict(size=6)
306
+ ))
307
+
308
+ fig.update_layout(
309
+ title="Stock Price History",
310
+ xaxis_title="Day",
311
+ yaxis_title="Price ($)",
312
+ template="plotly_white",
313
+ legend=dict(title="Stocks", orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
314
+ height=400 + 50 * len(env.stocks)
315
+ )
316
+
317
  return fig
318
 
319
 
requirements.txt CHANGED
@@ -1,4 +1,5 @@
1
  gradio
2
  numpy
3
  pandas
4
- matplotlib
 
 
1
  gradio
2
  numpy
3
  pandas
4
+ matplotlib
5
+ plotly