Add support for "trailing_" and "out" variants of various ops.

We already had the `promoteTrailingOutTensor` flag, but weren't using
it. A inplaceVariantKernelName flag needed to be added.

This change is a little dissatisfying, as the conversions done by the
RecognizeKernelsPass are currently non-orthogonal. In particular,
`kDropResultAndAliasArg0` probably won't work as intended if mixed with
these (we probably need to promote kDropResultAndAliasArg0 to not be an
arg-level thing anyway, as we have done with promoteTrailingOutTensor).

This involved adding a new op `numpy.overwrite_array`.

```
numpy.overwrite_array %arg2 overwrites %arg0 : tensor<2x3xf32>, !numpy.ndarray<[2,3]:f32>
```

This models the destructive update behavior. Note that in the above op,
we cannot simply RAUW %arg0 with a suitably conveted %arg2 (for example,
%arg0 might have uses that are not dominated by %arg2, or might have an
alias relation with some other array in the program). In general, we
need a pass analogous to "SSA-formation" which knows how to see through
these to uncover an underlying tensor program.

Also, add tanh_out_e2e.py/div_inplace_e2e.py and fix some bitrot in
refjit.py which is my running example I'm trying to get working.
pull/190/head
Sean Silva 2021-03-18 13:13:40 -07:00
parent a53ed850bd
commit 703428eff4
10 changed files with 203 additions and 6 deletions

View File

@ -0,0 +1,32 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
import npcomp
from npcomp.compiler.pytorch.backend import refjit
from npcomp.compiler.utils import logging
import test_utils
logging.enable()
torch.manual_seed(0)
arg0 = torch.ones(2, 2)
arg1 = torch.ones(2, 2)
def fun(a, b):
return a.div_(b)
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("test", [arg0, arg1]) as f:
f.returns([fun(arg0, arg1)])
backend = refjit.CompilerBackend()
jit_module = backend.load(backend.compile(mb.module))
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
test_utils.compare_outputs(torch.mm, jit_module.test, arg0 + 1, arg1 + 1)

View File

@ -0,0 +1,33 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import torch
import torch_mlir
import npcomp
from npcomp.compiler.pytorch.backend import refjit
from npcomp.compiler.utils import logging
import test_utils
logging.enable()
torch.manual_seed(0)
arg0 = torch.ones(2, 2)
def fun(a):
z = torch.zeros(2, 2)
torch.tanh(a, out=z)
return z
mb = torch_mlir.ModuleBuilder()
with mb.capture_function("test", [arg0]) as f:
f.returns([fun(arg0)])
backend = refjit.CompilerBackend()
jit_module = backend.load(backend.compile(mb.module))
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
test_utils.compare_outputs(torch.mm, jit_module.test, arg0 + 1, arg1 + 1)

View File

