Skip to content

BERTSubs (Intra)

BERTSubsIntraPipeline(onto, config)

Class for the intra-ontology subsumption prediction setting of BERTSubs.

Attributes:

Name Type Description
onto Ontology

The target ontology.

config CfgNode

The configuration for BERTSubs.

sampler SubsumptionSample

The subsumption sampler for BERTSubs.

Source code in src/deeponto/complete/bertsubs/pipeline_intra.py
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
def __init__(self, onto: Ontology, config: CfgNode):
    self.onto = onto
    self.config = config
    self.sampler = SubsumptionSampler(onto=onto, config=config)
    start_time = datetime.datetime.now()

    n = 0
    for k in self.sampler.named_classes:
        n += len(self.sampler.iri_label[k])
    print(
        "%d named classes, %.1f labels per class"
        % (len(self.sampler.named_classes), n / len(self.sampler.named_classes))
    )

    read_subsumptions = lambda file_name: [line.strip().split(",") for line in open(file_name).readlines()]
    test_subsumptions = (
        None
        if config.test_subsumption_file is None or config.test_subsumption_file == "None"
        else read_subsumptions(config.test_subsumption_file)
    )

    # The train/valid subsumptions are not given. They will be extracted from the given ontology:
    if config.train_subsumption_file is None or config.train_subsumption_file == "None":
        subsumptions0 = self.extract_subsumptions_from_ontology(
            onto=onto, subsumption_type=config.subsumption_type
        )
        random.shuffle(subsumptions0)
        valid_size = int(len(subsumptions0) * config.valid.valid_ratio)
        train_subsumptions0, valid_subsumptions0 = subsumptions0[valid_size:], subsumptions0[0:valid_size]
        train_subsumptions, valid_subsumptions = [], []
        if config.subsumption_type == "named_class":
            for subs in train_subsumptions0:
                c1, c2 = subs.getSubClass(), subs.getSuperClass()
                train_subsumptions.append([str(c1.getIRI()), str(c2.getIRI())])

            size_sum = 0
            for subs in valid_subsumptions0:
                c1, c2 = subs.getSubClass(), subs.getSuperClass()
                neg_candidates = BERTSubsIntraPipeline.get_test_neg_candidates_named_class(
                    subclass=c1, gt=c2, max_neg_size=config.valid.max_neg_size, onto=onto
                )
                size = len(neg_candidates)
                size_sum += size
                if size > 0:
                    item = [str(c1.getIRI()), str(c2.getIRI())] + [str(c.getIRI()) for c in neg_candidates]
                    valid_subsumptions.append(item)
            print("\t average neg candidate size in validation: %.2f" % (size_sum / len(valid_subsumptions)))

        elif config.subsumption_type == "restriction":
            for subs in train_subsumptions0:
                c1, c2 = subs.getSubClass(), subs.getSuperClass()
                train_subsumptions.append([str(c1.getIRI()), str(c2)])

            restrictions = BERTSubsIntraPipeline.extract_restrictions_from_ontology(onto=onto)
            print("restrictions: %d" % len(restrictions))
            size_sum = 0
            for subs in valid_subsumptions0:
                c1, c2 = subs.getSubClass(), subs.getSuperClass()
                c2_neg = BERTSubsIntraPipeline.get_test_neg_candidates_restriction(
                    subcls=c1, max_neg_size=config.valid.max_neg_size, restrictions=restrictions, onto=onto
                )
                size_sum += len(c2_neg)
                item = [str(c1.getIRI()), str(c2)] + [str(r) for r in c2_neg]
                valid_subsumptions.append(item)
            print("valid candidate negative avg. size: %.1f" % (size_sum / len(valid_subsumptions)))
        else:
            warnings.warn("Unknown subsumption type %s" % config.subsumption_type)
            sys.exit(0)

    # The train/valid subsumptions are given:
    else:
        train_subsumptions = read_subsumptions(config.train_subsumption_file)
        valid_subsumptions = read_subsumptions(config.valid_subsumption_file)

    print("Positive train/valid subsumptions: %d/%d" % (len(train_subsumptions), len(valid_subsumptions)))
    tr = self.sampler.generate_samples(subsumptions=train_subsumptions)
    va = self.sampler.generate_samples(subsumptions=valid_subsumptions, duplicate=False)

    end_time = datetime.datetime.now()
    print("data pre-processing costs %.1f minutes" % ((end_time - start_time).seconds / 60))

    start_time = datetime.datetime.now()
    torch.cuda.empty_cache()
    bert_trainer = BERTSubsumptionClassifierTrainer(
        config.fine_tune.pretrained,
        train_data=tr,
        val_data=va,
        max_length=config.prompt.max_length,
        early_stop=config.fine_tune.early_stop,
    )

    epoch_steps = len(bert_trainer.tra) // config.fine_tune.batch_size  # total steps of an epoch
    logging_steps = int(epoch_steps * 0.02) if int(epoch_steps * 0.02) > 0 else 5
    eval_steps = 5 * logging_steps
    training_args = TrainingArguments(
        output_dir=config.fine_tune.output_dir,
        num_train_epochs=config.fine_tune.num_epochs,
        per_device_train_batch_size=config.fine_tune.batch_size,
        per_device_eval_batch_size=config.fine_tune.batch_size,
        warmup_ratio=config.fine_tune.warm_up_ratio,
        weight_decay=0.01,
        logging_steps=logging_steps,
        logging_dir=f"{config.fine_tune.output_dir}/tb",
        eval_steps=eval_steps,
        evaluation_strategy="steps",
        do_train=True,
        do_eval=True,
        save_steps=eval_steps,
        load_best_model_at_end=True,
        save_total_limit=1,
        metric_for_best_model="accuracy",
        greater_is_better=True,
    )
    if config.fine_tune.do_fine_tune and (
        config.prompt.prompt_type == "traversal"
        or (config.prompt.prompt_type == "path" and config.prompt.use_sub_special_token)
    ):
        bert_trainer.add_special_tokens(["<SUB>"])

    bert_trainer.train(train_args=training_args, do_fine_tune=config.fine_tune.do_fine_tune)
    if config.fine_tune.do_fine_tune:
        bert_trainer.trainer.save_model(
            output_dir=os.path.join(config.fine_tune.output_dir, "fine-tuned-checkpoint")
        )
        print("fine-tuning done, fine-tuned model saved")
    else:
        print("pretrained or fine-tuned model loaded.")
    end_time = datetime.datetime.now()
    print("Fine-tuning costs %.1f minutes" % ((end_time - start_time).seconds / 60))

    bert_trainer.model.eval()
    self.device = torch.device(f"cuda") if torch.cuda.is_available() else torch.device("cpu")
    bert_trainer.model.to(self.device)
    self.tokenize = lambda x: bert_trainer.tokenizer(
        x, max_length=config.prompt.max_length, truncation=True, padding=True, return_tensors="pt"
    )
    softmax = torch.nn.Softmax(dim=1)
    self.classifier = lambda x: softmax(bert_trainer.model(**x).logits)[:, 1]

    self.evaluate(target_subsumptions=valid_subsumptions, test_type="valid")
    if test_subsumptions is not None:
        if config.test_type == "evaluation":
            self.evaluate(target_subsumptions=test_subsumptions, test_type="test")
        elif config.test_type == "prediction":
            self.predict(target_subsumptions=test_subsumptions)
        else:
            warnings.warn("Unknown test_type: %s" % config.test_type)
    print("\n ------------------------- done! ---------------------------\n\n\n")

score(samples)

The scoring function based on the fine-tuned BERT classifier.

Parameters:

Name Type Description Default
samples List[Tuple]

A list of input sentence pairs to be scored.

required
Source code in src/deeponto/complete/bertsubs/pipeline_intra.py
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def score(self, samples: List[List]):
    r"""The scoring function based on the fine-tuned BERT classifier.

    Args:
        samples (List[Tuple]): A list of input sentence pairs to be scored.
    """
    sample_size = len(samples)
    scores = np.zeros(sample_size)
    batch_num = math.ceil(sample_size / self.config.evaluation.batch_size)
    for i in range(batch_num):
        j = (
            (i + 1) * self.config.evaluation.batch_size
            if (i + 1) * self.config.evaluation.batch_size <= sample_size
            else sample_size
        )
        inputs = self.tokenize(samples[i * self.config.evaluation.batch_size : j])
        inputs.to(self.device)
        with torch.no_grad():
            batch_scores = self.classifier(inputs)
        scores[i * self.config.evaluation.batch_size : j] = batch_scores.cpu().numpy()
    return scores

evaluate(target_subsumptions, test_type='test')

Test and calculate the metrics for a given list of subsumption pairs.

Parameters:

Name Type Description Default
target_subsumptions List[Tuple]

A list of subsumption pairs.

required
test_type str

test for testing or valid for validation.

'test'
Source code in src/deeponto/complete/bertsubs/pipeline_intra.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def evaluate(self, target_subsumptions: List[List], test_type: str = "test"):
    r"""Test and calculate the metrics for a given list of subsumption pairs.

    Args:
        target_subsumptions (List[Tuple]): A list of subsumption pairs.
        test_type (str): `test` for testing or `valid` for validation.
    """

    MRR_sum, hits1_sum, hits5_sum, hits10_sum = 0, 0, 0, 0
    MRR, Hits1, Hits5, Hits10 = 0, 0, 0, 0
    size_sum, size_n = 0, 0
    for k0, test in enumerate(target_subsumptions):
        subcls, gt = test[0], test[1]
        candidates = test[1:]

        candidate_subsumptions = [[subcls, c] for c in candidates]
        candidate_scores = np.zeros(len(candidate_subsumptions))
        for k1, candidate_subsumption in enumerate(candidate_subsumptions):
            samples = self.sampler.subsumptions_to_samples(subsumptions=[candidate_subsumption], sample_label=None)
            size_sum += len(samples)
            size_n += 1
            scores = self.score(samples=samples)
            candidate_scores[k1] = np.average(scores)

        sorted_indexes = np.argsort(candidate_scores)[::-1]
        sorted_classes = [candidates[i] for i in sorted_indexes]

        rank = sorted_classes.index(gt) + 1
        MRR_sum += 1.0 / rank
        hits1_sum += 1 if gt in sorted_classes[:1] else 0
        hits5_sum += 1 if gt in sorted_classes[:5] else 0
        hits10_sum += 1 if gt in sorted_classes[:10] else 0
        num = k0 + 1
        MRR, Hits1, Hits5, Hits10 = MRR_sum / num, hits1_sum / num, hits5_sum / num, hits10_sum / num
        if num % 500 == 0:
            print(
                "\n%d tested, MRR: %.3f, Hits@1: %.3f, Hits@5: %.3f, Hits@10: %.3f\n"
                % (num, MRR, Hits1, Hits5, Hits10)
            )
    print(
        "\n[%s], MRR: %.3f, Hits@1: %.3f, Hits@5: %.3f, Hits@10: %.3f\n" % (test_type, MRR, Hits1, Hits5, Hits10)
    )
    print("%.2f samples per testing subsumption" % (size_sum / size_n))

predict(target_subsumptions)

Predict a score for each given subsumption in the list.

The scores will be saved in test_subsumption_scores.csv.

Parameters:

Name Type Description Default
target_subsumptions List[List]

Each item is a list where the first element is a fixed ontology class \(C\), and the remaining elements are potential (candidate) super-classes of \(C\).

