mirror of https://github.com/llvm/torch-mlir
Add aten.mm op and "test" it e2e.
Note that unlike aten.matmul which has dynamic behavior depending on the argument ranks (can do matrix-matrix, matrix-vector, batch matmul, etc.), aten.mm is just a vanilla matrix multiply, which can be lowered precisely to tcf.matmul. The "test" is really just an example that I stared at while getting my feet wet with this. We probably want something that actually tests this as part of `ninja check-npcomp`.pull/122/head
parent
ec1336a8a3
commit
1dfcfa9cd1
|
@ -0,0 +1,33 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import sys
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
import npcomp
|
||||
from npcomp.compiler.pytorch.backend.refjit import *
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
logging.enable()
|
||||
|
||||
torch.manual_seed(0)
|
||||
lhs = torch.rand(2, 3)
|
||||
rhs = torch.rand(3, 4)
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("mm", [lhs, rhs]) as f:
|
||||
result = torch.mm(lhs, rhs)
|
||||
f.returns([result])
|
||||
|
||||
backend = CompilerBackend()
|
||||
jit_module = backend.load(backend.compile(mb.module))
|
||||
|
||||
jit_result = jit_module.mm(lhs.numpy(), rhs.numpy())
|
||||
|
||||
print(f"PyTorch Result = {result.numpy()}", file=sys.stderr)
|
||||
print(f"JIT Result = {jit_result}", file=sys.stderr)
|
||||
|
||||
np.testing.assert_allclose(result.numpy(), jit_result)
|
|
@ -102,6 +102,7 @@ def generate_ops(g: "OpGenerator"):
|
|||
g.ordinary_immutable_op(
|
||||
"aten::_log_softmax_backward_data(Tensor,Tensor,int,Tensor)",
|
||||
"LogSoftmaxBackwardDataOp", "log_softmax_backward_data")
|
||||
g.ordinary_immutable_op("aten::mm(Tensor,Tensor)", "MmOp", "mm")
|
||||
|
||||
# Loss functions.
|
||||
g.print_banner("Loss function ops")
|
||||
|
|
|
@ -215,32 +215,6 @@ uint64_t getConv2dResultTransferVolume(T *o, unsigned int idx, bool write) {
|
|||
}
|
||||
}
|
||||
|
||||
// Return the op statistics for matrixmultiply-like operations.
|
||||
template <typename T> std::map<std::string, uint64_t> getMMOpStatistics(T op) {
|
||||
|
||||
std::map<std::string, uint64_t> toReturn;
|
||||
|
||||
TensorType resultTy = op.getResult().getType().template cast<TensorType>();
|
||||
uint64_t ofm_volume = getTensorVolume(resultTy);
|
||||
|
||||
// Use the weight tensor to find the number of input neurons
|
||||
TensorType lossTy = op.getOperand(0).getType().template cast<TensorType>();
|
||||
TensorType weightTy = op.getOperand(1).getType().template cast<TensorType>();
|
||||
uint64_t num_input_neurons = weightTy.getShape()[0];
|
||||
uint64_t total_MACs = ofm_volume * num_input_neurons;
|
||||
toReturn["ops:MAC"] = total_MACs;
|
||||
|
||||
uint64_t loss_in_volume = getTensorVolume(lossTy);
|
||||
uint64_t weight_volume = getTensorVolume(weightTy);
|
||||
toReturn["reads"] = loss_in_volume + weight_volume;
|
||||
toReturn["writes"] = ofm_volume;
|
||||
|
||||
toReturn["operand:0:activation_in"] = loss_in_volume;
|
||||
toReturn["operand:1:activation_in"] = weight_volume;
|
||||
toReturn["result:0:activation_out"] = ofm_volume;
|
||||
return toReturn;
|
||||
}
|
||||
|
||||
// Return the op statistics for ReLU-like operations.
|
||||
template <typename T>
|
||||
std::map<std::string, uint64_t> getReLUOpStatistics(T op) {
|
||||
|
|
|
@ -941,6 +941,25 @@ const Torch::BuildKernelMetadata &LogSoftmaxBackwardDataOp::getTorchBuildKernelM
|
|||
return metadata;
|
||||
}
|
||||
|
||||
Torch::KernelMetadata MmOp::getTorchKernelMetadata() {
|
||||
return getTorchBuildKernelMetadata();
|
||||
}
|
||||
|
||||
const Torch::BuildKernelMetadata &MmOp::getTorchBuildKernelMetadata() {
|
||||
using KVC = Torch::KernelValueConversion::BitMask;
|
||||
static Torch::BuildKernelMetadata metadata = ([]() {
|
||||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::mm";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
m.addReturnConversions({KVC::kImmutableTensor});
|
||||
return m;
|
||||
})();
|
||||
return metadata;
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Loss function ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -540,6 +540,17 @@ def aten_LogSoftmaxBackwardDataOp: aten_Op<"log_softmax_backward_data", [NoSideE
|
|||
);
|
||||
}
|
||||
|
||||
def aten_MmOp: aten_Op<"mm", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
|
||||
let summary = "Recognized op for kernel aten::mm";
|
||||
let arguments = (ins
|
||||
AnyTorchImmutableTensor:$self,
|
||||
AnyTorchImmutableTensor:$mat2
|
||||
);
|
||||
let results = (outs
|
||||
AnyTorchImmutableTensor
|
||||
);
|
||||
}
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// Loss function ops
|
||||
// -----------------------------------------------------------------------------
|
||||
|
|
|
@ -92,22 +92,6 @@ def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>,
|
|||
}];
|
||||
}
|
||||
|
||||
def aten_MmOp: aten_Op<"mm", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
ins AnyTensor:$self,
|
||||
AnyTensor:$mat2
|
||||
);
|
||||
let summary = "aten mm operator";
|
||||
let description = [{
|
||||
MmOp
|
||||
aten mm operator
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
std::map<std::string, uint64_t> getStatistics();
|
||||
}];
|
||||
}
|
||||
|
||||
def aten_MulUnderOp: aten_Op<"mul_", [NoSideEffect, StatisticsOpInterface]>,
|
||||
Results<(outs AnyTensor)> {
|
||||
let arguments = (
|
||||
|
|
|
@ -74,4 +74,5 @@ void mlir::NPCOMP::populateCoreATenToTCFPatterns(
|
|||
patterns.insert<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
|
||||
patterns.insert<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(
|
||||
context);
|
||||
patterns.insert<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
|
||||
}
|
||||
|
|
|
@ -371,14 +371,6 @@ std::map<std::string, uint64_t> MeanOp::getStatistics() {
|
|||
return toReturn;
|
||||
}
|
||||
|
||||
// mm
|
||||
// std::map<std::string, uint64_t> MMOp::getStatistics() {
|
||||
// getMMOpStatistics(*this);
|
||||
// }
|
||||
std::map<std::string, uint64_t> MmOp::getStatistics() {
|
||||
return getMMOpStatistics(*this);
|
||||
}
|
||||
|
||||
// mul_
|
||||
std::map<std::string, uint64_t> MulUnderOp::getStatistics() {
|
||||
|
||||
|
|
Loading…
Reference in New Issue