mirror of https://github.com/llvm/torch-mlir
Add support for "trailing_" and "out" variants of various ops.
We already had the `promoteTrailingOutTensor` flag, but weren't using it. A inplaceVariantKernelName flag needed to be added. This change is a little dissatisfying, as the conversions done by the RecognizeKernelsPass are currently non-orthogonal. In particular, `kDropResultAndAliasArg0` probably won't work as intended if mixed with these (we probably need to promote kDropResultAndAliasArg0 to not be an arg-level thing anyway, as we have done with promoteTrailingOutTensor). This involved adding a new op `numpy.overwrite_array`. ``` numpy.overwrite_array %arg2 overwrites %arg0 : tensor<2x3xf32>, !numpy.ndarray<[2,3]:f32> ``` This models the destructive update behavior. Note that in the above op, we cannot simply RAUW %arg0 with a suitably conveted %arg2 (for example, %arg0 might have uses that are not dominated by %arg2, or might have an alias relation with some other array in the program). In general, we need a pass analogous to "SSA-formation" which knows how to see through these to uncover an underlying tensor program. Also, add tanh_out_e2e.py/div_inplace_e2e.py and fix some bitrot in refjit.py which is my running example I'm trying to get working.pull/190/head
parent
a53ed850bd
commit
703428eff4
|
@ -0,0 +1,32 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
import npcomp
|
||||
from npcomp.compiler.pytorch.backend import refjit
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
import test_utils
|
||||
|
||||
logging.enable()
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
arg0 = torch.ones(2, 2)
|
||||
arg1 = torch.ones(2, 2)
|
||||
|
||||
def fun(a, b):
|
||||
return a.div_(b)
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("test", [arg0, arg1]) as f:
|
||||
f.returns([fun(arg0, arg1)])
|
||||
|
||||
backend = refjit.CompilerBackend()
|
||||
jit_module = backend.load(backend.compile(mb.module))
|
||||
|
||||
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
|
||||
test_utils.compare_outputs(torch.mm, jit_module.test, arg0 + 1, arg1 + 1)
|
|
@ -0,0 +1,33 @@
|
|||
# -*- Python -*-
|
||||
# This file is licensed under a pytorch-style license
|
||||
# See frontends/pytorch/LICENSE for license information.
|
||||
|
||||
import torch
|
||||
import torch_mlir
|
||||
|
||||
import npcomp
|
||||
from npcomp.compiler.pytorch.backend import refjit
|
||||
from npcomp.compiler.utils import logging
|
||||
|
||||
import test_utils
|
||||
|
||||
logging.enable()
|
||||
|
||||
torch.manual_seed(0)
|
||||
|
||||
arg0 = torch.ones(2, 2)
|
||||
|
||||
def fun(a):
|
||||
z = torch.zeros(2, 2)
|
||||
torch.tanh(a, out=z)
|
||||
return z
|
||||
|
||||
mb = torch_mlir.ModuleBuilder()
|
||||
with mb.capture_function("test", [arg0]) as f:
|
||||
f.returns([fun(arg0)])
|
||||
|
||||
backend = refjit.CompilerBackend()
|
||||
jit_module = backend.load(backend.compile(mb.module))
|
||||
|
||||
test_utils.compare_outputs(torch.mm, jit_module.test, arg0, arg1)
|
||||
test_utils.compare_outputs(torch.mm, jit_module.test, arg0 + 1, arg1 + 1)
|
|
@ -209,11 +209,13 @@ class OpGenerator:
|
|||
- Setting all arguments and returns to kImmutableTensor
|
||||
- Enabling kPromoteScalarToTensor on the second argument.
|
||||
"""
|
||||
kernel_name = kernel_sig.partition("(")[0]
|
||||
opdef = self.define_op(
|
||||
kernel_sig=kernel_sig,
|
||||
ods_name=ods_name,
|
||||
op_name=op_name,
|
||||
promote_trailing_out_tensor=promote_trailing_out_tensor,
|
||||
inplace_variant_kernel_name=kernel_name + "_",
|
||||
traits=list(traits) + ["NoSideEffect"],
|
||||
**kwargs)
|
||||
opdef.arg_transforms(
|
||||
|
@ -443,6 +445,7 @@ class InflightOpDef:
|
|||
traits: Sequence[str] = (),
|
||||
alias_kernel_names: Sequence[str] = (),
|
||||
promote_trailing_out_tensor: bool = False,
|
||||
inplace_variant_kernel_name: Optional[str] = None,
|
||||
override_arg_types: Sequence[str] = None,
|
||||
override_return_types: Sequence[str] = None,
|
||||
drop_arg_indices: Sequence[int] = (),
|
||||
|
@ -455,6 +458,7 @@ class InflightOpDef:
|
|||
self.traits = list(traits)
|
||||
self.alias_kernel_names = list(alias_kernel_names)
|
||||
self.promote_trailing_out_tensor = promote_trailing_out_tensor
|
||||
self.inplace_variant_kernel_name = inplace_variant_kernel_name
|
||||
self.override_arg_types = override_arg_types
|
||||
self.override_return_types = override_return_types
|
||||
self.drop_arg_indices = drop_arg_indices
|
||||
|
@ -548,6 +552,7 @@ class InflightOpDef:
|
|||
arg_type_flags=self.arg_type_flags,
|
||||
return_type_flags=self.return_type_flags,
|
||||
promote_trailing_out_tensor=self.promote_trailing_out_tensor,
|
||||
inplace_variant_kernel_name=self.inplace_variant_kernel_name,
|
||||
alias_kernel_names=self.alias_kernel_names)
|
||||
|
||||
|
||||
|
@ -647,6 +652,7 @@ class CCImplEmitter(EmitterBase):
|
|||
arg_type_flags: List[Tuple[str, List[Tuple[str]]]],
|
||||
return_type_flags: List[Tuple[str, List[Tuple[str]]]],
|
||||
promote_trailing_out_tensor: bool = False,
|
||||
inplace_variant_kernel_name: Optional[str] = None,
|
||||
alias_kernel_names: Sequence[str] = ()):
|
||||
# getTorchKernelMetadata() method.
|
||||
self.print(
|
||||
|
@ -671,6 +677,9 @@ class CCImplEmitter(EmitterBase):
|
|||
f"m.aliasKernelNames.push_back({self.quote(alias_kernel_name)});")
|
||||
if promote_trailing_out_tensor:
|
||||
self.print("m.promoteTrailingOutTensor = true;")
|
||||
if inplace_variant_kernel_name is not None:
|
||||
self.print(
|
||||
f"m.inplaceVariantKernelName = {self.quote(inplace_variant_kernel_name)};")
|
||||
# Arg types/flags.
|
||||
arg_types = self._format_cpp_str_initlist(
|
||||
[t[0] for t in arg_type_flags])
|
||||
|
|
|
@ -29,6 +29,7 @@ const Torch::BuildKernelMetadata &AddOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::add";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::add_";
|
||||
m.addArgTypes({"Tensor", "Tensor", "Scalar"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar, KVC::kNone});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -48,6 +49,7 @@ const Torch::BuildKernelMetadata &Atan2Op::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::atan2";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::atan2_";
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -67,6 +69,7 @@ const Torch::BuildKernelMetadata &DivOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::div";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::div_";
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -86,6 +89,7 @@ const Torch::BuildKernelMetadata &FloorDivideOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::floor_divide";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::floor_divide_";
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -105,6 +109,7 @@ const Torch::BuildKernelMetadata &MulOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::mul";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::mul_";
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -124,6 +129,7 @@ const Torch::BuildKernelMetadata &RemainderOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::remainder";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::remainder_";
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -143,6 +149,7 @@ const Torch::BuildKernelMetadata &TrueDivideOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::true_divide";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::true_divide_";
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -162,6 +169,7 @@ const Torch::BuildKernelMetadata &MaximumOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::maximum";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::maximum_";
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
@ -181,6 +189,7 @@ const Torch::BuildKernelMetadata &MinimumOp::getTorchBuildKernelMetadata() {
|
|||
Torch::BuildKernelMetadata m;
|
||||
m.kernelName = "aten::minimum";
|
||||
m.promoteTrailingOutTensor = true;
|
||||
m.inplaceVariantKernelName = "aten::minimum_";
|
||||
m.addArgTypes({"Tensor", "Tensor"});
|
||||
m.addArgConversions({KVC::kImmutableTensor, KVC::kImmutableTensor|KVC::kPromoteScalar});
|
||||
m.addReturnTypes({"Tensor"});
|
||||
|
|
|
@ -79,6 +79,30 @@ def Numpy_CopyToTensorOp : Numpy_Op<"copy_to_tensor", [
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Numpy_OverwriteArrayOp : Numpy_Op<"overwrite_array", []> {
|
||||
let summary = "Ovewrite the contents of array with a tensor.";
|
||||
let description = [{
|
||||
Replaces the contents of `array` with corresponding values from `tensor`.
|
||||
|
||||
Immediately after this op has completed, indexing `array` will result
|
||||
in identical values as indexing into `tensor`. Of course, later ops
|
||||
might mutate `array`, so this relationship need not hold for the entire
|
||||
program.
|
||||
|
||||
This op has undefined behavior if the tensor and array have different
|
||||
shapes or dtypes.
|
||||
}];
|
||||
let arguments = (ins
|
||||
Numpy_AnyTensor:$tensor,
|
||||
Numpy_NdArrayType:$array
|
||||
);
|
||||
let results = (outs
|
||||
);
|
||||
let assemblyFormat = [{
|
||||
$tensor `overwrites` $array attr-dict `:` type($tensor) `,` type($array)
|
||||
}];
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// Universal function ops (ufunc)
|
||||
// See: https://docs.scipy.org/doc/numpy/reference/ufuncs.html
|
||||
|
|
|
@ -80,6 +80,11 @@ struct BuildKernelMetadata : public KernelMetadata {
|
|||
/// all be handled with this flag.
|
||||
bool promoteTrailingOutTensor = false;
|
||||
|
||||
/// Many ops have variant that treats the first (self) argument as an out
|
||||
/// param (usually denoted with a trailing `_`, such as `aten::div_`).
|
||||
/// When this string is set, it indicates the name of such a variant op.
|
||||
Optional<StringRef> inplaceVariantKernelName = None;
|
||||
|
||||
SmallVector<KernelValueConversion::BitMask, 4> argConversions;
|
||||
SmallVector<KernelValueConversion::BitMask, 4> returnConversions;
|
||||
|
||||
|
|
|
@ -188,6 +188,12 @@ public:
|
|||
CandidateTransformList &candidates = kernelTransforms[aliasKernelName];
|
||||
candidates.emplace_back(opName, buildMetadata);
|
||||
}
|
||||
|
||||
if (buildMetadata.inplaceVariantKernelName) {
|
||||
CandidateTransformList &candidates =
|
||||
kernelTransforms[*buildMetadata.inplaceVariantKernelName];
|
||||
candidates.emplace_back(opName, buildMetadata);
|
||||
}
|
||||
}
|
||||
|
||||
LogicalResult transformKernelCall(KernelCallOp kernelCall,
|
||||
|
@ -235,8 +241,10 @@ public:
|
|||
return failure();
|
||||
}
|
||||
|
||||
// TODO: Detect trailing outref.
|
||||
bool sourceHasTrailingOutRef = false;
|
||||
bool sourceHasTrailingOutRef =
|
||||
candidate.buildMetadata.promoteTrailingOutTensor &&
|
||||
sourceMetadata.argTypes.size() ==
|
||||
candidate.buildMetadata.argTypes.size() + 1;
|
||||
if (sourceHasTrailingOutRef ||
|
||||
sourceMetadata.argTypes.size() ==
|
||||
candidate.buildMetadata.argTypes.size()) {
|
||||
|
@ -261,7 +269,6 @@ public:
|
|||
PatternRewriter &rewriter) const {
|
||||
using KVC = KernelValueConversion::BitMask;
|
||||
// Pre-conditions.
|
||||
assert(!sourceHasTrailingOutRef && "trailing outref not yet implemented");
|
||||
if (sourceHasTrailingOutRef)
|
||||
assert((sourceMetadata.argTypes.size() ==
|
||||
candidate.buildMetadata.argTypes.size() + 1) &&
|
||||
|
@ -270,6 +277,10 @@ public:
|
|||
assert(sourceMetadata.argTypes.size() ==
|
||||
candidate.buildMetadata.argTypes.size() &&
|
||||
"arg arity mismatch");
|
||||
bool isInplaceVariant =
|
||||
candidate.buildMetadata.inplaceVariantKernelName &&
|
||||
kernelCall.kernelName() ==
|
||||
*candidate.buildMetadata.inplaceVariantKernelName;
|
||||
|
||||
// Convert fixed return types.
|
||||
using PostConversionCallback = std::function<void()>;
|
||||
|
@ -368,6 +379,9 @@ public:
|
|||
Operation *newOp = rewriter.createOperation(state);
|
||||
|
||||
// Materialize conversions for results.
|
||||
// For out params, we need to save off the converted first result -- we will
|
||||
// just RAUW it with the out param later.
|
||||
Value firstResultConverted;
|
||||
for (auto it : llvm::enumerate(resultConversions)) {
|
||||
ConversionInfo &info = it.value();
|
||||
Value origOpResultValue = info.originalValue;
|
||||
|
@ -379,12 +393,31 @@ public:
|
|||
newOpResultValue, rewriter);
|
||||
}
|
||||
origOpResultValue.replaceAllUsesWith(convertedValue);
|
||||
if (it.index() == 0)
|
||||
firstResultConverted = convertedValue;
|
||||
}
|
||||
|
||||
// Post conversion callbacks.
|
||||
for (auto &callback : postConversionCallbacks)
|
||||
callback();
|
||||
|
||||
if (sourceHasTrailingOutRef || isInplaceVariant) {
|
||||
assert(newOp->getNumResults() > 0 &&
|
||||
newOp->getResultTypes()[0].isa<TensorType>() &&
|
||||
"must have tensor first result");
|
||||
LLVM_DEBUG(llvm::dbgs()
|
||||
<< " - Ovewriting out param with result tensor.\n");
|
||||
Value out;
|
||||
if (sourceHasTrailingOutRef)
|
||||
out = kernelCall.getOperand(fixedArgArity);
|
||||
else // isInplaceVariant
|
||||
out = kernelCall.getOperand(0);
|
||||
rewriter.create<Numpy::OverwriteArrayOp>(kernelCall.getLoc(),
|
||||
newOp->getResult(0), out);
|
||||
assert(firstResultConverted && "must have a first result");
|
||||
firstResultConverted.replaceAllUsesWith(out);
|
||||
}
|
||||
|
||||
// Done.
|
||||
rewriter.eraseOp(kernelCall);
|
||||
return success();
|
||||
|
|
|
@ -61,13 +61,15 @@ class CompilerBackend:
|
|||
for IREE, it is a serialized VM flatbuffer) but the contract is that
|
||||
it is operated on by methods on this class.
|
||||
"""
|
||||
# TODO: Once transitioned to new Python API, don't reparse the module.
|
||||
with Context() as context:
|
||||
with imported_module.context as context:
|
||||
if self._debug:
|
||||
logging.debug("Initial PyTorch IR:\n{}", imported_module)
|
||||
|
||||
# Frontend.
|
||||
pm = PassManager.parse(",".join(TORCH_TO_TCF_PASSES))
|
||||
pipeline_str = ",".join(TORCH_TO_TCF_PASSES)
|
||||
if self._debug:
|
||||
logging.debug("Running Torch->TCF pipeline '{}'", pipeline_str)
|
||||
pm = PassManager.parse(pipeline_str)
|
||||
pm.run(imported_module)
|
||||
if self._debug:
|
||||
logging.debug("TCF IR:\n{}", imported_module)
|
||||
|
|
|
@ -108,3 +108,42 @@ func @copy_inplace(%arg0: !numpy.ndarray<[4]:f32>, %arg1: !numpy.ndarray<[4]:f32
|
|||
// CHECK: return %arg0
|
||||
return %0 : !numpy.ndarray<[4]:f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Out params.
|
||||
// Some torch ops allow an extra argument which the result is written into.
|
||||
// The return value is identical to this out argument can be RAUW'ed.
|
||||
//
|
||||
// CHECK-LABEL: func @out_param(
|
||||
// CHECK-SAME: %[[ARRAY:.*]]: !numpy.ndarray<[2,2]:f32>,
|
||||
// CHECK-SAME: %[[OUT:.*]]: !numpy.ndarray<[2,2]:f32>) -> (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) {
|
||||
func @out_param(%arg0: !numpy.ndarray<[2,2]:f32>, %arg1: !numpy.ndarray<[2,2]:f32>) -> (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) {
|
||||
// CHECK: %[[TENSOR:.*]] = numpy.copy_to_tensor %[[ARRAY]] : (!numpy.ndarray<[2,2]:f32>) -> tensor<2x2xf32>
|
||||
// CHECK: %[[RESULT_TENSOR:.*]] = "aten.tanh"(%[[TENSOR]]) : (tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: numpy.overwrite_array %[[RESULT_TENSOR]] overwrites %[[OUT]] : tensor<2x2xf32>, !numpy.ndarray<[2,2]:f32>
|
||||
%3 = torch.kernel_call "aten::tanh" %arg0, %arg1 : (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) -> !numpy.ndarray<[2,2]:f32> {sigArgTypes = ["Tensor", "Tensor"], sigIsMutable = true, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
// CHECK: return %[[OUT]], %[[OUT]] : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
|
||||
return %3, %arg1 : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Inplace variant.
|
||||
// Some torch ops have a trailing "_" variant which updates the first (self)
|
||||
// parameter in place.
|
||||
// These are equivalent to the out param versions, as-if there the self param
|
||||
// was appended as the out param.
|
||||
//
|
||||
// CHECK-LABEL: func @inplace_variant(
|
||||
// CHECK-SAME: %[[LHS_OUT:.*]]: !numpy.ndarray<[2,2]:f32>,
|
||||
// CHECK-SAME: %[[RHS:.*]]: !numpy.ndarray<[2,2]:f32>) -> (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) {
|
||||
func @inplace_variant(%arg0: !numpy.ndarray<[2,2]:f32>, %arg1: !numpy.ndarray<[2,2]:f32>) -> (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) {
|
||||
// CHECK: %[[LHS_TENSOR:.*]] = numpy.copy_to_tensor %[[LHS_OUT]] : (!numpy.ndarray<[2,2]:f32>) -> tensor<2x2xf32>
|
||||
// CHECK: %[[RHS_TENSOR:.*]] = numpy.copy_to_tensor %[[RHS]] : (!numpy.ndarray<[2,2]:f32>) -> tensor<2x2xf32>
|
||||
// CHECK: %[[RESULT_TENSOR:.*]] = "aten.div"(%[[LHS_TENSOR]], %[[RHS_TENSOR]]) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
|
||||
// CHECK: numpy.overwrite_array %[[RESULT_TENSOR]] overwrites %[[LHS_OUT]] : tensor<2x2xf32>, !numpy.ndarray<[2,2]:f32>
|
||||
%0 = torch.kernel_call "aten::div_" %arg0, %arg1 : (!numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>) -> !numpy.ndarray<[2,2]:f32> {sigArgTypes = ["Tensor", "Tensor"], sigIsMutable = true, sigIsVararg = false, sigIsVarret = false, sigRetTypes = ["Tensor"]}
|
||||
// CHECK: return %[[LHS_OUT]], %[[LHS_OUT]] : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
|
||||
return %0, %arg0 : !numpy.ndarray<[2,2]:f32>, !numpy.ndarray<[2,2]:f32>
|
||||
}
|
||||
|
|
|
@ -6,3 +6,14 @@ func @builtin_ufunc(%arg0 : tensor<3xf64>, %arg1 : tensor<3xf64>) -> tensor<3xf6
|
|||
%0 = numpy.builtin_ufunc_call<"numpy.add"> (%arg0, %arg1) : (tensor<3xf64>, tensor<3xf64>) -> tensor<3xf64>
|
||||
return %0 : tensor<3xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @ndarray_tensor_bridging
|
||||
func @ndarray_tensor_bridging(%arg0: !numpy.ndarray<[2,3]:f32>, %arg1: !numpy.ndarray<[2,3]:f32>, %arg2: tensor<2x3xf32>) {
|
||||
// CHECK-NEXT: numpy.copy_to_tensor
|
||||
%t = numpy.copy_to_tensor %arg1 : (!numpy.ndarray<[2,3]:f32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: numpy.create_array_from_tensor
|
||||
%a = numpy.create_array_from_tensor %arg2 : (tensor<2x3xf32>) -> !numpy.ndarray<[2,3]:f32>
|
||||
// CHECK-NEXT: numpy.overwrite_array
|
||||
numpy.overwrite_array %arg2 overwrites %arg0 : tensor<2x3xf32>, !numpy.ndarray<[2,3]:f32>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue