Add aten.mm op and "test" it e2e.

Note that unlike aten.matmul which has dynamic behavior
depending on the argument ranks (can do matrix-matrix, matrix-vector,
batch matmul, etc.), aten.mm is just a vanilla matrix
multiply, which can be lowered precisely to tcf.matmul.

The "test" is really just an example that I stared at while getting my
feet wet with this. We probably want something that actually tests this
as part of `ninja check-npcomp`.
pull/122/head
Sean Silva 2020-11-20 15:59:55 -08:00
parent ec1336a8a3
commit 1dfcfa9cd1
8 changed files with 65 additions and 50 deletions

View File

@ -0,0 +1,33 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import sys
import numpy as np
import torch
import torch_mlir
import npcomp
from npcomp.compiler.pytorch.backend.refjit import *
from npcomp.compiler.utils import logging
logging.enable()
torch.manual_seed(0)
lhs = torch.rand(2, 3)
rhs = torch.rand(3, 4)
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("mm", [lhs, rhs]) as f:
result = torch.mm(lhs, rhs)
f.returns([result])
backend = CompilerBackend()
jit_module = backend.load(backend.compile(mb.module))
jit_result = jit_module.mm(lhs.numpy(), rhs.numpy())
print(f"PyTorch Result = {result.numpy()}", file=sys.stderr)
print(f"JIT Result = {jit_result}", file=sys.stderr)
np.testing.assert_allclose(result.numpy(), jit_result)

View File

@ -102,6 +102,7 @@ def generate_ops(g: "OpGenerator"):
g.ordinary_immutable_op(
"aten::_log_softmax_backward_data(Tensor,Tensor,int,Tensor)",
"LogSoftmaxBackwardDataOp", "log_softmax_backward_data")
g.ordinary_immutable_op("aten::mm(Tensor,Tensor)", "MmOp", "mm")
# Loss functions.
g.print_banner("Loss function ops")

View File

@ -215,32 +215,6 @@ uint64_t getConv2dResultTransferVolume(T *o, unsigned int idx, bool write) {
}
}
// Return the op statistics for matrixmultiply-like operations.
template <typename T> std::map<std::string, uint64_t> getMMOpStatistics(T op) {
std::map<std::string, uint64_t> toReturn;
TensorType resultTy = op.getResult().getType().template cast<TensorType>();
uint64_t ofm_volume = getTensorVolume(resultTy);
// Use the weight tensor to find the number of input neurons
TensorType lossTy = op.getOperand(0).getType().template cast<TensorType>();
TensorType weightTy = op.getOperand(1).getType().template cast<TensorType>();
uint64_t num_input_neurons = weightTy.getShape()[0];
uint64_t total_MACs = ofm_volume * num_input_neurons;
toReturn["ops:MAC"] = total_MACs;
uint64_t loss_in_volume = getTensorVolume(lossTy);
uint64_t weight_volume = getTensorVolume(weightTy);
toReturn["reads"] = loss_in_volume + weight_volume;
toReturn["writes"] = ofm_volume;
toReturn["operand:0:activation_in"] = loss_in_volume;
toReturn["operand:1:activation_in"] = weight_volume;
toReturn["result:0:activation_out"] = ofm_volume;
return toReturn;
}
// Return the op statistics for ReLU-like operations.
template <typename T>
std::map<std::string, uint64_t> getReLUOpStatistics(T op) {

View File

@ -941,6 +941,25 @@ const Torch::BuildKernelMetadata &LogSoftmaxBackwardDataOp::getTorchBuildKernelM
return metadata;
}
Torch::KernelMetadata MmOp::getTorchKernelMetadata() {
return getTorchBuildKernelMetadata();
}
const Torch::BuildKernelMetadata &MmOp::getTorchBuildKernelMetadata() {
using KVC = Torch::KernelValueConversion::BitMask;
static Torch::BuildKernelMetadata metadata = ([]() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::mm";
m.promoteTrailingOutTensor = true;
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor});
m.addReturnTypes({"Tensor"});
m.addReturnConversions({KVC::kImmutableTensor});
return m;
})();
return metadata;
}
// -----------------------------------------------------------------------------
// Loss function ops
// -----------------------------------------------------------------------------

View File

@ -540,6 +540,17 @@ def aten_LogSoftmaxBackwardDataOp: aten_Op<"log_softmax_backward_data", [NoSideE
);
}
def aten_MmOp: aten_Op<"mm", [NoSideEffect, DeclareOpInterfaceMethods<TorchBuildableKernelOpInterface>, DeclareOpInterfaceMethods<TorchKernelOpInterface>]> {
let summary = "Recognized op for kernel aten::mm";
let arguments = (ins
AnyTorchImmutableTensor:$self,
AnyTorchImmutableTensor:$mat2
);
let results = (outs
AnyTorchImmutableTensor
);
}
// -----------------------------------------------------------------------------
// Loss function ops
// -----------------------------------------------------------------------------

View File

@ -92,22 +92,6 @@ def aten_MeanOp: aten_Op<"mean", [NoSideEffect, StatisticsOpInterface]>,
}];
}
def aten_MmOp: aten_Op<"mm", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (
ins AnyTensor:$self,
AnyTensor:$mat2
);
let summary = "aten mm operator";
let description = [{
MmOp
aten mm operator
}];
let extraClassDeclaration = [{
std::map<std::string, uint64_t> getStatistics();
}];
}
def aten_MulUnderOp: aten_Op<"mul_", [NoSideEffect, StatisticsOpInterface]>,
Results<(outs AnyTensor)> {
let arguments = (

View File

@ -74,4 +74,5 @@ void mlir::NPCOMP::populateCoreATenToTCFPatterns(
patterns.insert<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
patterns.insert<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(
context);
patterns.insert<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
}

View File

@ -371,14 +371,6 @@ std::map<std::string, uint64_t> MeanOp::getStatistics() {
return toReturn;
}
// mm
// std::map<std::string, uint64_t> MMOp::getStatistics() {
// getMMOpStatistics(*this);
// }
std::map<std::string, uint64_t> MmOp::getStatistics() {
return getMMOpStatistics(*this);
}
// mul_
std::map<std::string, uint64_t> MulUnderOp::getStatistics() {