mirror of https://github.com/llvm/torch-mlir
* [tosa] Support for AtenNe[Tensor|Scalar]Op, AtenLog2Op, AtenBitwiseAndTensorOp, AtenSquareOp and AtenThresholdOp * Fix for Issue #532 - Mixed input types for few ops and updated few tests to use i32 instead of i64 Signed-off-by: Anup Gangwar <anup.gangwar@arm.com> Co-authored-by: Anup Gangwar <anup.gangwar@arm.com>pull/587/head
parent
c1167853db
commit
756b75fb2d
|
@ -739,14 +739,14 @@ class ElementwiseMulScalarModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.mul(x, 8.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseMulScalarModule())
|
||||
def ElementwiseMulScalarModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4)))
|
||||
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
class ElementwiseMulTensorFloatModule(torch.nn.Module):
|
||||
|
@ -1045,7 +1045,7 @@ class ElementwiseSubScalarIntModule(torch.nn.Module):
|
|||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int64, True),
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.sub(x, 2.1, alpha=2)
|
||||
|
@ -1053,7 +1053,7 @@ class ElementwiseSubScalarIntModule(torch.nn.Module):
|
|||
|
||||
@register_test_case(module_factory=lambda: ElementwiseSubScalarIntModule())
|
||||
def ElementwiseSubScalarIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4)))
|
||||
module.forward(torch.randint(10, (3, 4), dtype=torch.int32))
|
||||
|
||||
|
||||
class ElementwiseSubScalarFloatModule(torch.nn.Module):
|
||||
|
@ -1072,8 +1072,7 @@ class ElementwiseSubScalarFloatModule(torch.nn.Module):
|
|||
def ElementwiseSubScalarFloatModule_basic(module, tu: TestUtils):
|
||||
module.forward(tu.rand(3, 4))
|
||||
|
||||
|
||||
class ElementwiseAddScalarIntModule(torch.nn.Module):
|
||||
class ElementwiseAddScalarInt64Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
|
@ -1085,9 +1084,26 @@ class ElementwiseAddScalarIntModule(torch.nn.Module):
|
|||
def forward(self, x):
|
||||
return torch.add(x, 3.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAddScalarInt64Module())
|
||||
def ElementwiseAddScalarInt64Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4)))
|
||||
|
||||
|
||||
class ElementwiseAddScalarIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1, -1], torch.int32, True),
|
||||
])
|
||||
def forward(self, x):
|
||||
return torch.add(x, 3.0)
|
||||
|
||||
@register_test_case(module_factory=lambda: ElementwiseAddScalarIntModule())
|
||||
def ElementwiseAddScalarIntModule_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (3, 4)))
|
||||
module.forward(torch.randint(10, (2, 3), dtype=torch.int32))
|
||||
|
||||
|
||||
class ElementwiseAddScalarFloatModule(torch.nn.Module):
|
||||
|
|
|
@ -12,6 +12,24 @@ from torch_mlir_e2e_test.torchscript.annotations import annotate_args, export
|
|||
# ==============================================================================
|
||||
|
||||
|
||||
class Threshold1dIntI32Module(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
@export
|
||||
@annotate_args([
|
||||
None,
|
||||
([-1], torch.int32, True),
|
||||
])
|
||||
|
||||
def forward(self, input):
|
||||
return torch.ops.aten.threshold(input, 1, 2)
|
||||
|
||||
@register_test_case(module_factory=lambda: Threshold1dIntI32Module())
|
||||
def Threshold1dIntI32Module_basic(module, tu: TestUtils):
|
||||
module.forward(torch.randint(10, (4,), dtype=torch.int32))
|
||||
|
||||
|
||||
class Threshold1dIntModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
|
|
@ -99,4 +99,12 @@ TOSA_PASS_SET = {
|
|||
"LayerNormNormalizeOverAllDimsModule_basic",
|
||||
"PermuteModule_basic",
|
||||
"PermuteNegativeIndexModule_basic",
|
||||
"ElementwiseLog2Module_basic",
|
||||
"Threshold1dIntI32Module_basic",
|
||||
"Threshold1dFloatModule_basic",
|
||||
"Threshold2dFloatModule_basic",
|
||||
"Threshold3dFloatModule_basic",
|
||||
"ElementwiseSubScalarIntModule_basic",
|
||||
"ElementwiseAddScalarIntModule_basic",
|
||||
"ElementwiseMulScalarModule_basic",
|
||||
}
|
||||
|
|
|
@ -111,38 +111,63 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
static bool isInValidRange(bool isFloat, const double &doubleValue, bool isInt,
|
||||
const int64_t &intValue) {
|
||||
if (isFloat) {
|
||||
return (doubleValue >= std::numeric_limits<T>::min()) &&
|
||||
(doubleValue <= std::numeric_limits<T>::max());
|
||||
} else {
|
||||
assert(isInt);
|
||||
return (intValue >= std::numeric_limits<T>::min()) &&
|
||||
(intValue <= std::numeric_limits<T>::max());
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
// FIXME: This will eventually go into a Tosa*Utils file.
|
||||
LogicalResult torchScalarToTosaTensor(ConversionPatternRewriter &rewriter,
|
||||
Operation *op, Value torchScalarValue,
|
||||
Value &tosaTensor, Type dtype) {
|
||||
Value &tosaTensor, Type dtype,
|
||||
llvm::ArrayRef<int64_t> dshape) {
|
||||
// Retrieve a const float or int value but create the out Tensor with dtype.
|
||||
double doubleValue;
|
||||
auto isFloat =
|
||||
matchPattern(torchScalarValue, m_TorchConstantFloat(&doubleValue));
|
||||
|
||||
int64_t intValue;
|
||||
auto isInt = matchPattern(torchScalarValue, m_TorchConstantInt(&intValue));
|
||||
|
||||
if (!isFloat && !isInt)
|
||||
return op->emitError("Unable to extract the scalar constant");
|
||||
|
||||
if (dtype.isa<mlir::FloatType>()) {
|
||||
double scalarValue;
|
||||
|
||||
if (!matchPattern(torchScalarValue, m_TorchConstantFloat(&scalarValue)))
|
||||
return failure();
|
||||
|
||||
tosaTensor =
|
||||
mlir::tosa::getTosaConstTensorSingleF32(rewriter, op, scalarValue);
|
||||
tosaTensor = tosa::getConstTensor<float>(
|
||||
rewriter, op, (isFloat ? doubleValue : intValue), dshape)
|
||||
.getValue();
|
||||
} else if (auto intType = dtype.dyn_cast<mlir::IntegerType>()) {
|
||||
int64_t scalarValue;
|
||||
|
||||
if (!matchPattern(torchScalarValue, m_TorchConstantInt(&scalarValue)))
|
||||
return failure();
|
||||
|
||||
auto w = intType.getWidth();
|
||||
if (w != 32 && w != 64)
|
||||
return op->emitError("Unsupported integer type") << intType;
|
||||
|
||||
if (w == 32) {
|
||||
tosaTensor = tosa::getConstTensor<int32_t>(
|
||||
rewriter, op, {static_cast<int32_t>(scalarValue)}, {})
|
||||
.getValue();
|
||||
} else if (w == 64) {
|
||||
if (!isInValidRange<int32_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return op->emitError("Supplied value of scalar constant exceeds limits "
|
||||
"of destination type");
|
||||
}
|
||||
int32_t d = isFloat ? static_cast<int32_t>(doubleValue)
|
||||
: static_cast<int32_t>(intValue);
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<int64_t>(rewriter, op, {scalarValue}, {})
|
||||
.getValue();
|
||||
tosa::getConstTensor<int32_t>(rewriter, op, {d}, dshape).getValue();
|
||||
} else if (w == 64) {
|
||||
if (!isInValidRange<int64_t>(isFloat, doubleValue, isInt, intValue)) {
|
||||
return op->emitError("Supplied value of scalar constant exceeds limits "
|
||||
"of destination type");
|
||||
}
|
||||
int64_t d = (isFloat ? static_cast<int64_t>(doubleValue) : intValue);
|
||||
tosaTensor =
|
||||
tosa::getConstTensor<int64_t>(rewriter, op, {d}, dshape).getValue();
|
||||
}
|
||||
return success();
|
||||
} else
|
||||
return op->emitError("Usupported element type");
|
||||
|
||||
|
@ -154,7 +179,7 @@ LogicalResult torchAlphaToTosaTensor(ConversionPatternRewriter &rewriter,
|
|||
Value &alphaTensor, Type dtype,
|
||||
bool checkForUnity) {
|
||||
if (succeeded(torchScalarToTosaTensor(rewriter, op, alphaScalar, alphaTensor,
|
||||
dtype)))
|
||||
dtype, {})))
|
||||
return success();
|
||||
|
||||
// `alpha` has not been specified.
|
||||
|
@ -183,44 +208,59 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
||||
auto rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (!lhsTy)
|
||||
if (!lhsType)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
if (!lhsElemTy.isIntOrFloat())
|
||||
if (auto lhsElemTy = lhsType.getElementType().dyn_cast<IntegerType>()) {
|
||||
if (lhsElemTy.getWidth() > 32)
|
||||
return op.emitError(
|
||||
"Integers with widths greater than 32 are not supported");
|
||||
}
|
||||
|
||||
auto outType =
|
||||
static_cast<Type>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()))
|
||||
.cast<TensorType>();
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat()) {
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
}
|
||||
|
||||
Value rhsAsTensor;
|
||||
if (!rhsTy) {
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
|
||||
op.other(), rhsAsTensor, lhsElemTy)))
|
||||
if (!rhsType) {
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor,
|
||||
outElemTy, {})))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA operation");
|
||||
}
|
||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
auto rhsTensor = rhsType ? rhs : rhsAsTensor;
|
||||
|
||||
// Handle alpha.
|
||||
Value alphaTensor;
|
||||
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), op.alpha(),
|
||||
alphaTensor, lhsElemTy, false)))
|
||||
alphaTensor, outElemTy,
|
||||
/*checkForUnity=*/false))) {
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"alpha in conversion to TOSA operation");
|
||||
}
|
||||
|
||||
auto multTensor = rewriter.create<tosa::MulOp>(
|
||||
op.getLoc(), rhsTy ? rhsTy : RankedTensorType::get({}, lhsElemTy),
|
||||
rhsTensor, alphaTensor, /*shift*/ 0);
|
||||
op.getLoc(), rhsType ? rhsType : RankedTensorType::get({}, outElemTy),
|
||||
rhsTensor, alphaTensor, /*shift=*/0);
|
||||
|
||||
if (outElemTy.isa<mlir::FloatType>()) {
|
||||
if (lhsType.getElementType() != outElemTy)
|
||||
lhs = rewriter.create<tosa::CastOp>(op.getLoc(), outType, lhs);
|
||||
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(op, outType, lhs, multTensor);
|
||||
|
||||
if (lhsElemTy.isa<mlir::FloatType>()) {
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
lhs, multTensor);
|
||||
return success();
|
||||
} else {
|
||||
return op.emitError(
|
||||
|
@ -251,23 +291,42 @@ public:
|
|||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
|
||||
// For bitwise operators, only integer datatype legalization is supported
|
||||
if (lhsElemTy.isa<mlir::FloatType>() &&
|
||||
std::is_same<AtenOpT, AtenBitwiseAndTensorOp>()) {
|
||||
return op.emitError("For bitwise operators, only integer datatype "
|
||||
"legalization is supported");
|
||||
}
|
||||
|
||||
Value rhsAsTensor;
|
||||
if (!rhsTy) {
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
|
||||
op.other(), rhsAsTensor, lhsElemTy)))
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor,
|
||||
lhsElemTy, {})))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA operation");
|
||||
}
|
||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
// There is no Lesser operator in TOSA
|
||||
// There is no Lesser operator in TOSA.
|
||||
auto swapLhsRhs = (std::is_same<AtenOpT, AtenLtTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenLtScalarOp>());
|
||||
|
||||
rewriter.replaceOpWithNewOp<TosaOpT>(
|
||||
op,
|
||||
auto resultOp = rewriter.create<TosaOpT>(
|
||||
op.getLoc(),
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
(swapLhsRhs ? rhsTensor : lhs), (swapLhsRhs ? lhs : rhsTensor));
|
||||
|
||||
// There is no NE operator in TOSA.
|
||||
if (std::is_same<AtenOpT, AtenNeTensorOp>() ||
|
||||
std::is_same<AtenOpT, AtenNeScalarOp>())
|
||||
rewriter.replaceOpWithNewOp<tosa::LogicalNotOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()),
|
||||
resultOp.getResult());
|
||||
else
|
||||
rewriter.replaceOp(op, resultOp.getResult());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
@ -282,29 +341,44 @@ public:
|
|||
matchAndRewrite(AtenOpT op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Value lhs = adaptor.self();
|
||||
auto lhsTy = lhs.getType().dyn_cast<TensorType>();
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsTy = rhs.getType().dyn_cast<TensorType>();
|
||||
auto lhsType = lhs.getType().dyn_cast<TensorType>();
|
||||
|
||||
if (!lhsTy)
|
||||
if (!lhsType)
|
||||
return op.emitError("Only Tensor types supported in TOSA");
|
||||
|
||||
auto lhsElemTy = lhsTy.getElementType();
|
||||
if (!lhsElemTy.isIntOrFloat())
|
||||
auto outType =
|
||||
static_cast<Type>(
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
op.getType()))
|
||||
.cast<TensorType>();
|
||||
|
||||
Type outElemTy = outType.getElementType();
|
||||
if (!outElemTy.isIntOrFloat())
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
|
||||
Value rhsAsTensor;
|
||||
if (!rhsTy) {
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
|
||||
op.other(), rhsAsTensor, lhsElemTy)))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA operation");
|
||||
Value rhsTensor;
|
||||
if (std::is_same<AtenOpT, AtenSquareOp>()) {
|
||||
rhsTensor = lhs;
|
||||
} else {
|
||||
Value rhsAsTensor;
|
||||
Value rhs = adaptor.other();
|
||||
auto rhsType = rhs.getType().dyn_cast<TensorType>();
|
||||
if (!rhsType) {
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.other(),
|
||||
rhsAsTensor, outElemTy, {})))
|
||||
return op.emitError(
|
||||
"Currently only scalar constants are supported for "
|
||||
"conversion in TOSA operation");
|
||||
}
|
||||
rhsTensor = rhsType ? rhs : rhsAsTensor;
|
||||
}
|
||||
auto rhsTensor = rhsTy ? rhs : rhsAsTensor;
|
||||
|
||||
if (lhsElemTy.isa<mlir::FloatType>() ||
|
||||
lhsElemTy.isa<mlir::IntegerType>()) {
|
||||
if (outElemTy.isa<mlir::FloatType>() ||
|
||||
outElemTy.isa<mlir::IntegerType>()) {
|
||||
if (lhsType.getElementType() != outElemTy)
|
||||
lhs = rewriter.create<tosa::CastOp>(op.getLoc(), outType, lhs);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(
|
||||
op,
|
||||
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
|
||||
|
@ -343,8 +417,8 @@ public:
|
|||
|
||||
Value rhsAsTensor;
|
||||
if (!rhsTy) {
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(),
|
||||
op.other(), rhsAsTensor, lhsElemTy)))
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.other(), rhsAsTensor,
|
||||
lhsElemTy, {})))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA operation");
|
||||
}
|
||||
|
@ -807,8 +881,8 @@ LogicalResult ConvertAtenOp<AtenPowTensorScalarOp>::matchAndRewrite(
|
|||
|
||||
Value expTensor;
|
||||
Value expScalar = op.exponent();
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), expScalar,
|
||||
expTensor, selfTy.getElementType())))
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, expScalar, expTensor,
|
||||
selfTy.getElementType(), {})))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA Pow operation");
|
||||
|
||||
|
@ -1584,19 +1658,19 @@ LogicalResult ConvertAtenOp<AtenRsubScalarOp>::matchAndRewrite(
|
|||
|
||||
Value otherTensor, alphaTensor;
|
||||
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op.getOperation(), otherScalar,
|
||||
otherTensor, selfTy.getElementType())))
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, otherScalar, otherTensor,
|
||||
selfTy.getElementType(), {})))
|
||||
return op.emitError("Currently only scalar constants are supported for "
|
||||
"conversion in TOSA Rsub operation");
|
||||
|
||||
if (failed(torchAlphaToTosaTensor(rewriter, op.getOperation(), alphaScalar,
|
||||
alphaTensor, selfTy.getElementType(),
|
||||
true)))
|
||||
/*checkForUnity=*/true)))
|
||||
return failure();
|
||||
|
||||
auto multTensor = rewriter.create<tosa::MulOp>(
|
||||
op->getLoc(), getTypeConverter()->convertType(op.getType()), self,
|
||||
alphaTensor, /*shift*/ 0);
|
||||
alphaTensor, /*shift=*/0);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::SubOp>(
|
||||
op, getTypeConverter()->convertType(op.getType()), otherTensor,
|
||||
|
@ -2142,7 +2216,7 @@ LogicalResult ConvertAtenOp<AtenFlattenUsingIntsOp>::matchAndRewrite(
|
|||
|
||||
auto newType = RankedTensorType::get(newShape, selfType.getElementType());
|
||||
auto reshapeOp =
|
||||
rewriter.create<tosa::ReshapeOp>(op->getLoc(), newType, adaptor.self(),
|
||||
rewriter.create<tosa::ReshapeOp>(op.getLoc(), newType, adaptor.self(),
|
||||
rewriter.getI64ArrayAttr(newShape));
|
||||
|
||||
rewriter.replaceOpWithNewOp<tensor::CastOp>(
|
||||
|
@ -2184,6 +2258,80 @@ LogicalResult ConvertAtenOp<AtenPermuteOp>::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenLog2Op>::matchAndRewrite(
|
||||
AtenLog2Op op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
if (!selfType)
|
||||
return op.emitError("Only tensor types are currently supported");
|
||||
|
||||
// Constant value of ln2.
|
||||
SmallVector<int64_t> ln2Shape(selfType.getRank(), 1);
|
||||
auto ln2Op =
|
||||
tosa::getConstTensor<float>(rewriter, op, {0.69314718056}, ln2Shape)
|
||||
.getValue();
|
||||
auto rcpOp =
|
||||
rewriter.create<tosa::ReciprocalOp>(op.getLoc(), ln2Op.getType(), ln2Op);
|
||||
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
auto logOp =
|
||||
rewriter.create<tosa::LogOp>(op.getLoc(), outType, adaptor.self());
|
||||
rewriter.replaceOpWithNewOp<tosa::MulOp>(op, outType, logOp, rcpOp,
|
||||
/*shift=*/0);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <>
|
||||
LogicalResult ConvertAtenOp<AtenThresholdOp>::matchAndRewrite(
|
||||
AtenThresholdOp op, OpAdaptor adaptor,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
// Not a tensor type.
|
||||
auto selfType = adaptor.self().getType().dyn_cast<TensorType>();
|
||||
if (!selfType)
|
||||
return op.emitError("Only tensor types are currently supported");
|
||||
|
||||
auto selfElemTy = selfType.getElementType();
|
||||
if (!selfElemTy.isIntOrFloat())
|
||||
return op.emitError(
|
||||
"Only floating-point or integer datatype legalization supported");
|
||||
|
||||
// Integer types with width > 32 are not supported
|
||||
auto selfIntType = selfElemTy.dyn_cast<IntegerType>();
|
||||
if (selfIntType && selfIntType.getWidth() > 32) {
|
||||
return op.emitError(
|
||||
"Integer types with width greater than 32 are not supported");
|
||||
}
|
||||
|
||||
SmallVector<int64_t> constTypeShape(selfType.getRank(), 1);
|
||||
Value threshold, value;
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.threshold(), threshold,
|
||||
selfElemTy, constTypeShape)))
|
||||
return op.emitError("Only scalar constant is supported for threshold");
|
||||
|
||||
if (failed(torchScalarToTosaTensor(rewriter, op, op.value(), value,
|
||||
selfElemTy, constTypeShape)))
|
||||
return op.emitError("Only scalar constant is supported for value");
|
||||
|
||||
// Threshold only clamps the upper values. tosa::ClampOp has the same
|
||||
// value for both threshold and clamped value so cannot be used.
|
||||
auto outType = getTypeConverter()->convertType(op.getType());
|
||||
|
||||
auto cmpOp = rewriter.create<tosa::GreaterOp>(
|
||||
op.getLoc(),
|
||||
RankedTensorType::get(selfType.getShape(), rewriter.getIntegerType(1)),
|
||||
adaptor.self(), threshold);
|
||||
|
||||
rewriter.replaceOpWithNewOp<tosa::SelectOp>(op, outType, cmpOp,
|
||||
adaptor.self(), value);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
template <typename AtenOpT, typename TosaOpT>
|
||||
class ConvertAtenPoolingBaseOp : public OpConversionPattern<AtenOpT> {
|
||||
public:
|
||||
|
@ -2527,6 +2675,9 @@ public:
|
|||
INSERT_BINARY_COMPARE_PATTERN(AtenLtScalarOp, tosa::GreaterOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqTensorOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenEqScalarOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenNeTensorOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenNeScalarOp, tosa::EqualOp)
|
||||
INSERT_BINARY_COMPARE_PATTERN(AtenBitwiseAndTensorOp, tosa::BitwiseAndOp)
|
||||
#undef INSERT_BINARY_COMPARE_PATTERN
|
||||
|
||||
#define INSERT_BINARY_MUL_PATTERN(AtenOp) \
|
||||
|
@ -2629,6 +2780,8 @@ public:
|
|||
INSERT_ATENOP_PATTERN(AtenNativeLayerNormOp);
|
||||
INSERT_ATENOP_PATTERN(AtenFlattenUsingIntsOp);
|
||||
INSERT_ATENOP_PATTERN(AtenPermuteOp);
|
||||
INSERT_ATENOP_PATTERN(AtenLog2Op);
|
||||
INSERT_ATENOP_PATTERN(AtenThresholdOp);
|
||||
#undef INSERT_ATENOP_PATTERN
|
||||
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
|
|
|
@ -595,6 +595,23 @@ func @forward(%arg0: !torch.vtensor<[5,2,2,3],f32> , %arg1: !torch.vtensor<[2,2,
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.ne.Tensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.equal"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.logical_not"(%[[VAL_4]]) : (tensor<?x?xi1>) -> tensor<?x?xi1>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xi1> -> !torch.vtensor<[?,?],i1>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],i1>
|
||||
// CHECK: }
|
||||
func @torch.aten.ne.Tensor$basic(%arg0: !torch.vtensor<[?,?],f32>, %arg1: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],i1> {
|
||||
%0 = torch.aten.ne.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],f32>, !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],i1>
|
||||
return %0 : !torch.vtensor<[?,?],i1>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @forward(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[3,4,2],f32>) -> !torch.vtensor<[3,2,4],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[3,4,2],f32> -> tensor<3x4x2xf32>
|
||||
|
@ -615,3 +632,36 @@ func @forward(%arg0: !torch.vtensor<[3,4,2],f32> ) -> !torch.vtensor<[3,2,4],f32
|
|||
%1 = torch.aten.permute %arg0, %0 : !torch.vtensor<[3,4,2],f32>, !torch.list<!torch.int> -> !torch.vtensor<[3,2,4],f32>
|
||||
return %1 : !torch.vtensor<[3,2,4],f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.bitwise_and.Tensor$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],si32>,
|
||||
// CHECK-SAME: %[[VAL_1:.*]]: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||
// CHECK: %[[VAL_2:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_3:.*]] = torch_c.to_builtin_tensor %[[VAL_1]] : !torch.vtensor<[?,?],si32> -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.bitwise_and"(%[[VAL_2]], %[[VAL_3]]) : (tensor<?x?xi32>, tensor<?x?xi32>) -> tensor<?x?xi32>
|
||||
// CHECK: %[[VAL_5:.*]] = torch_c.from_builtin_tensor %[[VAL_4]] : tensor<?x?xi32> -> !torch.vtensor<[?,?],si32>
|
||||
// CHECK: return %[[VAL_5]] : !torch.vtensor<[?,?],si32>
|
||||
// CHECK: }
|
||||
func @torch.aten.bitwise_and.Tensor$basic(%arg0: !torch.vtensor<[?,?],si32>, %arg1: !torch.vtensor<[?,?],si32>) -> !torch.vtensor<[?,?],si32> {
|
||||
%0 = torch.aten.bitwise_and.Tensor %arg0, %arg1 : !torch.vtensor<[?,?],si32>, !torch.vtensor<[?,?],si32> -> !torch.vtensor<[?,?],si32>
|
||||
return %0 : !torch.vtensor<[?,?],si32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @torch.aten.log2$basic(
|
||||
// CHECK-SAME: %[[VAL_0:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
|
||||
// CHECK: %[[VAL_1:.*]] = torch_c.to_builtin_tensor %[[VAL_0]] : !torch.vtensor<[?,?],f32> -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_2:.*]] = "tosa.const"() {value = dense<0.693147182> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
|
||||
// CHECK: %[[VAL_3:.*]] = "tosa.reciprocal"(%[[VAL_2]]) : (tensor<1x1xf32>) -> tensor<1x1xf32>
|
||||
// CHECK: %[[VAL_4:.*]] = "tosa.log"(%[[VAL_1]]) : (tensor<?x?xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_5:.*]] = "tosa.mul"(%[[VAL_4]], %[[VAL_3]]) {shift = 0 : i32} : (tensor<?x?xf32>, tensor<1x1xf32>) -> tensor<?x?xf32>
|
||||
// CHECK: %[[VAL_6:.*]] = torch_c.from_builtin_tensor %[[VAL_5]] : tensor<?x?xf32> -> !torch.vtensor<[?,?],f32>
|
||||
// CHECK: return %[[VAL_6]] : !torch.vtensor<[?,?],f32>
|
||||
// CHECK: }
|
||||
func @torch.aten.log2$basic(%arg0: !torch.vtensor<[?,?],f32> ) -> !torch.vtensor<[?,?],f32> {
|
||||
%0 = torch.aten.log2 %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
|
||||
return %0 : !torch.vtensor<[?,?],f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue