File size: 3,540 Bytes
4d919ad
 
1e43f57
4d919ad
 
 
49c6030
4d919ad
 
 
 
 
 
 
 
 
 
 
 
 
 
1e43f57
4d919ad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
import gradio as gr
import numpy as np
import torch
import pandas as pd
from chronos import ChronosPipeline
from io import StringIO

# --- Model Loading ---
# This part is outside the function so it only runs once when the app starts
try:
    model_name = "amazon/chronos-t5-small"
    pipeline = ChronosPipeline.from_pretrained(
        model_name,
        device_map="cpu", # Force CPU usage for free tier
        torch_dtype=torch.float32,
    )
    print(f"Loaded model: {model_name}")
except Exception as e:
    # A fallback in case the model fails to load
    print(f"Error loading model: {e}")
    pipeline = None

# --- Prediction Function ---
def forecast_time_series(csv_file, prediction_length):
    """
    Takes a CSV file, extracts the last column (time series), and forecasts.
    """
    if pipeline is None:
        return "Model failed to load. Please check logs/dependencies."

    try:
        # Read the CSV file content from the Gradio InputFile
        content = csv_file.read().decode('utf-8')
        df = pd.read_csv(StringIO(content))
        
        # Assume the time series data is in the last column
        # and has no missing values
        historical_data = df.iloc[:, -1].values
        
        if len(historical_data) < 50:
             return "Please upload a time series with at least 50 historical points for a good forecast."

        # Convert historical data to the required format
        historical_series = torch.tensor(historical_data, dtype=torch.float32)

        # Generate the forecast
        forecast_samples = pipeline.predict(
            historical_series,
            prediction_length=int(prediction_length),
            num_samples=20, # Number of probabilistic paths to generate
        )
        
        # Calculate the median for the central prediction line
        median_forecast = np.quantile(forecast_samples.numpy(), 0.5, axis=0)
        
        # Prepare the output data for plotting
        historical_index = np.arange(len(historical_data))
        forecast_index = np.arange(len(historical_data), len(historical_data) + int(prediction_length))
        
        # Create a single plot with both historical and forecast data
        plot_data = {
            "Historical": list(historical_data),
            "Forecast": list(median_forecast),
        }
        
        return {
            "Historical": (historical_index, historical_data),
            "Forecast": (forecast_index, median_forecast)
        }

    except Exception as e:
        return f"An error occurred: {e}"

# --- Gradio Interface Setup ---
# Define the example input file structure (for user convenience)
example_data = [
    [
        'date,value\n2025-01-01,10.0\n2025-01-02,11.5\n...\n2025-03-20,15.2',
        7
    ] # A sample input isn't a file, so it can't be added directly here. 
      # Users will need to upload a CSV file manually.
]


gr_plot = gr.Plot(label="Time Series Forecast (Historical + Predicted Median)")

gr.Interface(
    fn=forecast_time_series,
    inputs=[
        gr.File(label="Upload a CSV file (Time series must be in the last column)"),
        gr.Slider(minimum=7, maximum=30, step=1, value=14, label="Number of Future Steps (Days) to Predict"),
    ],
    outputs=gr_plot,
    title="Chronos Time Series Forecasting Demo on Hugging Face",
    description="Upload a CSV file containing a single historical time series. This demo uses the Chronos-T5-Small Foundation Model to generate a 14-day (default) forecast.",
    examples=None,
    live=False,
).launch()