File size: 5,628 Bytes
e7b7078
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import re
import os
import json
import torch
import torch.nn as nn
from urllib.parse import urlparse
from transformers import AutoModel, AutoConfig, AutoTokenizer
from transformers.modeling_outputs import SequenceClassifierOutput

PROFILE_SLUGS = re.compile(
    r'/(profile|store|shop|freelancers?|biz|therapists?|counsellors?|'
    r'restaurants?|menu|cottage|actors?|celebrants?|broker-finder|'
    r'users?|usr|sellers?|vendors?|merchants?|dealers?|agents?|'
    r'members?|str|book|booking|appointments?)(/|$)', re.IGNORECASE
)

NUM_TABULAR_FEATURES = 6
NUMERIC_ID_IN_PATH = re.compile(r'/\d{3,}(/|$)')
TABULAR_HIDDEN_SIZE = 128

KNOWN_PLATFORMS_PATH = os.path.join(os.path.dirname(__file__), "known_platforms.json")
with open(KNOWN_PLATFORMS_PATH) as _f:
    KNOWN_PLATFORMS = set(json.load(_f))

try:
    import tldextract
    _get_registered_domain = lambda url: tldextract.extract(url).registered_domain.lower()
    _tld = lambda url: tldextract.extract(url).suffix.lower()
except ImportError:
    _get_registered_domain = lambda url: '.'.join(urlparse(url).netloc.lower().split('.')[-2:])
    _tld = lambda url: urlparse(url).netloc.lower().split('.')[-1]

_subdomain_dot_count = lambda url: max(0, urlparse(url).netloc.count('.') - 1)
_path_depth = lambda url: len([s for s in urlparse(url).path.split('/') if s])

extract_tabular_features = lambda url: [
    1.0 if PROFILE_SLUGS.search(urlparse(url).path.lower()) else 0.0,
    1.0 if _get_registered_domain(url) in KNOWN_PLATFORMS else 0.0,
    min(_path_depth(url) / 10.0, 1.0),
    min(_subdomain_dot_count(url) / 3.0, 1.0),
    1.0 if NUMERIC_ID_IN_PATH.search(urlparse(url).path) else 0.0,
    1.0 if _tld(url) == 'jp' else 0.0,
]


class UrlBertWithTabular(nn.Module):
    def __init__(self, bert_model_name, num_labels, num_tabular_features=NUM_TABULAR_FEATURES):
        super().__init__()
        self.bert = AutoModel.from_pretrained(bert_model_name)
        self.hidden_size = self.bert.config.hidden_size
        self.num_labels = num_labels
        self.num_tabular_features = num_tabular_features
        self.tabular_proj = nn.Sequential(
            nn.Linear(num_tabular_features, TABULAR_HIDDEN_SIZE),
            nn.ReLU(),
            nn.Dropout(0.1),
        )
        self.classifier = nn.Linear(self.hidden_size + TABULAR_HIDDEN_SIZE, num_labels)

    def forward(self, input_ids=None, attention_mask=None, token_type_ids=None, tabular_features=None, **kwargs):
        bert_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        cls_output = bert_output.last_hidden_state[:, 0, :]
        tabular_proj = self.tabular_proj(tabular_features.float())
        combined = torch.cat([cls_output, tabular_proj], dim=1)
        logits = self.classifier(combined)
        return SequenceClassifierOutput(logits=logits)

    @classmethod
    def from_pretrained(cls, save_directory):
        with open(os.path.join(save_directory, "tabular_config.json")) as f:
            tabular_config = json.load(f)
        bert_config = AutoConfig.from_pretrained(save_directory)
        model = cls.__new__(cls)
        nn.Module.__init__(model)
        model.bert = AutoModel.from_config(bert_config)
        model.hidden_size = bert_config.hidden_size
        model.num_labels = tabular_config["num_labels"]
        model.num_tabular_features = tabular_config["num_tabular_features"]
        model.tabular_proj = nn.Sequential(
            nn.Linear(model.num_tabular_features, TABULAR_HIDDEN_SIZE),
            nn.ReLU(),
            nn.Dropout(0.1),
        )
        model.classifier = nn.Linear(model.hidden_size + TABULAR_HIDDEN_SIZE, model.num_labels)
        safetensors_path = os.path.join(save_directory, "model.safetensors")
        bin_path = os.path.join(save_directory, "pytorch_model.bin")
        if os.path.exists(safetensors_path):
            from safetensors.torch import load_file
            state_dict = load_file(safetensors_path)
        else:
            state_dict = torch.load(bin_path, map_location="cpu", weights_only=True)
        model.load_state_dict(state_dict)
        return model


LABEL_MAP = {0: "official_website", 1: "platform"}


class EndpointHandler:
    def __init__(self, path=""):
        self.model = UrlBertWithTabular.from_pretrained(path)
        self.model.eval()
        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)

    def __call__(self, data):
        inputs = data.get("inputs", data)
        if isinstance(inputs, str):
            inputs = [inputs]

        encodings = self.tokenizer(
            inputs, padding=True, truncation=True, max_length=128, return_tensors="pt"
        ).to(self.device)

        tabular = torch.tensor(
            [extract_tabular_features(url) for url in inputs], dtype=torch.float32
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(
                input_ids=encodings["input_ids"],
                attention_mask=encodings["attention_mask"],
                tabular_features=tabular,
            )

        probs = torch.softmax(outputs.logits, dim=-1)
        results = []
        for i in range(len(inputs)):
            scores = probs[i].tolist()
            predictions = [
                {"label": LABEL_MAP.get(j, f"LABEL_{j}"), "score": scores[j]}
                for j in range(len(scores))
            ]
            predictions.sort(key=lambda x: x["score"], reverse=True)
            results.append(predictions)

        return results