@ -209,11 +209,13 @@ class OpGenerator:
- Setting all arguments and returns to kImmutableTensor
- Enabling kPromoteScalarToTensor on the second argument.
"""
kernel_name = kernel_sig.partition("(")[0]
opdef = self.define_op(
kernel_sig=kernel_sig,
ods_name=ods_name,
op_name=op_name,
promote_trailing_out_tensor=promote_trailing_out_tensor,
inplace_variant_kernel_name=kernel_name + "_",
traits=list(traits) + ["NoSideEffect"],
**kwargs)
opdef.arg_transforms(
@ -443,6 +445,7 @@ class InflightOpDef:
traits: Sequence[str] = (),
alias_kernel_names: Sequence[str] = (),
promote_trailing_out_tensor: bool = False,
inplace_variant_kernel_name: Optional[str] = None,
override_arg_types: Sequence[str] = None,
override_return_types: Sequence[str] = None,
drop_arg_indices: Sequence[int] = (),
@ -455,6 +458,7 @@ class InflightOpDef:
self.traits = list(traits)
self.alias_kernel_names = list(alias_kernel_names)
self.promote_trailing_out_tensor = promote_trailing_out_tensor
self.inplace_variant_kernel_name = inplace_variant_kernel_name
self.override_arg_types = override_arg_types
self.override_return_types = override_return_types
self.drop_arg_indices = drop_arg_indices
@ -548,6 +552,7 @@ class InflightOpDef:
arg_type_flags=self.arg_type_flags,
return_type_flags=self.return_type_flags,
promote_trailing_out_tensor=self.promote_trailing_out_tensor,
inplace_variant_kernel_name=self.inplace_variant_kernel_name,
alias_kernel_names=self.alias_kernel_names)
@ -647,6 +652,7 @@ class CCImplEmitter(EmitterBase):
arg_type_flags: List[Tuple[str, List[Tuple[str]]]],
return_type_flags: List[Tuple[str, List[Tuple[str]]]],
promote_trailing_out_tensor: bool = False,
inplace_variant_kernel_name: Optional[str] = None,
alias_kernel_names: Sequence[str] = ()):
# getTorchKernelMetadata() method.
self.print(
@ -671,6 +677,9 @@ class CCImplEmitter(EmitterBase):
f"m.aliasKernelNames.push_back({self.quote(alias_kernel_name)});")
if promote_trailing_out_tensor:
self.print("m.promoteTrailingOutTensor = true;")
if inplace_variant_kernel_name is not None:
self.print(
f"m.inplaceVariantKernelName = {self.quote(inplace_variant_kernel_name)};")
# Arg types/flags.
arg_types = self._format_cpp_str_initlist(
[t[0] for t in arg_type_flags])

View File

@ -29,6 +29,7 @@ const Torch::BuildKernelMetadata &AddOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::add";
m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::add_";
m.addArgTypes({"Tensor", "Tensor", "Scalar"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar, KVC::kNone});
m.addReturnTypes({"Tensor"});
@ -48,6 +49,7 @@ const Torch::BuildKernelMetadata &Atan2Op::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::atan2";
m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::atan2_";
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
@ -67,6 +69,7 @@ const Torch::BuildKernelMetadata &DivOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::div";
m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::div_";
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
@ -86,6 +89,7 @@ const Torch::BuildKernelMetadata &FloorDivideOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::floor_divide";
m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::floor_divide_";
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
@ -105,6 +109,7 @@ const Torch::BuildKernelMetadata &MulOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::mul";
m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::mul_";
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
@ -124,6 +129,7 @@ const Torch::BuildKernelMetadata &RemainderOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::remainder";
m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::remainder_";
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
@ -143,6 +149,7 @@ const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::true_divide";
m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::true_divide_";
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
@ -162,6 +169,7 @@ const Torch::BuildKernelMetadata &MaximumOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::maximum";
m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::maximum_";
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});
@ -181,6 +189,7 @@ const Torch::BuildKernelMetadata &MinimumOp::getTorchBuildKernelMetadata() {
Torch::BuildKernelMetadata m;
m.kernelName = "aten::minimum";
m.promoteTrailingOutTensor = true;
m.inplaceVariantKernelName = "aten::minimum_";
m.addArgTypes({"Tensor", "Tensor"});
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
m.addReturnTypes({"Tensor"});

View File

@ -79,6 +79,30 @@ def Numpy_CopyToTensorOp : Numpy_Op<"copy_to_tensor", [
let hasCanonicalizer = 1;
}
def Numpy_OverwriteArrayOp : Numpy_Op<"overwrite_array", []> {
let summary = "Ovewrite the contents of array with a tensor.";
let description = [{
Replaces the contents of `array` with corresponding values from `tensor`.
Immediately after this op has completed, indexing `array` will result
in identical values as indexing into `tensor`. Of course, later ops
might mutate `array`, so this relationship need not hold for the entire
program.
This op has undefined behavior if the tensor and array have different
shapes or dtypes.
}];
let arguments = (ins
Numpy_AnyTensor:$tensor,
Numpy_NdArrayType:$array
);
let results = (outs
);
let assemblyFormat = [{
$tensor `overwrites` $array attr-dict `:` type($tensor) `,` type($array)
}];
}
//----------------------------------------------------------------------------//
// Universal function ops (ufunc)
// See: https://docs.scipy.org/doc/numpy/reference/ufuncs.html

View File

@ -80,6 +80,11 @@ struct BuildKernelMetadata : public KernelMetadata {
/// all be handled with this flag.
bool promoteTrailingOutTensor = false;
/// Many ops have variant that treats the first (self) argument as an out
/// param (usually denoted with a trailing `_`, such as `aten::div_`).
/// When this string is set, it indicates the name of such a variant op.
Optional<StringRef> inplaceVariantKernelName = None;
SmallVector<KernelValueConversion::BitMask, 4> argConversions;
SmallVector<KernelValueConversion::BitMask, 4> returnConversions;

View File

@ -188,6 +188,12 @@ public:
CandidateTransformList &candidates = kernelTransforms[aliasKernelName];
candidates.emplace_back(opName, buildMetadata);
}
if (buildMetadata.inplaceVariantKernelName) {
CandidateTransformList &candidates =
kernelTransforms[*buildMetadata.inplaceVariantKernelName];
candidates.emplace_back(opName, buildMetadata);
}
}
LogicalResult transformKernelCall(KernelCallOp kernelCall,
@ -235,8 +241,10 @@ public:
return failure();
}
// TODO: Detect trailing outref.
bool sourceHasTrailingOutRef = false;
bool sourceHasTrailingOutRef =
candidate.buildMetadata.promoteTrailingOutTensor &&
sourceMetadata.argTypes.size() ==
candidate.buildMetadata.argTypes.size() + 1;
if (sourceHasTrailingOutRef ||
sourceMetadata.argTypes.size() ==
candidate.buildMetadata.argTypes.size()) {
@ -261,7 +269,6 @@ public:
PatternRewriter &rewriter) const {
using KVC = KernelValueConversion::BitMask;
// Pre-conditions.
assert(!sourceHasTrailingOutRef && "trailing outref not yet implemented");
if (sourceHasTrailingOutRef)
assert((sourceMetadata.argTypes.size() ==
candidate.buildMetadata.argTypes.size() + 1) &&
@ -270,6 +277,10 @@ public:
assert(sourceMetadata.argTypes.size() ==
candidate.buildMetadata.argTypes.size() &&
"arg arity mismatch");
bool isInplaceVariant =
candidate.buildMetadata.inplaceVariantKernelName &&
kernelCall.kernelName() ==
*candidate.buildMetadata.inplaceVariantKernelName;
// Convert fixed return types.
using PostConversionCallback = std::function<void()>;
@ -368,6 +379,9 @@ public:
Operation *newOp = rewriter.createOperation(state);
// Materialize conversions for results.
// For out params, we need to save off the converted first result -- we will
// just RAUW it with the out param later.
Value firstResultConverted;
for (auto it : llvm::enumerate(resultConversions)) {
ConversionInfo &info = it.value();
Value origOpResultValue = info.originalValue;
@ -379,12 +393,31 @@ public:
newOpResultValue, rewriter);
}
origOpResultValue.replaceAllUsesWith(convertedValue);
if (it.index() == 0)
firstResultConverted = convertedValue;
}
// Post conversion callbacks.
for (auto &callback : postConversionCallbacks)
callback();
if (sourceHasTrailingOutRef || isInplaceVariant) {
assert(newOp->getNumResults() > 0 &&
newOp->getResultTypes()[0].isa<TensorType>() &&
"must have tensor first result");
LLVM_DEBUG(llvm::dbgs()
<< " - Ovewriting out param with result tensor.\n");
Value out;
if (sourceHasTrailingOutRef)
out = kernelCall.getOperand(fixedArgArity);
else // isInplaceVariant
out = kernelCall.getOperand(0);
rewriter.create<Numpy::OverwriteArrayOp>(kernelCall.getLoc(),
newOp->getResult(0), out);
assert(firstResultConverted && "must have a first result");
firstResultConverted.replaceAllUsesWith(out);
}
// Done.
rewriter.eraseOp(kernelCall);
return success();

View File

@ -61,13 +61,15 @@ class CompilerBackend:
for IREE, it is a serialized VM flatbuffer) but the contract is that
it is operated on by methods on this class.
"""
# TODO: Once transitioned to new Python API, don't reparse the module.
with Context() as context:
with imported_module.context as context:
if self._debug:
logging.debug("Initial PyTorch IR:\n{}", imported_module)
# Frontend.
pm = PassManager.parse(",".join(TORCH_TO_TCF_PASSES))
pipeline_str = ",".join(TORCH_TO_TCF_PASSES)
if self._debug:
logging.debug("Running Torch->TCF pipeline '{}'", pipeline_str)
pm = PassManager.parse(pipeline_str)
pm.run(imported_module)
if self._debug:
logging.debug("TCF IR:\n{}", imported_module)

View File

@ -108,3 +108,42 @@ func @copy_inplace(%arg0: !numpy.ndarray<[4]:f32>, %arg1: !numpy.ndarray<[4]:f32
// CHECK: return %arg0
return %0 : !numpy.ndarray<[4]:f32>
}
// -----
// Out params.
// Some torch ops allow an extra argument which the result is written into.
// The return value is identical to this out argument can be RAUW'ed.
//
// CHECK-LABEL: func @out_param(
// CHECK-SAME: %[[ARRAY:.*]]: !numpy.ndarray<[2,2]:f32>,
// CHECK-SAME: %[[OUT:.*]]: !numpy.ndarray<[2,2]:f32>) -> (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) {
func @out_param(%arg0: !numpy.ndarray<[2,2]:f32>, %arg1: !numpy.ndarray<[2,2]:f32>) -> (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) {
// CHECK: %[[TENSOR:.*]] = numpy.copy_to_tensor %[[ARRAY]] : (!numpy.ndarray<[2,2]:f32>) -> tensor<2x2xf32>
// CHECK: %[[RESULT_TENSOR:.*]] = "aten.tanh"(%[[TENSOR]]) : (tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: numpy.overwrite_array %[[RESULT_TENSOR]] overwrites %[[OUT]] : tensor<2x2xf32>, !numpy.ndarray<[2,2]:f32>
%3 = torch.kernel_call "aten::tanh" %arg0, %arg1 : (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) -> !numpy.ndarray<[2,2]:f32> {sigArgTypes = ["Tensor", "Tensor"], sigIsMutable = true, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
// CHECK: return %[[OUT]], %[[OUT]] : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
return %3, %arg1 : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
}
// -----
// Inplace variant.
// Some torch ops have a trailing "_" variant which updates the first (self)
// parameter in place.
// These are equivalent to the out param versions, as-if there the self param
// was appended as the out param.
//
// CHECK-LABEL: func @inplace_variant(
// CHECK-SAME: %[[LHS_OUT:.*]]: !numpy.ndarray<[2,2]:f32>,
// CHECK-SAME: %[[RHS:.*]]: !numpy.ndarray<[2,2]:f32>) -> (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) {
func @inplace_variant(%arg0: !numpy.ndarray<[2,2]:f32>, %arg1: !numpy.ndarray<[2,2]:f32>) -> (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) {
// CHECK: %[[LHS_TENSOR:.*]] = numpy.copy_to_tensor %[[LHS_OUT]] : (!numpy.ndarray<[2,2]:f32>) -> tensor<2x2xf32>
// CHECK: %[[RHS_TENSOR:.*]] = numpy.copy_to_tensor %[[RHS]] : (!numpy.ndarray<[2,2]:f32>) -> tensor<2x2xf32>
// CHECK: %[[RESULT_TENSOR:.*]] = "aten.div"(%[[LHS_TENSOR]], %[[RHS_TENSOR]]) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
// CHECK: numpy.overwrite_array %[[RESULT_TENSOR]] overwrites %[[LHS_OUT]] : tensor<2x2xf32>, !numpy.ndarray<[2,2]:f32>
%0 = torch.kernel_call "aten::div_" %arg0, %arg1 : (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) -> !numpy.ndarray<[2,2]:f32> {sigArgTypes = ["Tensor", "Tensor"], sigIsMutable = true, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
// CHECK: return %[[LHS_OUT]], %[[LHS_OUT]] : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
return %0, %arg0 : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
}

View File

@ -6,3 +6,14 @@ func @builtin_ufunc(%arg0 : tensor<3xf64>, %arg1 : tensor<3xf64>) -> tensor<3xf6
%0 = numpy.builtin_ufunc_call<"numpy.add"> (%arg0, %arg1) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
return %0 : tensor<3xf64>
}
// CHECK-LABEL: @ndarray_tensor_bridging
func @ndarray_tensor_bridging(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.ndarray<[2,3]:f32>, %arg2: tensor<2x3xf32>) {
// CHECK-NEXT: numpy.copy_to_tensor
%t = numpy.copy_to_tensor %arg1 : (!numpy.ndarray<[2,3]:f32>) -> tensor<2x3xf32>
// CHECK-NEXT: numpy.create_array_from_tensor
%a = numpy.create_array_from_tensor %arg2 : (tensor<2x3xf32>) -> !numpy.ndarray<[2,3]:f32>
// CHECK-NEXT: numpy.overwrite_array
numpy.overwrite_array %arg2 overwrites %arg0 : tensor<2x3xf32>, !numpy.ndarray<[2,3]:f32>
return
}