| | import onnxscript |
| | import onnx_ir as ir |
| | import onnx_ir.passes.common |
| | import numpy as np |
| | import onnxslim |
| |
|
| |
|
| | class ReplaceDftWithMatMulRule(onnxscript.rewriter.RewriteRuleClassBase): |
| | def pattern(self, op, x, dft_length): |
| | x = op.Reshape(x, _allow_other_inputs=True) |
| | dft = op.DFT(x, dft_length, _outputs=["dft_output"]) |
| | real_part = op.Slice(dft, [0], [1], [-1]) |
| | return op.Squeeze(real_part, [-1]) |
| |
|
| | def rewrite(self, op, x: ir.Value, dft_length: ir.Value, dft_output: ir.Value): |
| | |
| | dft_node = dft_output.producer() |
| | assert dft_node is not None |
| |
|
| | dft_size = ir.convenience.get_const_tensor(dft_length).numpy().item() |
| |
|
| | |
| | |
| | |
| | num_freqs = dft_size // 2 + 1 |
| |
|
| | |
| | n = np.arange(dft_size, dtype=np.float32)[:, np.newaxis] |
| | k = np.arange(num_freqs, dtype=np.float32)[ |
| | np.newaxis, : |
| | ] |
| | dft_matrix = np.cos( |
| | 2 * np.pi * k * n / dft_size |
| | ) |
| |
|
| | |
| | dft_matrix = op.initializer(ir.tensor(dft_matrix), name=f"{x.name}_dft_matrix") |
| |
|
| | |
| | result = op.MatMul(x, dft_matrix) |
| |
|
| | return result |
| |
|
| |
|
| | class ReplaceSplit(onnxscript.rewriter.RewriteRuleClassBase): |
| | def pattern(self, op, x): |
| | return op.Split( |
| | x, _allow_other_inputs=True, _outputs=["split_out_1", "split_out_2"] |
| | ) |
| |
|
| | def rewrite(self, op, x: ir.Value, **kwargs): |
| | zero = op.initializer(ir.tensor(np.array([0], dtype=np.int64)), "zero") |
| | batch_size = op.Gather(x, zero) |
| | sample_size = op.initializer( |
| | ir.tensor(np.array([144000], dtype=np.int32)), "sample_size" |
| | ) |
| | return batch_size, sample_size |
| |
|
| |
|
| | class RemoveCast(onnxscript.rewriter.RewriteRuleClassBase): |
| | def pattern(self, op, x): |
| | return op.Cast(x) |
| |
|
| | def rewrite(self, op, x: ir.Value, **kwargs): |
| | return op.Identity(x) |
| |
|
| |
|
| | class RemoveReversedSequenceFork(onnxscript.rewriter.RewriteRuleClassBase): |
| | def pattern(self, op, x, y, scale, bias): |
| | x = op.Transpose(x) |
| | y = op.Transpose(y) |
| | x = op.ReverseSequence(x, _allow_other_inputs=True) |
| | y = op.ReverseSequence(y, _allow_other_inputs=True) |
| | x = op.Unsqueeze(x, _allow_other_inputs=True) |
| | y = op.Unsqueeze(y, _allow_other_inputs=True) |
| | concat = op.Concat(x, y) |
| | mul = op.Mul(concat, scale) |
| | add = op.Add(mul, bias) |
| | return op.Transpose(add) |
| |
|
| | def rewrite(self, op, x, y, scale, bias, **kwargs): |
| | |
| | neg_one = op.initializer(ir.tensor(np.array([-1], dtype=np.int64)), "neg_one") |
| | int_64_min = op.initializer( |
| | ir.tensor(np.array([-9223372036854775808], dtype=np.int64)), "int_64_min" |
| | ) |
| | |
| | x = op.Slice(x, neg_one, int_64_min, neg_one, neg_one) |
| | y = op.Slice(y, neg_one, int_64_min, neg_one, neg_one) |
| | x = op.Unsqueeze(x, neg_one) |
| | y = op.Unsqueeze(y, neg_one) |
| | concat = op.Concat(x, y, axis=3) |
| | |
| | mul = op.Mul(concat, scale) |
| | add = op.Add(mul, bias) |
| | return op.Transpose(add, perm=[0, 3, 2, 1]) |
| |
|
| |
|
| | model = ir.load("model.onnx") |
| |
|
| | |
| | model.graph.inputs[0].shape = ir.Shape(["batch", 144000]) |
| | model.graph.outputs[0].shape = ir.Shape(["batch", 6522]) |
| |
|
| | onnxscript.rewriter.rewrite( |
| | model, |
| | [ |
| | ReplaceDftWithMatMulRule().rule(), |
| | ReplaceSplit().rule(), |
| | RemoveCast().rule(), |
| | ], |
| | ) |
| |
|
| | |
| | initializers = list(model.graph.initializers.values()) |
| | for initializer in initializers: |
| | if initializer.dtype == ir.DataType.INT32: |
| | int32_array = initializer.const_value.numpy() |
| | int64_array = int32_array.astype(np.int64) |
| | new_initializer = ir.val(initializer.name, const_value=ir.tensor(int64_array)) |
| | model.graph.initializers.pop(initializer.name) |
| | model.graph.initializers.add(new_initializer) |
| | initializer.replace_all_uses_with(new_initializer) |
| |
|
| | onnxscript.optimizer.optimize( |
| | model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 |
| | ) |
| |
|
| |
|
| | |
| | def remove_slice_reshape(model: ir.Model): |
| | mul_node = model.graph.node("model/MEL_SPEC1/Mul") |
| | first_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_1") |
| | first_shape = ir.val( |
| | "first_shape", const_value=ir.tensor([-1, 72000, 2], dtype=ir.DataType.INT64) |
| | ) |
| | model.graph.initializers.add(first_shape) |
| | second_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_1") |
| | second_shape = ir.val( |
| | "second_shape", const_value=ir.tensor([-1, 18000, 8], dtype=ir.DataType.INT64) |
| | ) |
| | model.graph.initializers.add(second_shape) |
| |
|
| | third_reshape = model.graph.node("model/MEL_SPEC1/stft/frame/Reshape_4") |
| | third_shape = ir.val( |
| | "third_shape", const_value=ir.tensor([-1, 511, 2048], dtype=ir.DataType.INT64) |
| | ) |
| | model.graph.initializers.add(third_shape) |
| | fourth_reshape = model.graph.node("model/MEL_SPEC2/stft/frame/Reshape_4") |
| | fourth_shape = ir.val( |
| | "fourth_shape", const_value=ir.tensor([-1, 511, 1024], dtype=ir.DataType.INT64) |
| | ) |
| | model.graph.initializers.add(fourth_shape) |
| |
|
| | |
| | first_reshape.replace_input_with(0, mul_node.outputs[0]) |
| | first_reshape.replace_input_with(1, first_shape) |
| | second_reshape.replace_input_with(0, mul_node.outputs[0]) |
| | second_reshape.replace_input_with(1, second_shape) |
| | third_reshape.replace_input_with(1, third_shape) |
| | fourth_reshape.replace_input_with(1, fourth_shape) |
| |
|
| |
|
| | remove_slice_reshape(model) |
| | |
| | onnxscript.optimizer.optimize( |
| | model, input_size_limit=1024 * 1024 * 1024, output_size_limit=1024 * 1024 * 1024 |
| | ) |
| |
|
| | print("Slimming model...") |
| | model = ir.from_proto(onnxslim.slim(ir.to_proto(model))) |
| |
|
| | print("Removing reversed sequence fork...") |
| | onnxscript.rewriter.rewrite( |
| | model, |
| | [ |
| | RemoveReversedSequenceFork.rule(), |
| | ], |
| | ) |
| |
|
| | |
| | model = ir.from_proto(onnxslim.slim(ir.to_proto(model))) |
| |
|
| | onnx_ir.passes.common.ClearMetadataAndDocStringPass()(model) |
| | model.graph.inputs[0].name = "input" |
| | model.graph.outputs[0].name = "output" |
| | model.ir_version = 10 |
| | model.producer_name = "onnx-ir" |
| | model.graph.name = "BirdNET-v2.4" |
| |
|
| | ir.save(model, "birdnet.onnx") |
| |
|