[tosa] Support for some ops and fix for Issue #532 (#575)

* [tosa] Support for AtenNe[Tensor|Scalar]Op, AtenLog2Op,
AtenBitwiseAndTensorOp, AtenSquareOp and AtenThresholdOp
* Fix for Issue #532 - Mixed input types for few ops and updated few
tests to use i32 instead of i64

Signed-off-by: Anup Gangwar <anup.gangwar@arm.com>

Co-authored-by: Anup Gangwar <anup.gangwar@arm.com>
pull/587/head
Anup Gangwar 2022-02-11 14:30:02 -06:00 committed by GitHub
parent c1167853db
commit 756b75fb2d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 320 additions and 75 deletions

View File

@ -739,14 +739,14 @@ class ElementwiseMulScalarModule(torch.nn.Module):
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.mul(x, 8.0)
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 4)))
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
class ElementwiseMulTensorFloatModule(torch.nn.Module):
@ -1045,7 +1045,7 @@ class ElementwiseSubScalarIntModule(torch.nn.Module):
@export
@annotate_args([
None,
([-1, -1], torch.int64, True),
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.sub(x, 2.1, alpha=2)
@ -1053,7 +1053,7 @@ class ElementwiseSubScalarIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: ElementwiseSubScalarIntModule())
def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 4)))
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
class ElementwiseSubScalarFloatModule(torch.nn.Module):
@ -1072,8 +1072,7 @@ class ElementwiseSubScalarFloatModule(torch.nn.Module):
def ElementwiseSubScalarFloatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4))
class ElementwiseAddScalarIntModule(torch.nn.Module):
class ElementwiseAddScalarInt64Module(torch.nn.Module):
def __init__(self):
super().__init__()
@ -1085,9 +1084,26 @@ class ElementwiseAddScalarIntModule(torch.nn.Module):
def forward(self, x):
return torch.add(x, 3.0)
@register_test_case(module_factory=lambda: ElementwiseAddScalarInt64Module())
def ElementwiseAddScalarInt64Module_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 4)))
class ElementwiseAddScalarIntModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1], torch.int32, True),
])
def forward(self, x):
return torch.add(x, 3.0)
@register_test_case(module_factory=lambda: ElementwiseAddScalarIntModule())
def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (3, 4)))
module.forward(torch.randint(10, (2, 3), dtype=torch.int32))
class ElementwiseAddScalarFloatModule(torch.nn.Module):

View File

@ -12,6 +12,24 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
# ==============================================================================
class Threshold1dIntI32Module(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1], torch.int32, True),
])
def forward(self, input):
return torch.ops.aten.threshold(input, 1, 2)
@register_test_case(module_factory=lambda: Threshold1dIntI32Module())
def Threshold1dIntI32Module_basic(module, tu: TestUtils):
module.forward(torch.randint(10, (4,), dtype=torch.int32))
class Threshold1dIntModule(torch.nn.Module):
def __init__(self):
super().__init__()

View File

@ -99,4 +99,12 @@ TOSA_PASS_SET = {
"LayerNormNormalizeOverAllDimsModule_basic",
"PermuteModule_basic",
"PermuteNegativeIndexModule_basic",
"ElementwiseLog2Module_basic",
"Threshold1dIntI32Module_basic",
"Threshold1dFloatModule_basic",
"Threshold2dFloatModule_basic",
"Threshold3dFloatModule_basic",
"ElementwiseSubScalarIntModule_basic",
"ElementwiseAddScalarIntModule_basic",
"ElementwiseMulScalarModule_basic",
}

View File