required
Source code in src/deeponto/complete/bertsubs/pipeline_intra.py
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
def predict(self, target_subsumptions: List[List]):
    r"""Predict a score for each given subsumption in the list.

    The scores will be saved in `test_subsumption_scores.csv`.

    Args:
        target_subsumptions (List[List]): Each item is a list where the first element is a fixed ontology class $C$,
            and the remaining elements are potential (candidate) super-classes of $C$.
    """
    out_lines = []
    for test in target_subsumptions:
        subcls, candidates = test[0], test[1:]
        candidate_subsumptions = [[subcls, c] for c in candidates]
        candidate_scores = []

        for candidate_subsumption in candidate_subsumptions:
            samples = self.sampler.subsumptions_to_samples(subsumptions=[candidate_subsumption], sample_label=None)
            scores = self.score(samples=samples)
            candidate_scores.append(np.average(scores))

        out_lines.append(",".join([str(i) for i in candidate_scores]))

    out_file = "test_subsumption_scores.csv"
    with open(out_file, "w") as f:
        for line in out_lines:
            f.write("%s\n" % line)
    print("Predicted subsumption scores are saved to %s" % out_file)

extract_subsumptions_from_ontology(onto, subsumption_type) staticmethod

Extract target subsumptions from a given ontology.

Parameters:

Name Type Description Default
onto Ontology

The target ontology.

required
subsumption_type str

the type of subsumptions, options are "named_class" or "restriction".

required
Source code in src/deeponto/complete/bertsubs/pipeline_intra.py
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
@staticmethod
def extract_subsumptions_from_ontology(onto: Ontology, subsumption_type: str):
    r"""Extract target subsumptions from a given ontology.

    Args:
        onto (Ontology): The target ontology.
        subsumption_type (str): the type of subsumptions, options are `"named_class"` or `"restriction"`.

    """
    all_subsumptions = onto.get_subsumption_axioms(entity_type="Classes")
    subsumptions = []
    if subsumption_type == "restriction":
        for subs in all_subsumptions:
            if (
                not onto.check_deprecated(owl_object=subs.getSubClass())
                and not onto.check_named_entity(owl_object=subs.getSuperClass())
                and SubsumptionSampler.is_basic_existential_restriction(
                    complex_class_str=str(subs.getSuperClass())
                )
            ):
                subsumptions.append(subs)
    elif subsumption_type == "named_class":
        for subs in all_subsumptions:
            c1, c2 = subs.getSubClass(), subs.getSuperClass()
            if (
                onto.check_named_entity(owl_object=c1)
                and not onto.check_deprecated(owl_object=c1)
                and onto.check_named_entity(owl_object=c2)
                and not onto.check_deprecated(owl_object=c2)
            ):
                subsumptions.append(subs)
    else:
        warnings.warn("\nUnknown subsumption type: %s\n" % subsumption_type)
    return subsumptions

extract_restrictions_from_ontology(onto) staticmethod

Extract basic existential restriction from an ontology.

Parameters:

Name Type Description Default
onto Ontology

The target ontology.

required

Returns:

Name Type Description
restrictions List

a list of existential restrictions.

Source code in src/deeponto/complete/bertsubs/pipeline_intra.py
326
327
328
329
330
331
332
333
334
335
336
337
338
339
@staticmethod
def extract_restrictions_from_ontology(onto: Ontology):
    r"""Extract basic existential restriction from an ontology.

    Args:
        onto (Ontology): The target ontology.
    Returns:
        restrictions (List): a list of existential restrictions.
    """
    restrictions = []
    for complexC in onto.get_asserted_complex_classes():
        if SubsumptionSampler.is_basic_existential_restriction(complex_class_str=str(complexC)):
            restrictions.append(complexC)
    return restrictions

get_test_neg_candidates_restriction(subcls, max_neg_size, restrictions, onto) staticmethod

Get a list of negative candidate class restrictions for testing.

Source code in src/deeponto/complete/bertsubs/pipeline_intra.py
341
342
343
344
345
346
347
348
349
350
351
@staticmethod
def get_test_neg_candidates_restriction(subcls, max_neg_size, restrictions, onto):
    """Get a list of negative candidate class restrictions for testing."""
    neg_restrictions = list()
    n = max_neg_size * 2 if max_neg_size * 2 <= len(restrictions) else len(restrictions)
    for r in random.sample(restrictions, n):
        if not onto.reasoner.check_subsumption(sub_entity=subcls, super_entity=r):
            neg_restrictions.append(r)
            if len(neg_restrictions) >= max_neg_size:
                break
    return neg_restrictions

get_test_neg_candidates_named_class(subclass, gt, max_neg_size, onto, max_depth=3, max_width=8) staticmethod

Get a list of negative candidate named classes for testing.

Source code in src/deeponto/complete/bertsubs/pipeline_intra.py
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
@staticmethod
def get_test_neg_candidates_named_class(subclass, gt, max_neg_size, onto, max_depth=3, max_width=8):
    """Get a list of negative candidate named classes for testing."""
    all_nebs, seeds = set(), [gt]
    depth = 1
    while depth <= max_depth:
        new_seeds = set()
        for seed in seeds:
            nebs = set()
            for nc_iri in onto.reasoner.get_inferred_sub_entities(
                seed, direct=True
            ) + onto.reasoner.get_inferred_super_entities(seed, direct=True):
                nc = onto.owl_classes[nc_iri]
                if onto.check_named_entity(owl_object=nc) and not onto.check_deprecated(owl_object=nc):
                    nebs.add(nc)
            new_seeds = new_seeds.union(nebs)
            all_nebs = all_nebs.union(nebs)
        depth += 1
        seeds = random.sample(new_seeds, max_width) if len(new_seeds) > max_width else new_seeds
    all_nebs = (
        all_nebs
        - {onto.owl_classes[iri] for iri in onto.reasoner.get_inferred_super_entities(subclass, direct=False)}
        - {subclass}
    )
    if len(all_nebs) > max_neg_size:
        return random.sample(all_nebs, max_neg_size)
    else:
        return list(all_nebs)

SubsumptionSampler(onto, config)

Class for sampling functions for training the subsumption prediction model.

Attributes:

Name Type Description
onto Ontology

The target ontology.

config CfgNode

The loaded configuration.

named_classes Set[str]

IRIs of named classes that are not deprecated.

iri_label Dict[str, List]

key -- class iris from named_classes, value -- a list of labels.

restrictionObjects Set[OWLClassExpression]

Basic existential restrictions that appear in the ontology.

restrictions set[str]

Strings of basic existential restrictions corresponding to restrictionObjects.

restriction_label Dict[str

List]): key -- existential restriction string, value -- a list of existential restriction labels.

verb OntologyVerbaliser

object for verbalisation.

Source code in src/deeponto/complete/bertsubs/text_semantics.py
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
def __init__(self, onto: Ontology, config: CfgNode):
    self.onto = onto
    self.config = config
    self.named_classes = self.extract_named_classes(onto=onto)
    self.iri_label = dict()
    for iri in self.named_classes:
        self.iri_label[iri] = []
        for p in config.label_property:
            strings = onto.get_annotations(
                owl_object=onto.get_owl_object(iri),
                annotation_property_iri=p,
                annotation_language_tag=None,
                apply_lowercasing=False,
                normalise_identifiers=False,
            )
            for s in strings:
                if s not in self.iri_label[iri]:
                    self.iri_label[iri].append(s)

    self.restrictionObjects = set()
    self.restrictions = set()
    self.restriction_label = dict()
    self.verb = OntologyVerbaliser(onto=onto)
    for complexC in onto.get_asserted_complex_classes():
        s = str(complexC)
        self.restriction_label[s] = []
        if self.is_basic_existential_restriction(complex_class_str=s):
            self.restrictionObjects.add(complexC)
            self.restrictions.add(s)
            self.restriction_label[s].append(self.verb.verbalise_class_expression(complexC).verbal)

is_basic_existential_restriction(complex_class_str) staticmethod

Determine if a complex class expression is a basic existential restriction.

Source code in src/deeponto/complete/bertsubs/text_semantics.py
75
76
77
78
79
80
81
82
83
@staticmethod
def is_basic_existential_restriction(complex_class_str: str):
    """Determine if a complex class expression is a basic existential restriction."""
    IRI = "<https?:\\/\\/(?:www\\.)?[-a-zA-Z0-9@:%._\\+~#=]{1,256}\\.[a-zA-Z0-9()]{1,6}\\b(?:[-a-zA-Z0-9()@:%_\\+.~#?&\\/=]*)>"
    p = rf"ObjectSomeValuesFrom\({IRI}\s{IRI}\)"
    if re.match(p, complex_class_str):
        return True
    else:
        return False

generate_samples(subsumptions, duplicate=True)

Generate text samples from subsumptions.

Parameters:

Name Type Description Default
subsumptions List[List]

A list of subsumptions, each of which of is a two-component list (sub_class_iri, super_class_iri_or_str).

required
duplicate bool

True -- duplicate the positive and negative samples, False -- do not duplicate.

True

Returns:

Type Description
List[List]

A list of samples, each element is a triple in the form of (sub_class_string, super_class_string, label_index).

Source code in src/deeponto/complete/bertsubs/text_semantics.py
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
def generate_samples(self, subsumptions: List[List], duplicate: bool = True):
    r"""Generate text samples from subsumptions.

    Args:
        subsumptions (List[List]): A list of subsumptions, each of which of is a two-component list `(sub_class_iri, super_class_iri_or_str)`.
        duplicate (bool): `True` -- duplicate the positive and negative samples, `False` -- do not duplicate.

    Returns:
        (List[List]): A list of samples, each element is a triple
            in the form of `(sub_class_string, super_class_string, label_index)`.
    """
    if duplicate:
        pos_dup, neg_dup = self.config.fine_tune.train_pos_dup, self.config.fine_tune.train_neg_dup
    else:
        pos_dup, neg_dup = 1, 1
    neg_subsumptions = list()
    for subs in subsumptions:
        c1 = subs[0]
        for _ in range(neg_dup):
            neg_c = self.get_negative_sample(subclass_iri=c1, subsumption_type=self.config.subsumption_type)
            if neg_c is not None:
                neg_subsumptions.append([c1, neg_c])
    pos_samples = self.subsumptions_to_samples(subsumptions=subsumptions, sample_label=1)
    pos_samples = pos_dup * pos_samples
    neg_samples = self.subsumptions_to_samples(subsumptions=neg_subsumptions, sample_label=0)
    if len(neg_samples) < len(pos_samples):
        neg_samples = neg_samples + [
            random.choice(neg_samples) for _ in range(len(pos_samples) - len(neg_samples))
        ]
    if len(neg_samples) > len(pos_samples):
        pos_samples = pos_samples + [
            random.choice(pos_samples) for _ in range(len(neg_samples) - len(pos_samples))
        ]
    print("pos_samples: %d, neg_samples: %d" % (len(pos_samples), len(neg_samples)))
    all_samples = [s for s in pos_samples + neg_samples if s[0] != "" and s[1] != ""]
    random.shuffle(all_samples)
    return all_samples

subsumptions_to_samples(subsumptions, sample_label)

Transform subsumptions into samples of strings.

Parameters:

Name Type Description Default
subsumptions List[List]

The given subsumptions.

required
sample_label Union[int, None]

1 (positive), 0 (negative), None (no label).

required

Returns:

Type Description
List[List]

A list of samples, each element is a triple in the form of (sub_class_string, super_class_string, label_index).

Source code in src/deeponto/complete/bertsubs/text_semantics.py
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
def subsumptions_to_samples(self, subsumptions: List[List], sample_label: Union[int, None]):
    r"""Transform subsumptions into samples of strings.

    Args:
        subsumptions (List[List]): The given subsumptions.
        sample_label (Union[int, None]): `1` (positive), `0` (negative), `None` (no label).

    Returns:
        (List[List]): A list of samples, each element is a triple
            in the form of `(sub_class_string, super_class_string, label_index)`.

    """
    local_samples = list()
    for subs in subsumptions:
        subcls, supcls = subs[0], subs[1]
        substrs = self.iri_label[subcls] if subcls in self.iri_label and len(self.iri_label[subcls]) > 0 else [""]

        if self.config.subsumption_type == "named_class":
            supstrs = self.iri_label[supcls] if supcls in self.iri_label and len(self.iri_label[supcls]) else [""]
        else:
            if supcls in self.restriction_label and len(self.restriction_label[supcls]) > 0:
                supstrs = self.restriction_label[supcls]
            else:
                supstrs = [self.verb.verbalise_class_expression(supcls).verbal]

        if self.config.use_one_label:
            substrs, supstrs = substrs[0:1], supstrs[0:1]

        if self.config.prompt.prompt_type == "isolated":
            for substr in substrs:
                for supstr in supstrs:
                    local_samples.append([substr, supstr])

        elif self.config.prompt.prompt_type == "traversal":
            subs_list_strs = set()
            for _ in range(self.config.prompt.context_dup):
                context_sub, no_duplicate = self.traversal_subsumptions(
                    cls=subcls,
                    hop=self.config.prompt.prompt_hop,
                    direction="subclass",
                    max_subsumptions=self.config.prompt.prompt_max_subsumptions,
                )
                subs_list = [self.named_subsumption_to_str(subsum) for subsum in context_sub]
                subs_list_str = " <SEP> ".join(subs_list)
                subs_list_strs.add(subs_list_str)
                if no_duplicate:
                    break

            if self.config.subsumption_type == "named_class":
                sups_list_strs = set()
                for _ in range(self.config.prompt.context_dup):
                    context_sup, no_duplicate = self.traversal_subsumptions(
                        cls=supcls,
                        hop=self.config.prompt.prompt_hop,
                        direction="supclass",
                        max_subsumptions=self.config.prompt.prompt_max_subsumptions,
                    )
                    sups_list = [self.named_subsumption_to_str(subsum) for subsum in context_sup]
                    sups_list_str = " <SEP> ".join(sups_list)
                    sups_list_strs.add(sups_list_str)
                    if no_duplicate:
                        break
            else:
                sups_list_strs = set(supstrs)

            for subs_list_str in subs_list_strs:
                for substr in substrs:
                    s1 = substr + " <SEP> " + subs_list_str
                    for sups_list_str in sups_list_strs:
                        for supstr in supstrs:
                            s2 = supstr + " <SEP> " + sups_list_str
                            local_samples.append([s1, s2])

        elif self.config.prompt.prompt_type == "path":
            sep_token = "<SUB>" if self.config.prompt.use_sub_special_token else "<SEP>"

            s1_set = set()
            for _ in range(self.config.prompt.context_dup):
                context_sub, no_duplicate = self.path_subsumptions(
                    cls=subcls, hop=self.config.prompt.prompt_hop, direction="subclass"
                )
                if len(context_sub) > 0:
                    s1 = ""
                    for i in range(len(context_sub)):
                        subsum = context_sub[len(context_sub) - i - 1]
                        subc = subsum[0]
                        s1 += "%s %s " % (
                            self.iri_label[subc][0]
                            if subc in self.iri_label and len(self.iri_label[subc]) > 0
                            else "",
                            sep_token,
                        )
                    for substr in substrs:
                        s1_set.add(s1 + substr)
                else:
                    for substr in substrs:
                        s1_set.add("%s %s" % (sep_token, substr))

                if no_duplicate:
                    break

            if self.config.subsumption_type == "named_class":
                s2_set = set()
                for _ in range(self.config.prompt.context_dup):
                    context_sup, no_duplicate = self.path_subsumptions(
                        cls=supcls, hop=self.config.prompt.prompt_hop, direction="supclass"
                    )
                    if len(context_sup) > 0:
                        s2 = ""
                        for subsum in context_sup:
                            supc = subsum[1]
                            s2 += " %s %s" % (
                                sep_token,
                                self.iri_label[supc][0]
                                if supc in self.iri_label and len(self.iri_label[supc]) > 0
                                else "",
                            )
                        for supstr in supstrs:
                            s2_set.add(supstr + s2)
                    else:
                        for supstr in supstrs:
                            s2_set.add("%s %s" % (supstr, sep_token))

                    if no_duplicate:
                        break
            else:
                s2_set = set(supstrs)

            for s1 in s1_set:
                for s2 in s2_set:
                    local_samples.append([s1, s2])

        else:
            print(f"unknown context type {self.config.prompt.prompt_type}")
            sys.exit(0)

    if sample_label is not None:
        for i in range(len(local_samples)):
            local_samples[i].append(sample_label)

    return local_samples

get_negative_sample(subclass_iri, subsumption_type='named_class')

Given a named subclass, get a negative class for a negative subsumption.

Parameters:

Name Type Description Default
subclass_iri str

IRI of a given sub-class.

required
subsumption_type str

named_class or restriction.

'named_class'
Source code in src/deeponto/complete/bertsubs/text_semantics.py
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
def get_negative_sample(self, subclass_iri: str, subsumption_type: str = "named_class"):
    r"""Given a named subclass, get a negative class for a negative subsumption.

    Args:
        subclass_iri (str): IRI of a given sub-class.
        subsumption_type (str): `named_class` or `restriction`.
    """
    subclass = self.onto.get_owl_object(iri=subclass_iri)
    if subsumption_type == "named_class":
        if self.config.no_reasoning:
            parents = self.onto.get_asserted_parents(owl_object=subclass, named_only=True)
            ancestors = set([str(item.getIRI()) for item in parents])
        else:
            ancestors = set(self.onto.reasoner.get_inferred_super_entities(subclass, direct=False))
        neg_c = random.sample(self.named_classes - ancestors, 1)[0]
        return neg_c
    else:
        for neg_c in random.sample(self.restrictionObjects, 5):
            if self.config.no_reasoning:
                return str(neg_c)
            else:
                if not self.onto.reasoner.check_subsumption(sub_entity=subclass, super_entity=neg_c):
                    return str(neg_c)
        return None

named_subsumption_to_str(subsum)

Transform a named subsumption into string with <SUB> and classes' labels.

Parameters:

Name Type Description Default
subsum List[Tuple]

A list of subsumption pairs in the form of (sub_class_iri, super_class_iri).

required
Source code in src/deeponto/complete/bertsubs/text_semantics.py
298
299
300
301
302
303
304
305
306
307
def named_subsumption_to_str(self, subsum: List):
    r"""Transform a named subsumption into string with `<SUB>` and classes' labels.

    Args:
        subsum (List[Tuple]): A list of subsumption pairs in the form of `(sub_class_iri, super_class_iri)`.
    """
    subc, supc = subsum[0], subsum[1]
    subs = self.iri_label[subc][0] if subc in self.iri_label and len(self.iri_label[subc]) > 0 else ""
    sups = self.iri_label[supc][0] if supc in self.iri_label and len(self.iri_label[supc]) > 0 else ""
    return "%s <SUB> %s" % (subs, sups)

subclass_to_strings(subcls)

Transform a sub-class into strings (with the path or traversal context template).

Parameters:

Name Type Description Default
subcls str

IRI of the sub-class.

required
Source code in src/deeponto/complete/bertsubs/text_semantics.py
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
def subclass_to_strings(self, subcls):
    r"""Transform a sub-class into strings (with the path or traversal context template).

    Args:
        subcls (str): IRI of the sub-class.
    """
    substrs = self.iri_label[subcls] if subcls in self.iri_label and len(self.iri_label[subcls]) > 0 else [""]

    if self.config.use_one_label:
        substrs = substrs[0:1]

    if self.config.prompt.prompt_type == "isolated":
        return substrs

    elif self.config.prompt.prompt_type == "traversal":
        subs_list_strs = set()
        for _ in range(self.config.prompt.context_dup):
            context_sub, no_duplicate = self.traversal_subsumptions(
                cls=subcls,
                hop=self.config.prompt.prompt_hop,
                direction="subclass",
                max_subsumptions=self.config.prompt.prompt_max_subsumptions,
            )
            subs_list = [self.named_subsumption_to_str(subsum) for subsum in context_sub]
            subs_list_str = " <SEP> ".join(subs_list)
            subs_list_strs.add(subs_list_str)
            if no_duplicate:
                break

        strs = list()
        for subs_list_str in subs_list_strs:
            for substr in substrs:
                s1 = substr + " <SEP> " + subs_list_str
                strs.append(s1)
        return strs

    elif self.config.prompt.prompt_type == "path":
        sep_token = "<SUB>" if self.config.prompt.use_sub_special_token else "<SEP>"

        s1_set = set()
        for _ in range(self.config.prompt.context_dup):
            context_sub, no_duplicate = self.path_subsumptions(
                cls=subcls, hop=self.config.prompt.prompt_hop, direction="subclass"
            )
            if len(context_sub) > 0:
                s1 = ""
                for i in range(len(context_sub)):
                    subsum = context_sub[len(context_sub) - i - 1]
                    subc = subsum[0]
                    s1 += "%s %s " % (
                        self.iri_label[subc][0]
                        if subc in self.iri_label and len(self.iri_label[subc]) > 0
                        else "",
                        sep_token,
                    )
                for substr in substrs:
                    s1_set.add(s1 + substr)
            else:
                for substr in substrs:
                    s1_set.add("%s %s" % (sep_token, substr))
            if no_duplicate:
                break

        return list(s1_set)

supclass_to_strings(supcls, subsumption_type='named_class')

Transform a super-class into strings (with the path or traversal context template if the subsumption type is "named_class").

Parameters:

Name Type Description Default
supcls str

IRI of the super-class.

required
subsumption_type str

The type of the subsumption.

'named_class'
Source code in src/deeponto/complete/bertsubs/text_semantics.py
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
def supclass_to_strings(self, supcls: str, subsumption_type: str = "named_class"):
    r"""Transform a super-class into strings (with the path or traversal context template if the subsumption type is `"named_class"`).

    Args:
        supcls (str): IRI of the super-class.
        subsumption_type (str): The type of the subsumption.
    """

    if subsumption_type == "named_class":
        supstrs = self.iri_label[supcls] if supcls in self.iri_label and len(self.iri_label[supcls]) else [""]
    else:
        if supcls in self.restriction_label and len(self.restriction_label[supcls]) > 0:
            supstrs = self.restriction_label[supcls]
        else:
            warnings.warn("Warning: %s has no descriptions" % supcls)
            supstrs = [""]

    if self.config.use_one_label:
        if subsumption_type == "named_class":
            supstrs = supstrs[0:1]

    if self.config.prompt.prompt_type == "isolated":
        return supstrs

    elif self.config.prompt.prompt_type == "traversal":
        if subsumption_type == "named_class":
            sups_list_strs = set()
            for _ in range(self.config.prompt.context_dup):
                context_sup, no_duplicate = self.traversal_subsumptions(
                    cls=supcls,
                    hop=self.config.prompt.prompt_hop,
                    direction="supclass",
                    max_subsumptions=self.config.prompt.prompt_max_subsumptions,
                )
                sups_list = [self.named_subsumption_to_str(subsum) for subsum in context_sup]
                sups_list_str = " <SEP> ".join(sups_list)
                sups_list_strs.add(sups_list_str)
                if no_duplicate:
                    break

        else:
            sups_list_strs = set(supstrs)

        strs = list()
        for sups_list_str in sups_list_strs:
            for supstr in supstrs:
                s2 = supstr + " <SEP> " + sups_list_str
                strs.append(s2)
        return strs

    elif self.config.prompt.prompt_type == "path":
        sep_token = "<SUB>" if self.config.prompt.use_sub_special_token else "<SEP>"

        if subsumption_type == "named_class":
            s2_set = set()
            for _ in range(self.config.prompt.context_dup):
                context_sup, no_duplicate = self.path_subsumptions(
                    cls=supcls, hop=self.config.prompt.prompt_hop, direction="supclass"
                )
                if len(context_sup) > 0:
                    s2 = ""
                    for subsum in context_sup:
                        supc = subsum[1]
                        s2 += " %s %s" % (
                            sep_token,
                            self.iri_label[supc][0]
                            if supc in self.iri_label and len(self.iri_label[supc]) > 0
                            else "",
                        )
                    for supstr in supstrs:
                        s2_set.add(supstr + s2)
                else:
                    for supstr in supstrs:
                        s2_set.add("%s %s" % (supstr, sep_token))

                if no_duplicate:
                    break
        else:
            s2_set = set(supstrs)

        return list(s2_set)

    else:
        print("unknown context type %s" % self.config.prompt.prompt_type)
        sys.exit(0)

traversal_subsumptions(cls, hop=1, direction='subclass', max_subsumptions=5)

Given a class, get its subsumptions by traversing the class hierarchy.

If the class is a sub-class in the subsumption axiom, get subsumptions from downside.
If the class is a super-class in the subsumption axiom, get subsumptions from upside.

Parameters:

Name Type Description Default
cls str

IRI of a named class.

required
hop int

The depth of the path.

1
direction str

subclass (downside path) or supclass (upside path).

'subclass'
max_subsumptions int

The maximum number of subsumptions to consider.

5
Source code in src/deeponto/complete/bertsubs/text_semantics.py
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
def traversal_subsumptions(self, cls: str, hop: int = 1, direction: str = "subclass", max_subsumptions: int = 5):
    r"""Given a class, get its subsumptions by traversing the class hierarchy.

        If the class is a sub-class in the subsumption axiom, get subsumptions from downside.
        If the class is a super-class in the subsumption axiom, get subsumptions from upside.

    Args:
        cls (str): IRI of a named class.
        hop (int): The depth of the path.
        direction (str): `subclass` (downside path) or `supclass` (upside path).
        max_subsumptions (int): The maximum number of subsumptions to consider.
    """
    subsumptions = list()
    seeds = [cls]
    d = 1
    no_duplicate = True
    while d <= hop:
        new_seeds = list()
        for s in seeds:
            if direction == "subclass":
                tmp = self.onto.reasoner.get_inferred_sub_entities(
                    self.onto.get_owl_object(iri=s), direct=True
                )
                if len(tmp) > 1:
                    no_duplicate = False
                random.shuffle(tmp)
                for c in tmp:
                    if not self.onto.check_deprecated(owl_object=self.onto.get_owl_object(iri=c)):
                        subsumptions.append([c, s])
                        if c not in new_seeds:
                            new_seeds.append(c)
            elif direction == "supclass":
                tmp = self.onto.reasoner.get_inferred_super_entities(
                    self.onto.get_owl_object(iri=s), direct=True
                )
                if len(tmp) > 1:
                    no_duplicate = False
                random.shuffle(tmp)
                for c in tmp:
                    if not self.onto.check_deprecated(owl_object=self.onto.get_owl_object(iri=c)):
                        subsumptions.append([s, c])
                        if c not in new_seeds:
                            new_seeds.append(c)
            else:
                warnings.warn("Unknown direction: %s" % direction)
        if len(subsumptions) >= max_subsumptions:
            subsumptions = random.sample(subsumptions, max_subsumptions)
            break
        else:
            seeds = new_seeds
            random.shuffle(seeds)
            d += 1
    return subsumptions, no_duplicate

path_subsumptions(cls, hop=1, direction='subclass')

Given a class, get its path subsumptions.

If the class is a sub-class in the subsumption axiom, get subsumptions from downside.
If the class is a super-class in the subsumption axiom, get subsumptions from upside.

Parameters:

Name Type Description Default
cls str

IRI of a named class.

required
hop int

The depth of the path.

1
direction str

subclass (downside path) or supclass (upside path).

'subclass'
Source code in src/deeponto/complete/bertsubs/text_semantics.py
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
def path_subsumptions(self, cls: str, hop: int = 1, direction: str = "subclass"):
    r"""Given a class, get its path subsumptions.

        If the class is a sub-class in the subsumption axiom, get subsumptions from downside.
        If the class is a super-class in the subsumption axiom, get subsumptions from upside.

    Args:
        cls (str): IRI of a named class.
        hop (int): The depth of the path.
        direction (str): `subclass` (downside path) or `supclass` (upside path).
    """
    subsumptions = list()
    seed = cls
    d = 1
    no_duplicate = True
    while d <= hop:
        if direction == "subclass":
            tmp = self.onto.reasoner.get_inferred_sub_entities(
                self.onto.get_owl_object(iri=seed), direct=True
            )
            if len(tmp) > 1:
                no_duplicate = False
            end = True
            if len(tmp) > 0:
                random.shuffle(tmp)
                for c in tmp:
                    if not self.onto.check_deprecated(owl_object=self.onto.get_owl_object(iri=c)):
                        subsumptions.append([c, seed])
                        seed = c
                        end = False
                        break
            if end:
                break
        elif direction == "supclass":
            tmp = self.onto.reasoner.get_inferred_super_entities(
                self.onto.get_owl_object(iri=seed), direct=True
            )
            if len(tmp) > 1:
                no_duplicate = False
            end = True
            if len(tmp) > 0:
                random.shuffle(tmp)
                for c in tmp:
                    if not self.onto.check_deprecated(owl_object=self.onto.get_owl_object(iri=c)):
                        subsumptions.append([seed, c])
                        seed = c
                        end = False
                        break
            if end:
                break
        else:
            warnings.warn("Unknown direction: %s" % direction)

        d += 1
    return subsumptions, no_duplicate

BERTSubsumptionClassifierTrainer(bert_checkpoint, train_data, val_data, max_length=128, early_stop=False, early_stop_patience=10)

Source code in src/deeponto/complete/bertsubs/bert_classifier.py
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
def __init__(
    self,
    bert_checkpoint: str,
    train_data: List,
    val_data: List,
    max_length: int = 128,
    early_stop: bool = False,
    early_stop_patience: int = 10,
):
    print(f"initialize BERT for Binary Classification from the Pretrained BERT model at: {bert_checkpoint} ...")

    # BERT
    self.model = AutoModelForSequenceClassification.from_pretrained(bert_checkpoint)
    self.tokenizer = AutoTokenizer.from_pretrained(bert_checkpoint)
    self.trainer = None

    self.max_length = max_length
    self.tra = self.load_dataset(train_data, max_length=self.max_length, count_token_size=True)
    self.val = self.load_dataset(val_data, max_length=self.max_length, count_token_size=True)
    print(f"text max length: {self.max_length}")
    print(f"data files loaded with sizes:")
    print(f"\t[# Train]: {len(self.tra)}, [# Val]: {len(self.val)}")

    # early stopping
    self.early_stop = early_stop
    self.early_stop_patience = early_stop_patience

add_special_tokens(tokens)

Add additional special tokens into the tokenizer's vocab.

Parameters:

Name Type Description Default
tokens List[str]

additional tokens to add, e.g., ["<SUB>","<EOA>","<EOC>"]

required
Source code in src/deeponto/complete/bertsubs/bert_classifier.py
60
61
62
63
64
65
66
67
def add_special_tokens(self, tokens: List):
    r"""Add additional special tokens into the tokenizer's vocab.
    Args:
        tokens (List[str]): additional tokens to add, e.g., `["<SUB>","<EOA>","<EOC>"]`
    """
    special_tokens_dict = {"additional_special_tokens": tokens}
    self.tokenizer.add_special_tokens(special_tokens_dict)
    self.model.resize_token_embeddings(len(self.tokenizer))

train(train_args, do_fine_tune=True)

Initiate the Huggingface trainer with input arguments and start training.

Parameters:

Name Type Description Default
train_args TrainingArguments

Arguments for training.

required
do_fine_tune bool

False means loading the checkpoint without training. Defaults to True.

True
Source code in src/deeponto/complete/bertsubs/bert_classifier.py
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
def train(self, train_args: TrainingArguments, do_fine_tune: bool = True):
    r"""Initiate the Huggingface trainer with input arguments and start training.
    Args:
        train_args (TrainingArguments): Arguments for training.
        do_fine_tune (bool): `False` means loading the checkpoint without training. Defaults to `True`.
    """
    self.trainer = Trainer(
        model=self.model,
        args=train_args,
        train_dataset=self.tra,
        eval_dataset=self.val,
        compute_metrics=self.compute_metrics,
        tokenizer=self.tokenizer,
    )
    if self.early_stop:
        self.trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=self.early_stop_patience))
    if do_fine_tune:
        self.trainer.train()

compute_metrics(pred) staticmethod

Auxiliary function to add accurate metric into evaluation.

Source code in src/deeponto/complete/bertsubs/bert_classifier.py
88
89
90
91
92
93
94
95
@staticmethod
def compute_metrics(pred):
    """Auxiliary function to add accurate metric into evaluation.
    """
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc}

load_dataset(data, max_length=512, count_token_size=False)

Load a Huggingface dataset from a list of samples.

Parameters:

Name Type Description Default
data List[Tuple]

Data samples in a list.

required
max_length int

Maximum length of the input sequence.

512
count_token_size bool

Whether or not to count the token sizes of the data. Defaults to False.

False
Source code in src/deeponto/complete/bertsubs/bert_classifier.py
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
def load_dataset(self, data: List, max_length: int = 512, count_token_size: bool = False) -> Dataset:
    r"""Load a Huggingface dataset from a list of samples.
    Args:
        data (List[Tuple]): Data samples in a list.
        max_length (int): Maximum length of the input sequence.
        count_token_size (bool): Whether or not to count the token sizes of the data. Defaults to `False`.
    """
    # data_df = pd.DataFrame(data, columns=["sent1", "sent2", "labels"])
    # dataset = Dataset.from_pandas(data_df)

    def iterate():
        for sample in data:
            yield {"sent1": sample[0], "sent2": sample[1], "labels": sample[2]}

    dataset = Dataset.from_generator(iterate)

    if count_token_size:
        tokens = self.tokenizer(dataset["sent1"], dataset["sent2"])
        l_sum, num_128, num_256, num_512, l_max = 0, 0, 0, 0, 0
        for item in tokens["input_ids"]:
            l = len(item)
            l_sum += l
            if l <= 128:
                num_128 += 1
            if l <= 256:
                num_256 += 1
            if l <= 512:
                num_512 += 1
            if l > l_max:
                l_max = l
        print("average token size: %.2f" % (l_sum / len(tokens["input_ids"])))
        print("ratio of token size <= 128: %.3f" % (num_128 / len(tokens["input_ids"])))
        print("ratio of token size <= 256: %.3f" % (num_256 / len(tokens["input_ids"])))
        print("ratio of token size <= 512: %.3f" % (num_512 / len(tokens["input_ids"])))
        print("max token size: %d" % l_max)
    dataset = dataset.map(
        lambda examples: self.tokenizer(
            examples["sent1"], examples["sent2"], max_length=max_length, truncation=True
        ),
        batched=True,
        num_proc=1,
    )
    return dataset

Last update: February 2, 2024
Created: January 13, 2023
GitHub: @Lawhy   Personal Page: yuanhe.wiki