@ -111,38 +111,63 @@ public:
}
};
template <typename T>
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
const int64_t &intValue) {
if (isFloat) {
return (doubleValue >= std::numeric_limits<T>::min()) &&
(doubleValue <= std::numeric_limits<T>::max());
} else {
assert(isInt);
return (intValue >= std::numeric_limits<T>::min()) &&
(intValue <= std::numeric_limits<T>::max());
}
return true;
}
// FIXME: This will eventually go into a Tosa*Utils file.
LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
Operation *op, Value torchScalarValue,
Value &tosaTensor, Type dtype) {
Value &tosaTensor, Type dtype,
llvm::ArrayRef<int64_t> dshape) {
// Retrieve a const float or int value but create the out Tensor with dtype.
double doubleValue;
auto isFloat =
matchPattern(torchScalarValue, m_TorchConstantFloat(&doubleValue));
int64_t intValue;
auto isInt = matchPattern(torchScalarValue, m_TorchConstantInt(&intValue));
if (!isFloat && !isInt)
return op->emitError("Unable to extract the scalar constant");
if (dtype.isa<mlir::FloatType>()) {
double scalarValue;
if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue)))
return failure();
tosaTensor =
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue);
tosaTensor = tosa::getConstTensor<float>(
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
.getValue();
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
int64_t scalarValue;
if (!matchPattern(torchScalarValue, m_TorchConstantInt(&scalarValue)))
return failure();
auto w = intType.getWidth();
if (w != 32 && w != 64)
return op->emitError("Unsupported integer type") << intType;
if (w == 32) {
tosaTensor = tosa::getConstTensor<int32_t>(
rewriter, op, {static_cast<int32_t>(scalarValue)}, {})
.getValue();
} else if (w == 64) {
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
return op->emitError("Supplied value of scalar constant exceeds limits "
"of destination type");
}
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
: static_cast<int32_t>(intValue);
tosaTensor =
tosa::getConstTensor<int64_t>(rewriter, op, {scalarValue}, {})
.getValue();
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
} else if (w == 64) {
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
return op->emitError("Supplied value of scalar constant exceeds limits "
"of destination type");
}
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
tosaTensor =
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
}
return success();
} else
return op->emitError("Usupported element type");
@ -154,7 +179,7 @@ LogicalResult torchAlphaToTosaTensor(ConversionPatternRewriter &rewriter,
Value &alphaTensor, Type dtype,
bool checkForUnity) {
if (succeeded(torchScalarToTosaTensor(rewriter, op, alphaScalar, alphaTensor,
dtype)))
dtype, {})))
return success();
// `alpha` has not been specified.
@ -183,44 +208,59 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
auto lhsType = lhs.getType().dyn_cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
auto rhsType = rhs.getType().dyn_cast<TensorType>();
if (!lhsTy)
if (!lhsType)
return op.emitError("Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
if (!lhsElemTy.isIntOrFloat())
if (auto lhsElemTy = lhsType.getElementType().dyn_cast<IntegerType>()) {
if (lhsElemTy.getWidth() > 32)
return op.emitError(
"Integers with widths greater than 32 are not supported");
}
auto outType =
static_cast<Type>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()))
.cast<TensorType>();
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat()) {
return op.emitError(
"Only floating-point or integer datatype legalization supported");
}
Value rhsAsTensor;
if (!rhsTy) {
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
op.other(), rhsAsTensor, lhsElemTy)))
if (!rhsType) {
if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor,
outElemTy, {})))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA operation");
}
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
auto rhsTensor = rhsType ? rhs : rhsAsTensor;
// Handle alpha.
Value alphaTensor;
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), op.alpha(),
alphaTensor, lhsElemTy, false)))
alphaTensor, outElemTy,
/*checkForUnity=*/false))) {
return op.emitError("Currently only scalar constants are supported for "
"alpha in conversion to TOSA operation");
}
auto multTensor = rewriter.create<tosa::MulOp>(
op.getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy),
rhsTensor, alphaTensor, /*shift*/ 0);
op.getLoc(), rhsType ? rhsType : RankedTensorType::get({}, outElemTy),
rhsTensor, alphaTensor, /*shift=*/0);
if (outElemTy.isa<mlir::FloatType>()) {
if (lhsType.getElementType() != outElemTy)
lhs = rewriter.create<tosa::CastOp>(op.getLoc(), outType, lhs);
rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, lhs, multTensor);
if (lhsElemTy.isa<mlir::FloatType>()) {
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
lhs, multTensor);
return success();
} else {
return op.emitError(
@ -251,23 +291,42 @@ public:
return op.emitError(
"Only floating-point or integer datatype legalization supported");
// For bitwise operators, only integer datatype legalization is supported
if (lhsElemTy.isa<mlir::FloatType>() &&
std::is_same<AtenOpT, AtenBitwiseAndTensorOp>()) {
return op.emitError("For bitwise operators, only integer datatype "
"legalization is supported");
}
Value rhsAsTensor;
if (!rhsTy) {
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
op.other(), rhsAsTensor, lhsElemTy)))
if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor,
lhsElemTy, {})))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA operation");
}
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
// There is no Lesser operator in TOSA
// There is no Lesser operator in TOSA.
auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
std::is_same<AtenOpT, AtenLtScalarOp>());
rewriter.replaceOpWithNewOp<TosaOpT>(
op,
auto resultOp = rewriter.create<TosaOpT>(
op.getLoc(),
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
(swapLhsRhs ? rhsTensor : lhs), (swapLhsRhs ? lhs : rhsTensor));
// There is no NE operator in TOSA.
if (std::is_same<AtenOpT, AtenNeTensorOp>() ||
std::is_same<AtenOpT, AtenNeScalarOp>())
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()),
resultOp.getResult());
else
rewriter.replaceOp(op, resultOp.getResult());
return success();
}
};
@ -282,29 +341,44 @@ public:
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Value lhs = adaptor.self();
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
Value rhs = adaptor.other();
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
auto lhsType = lhs.getType().dyn_cast<TensorType>();
if (!lhsTy)
if (!lhsType)
return op.emitError("Only Tensor types supported in TOSA");
auto lhsElemTy = lhsTy.getElementType();
if (!lhsElemTy.isIntOrFloat())
auto outType =
static_cast<Type>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()))
.cast<TensorType>();
Type outElemTy = outType.getElementType();
if (!outElemTy.isIntOrFloat())
return op.emitError(
"Only floating-point or integer datatype legalization supported");
Value rhsAsTensor;
if (!rhsTy) {
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
op.other(), rhsAsTensor, lhsElemTy)))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA operation");
Value rhsTensor;
if (std::is_same<AtenOpT, AtenSquareOp>()) {
rhsTensor = lhs;
} else {
Value rhsAsTensor;
Value rhs = adaptor.other();
auto rhsType = rhs.getType().dyn_cast<TensorType>();
if (!rhsType) {
if (failed(torchScalarToTosaTensor(rewriter, op, op.other(),
rhsAsTensor, outElemTy, {})))
return op.emitError(
"Currently only scalar constants are supported for "
"conversion in TOSA operation");
}
rhsTensor = rhsType ? rhs : rhsAsTensor;
}
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
if (lhsElemTy.isa<mlir::FloatType>() ||
lhsElemTy.isa<mlir::IntegerType>()) {
if (outElemTy.isa<mlir::FloatType>() ||
outElemTy.isa<mlir::IntegerType>()) {
if (lhsType.getElementType() != outElemTy)
lhs = rewriter.create<tosa::CastOp>(op.getLoc(), outType, lhs);
rewriter.replaceOpWithNewOp<tosa::MulOp>(
op,
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
@ -343,8 +417,8 @@ public:
Value rhsAsTensor;
if (!rhsTy) {
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
op.other(), rhsAsTensor, lhsElemTy)))
if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor,
lhsElemTy, {})))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA operation");
}
@ -807,8 +881,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
Value expTensor;
Value expScalar = op.exponent();
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), expScalar,
expTensor, selfTy.getElementType())))
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
selfTy.getElementType(), {})))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA Pow operation");
@ -1584,19 +1658,19 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
Value otherTensor, alphaTensor;
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), otherScalar,
otherTensor, selfTy.getElementType())))
if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,
selfTy.getElementType(), {})))
return op.emitError("Currently only scalar constants are supported for "
"conversion in TOSA Rsub operation");
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar,
alphaTensor, selfTy.getElementType(),
true)))
/*checkForUnity=*/true)))
return failure();
auto multTensor = rewriter.create<tosa::MulOp>(
op->getLoc(), getTypeConverter()->convertType(op.getType()), self,
alphaTensor, /*shift*/ 0);
alphaTensor, /*shift=*/0);
rewriter.replaceOpWithNewOp<tosa::SubOp>(
op, getTypeConverter()->convertType(op.getType()), otherTensor,
@ -2142,7 +2216,7 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
auto newType = RankedTensorType::get(newShape, selfType.getElementType());
auto reshapeOp =
rewriter.create<tosa::ReshapeOp>(op->getLoc(), newType, adaptor.self(),
rewriter.create<tosa::ReshapeOp>(op.getLoc(), newType, adaptor.self(),
rewriter.getI64ArrayAttr(newShape));
rewriter.replaceOpWithNewOp<tensor::CastOp>(
@ -2184,6 +2258,80 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
AtenLog2Op op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType)
return op.emitError("Only tensor types are currently supported");
// Constant value of ln2.
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
auto ln2Op =
tosa::getConstTensor<float>(rewriter, op, {0.69314718056}, ln2Shape)
.getValue();
auto rcpOp =
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);
auto outType = getTypeConverter()->convertType(op.getType());
auto logOp =
rewriter.create<tosa::LogOp>(op.getLoc(), outType, adaptor.self());
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, outType, logOp, rcpOp,
/*shift=*/0);
return success();
}
template <>
LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
AtenThresholdOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const {
// Not a tensor type.
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
if (!selfType)
return op.emitError("Only tensor types are currently supported");
auto selfElemTy = selfType.getElementType();
if (!selfElemTy.isIntOrFloat())
return op.emitError(
"Only floating-point or integer datatype legalization supported");
// Integer types with width > 32 are not supported
auto selfIntType = selfElemTy.dyn_cast<IntegerType>();
if (selfIntType && selfIntType.getWidth() > 32) {
return op.emitError(
"Integer types with width greater than 32 are not supported");
}
SmallVector<int64_t> constTypeShape(selfType.getRank(), 1);
Value threshold, value;
if (failed(torchScalarToTosaTensor(rewriter, op, op.threshold(), threshold,
selfElemTy, constTypeShape)))
return op.emitError("Only scalar constant is supported for threshold");
if (failed(torchScalarToTosaTensor(rewriter, op, op.value(), value,
selfElemTy, constTypeShape)))
return op.emitError("Only scalar constant is supported for value");
// Threshold only clamps the upper values. tosa::ClampOp has the same
// value for both threshold and clamped value so cannot be used.
auto outType = getTypeConverter()->convertType(op.getType());
auto cmpOp = rewriter.create<tosa::GreaterOp>(
op.getLoc(),
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
adaptor.self(), threshold);
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmpOp,
adaptor.self(), value);
return success();
}
template <typename AtenOpT, typename TosaOpT>
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
public:
@ -2527,6 +2675,9 @@ public:
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp)
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp)
#undef INSERT_BINARY_COMPARE_PATTERN
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
@ -2629,6 +2780,8 @@ public:
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
INSERT_ATENOP_PATTERN(AtenPermuteOp);
INSERT_ATENOP_PATTERN(AtenLog2Op);
INSERT_ATENOP_PATTERN(AtenThresholdOp);
#undef INSERT_ATENOP_PATTERN
if (failed(applyPartialConversion(getOperation(), target,

View File

@ -595,6 +595,23 @@ func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,
// -----
// CHECK-LABEL: func @torch.aten.ne.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_4:.*]] = "tosa.equal"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
// CHECK: %[[VAL_5:.*]] = "tosa.logical_not"(%[[VAL_4]]) : (tensor<?x?xi1>) -> tensor<?x?xi1>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],i1>
// CHECK: }
func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
%0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
return %0 : !torch.vtensor<[?,?],i1>
}
// -----
// CHECK-LABEL: func @forward(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32>
@ -615,3 +632,36 @@ func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32
%1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[3,4,2],f32>, !torch.list<!torch.int> -> !torch.vtensor<[3,2,4],f32>
return %1 : !torch.vtensor<[3,2,4],f32>
}
// -----
// CHECK-LABEL: func @torch.aten.bitwise_and.Tensor$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
// CHECK: %[[VAL_4:.*]] = "tosa.bitwise_and"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
// CHECK: }
func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
%0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
return %0 : !torch.vtensor<[?,?],si32>
}
// -----
// CHECK-LABEL: func @torch.aten.log2$basic(
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.693147182> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
// CHECK: %[[VAL_3:.*]] = "tosa.reciprocal"(%[[VAL_2]]) : (tensor<1x1xf32>) -> tensor<1x1xf32>
// CHECK: %[[VAL_4:.*]] = "tosa.log"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_5:.*]] = "tosa.mul"(%[[VAL_4]], %[[VAL_3]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<1x1xf32>) -> tensor<?x?xf32>
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
// CHECK: }
func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.log2 %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}